Skip to content

Commit 37a2746

Browse files
Restore utils
1 parent 2a4cdea commit 37a2746

File tree

2 files changed

+153
-0
lines changed

2 files changed

+153
-0
lines changed

.fernignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ tests
66
.github/workflows/ci.yml
77
LICENSE
88
.github/workflows/tests.yml
9+
src/cohere/utils.py

src/cohere/utils.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import asyncio
2+
import time
3+
import typing
4+
from typing import Optional
5+
6+
from .types import EmbedJob, CreateEmbedJobResponse
7+
from .datasets import DatasetsCreateResponse, DatasetsGetResponse
8+
9+
10+
def get_terminal_states():
11+
return get_success_states() | get_failed_states()
12+
13+
14+
def get_success_states():
15+
return {"complete", "validated"}
16+
17+
18+
def get_failed_states():
19+
return {"unknown", "failed", "skipped", "cancelled", "failed"}
20+
21+
22+
def get_id(
23+
awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse, EmbedJob, DatasetsGetResponse]):
24+
return getattr(awaitable, "job_id", None) or getattr(awaitable, "id", None) or getattr(
25+
getattr(awaitable, "dataset", None), "id", None)
26+
27+
28+
def get_validation_status(awaitable: typing.Union[EmbedJob, DatasetsGetResponse]):
29+
return getattr(awaitable, "status", None) or getattr(getattr(awaitable, "dataset", None), "validation_status", None)
30+
31+
32+
def get_job(cohere: typing.Any,
33+
awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse, EmbedJob, DatasetsGetResponse]) -> \
34+
typing.Union[
35+
EmbedJob, DatasetsGetResponse]:
36+
if awaitable.__class__.__name__ == "EmbedJob" or awaitable.__class__.__name__ == "CreateEmbedJobResponse":
37+
return cohere.embed_jobs.get(id=get_id(awaitable))
38+
elif awaitable.__class__.__name__ == "DatasetsGetResponse" or awaitable.__class__.__name__ == "DatasetsCreateResponse":
39+
return cohere.datasets.get(id=get_id(awaitable))
40+
else:
41+
raise ValueError(f"Unexpected awaitable type {awaitable}")
42+
43+
44+
async def async_get_job(cohere: typing.Any, awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse]) -> \
45+
typing.Union[
46+
EmbedJob, DatasetsGetResponse]:
47+
if awaitable.__class__.__name__ == "EmbedJob" or awaitable.__class__.__name__ == "CreateEmbedJobResponse":
48+
return await cohere.embed_jobs.get(id=get_id(awaitable))
49+
elif awaitable.__class__.__name__ == "DatasetsGetResponse" or awaitable.__class__.__name__ == "DatasetsCreateResponse":
50+
return await cohere.datasets.get(id=get_id(awaitable))
51+
else:
52+
raise ValueError(f"Unexpected awaitable type {awaitable}")
53+
54+
55+
def get_failure_reason(job: typing.Union[EmbedJob, DatasetsGetResponse]) -> Optional[str]:
56+
if isinstance(job, EmbedJob):
57+
return f"Embed job {job.job_id} failed with status {job.status}"
58+
elif isinstance(job, DatasetsGetResponse):
59+
return f"Dataset creation {job.dataset.validation_status} failed with status {job.dataset.validation_status}"
60+
return None
61+
62+
63+
@typing.overload
64+
def wait(
65+
cohere: typing.Any,
66+
awaitable: CreateEmbedJobResponse,
67+
timeout: Optional[float] = None,
68+
interval: float = 10,
69+
) -> EmbedJob:
70+
...
71+
72+
73+
@typing.overload
74+
def wait(
75+
cohere: typing.Any,
76+
awaitable: DatasetsCreateResponse,
77+
timeout: Optional[float] = None,
78+
interval: float = 10,
79+
) -> DatasetsGetResponse:
80+
...
81+
82+
83+
def wait(
84+
cohere: typing.Any,
85+
awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse],
86+
timeout: Optional[float] = None,
87+
interval: float = 2,
88+
) -> typing.Union[EmbedJob, DatasetsGetResponse]:
89+
start_time = time.time()
90+
terminal_states = get_terminal_states()
91+
failed_states = get_failed_states()
92+
93+
job = get_job(cohere, awaitable)
94+
while get_validation_status(job) not in terminal_states:
95+
if timeout is not None and time.time() - start_time > timeout:
96+
raise TimeoutError(f"wait timed out after {timeout} seconds")
97+
98+
time.sleep(interval)
99+
print("...")
100+
101+
job = get_job(cohere, awaitable)
102+
103+
if get_validation_status(job) in failed_states:
104+
raise Exception(get_failure_reason(job))
105+
106+
return job
107+
108+
109+
@typing.overload
110+
async def async_wait(
111+
cohere: typing.Any,
112+
awaitable: CreateEmbedJobResponse,
113+
timeout: Optional[float] = None,
114+
interval: float = 10,
115+
) -> EmbedJob:
116+
...
117+
118+
119+
@typing.overload
120+
async def async_wait(
121+
cohere: typing.Any,
122+
awaitable: DatasetsCreateResponse,
123+
timeout: Optional[float] = None,
124+
interval: float = 10,
125+
) -> DatasetsGetResponse:
126+
...
127+
128+
129+
async def async_wait(
130+
cohere: typing.Any,
131+
awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse],
132+
timeout: Optional[float] = None,
133+
interval: float = 10,
134+
) -> typing.Union[EmbedJob, DatasetsGetResponse]:
135+
start_time = time.time()
136+
terminal_states = get_terminal_states()
137+
failed_states = get_failed_states()
138+
139+
job = await async_get_job(cohere, awaitable)
140+
while get_validation_status(job) not in terminal_states:
141+
if timeout is not None and time.time() - start_time > timeout:
142+
raise TimeoutError(f"wait timed out after {timeout} seconds")
143+
144+
await asyncio.sleep(interval)
145+
print("...")
146+
147+
job = await async_get_job(cohere, awaitable)
148+
149+
if get_validation_status(job) in failed_states:
150+
raise Exception(get_failure_reason(job))
151+
152+
return job

0 commit comments

Comments
 (0)