1111from tokenizers import Tokenizer # type: ignore
1212
1313from . import GenerateStreamedResponse , Generation , \
14- NonStreamedChatResponse , EmbedResponse , StreamedChatResponse , RerankResponse
14+ NonStreamedChatResponse , EmbedResponse , StreamedChatResponse , RerankResponse , ApiMeta , ApiMetaTokens , \
15+ ApiMetaBilledUnits
1516from .client import Client , ClientEnvironment
1617from .core import construct_type
1718
@@ -120,7 +121,16 @@ def stream_generator(response: httpx.Response, endpoint: str) -> typing.Iterator
120121 response_type = stream_response_mapping [endpoint ]
121122 parsed = typing .cast (response_type , # type: ignore
122123 construct_type (type_ = response_type , object_ = streamed_obj ))
123- yield (json .dumps (parsed .dict ()) + "\n " ).encode ("utf-8" ) # type: ignore
124+ yield (json .dumps (parsed .dict ()) + "\n " ).encode ("utf-8" ) # type: ignore
125+
126+
127+ def map_token_counts (response : httpx .Response ) -> ApiMeta :
128+ input_tokens = int (response .headers ["X-Amzn-Bedrock-Input-Token-Count" ])
129+ output_tokens = int (response .headers ["X-Amzn-Bedrock-Output-Token-Count" ])
130+ return ApiMeta (
131+ tokens = ApiMetaTokens (input_tokens = input_tokens , output_tokens = output_tokens ),
132+ billed_units = ApiMetaBilledUnits (input_tokens = input_tokens , output_tokens = output_tokens ),
133+ )
124134
125135
126136def map_response_from_bedrock ():
@@ -138,15 +148,20 @@ def _hook(
138148 ), endpoint )
139149 else :
140150 response_type = response_mapping [endpoint ]
141- output = iter ([json .dumps (typing .cast (response_type , # type: ignore
142- construct_type (
143- type_ = response_type ,
144- # type: ignore
145- object_ = json .loads (response .read ()))).dict ()
146- ).encode (
147- "utf-8" )])
151+ response_obj = json .loads (response .read ())
152+ response_obj ["meta" ] = map_token_counts (response ).dict ()
153+ cast_obj : typing .Any = typing .cast (response_type , # type: ignore
154+ construct_type (
155+ type_ = response_type ,
156+ # type: ignore
157+ object_ = response_obj ))
158+
159+ output = iter ([json .dumps (cast_obj .dict ()).encode ("utf-8" )])
148160
149161 response .stream = Streamer (output )
162+
163+ # reset response object to allow for re-reading
164+ del response ._content
150165 response .is_stream_consumed = False
151166 response .is_closed = False
152167
@@ -239,5 +254,3 @@ def get_url(
239254 endpoint = "invocations" if not stream else "invocations-response-stream"
240255 return f"https://runtime.sagemaker.{ aws_region } .amazonaws.com/endpoints/{ model } /{ endpoint } "
241256 return ""
242-
243-
0 commit comments