Skip to content

Commit add37bf

Browse files
Map X-Amzn-Bedrock token counts (#537)
* Map X-Amzn-Bedrock token counts * Type fixes
1 parent b13f5c5 commit add37bf

File tree

2 files changed

+36
-11
lines changed

2 files changed

+36
-11
lines changed

src/cohere/aws_client.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from tokenizers import Tokenizer # type: ignore
1212

1313
from . import GenerateStreamedResponse, Generation, \
14-
NonStreamedChatResponse, EmbedResponse, StreamedChatResponse, RerankResponse
14+
NonStreamedChatResponse, EmbedResponse, StreamedChatResponse, RerankResponse, ApiMeta, ApiMetaTokens, \
15+
ApiMetaBilledUnits
1516
from .client import Client, ClientEnvironment
1617
from .core import construct_type
1718

@@ -120,7 +121,16 @@ def stream_generator(response: httpx.Response, endpoint: str) -> typing.Iterator
120121
response_type = stream_response_mapping[endpoint]
121122
parsed = typing.cast(response_type, # type: ignore
122123
construct_type(type_=response_type, object_=streamed_obj))
123-
yield (json.dumps(parsed.dict()) + "\n").encode("utf-8") # type: ignore
124+
yield (json.dumps(parsed.dict()) + "\n").encode("utf-8") # type: ignore
125+
126+
127+
def map_token_counts(response: httpx.Response) -> ApiMeta:
128+
input_tokens = int(response.headers["X-Amzn-Bedrock-Input-Token-Count"])
129+
output_tokens = int(response.headers["X-Amzn-Bedrock-Output-Token-Count"])
130+
return ApiMeta(
131+
tokens=ApiMetaTokens(input_tokens=input_tokens, output_tokens=output_tokens),
132+
billed_units=ApiMetaBilledUnits(input_tokens=input_tokens, output_tokens=output_tokens),
133+
)
124134

125135

126136
def map_response_from_bedrock():
@@ -138,15 +148,20 @@ def _hook(
138148
), endpoint)
139149
else:
140150
response_type = response_mapping[endpoint]
141-
output = iter([json.dumps(typing.cast(response_type, # type: ignore
142-
construct_type(
143-
type_=response_type,
144-
# type: ignore
145-
object_=json.loads(response.read()))).dict()
146-
).encode(
147-
"utf-8")])
151+
response_obj = json.loads(response.read())
152+
response_obj["meta"] = map_token_counts(response).dict()
153+
cast_obj: typing.Any = typing.cast(response_type, # type: ignore
154+
construct_type(
155+
type_=response_type,
156+
# type: ignore
157+
object_=response_obj))
158+
159+
output = iter([json.dumps(cast_obj.dict()).encode("utf-8")])
148160

149161
response.stream = Streamer(output)
162+
163+
# reset response object to allow for re-reading
164+
del response._content
150165
response.is_stream_consumed = False
151166
response.is_closed = False
152167

@@ -239,5 +254,3 @@ def get_url(
239254
endpoint = "invocations" if not stream else "invocations-response-stream"
240255
return f"https://runtime.sagemaker.{aws_region}.amazonaws.com/endpoints/{model}/{endpoint}"
241256
return ""
242-
243-

tests/test_aws_client.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,18 @@ def test_chat(self) -> None:
108108
self.assertIsNotNone(response.text)
109109
self.assertIsNotNone(response.generation_id)
110110
self.assertIsNotNone(response.finish_reason)
111+
112+
self.assertIsNotNone(response.meta)
113+
if response.meta is not None:
114+
self.assertIsNotNone(response.meta.tokens)
115+
if response.meta.tokens is not None:
116+
self.assertIsNotNone(response.meta.tokens.input_tokens)
117+
self.assertIsNotNone(response.meta.tokens.output_tokens)
118+
119+
self.assertIsNotNone(response.meta.billed_units)
120+
if response.meta.billed_units is not None:
121+
self.assertIsNotNone(response.meta.billed_units.input_tokens)
122+
self.assertIsNotNone(response.meta.billed_units.input_tokens)
111123

112124
def test_chat_stream(self) -> None:
113125
response_types = set()

0 commit comments

Comments
 (0)