1111from langchain_core .tools .base import BaseTool , BaseToolkit , ToolException
1212from mcp import ClientSession , ListToolsResult
1313from mcp .types import EmbeddedResource , ImageContent , TextContent
14- from pydantic .json_schema import JsonSchemaValue
15- from pydantic_core import core_schema as cs
1614
1715
1816class MCPToolkit (BaseToolkit ):
@@ -43,62 +41,13 @@ def get_tools(self) -> list[BaseTool]:
4341 session = self .session ,
4442 name = tool .name ,
4543 description = tool .description or "" ,
46- args_schema = create_schema_model ( tool .inputSchema ) ,
44+ args_schema = tool .inputSchema ,
4745 )
4846 # list_tools returns a PaginatedResult, but I don't see a way to pass the cursor to retrieve more tools
4947 for tool in self ._tools .tools
5048 ]
5149
5250
53- TYPEMAP = {
54- "integer" : int ,
55- "number" : float ,
56- "array" : list ,
57- "boolean" : bool ,
58- "string" : str ,
59- "object" : object ,
60- "null" : type (None ),
61- }
62-
63- FIELD_DEFAULTS = {
64- int : 0 ,
65- float : 0.0 ,
66- list : [],
67- bool : False ,
68- str : "" ,
69- type (None ): None ,
70- }
71-
72-
73- def configure_field (name : str , type_ : dict [str , t .Any ], required : list [str ]) -> tuple [type , t .Any ]:
74- field_type = TYPEMAP [type_ ["type" ]]
75- default_ = FIELD_DEFAULTS .get (field_type ) if name not in required else ...
76- return field_type , default_
77-
78-
79- def create_schema_model (schema : dict [str , t .Any ]) -> type [pydantic .BaseModel ]:
80- # Create a new model class that returns our JSON schema.
81- # LangChain requires a BaseModel class.
82- class SchemaBase (pydantic .BaseModel ):
83- model_config = pydantic .ConfigDict (extra = "allow" )
84-
85- @t .override
86- @classmethod
87- def __get_pydantic_json_schema__ (
88- cls , core_schema : cs .CoreSchema , handler : pydantic .GetJsonSchemaHandler
89- ) -> JsonSchemaValue :
90- return schema
91-
92- # Since this langchain patch, we need to synthesize pydantic fields from the schema
93- # https://github.com/langchain-ai/langchain/commit/033ac417609297369eb0525794d8b48a425b8b33
94- required = schema .get ("required" , [])
95- fields : dict [str , t .Any ] = {
96- name : configure_field (name , type_ , required ) for name , type_ in schema ["properties" ].items ()
97- }
98-
99- return pydantic .create_model ("Schema" , __base__ = SchemaBase , ** fields )
100-
101-
10251class MCPTool (BaseTool ):
10352 """
10453 MCP server tool
@@ -124,9 +73,3 @@ async def _arun(self, *args: t.Any, **kwargs: t.Any) -> tuple[str, list[ImageCon
12473 text_content = [block for block in result .content if isinstance (block , TextContent )]
12574 artifacts = [block for block in result .content if not isinstance (block , TextContent )]
12675 return pydantic_core .to_json (text_content ).decode (), artifacts
127-
128- @property
129- @t .override
130- def tool_call_schema (self ) -> type [pydantic .BaseModel ]:
131- assert self .args_schema is not None # noqa: S101
132- return self .args_schema
0 commit comments