Skip to content

Commit 7033692

Browse files
committed
langchain now supports raw JSON schema for tools
langchain-ai/langchain#29812
1 parent fcceaeb commit 7033692

File tree

1 file changed

+1
-58
lines changed

1 file changed

+1
-58
lines changed

src/langchain_mcp/toolkit.py

Lines changed: 1 addition & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
from langchain_core.tools.base import BaseTool, BaseToolkit, ToolException
1212
from mcp import ClientSession, ListToolsResult
1313
from mcp.types import EmbeddedResource, ImageContent, TextContent
14-
from pydantic.json_schema import JsonSchemaValue
15-
from pydantic_core import core_schema as cs
1614

1715

1816
class MCPToolkit(BaseToolkit):
@@ -43,62 +41,13 @@ def get_tools(self) -> list[BaseTool]:
4341
session=self.session,
4442
name=tool.name,
4543
description=tool.description or "",
46-
args_schema=create_schema_model(tool.inputSchema),
44+
args_schema=tool.inputSchema,
4745
)
4846
# list_tools returns a PaginatedResult, but I don't see a way to pass the cursor to retrieve more tools
4947
for tool in self._tools.tools
5048
]
5149

5250

53-
TYPEMAP = {
54-
"integer": int,
55-
"number": float,
56-
"array": list,
57-
"boolean": bool,
58-
"string": str,
59-
"object": object,
60-
"null": type(None),
61-
}
62-
63-
FIELD_DEFAULTS = {
64-
int: 0,
65-
float: 0.0,
66-
list: [],
67-
bool: False,
68-
str: "",
69-
type(None): None,
70-
}
71-
72-
73-
def configure_field(name: str, type_: dict[str, t.Any], required: list[str]) -> tuple[type, t.Any]:
74-
field_type = TYPEMAP[type_["type"]]
75-
default_ = FIELD_DEFAULTS.get(field_type) if name not in required else ...
76-
return field_type, default_
77-
78-
79-
def create_schema_model(schema: dict[str, t.Any]) -> type[pydantic.BaseModel]:
80-
# Create a new model class that returns our JSON schema.
81-
# LangChain requires a BaseModel class.
82-
class SchemaBase(pydantic.BaseModel):
83-
model_config = pydantic.ConfigDict(extra="allow")
84-
85-
@t.override
86-
@classmethod
87-
def __get_pydantic_json_schema__(
88-
cls, core_schema: cs.CoreSchema, handler: pydantic.GetJsonSchemaHandler
89-
) -> JsonSchemaValue:
90-
return schema
91-
92-
# Since this langchain patch, we need to synthesize pydantic fields from the schema
93-
# https://github.com/langchain-ai/langchain/commit/033ac417609297369eb0525794d8b48a425b8b33
94-
required = schema.get("required", [])
95-
fields: dict[str, t.Any] = {
96-
name: configure_field(name, type_, required) for name, type_ in schema["properties"].items()
97-
}
98-
99-
return pydantic.create_model("Schema", __base__=SchemaBase, **fields)
100-
101-
10251
class MCPTool(BaseTool):
10352
"""
10453
MCP server tool
@@ -124,9 +73,3 @@ async def _arun(self, *args: t.Any, **kwargs: t.Any) -> tuple[str, list[ImageCon
12473
text_content = [block for block in result.content if isinstance(block, TextContent)]
12574
artifacts = [block for block in result.content if not isinstance(block, TextContent)]
12675
return pydantic_core.to_json(text_content).decode(), artifacts
127-
128-
@property
129-
@t.override
130-
def tool_call_schema(self) -> type[pydantic.BaseModel]:
131-
assert self.args_schema is not None # noqa: S101
132-
return self.args_schema

0 commit comments

Comments
 (0)