Skip to content

Commit 77e0087

Browse files
Fixes rc and add async tests (#407)
* Fix type * Add async tests * Drop tests.yml * Add chat stream tests
1 parent 372bb05 commit 77e0087

File tree

5 files changed

+332
-45
lines changed

5 files changed

+332
-45
lines changed

.fernignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ README.md
44
src/cohere/client.py
55
tests
66
.github/workflows/ci.yml
7-
LICENSE
7+
LICENSE
8+
.github/workflows/tests.yml

.github/workflows/tests.yml

Lines changed: 0 additions & 41 deletions
This file was deleted.

src/cohere/client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import httpx
44

55
from .base_client import BaseCohere, AsyncBaseCohere
6-
from .environment import CohereEnvironment
6+
from .environment import ClientEnvironment
77

88
# Use NoReturn as Never type for compatibility
99
Never = typing.NoReturn
@@ -59,7 +59,7 @@ def __init__(
5959
api_key: typing.Union[str, typing.Callable[[], str]],
6060
*,
6161
base_url: typing.Optional[str] = None,
62-
environment: CohereEnvironment = CohereEnvironment.PRODUCTION,
62+
environment: ClientEnvironment = ClientEnvironment.PRODUCTION,
6363
client_name: typing.Optional[str] = None,
6464
timeout: typing.Optional[float] = 60,
6565
httpx_client: typing.Optional[httpx.Client] = None,
@@ -128,7 +128,7 @@ def __init__(
128128
api_key: typing.Union[str, typing.Callable[[], str]],
129129
*,
130130
base_url: typing.Optional[str] = None,
131-
environment: CohereEnvironment = CohereEnvironment.PRODUCTION,
131+
environment: ClientEnvironment = ClientEnvironment.PRODUCTION,
132132
client_name: typing.Optional[str] = None,
133133
timeout: typing.Optional[float] = 60,
134134
httpx_client: typing.Optional[httpx.AsyncClient] = None,

tests/test_async_client.py

Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
import os
2+
import unittest
3+
from time import sleep
4+
5+
import cohere
6+
from cohere import ChatMessage, ChatConnector, ClassifyExample, CreateConnectorServiceAuth, Tool, \
7+
ToolParameterDefinitionsValue, ChatRequestToolResultsItem
8+
9+
co = cohere.AsyncClient(os.environ['COHERE_API_KEY'], timeout=10000)
10+
11+
package_dir = os.path.dirname(os.path.abspath(__file__))
12+
embed_job = os.path.join(package_dir, 'embed_job.jsonl')
13+
14+
15+
class TestClient(unittest.TestCase):
16+
17+
async def test_chat(self) -> None:
18+
chat = await co.chat(
19+
chat_history=[
20+
ChatMessage(role="USER",
21+
message="Who discovered gravity?"),
22+
ChatMessage(role="CHATBOT", message="The man who is widely credited with discovering "
23+
"gravity is Sir Isaac Newton")
24+
],
25+
message="What year was he born?",
26+
connectors=[ChatConnector(id="web-search")]
27+
)
28+
29+
print(chat)
30+
31+
async def test_chat_stream(self) -> None:
32+
stream = co.chat_stream(
33+
chat_history=[
34+
ChatMessage(role="USER",
35+
message="Who discovered gravity?"),
36+
ChatMessage(role="CHATBOT", message="The man who is widely credited with discovering "
37+
"gravity is Sir Isaac Newton")
38+
],
39+
message="What year was he born?",
40+
connectors=[ChatConnector(id="web-search")]
41+
)
42+
43+
async for chat_event in stream:
44+
if chat_event.event_type == "text-generation":
45+
print(chat_event.text)
46+
47+
async def test_stream_equals_true(self) -> None:
48+
with self.assertRaises(ValueError):
49+
await co.chat(
50+
stream=True, # type: ignore
51+
message="What year was he born?",
52+
)
53+
54+
async def test_deprecated_fn(self) -> None:
55+
with self.assertRaises(ValueError):
56+
await co.check_api_key("dummy", dummy="dummy") # type: ignore
57+
58+
async def test_moved_fn(self) -> None:
59+
with self.assertRaises(ValueError):
60+
await co.list_connectors("dummy", dummy="dummy") # type: ignore
61+
62+
async def test_generate(self) -> None:
63+
response = await co.generate(
64+
prompt='Please explain to me how LLMs work',
65+
)
66+
print(response)
67+
68+
async def test_embed(self) -> None:
69+
response = await co.embed(
70+
texts=['hello', 'goodbye'],
71+
model='embed-english-v3.0',
72+
input_type="classification"
73+
)
74+
print(response)
75+
76+
async def test_embed_job_crud(self) -> None:
77+
dataset = await co.datasets.create(
78+
name="test",
79+
type="embed-input",
80+
data=open(embed_job, 'rb'),
81+
)
82+
83+
while True:
84+
ds = await co.datasets.get(dataset.id or "")
85+
sleep(2)
86+
print(ds, flush=True)
87+
if ds.dataset.validation_status != "processing":
88+
break
89+
90+
# start an embed job
91+
job = await co.embed_jobs.create(
92+
dataset_id=dataset.id or "",
93+
input_type="search_document",
94+
model='embed-english-v3.0')
95+
96+
print(job)
97+
98+
# list embed jobs
99+
my_embed_jobs = await co.embed_jobs.list()
100+
101+
print(my_embed_jobs)
102+
103+
while True:
104+
em = await co.embed_jobs.get(job.job_id)
105+
sleep(2)
106+
print(em, flush=True)
107+
if em.status != "processing":
108+
break
109+
110+
await co.embed_jobs.cancel(job.job_id)
111+
112+
await co.datasets.delete(dataset.id or "")
113+
114+
async def test_rerank(self) -> None:
115+
docs = [
116+
'Carson City is the capital city of the American state of Nevada.',
117+
'The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.',
118+
'Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.',
119+
'Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.']
120+
121+
response = await co.rerank(
122+
model='rerank-english-v2.0',
123+
query='What is the capital of the United States?',
124+
documents=docs,
125+
top_n=3,
126+
)
127+
128+
print(response)
129+
130+
async def test_classify(self) -> None:
131+
examples = [
132+
ClassifyExample(text="Dermatologists don't like her!", label="Spam"),
133+
ClassifyExample(text="'Hello, open to this?'", label="Spam"),
134+
ClassifyExample(
135+
text="I need help please wire me $1000 right now", label="Spam"),
136+
ClassifyExample(text="Nice to know you ;)", label="Spam"),
137+
ClassifyExample(text="Please help me?", label="Spam"),
138+
ClassifyExample(
139+
text="Your parcel will be delivered today", label="Not spam"),
140+
ClassifyExample(
141+
text="Review changes to our Terms and Conditions", label="Not spam"),
142+
ClassifyExample(text="Weekly sync notes", label="Not spam"),
143+
ClassifyExample(
144+
text="'Re: Follow up from today's meeting'", label="Not spam"),
145+
ClassifyExample(text="Pre-read for tomorrow", label="Not spam"),
146+
]
147+
inputs = [
148+
"Confirm your email address",
149+
"hey i need u to send some $",
150+
]
151+
response = await co.classify(
152+
inputs=inputs,
153+
examples=examples,
154+
)
155+
print(response)
156+
157+
async def test_datasets_crud(self) -> None:
158+
my_dataset = await co.datasets.create(
159+
name="test",
160+
type="embed-input",
161+
data=open(embed_job, 'rb'),
162+
)
163+
164+
print(my_dataset)
165+
166+
my_datasets = await co.datasets.list()
167+
168+
print(my_datasets)
169+
170+
dataset = await co.datasets.get(my_dataset.id or "")
171+
172+
print(dataset)
173+
174+
await co.datasets.delete(my_dataset.id or "")
175+
176+
async def test_summarize(self) -> None:
177+
text = (
178+
"Ice cream is a sweetened frozen food typically eaten as a snack or dessert. "
179+
"It may be made from milk or cream and is flavoured with a sweetener, "
180+
"either sugar or an alternative, and a spice, such as cocoa or vanilla, "
181+
"or with fruit such as strawberries or peaches. "
182+
"It can also be made by whisking a flavored cream base and liquid nitrogen together. "
183+
"Food coloring is sometimes added, in addition to stabilizers. "
184+
"The mixture is cooled below the freezing point of water and stirred to incorporate air spaces "
185+
"and to prevent detectable ice crystals from forming. The result is a smooth, "
186+
"semi-solid foam that is solid at very low temperatures (below 2 °C or 35 °F). "
187+
"It becomes more malleable as its temperature increases.\n\n"
188+
"The meaning of the name \"ice cream\" varies from one country to another. "
189+
"In some countries, such as the United States, \"ice cream\" applies only to a specific variety, "
190+
"and most governments regulate the commercial use of the various terms according to the "
191+
"relative quantities of the main ingredients, notably the amount of cream. "
192+
"Products that do not meet the criteria to be called ice cream are sometimes labelled "
193+
"\"frozen dairy dessert\" instead. In other countries, such as Italy and Argentina, "
194+
"one word is used fo\r all variants. Analogues made from dairy alternatives, "
195+
"such as goat's or sheep's milk, or milk substitutes "
196+
"(e.g., soy, cashew, coconut, almond milk or tofu), are available for those who are "
197+
"lactose intolerant, allergic to dairy protein or vegan."
198+
)
199+
200+
response = await co.summarize(
201+
text=text,
202+
)
203+
204+
print(response)
205+
206+
async def test_tokenize(self) -> None:
207+
response = await co.tokenize(
208+
text='tokenize me! :D',
209+
model='command'
210+
)
211+
print(response)
212+
213+
async def test_detokenize(self) -> None:
214+
response = await co.detokenize(
215+
tokens=[10104, 12221, 1315, 34, 1420, 69],
216+
model="command"
217+
)
218+
print(response)
219+
220+
async def test_connectors_crud(self) -> None:
221+
created_connector = await co.connectors.create(
222+
name="Example connector",
223+
url="https://dummy-connector-o5btz7ucgq-uc.a.run.app/search",
224+
service_auth=CreateConnectorServiceAuth(
225+
token="dummy-connector-token",
226+
type="bearer",
227+
)
228+
)
229+
print(created_connector)
230+
231+
connector = await co.connectors.get(created_connector.connector.id)
232+
233+
print(connector)
234+
235+
updated_connector = await co.connectors.update(
236+
id=connector.connector.id, name="new name")
237+
238+
print(updated_connector)
239+
240+
await co.connectors.delete(created_connector.connector.id)
241+
242+
async def test_tool_use(self) -> None:
243+
tools = [
244+
Tool(
245+
name="sales_database",
246+
description="Connects to a database about sales volumes",
247+
parameter_definitions={
248+
"day": ToolParameterDefinitionsValue(
249+
description="Retrieves sales data from this day, formatted as YYYY-MM-DD.",
250+
type="str",
251+
required=True
252+
)}
253+
)
254+
]
255+
256+
tool_parameters_response = await co.chat(
257+
message="How good were the sales on September 29?",
258+
tools=tools,
259+
model="command-nightly",
260+
preamble="""
261+
## Task Description
262+
You help people answer their questions and other requests interactively. You will be asked a very wide array of requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to help you, which you use to research your answer. You should focus on serving the user's needs as best you can, which will be wide-ranging.
263+
264+
## Style Guide
265+
Unless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling.
266+
"""
267+
)
268+
269+
if tool_parameters_response.tool_calls is not None:
270+
self.assertEqual(tool_parameters_response.tool_calls[0].name, "sales_database")
271+
self.assertEqual(tool_parameters_response.tool_calls[0].parameters, {"day": "2023-09-29"})
272+
else:
273+
raise ValueError("Expected tool calls to be present")
274+
275+
local_tools = {
276+
"sales_database": lambda day: {
277+
"number_of_sales": 120,
278+
"total_revenue": 48500,
279+
"average_sale_value": 404.17,
280+
"date": "2023-09-29"
281+
}
282+
}
283+
284+
tool_results = []
285+
for tool_call in tool_parameters_response.tool_calls:
286+
output = local_tools[tool_call.name](**tool_call.parameters)
287+
outputs = [output]
288+
289+
tool_results.append(ChatRequestToolResultsItem(
290+
call=tool_call,
291+
outputs=outputs
292+
))
293+
294+
cited_response = await co.chat(
295+
message="How good were the sales on September 29?",
296+
tools=tools,
297+
tool_results=tool_results,
298+
model="command-nightly",
299+
)
300+
301+
self.assertEqual(cited_response.documents, [
302+
{
303+
"tool_name": "sales_database",
304+
"average_sale_value": "404.17",
305+
"date": "2023-09-29",
306+
"id": "sales_database:0:0",
307+
"number_of_sales": "120",
308+
"total_revenue": "48500",
309+
}
310+
])

0 commit comments

Comments
 (0)