Skip to content

Commit 189ce39

Browse files
Fix bug when partial embeddings by type requested (#440)
* Fix bug when partial embeddings by type requested * Update src/cohere/utils.py Co-authored-by: harry-cohere <[email protected]> --------- Co-authored-by: harry-cohere <[email protected]>
1 parent ac7e7c4 commit 189ce39

File tree

2 files changed

+77
-2
lines changed

2 files changed

+77
-2
lines changed

src/cohere/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,13 +211,16 @@ def merge_embed_responses(responses: typing.List[EmbedResponse]) -> EmbedRespons
211211
for response in embeddings_type
212212
]
213213

214+
# only get set keys from the pydantic model (i.e. exclude fields that are set to 'None')
215+
fields = embeddings_type[0].embeddings.dict(exclude_unset=True).keys()
216+
214217
merged_dicts = {
215218
field: [
216219
embedding
217220
for embedding_by_type in embeddings_by_type
218221
for embedding in getattr(embedding_by_type, field)
219222
]
220-
for field in EmbedByTypeResponseEmbeddings.__fields__
223+
for field in fields
221224
}
222225

223226
embeddings_by_type_merged = EmbedByTypeResponseEmbeddings.parse_obj(merged_dicts)

tests/test_embed_utils.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,48 @@
5050
)
5151
)
5252

53+
ebt_partial_1 = EmbedResponse_EmbeddingsByType(
54+
response_type="embeddings_by_type",
55+
id="1",
56+
embeddings=EmbedByTypeResponseEmbeddings(
57+
float=[[0, 1, 2], [3, 4, 5]],
58+
int8=[[0, 1, 2], [3, 4, 5]],
59+
binary=[[5, 6, 7], [8, 9, 10]],
60+
),
61+
texts=["hello", "goodbye"],
62+
meta=ApiMeta(
63+
api_version=ApiMetaApiVersion(version="1"),
64+
billed_units=ApiMetaBilledUnits(
65+
input_tokens=1,
66+
output_tokens=1,
67+
search_units=1,
68+
classifications=1
69+
),
70+
warnings=["test_warning_1"]
71+
)
72+
)
73+
74+
ebt_partial_2 = EmbedResponse_EmbeddingsByType(
75+
response_type="embeddings_by_type",
76+
id="2",
77+
embeddings=EmbedByTypeResponseEmbeddings(
78+
float=[[7, 8, 9], [10, 11, 12]],
79+
int8=[[7, 8, 9], [10, 11, 12]],
80+
binary=[[14, 15, 16], [17, 18, 19]],
81+
),
82+
texts=["bye", "seeya"],
83+
meta=ApiMeta(
84+
api_version=ApiMetaApiVersion(version="1"),
85+
billed_units=ApiMetaBilledUnits(
86+
input_tokens=2,
87+
output_tokens=2,
88+
search_units=2,
89+
classifications=2
90+
),
91+
warnings=["test_warning_1", "test_warning_2"]
92+
)
93+
)
94+
5395
ebf_1 = EmbedResponse_EmbeddingsFloats(
5496
response_type="embeddings_floats",
5597
id="1",
@@ -93,7 +135,6 @@ def test_merge_embeddings_by_type(self) -> None:
93135
ebt_2
94136
])
95137

96-
97138
if resp.meta is None:
98139
raise Exception("this is just for mpy")
99140

@@ -147,3 +188,34 @@ def test_merge_embeddings_floats(self) -> None:
147188
warnings=resp.meta.warnings # order ignored
148189
)
149190
))
191+
192+
def test_merge_partial_embeddings_floats(self) -> None:
193+
resp = merge_embed_responses([
194+
ebt_partial_1,
195+
ebt_partial_2
196+
])
197+
198+
if resp.meta is None:
199+
raise Exception("this is just for mpy")
200+
201+
self.assertEqual(set(resp.meta.warnings or []), {"test_warning_1", "test_warning_2"})
202+
self.assertEqual(resp, EmbedResponse_EmbeddingsByType(
203+
response_type="embeddings_by_type",
204+
id="1, 2",
205+
embeddings=EmbedByTypeResponseEmbeddings(
206+
float=[[0, 1, 2], [3, 4, 5], [7, 8, 9], [10, 11, 12]],
207+
int8=[[0, 1, 2], [3, 4, 5], [7, 8, 9], [10, 11, 12]],
208+
binary=[[5, 6, 7], [8, 9, 10], [14, 15, 16], [17, 18, 19]],
209+
),
210+
texts=["hello", "goodbye", "bye", "seeya"],
211+
meta=ApiMeta(
212+
api_version=ApiMetaApiVersion(version="1"),
213+
billed_units=ApiMetaBilledUnits(
214+
input_tokens=3,
215+
output_tokens=3,
216+
search_units=3,
217+
classifications=3
218+
),
219+
warnings=resp.meta.warnings # order ignored
220+
)
221+
))

0 commit comments

Comments
 (0)