Skip to content

Commit d476799

Browse files
Add wait and fixes types (#410)
* Add wait and fixes types * Fix types * Fix imports * Fix type
1 parent 77e0087 commit d476799

File tree

4 files changed

+205
-58
lines changed

4 files changed

+205
-58
lines changed

src/cohere/client.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from .base_client import BaseCohere, AsyncBaseCohere
66
from .environment import ClientEnvironment
7+
from .utils import wait, async_wait
8+
79

810
# Use NoReturn as Never type for compatibility
911
Never = typing.NoReturn
@@ -25,6 +27,7 @@ def throw_if_stream_is_true(*args, **kwargs) -> None:
2527
"Since python sdk cohere==5.0.0, you must now use chat_stream(...) instead of chat(stream=True, ...)"
2628
)
2729

30+
2831
def moved_function(fn_name: str, new_fn_name: str) -> typing.Any:
2932
"""
3033
This method is moved. Please update usage.
@@ -56,7 +59,7 @@ def fn(*args, **kwargs):
5659
class Client(BaseCohere):
5760
def __init__(
5861
self,
59-
api_key: typing.Union[str, typing.Callable[[], str]],
62+
api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None,
6063
*,
6164
base_url: typing.Optional[str] = None,
6265
environment: ClientEnvironment = ClientEnvironment.PRODUCTION,
@@ -76,6 +79,8 @@ def __init__(
7679

7780
validate_args(self, "chat", throw_if_stream_is_true)
7881

82+
wait = wait
83+
7984
"""
8085
The following methods have been moved or deprecated in cohere==5.0.0. Please update your usage.
8186
Issues may be filed in https://github.com/cohere-ai/cohere-python/issues.
@@ -125,7 +130,7 @@ def __init__(
125130
class AsyncClient(AsyncBaseCohere):
126131
def __init__(
127132
self,
128-
api_key: typing.Union[str, typing.Callable[[], str]],
133+
api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None,
129134
*,
130135
base_url: typing.Optional[str] = None,
131136
environment: ClientEnvironment = ClientEnvironment.PRODUCTION,
@@ -145,6 +150,8 @@ def __init__(
145150

146151
validate_args(self, "chat", throw_if_stream_is_true)
147152

153+
wait = async_wait
154+
148155
"""
149156
The following methods have been moved or deprecated in cohere==5.0.0. Please update your usage.
150157
Issues may be filed in https://github.com/cohere-ai/cohere-python/issues.

src/cohere/utils.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import asyncio
2+
import time
3+
import typing
4+
from typing import Optional
5+
6+
from .types import EmbedJob, CreateEmbedJobResponse
7+
from .datasets import DatasetsCreateResponse, DatasetsGetResponse
8+
9+
10+
def get_terminal_states():
11+
return get_success_states() | get_failed_states()
12+
13+
14+
def get_success_states():
15+
return {"complete", "validated"}
16+
17+
18+
def get_failed_states():
19+
return {"unknown", "failed", "skipped", "cancelled", "failed"}
20+
21+
22+
def get_id(
23+
awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse, EmbedJob, DatasetsGetResponse]):
24+
return getattr(awaitable, "job_id", None) or getattr(awaitable, "id", None) or getattr(
25+
getattr(awaitable, "dataset", None), "id", None)
26+
27+
28+
def get_validation_status(awaitable: typing.Union[EmbedJob, DatasetsGetResponse]):
29+
return getattr(awaitable, "status", None) or getattr(getattr(awaitable, "dataset", None), "validation_status", None)
30+
31+
32+
def get_job(cohere: typing.Any,
33+
awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse, EmbedJob, DatasetsGetResponse]) -> \
34+
typing.Union[
35+
EmbedJob, DatasetsGetResponse]:
36+
if awaitable.__class__.__name__ == "EmbedJob" or awaitable.__class__.__name__ == "CreateEmbedJobResponse":
37+
return cohere.embed_jobs.get(id=get_id(awaitable))
38+
elif awaitable.__class__.__name__ == "DatasetsGetResponse" or awaitable.__class__.__name__ == "DatasetsCreateResponse":
39+
return cohere.datasets.get(id=get_id(awaitable))
40+
else:
41+
raise ValueError(f"Unexpected awaitable type {awaitable}")
42+
43+
44+
async def async_get_job(cohere: typing.Any, awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse]) -> \
45+
typing.Union[
46+
EmbedJob, DatasetsGetResponse]:
47+
if awaitable.__class__.__name__ == "EmbedJob" or awaitable.__class__.__name__ == "CreateEmbedJobResponse":
48+
return await cohere.embed_jobs.get(id=get_id(awaitable))
49+
elif awaitable.__class__.__name__ == "DatasetsGetResponse" or awaitable.__class__.__name__ == "DatasetsCreateResponse":
50+
return await cohere.datasets.get(id=get_id(awaitable))
51+
else:
52+
raise ValueError(f"Unexpected awaitable type {awaitable}")
53+
54+
55+
def get_failure_reason(job: typing.Union[EmbedJob, DatasetsGetResponse]) -> Optional[str]:
56+
if isinstance(job, EmbedJob):
57+
return f"Embed job {job.job_id} failed with status {job.status}"
58+
elif isinstance(job, DatasetsGetResponse):
59+
return f"Dataset creation {job.dataset.validation_status} failed with status {job.dataset.validation_status}"
60+
return None
61+
62+
63+
@typing.overload
64+
def wait(
65+
cohere: typing.Any,
66+
awaitable: CreateEmbedJobResponse,
67+
timeout: Optional[float] = None,
68+
interval: float = 10,
69+
) -> EmbedJob:
70+
...
71+
72+
73+
@typing.overload
74+
def wait(
75+
cohere: typing.Any,
76+
awaitable: DatasetsCreateResponse,
77+
timeout: Optional[float] = None,
78+
interval: float = 10,
79+
) -> DatasetsGetResponse:
80+
...
81+
82+
83+
def wait(
84+
cohere: typing.Any,
85+
awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse],
86+
timeout: Optional[float] = None,
87+
interval: float = 2,
88+
) -> typing.Union[EmbedJob, DatasetsGetResponse]:
89+
start_time = time.time()
90+
terminal_states = get_terminal_states()
91+
failed_states = get_failed_states()
92+
93+
job = get_job(cohere, awaitable)
94+
while get_validation_status(job) not in terminal_states:
95+
if timeout is not None and time.time() - start_time > timeout:
96+
raise TimeoutError(f"wait timed out after {timeout} seconds")
97+
98+
time.sleep(interval)
99+
print("...")
100+
101+
job = get_job(cohere, awaitable)
102+
103+
if get_validation_status(job) in failed_states:
104+
raise Exception(get_failure_reason(job))
105+
106+
return job
107+
108+
109+
@typing.overload
110+
async def async_wait(
111+
cohere: typing.Any,
112+
awaitable: CreateEmbedJobResponse,
113+
timeout: Optional[float] = None,
114+
interval: float = 10,
115+
) -> EmbedJob:
116+
...
117+
118+
119+
@typing.overload
120+
async def async_wait(
121+
cohere: typing.Any,
122+
awaitable: DatasetsCreateResponse,
123+
timeout: Optional[float] = None,
124+
interval: float = 10,
125+
) -> DatasetsGetResponse:
126+
...
127+
128+
129+
async def async_wait(
130+
cohere: typing.Any,
131+
awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse],
132+
timeout: Optional[float] = None,
133+
interval: float = 10,
134+
) -> typing.Union[EmbedJob, DatasetsGetResponse]:
135+
start_time = time.time()
136+
terminal_states = get_terminal_states()
137+
failed_states = get_failed_states()
138+
139+
job = await async_get_job(cohere, awaitable)
140+
while get_validation_status(job) not in terminal_states:
141+
if timeout is not None and time.time() - start_time > timeout:
142+
raise TimeoutError(f"wait timed out after {timeout} seconds")
143+
144+
await asyncio.sleep(interval)
145+
print("...")
146+
147+
job = await async_get_job(cohere, awaitable)
148+
149+
if get_validation_status(job) in failed_states:
150+
raise Exception(get_failure_reason(job))
151+
152+
return job

0 commit comments

Comments
 (0)