Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4015,7 +4015,11 @@ def embedding( # noqa: PLR0915
azure_ad_token_provider = kwargs.get("azure_ad_token_provider", None)
aembedding: Optional[bool] = kwargs.get("aembedding", None)
extra_headers = kwargs.get("extra_headers", None)
headers = kwargs.get("headers", None)
headers = kwargs.get("headers", None) or extra_headers
if headers is None:
headers = {}
if extra_headers is not None:
headers.update(extra_headers)
### CUSTOM MODEL COST ###
input_cost_per_token = kwargs.get("input_cost_per_token", None)
output_cost_per_token = kwargs.get("output_cost_per_token", None)
Expand Down Expand Up @@ -4326,7 +4330,7 @@ def embedding( # noqa: PLR0915
litellm_params={},
api_base=api_base,
print_verbose=print_verbose,
extra_headers=extra_headers,
extra_headers=headers,
api_key=api_key,
)
elif custom_llm_provider == "triton":
Expand Down
152 changes: 151 additions & 1 deletion tests/test_litellm/llms/bedrock/embed/test_bedrock_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,4 +404,154 @@ def test_twelvelabs_missing_input_type_error():
)

# Should succeed without input_type
assert isinstance(response, litellm.EmbeddingResponse)
assert isinstance(response, litellm.EmbeddingResponse)


@pytest.mark.parametrize(
"model,embed_response",
[
("bedrock/amazon.titan-embed-text-v1", titan_embedding_response),
("bedrock/amazon.titan-embed-text-v2:0", titan_embedding_response),
("bedrock/cohere.embed-english-v3", cohere_embedding_response),
],
)
def test_bedrock_embedding_header_forwarding(model, embed_response):
"""
Test that custom headers are correctly forwarded to Bedrock embedding API calls.

This test verifies the fix for the issue where headers configured via
forward_client_headers_to_llm_api were not being passed to Bedrock embedding provider.

Relevant Issue: https://github.com/BerriAI/litellm/pull/16042
"""
litellm.set_verbose = True
client = HTTPHandler()
test_api_key = "test-bearer-token-12345"

# Headers that would be set by the proxy when forwarding client headers
custom_headers = {
"X-Custom-Header": "CustomValue",
"X-BYOK-Token": "secret-token",
"Extra-Header": "foobar",
}

with patch.object(client, "post") as mock_post:
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = json.dumps(embed_response)
mock_response.json = lambda: json.loads(mock_response.text)
mock_post.return_value = mock_response

try:
# Call embedding with custom headers via kwargs
# This simulates what the proxy does when forward_client_headers_to_llm_api is set
response = litellm.embedding(
model=model,
input=test_input,
client=client,
headers=custom_headers, # This is how proxy passes forwarded headers
aws_region_name="us-east-1",
aws_bedrock_runtime_endpoint="https://bedrock-runtime.us-east-1.amazonaws.com",
api_key=test_api_key,
)

assert isinstance(response, litellm.EmbeddingResponse)

# Verify that the request was made
assert mock_post.called, "HTTP client post should be called"

# Get the actual call arguments
call_kwargs = mock_post.call_args.kwargs
headers = call_kwargs.get("headers", {})

# Verify our custom headers are present in the request headers
# Note: AWS SigV4 signing may modify header names to lowercase
for header_key, header_value in custom_headers.items():
header_found = (
header_key in headers
or header_key.lower() in headers
or any(k.lower() == header_key.lower() for k in headers.keys())
)
assert header_found, (
f"Header {header_key} should be in request headers. "
f"Found headers: {list(headers.keys())}"
)

print(f"✓ Test passed for {model}")
print(f" Headers correctly forwarded: {list(headers.keys())}")

except Exception as e:
pytest.fail(f"Failed to forward headers to {model}: {str(e)}")


def test_bedrock_embedding_extra_headers_and_headers_merge():
"""
Test that both extra_headers and headers parameters are correctly merged for Bedrock embeddings.

This ensures that headers from kwargs (forwarded by proxy) and extra_headers
(passed explicitly) are both included in the final headers sent to the provider.
"""
litellm.set_verbose = True
client = HTTPHandler()
test_api_key = "test-bearer-token-12345"
model = "bedrock/amazon.titan-embed-text-v1"

# Headers from proxy (via kwargs["headers"])
proxy_headers = {"X-Forwarded-Header": "ProxyValue"}

# Explicit extra_headers
explicit_headers = {"X-Explicit-Header": "ExplicitValue"}

# Mock response
embed_response = {
"embedding": [0.1, 0.2, 0.3],
"inputTextTokenCount": 10
}

with patch.object(client, "post") as mock_post:
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = json.dumps(embed_response)
mock_response.json = lambda: json.loads(mock_response.text)
mock_post.return_value = mock_response

try:
response = litellm.embedding(
model=model,
input=test_input,
client=client,
headers=proxy_headers, # From proxy forwarding
extra_headers=explicit_headers, # Explicitly passed
aws_region_name="us-east-1",
aws_bedrock_runtime_endpoint="https://bedrock-runtime.us-east-1.amazonaws.com",
api_key=test_api_key,
)

assert isinstance(response, litellm.EmbeddingResponse)

call_kwargs = mock_post.call_args.kwargs
headers = call_kwargs.get("headers", {})

# Both sets of headers should be present
# Note: AWS SigV4 signing may modify header names to lowercase
proxy_header_found = any(
k.lower() == "x-forwarded-header" for k in headers.keys()
)
assert proxy_header_found, (
"Proxy forwarded header should be present. "
f"Found headers: {list(headers.keys())}"
)

explicit_header_found = any(
k.lower() == "x-explicit-header" for k in headers.keys()
)
assert explicit_header_found, (
"Explicitly passed header should be present. "
f"Found headers: {list(headers.keys())}"
)

print("✓ Both header sources correctly merged and forwarded")
print(f" Final headers: {list(headers.keys())}")

except Exception as e:
pytest.fail(f"Failed to merge and forward headers: {str(e)}")
Loading