|
7 | 7 | from codex.types.projects.access_key_retrieve_project_id_response import ( |
8 | 8 | AccessKeyRetrieveProjectIDResponse, |
9 | 9 | ) |
10 | | -from codex.types.projects.entry_query_response import ( |
11 | | - Entry as SDKEntry, |
12 | | -) |
13 | | -from codex.types.projects.entry_query_response import ( |
14 | | - EntryManagedMetadata, |
15 | | - EntryManagedMetadataTrustworthiness, |
16 | | - EntryQueryResponse, |
17 | | -) |
18 | 10 |
|
19 | 11 | from cleanlab_codex.project import MissingProjectError, Project |
20 | | -from cleanlab_codex.types.entry import EntryCreate |
21 | 12 |
|
22 | 13 | FAKE_PROJECT_ID = str(uuid.uuid4()) |
23 | 14 | FAKE_USER_ID = "Test User" |
@@ -71,44 +62,6 @@ def test_create_project(mock_client_from_api_key: MagicMock, default_headers: di |
71 | 62 | assert mock_client_from_api_key.projects.retrieve.call_count == 0 |
72 | 63 |
|
73 | 64 |
|
74 | | -def test_add_entries(mock_client_from_api_key: MagicMock) -> None: |
75 | | - answered_entry_create = EntryCreate( |
76 | | - question="What is the capital of France?", |
77 | | - answer="Paris", |
78 | | - ) |
79 | | - unanswered_entry_create = EntryCreate( |
80 | | - question="What is the capital of Germany?", |
81 | | - ) |
82 | | - project = Project(mock_client_from_api_key, FAKE_PROJECT_ID) |
83 | | - project.add_entries([answered_entry_create, unanswered_entry_create]) |
84 | | - |
85 | | - for call, entry in zip( |
86 | | - mock_client_from_api_key.projects.entries.create.call_args_list, |
87 | | - [answered_entry_create, unanswered_entry_create], |
88 | | - ): |
89 | | - assert call.args[0] == FAKE_PROJECT_ID |
90 | | - assert call.kwargs["question"] == entry["question"] |
91 | | - assert call.kwargs["answer"] == entry.get("answer") |
92 | | - |
93 | | - |
94 | | -def test_add_entries_no_access_key(mock_client_from_access_key: MagicMock) -> None: |
95 | | - mock_error = Mock(response=Mock(status=401), body={"error": "Unauthorized"}) |
96 | | - |
97 | | - mock_client_from_access_key.projects.entries.create.side_effect = AuthenticationError( |
98 | | - "test", response=mock_error.response, body=mock_error.body |
99 | | - ) |
100 | | - |
101 | | - answered_entry_create = EntryCreate( |
102 | | - question="What is the capital of France?", |
103 | | - answer="Paris", |
104 | | - ) |
105 | | - |
106 | | - project = Project.from_access_key(DUMMY_ACCESS_KEY) |
107 | | - |
108 | | - with pytest.raises(AuthenticationError, match="See cleanlab_codex.Client.get_project"): |
109 | | - project.add_entries([answered_entry_create]) |
110 | | - |
111 | | - |
112 | 65 | def test_create_access_key(mock_client_from_api_key: MagicMock, default_headers: dict[str, str]) -> None: |
113 | 66 | project = Project(mock_client_from_api_key, FAKE_PROJECT_ID) |
114 | 67 | access_key_name = "Test Access Key" |
@@ -144,83 +97,3 @@ def test_init_nonexistent_project_id(mock_client_from_access_key: MagicMock) -> |
144 | 97 | with pytest.raises(MissingProjectError): |
145 | 98 | Project(mock_client_from_access_key, FAKE_PROJECT_ID) |
146 | 99 | assert mock_client_from_access_key.projects.retrieve.call_count == 1 |
147 | | - |
148 | | - |
149 | | -def test_query_question_found_fallback_answer( |
150 | | - mock_client_from_access_key: MagicMock, |
151 | | -) -> None: |
152 | | - unanswered_entry = SDKEntry( |
153 | | - id=str(uuid.uuid4()), |
154 | | - question="What is the capital of France?", |
155 | | - answer=None, |
156 | | - managed_metadata=EntryManagedMetadata(trustworthiness=EntryManagedMetadataTrustworthiness(scores=[0.95])), |
157 | | - ) |
158 | | - |
159 | | - mock_client_from_access_key.projects.entries.query.return_value = EntryQueryResponse( |
160 | | - entry=unanswered_entry, answer=None |
161 | | - ) |
162 | | - project = Project(mock_client_from_access_key, FAKE_PROJECT_ID) |
163 | | - res = project.query("What is the capital of France?") |
164 | | - assert res[0] is None |
165 | | - assert res[1] is not None |
166 | | - assert res[1].model_dump() == unanswered_entry.model_dump() |
167 | | - |
168 | | - |
169 | | -def test_query_question_not_found_fallback_answer( |
170 | | - mock_client_from_access_key: MagicMock, |
171 | | -) -> None: |
172 | | - mock_entry = SDKEntry( |
173 | | - id="fake-id", |
174 | | - question="What is the capital of France?", |
175 | | - answer=None, |
176 | | - managed_metadata=EntryManagedMetadata(trustworthiness=EntryManagedMetadataTrustworthiness(scores=[0.95])), |
177 | | - ) |
178 | | - mock_client_from_access_key.projects.entries.query.return_value = EntryQueryResponse(entry=mock_entry, answer=None) |
179 | | - |
180 | | - project = Project(mock_client_from_access_key, FAKE_PROJECT_ID) |
181 | | - res = project.query("What is the capital of France?", fallback_answer="Paris") |
182 | | - assert res[0] == "Paris" |
183 | | - assert res[1] is not None |
184 | | - assert res[1].model_dump() == mock_entry.model_dump() |
185 | | - |
186 | | - |
187 | | -def test_query_answer_found(mock_client_from_access_key: MagicMock) -> None: |
188 | | - answered_entry = SDKEntry( |
189 | | - id=str(uuid.uuid4()), |
190 | | - question="What is the capital of France?", |
191 | | - answer="Paris", |
192 | | - managed_metadata=EntryManagedMetadata(trustworthiness=EntryManagedMetadataTrustworthiness(scores=[0.95])), |
193 | | - ) |
194 | | - mock_client_from_access_key.projects.entries.query.return_value = EntryQueryResponse( |
195 | | - answer="Paris", entry=answered_entry |
196 | | - ) |
197 | | - project = Project(mock_client_from_access_key, FAKE_PROJECT_ID) |
198 | | - res = project.query("What is the capital of France?") |
199 | | - assert res[0] == answered_entry.answer |
200 | | - assert res[1] is not None |
201 | | - assert res[1].model_dump() == answered_entry.model_dump() |
202 | | - |
203 | | - |
204 | | -def test_query_answer_found_with_metadata(mock_client_from_access_key: MagicMock) -> None: |
205 | | - answered_entry = SDKEntry( |
206 | | - id=str(uuid.uuid4()), |
207 | | - question="What is the capital of France?", |
208 | | - answer="Paris", |
209 | | - client_query_metadata=[{"trustworthiness_score": 0.95}], |
210 | | - managed_metadata=EntryManagedMetadata(trustworthiness=EntryManagedMetadataTrustworthiness(scores=[0.95])), |
211 | | - ) |
212 | | - mock_client_from_access_key.projects.entries.query.return_value = EntryQueryResponse( |
213 | | - answer="Paris", entry=answered_entry |
214 | | - ) |
215 | | - project = Project(mock_client_from_access_key, FAKE_PROJECT_ID) |
216 | | - res = project.query("What is the capital of France?", metadata={"trustworthiness_score": 0.95}) |
217 | | - assert res[0] == answered_entry.answer |
218 | | - assert res[1] is not None |
219 | | - assert res[1].model_dump() == answered_entry.model_dump() # metadata should be included in the entry |
220 | | - |
221 | | - |
222 | | -def test_add_entries_empty_list(mock_client_from_access_key: MagicMock) -> None: |
223 | | - """Test adding an empty list of entries""" |
224 | | - project = Project(mock_client_from_access_key, FAKE_PROJECT_ID) |
225 | | - project.add_entries([]) |
226 | | - mock_client_from_access_key.projects.entries.create.assert_not_called() |
0 commit comments