Skip to content

Commit

Permalink
expose augmented basemodel
Browse files Browse the repository at this point in the history
  • Loading branch information
rgbkrk committed Nov 6, 2023
1 parent dd7f0c9 commit d82c58a
Showing 1 changed file with 15 additions and 93 deletions.
108 changes: 15 additions & 93 deletions chatlab/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,91 +27,19 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import json
from functools import wraps
from typing import Any, Callable

from docstring_parser import parse
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel, create_model, validate_arguments


class openai_function:
"""Decorator to convert a function into an OpenAI function.
This decorator will convert a function into an OpenAI function. The
function will be validated using pydantic and the schema will be
generated from the function signature.
Example:
```python
@openai_function
def sum(a: int, b: int) -> int:
return a + b
completion = openai.ChatCompletion.create(
...
messages=[{
"content": "What is 1 + 1?",
"role": "user"
}]
)
sum.from_response(completion)
# 2
```
"""

def __init__(self, func: Callable) -> None:
self.func = func
self.validate_func = validate_arguments(func)
self.docstring = parse(self.func.__doc__ or "")

parameters = self.validate_func.model.model_json_schema()
parameters["properties"] = {
k: v for k, v in parameters["properties"].items() if k not in ("v__duplicate_kwargs", "args", "kwargs")
}
for param in self.docstring.params:
if (name := param.arg_name) in parameters["properties"] and (description := param.description):
parameters["properties"][name]["description"] = description
parameters["required"] = sorted(k for k, v in parameters["properties"].items() if "default" not in v)
self.openai_schema = {
"name": self.func.__name__,
"description": self.docstring.short_description,
"parameters": parameters,
}
self.model = self.validate_func.model

def __call__(self, *args: Any, **kwargs: Any) -> Any:
@wraps(self.func)
def wrapper(*args, **kwargs):
return self.validate_func(*args, **kwargs)

return wrapper(*args, **kwargs)

def from_message(self, message: ChatCompletionMessageParam, strict: bool = False):
"""
Parse the response from OpenAI's API and return the function call
Parameters:
completion (openai.ChatCompletion): The response from OpenAI's API
throw_error (bool): Whether to throw an error if the response does not contain a function call
Returns:
result (any): result of the function call
"""

assert "function_call" in message, "No function call detected"
assert message["function_call"]["name"] == self.openai_schema["name"], "Function name does not match"

function_call = message["function_call"]
arguments = json.loads(function_call["arguments"], strict=strict)
return self.validate_func(**arguments)
from pydantic import BaseModel, create_model


class OpenAISchema(BaseModel):
"""Augments a Pydantic model with OpenAI's schema for function calling.
This class augments a Pydantic model with OpenAI's schema for function calling. The schema is generated from the model's signature and docstring. The schema can be used to validate the response from OpenAI's API and extract the function call.
This class augments a Pydantic model with OpenAI's schema for function calling. The schema is generated from the
model's signature and docstring. The schema can be used to validate the response from OpenAI's API and extract the
function call.
## Usage
Expand Down Expand Up @@ -153,7 +81,8 @@ def openai_schema(cls):
"""Return the schema in the format of OpenAI's schema as jsonschema.
Note:
Its important to add a docstring to describe how to best use this class, it will be included in the description attribute and be part of the prompt.
Its important to add a docstring to describe how to best use this class, it will be included in the
description attribute and be part of the prompt.
Returns:
model_json_schema (dict): A dictionary in the format of OpenAI's schema as jsonschema
Expand Down Expand Up @@ -183,29 +112,22 @@ def openai_schema(cls):
}

@classmethod
def from_response(
def from_message(
cls,
completion,
throw_error: bool = True,
message: ChatCompletionMessageParam,
validation_context=None,
strict: bool = None,
strict: bool = False,
):
"""Execute the function from the response of an openai chat completion
"""Execute the function from the response of an openai chat completion.
Parameters:
completion (openai.ChatCompletion): The response from an openai chat completion
throw_error (bool): Whether to throw an error if the function call is not detected
validation_context (dict): The validation context to use for validating the response
strict (bool): Whether to use strict json parsing
Args:
cls (OpenAISchema): An instance of the class
Returns:
cls (OpenAISchema): An instance of the class
None
"""
message = completion["choices"][0]["message"]

if throw_error:
assert "function_call" in message, "No function call detected"
assert message["function_call"]["name"] == cls.openai_schema["name"], "Function name does not match"
assert "function_call" in message, "No function call detected"
assert message["function_call"]["name"] == cls.openai_schema()["name"], "Function name does not match"

return cls.model_validate_json(
message["function_call"]["arguments"],
Expand Down

0 comments on commit d82c58a

Please sign in to comment.