Skip to content

Commit 045ceba

Browse files
Throw if stream=True (#404)
* Throw if stream=True * Fix types
1 parent d5546a7 commit 045ceba

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

src/cohere/client.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,23 @@
66
from .environment import CohereEnvironment
77

88

9+
def validate_args(obj: typing.Any, method_name: str, check_fn: typing.Callable[[typing.Any], typing.Any]) -> None:
10+
method = getattr(obj, method_name)
11+
12+
def wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
13+
check_fn(*args, **kwargs)
14+
return method(*args, **kwargs)
15+
16+
setattr(obj, method_name, wrapped)
17+
18+
19+
def throw_if_stream_is_true(*args, **kwargs) -> None:
20+
if kwargs.get("stream") is True:
21+
raise ValueError(
22+
"Since python sdk cohere==5.0.0, you must now use chat_stream(...) instead of chat(stream=True, ...)"
23+
)
24+
25+
926
class Client(BaseCohere):
1027
def __init__(
1128
self,
@@ -27,6 +44,8 @@ def __init__(
2744
httpx_client=httpx_client,
2845
)
2946

47+
validate_args(self, "chat", throw_if_stream_is_true)
48+
3049

3150
class AsyncClient(AsyncBaseCohere):
3251
def __init__(
@@ -48,3 +67,5 @@ def __init__(
4867
timeout=timeout,
4968
httpx_client=httpx_client,
5069
)
70+
71+
validate_args(self, "chat", throw_if_stream_is_true)

tests/test_client.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ def test_chat(self) -> None:
2828

2929
print(chat)
3030

31+
def test_stream_equals_true(self) -> None:
32+
with self.assertRaises(ValueError):
33+
co.chat(
34+
stream=True, # type: ignore
35+
message="What year was he born?",
36+
)
37+
3138
def test_generate(self) -> None:
3239
response = co.generate(
3340
prompt='Please explain to me how LLMs work',

0 commit comments

Comments
 (0)