From 6c9b9100710a1b7e7c305535d47cdebc0a599642 Mon Sep 17 00:00:00 2001 From: Isaac Wasserman Date: Tue, 28 Jan 2025 18:52:48 -0500 Subject: [PATCH 1/4] more robust model generation --- src/langchain_mcp/toolkit.py | 173 +++++++++++++++++++++++++---------- 1 file changed, 126 insertions(+), 47 deletions(-) diff --git a/src/langchain_mcp/toolkit.py b/src/langchain_mcp/toolkit.py index dc0da0a..aea67f7 100644 --- a/src/langchain_mcp/toolkit.py +++ b/src/langchain_mcp/toolkit.py @@ -2,16 +2,18 @@ # SPDX-License-Identifier: MIT import asyncio +import sys import warnings from collections.abc import Callable +from enum import Enum +from typing import Any, Union import pydantic import pydantic_core import typing_extensions as t from langchain_core.tools.base import BaseTool, BaseToolkit, ToolException from mcp import ClientSession, ListToolsResult -from pydantic.json_schema import JsonSchemaValue -from pydantic_core import core_schema as cs +from pydantic import BaseModel, Field, create_model class MCPToolkit(BaseToolkit): @@ -42,59 +44,136 @@ def get_tools(self) -> list[BaseTool]: session=self.session, name=tool.name, description=tool.description or "", - args_schema=create_schema_model(tool.inputSchema), + args_schema=create_model_from_schema(tool.inputSchema, tool.name), ) # list_tools returns a PaginatedResult, but I don't see a way to pass the cursor to retrieve more tools for tool in self._tools.tools ] -TYPEMAP = { - "integer": int, - "number": float, - "array": list, - "boolean": bool, - "string": str, - "null": type(None), -} - -FIELD_DEFAULTS = { - int: 0, - float: 0.0, - list: [], - bool: False, - str: "", - type(None): None, -} - - -def configure_field(name: str, type_: dict[str, t.Any], required: list[str]) -> tuple[type, t.Any]: - field_type = TYPEMAP[type_["type"]] - default_ = FIELD_DEFAULTS.get(field_type) if name not in required else ... - return field_type, default_ - - -def create_schema_model(schema: dict[str, t.Any]) -> type[pydantic.BaseModel]: - # Create a new model class that returns our JSON schema. - # LangChain requires a BaseModel class. - class SchemaBase(pydantic.BaseModel): - model_config = pydantic.ConfigDict(extra="allow") - - @t.override - @classmethod - def __get_pydantic_json_schema__( - cls, core_schema: cs.CoreSchema, handler: pydantic.GetJsonSchemaHandler - ) -> JsonSchemaValue: - return schema - - # Since this langchain patch, we need to synthesize pydantic fields from the schema - # https://github.com/langchain-ai/langchain/commit/033ac417609297369eb0525794d8b48a425b8b33 +TYPEMAP = {"string": str, "integer": int, "number": float, "boolean": bool, "array": list, "object": dict, "null": None} + + +def resolve_ref(root_schema: dict[str, Any], ref: str) -> dict[str, Any]: + """Resolve a $ref pointer in the schema""" + if not ref.startswith("#/"): + raise ValueError(f"Only local references supported: {ref}") + + path = ref.lstrip("#/").split("/") + current = root_schema + + for part in path: + if part not in current: + raise ValueError(f"Could not find {part} in schema. Available keys: {list(current.keys())}") + current = current[part] + + return current + + +def get_field_type(root_schema: dict[str, Any], type_def: dict[str, Any]) -> Any: + """Convert JSON schema type definition to Python/Pydantic type""" + # Handle non-dict type definitions (like when additionalProperties is a boolean) + if not isinstance(type_def, dict): + return Any + + if "$ref" in type_def: + referenced_schema = resolve_ref(root_schema, type_def["$ref"]) + # Create a forward reference since the model might not exist yet + return f"'{referenced_schema.get('title', 'UntitledModel')}'" + + if "enum" in type_def: + # Create an Enum class for this field + enum_name = f"Enum_{hash(str(type_def['enum']))}" + enum_values = {str(v): v for v in type_def["enum"]} + return Enum(enum_name, enum_values) + + if "anyOf" in type_def: + types = [get_field_type(root_schema, t) for t in type_def["anyOf"]] + # Remove None from types list to handle it separately + types = [t for t in types if t is not None] + if None in [get_field_type(root_schema, t) for t in type_def["anyOf"]]: + # If None is one of the possible types, make the field optional + if len(types) == 1: + return types[0] | None + return Union[tuple(types)] | None # noqa: UP007 + if len(types) == 1: + return types[0] + return Union[tuple(types)] # noqa: UP007 + + if "type" not in type_def: + return Any + + type_name = type_def["type"] + if type_name == "array": + if "items" in type_def: + item_type = get_field_type(root_schema, type_def["items"]) + return list[item_type] + return list[Any] + + if type_name == "object": + if "additionalProperties" in type_def: + additional_props = type_def["additionalProperties"] + # Handle case where additionalProperties is a boolean + if isinstance(additional_props, bool): + return dict[str, Any] if additional_props else dict[str, Any] + # Handle case where additionalProperties is a schema + value_type = get_field_type(root_schema, additional_props) + return dict[str, value_type] + return dict[str, Any] + + return TYPEMAP.get(type_name, Any) + + +def create_model_from_schema( + schema: dict[str, Any], name: str, root_schema: dict[str, Any] | None = None, created_models: set[str] | None = None +) -> type[BaseModel]: + """Create a Pydantic model from a JSON schema definition + + Args: + schema: The schema for this specific model + name: Name for the model + root_schema: The complete schema containing all definitions + created_models: Set to track which models have already been created + """ + # Initialize tracking of created models + if created_models is None: + created_models = set() + + # If root_schema is not provided, use the current schema as root + if root_schema is None: + root_schema = schema + + # If we've already created this model, return its class from the module + if name in created_models: + return getattr(sys.modules[__name__], name) + + # Add this model to created_models before processing to handle circular references + created_models.add(name) + + # Create referenced models first if we have definitions + if "$defs" in root_schema: + for model_name, model_schema in root_schema["$defs"].items(): + if model_schema.get("type") == "object" and model_name not in created_models: + create_model_from_schema(model_schema, model_name, root_schema, created_models) + + properties = schema.get("properties", {}) required = schema.get("required", []) - fields: dict[str, t.Any] = { - name: configure_field(name, type_, required) for name, type_ in schema["properties"].items() - } - return pydantic.create_model("Schema", __base__=SchemaBase, **fields) + fields = {} + for field_name, field_schema in properties.items(): + field_type = get_field_type(root_schema, field_schema) + default = field_schema.get("default", ...) + if field_name not in required and default is ...: + field_type = field_type | None + default = None + + description = field_schema.get("description", "") + fields[field_name] = (field_type, Field(default=default, description=description)) + + model = create_model(name, **fields) + # Add model to the module's namespace so it can be referenced + setattr(sys.modules[__name__], name, model) + return model class MCPTool(BaseTool): From d0e929398247b6cbb372c736c689ad7d056c5901 Mon Sep 17 00:00:00 2001 From: Isaac Wasserman Date: Tue, 28 Jan 2025 19:04:10 -0500 Subject: [PATCH 2/4] fixed types --- src/langchain_mcp/toolkit.py | 55 ++++++++++++++++++++++-------------- 1 file changed, 34 insertions(+), 21 deletions(-) diff --git a/src/langchain_mcp/toolkit.py b/src/langchain_mcp/toolkit.py index aea67f7..5dec7cf 100644 --- a/src/langchain_mcp/toolkit.py +++ b/src/langchain_mcp/toolkit.py @@ -6,7 +6,7 @@ import warnings from collections.abc import Callable from enum import Enum -from typing import Any, Union +from typing import Any, Dict, List, Type, TypeVar, Union import pydantic import pydantic_core @@ -46,12 +46,22 @@ def get_tools(self) -> list[BaseTool]: description=tool.description or "", args_schema=create_model_from_schema(tool.inputSchema, tool.name), ) - # list_tools returns a PaginatedResult, but I don't see a way to pass the cursor to retrieve more tools for tool in self._tools.tools ] -TYPEMAP = {"string": str, "integer": int, "number": float, "boolean": bool, "array": list, "object": dict, "null": None} +# Define type alias for clarity +JsonSchemaType = Type[Any] + +TYPEMAP: dict[str, JsonSchemaType] = { + "string": str, + "integer": int, + "number": float, + "boolean": bool, + "array": List, + "object": Dict, + "null": type(None), +} def resolve_ref(root_schema: dict[str, Any], ref: str) -> dict[str, Any]: @@ -79,7 +89,7 @@ def get_field_type(root_schema: dict[str, Any], type_def: dict[str, Any]) -> Any if "$ref" in type_def: referenced_schema = resolve_ref(root_schema, type_def["$ref"]) # Create a forward reference since the model might not exist yet - return f"'{referenced_schema.get('title', 'UntitledModel')}'" + return referenced_schema.get("title", "UntitledModel") if "enum" in type_def: # Create an Enum class for this field @@ -90,15 +100,15 @@ def get_field_type(root_schema: dict[str, Any], type_def: dict[str, Any]) -> Any if "anyOf" in type_def: types = [get_field_type(root_schema, t) for t in type_def["anyOf"]] # Remove None from types list to handle it separately - types = [t for t in types if t is not None] - if None in [get_field_type(root_schema, t) for t in type_def["anyOf"]]: + types = [t for t in types if t is not type(None)] # noqa: E721 + if type(None) in [get_field_type(root_schema, t) for t in type_def["anyOf"]]: # If None is one of the possible types, make the field optional if len(types) == 1: - return types[0] | None - return Union[tuple(types)] | None # noqa: UP007 + return Union[types[0], type(None)] + return Union[tuple(types + [type(None)])] if len(types) == 1: return types[0] - return Union[tuple(types)] # noqa: UP007 + return Union[tuple(types)] if "type" not in type_def: return Any @@ -107,26 +117,29 @@ def get_field_type(root_schema: dict[str, Any], type_def: dict[str, Any]) -> Any if type_name == "array": if "items" in type_def: item_type = get_field_type(root_schema, type_def["items"]) - return list[item_type] - return list[Any] + return List[item_type] # type: ignore + return List[Any] if type_name == "object": if "additionalProperties" in type_def: additional_props = type_def["additionalProperties"] # Handle case where additionalProperties is a boolean if isinstance(additional_props, bool): - return dict[str, Any] if additional_props else dict[str, Any] + return Dict[str, Any] # Handle case where additionalProperties is a schema value_type = get_field_type(root_schema, additional_props) - return dict[str, value_type] - return dict[str, Any] + return Dict[str, value_type] # type: ignore + return Dict[str, Any] return TYPEMAP.get(type_name, Any) +ModelType = TypeVar("ModelType", bound=BaseModel) + + def create_model_from_schema( schema: dict[str, Any], name: str, root_schema: dict[str, Any] | None = None, created_models: set[str] | None = None -) -> type[BaseModel]: +) -> Type[ModelType]: """Create a Pydantic model from a JSON schema definition Args: @@ -159,18 +172,18 @@ def create_model_from_schema( properties = schema.get("properties", {}) required = schema.get("required", []) - fields = {} + fields: dict[str, tuple[Any, Any]] = {} for field_name, field_schema in properties.items(): field_type = get_field_type(root_schema, field_schema) default = field_schema.get("default", ...) if field_name not in required and default is ...: - field_type = field_type | None + field_type = Union[field_type, type(None)] default = None description = field_schema.get("description", "") fields[field_name] = (field_type, Field(default=default, description=description)) - model = create_model(name, **fields) + model = create_model(name, **fields) # type: ignore # Add model to the module's namespace so it can be referenced setattr(sys.modules[__name__], name, model) return model @@ -185,7 +198,7 @@ class MCPTool(BaseTool): handle_tool_error: bool | str | Callable[[ToolException], str] | None = True @t.override - def _run(self, *args: t.Any, **kwargs: t.Any) -> t.Any: + def _run(self, *args: Any, **kwargs: Any) -> Any: warnings.warn( "Invoke this tool asynchronousely using `ainvoke`. This method exists only to satisfy standard tests.", stacklevel=1, @@ -193,7 +206,7 @@ def _run(self, *args: t.Any, **kwargs: t.Any) -> t.Any: return asyncio.run(self._arun(*args, **kwargs)) @t.override - async def _arun(self, *args: t.Any, **kwargs: t.Any) -> t.Any: + async def _arun(self, *args: Any, **kwargs: Any) -> Any: result = await self.session.call_tool(self.name, arguments=kwargs) content = pydantic_core.to_json(result.content).decode() if result.isError: @@ -202,6 +215,6 @@ async def _arun(self, *args: t.Any, **kwargs: t.Any) -> t.Any: @t.override @property - def tool_call_schema(self) -> type[pydantic.BaseModel]: + def tool_call_schema(self) -> Type[pydantic.BaseModel]: assert self.args_schema is not None # noqa: S101 return self.args_schema From 86a9aeec5de328a4f3e15061080aa3562f10f9b1 Mon Sep 17 00:00:00 2001 From: Isaac Wasserman Date: Tue, 28 Jan 2025 19:05:37 -0500 Subject: [PATCH 3/4] fixed types --- src/langchain_mcp/toolkit.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/langchain_mcp/toolkit.py b/src/langchain_mcp/toolkit.py index 5dec7cf..019a91c 100644 --- a/src/langchain_mcp/toolkit.py +++ b/src/langchain_mcp/toolkit.py @@ -6,7 +6,7 @@ import warnings from collections.abc import Callable from enum import Enum -from typing import Any, Dict, List, Type, TypeVar, Union +from typing import Any, TypeVar, Union import pydantic import pydantic_core @@ -51,15 +51,15 @@ def get_tools(self) -> list[BaseTool]: # Define type alias for clarity -JsonSchemaType = Type[Any] +JsonSchemaType = type[Any] TYPEMAP: dict[str, JsonSchemaType] = { "string": str, "integer": int, "number": float, "boolean": bool, - "array": List, - "object": Dict, + "array": list, + "object": dict, "null": type(None), } @@ -117,19 +117,19 @@ def get_field_type(root_schema: dict[str, Any], type_def: dict[str, Any]) -> Any if type_name == "array": if "items" in type_def: item_type = get_field_type(root_schema, type_def["items"]) - return List[item_type] # type: ignore - return List[Any] + return list[item_type] # type: ignore + return list[Any] if type_name == "object": if "additionalProperties" in type_def: additional_props = type_def["additionalProperties"] # Handle case where additionalProperties is a boolean if isinstance(additional_props, bool): - return Dict[str, Any] + return dict[str, Any] # Handle case where additionalProperties is a schema value_type = get_field_type(root_schema, additional_props) - return Dict[str, value_type] # type: ignore - return Dict[str, Any] + return dict[str, value_type] # type: ignore + return dict[str, Any] return TYPEMAP.get(type_name, Any) @@ -139,7 +139,7 @@ def get_field_type(root_schema: dict[str, Any], type_def: dict[str, Any]) -> Any def create_model_from_schema( schema: dict[str, Any], name: str, root_schema: dict[str, Any] | None = None, created_models: set[str] | None = None -) -> Type[ModelType]: +) -> type[ModelType]: """Create a Pydantic model from a JSON schema definition Args: @@ -215,6 +215,6 @@ async def _arun(self, *args: Any, **kwargs: Any) -> Any: @t.override @property - def tool_call_schema(self) -> Type[pydantic.BaseModel]: + def tool_call_schema(self) -> type[pydantic.BaseModel]: assert self.args_schema is not None # noqa: S101 return self.args_schema From 00ac05b843fcff4e526b95b14ca9714e784bcc86 Mon Sep 17 00:00:00 2001 From: Isaac Wasserman Date: Tue, 28 Jan 2025 19:12:53 -0500 Subject: [PATCH 4/4] more type fixing --- src/langchain_mcp/toolkit.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/langchain_mcp/toolkit.py b/src/langchain_mcp/toolkit.py index 019a91c..0a7b311 100644 --- a/src/langchain_mcp/toolkit.py +++ b/src/langchain_mcp/toolkit.py @@ -104,11 +104,11 @@ def get_field_type(root_schema: dict[str, Any], type_def: dict[str, Any]) -> Any if type(None) in [get_field_type(root_schema, t) for t in type_def["anyOf"]]: # If None is one of the possible types, make the field optional if len(types) == 1: - return Union[types[0], type(None)] - return Union[tuple(types + [type(None)])] + return types[0] | type(None) + return Union[tuple(types + [type(None)])] # noqa: UP007 if len(types) == 1: return types[0] - return Union[tuple(types)] + return Union[tuple(types)] # noqa: UP007 if "type" not in type_def: return Any @@ -177,7 +177,7 @@ def create_model_from_schema( field_type = get_field_type(root_schema, field_schema) default = field_schema.get("default", ...) if field_name not in required and default is ...: - field_type = Union[field_type, type(None)] + field_type = field_type | type(None) default = None description = field_schema.get("description", "")