Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

import json
import logging
import math
from collections import OrderedDict
from collections.abc import AsyncIterator
from typing import Any

Expand All @@ -43,6 +45,21 @@

logger = logging.getLogger(__name__)

# Lower bound on fuzzy-match similarity to reduce the cache-poisoning surface.
# A threshold below this makes it trivial to craft an input that is "similar
# enough" to a legitimate user's cached key to hijack their response (for the
# current process, which is how the in-memory cache is scoped). 0.85 is the
# smallest value that we're comfortable shipping as an unconditional default;
# operators with strict needs should use 1.0 (exact match only).
_MIN_FUZZY_THRESHOLD = 0.85

# Default bound on cache size. The previous implementation used an unbounded
# dict which, under sustained unique input, grew without limit — a memory-
# exhaustion DoS and, combined with fuzzy matching, a long-lived surface for
# cross-request confusion. OrderedDict-backed LRU evicts the oldest entry
# when the cache exceeds this bound.
_DEFAULT_MAX_CACHE_ENTRIES = 1024
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is needed as a global const, the default value should be in the config class, since we only construct CacheMiddleware from the config object, the constructor similarly doesn't need a default value.



class CacheMiddleware(FunctionMiddleware):
"""Cache middleware that memoizes function outputs based on input similarity.
Expand All @@ -67,19 +84,62 @@ class CacheMiddleware(FunctionMiddleware):
computation.
"""

def __init__(self, *, enabled_mode: str, similarity_threshold: float) -> None:
def __init__(
self,
*,
enabled_mode: str,
similarity_threshold: float,
max_entries: int = _DEFAULT_MAX_CACHE_ENTRIES,
) -> None:
"""Initialize the cache middleware.

Args:
enabled_mode: Either "always" or "eval". If "eval", only caches
when Context.is_evaluating is True.
similarity_threshold: Similarity threshold between 0 and 1.
If 1.0, performs exact matching. Otherwise uses fuzzy matching.
similarity_threshold: Similarity threshold in [_MIN_FUZZY_THRESHOLD, 1.0].
If 1.0, performs exact matching. Lower values enable difflib-
based fuzzy matching. Values below _MIN_FUZZY_THRESHOLD are
rejected to prevent cache-poisoning where a crafted input
collides with a legitimate user's cached key.
max_entries: Maximum number of cache entries. When exceeded, the
oldest entry is evicted (LRU). Defaults to
_DEFAULT_MAX_CACHE_ENTRIES. Must be >= 1.
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

Raises:
ValueError: If similarity_threshold is outside [_MIN_FUZZY_THRESHOLD, 1.0]
or max_entries is not a positive integer.
"""
# Reject bool explicitly — `isinstance(True, int)` is True in Python,
# and `True`/`False` silently sneaking through as numeric is a classic
# config bug (user passes the wrong key, gets no error). Check bool
# FIRST so the "must be a number" message doesn't lie.
if isinstance(similarity_threshold, bool):
raise ValueError(
f"similarity_threshold must be a number, got bool ({similarity_threshold!r})")
if not isinstance(similarity_threshold, (int, float)):
raise ValueError(
f"similarity_threshold must be a number, got {type(similarity_threshold).__name__}")
if not math.isfinite(similarity_threshold):
raise ValueError(
f"similarity_threshold must be finite, got {similarity_threshold!r}")
if similarity_threshold < _MIN_FUZZY_THRESHOLD or similarity_threshold > 1.0:
raise ValueError(
f"similarity_threshold={similarity_threshold} is outside the safe range "
f"[{_MIN_FUZZY_THRESHOLD}, 1.0]. Lower values make cache-poisoning trivial — "
"a crafted input can collide with a legitimate user's cached key. Use 1.0 "
"for exact matching (recommended), or a value >= "
f"{_MIN_FUZZY_THRESHOLD} for fuzzy matching.")
# Same bool-as-int foot-gun applies to max_entries.
if isinstance(max_entries, bool) or not isinstance(max_entries, int) or max_entries < 1:
raise ValueError(f"max_entries must be a positive integer, got {max_entries!r}")

super().__init__(is_final=True)
self._enabled_mode = enabled_mode
self._similarity_threshold = similarity_threshold
self._cache: dict[str, Any] = {}
# OrderedDict gives O(1) LRU: move_to_end() on hit, popitem(last=False)
# to evict the oldest when we exceed max_entries.
self._cache: OrderedDict[str, Any] = OrderedDict()
self._max_entries = max_entries

# ==================== Abstract Method Implementations ====================

Expand Down Expand Up @@ -199,20 +259,31 @@ async def function_middleware_invoke(self,
# Phase 1: Preprocess - look for a similar cached input
similar_key = self._find_similar_key(input_str)
if similar_key is not None:
# Cache hit - short-circuit and return cached output
# Cache hit - short-circuit and return cached output.
# Move the hit entry to the MRU end so LRU eviction prefers truly
# old entries, not just recently-useful ones.
logger.debug("Cache hit for function %s with similarity %.2f",
context.name,
1.0 if similar_key == input_str else self._similarity_threshold)
self._cache.move_to_end(similar_key)
# Phase 4: Continue - return cached result
return self._cache[similar_key]

# Phase 2: Call next - no cache hit, call next middleware/function
logger.debug("Cache miss for function %s", context.name)
result = await call_next(*args, **kwargs)

# Phase 3: Postprocess - cache the result for future use
# Phase 3: Postprocess - cache the result for future use. Enforce the
# LRU bound BEFORE insert so the new entry always lands in a cache of
# size <= max_entries, preventing unbounded memory growth (DoS).
self._cache[input_str] = result
logger.debug("Cached result for function %s", context.name)
self._cache.move_to_end(input_str)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this call to self._cache.move_to_end(input_str) needed? If it was just inserted on line 236, shouldn't it already be at the end?

while len(self._cache) > self._max_entries:
self._cache.popitem(last=False)
logger.debug("Cached result for function %s (size=%d/%d)",
context.name,
len(self._cache),
self._max_entries)

# Phase 4: Continue - return the fresh result
return result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from pydantic import Field

from nat.data_models.middleware import FunctionMiddlewareBaseConfig
from nat.middleware.cache.cache_middleware import _DEFAULT_MAX_CACHE_ENTRIES
from nat.middleware.cache.cache_middleware import _MIN_FUZZY_THRESHOLD


class CacheMiddlewareConfig(FunctionMiddlewareBaseConfig, name="cache"):
Expand All @@ -31,14 +33,33 @@ class CacheMiddlewareConfig(FunctionMiddlewareBaseConfig, name="cache"):
enabled_mode: Controls when caching is active:
- "always": Cache is always enabled
- "eval": Cache only active when Context.is_evaluating is True
similarity_threshold: Float between 0 and 1 for input matching:
- 1.0: Exact string matching (fastest)
- < 1.0: Fuzzy matching using difflib similarity
similarity_threshold: Float in [_MIN_FUZZY_THRESHOLD, 1.0] for input
matching:
- 1.0: Exact string matching (fastest, recommended)
- >= _MIN_FUZZY_THRESHOLD: Fuzzy matching via difflib. Values
below this bound are rejected as a cache-poisoning risk —
crafted inputs at lower thresholds can collide with a
legitimate user's cached key.
max_entries: Upper bound on cached entries. When exceeded, the
least-recently-used entry is evicted. Must be a positive int;
defaults to _DEFAULT_MAX_CACHE_ENTRIES.
"""

enabled_mode: Literal["always", "eval"] = Field(
default="eval", description="When caching is enabled: 'always' or 'eval' (only during evaluation)")
similarity_threshold: float = Field(default=1.0,
ge=0.0,
le=1.0,
description="Similarity threshold between 0 and 1. Use 1.0 for exact matching")
similarity_threshold: float = Field(
default=1.0,
ge=_MIN_FUZZY_THRESHOLD,
le=1.0,
description=(
f"Similarity threshold in [{_MIN_FUZZY_THRESHOLD}, 1.0]. Use 1.0 for exact matching "
"(recommended). Lower values enable fuzzy matching but are bounded below to prevent "
"cache-poisoning collisions with legitimate cached keys."),
)
max_entries: int = Field(
default=_DEFAULT_MAX_CACHE_ENTRIES,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
default=_DEFAULT_MAX_CACHE_ENTRIES,
default=1024,

Remove the const global, 1024 isn't a number that needs to be a const.

ge=1,
description=("Maximum number of cache entries before LRU eviction. Must be >= 1. "
"Prevents memory-exhaustion DoS from unbounded cache growth under "
"sustained unique inputs."),
)
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def test_default_initialization(self):

def test_custom_initialization(self):
"""Test custom initialization."""
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.8)
# Use 0.9 (above the enforced minimum) to exercise non-default fuzzy mode.
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.9)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert, I don't think this is needed

# Check attributes are set
assert hasattr(middleware, '_enabled_mode')
assert hasattr(middleware, '_similarity_threshold')
Expand Down Expand Up @@ -108,8 +109,12 @@ async def mock_next_call(*args, **kwargs):
assert result3.result == "Result for test"

async def test_fuzzy_match_caching(self, middleware_context):
"""Test fuzzy matching with similarity_threshold < 1.0."""
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.8)
"""Test fuzzy matching with similarity_threshold < 1.0.

Uses 0.9 (above the enforced minimum) — 0.8 is no longer a valid
threshold after the cache-poisoning hardening.
"""
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.9)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, not needed


call_count = 0

Expand Down Expand Up @@ -267,8 +272,10 @@ async def mock_next_call(*args, **kwargs):

def test_similarity_computation_for_different_thresholds(self):
"""Test similarity computation for different thresholds."""
# This is more of a unit test for the similarity logic
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.5)
# This is more of a unit test for the similarity logic.
# Uses 0.9 (above the enforced minimum) to exercise fuzzy matching
# without enabling cache-poisoning-prone low thresholds.
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.9)

# Directly test internal methods
# Add a cached entry
Expand All @@ -278,14 +285,18 @@ def test_similarity_computation_for_different_thresholds(self):
# Test various similarity levels
# Exact match
assert middleware._find_similar_key(test_key) == test_key # noqa
# Very similar
# Very similar (one char shorter, ~0.95 ratio)
assert middleware._find_similar_key("hello worl") == test_key # noqa
# Too different - use a completely different string
assert middleware._find_similar_key("xyz123abc") is None # noqa

async def test_multiple_similar_entries(self, middleware_context):
"""Test behavior with multiple similar cached entries."""
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.7)
"""Test behavior with multiple similar cached entries.

Uses 0.85 (the enforced minimum) instead of the original 0.7 —
below 0.85 is now rejected as a cache-poisoning risk.
"""
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.85)

# Pre-populate cache with similar entries
key1 = middleware._serialize_input( # noqa
Expand All @@ -306,3 +317,132 @@ async def mock_next_call(*args, **kwargs):
input_str = {"value": "test input X", "number": 42}
await middleware.function_middleware_invoke(input_str, call_next=mock_next_call, context=middleware_context)
# The exact behavior depends on which cached key is most similar


class TestSimilarityThresholdFloor:
"""The constructor must reject similarity thresholds below the safe floor.

Below ~0.85, crafting an input whose difflib ratio exceeds the threshold
against a legitimate cached key is trivial (small edits, common prefixes,
shared structural tokens). Accepting those values silently produces a
cache where one caller can hijack another caller's response.
"""

@pytest.mark.parametrize("threshold", [0.0, 0.3, 0.5, 0.7, 0.84])
def test_below_floor_is_rejected(self, threshold):
with pytest.raises(ValueError, match="outside the safe range"):
CacheMiddleware(enabled_mode="always", similarity_threshold=threshold)

@pytest.mark.parametrize("threshold", [0.85, 0.9, 0.95, 1.0])
def test_at_or_above_floor_is_allowed(self, threshold):
mw = CacheMiddleware(enabled_mode="always", similarity_threshold=threshold)
assert mw._similarity_threshold == threshold # noqa: SLF001

def test_threshold_above_one_is_rejected(self):
with pytest.raises(ValueError, match="outside the safe range"):
CacheMiddleware(enabled_mode="always", similarity_threshold=1.5)

def test_threshold_non_numeric_is_rejected(self):
with pytest.raises(ValueError, match="must be a number"):
CacheMiddleware(enabled_mode="always", similarity_threshold="high") # type: ignore[arg-type]

@pytest.mark.parametrize("bad_bool", [True, False])
def test_threshold_bool_is_rejected(self, bad_bool):
"""`isinstance(True, int)` is True in Python — reject bools explicitly
so a config with the wrong key type doesn't silently become 1.0 or 0.0."""
with pytest.raises(ValueError, match="got bool"):
CacheMiddleware(enabled_mode="always", similarity_threshold=bad_bool) # type: ignore[arg-type]

@pytest.mark.parametrize("bad_value", [float("nan"), float("inf"), float("-inf")])
def test_threshold_non_finite_is_rejected(self, bad_value):
"""NaN, +inf, -inf must be rejected before the range comparison."""
with pytest.raises(ValueError, match="must be finite"):
CacheMiddleware(enabled_mode="always", similarity_threshold=bad_value)


class TestMaxEntriesLruEviction:
"""The cache must bound its size to prevent memory-exhaustion DoS.

The previous implementation used an unbounded dict; sustained unique
inputs would grow the cache without limit, eventually crashing the
process. LRU eviction ensures the cache stays within max_entries.
"""

async def test_default_max_entries_is_positive(self):
mw = CacheMiddleware(enabled_mode="always", similarity_threshold=1.0)
assert mw._max_entries > 0 # noqa: SLF001

def test_zero_max_entries_is_rejected(self):
with pytest.raises(ValueError, match="positive integer"):
CacheMiddleware(enabled_mode="always", similarity_threshold=1.0, max_entries=0)

def test_negative_max_entries_is_rejected(self):
with pytest.raises(ValueError, match="positive integer"):
CacheMiddleware(enabled_mode="always", similarity_threshold=1.0, max_entries=-5)

@pytest.mark.parametrize("bad_bool", [True, False])
def test_bool_max_entries_is_rejected(self, bad_bool):
"""Same bool-as-int foot-gun protection as similarity_threshold."""
with pytest.raises(ValueError, match="positive integer"):
CacheMiddleware(
enabled_mode="always",
similarity_threshold=1.0,
max_entries=bad_bool, # type: ignore[arg-type]
)

async def test_cache_evicts_oldest_when_exceeding_max_entries(self, middleware_context):
"""Insert more unique entries than max_entries; verify size stays bounded."""
mw = CacheMiddleware(
enabled_mode="always",
similarity_threshold=1.0, # exact match keeps the test deterministic
max_entries=3,
)

call_count = 0

async def mock_next_call(*_args, **_kwargs):
nonlocal call_count
call_count += 1
return _TestOutput(result=f"result_{call_count}")

for i in range(10):
await mw.function_middleware_invoke(
{"value": f"unique_input_{i}"},
call_next=mock_next_call,
context=middleware_context,
)

assert len(mw._cache) == 3 # noqa: SLF001
# The MOST recent three inserts should be what's left.
latest_keys = list(mw._cache.keys()) # noqa: SLF001
for i in range(7, 10):
assert any(f"unique_input_{i}" in k for k in latest_keys)

async def test_cache_hit_promotes_entry_to_most_recently_used(self, middleware_context):
"""A cache hit should move the entry to MRU so later evictions spare it."""
mw = CacheMiddleware(
enabled_mode="always",
similarity_threshold=1.0,
max_entries=3,
)

async def mock_next_call(*_args, **_kwargs):
return _TestOutput(result="r")

# Fill the cache with A, B, C (A is oldest)
for key in ("A", "B", "C"):
await mw.function_middleware_invoke(
{"value": key}, call_next=mock_next_call, context=middleware_context)

# Hit A again — should promote A to the MRU end
await mw.function_middleware_invoke(
{"value": "A"}, call_next=mock_next_call, context=middleware_context)

# Now insert D — B (now oldest) should be evicted, not A.
await mw.function_middleware_invoke(
{"value": "D"}, call_next=mock_next_call, context=middleware_context)

keys = "".join(list(mw._cache.keys())) # noqa: SLF001
assert '"value": "A"' in keys
assert '"value": "D"' in keys
assert '"value": "B"' not in keys