Skip to content

Commit 2ed7109

Browse files
Move models param to the endpoint (#516)
1 parent db1255f commit 2ed7109

File tree

4 files changed

+35
-55
lines changed

4 files changed

+35
-55
lines changed

src/cohere/aws_client.py

Lines changed: 9 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,6 @@ def __init__(
2525
aws_session_token: typing.Optional[str] = None,
2626
aws_region: typing.Optional[str] = None,
2727
timeout: typing.Optional[float] = None,
28-
chat_model: typing.Optional[str] = None,
29-
embed_model: typing.Optional[str] = None,
30-
generate_model: typing.Optional[str] = None,
3128
service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]],
3229
):
3330
Client.__init__(
@@ -44,9 +41,6 @@ def __init__(
4441
aws_secret_key=aws_secret_key,
4542
aws_session_token=aws_session_token,
4643
aws_region=aws_region,
47-
chat_model=chat_model,
48-
embed_model=embed_model,
49-
generate_model=generate_model,
5044
),
5145
timeout=timeout,
5246
),
@@ -62,9 +56,6 @@ def get_event_hooks(
6256
aws_secret_key: typing.Optional[str] = None,
6357
aws_session_token: typing.Optional[str] = None,
6458
aws_region: typing.Optional[str] = None,
65-
chat_model: typing.Optional[str] = None,
66-
embed_model: typing.Optional[str] = None,
67-
generate_model: typing.Optional[str] = None,
6859
) -> typing.Dict[str, typing.List[EventHook]]:
6960
return {
7061
"request": [
@@ -74,17 +65,10 @@ def get_event_hooks(
7465
aws_secret_key=aws_secret_key,
7566
aws_session_token=aws_session_token,
7667
aws_region=aws_region,
77-
chat_model=chat_model,
78-
embed_model=embed_model,
79-
generate_model=generate_model,
8068
),
8169
],
8270
"response": [
83-
map_response_from_bedrock(
84-
chat_model=chat_model,
85-
embed_model=embed_model,
86-
generate_model=generate_model,
87-
)
71+
map_response_from_bedrock()
8872
],
8973
}
9074

@@ -138,17 +122,12 @@ def stream_generator(response: httpx.Response, endpoint: str) -> typing.Iterator
138122
yield (json.dumps(parsed.dict()) + "\n").encode("utf-8") # type: ignore
139123

140124

141-
def map_response_from_bedrock(
142-
chat_model: typing.Optional[str] = None,
143-
embed_model: typing.Optional[str] = None,
144-
generate_model: typing.Optional[str] = None,
145-
):
125+
def map_response_from_bedrock():
146126
def _hook(
147127
response: httpx.Response,
148128
) -> None:
149129
stream = response.headers["content-type"] == "application/vnd.amazon.eventstream"
150-
endpoint = get_endpoint_from_url(
151-
response.url.path, chat_model, embed_model, generate_model)
130+
endpoint = response.request.extensions["endpoint"]
152131
output: typing.Iterator[bytes]
153132

154133
if stream:
@@ -179,9 +158,6 @@ def map_request_to_bedrock(
179158
aws_secret_key: typing.Optional[str] = None,
180159
aws_session_token: typing.Optional[str] = None,
181160
aws_region: typing.Optional[str] = None,
182-
chat_model: typing.Optional[str] = None,
183-
embed_model: typing.Optional[str] = None,
184-
generate_model: typing.Optional[str] = None,
185161
) -> EventHook:
186162
session = boto3.Session(
187163
region_name=aws_region,
@@ -192,23 +168,18 @@ def map_request_to_bedrock(
192168
credentials = session.get_credentials()
193169
signer = SigV4Auth(credentials, service, session.region_name)
194170

195-
model_lookup = {
196-
"embed": embed_model,
197-
"chat": chat_model,
198-
"generate": generate_model,
199-
}
200-
201171
def _event_hook(request: httpx.Request) -> None:
202172
headers = request.headers.copy()
203173
del headers["connection"]
204174

205175
endpoint = request.url.path.split("/")[-1]
206176
body = json.loads(request.read())
177+
model = body["model"]
207178

208179
url = get_url(
209180
platform=service,
210181
aws_region=aws_region,
211-
model=model_lookup[endpoint], # type: ignore
182+
model=model, # type: ignore
212183
stream="stream" in body and body["stream"],
213184
)
214185
request.url = URL(url)
@@ -217,6 +188,9 @@ def _event_hook(request: httpx.Request) -> None:
217188
if "stream" in body:
218189
del body["stream"]
219190

191+
if "model" in body:
192+
del body["model"]
193+
220194
new_body = json.dumps(body).encode("utf-8")
221195
request.stream = ByteStream(new_body)
222196
request._content = new_body
@@ -231,6 +205,7 @@ def _event_hook(request: httpx.Request) -> None:
231205
signer.add_auth(aws_request)
232206

233207
request.headers = httpx.Headers(aws_request.prepare().headers)
208+
request.extensions["endpoint"] = endpoint
234209

235210
return _event_hook
236211

src/cohere/bedrock_client.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@ def __init__(
1717
aws_session_token: typing.Optional[str] = None,
1818
aws_region: typing.Optional[str] = None,
1919
timeout: typing.Optional[float] = None,
20-
chat_model: typing.Optional[str] = None,
21-
embed_model: typing.Optional[str] = None,
22-
generate_model: typing.Optional[str] = None,
2320
):
2421
AwsClient.__init__(
2522
self,
@@ -29,7 +26,4 @@ def __init__(
2926
aws_session_token=aws_session_token,
3027
aws_region=aws_region,
3128
timeout=timeout,
32-
chat_model=chat_model,
33-
embed_model=embed_model,
34-
generate_model=generate_model,
3529
)

src/cohere/sagemaker_client.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@ def __init__(
1414
aws_session_token: typing.Optional[str] = None,
1515
aws_region: typing.Optional[str] = None,
1616
timeout: typing.Optional[float] = None,
17-
chat_model: typing.Optional[str] = None,
18-
embed_model: typing.Optional[str] = None,
19-
generate_model: typing.Optional[str] = None,
2017
):
2118
AwsClient.__init__(
2219
self,
@@ -26,7 +23,4 @@ def __init__(
2623
aws_session_token=aws_session_token,
2724
aws_region=aws_region,
2825
timeout=timeout,
29-
chat_model=chat_model,
30-
embed_model=embed_model,
31-
generate_model=generate_model,
3226
)

tests/test_aws_client.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,73 @@
11
import os
22
import unittest
33

4+
import typing
45
import cohere
56
from parameterized import parameterized_class # type: ignore
67

78
package_dir = os.path.dirname(os.path.abspath(__file__))
89
embed_job = os.path.join(package_dir, 'embed_job.jsonl')
910

1011

12+
models = {
13+
"bedrock": {
14+
"chat_model": "cohere.command-r-plus-v1:0",
15+
"embed_model": "cohere.embed-multilingual-v3",
16+
"generate_model": "cohere.command-text-v14",
17+
},
18+
"sagemaker": {
19+
"chat_model": "cohere.command-r-plus-v1:0",
20+
"embed_model": "cohere.embed-multilingual-v3",
21+
"generate_model": "cohere-command-light",
22+
},
23+
}
24+
25+
1126
@parameterized_class([
1227
{
1328
"client": cohere.BedrockClient(
1429
timeout=10000,
1530
aws_region="us-east-1",
16-
chat_model="cohere.command-r-plus-v1:0",
17-
embed_model="cohere.embed-multilingual-v3",
18-
generate_model="cohere.command-text-v14",
1931
aws_access_key="...",
2032
aws_secret_key="...",
2133
aws_session_token="...",
22-
)
34+
),
35+
"models": models["bedrock"],
2336
},
2437
{
2538
"client": cohere.SagemakerClient(
2639
timeout=10000,
2740
aws_region="us-east-1",
28-
chat_model="cohere.command-r-plus-v1:0",
29-
embed_model="cohere.embed-multilingual-v3",
30-
generate_model="cohere-command-light",
3141
aws_access_key="...",
3242
aws_secret_key="...",
3343
aws_session_token="...",
34-
)
44+
),
45+
"models": models["sagemaker"],
3546
}
3647
])
3748
@unittest.skip("skip tests until they work in CI")
3849
class TestClient(unittest.TestCase):
39-
client: cohere.AwsClient;
50+
client: cohere.AwsClient
51+
models: typing.Dict[str, str]
4052

4153
def test_embed(self) -> None:
4254
response = self.client.embed(
55+
model=self.models["embed_model"],
4356
texts=["I love Cohere!"],
4457
input_type="search_document",
4558
)
4659
print(response)
4760

4861
def test_generate(self) -> None:
4962
response = self.client.generate(
63+
model=self.models["generate_model"],
5064
prompt='Please explain to me how LLMs work',
5165
)
5266
print(response)
5367

5468
def test_generate_stream(self) -> None:
5569
response = self.client.generate_stream(
70+
model=self.models["generate_model"],
5671
prompt='Please explain to me how LLMs work',
5772
)
5873
for event in response:
@@ -62,6 +77,7 @@ def test_generate_stream(self) -> None:
6277

6378
def test_chat(self) -> None:
6479
response = self.client.chat(
80+
model=self.models["chat_model"],
6581
message='Please explain to me how LLMs work',
6682
)
6783
print(response)
@@ -73,6 +89,7 @@ def test_chat(self) -> None:
7389
def test_chat_stream(self) -> None:
7490
response_types = set()
7591
response = self.client.chat_stream(
92+
model=self.models["chat_model"],
7693
message='Please explain to me how LLMs work',
7794
)
7895
for event in response:

0 commit comments

Comments
 (0)