diff --git a/packages/server/engine-lib/rocketlib-python/lib/rocketlib/filters.py b/packages/server/engine-lib/rocketlib-python/lib/rocketlib/filters.py index c1842397..66b5cbcd 100644 --- a/packages/server/engine-lib/rocketlib-python/lib/rocketlib/filters.py +++ b/packages/server/engine-lib/rocketlib-python/lib/rocketlib/filters.py @@ -108,6 +108,7 @@ def normalize_tool_input( input_obj: Any, *, extra_envelope_keys: Iterable[str] = (), + strip_keys: Iterable[str] = ('security_context',), parse_json_strings: bool = True, unwrap_pydantic: bool = True, tool_name: str = 'tool', @@ -126,12 +127,16 @@ def normalize_tool_input( whose value is a dict is merged into the top level. Top-level keys win on conflict (so a sibling key beside the envelope overrides the one inside it). - 6. Strip ``security_context`` (engine-injected, never tool args). + 6. Strip every key listed in ``strip_keys``. Args: input_obj: Raw tool input as delivered by the engine's invoke chain. extra_envelope_keys: Additional keys that, like ``input``, wrap the real arguments and should be unwrapped/merged. + strip_keys: Keys to drop from the final dict before returning. + Defaults to ``('security_context',)`` — engine-injected and + never a tool arg. Pass ``()`` to disable stripping, or a list + to add more (e.g. ``('security_context', 'trace_id')``). parse_json_strings: Try ``json.loads`` on string inputs. Set False for tools where the engine path is known never to deliver a JSON-encoded string. @@ -174,7 +179,7 @@ def normalize_tool_input( warning(f'{tool_name}: unexpected input type {type(input_obj).__name__}') return {} - # Shallow-copy so the envelope-merge and ``security_context`` pop below + # Shallow-copy so the envelope-merge and the strip_keys pop below # never mutate a caller-owned dict. input_obj = dict(input_obj) @@ -184,7 +189,8 @@ def normalize_tool_input( extras = {k: v for k, v in input_obj.items() if k != key} input_obj = {**wrapped, **extras} - input_obj.pop('security_context', None) + for key in strip_keys: + input_obj.pop(key, None) return input_obj diff --git a/packages/server/engine-lib/rocketlib-python/tests/test_tool_input_helpers.py b/packages/server/engine-lib/rocketlib-python/tests/test_tool_input_helpers.py index dee85c48..fc0eeb6b 100644 --- a/packages/server/engine-lib/rocketlib-python/tests/test_tool_input_helpers.py +++ b/packages/server/engine-lib/rocketlib-python/tests/test_tool_input_helpers.py @@ -90,7 +90,9 @@ def test_top_level_keys_win_on_conflict(self): result = normalize_tool_input({'input': {'q': 'inner'}, 'q': 'outer'}) assert result == {'q': 'outer'} - def test_security_context_stripped(self): + def test_security_context_stripped_by_default(self): + # ``security_context`` is in the default ``strip_keys`` so callers + # don't have to opt in to engine-injected-key removal. result = normalize_tool_input({'q': 'x', 'security_context': {'user': 'a'}}) assert result == {'q': 'x'} @@ -98,6 +100,37 @@ def test_security_context_stripped_from_inside_input_envelope(self): result = normalize_tool_input({'input': {'q': 'x', 'security_context': {'user': 'a'}}}) assert result == {'q': 'x'} + def test_strip_keys_disabled_keeps_security_context(self): + # Pass an empty ``strip_keys`` to disable the default stripping — + # ``security_context`` is preserved verbatim. + result = normalize_tool_input( + {'q': 'x', 'security_context': {'user': 'a'}}, + strip_keys=(), + ) + assert result == {'q': 'x', 'security_context': {'user': 'a'}} + + def test_strip_keys_custom_replaces_default(self): + # ``strip_keys`` is a replacement, not additive: when the caller + # supplies their own list, ``security_context`` is no longer + # stripped unless the caller includes it. + result = normalize_tool_input( + {'q': 'x', 'security_context': {'user': 'a'}, 'trace_id': 'abc'}, + strip_keys=('trace_id',), + ) + assert result == {'q': 'x', 'security_context': {'user': 'a'}} + + def test_strip_keys_can_drop_multiple(self): + result = normalize_tool_input( + {'q': 'x', 'security_context': {}, 'trace_id': 'abc', 'session': 'z'}, + strip_keys=('security_context', 'trace_id', 'session'), + ) + assert result == {'q': 'x'} + + def test_strip_keys_missing_keys_silently_ignored(self): + # pop(key, None) — listing a key that isn't present is a no-op. + result = normalize_tool_input({'q': 'x'}, strip_keys=('not_present',)) + assert result == {'q': 'x'} + def test_non_dict_input_envelope_left_alone(self): # If 'input' isn't a dict (e.g. an int), the helper should not crash # and should not unwrap — the value stays at the top level.