Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions src/smolagents/_function_type_hints_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,103 @@
}


def _is_pydantic_model(type_hint: type) -> bool:
"""
Check if a type hint represents a Pydantic BaseModel.

Args:
type_hint: The type to check

Returns:
bool: True if the type is a Pydantic BaseModel, False otherwise
"""
try:
# Check if pydantic is available
import pydantic

# Check if the type is a class and inherits from BaseModel
return inspect.isclass(type_hint) and issubclass(type_hint, pydantic.BaseModel)
except ImportError:
# pydantic not available
return False
except TypeError:
# Not a class or other type error
return False


def _get_pydantic_json_schema(model_class: type) -> dict:
"""
Get JSON schema from a Pydantic BaseModel.

Args:
model_class: The Pydantic model class

Returns:
dict: JSON schema for the model
"""
try:
# Get the schema using Pydantic's built-in method
schema = model_class.model_json_schema()
return schema
except Exception as e:
raise TypeHintParsingException(f"Failed to get Pydantic schema for {model_class}: {e}")


def _process_pydantic_schema(schema: dict) -> dict:
"""
Process a Pydantic JSON schema to make it compatible with smolagents.

This function handles:
- Resolving $refs to inline definitions
- Converting enum constraints to proper format
- Ensuring required fields are properly marked

Args:
schema: Raw Pydantic JSON schema

Returns:
dict: Processed schema compatible with smolagents
"""
# Make a copy to avoid modifying the original
processed_schema = copy(schema)

# Get definitions if they exist
definitions = schema.get("$defs", {})

def resolve_refs(obj):
"""Recursively resolve $ref references in the schema."""
if isinstance(obj, dict):
if "$ref" in obj:
# Extract the reference path
ref_path = obj["$ref"]
if ref_path.startswith("#/$defs/"):
def_name = ref_path.split("/")[-1]
if def_name in definitions:
# Replace the $ref with the actual definition
resolved = resolve_refs(definitions[def_name])
return resolved
# If we cannot resolve the ref, return the object as is
return obj
else:
# Recursively process all values in the dictionary
return {k: resolve_refs(v) for k, v in obj.items()}
elif isinstance(obj, list):
# Recursively process all items in the list
return [resolve_refs(item) for item in obj]
else:
# Return primitive values as-is
return obj

# Resolve all $refs in the schema
processed_schema = resolve_refs(processed_schema)

# Remove $defs since we have inlined everything
if "$defs" in processed_schema:
del processed_schema["$defs"]

return processed_schema


def get_package_name(import_name: str) -> str:
"""
Return the package name for a given import name.
Expand Down Expand Up @@ -328,6 +425,13 @@ def _parse_type_hint(hint: type) -> dict:
args = get_args(hint)

if origin is None:
# Check if this is a Pydantic model before falling back to regular type parsing
if _is_pydantic_model(hint):
# Get the Pydantic schema and process it
pydantic_schema = _get_pydantic_json_schema(hint)
processed_schema = _process_pydantic_schema(pydantic_schema)
return processed_schema

try:
return _get_json_schema_type(hint)
except KeyError:
Expand Down
Loading