From 968e9dabe5e7dbff8a43d75482bbc81fd8068a34 Mon Sep 17 00:00:00 2001 From: Dave Shoup Date: Sat, 30 Sep 2023 10:31:33 -0400 Subject: [PATCH 1/3] support additional JSON-serializable function arguments by building pydantic model from function inspection --- chatlab/registry.py | 165 ++++++++++++++++++++++++-------------------- 1 file changed, 89 insertions(+), 76 deletions(-) diff --git a/chatlab/registry.py b/chatlab/registry.py index 8e479f0..c385080 100644 --- a/chatlab/registry.py +++ b/chatlab/registry.py @@ -43,9 +43,22 @@ class WhatTime(BaseModel): import inspect import json from enum import Enum -from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Type, Union, get_args, get_origin, overload - -from pydantic import BaseModel +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Literal, + Optional, + Type, + Union, + get_args, + get_origin, + overload, +) + +from pydantic import BaseModel, create_model from .decorators import ChatlabMetadata @@ -87,48 +100,10 @@ def is_union_type(t): return get_origin(t) is Union -def process_type(annotation, is_required=True): - """Determine the JSON schema type of a type annotation.""" - origin = get_origin(annotation) - args = get_args(annotation) - - if is_optional_type(annotation): - return process_type(args[0], is_required=False) - - elif origin is Union: - types = [process_type(t, is_required)[0]["type"] for t in args if t is not type(None)] # noqa: E721 - return {"type": types}, is_required - - elif origin is list: - item_type = process_type(args[0], is_required)[0]["type"] - return {"type": "array", "items": {"type": item_type}}, is_required - - elif origin is Literal: - values = get_args(annotation) - return {"type": "string", "enum": values}, is_required - - elif issubclass(annotation, Enum): - values = [e.name for e in annotation] - return {"type": "string", "enum": values}, is_required - - elif origin is dict: - return {"type": "object"}, is_required - - elif annotation in ALLOWED_TYPES: - return { - "type": JSON_SCHEMA_TYPES[annotation], - }, is_required - - else: - raise Exception(f"Type annotation must be a JSON serializable type ({ALLOWED_TYPES})") - +class FunctionSchemaConfig: + """Config used for model generation during function schema creation.""" -def process_parameter(name, param): - """Process a function parameter for use in a JSON schema.""" - prop_schema, is_required = process_type(param.annotation, param.default == inspect.Parameter.empty) - if param.default != inspect.Parameter.empty: - prop_schema["default"] = param.default - return prop_schema, is_required + arbitrary_types_allowed = True def generate_function_schema( @@ -146,38 +121,66 @@ def generate_function_schema( if not doc: raise Exception("Only functions with docstrings can be registered") - schema = None + schema = { + "name": func_name, + "description": doc, + "parameters": {}, + } + if isinstance(parameter_schema, dict): - schema = parameter_schema + parameters = parameter_schema elif parameter_schema is not None: - schema = parameter_schema.schema() + parameters = parameter_schema.schema() else: - schema_properties = {} - required = [] - + # extract function parameters and their type annotations sig = inspect.signature(function) + + fields = {} for name, param in sig.parameters.items(): - prop_schema, is_required = process_parameter(name, param) - schema_properties[name] = prop_schema - if is_required: - required.append(name) - - schema = {"type": "object", "properties": {}, "required": []} - if len(schema_properties) > 0: - schema = { - "type": "object", - "properties": schema_properties, - "required": required, - } - - if schema is None: - raise Exception(f"Could not generate schema for function {func_name}") - - return { - "name": func_name, - "description": doc, - "parameters": schema, - } + # skip 'self' for class methods + if name == "self": + continue + + # determine type annotation + if param.annotation == inspect.Parameter.empty: + # no annotation, raise instead of falling back to Any + raise Exception( + f"`{name}` parameter of {func_name} must have a JSON-serializable type annotation" + ) + type_annotation = param.annotation + + # get the default value, otherwise set as required + default_value = ... + if param.default != inspect.Parameter.empty: + default_value = param.default + + fields[name] = (type_annotation, default_value) + + # create the pydantic model and return its JSON schema to pass into the 'parameters' part of the + # function schema used by OpenAI + model = create_model( + function.__name__, + __config__=FunctionSchemaConfig, + **fields, + ) + parameters: dict = model.schema() + + if "properties" not in parameters: + parameters["properties"] = {} + + # remove "title" since it's unused by OpenAI + parameters.pop("title", None) + for field_name in parameters["properties"].keys(): + parameters["properties"][field_name].pop("title", None) + + if "required" not in parameters: + parameters["required"] = [] + + schema["parameters"] = parameters + # from rich import print as rprint + + # breakpoint() + return schema # Declare the type for the python hallucination @@ -232,7 +235,9 @@ def __init__(self, python_hallucination_function: Optional[PythonHallucinationFu self.python_hallucination_function = python_hallucination_function - def decorator(self, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None) -> Callable: + def decorator( + self, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None + ) -> Callable: """Create a decorator for registering functions with a schema.""" def decorator(function): @@ -243,16 +248,22 @@ def decorator(function): @overload def register( - self, function: None = None, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None + self, + function: None = None, + parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None, ) -> Callable: ... @overload - def register(self, function: Callable, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None) -> Dict: + def register( + self, function: Callable, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None + ) -> Dict: ... def register( - self, function: Optional[Callable] = None, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None + self, + function: Optional[Callable] = None, + parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None, ) -> Union[Callable, Dict]: """Register a function for use in `Chat`s. Can be used as a decorator or directly to register a function. @@ -407,7 +418,9 @@ async def call(self, name: str, arguments: Optional[str] = None) -> Any: parameters = json.loads(arguments) # TODO: Validate parameters against schema except json.JSONDecodeError: - raise FunctionArgumentError(f"Invalid Function call on {name}. Arguments must be a valid JSON object") + raise FunctionArgumentError( + f"Invalid Function call on {name}. Arguments must be a valid JSON object" + ) if function is None: raise UnknownFunctionError(f"Function {name} is not registered") From 9ff3ccc787b25512eb0b0c23882990647b139890 Mon Sep 17 00:00:00 2001 From: Dave Shoup Date: Sat, 30 Sep 2023 10:31:48 -0400 Subject: [PATCH 2/3] update & add tests --- tests/test_registry.py | 217 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 207 insertions(+), 10 deletions(-) diff --git a/tests/test_registry.py b/tests/test_registry.py index 74e6d8d..9c160f8 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -1,11 +1,17 @@ # flake8: noqa +import uuid from unittest import mock from unittest.mock import MagicMock, patch import pytest -from pydantic import BaseModel +from pydantic import BaseModel, Field -from chatlab.registry import FunctionArgumentError, FunctionRegistry, UnknownFunctionError, generate_function_schema +from chatlab.registry import ( + FunctionArgumentError, + FunctionRegistry, + UnknownFunctionError, + generate_function_schema, +) # Define a function to use in testing @@ -17,7 +23,51 @@ def simple_func(x: int, y: str, z: bool = False): class SimpleModel(BaseModel): x: int y: str - z: bool = False + z: bool = Field(default=False, description="A simple boolean field") + + +class SimpleClass: + def simple_method(self, x: int, y: str, z: bool = False): + """A simple test method""" + return f"{x}, {y}, {z}" + + +def simple_func_with_model_arg( + x: int, + y: str, + z: bool = False, + model: SimpleModel = None, +) -> str: + """A simple test function with a model argument""" + return f"{x}, {y}, {z}, {model}" + + +class NestedModel(BaseModel): + foo: int + bar: str + baz: bool = True + simple_model: SimpleModel + + +def simple_func_with_model_args( + x: int, + y: str, + z: bool = False, + model: SimpleModel = None, + nested_model: NestedModel = None, +) -> str: + """A simple test function with model arguments""" + return f"{x}, {y}, {z}, {model}, {nested_model}" + + +def simple_func_with_uuid_arg( + x: int, + y: str, + z: bool = False, + uuid: uuid.UUID = None, +) -> str: + """A simple test function with a uuid argument""" + return f"{x}, {y}, {z}, {uuid}" # Test the function generation schema @@ -39,16 +89,25 @@ def no_type_annotation(x): """Return back x""" return x - with pytest.raises(Exception, match="Type annotation must be a JSON serializable type"): + with pytest.raises( + Exception, + match=f"`x` parameter of no_type_annotation must have a JSON-serializable type annotation", + ): generate_function_schema(no_type_annotation) def test_generate_function_schema_unallowed_type(): - def unallowed_type(x: set): + class NewType: + pass + + def unallowed_type(x: NewType): '''Return back x''' return x - with pytest.raises(Exception, match="Type annotation must be a JSON serializable type"): + with pytest.raises( + ValueError, + match="Value not declarable with JSON Schema, field: name='x' type=NewType required=True", + ): generate_function_schema(unallowed_type) @@ -73,9 +132,146 @@ def test_generate_function_schema(): def test_generate_function_schema_with_model(): schema = generate_function_schema(simple_func, SimpleModel) expected_schema = { - "name": "simple_func", - "description": "A simple test function", - "parameters": SimpleModel.schema(), + 'name': 'simple_func', + 'description': 'A simple test function', + 'parameters': { + 'type': 'object', + 'properties': { + 'x': {'type': 'integer'}, + 'y': {'type': 'string'}, + 'z': { + 'default': False, + 'type': 'boolean', + "description": "A simple boolean field", + }, + }, + 'required': ['x', 'y'], + }, + } + assert schema == expected_schema + + +def test_generate_function_schema_with_method(): + schema = generate_function_schema(SimpleClass().simple_method) + expected_schema = { + "name": "simple_method", + "description": "A simple test method", + "parameters": { + "type": "object", + "properties": { + "x": {"type": "integer"}, + "y": {"type": "string"}, + "z": {"type": "boolean", "default": False}, + }, + "required": ["x", "y"], + }, + } + assert schema == expected_schema + + +def test_generate_function_schema_with_model_argument(): + schema = generate_function_schema(simple_func_with_model_arg) + expected_schema = { + "name": "simple_func_with_model_arg", + "description": "A simple test function with a model argument", + "parameters": { + "type": "object", + "properties": { + 'x': {'type': 'integer'}, + 'y': {'type': 'string'}, + 'z': {'default': False, 'type': 'boolean'}, + 'model': {'$ref': '#/definitions/SimpleModel'}, + }, + "required": ["x", "y"], + "definitions": { + 'SimpleModel': { + 'title': 'SimpleModel', + 'type': 'object', + 'properties': { + 'x': {'title': 'X', 'type': 'integer'}, + 'y': {'title': 'Y', 'type': 'string'}, + 'z': { + 'title': 'Z', + 'description': 'A simple boolean field', + 'default': False, + 'type': 'boolean', + }, + }, + 'required': ['x', 'y'], + } + }, + }, + } + assert schema == expected_schema + + +def test_generate_function_schema_with_model_and_nested_model_arguments(): + schema = generate_function_schema(simple_func_with_model_args) + expected_schema = { + "name": "simple_func_with_model_args", + "description": "A simple test function with model arguments", + "parameters": { + "type": "object", + "properties": { + 'x': {'type': 'integer'}, + 'y': {'type': 'string'}, + 'z': {'default': False, 'type': 'boolean'}, + 'model': {'$ref': '#/definitions/SimpleModel'}, + 'nested_model': {'$ref': '#/definitions/NestedModel'}, + }, + "required": ["x", "y"], + "definitions": { + 'SimpleModel': { + 'title': 'SimpleModel', + 'type': 'object', + 'properties': { + 'x': {'title': 'X', 'type': 'integer'}, + 'y': {'title': 'Y', 'type': 'string'}, + 'z': { + 'title': 'Z', + 'description': 'A simple boolean field', + 'default': False, + 'type': 'boolean', + }, + }, + 'required': ['x', 'y'], + }, + 'NestedModel': { + 'title': 'NestedModel', + 'type': 'object', + 'properties': { + 'foo': {'title': 'Foo', 'type': 'integer'}, + 'bar': {'title': 'Bar', 'type': 'string'}, + 'baz': { + 'title': 'Baz', + 'default': True, + 'type': 'boolean', + }, + 'simple_model': {'$ref': '#/definitions/SimpleModel'}, + }, + 'required': ['foo', 'bar', 'simple_model'], + }, + }, + }, + } + assert schema == expected_schema + + +def test_generate_function_schema_with_uuid_argument(): + schema = generate_function_schema(simple_func_with_uuid_arg) + expected_schema = { + "name": "simple_func_with_uuid_arg", + "description": "A simple test function with a uuid argument", + "parameters": { + "type": "object", + "properties": { + 'x': {'type': 'integer'}, + 'y': {'type': 'string'}, + 'z': {'default': False, 'type': 'boolean'}, + 'uuid': {'type': 'string', 'format': 'uuid'}, + }, + "required": ["x", "y"], + }, } assert schema == expected_schema @@ -93,7 +289,8 @@ async def test_function_registry_function_argument_error(): registry = FunctionRegistry() registry.register(simple_func, SimpleModel) with pytest.raises( - FunctionArgumentError, match="Invalid Function call on simple_func. Arguments must be a valid JSON object" + FunctionArgumentError, + match="Invalid Function call on simple_func. Arguments must be a valid JSON object", ): await registry.call("simple_func", arguments="not json") From 174d9ab503db3c996040ab0198f767477927072a Mon Sep 17 00:00:00 2001 From: Dave Shoup Date: Sat, 30 Sep 2023 10:34:55 -0400 Subject: [PATCH 3/3] remove leftovers from debugging tests --- chatlab/registry.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/chatlab/registry.py b/chatlab/registry.py index c385080..910b6e3 100644 --- a/chatlab/registry.py +++ b/chatlab/registry.py @@ -177,9 +177,6 @@ def generate_function_schema( parameters["required"] = [] schema["parameters"] = parameters - # from rich import print as rprint - - # breakpoint() return schema