Skip to content

Commit 2eac37d

Browse files
Add support for context manager (#435)
* Add support for context manager * Fix
1 parent b6b5090 commit 2eac37d

File tree

3 files changed

+24
-0
lines changed

3 files changed

+24
-0
lines changed

src/cohere/client.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,14 @@ def __init__(
8787

8888
validate_args(self, "chat", throw_if_stream_is_true)
8989

90+
# support context manager until Fern upstreams
91+
# https://linear.app/buildwithfern/issue/FER-1242/expose-a-context-manager-interface-or-the-http-client-easily
92+
def __enter__(self):
93+
return self
94+
95+
def __exit__(self, exc_type, exc_value, traceback):
96+
self._client_wrapper.httpx_client.httpx_client.close()
97+
9098
wait = wait
9199

92100
"""
@@ -161,6 +169,14 @@ def __init__(
161169

162170
validate_args(self, "chat", throw_if_stream_is_true)
163171

172+
# support context manager until Fern upstreams
173+
# https://linear.app/buildwithfern/issue/FER-1242/expose-a-context-manager-interface-or-the-http-client-easily
174+
async def __aenter__(self):
175+
return self
176+
177+
async def __aexit__(self, exc_type, exc_value, traceback):
178+
await self._client_wrapper.httpx_client.httpx_client.aclose()
179+
164180
wait = async_wait
165181

166182
"""

tests/test_async_client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ async def test_token_falls_back_on_env_variable(self) -> None:
1919
cohere.AsyncClient(api_key=None)
2020
cohere.AsyncClient(None)
2121

22+
async def test_context_manager(self) -> None:
23+
async with cohere.AsyncClient(api_key="xxx") as client:
24+
self.assertIsNotNone(client)
25+
2226
async def test_chat(self) -> None:
2327
chat = await self.co.chat(
2428
chat_history=[

tests/test_client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ def test_token_falls_back_on_env_variable(self) -> None:
1717
cohere.Client(api_key=None)
1818
cohere.Client(None)
1919

20+
def test_context_manager(self) -> None:
21+
with cohere.Client(api_key="xxx") as client:
22+
self.assertIsNotNone(client)
23+
2024
def test_chat(self) -> None:
2125
chat = co.chat(
2226
chat_history=[

0 commit comments

Comments
 (0)