Skip to content

Model generation #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 130 additions & 38 deletions src/langchain_mcp/toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@
# SPDX-License-Identifier: MIT

import asyncio
import sys
import warnings
from collections.abc import Callable
from enum import Enum
from typing import Any, TypeVar, Union

import pydantic
import pydantic_core
import typing_extensions as t
from langchain_core.tools.base import BaseTool, BaseToolkit, ToolException
from mcp import ClientSession, ListToolsResult
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import core_schema as cs
from pydantic import BaseModel, Field, create_model


class MCPToolkit(BaseToolkit):
Expand Down Expand Up @@ -42,59 +44,149 @@ def get_tools(self) -> list[BaseTool]:
session=self.session,
name=tool.name,
description=tool.description or "",
args_schema=create_schema_model(tool.inputSchema),
args_schema=create_model_from_schema(tool.inputSchema, tool.name),
)
# list_tools returns a PaginatedResult, but I don't see a way to pass the cursor to retrieve more tools
for tool in self._tools.tools
]


TYPEMAP = {
# Define type alias for clarity
JsonSchemaType = type[Any]

TYPEMAP: dict[str, JsonSchemaType] = {
"string": str,
"integer": int,
"number": float,
"array": list,
"boolean": bool,
"string": str,
"array": list,
"object": dict,
"null": type(None),
}

FIELD_DEFAULTS = {
int: 0,
float: 0.0,
list: [],
bool: False,
str: "",
type(None): None,
}

def resolve_ref(root_schema: dict[str, Any], ref: str) -> dict[str, Any]:
"""Resolve a $ref pointer in the schema"""
if not ref.startswith("#/"):
raise ValueError(f"Only local references supported: {ref}")

path = ref.lstrip("#/").split("/")
current = root_schema

for part in path:
if part not in current:
raise ValueError(f"Could not find {part} in schema. Available keys: {list(current.keys())}")
current = current[part]

return current


def get_field_type(root_schema: dict[str, Any], type_def: dict[str, Any]) -> Any:
"""Convert JSON schema type definition to Python/Pydantic type"""
# Handle non-dict type definitions (like when additionalProperties is a boolean)
if not isinstance(type_def, dict):
return Any

if "$ref" in type_def:
referenced_schema = resolve_ref(root_schema, type_def["$ref"])
# Create a forward reference since the model might not exist yet
return referenced_schema.get("title", "UntitledModel")

if "enum" in type_def:
# Create an Enum class for this field
enum_name = f"Enum_{hash(str(type_def['enum']))}"
enum_values = {str(v): v for v in type_def["enum"]}
return Enum(enum_name, enum_values)

if "anyOf" in type_def:
types = [get_field_type(root_schema, t) for t in type_def["anyOf"]]
# Remove None from types list to handle it separately
types = [t for t in types if t is not type(None)] # noqa: E721
if type(None) in [get_field_type(root_schema, t) for t in type_def["anyOf"]]:
# If None is one of the possible types, make the field optional
if len(types) == 1:
return types[0] | type(None)
return Union[tuple(types + [type(None)])] # noqa: UP007
if len(types) == 1:
return types[0]
return Union[tuple(types)] # noqa: UP007

if "type" not in type_def:
return Any

type_name = type_def["type"]
if type_name == "array":
if "items" in type_def:
item_type = get_field_type(root_schema, type_def["items"])
return list[item_type] # type: ignore
return list[Any]

if type_name == "object":
if "additionalProperties" in type_def:
additional_props = type_def["additionalProperties"]
# Handle case where additionalProperties is a boolean
if isinstance(additional_props, bool):
return dict[str, Any]
# Handle case where additionalProperties is a schema
value_type = get_field_type(root_schema, additional_props)
return dict[str, value_type] # type: ignore
return dict[str, Any]

return TYPEMAP.get(type_name, Any)


ModelType = TypeVar("ModelType", bound=BaseModel)


def create_model_from_schema(
schema: dict[str, Any], name: str, root_schema: dict[str, Any] | None = None, created_models: set[str] | None = None
) -> type[ModelType]:
"""Create a Pydantic model from a JSON schema definition

Args:
schema: The schema for this specific model
name: Name for the model
root_schema: The complete schema containing all definitions
created_models: Set to track which models have already been created
"""
# Initialize tracking of created models
if created_models is None:
created_models = set()

def configure_field(name: str, type_: dict[str, t.Any], required: list[str]) -> tuple[type, t.Any]:
field_type = TYPEMAP[type_["type"]]
default_ = FIELD_DEFAULTS.get(field_type) if name not in required else ...
return field_type, default_
# If root_schema is not provided, use the current schema as root
if root_schema is None:
root_schema = schema

# If we've already created this model, return its class from the module
if name in created_models:
return getattr(sys.modules[__name__], name)

def create_schema_model(schema: dict[str, t.Any]) -> type[pydantic.BaseModel]:
# Create a new model class that returns our JSON schema.
# LangChain requires a BaseModel class.
class SchemaBase(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra="allow")
# Add this model to created_models before processing to handle circular references
created_models.add(name)

@t.override
@classmethod
def __get_pydantic_json_schema__(
cls, core_schema: cs.CoreSchema, handler: pydantic.GetJsonSchemaHandler
) -> JsonSchemaValue:
return schema
# Create referenced models first if we have definitions
if "$defs" in root_schema:
for model_name, model_schema in root_schema["$defs"].items():
if model_schema.get("type") == "object" and model_name not in created_models:
create_model_from_schema(model_schema, model_name, root_schema, created_models)

# Since this langchain patch, we need to synthesize pydantic fields from the schema
# https://github.com/langchain-ai/langchain/commit/033ac417609297369eb0525794d8b48a425b8b33
properties = schema.get("properties", {})
required = schema.get("required", [])
fields: dict[str, t.Any] = {
name: configure_field(name, type_, required) for name, type_ in schema["properties"].items()
}

return pydantic.create_model("Schema", __base__=SchemaBase, **fields)
fields: dict[str, tuple[Any, Any]] = {}
for field_name, field_schema in properties.items():
field_type = get_field_type(root_schema, field_schema)
default = field_schema.get("default", ...)
if field_name not in required and default is ...:
field_type = field_type | type(None)
default = None

description = field_schema.get("description", "")
fields[field_name] = (field_type, Field(default=default, description=description))

model = create_model(name, **fields) # type: ignore
# Add model to the module's namespace so it can be referenced
setattr(sys.modules[__name__], name, model)
return model


class MCPTool(BaseTool):
Expand All @@ -106,15 +198,15 @@ class MCPTool(BaseTool):
handle_tool_error: bool | str | Callable[[ToolException], str] | None = True

@t.override
def _run(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
def _run(self, *args: Any, **kwargs: Any) -> Any:
warnings.warn(
"Invoke this tool asynchronousely using `ainvoke`. This method exists only to satisfy standard tests.",
stacklevel=1,
)
return asyncio.run(self._arun(*args, **kwargs))

@t.override
async def _arun(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
result = await self.session.call_tool(self.name, arguments=kwargs)
content = pydantic_core.to_json(result.content).decode()
if result.isError:
Expand Down