-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Add ToolAwareContextFilterPlugin to preserve tool call sequences #4074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
f755dca
7e130a3
1213b65
f1677d9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,269 @@ | ||||||
| # Copyright 2025 Google LLC | ||||||
| # | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
| # you may not use this file except in compliance with the License. | ||||||
| # You may obtain a copy of the License at | ||||||
| # | ||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||
| # | ||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | ||||||
|
|
||||||
| """Tool-aware context filter plugin for managing conversation history. | ||||||
|
|
||||||
| This plugin extends the standard context filtering to properly handle function | ||||||
| call/response sequences, ensuring they remain atomic during history trimming. | ||||||
|
|
||||||
| PROBLEM WITH STANDARD ContextFilterPlugin: | ||||||
| ========================================== | ||||||
| The standard ContextFilterPlugin treats each model message as a separate | ||||||
| "invocation", but when a model makes a tool call, it creates MULTIPLE model | ||||||
| messages in sequence: | ||||||
| 1. Model message with function_call | ||||||
| 2. User message with function_response (tool result) | ||||||
| 3. Model message with final text response | ||||||
|
|
||||||
| When filtering to keep N "invocations", the standard plugin can split these | ||||||
| related messages apart, creating orphaned function_responses without their | ||||||
| corresponding function_calls, which violates OpenAI API requirements. | ||||||
|
|
||||||
| HOW THIS PLUGIN SOLVES IT: | ||||||
| =========================== | ||||||
| This plugin groups messages into LOGICAL invocations where a complete cycle is: | ||||||
| - User query (one or more messages) | ||||||
| - Model response (possibly with function_call) | ||||||
| - Function response(s) (if tool was called) | ||||||
| - Model final response (after tool execution) | ||||||
|
|
||||||
| All messages in a tool call sequence are kept together as an atomic unit. | ||||||
| """ | ||||||
|
|
||||||
| from __future__ import annotations | ||||||
|
|
||||||
| import logging | ||||||
| from typing import Callable, List, Optional | ||||||
|
|
||||||
| from google.adk.agents.callback_context import CallbackContext | ||||||
| from google.adk.events.event import Event | ||||||
| from google.adk.models.llm_request import LlmRequest | ||||||
| from google.adk.models.llm_response import LlmResponse | ||||||
| from google.adk.plugins.base_plugin import BasePlugin | ||||||
|
|
||||||
| logger = logging.getLogger("google_adk." + __name__) | ||||||
|
|
||||||
|
|
||||||
| class ToolAwareContextFilterPlugin(BasePlugin): | ||||||
| """A plugin that filters LLM context while preserving tool call sequences. | ||||||
|
|
||||||
| This plugin extends context filtering to handle function call/response pairs | ||||||
| correctly, ensuring they are never split during history trimming. | ||||||
| """ | ||||||
|
|
||||||
| def __init__( | ||||||
| self, | ||||||
| num_invocations_to_keep: Optional[int] = None, | ||||||
| custom_filter: Optional[Callable[[List[Event]], List[Event]]] = None, | ||||||
| name: str = "tool_aware_context_filter_plugin", | ||||||
| ): | ||||||
| """Initializes the tool-aware context filter plugin. | ||||||
|
|
||||||
| Args: | ||||||
| num_invocations_to_keep: The number of last invocations to keep. An | ||||||
| invocation is defined as a complete user-model interaction cycle, | ||||||
| including any tool calls and their responses. | ||||||
| custom_filter: A function to apply additional filtering to the context. | ||||||
| name: The name of the plugin instance. | ||||||
| """ | ||||||
| super().__init__(name) | ||||||
| self._num_invocations_to_keep = num_invocations_to_keep | ||||||
| self._custom_filter = custom_filter | ||||||
|
|
||||||
| @staticmethod | ||||||
| def _has_function_call(content) -> bool: | ||||||
| """Check if a content has a function_call part.""" | ||||||
| if not content.parts: | ||||||
| return False | ||||||
| return any( | ||||||
| hasattr(part, "function_call") and part.function_call | ||||||
| for part in content.parts | ||||||
| ) | ||||||
|
|
||||||
| @staticmethod | ||||||
| def _has_function_response(content) -> bool: | ||||||
|
Comment on lines
+86
to
+97
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The methods |
||||||
| """Check if a content has a function_response part.""" | ||||||
| if not content.parts: | ||||||
| return False | ||||||
| return any( | ||||||
| hasattr(part, "function_response") and part.function_response | ||||||
| for part in content.parts | ||||||
| ) | ||||||
|
|
||||||
| def _group_into_invocations(self, contents: List) -> List[List[int]]: | ||||||
|
||||||
| def _group_into_invocations(self, contents: List) -> List[List[int]]: | |
| def _group_into_invocations(self, contents: List[types.Content]) -> List[List[int]]: |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _group_into_invocations method is quite long and contains complex nested logic, making it hard to follow. For improved maintainability and readability, consider refactoring it by breaking it down into smaller, more focused helper methods. For example, you could have separate methods for processing 'user' messages and 'model' messages. This would make the main loop simpler and the logic for each case easier to follow and test in isolation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type hint for
custom_filterisOptional[Callable[[List[Event]], List[Event]]]. However, the filter is applied tollm_request.contents, which is of typeList[types.Content]. The type hint should beOptional[Callable[[List[types.Content]], List[types.Content]]]to match the actual usage. You will also need to addfrom google.genai import typesat the top of the file.