@@ -25,9 +25,6 @@ def __init__(
2525 aws_session_token : typing .Optional [str ] = None ,
2626 aws_region : typing .Optional [str ] = None ,
2727 timeout : typing .Optional [float ] = None ,
28- chat_model : typing .Optional [str ] = None ,
29- embed_model : typing .Optional [str ] = None ,
30- generate_model : typing .Optional [str ] = None ,
3128 service : typing .Union [typing .Literal ["bedrock" ], typing .Literal ["sagemaker" ]],
3229 ):
3330 Client .__init__ (
@@ -44,9 +41,6 @@ def __init__(
4441 aws_secret_key = aws_secret_key ,
4542 aws_session_token = aws_session_token ,
4643 aws_region = aws_region ,
47- chat_model = chat_model ,
48- embed_model = embed_model ,
49- generate_model = generate_model ,
5044 ),
5145 timeout = timeout ,
5246 ),
@@ -62,9 +56,6 @@ def get_event_hooks(
6256 aws_secret_key : typing .Optional [str ] = None ,
6357 aws_session_token : typing .Optional [str ] = None ,
6458 aws_region : typing .Optional [str ] = None ,
65- chat_model : typing .Optional [str ] = None ,
66- embed_model : typing .Optional [str ] = None ,
67- generate_model : typing .Optional [str ] = None ,
6859) -> typing .Dict [str , typing .List [EventHook ]]:
6960 return {
7061 "request" : [
@@ -74,17 +65,10 @@ def get_event_hooks(
7465 aws_secret_key = aws_secret_key ,
7566 aws_session_token = aws_session_token ,
7667 aws_region = aws_region ,
77- chat_model = chat_model ,
78- embed_model = embed_model ,
79- generate_model = generate_model ,
8068 ),
8169 ],
8270 "response" : [
83- map_response_from_bedrock (
84- chat_model = chat_model ,
85- embed_model = embed_model ,
86- generate_model = generate_model ,
87- )
71+ map_response_from_bedrock ()
8872 ],
8973 }
9074
@@ -138,17 +122,12 @@ def stream_generator(response: httpx.Response, endpoint: str) -> typing.Iterator
138122 yield (json .dumps (parsed .dict ()) + "\n " ).encode ("utf-8" ) # type: ignore
139123
140124
141- def map_response_from_bedrock (
142- chat_model : typing .Optional [str ] = None ,
143- embed_model : typing .Optional [str ] = None ,
144- generate_model : typing .Optional [str ] = None ,
145- ):
125+ def map_response_from_bedrock ():
146126 def _hook (
147127 response : httpx .Response ,
148128 ) -> None :
149129 stream = response .headers ["content-type" ] == "application/vnd.amazon.eventstream"
150- endpoint = get_endpoint_from_url (
151- response .url .path , chat_model , embed_model , generate_model )
130+ endpoint = response .request .extensions ["endpoint" ]
152131 output : typing .Iterator [bytes ]
153132
154133 if stream :
@@ -179,9 +158,6 @@ def map_request_to_bedrock(
179158 aws_secret_key : typing .Optional [str ] = None ,
180159 aws_session_token : typing .Optional [str ] = None ,
181160 aws_region : typing .Optional [str ] = None ,
182- chat_model : typing .Optional [str ] = None ,
183- embed_model : typing .Optional [str ] = None ,
184- generate_model : typing .Optional [str ] = None ,
185161) -> EventHook :
186162 session = boto3 .Session (
187163 region_name = aws_region ,
@@ -192,23 +168,18 @@ def map_request_to_bedrock(
192168 credentials = session .get_credentials ()
193169 signer = SigV4Auth (credentials , service , session .region_name )
194170
195- model_lookup = {
196- "embed" : embed_model ,
197- "chat" : chat_model ,
198- "generate" : generate_model ,
199- }
200-
201171 def _event_hook (request : httpx .Request ) -> None :
202172 headers = request .headers .copy ()
203173 del headers ["connection" ]
204174
205175 endpoint = request .url .path .split ("/" )[- 1 ]
206176 body = json .loads (request .read ())
177+ model = body ["model" ]
207178
208179 url = get_url (
209180 platform = service ,
210181 aws_region = aws_region ,
211- model = model_lookup [ endpoint ] , # type: ignore
182+ model = model , # type: ignore
212183 stream = "stream" in body and body ["stream" ],
213184 )
214185 request .url = URL (url )
@@ -217,6 +188,9 @@ def _event_hook(request: httpx.Request) -> None:
217188 if "stream" in body :
218189 del body ["stream" ]
219190
191+ if "model" in body :
192+ del body ["model" ]
193+
220194 new_body = json .dumps (body ).encode ("utf-8" )
221195 request .stream = ByteStream (new_body )
222196 request ._content = new_body
@@ -231,6 +205,7 @@ def _event_hook(request: httpx.Request) -> None:
231205 signer .add_auth (aws_request )
232206
233207 request .headers = httpx .Headers (aws_request .prepare ().headers )
208+ request .extensions ["endpoint" ] = endpoint
234209
235210 return _event_hook
236211
0 commit comments