Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions libs/aws/langchain_aws/chat_models/bedrock_converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,17 +538,25 @@ def create_cache_point(cls, cache_type: str = "default") -> Dict[str, Any]:
@classmethod
def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
values = _build_model_kwargs(values, all_required_field_names)

# Merge model_kwargs (name assumed in langchain-core) and
# additional_model_request_fields (name used in ChatBedrockConverse)
model_kwargs = values.pop("model_kwargs", {})
additional_model_request_fields = values.pop(
"additional_model_request_fields", {}
)
if additional_model_request_fields or model_kwargs:
if model_kwargs:
model_kwargs_msg = (
"Please use additional_model_request_fields instead of "
"model_kwargs for any extra inference parameters."
)
logger.warning(model_kwargs_msg)
warnings.warn(model_kwargs_msg)

all_required_field_names = get_pydantic_field_names(cls)
values = _build_model_kwargs(values, all_required_field_names)
base_model_kwargs = values.pop("model_kwargs", {})

if additional_model_request_fields or model_kwargs or base_model_kwargs:
values["additional_model_request_fields"] = {
**base_model_kwargs,
**model_kwargs,
**additional_model_request_fields,
}
Expand Down
10 changes: 10 additions & 0 deletions libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1627,6 +1627,16 @@ def test_model_kwargs() -> None:
assert llm.region_name == "us-west-2"
assert llm.additional_model_request_fields == {"foo": "bar"}

with pytest.warns(match="additional_model_request_fields instead of model_kwargs"):
llm = ChatBedrockConverse(
model="my-model",
region_name="us-west-2",
model_kwargs={"foo": "bar"}, # type: ignore[call-arg]
)
assert llm.model_id == "my-model"
assert llm.region_name == "us-west-2"
assert llm.additional_model_request_fields == {"foo": "bar"}

with pytest.warns(match="transferred to model_kwargs"):
llm = ChatBedrockConverse( # type: ignore[call-arg]
model="my-model",
Expand Down