Skip to content

Commit aa296a8

Browse files
authored
feat: Support Qwen3 models on ChatBedrock (+ gpt-oss streaming) (#679)
Adding support for using the new [Qwen3](https://www.aboutamazon.com/news/aws/alibaba-qwen3-deepseek-v3-amazon-bedrock) serverless models with the InvokeModel APIs via ChatBedrock. This PR also implements ChatBedrock streaming support for both Qwen3 and OpenAI GPT-OSS models.
1 parent da653c3 commit aa296a8

File tree

3 files changed

+191
-19
lines changed

3 files changed

+191
-19
lines changed

libs/aws/langchain_aws/chat_models/bedrock.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def convert_messages_to_prompt_writer(messages: List[BaseMessage]) -> str:
279279
"""Convert a list of messages to a prompt for Writer."""
280280

281281
return "\n".join(
282-
[_convert_one_message_to_text_llama(message) for message in messages]
282+
[_convert_one_message_to_text_writer(message) for message in messages]
283283
)
284284

285285

@@ -741,7 +741,7 @@ def format_messages(
741741
]:
742742
if provider == "anthropic":
743743
return _format_anthropic_messages(messages)
744-
elif provider == "openai":
744+
elif provider in ("openai", "qwen"):
745745
return cast(List[Dict[str, Any]], convert_to_openai_messages(messages))
746746
raise NotImplementedError(
747747
f"Provider {provider} not supported for format_messages"
@@ -914,7 +914,7 @@ def _stream(
914914
system = self.system_prompt_with_tools
915915
else:
916916
system = system_str
917-
elif provider == "openai":
917+
elif provider in ("openai", "qwen"):
918918
formatted_messages = cast(
919919
List[Dict[str, Any]],
920920
ChatPromptAdapter.format_messages(provider, messages),
@@ -1060,7 +1060,7 @@ def _generate(
10601060
else:
10611061
system = system_str
10621062
citations_enabled = _citations_enabled(formatted_messages)
1063-
elif provider == "openai":
1063+
elif provider in ("openai", "qwen"):
10641064
formatted_messages = cast(
10651065
List[Dict[str, Any]],
10661066
ChatPromptAdapter.format_messages(provider, messages),

libs/aws/langchain_aws/llms/bedrock.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,16 @@ def _stream_response_to_generation_chunk(
185185
if k
186186
not in [output_key, "prompt_token_count", "generation_token_count", "created"]
187187
}
188+
189+
if provider in ["mistral", "deepseek", "writer"]:
190+
text = stream_response[output_key][0]["text"]
191+
elif provider in ["openai", "qwen"]:
192+
text = stream_response[output_key][0]["delta"].get("content", "")
193+
else:
194+
text = stream_response[output_key]
195+
188196
return GenerationChunk(
189-
text=(
190-
stream_response[output_key]
191-
if provider not in ["mistral", "deepseek", "writer"]
192-
else stream_response[output_key][0]["text"]
193-
),
197+
text=text,
194198
generation_info=generation_info,
195199
)
196200

@@ -297,6 +301,8 @@ class LLMInputOutputAdapter:
297301
"deepseek": "choices",
298302
"meta": "generation",
299303
"mistral": "outputs",
304+
"openai": "choices",
305+
"qwen": "choices",
300306
"writer": "choices",
301307
}
302308

@@ -402,14 +408,19 @@ def prepare_input(
402408
input_body["max_tokens"] = max_tokens
403409
elif provider == "writer":
404410
input_body["max_tokens"] = max_tokens
405-
elif provider == "openai":
406-
input_body["max_output_tokens"] = max_tokens
407411
else:
408412
# TODO: Add AI21 support, param depends on specific model.
409413
pass
410414
if temperature is not None:
411415
input_body["temperature"] = temperature
412416

417+
elif provider in ("openai", "qwen"):
418+
input_body["messages"] = messages
419+
if max_tokens:
420+
input_body["max_tokens"] = max_tokens
421+
if temperature is not None:
422+
input_body["temperature"] = temperature
423+
413424
elif provider == "amazon":
414425
input_body = dict()
415426
input_body["inputText"] = prompt
@@ -419,12 +430,6 @@ def prepare_input(
419430
if temperature is not None:
420431
input_body["textGenerationConfig"]["temperature"] = temperature
421432

422-
elif provider == "openai":
423-
input_body["messages"] = messages
424-
if max_tokens:
425-
input_body["max_tokens"] = max_tokens
426-
if temperature is not None:
427-
input_body["temperature"] = temperature
428433
else:
429434
input_body["inputText"] = prompt
430435

@@ -478,6 +483,8 @@ def prepare_output(cls, provider: str, response: Any) -> dict:
478483
text = response_body.get("outputs")[0].get("text")
479484
elif provider == "openai":
480485
text = response_body.get("choices")[0].get("message").get("content")
486+
elif provider == "qwen":
487+
text = response_body.get("choices")[0].get("message").get("content")
481488
else:
482489
text = response_body.get("results")[0].get("outputText")
483490

@@ -576,6 +583,14 @@ def prepare_output_stream(
576583
yield _get_invocation_metrics_chunk(chunk_obj)
577584
return
578585

586+
elif (
587+
provider in ("qwen", "openai")
588+
and chunk_obj.get(output_key, [{}])[0].get("finish_reason", "")
589+
== "stop"
590+
):
591+
yield _get_invocation_metrics_chunk(chunk_obj)
592+
return
593+
579594
elif messages_api and (chunk_obj.get("type") == "message_stop"):
580595
yield _get_invocation_metrics_chunk(chunk_obj)
581596
return
@@ -619,6 +634,14 @@ async def aprepare_output_stream(
619634
):
620635
return
621636

637+
elif (
638+
provider in ("qwen", "openai")
639+
and chunk_obj.get(output_key, [{}])[0].get("finish_reason", "")
640+
== "stop"
641+
):
642+
yield _get_invocation_metrics_chunk(chunk_obj)
643+
return
644+
622645
generation_chunk = _stream_response_to_generation_chunk(
623646
chunk_obj,
624647
provider=provider,
@@ -1219,7 +1242,7 @@ def _prepare_input_and_invoke_stream(
12191242
provider,
12201243
response,
12211244
stop,
1222-
True if messages else False,
1245+
True if (messages and provider == "anthropic") else False,
12231246
coerce_content_to_string=coerce_content_to_string,
12241247
):
12251248
yield chunk
@@ -1288,7 +1311,7 @@ async def _aprepare_input_and_invoke_stream(
12881311
provider,
12891312
response,
12901313
stop,
1291-
True if messages else False,
1314+
True if (messages and provider == "anthropic") else False,
12921315
):
12931316
yield chunk
12941317

libs/aws/tests/unit_tests/llms/test_bedrock.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,63 @@ def test__human_assistant_format() -> None:
304304
{"chunk": {"bytes": b'"[DONE]"'}},
305305
]
306306

307+
MOCK_STREAMING_RESPONSE_QWEN = [
308+
{
309+
"chunk": {
310+
"bytes": b'{"choices": [{"delta": {"content": "", "role": "assistant"}, '
311+
b'"finish_reason": null, "index": 0}], '
312+
b'"created": 1759875373, '
313+
b'"id": "chatcmpl-a069cbda08ce4599afae798c4d2de095", '
314+
b'"model": "qwen.qwen3-32b-v1:0", '
315+
b'"object": "chat.completion.chunk", '
316+
b'"service_tier": "auto"}'
317+
}
318+
},
319+
{
320+
"chunk": {
321+
"bytes": b'{"choices": [{"delta": {"content": "Hello. \\nGoodbye."}, '
322+
b'"finish_reason": "stop", "index": 0}], '
323+
b'"created": 1759875373, '
324+
b'"id": "chatcmpl-a069cbda08ce4599afae798c4d2de095", '
325+
b'"model": "qwen.qwen3-32b-v1:0", '
326+
b'"object": "chat.completion.chunk", '
327+
b'"service_tier": "auto", '
328+
b'"amazon-bedrock-invocationMetrics": {'
329+
b'"inputTokenCount": 35, "outputTokenCount": 7, '
330+
b'"invocationLatency": 225, "firstByteLatency": 191}}'
331+
}
332+
},
333+
]
334+
335+
MOCK_STREAMING_RESPONSE_OPENAI = [
336+
{
337+
"chunk": {
338+
"bytes": b'{"choices": [{"delta": {"content": "Hello."}, '
339+
b'"finish_reason": null, "index": 0}], '
340+
b'"created": 1759813667, '
341+
b'"id": "chatcmpl-fa6fb768b71046eeb3880cbb4a1b07c1", '
342+
b'"model": "openai.gpt-oss-20b-1:0", '
343+
b'"object": "chat.completion.chunk", "service_tier": "auto"}'
344+
}
345+
},
346+
{
347+
"chunk": {
348+
"bytes": b'{"choices": [{"delta": {}, '
349+
b'"finish_reason": "stop", "index": 0}],'
350+
b' "created": 1759813667, '
351+
b'"id": "chatcmpl-fa6fb768b71046eeb3880cbb4a1b07c1", '
352+
b'"model": "openai.gpt-oss-20b-1:0", '
353+
b'"object": "chat.completion.chunk", '
354+
b'"service_tier": "auto", '
355+
b'"amazon-bedrock-invocationMetrics": {'
356+
b'"inputTokenCount": 84, '
357+
b'"outputTokenCount": 87, '
358+
b'"invocationLatency": 3981, '
359+
b'"firstByteLatency": 3615}}'
360+
}
361+
},
362+
]
363+
307364

308365
async def async_gen_mock_streaming_response() -> AsyncGenerator[Dict, None]:
309366
for item in MOCK_STREAMING_RESPONSE:
@@ -421,6 +478,56 @@ def writer_streaming_response():
421478
return response
422479

423480

481+
@pytest.fixture
482+
def qwen_response():
483+
body = MagicMock()
484+
body.read.return_value = json.dumps(
485+
{"choices": [{"message": {"content": "This is the Qwen output text."}}]}
486+
).encode()
487+
response = dict(
488+
body=body,
489+
ResponseMetadata={
490+
"HTTPHeaders": {
491+
"x-amzn-bedrock-input-token-count": "35",
492+
"x-amzn-bedrock-output-token-count": "42",
493+
}
494+
},
495+
)
496+
497+
return response
498+
499+
500+
@pytest.fixture
501+
def qwen_streaming_response():
502+
response = dict(body=MOCK_STREAMING_RESPONSE_QWEN)
503+
return response
504+
505+
506+
@pytest.fixture
507+
def openai_response():
508+
body = MagicMock()
509+
body.read.return_value = json.dumps(
510+
{"choices": [{"message": {"content": "This is the OpenAI output text."}}]}
511+
).encode()
512+
response = dict(
513+
body=body,
514+
ResponseMetadata={
515+
"HTTPHeaders": {
516+
"x-amzn-bedrock-input-token-count": "85",
517+
"x-amzn-bedrock-output-token-count": "80",
518+
}
519+
},
520+
)
521+
522+
return response
523+
524+
525+
@pytest.fixture
526+
def openai_streaming_response():
527+
response = dict(body=MOCK_STREAMING_RESPONSE_OPENAI)
528+
return response
529+
530+
424531
@pytest.fixture
425532
def cohere_response():
426533
body = MagicMock()
@@ -556,6 +663,48 @@ def test_prepare_output_stream_for_writer(writer_streaming_response) -> None:
556663
assert results[1] == "lo."
557664

558665

666+
def test_prepare_output_for_qwen(qwen_response):
667+
result = LLMInputOutputAdapter.prepare_output("qwen", qwen_response)
668+
assert result["text"] == "This is the Qwen output text."
669+
assert result["usage"]["prompt_tokens"] == 35
670+
assert result["usage"]["completion_tokens"] == 42
671+
assert result["usage"]["total_tokens"] == 77
672+
assert result["stop_reason"] is None
673+
674+
675+
def test_prepare_output_stream_for_qwen(qwen_streaming_response) -> None:
676+
results = [
677+
chunk.text
678+
for chunk in LLMInputOutputAdapter.prepare_output_stream(
679+
"qwen", qwen_streaming_response
680+
)
681+
]
682+
683+
assert results[0] == ""
684+
assert results[1] == "Hello. \nGoodbye."
685+
686+
687+
def test_prepare_output_for_openai(openai_response):
688+
result = LLMInputOutputAdapter.prepare_output("openai", openai_response)
689+
assert result["text"] == "This is the OpenAI output text."
690+
assert result["usage"]["prompt_tokens"] == 85
691+
assert result["usage"]["completion_tokens"] == 80
692+
assert result["usage"]["total_tokens"] == 165
693+
assert result["stop_reason"] is None
694+
695+
696+
def test_prepare_output_stream_for_openai(openai_streaming_response) -> None:
697+
results = [
698+
chunk.text
699+
for chunk in LLMInputOutputAdapter.prepare_output_stream(
700+
"openai", openai_streaming_response
701+
)
702+
]
703+
704+
assert results[0] == "Hello."
705+
assert results[1] == ""
706+
707+
559708
def test_prepare_output_for_cohere(cohere_response):
560709
result = LLMInputOutputAdapter.prepare_output("cohere", cohere_response)
561710
assert result["text"] == "This is the Cohere output text."

0 commit comments

Comments
 (0)