Skip to content

Commit 0ca914e

Browse files
authored
fix(langchain): preserve anthropic cache metrics (#455)
LangChain Anthropic responses report cache reads and cache writes separately from normal input tokens, including TTL-specific cache creation buckets. The previous cached-token fix avoided OpenAI double counting, but it could drop Anthropic cache-write detail from spans and produce totals that were less useful for cost analysis. Preserve the cache creation metrics users need to understand prompt-cache spend and keep token totals aligned with the prompt-cache semantics, while continuing to avoid double counting OpenAI cached input tokens. ref https://github.com/braintrustdata/braintrust-spec/blob/main/docs/features/prompt-cache.md
1 parent 4079ffa commit 0ca914e

5 files changed

Lines changed: 135 additions & 58 deletions

File tree

examples/langchain/README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,20 @@ export OPENAI_API_KEY=...
1111
uv sync
1212
uv run python example.py
1313
```
14+
15+
## Anthropic prompt-cache metrics demo
16+
17+
Use it to verify cache reads/writes and token totals on real Braintrust spans.
18+
19+
```bash
20+
# Loads BRAINTRUST_API_KEY and ANTHROPIC_API_KEY from ../../.env automatically.
21+
uv sync
22+
uv run python anthropic_prompt_cache.py
23+
```
24+
25+
To inspect the logged spans with the Braintrust CLI:
26+
27+
```bash
28+
bt projects list --json | jq '.[] | select(.name == "z-abhi-langchain-anthropic-cache-demo")'
29+
bt view logs --object-ref project_logs:<project-id> --list-mode spans --limit 10 --json
30+
```
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#!/usr/bin/env python
2+
"""Verify LangChain Anthropic prompt-cache metrics in Braintrust.
3+
4+
This sends two Anthropic requests through LangChain with a cacheable system
5+
prompt. The resulting Braintrust spans should show Anthropic cache reads and
6+
cache writes, including TTL-specific cache creation metrics when Anthropic
7+
returns them.
8+
"""
9+
10+
import os
11+
import uuid
12+
from pathlib import Path
13+
14+
import braintrust
15+
from braintrust.integrations.langchain import BraintrustCallbackHandler
16+
from dotenv import load_dotenv
17+
from langchain_anthropic import ChatAnthropic
18+
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
19+
20+
21+
ROOT = Path(__file__).resolve().parents[2]
22+
load_dotenv(ROOT / ".env")
23+
24+
PROJECT_NAME = os.environ.get("BRAINTRUST_PROJECT", "py-sdk-demo-langchain-anthropic-cache")
25+
MODEL = os.environ.get("ANTHROPIC_MODEL", "claude-sonnet-4-5-20250929")
26+
27+
# Anthropic prompt caching requires a sufficiently long cacheable prefix.
28+
CACHEABLE_SYSTEM_PROMPT = "\n".join(
29+
[
30+
"You are helping validate prompt-cache accounting in an SDK integration.",
31+
"Always answer briefly and mention the requested section title.",
32+
"",
33+
"Reference document:",
34+
*[
35+
f"Section {i}: This paragraph describes stable product guidance, tracing semantics, "
36+
"token accounting, and prompt-cache behavior for repeat requests."
37+
for i in range(1, 90)
38+
],
39+
f"Stable cache key: {os.environ.get('CACHE_DEMO_KEY', 'langchain-anthropic-cache-demo')}",
40+
]
41+
)
42+
43+
44+
def main() -> None:
45+
logger = braintrust.init_logger(project=PROJECT_NAME)
46+
handler = BraintrustCallbackHandler(logger=logger)
47+
model = ChatAnthropic(model=MODEL, max_tokens=64)
48+
49+
messages: list[BaseMessage] = [
50+
SystemMessage(
51+
content=[
52+
{
53+
"type": "text",
54+
"text": CACHEABLE_SYSTEM_PROMPT,
55+
"cache_control": {"type": "ephemeral"},
56+
}
57+
]
58+
),
59+
HumanMessage(content=f"What is this document for? Run id: {uuid.uuid4().hex}"),
60+
]
61+
62+
for label in ("cache write", "cache read"):
63+
result = model.invoke(messages, config={"callbacks": [handler]})
64+
print(f"{label}: {result.content}")
65+
66+
braintrust.flush()
67+
print(f"Logged demo spans to Braintrust project: {PROJECT_NAME}")
68+
69+
70+
if __name__ == "__main__":
71+
main()

examples/langchain/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ description = "LangChain chain traced via BraintrustCallbackHandler"
55
requires-python = ">=3.10"
66
dependencies = [
77
"braintrust",
8+
"langchain-anthropic",
89
"langchain-core",
910
"langchain-openai",
11+
"python-dotenv",
1012
]
1113

1214
[tool.uv.sources]

py/src/braintrust/integrations/langchain/callbacks.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -661,36 +661,40 @@ def _get_metrics_from_response(response: LLMResult):
661661
input_token_details = usage_metadata.get("input_token_details")
662662
if input_token_details and isinstance(input_token_details, dict):
663663
cache_read = input_token_details.get("cache_read")
664-
# langchain-anthropic >= 1.4.0 maps cache_creation_input_tokens to
665-
# ephemeral tier fields (ephemeral_5m_input_tokens, ephemeral_1h_input_tokens)
666-
# rather than the top-level cache_creation field. Sum both for compat.
667664
cache_creation = input_token_details.get("cache_creation")
668-
if not cache_creation and (
669-
"ephemeral_5m_input_tokens" in input_token_details
670-
or "ephemeral_1h_input_tokens" in input_token_details
671-
):
672-
cache_creation = input_token_details.get("ephemeral_5m_input_tokens", 0) + input_token_details.get(
673-
"ephemeral_1h_input_tokens", 0
674-
)
665+
cache_creation_5m = input_token_details.get("ephemeral_5m_input_tokens")
666+
cache_creation_1h = input_token_details.get("ephemeral_1h_input_tokens")
667+
has_cache_creation_breakdown = cache_creation_5m is not None or cache_creation_1h is not None
675668

676669
if cache_read is not None:
677670
metrics["prompt_cached_tokens"] = cache_read
678-
if cache_creation is not None:
679-
metrics["prompt_cache_creation_tokens"] = cache_creation
680-
681-
cache_tokens = (cache_read or 0) + (cache_creation or 0)
671+
if has_cache_creation_breakdown:
672+
# Anthropic exposes TTL-specific cache creation buckets. Preserve the
673+
# split so downstream cost tooling can price 5m vs 1h writes correctly.
674+
if cache_creation_5m is not None:
675+
metrics["prompt_cache_creation_5m_tokens"] = cache_creation_5m
676+
if cache_creation_1h is not None:
677+
metrics["prompt_cache_creation_1h_tokens"] = cache_creation_1h
678+
effective_cache_creation = (cache_creation_5m or 0) + (cache_creation_1h or 0)
679+
else:
680+
if cache_creation is not None:
681+
metrics["prompt_cache_creation_tokens"] = cache_creation
682+
effective_cache_creation = cache_creation or 0
683+
cache_tokens = (cache_read or 0) + effective_cache_creation
682684
prompt_tokens = metrics.get("prompt_tokens")
683685
completion_tokens = metrics.get("completion_tokens")
684686
total_tokens = metrics.get("total_tokens")
685-
if (
686-
cache_tokens
687-
and prompt_tokens is not None
688-
and completion_tokens is not None
689-
and total_tokens == prompt_tokens + completion_tokens
690-
and _cache_tokens_are_separate_from_input_tokens(input_token_details)
691-
):
692-
metrics["prompt_tokens"] = prompt_tokens + cache_tokens
693-
metrics["total_tokens"] = total_tokens + cache_tokens
687+
if prompt_tokens is not None and completion_tokens is not None:
688+
if (
689+
cache_tokens
690+
and total_tokens == prompt_tokens + completion_tokens
691+
and _cache_tokens_are_separate_from_input_tokens(input_token_details)
692+
):
693+
prompt_tokens += cache_tokens
694+
metrics["prompt_tokens"] = prompt_tokens
695+
if total_tokens is not None:
696+
metrics["total_tokens"] = total_tokens + cache_tokens
697+
metrics["tokens"] = prompt_tokens + completion_tokens
694698

695699
if not metrics or not any(metrics.values()):
696700
llm_output: dict[str, Any] = response.llm_output or {}

py/src/braintrust/integrations/langchain/test_callbacks.py

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,10 @@
88
import pytest
99
from braintrust import logger
1010
from braintrust.integrations.langchain import BraintrustCallbackHandler
11-
from braintrust.integrations.langchain.callbacks import _get_metrics_from_response
1211
from braintrust.logger import flush
1312
from braintrust.test_helpers import init_test_logger
1413
from langchain_core.callbacks import BaseCallbackHandler
1514
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage
16-
from langchain_core.outputs import ChatGeneration, LLMResult
1715
from langchain_core.prompts import ChatPromptTemplate
1816
from langchain_core.prompts.prompt import PromptTemplate
1917
from langchain_core.runnables import RunnableMap, RunnableSerializable
@@ -908,34 +906,6 @@ def test_streaming_ttft(logger_memory_logger):
908906
)
909907

910908

911-
def test_openai_cached_tokens_are_not_folded_into_prompt_tokens():
912-
response = LLMResult(
913-
generations=[
914-
[
915-
ChatGeneration(
916-
message=AIMessage(
917-
content="Done",
918-
response_metadata={"model_name": "gpt-4o-mini-2024-07-18"},
919-
usage_metadata={
920-
"input_tokens": 1000,
921-
"output_tokens": 200,
922-
"total_tokens": 1200,
923-
"input_token_details": {"cache_read": 500},
924-
},
925-
)
926-
)
927-
]
928-
]
929-
)
930-
931-
assert _get_metrics_from_response(response) == {
932-
"prompt_tokens": 1000,
933-
"completion_tokens": 200,
934-
"total_tokens": 1200,
935-
"prompt_cached_tokens": 500,
936-
}
937-
938-
939909
@pytest.mark.vcr
940910
def test_prompt_caching_tokens(logger_memory_logger):
941911
from langchain_anthropic import ChatAnthropic
@@ -1114,11 +1084,19 @@ def test_prompt_caching_tokens(logger_memory_logger):
11141084
assert "prompt_tokens" in first_metrics
11151085
assert first_metrics["prompt_tokens"] > 0
11161086

1117-
assert "prompt_cache_creation_tokens" in first_metrics
1118-
assert first_metrics["prompt_cache_creation_tokens"] > 0
1087+
first_has_cache_creation_split = (
1088+
"prompt_cache_creation_5m_tokens" in first_metrics or "prompt_cache_creation_1h_tokens" in first_metrics
1089+
)
1090+
first_cache_creation_split = first_metrics.get("prompt_cache_creation_5m_tokens", 0) + first_metrics.get(
1091+
"prompt_cache_creation_1h_tokens", 0
1092+
)
1093+
first_cache_creation_tokens = first_cache_creation_split or first_metrics.get("prompt_cache_creation_tokens", 0)
1094+
assert first_cache_creation_tokens > 0
1095+
if first_has_cache_creation_split:
1096+
assert "prompt_cache_creation_tokens" not in first_metrics
11191097
assert first_metrics["prompt_cached_tokens"] == 0
1120-
assert first_metrics["prompt_tokens"] >= first_metrics["prompt_cache_creation_tokens"]
1121-
assert first_metrics["total_tokens"] == first_metrics["prompt_tokens"] + first_metrics["completion_tokens"]
1098+
assert first_metrics["prompt_tokens"] >= first_cache_creation_tokens
1099+
assert first_metrics["tokens"] == first_metrics["prompt_tokens"] + first_metrics["completion_tokens"]
11221100

11231101
second_metrics = None
11241102
for attempt in range(3):
@@ -1147,9 +1125,14 @@ def test_prompt_caching_tokens(logger_memory_logger):
11471125
time.sleep(1)
11481126

11491127
assert second_metrics is not None
1128+
second_has_cache_creation_split = (
1129+
"prompt_cache_creation_5m_tokens" in second_metrics or "prompt_cache_creation_1h_tokens" in second_metrics
1130+
)
1131+
if second_has_cache_creation_split:
1132+
assert "prompt_cache_creation_tokens" not in second_metrics
11501133
assert second_metrics["prompt_cached_tokens"] > 0
11511134
assert second_metrics["prompt_tokens"] >= second_metrics["prompt_cached_tokens"]
1152-
assert second_metrics["total_tokens"] == second_metrics["prompt_tokens"] + second_metrics["completion_tokens"]
1135+
assert second_metrics["tokens"] == second_metrics["prompt_tokens"] + second_metrics["completion_tokens"]
11531136

11541137

11551138
@pytest.mark.vcr

0 commit comments

Comments
 (0)