Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,30 @@ async def ainvoke(self, **kwargs: Any) -> str:
message = f"Failed to invoke tool '{self.name}' with args: {kwargs} , got error: {e!s}"
raise MCPInvocationError(message, self.name, kwargs) from e

def _get_valid_inputs(self) -> set[str]:
"""
Return the set of valid input parameter names from the MCP tool schema.

Used to validate that `inputs_from_state` only references parameters that actually exist.
Unlike the default implementation that introspects the function signature,
this returns parameters from the MCP tool's JSON schema.

When eager_connect=False and we have placeholder parameters, returns an empty set
to skip validation until warm_up() is called.

:returns: Set of valid input parameter names from the MCP tool schema.
"""
# Get parameters from the JSON schema (not from function introspection)
# MCPTool uses _invoke_tool(**kwargs) so introspection would only find 'kwargs'
properties = self.parameters.get("properties", {})

# If we have placeholder parameters (eager_connect=False), return empty set to skip validation
# Validation will happen during warm_up when real schema is fetched
if not properties:
return set()

return set(properties.keys())

def warm_up(self) -> None:
"""Connect and fetch the tool schema if eager_connect is turned off."""
with self._lock:
Expand All @@ -1111,6 +1135,19 @@ def warm_up(self) -> None:
tool = self._connect_and_initialize(self.name)
self.parameters = tool.inputSchema

# Validate inputs_from_state now that we have the real schema
# Note: Duplicates Tool.__post_init__() logic, but needed here for early error detection
# when eager_connect=False (validation was skipped during __init__ via empty _get_valid_inputs())
if self._inputs_from_state:
valid_inputs = set(self.parameters.get("properties", {}).keys())
for state_key, param_name in self._inputs_from_state.items():
if param_name not in valid_inputs:
msg = (
f"inputs_from_state maps '{state_key}' to unknown parameter '{param_name}'. "
f"Valid parameters are: {valid_inputs}."
)
raise ValueError(msg)

# Remove inputs_from_state keys from parameters schema if present
# This matches the behavior of ComponentTool
if self._inputs_from_state and "properties" in self.parameters:
Expand Down
7 changes: 4 additions & 3 deletions integrations/mcp/tests/test_mcp_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,13 @@ def test_mcp_tool_serde_with_state_mapping(self, mcp_tool_cleanup):
server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server)

# Create tool with state-mapping parameters
# The 'add' tool has parameters 'a' and 'b', so we map to 'a'
tool = MCPTool(
name="add",
server_info=server_info,
eager_connect=False,
outputs_to_string={"source": "result"},
inputs_from_state={"filter": "query_filter"},
inputs_from_state={"state_a": "a"},
outputs_to_state={"result": {"source": "output"}},
)
mcp_tool_cleanup(tool)
Expand All @@ -184,7 +185,7 @@ def test_mcp_tool_serde_with_state_mapping(self, mcp_tool_cleanup):

# Verify state-mapping parameters are serialized
assert tool_dict["data"]["outputs_to_string"] == {"source": "result"}
assert tool_dict["data"]["inputs_from_state"] == {"filter": "query_filter"}
assert tool_dict["data"]["inputs_from_state"] == {"state_a": "a"}
assert tool_dict["data"]["outputs_to_state"] == {"result": {"source": "output"}}

# Test deserialization (from_dict)
Expand All @@ -193,7 +194,7 @@ def test_mcp_tool_serde_with_state_mapping(self, mcp_tool_cleanup):

# Verify state-mapping parameters are restored
assert new_tool._outputs_to_string == {"source": "result"}
assert new_tool._inputs_from_state == {"filter": "query_filter"}
assert new_tool._inputs_from_state == {"state_a": "a"}
assert new_tool._outputs_to_state == {"result": {"source": "output"}}

@pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set")
Expand Down