Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import json
import logging
from collections import OrderedDict
from collections.abc import AsyncIterator
from typing import Any

Expand All @@ -44,6 +45,7 @@
logger = logging.getLogger(__name__)



class CacheMiddleware(FunctionMiddleware):
"""Cache middleware that memoizes function outputs based on input similarity.

Expand All @@ -67,19 +69,35 @@ class CacheMiddleware(FunctionMiddleware):
computation.
"""

def __init__(self, *, enabled_mode: str, similarity_threshold: float) -> None:
def __init__(
self,
*,
enabled_mode: str,
similarity_threshold: float,
max_entries: int,
) -> None:
"""Initialize the cache middleware.

Args:
enabled_mode: Either "always" or "eval". If "eval", only caches
when Context.is_evaluating is True.
similarity_threshold: Similarity threshold between 0 and 1.
If 1.0, performs exact matching. Otherwise uses fuzzy matching.
similarity_threshold: Similarity threshold in [0, 1.0]. If 1.0,
performs exact matching. Lower values enable difflib-based
fuzzy matching; note that difflib is quadratic in the worst
case, so large caches with low thresholds may have a
performance cost. Values near 0 increase the risk of cache
collisions where different inputs return the same cached
response.
max_entries: Maximum number of cache entries. When exceeded, the
oldest entry is evicted (LRU).
"""
super().__init__(is_final=True)
self._enabled_mode = enabled_mode
self._similarity_threshold = similarity_threshold
self._cache: dict[str, Any] = {}
# OrderedDict gives O(1) LRU: move_to_end() on hit, popitem(last=False)
# to evict the oldest when we exceed max_entries.
self._cache: OrderedDict[str, Any] = OrderedDict()
self._max_entries = max_entries

# ==================== Abstract Method Implementations ====================

Expand Down Expand Up @@ -142,22 +160,13 @@ def _find_similar_key(self, input_str: str) -> str | None:
# Exact matching - fast path
return input_str if input_str in self._cache else None

# Fuzzy matching using difflib
import difflib

best_match = None
best_ratio = 0.0

for cached_key in self._cache:
# Use SequenceMatcher for similarity computation
matcher = difflib.SequenceMatcher(None, input_str, cached_key)
ratio = matcher.ratio()

if ratio >= self._similarity_threshold and ratio > best_ratio:
best_ratio = ratio
best_match = cached_key

return best_match
best_matches = difflib.get_close_matches(
input_str, self._cache.keys(), n=1, cutoff=self._similarity_threshold)
if best_matches:
return best_matches[0]
return None

async def function_middleware_invoke(self,
*args: Any,
Expand Down Expand Up @@ -199,20 +208,30 @@ async def function_middleware_invoke(self,
# Phase 1: Preprocess - look for a similar cached input
similar_key = self._find_similar_key(input_str)
if similar_key is not None:
# Cache hit - short-circuit and return cached output
# Cache hit - short-circuit and return cached output.
# Move the hit entry to the MRU end so LRU eviction prefers truly
# old entries, not just recently-useful ones.
logger.debug("Cache hit for function %s with similarity %.2f",
context.name,
1.0 if similar_key == input_str else self._similarity_threshold)
self._cache.move_to_end(similar_key)
# Phase 4: Continue - return cached result
return self._cache[similar_key]

# Phase 2: Call next - no cache hit, call next middleware/function
logger.debug("Cache miss for function %s", context.name)
result = await call_next(*args, **kwargs)

# Phase 3: Postprocess - cache the result for future use
# Phase 3: Postprocess - cache the result for future use. Insert first,
# then enforce the LRU bound so the cache stays within max_entries,
# preventing unbounded memory growth (DoS).
self._cache[input_str] = result
logger.debug("Cached result for function %s", context.name)
while len(self._cache) > self._max_entries:
self._cache.popitem(last=False)
logger.debug("Cached result for function %s (size=%d/%d)",
context.name,
len(self._cache),
self._max_entries)

# Phase 4: Continue - return the fresh result
return result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from nat.data_models.middleware import FunctionMiddlewareBaseConfig


class CacheMiddlewareConfig(FunctionMiddlewareBaseConfig, name="cache"):
"""Configuration for cache middleware.

Expand All @@ -31,14 +30,35 @@ class CacheMiddlewareConfig(FunctionMiddlewareBaseConfig, name="cache"):
enabled_mode: Controls when caching is active:
- "always": Cache is always enabled
- "eval": Cache only active when Context.is_evaluating is True
similarity_threshold: Float between 0 and 1 for input matching:
- 1.0: Exact string matching (fastest)
- < 1.0: Fuzzy matching using difflib similarity
similarity_threshold: Float in [0, 1.0] for input matching:
- 1.0: Exact string matching (fastest, recommended)
- < 1.0: Fuzzy matching via difflib. Note that difflib is
quadratic in the worst case, so large caches with low
thresholds may have a performance cost. Values near 0
increase the risk of cache collisions where different
inputs return the same cached response.
max_entries: Upper bound on cached entries. When exceeded, the
least-recently-used entry is evicted. Must be a positive int;
defaults to 1024.
"""

enabled_mode: Literal["always", "eval"] = Field(
default="eval", description="When caching is enabled: 'always' or 'eval' (only during evaluation)")
similarity_threshold: float = Field(default=1.0,
ge=0.0,
le=1.0,
description="Similarity threshold between 0 and 1. Use 1.0 for exact matching")
similarity_threshold: float = Field(
default=1.0,
ge=0,
le=1.0,
description=(
"Similarity threshold in [0, 1.0]. Use 1.0 for exact matching (recommended). "
"Lower values enable fuzzy matching via difflib; note that difflib is quadratic "
"in the worst case, so large caches with low thresholds may have a performance "
"cost. Values near 0 increase the risk of cache collisions where different "
"inputs return the same cached response."),
)
max_entries: int = Field(
default=1024,
ge=1,
description=("Maximum number of cache entries before LRU eviction. Must be >= 1. "
"Prevents memory-exhaustion DoS from unbounded cache growth under "
"sustained unique inputs."),
)
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,8 @@ async def cache_middleware(config: CacheMiddlewareConfig, builder: Builder):
Yields:
A configured cache middleware instance
"""
yield CacheMiddleware(enabled_mode=config.enabled_mode, similarity_threshold=config.similarity_threshold)
yield CacheMiddleware(
enabled_mode=config.enabled_mode,
similarity_threshold=config.similarity_threshold,
max_entries=config.max_entries,
)
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ class TestCacheMiddlewareInitialization:

def test_default_initialization(self):
"""Test default initialization with required parameters."""
middleware = CacheMiddleware(enabled_mode="eval", similarity_threshold=1.0)
middleware = CacheMiddleware(enabled_mode="eval", similarity_threshold=1.0, max_entries=1024)
# Check internal attributes
assert hasattr(middleware, '_enabled_mode')
assert hasattr(middleware, '_similarity_threshold')
assert middleware.is_final is True

def test_custom_initialization(self):
"""Test custom initialization."""
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.8)
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.9, max_entries=1024)
# Check attributes are set
assert hasattr(middleware, '_enabled_mode')
assert hasattr(middleware, '_similarity_threshold')
Expand All @@ -73,7 +73,7 @@ class TestCacheMiddlewareCaching:

async def test_exact_match_caching(self, middleware_context):
"""Test exact match caching with similarity_threshold=1.0."""
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=1.0)
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=1.0, max_entries=1024)

# Mock the next call
call_count = 0
Expand Down Expand Up @@ -109,7 +109,7 @@ async def mock_next_call(*args, **kwargs):

async def test_fuzzy_match_caching(self, middleware_context):
"""Test fuzzy matching with similarity_threshold < 1.0."""
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.8)
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.9, max_entries=1024)

call_count = 0

Expand Down Expand Up @@ -144,7 +144,7 @@ async def mock_next_call(*args, **kwargs):

async def test_eval_mode_caching(self, middleware_context):
"""Test caching only works in eval mode when configured."""
middleware = CacheMiddleware(enabled_mode="eval", similarity_threshold=1.0)
middleware = CacheMiddleware(enabled_mode="eval", similarity_threshold=1.0, max_entries=1024)

call_count = 0

Expand Down Expand Up @@ -183,7 +183,7 @@ async def mock_next_call(*args, **kwargs):

async def test_serialization_failure(self, middleware_context):
"""Test behavior when input serialization fails."""
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=1.0)
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=1.0, max_entries=1024)

call_count = 0

Expand Down Expand Up @@ -214,7 +214,7 @@ class TestCacheMiddlewareStreaming:

async def test_streaming_bypass(self, middleware_context):
"""Test that streaming always bypasses cache."""
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=1.0)
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=1.0, max_entries=1024)

call_count = 0

Expand Down Expand Up @@ -249,7 +249,7 @@ class TestCacheMiddlewareEdgeCases:

async def test_context_retrieval_failure(self, middleware_context):
"""Test behavior when context retrieval fails in eval mode."""
middleware = CacheMiddleware(enabled_mode="eval", similarity_threshold=1.0)
middleware = CacheMiddleware(enabled_mode="eval", similarity_threshold=1.0, max_entries=1024)

call_count = 0

Expand All @@ -267,8 +267,7 @@ async def mock_next_call(*args, **kwargs):

def test_similarity_computation_for_different_thresholds(self):
"""Test similarity computation for different thresholds."""
# This is more of a unit test for the similarity logic
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.5)
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.9, max_entries=1024)

# Directly test internal methods
# Add a cached entry
Expand All @@ -278,14 +277,14 @@ def test_similarity_computation_for_different_thresholds(self):
# Test various similarity levels
# Exact match
assert middleware._find_similar_key(test_key) == test_key # noqa
# Very similar
# Very similar (one char shorter, ~0.95 ratio)
assert middleware._find_similar_key("hello worl") == test_key # noqa
# Too different - use a completely different string
assert middleware._find_similar_key("xyz123abc") is None # noqa

async def test_multiple_similar_entries(self, middleware_context):
"""Test behavior with multiple similar cached entries."""
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.7)
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.85, max_entries=1024)

# Pre-populate cache with similar entries
key1 = middleware._serialize_input( # noqa
Expand All @@ -306,3 +305,73 @@ async def mock_next_call(*args, **kwargs):
input_str = {"value": "test input X", "number": 42}
await middleware.function_middleware_invoke(input_str, call_next=mock_next_call, context=middleware_context)
# The exact behavior depends on which cached key is most similar


class TestMaxEntriesLruEviction:
"""The cache must bound its size to prevent memory-exhaustion DoS.

The previous implementation used an unbounded dict; sustained unique
inputs would grow the cache without limit, eventually crashing the
process. LRU eviction ensures the cache stays within max_entries.
"""

async def test_default_max_entries_is_positive(self):
mw = CacheMiddleware(enabled_mode="always", similarity_threshold=1.0, max_entries=1024)
assert mw._max_entries > 0 # noqa: SLF001

async def test_cache_evicts_oldest_when_exceeding_max_entries(self, middleware_context):
"""Insert more unique entries than max_entries; verify size stays bounded."""
mw = CacheMiddleware(
enabled_mode="always",
similarity_threshold=1.0, # exact match keeps the test deterministic
max_entries=3,
)

call_count = 0

async def mock_next_call(*_args, **_kwargs):
nonlocal call_count
call_count += 1
return _TestOutput(result=f"result_{call_count}")

for i in range(10):
await mw.function_middleware_invoke(
{"value": f"unique_input_{i}"},
call_next=mock_next_call,
context=middleware_context,
)

assert len(mw._cache) == 3 # noqa: SLF001
# The MOST recent three inserts should be what's left.
latest_keys = list(mw._cache.keys()) # noqa: SLF001
for i in range(7, 10):
assert any(f"unique_input_{i}" in k for k in latest_keys)

async def test_cache_hit_promotes_entry_to_most_recently_used(self, middleware_context):
"""A cache hit should move the entry to MRU so later evictions spare it."""
mw = CacheMiddleware(
enabled_mode="always",
similarity_threshold=1.0,
max_entries=3,
)

async def mock_next_call(*_args, **_kwargs):
return _TestOutput(result="r")

# Fill the cache with A, B, C (A is oldest)
for key in ("A", "B", "C"):
await mw.function_middleware_invoke(
{"value": key}, call_next=mock_next_call, context=middleware_context)

# Hit A again — should promote A to the MRU end
await mw.function_middleware_invoke(
{"value": "A"}, call_next=mock_next_call, context=middleware_context)

# Now insert D — B (now oldest) should be evicted, not A.
await mw.function_middleware_invoke(
{"value": "D"}, call_next=mock_next_call, context=middleware_context)

keys = "".join(list(mw._cache.keys())) # noqa: SLF001
assert '"value": "A"' in keys
assert '"value": "D"' in keys
assert '"value": "B"' not in keys