|
50 | 50 | ) |
51 | 51 | ) |
52 | 52 |
|
| 53 | +ebt_partial_1 = EmbedResponse_EmbeddingsByType( |
| 54 | + response_type="embeddings_by_type", |
| 55 | + id="1", |
| 56 | + embeddings=EmbedByTypeResponseEmbeddings( |
| 57 | + float=[[0, 1, 2], [3, 4, 5]], |
| 58 | + int8=[[0, 1, 2], [3, 4, 5]], |
| 59 | + binary=[[5, 6, 7], [8, 9, 10]], |
| 60 | + ), |
| 61 | + texts=["hello", "goodbye"], |
| 62 | + meta=ApiMeta( |
| 63 | + api_version=ApiMetaApiVersion(version="1"), |
| 64 | + billed_units=ApiMetaBilledUnits( |
| 65 | + input_tokens=1, |
| 66 | + output_tokens=1, |
| 67 | + search_units=1, |
| 68 | + classifications=1 |
| 69 | + ), |
| 70 | + warnings=["test_warning_1"] |
| 71 | + ) |
| 72 | +) |
| 73 | + |
| 74 | +ebt_partial_2 = EmbedResponse_EmbeddingsByType( |
| 75 | + response_type="embeddings_by_type", |
| 76 | + id="2", |
| 77 | + embeddings=EmbedByTypeResponseEmbeddings( |
| 78 | + float=[[7, 8, 9], [10, 11, 12]], |
| 79 | + int8=[[7, 8, 9], [10, 11, 12]], |
| 80 | + binary=[[14, 15, 16], [17, 18, 19]], |
| 81 | + ), |
| 82 | + texts=["bye", "seeya"], |
| 83 | + meta=ApiMeta( |
| 84 | + api_version=ApiMetaApiVersion(version="1"), |
| 85 | + billed_units=ApiMetaBilledUnits( |
| 86 | + input_tokens=2, |
| 87 | + output_tokens=2, |
| 88 | + search_units=2, |
| 89 | + classifications=2 |
| 90 | + ), |
| 91 | + warnings=["test_warning_1", "test_warning_2"] |
| 92 | + ) |
| 93 | +) |
| 94 | + |
53 | 95 | ebf_1 = EmbedResponse_EmbeddingsFloats( |
54 | 96 | response_type="embeddings_floats", |
55 | 97 | id="1", |
@@ -93,7 +135,6 @@ def test_merge_embeddings_by_type(self) -> None: |
93 | 135 | ebt_2 |
94 | 136 | ]) |
95 | 137 |
|
96 | | - |
97 | 138 | if resp.meta is None: |
98 | 139 | raise Exception("this is just for mpy") |
99 | 140 |
|
@@ -147,3 +188,34 @@ def test_merge_embeddings_floats(self) -> None: |
147 | 188 | warnings=resp.meta.warnings # order ignored |
148 | 189 | ) |
149 | 190 | )) |
| 191 | + |
| 192 | + def test_merge_partial_embeddings_floats(self) -> None: |
| 193 | + resp = merge_embed_responses([ |
| 194 | + ebt_partial_1, |
| 195 | + ebt_partial_2 |
| 196 | + ]) |
| 197 | + |
| 198 | + if resp.meta is None: |
| 199 | + raise Exception("this is just for mpy") |
| 200 | + |
| 201 | + self.assertEqual(set(resp.meta.warnings or []), {"test_warning_1", "test_warning_2"}) |
| 202 | + self.assertEqual(resp, EmbedResponse_EmbeddingsByType( |
| 203 | + response_type="embeddings_by_type", |
| 204 | + id="1, 2", |
| 205 | + embeddings=EmbedByTypeResponseEmbeddings( |
| 206 | + float=[[0, 1, 2], [3, 4, 5], [7, 8, 9], [10, 11, 12]], |
| 207 | + int8=[[0, 1, 2], [3, 4, 5], [7, 8, 9], [10, 11, 12]], |
| 208 | + binary=[[5, 6, 7], [8, 9, 10], [14, 15, 16], [17, 18, 19]], |
| 209 | + ), |
| 210 | + texts=["hello", "goodbye", "bye", "seeya"], |
| 211 | + meta=ApiMeta( |
| 212 | + api_version=ApiMetaApiVersion(version="1"), |
| 213 | + billed_units=ApiMetaBilledUnits( |
| 214 | + input_tokens=3, |
| 215 | + output_tokens=3, |
| 216 | + search_units=3, |
| 217 | + classifications=3 |
| 218 | + ), |
| 219 | + warnings=resp.meta.warnings # order ignored |
| 220 | + ) |
| 221 | + )) |
0 commit comments