Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

build function schemas with pydantic create_model #93

Merged
merged 3 commits into from
Sep 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 86 additions & 76 deletions chatlab/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,22 @@ class WhatTime(BaseModel):
import inspect
import json
from enum import Enum
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Type, Union, get_args, get_origin, overload

from pydantic import BaseModel
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Type,
Union,
get_args,
get_origin,
overload,
)

from pydantic import BaseModel, create_model

from .decorators import ChatlabMetadata

Expand Down Expand Up @@ -87,48 +100,10 @@ def is_union_type(t):
return get_origin(t) is Union


def process_type(annotation, is_required=True):
"""Determine the JSON schema type of a type annotation."""
origin = get_origin(annotation)
args = get_args(annotation)

if is_optional_type(annotation):
return process_type(args[0], is_required=False)

elif origin is Union:
types = [process_type(t, is_required)[0]["type"] for t in args if t is not type(None)] # noqa: E721
return {"type": types}, is_required

elif origin is list:
item_type = process_type(args[0], is_required)[0]["type"]
return {"type": "array", "items": {"type": item_type}}, is_required

elif origin is Literal:
values = get_args(annotation)
return {"type": "string", "enum": values}, is_required
class FunctionSchemaConfig:
"""Config used for model generation during function schema creation."""

elif issubclass(annotation, Enum):
values = [e.name for e in annotation]
return {"type": "string", "enum": values}, is_required

elif origin is dict:
return {"type": "object"}, is_required

elif annotation in ALLOWED_TYPES:
return {
"type": JSON_SCHEMA_TYPES[annotation],
}, is_required

else:
raise Exception(f"Type annotation must be a JSON serializable type ({ALLOWED_TYPES})")


def process_parameter(name, param):
"""Process a function parameter for use in a JSON schema."""
prop_schema, is_required = process_type(param.annotation, param.default == inspect.Parameter.empty)
if param.default != inspect.Parameter.empty:
prop_schema["default"] = param.default
return prop_schema, is_required
arbitrary_types_allowed = True


def generate_function_schema(
Expand All @@ -146,38 +121,63 @@ def generate_function_schema(
if not doc:
raise Exception("Only functions with docstrings can be registered")

schema = None
schema = {
"name": func_name,
"description": doc,
"parameters": {},
}

if isinstance(parameter_schema, dict):
schema = parameter_schema
parameters = parameter_schema
elif parameter_schema is not None:
schema = parameter_schema.schema()
parameters = parameter_schema.schema()
else:
schema_properties = {}
required = []

# extract function parameters and their type annotations
sig = inspect.signature(function)

fields = {}
for name, param in sig.parameters.items():
prop_schema, is_required = process_parameter(name, param)
schema_properties[name] = prop_schema
if is_required:
required.append(name)

schema = {"type": "object", "properties": {}, "required": []}
if len(schema_properties) > 0:
schema = {
"type": "object",
"properties": schema_properties,
"required": required,
}

if schema is None:
raise Exception(f"Could not generate schema for function {func_name}")

return {
"name": func_name,
"description": doc,
"parameters": schema,
}
# skip 'self' for class methods
if name == "self":
continue

# determine type annotation
if param.annotation == inspect.Parameter.empty:
# no annotation, raise instead of falling back to Any
raise Exception(
f"`{name}` parameter of {func_name} must have a JSON-serializable type annotation"
)
type_annotation = param.annotation

# get the default value, otherwise set as required
default_value = ...
if param.default != inspect.Parameter.empty:
default_value = param.default

fields[name] = (type_annotation, default_value)

# create the pydantic model and return its JSON schema to pass into the 'parameters' part of the
# function schema used by OpenAI
model = create_model(
function.__name__,
__config__=FunctionSchemaConfig,
**fields,
)
parameters: dict = model.schema()

if "properties" not in parameters:
parameters["properties"] = {}

# remove "title" since it's unused by OpenAI
parameters.pop("title", None)
for field_name in parameters["properties"].keys():
parameters["properties"][field_name].pop("title", None)

if "required" not in parameters:
parameters["required"] = []

schema["parameters"] = parameters
return schema


# Declare the type for the python hallucination
Expand Down Expand Up @@ -232,7 +232,9 @@ def __init__(self, python_hallucination_function: Optional[PythonHallucinationFu

self.python_hallucination_function = python_hallucination_function

def decorator(self, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None) -> Callable:
def decorator(
self, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None
) -> Callable:
"""Create a decorator for registering functions with a schema."""

def decorator(function):
Expand All @@ -243,16 +245,22 @@ def decorator(function):

@overload
def register(
self, function: None = None, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None
self,
function: None = None,
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
) -> Callable:
...

@overload
def register(self, function: Callable, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None) -> Dict:
def register(
self, function: Callable, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None
) -> Dict:
...

def register(
self, function: Optional[Callable] = None, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None
self,
function: Optional[Callable] = None,
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
) -> Union[Callable, Dict]:
"""Register a function for use in `Chat`s. Can be used as a decorator or directly to register a function.

Expand Down Expand Up @@ -407,7 +415,9 @@ async def call(self, name: str, arguments: Optional[str] = None) -> Any:
parameters = json.loads(arguments)
# TODO: Validate parameters against schema
except json.JSONDecodeError:
raise FunctionArgumentError(f"Invalid Function call on {name}. Arguments must be a valid JSON object")
raise FunctionArgumentError(
f"Invalid Function call on {name}. Arguments must be a valid JSON object"
)

if function is None:
raise UnknownFunctionError(f"Function {name} is not registered")
Expand Down
Loading