Skip to content

Commit bd906fa

Browse files
Add bedrock test and v2 clis (#609)
* Update bedrock cli * br * Add v2 clients * extras * None ==
1 parent b1dbf95 commit bd906fa

File tree

6 files changed

+191
-144
lines changed

6 files changed

+191
-144
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ jobs:
3535
version: 1.5.1
3636
virtualenvs-in-project: false
3737
- name: Install dependencies
38-
run: poetry install
38+
run: poetry install --extras aws
3939
- name: Test
4040
run: poetry run pytest .
4141
env:

src/cohere/aws_client.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .client import Client, ClientEnvironment
1313
from .core import construct_type
1414
from .manually_maintained.lazy_aws_deps import lazy_boto3, lazy_botocore
15-
15+
from .client_v2 import ClientV2
1616

1717
class AwsClient(Client):
1818
def __init__(
@@ -45,6 +45,37 @@ def __init__(
4545
)
4646

4747

48+
class AwsClientV2(ClientV2):
49+
def __init__(
50+
self,
51+
*,
52+
aws_access_key: typing.Optional[str] = None,
53+
aws_secret_key: typing.Optional[str] = None,
54+
aws_session_token: typing.Optional[str] = None,
55+
aws_region: typing.Optional[str] = None,
56+
timeout: typing.Optional[float] = None,
57+
service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]],
58+
):
59+
Client.__init__(
60+
self,
61+
base_url="https://api.cohere.com", # this url is unused for BedrockClient
62+
environment=ClientEnvironment.PRODUCTION,
63+
client_name="n/a",
64+
timeout=timeout,
65+
api_key="n/a",
66+
httpx_client=httpx.Client(
67+
event_hooks=get_event_hooks(
68+
service=service,
69+
aws_access_key=aws_access_key,
70+
aws_secret_key=aws_secret_key,
71+
aws_session_token=aws_session_token,
72+
aws_region=aws_region,
73+
),
74+
timeout=timeout,
75+
),
76+
)
77+
78+
4879
EventHook = typing.Callable[..., typing.Any]
4980

5081

src/cohere/bedrock_client.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from tokenizers import Tokenizer # type: ignore
44

5-
from .aws_client import AwsClient
5+
from .aws_client import AwsClient, AwsClientV2
66

77

88
class BedrockClient(AwsClient):
@@ -24,3 +24,24 @@ def __init__(
2424
aws_region=aws_region,
2525
timeout=timeout,
2626
)
27+
28+
29+
class BedrockClientV2(AwsClientV2):
30+
def __init__(
31+
self,
32+
*,
33+
aws_access_key: typing.Optional[str] = None,
34+
aws_secret_key: typing.Optional[str] = None,
35+
aws_session_token: typing.Optional[str] = None,
36+
aws_region: typing.Optional[str] = None,
37+
timeout: typing.Optional[float] = None,
38+
):
39+
AwsClientV2.__init__(
40+
self,
41+
service="bedrock",
42+
aws_access_key=aws_access_key,
43+
aws_secret_key=aws_secret_key,
44+
aws_session_token=aws_session_token,
45+
aws_region=aws_region,
46+
timeout=timeout,
47+
)

src/cohere/sagemaker_client.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import typing
22

3-
from .aws_client import AwsClient
3+
from .aws_client import AwsClient, AwsClientV2
44
from .manually_maintained.cohere_aws.client import Client
55
from .manually_maintained.cohere_aws.mode import Mode
66

@@ -26,4 +26,28 @@ def __init__(
2626
aws_region=aws_region,
2727
timeout=timeout,
2828
)
29+
self.sagemaker_finetuning = Client(aws_region=aws_region)
30+
31+
32+
class SagemakerClientV2(AwsClientV2):
33+
sagemaker_finetuning: Client
34+
35+
def __init__(
36+
self,
37+
*,
38+
aws_access_key: typing.Optional[str] = None,
39+
aws_secret_key: typing.Optional[str] = None,
40+
aws_session_token: typing.Optional[str] = None,
41+
aws_region: typing.Optional[str] = None,
42+
timeout: typing.Optional[float] = None,
43+
):
44+
AwsClientV2.__init__(
45+
self,
46+
service="sagemaker",
47+
aws_access_key=aws_access_key,
48+
aws_secret_key=aws_secret_key,
49+
aws_session_token=aws_session_token,
50+
aws_region=aws_region,
51+
timeout=timeout,
52+
)
2953
self.sagemaker_finetuning = Client(aws_region=aws_region)

tests/test_aws_client.py

Lines changed: 0 additions & 140 deletions
This file was deleted.

tests/test_bedrock_client.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import os
2+
import unittest
3+
4+
import typing
5+
import cohere
6+
7+
aws_access_key = os.getenv("AWS_ACCESS_KEY")
8+
aws_secret_key = os.getenv("AWS_SECRET_KEY")
9+
aws_session_token = os.getenv("AWS_SESSION_TOKEN")
10+
aws_region = os.getenv("AWS_REGION")
11+
endpoint_type = os.getenv("ENDPOINT_TYPE")
12+
13+
@unittest.skipIf(None == os.getenv("TEST_AWS"), "tests skipped because TEST_AWS is not set")
14+
class TestClient(unittest.TestCase):
15+
platform: str = "bedrock"
16+
client: cohere.AwsClient = cohere.BedrockClient(
17+
aws_access_key=aws_access_key,
18+
aws_secret_key=aws_secret_key,
19+
aws_session_token=aws_session_token,
20+
aws_region=aws_region,
21+
)
22+
models: typing.Dict[str, str] = {
23+
"chat_model": "cohere.command-r-plus-v1:0",
24+
"embed_model": "cohere.embed-multilingual-v3",
25+
"generate_model": "cohere.command-text-v14",
26+
}
27+
28+
def test_rerank(self) -> None:
29+
if self.platform != "sagemaker":
30+
self.skipTest("Only sagemaker supports rerank")
31+
32+
docs = [
33+
'Carson City is the capital city of the American state of Nevada.',
34+
'The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.',
35+
'Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.',
36+
'Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.']
37+
38+
response = self.client.rerank(
39+
model=self.models["rerank_model"],
40+
query='What is the capital of the United States?',
41+
documents=docs,
42+
top_n=3,
43+
)
44+
45+
self.assertEqual(len(response.results), 3)
46+
47+
def test_embed(self) -> None:
48+
response = self.client.embed(
49+
model=self.models["embed_model"],
50+
texts=["I love Cohere!"],
51+
input_type="search_document",
52+
)
53+
print(response)
54+
55+
def test_generate(self) -> None:
56+
response = self.client.generate(
57+
model=self.models["generate_model"],
58+
prompt='Please explain to me how LLMs work',
59+
)
60+
print(response)
61+
62+
def test_generate_stream(self) -> None:
63+
response = self.client.generate_stream(
64+
model=self.models["generate_model"],
65+
prompt='Please explain to me how LLMs work',
66+
)
67+
for event in response:
68+
print(event)
69+
if event.event_type == "text-generation":
70+
print(event.text, end='')
71+
72+
def test_chat(self) -> None:
73+
response = self.client.chat(
74+
model=self.models["chat_model"],
75+
message='Please explain to me how LLMs work',
76+
)
77+
print(response)
78+
79+
self.assertIsNotNone(response.text)
80+
self.assertIsNotNone(response.generation_id)
81+
self.assertIsNotNone(response.finish_reason)
82+
83+
self.assertIsNotNone(response.meta)
84+
if response.meta is not None:
85+
self.assertIsNotNone(response.meta.tokens)
86+
if response.meta.tokens is not None:
87+
self.assertIsNotNone(response.meta.tokens.input_tokens)
88+
self.assertIsNotNone(response.meta.tokens.output_tokens)
89+
90+
self.assertIsNotNone(response.meta.billed_units)
91+
if response.meta.billed_units is not None:
92+
self.assertIsNotNone(response.meta.billed_units.input_tokens)
93+
self.assertIsNotNone(response.meta.billed_units.input_tokens)
94+
95+
def test_chat_stream(self) -> None:
96+
response_types = set()
97+
response = self.client.chat_stream(
98+
model=self.models["chat_model"],
99+
message='Please explain to me how LLMs work',
100+
)
101+
for event in response:
102+
response_types.add(event.event_type)
103+
if event.event_type == "text-generation":
104+
print(event.text, end='')
105+
self.assertIsNotNone(event.text)
106+
if event.event_type == "stream-end":
107+
self.assertIsNotNone(event.finish_reason)
108+
self.assertIsNotNone(event.response)
109+
self.assertIsNotNone(event.response.text)
110+
111+
self.assertSetEqual(response_types, {"text-generation", "stream-end"})

0 commit comments

Comments
 (0)