Skip to content

Commit

Permalink
take out noteable origami for prototyping sake
Browse files Browse the repository at this point in the history
  • Loading branch information
rgbkrk committed Nov 6, 2023
1 parent a1980cb commit dd7f0c9
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 342 deletions.
44 changes: 13 additions & 31 deletions chatlab/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from typing import Any, Callable

from docstring_parser import parse
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel, create_model, validate_arguments


Expand Down Expand Up @@ -67,18 +68,12 @@ def __init__(self, func: Callable) -> None:

parameters = self.validate_func.model.model_json_schema()
parameters["properties"] = {
k: v
for k, v in parameters["properties"].items()
if k not in ("v__duplicate_kwargs", "args", "kwargs")
k: v for k, v in parameters["properties"].items() if k not in ("v__duplicate_kwargs", "args", "kwargs")
}
for param in self.docstring.params:
if (name := param.arg_name) in parameters["properties"] and (
description := param.description
):
if (name := param.arg_name) in parameters["properties"] and (description := param.description):
parameters["properties"][name]["description"] = description
parameters["required"] = sorted(
k for k, v in parameters["properties"].items() if "default" not in v
)
parameters["required"] = sorted(k for k, v in parameters["properties"].items() if "default" not in v)
self.openai_schema = {
"name": self.func.__name__,
"description": self.docstring.short_description,
Expand All @@ -93,7 +88,7 @@ def wrapper(*args, **kwargs):

return wrapper(*args, **kwargs)

def from_response(self, completion, throw_error=True, strict: bool = None):
def from_message(self, message: ChatCompletionMessageParam, strict: bool = False):
"""
Parse the response from OpenAI's API and return the function call
Expand All @@ -104,13 +99,9 @@ def from_response(self, completion, throw_error=True, strict: bool = None):
Returns:
result (any): result of the function call
"""
message = completion["choices"][0]["message"]

if throw_error:
assert "function_call" in message, "No function call detected"
assert (
message["function_call"]["name"] == self.openai_schema["name"]
), "Function name does not match"
assert "function_call" in message, "No function call detected"
assert message["function_call"]["name"] == self.openai_schema["name"], "Function name does not match"

function_call = message["function_call"]
arguments = json.loads(function_call["arguments"], strict=strict)
Expand Down Expand Up @@ -157,7 +148,7 @@ class User(OpenAISchema):
"""

@property
@classmethod
def openai_schema(cls):
"""Return the schema in the format of OpenAI's schema as jsonschema.
Expand All @@ -169,27 +160,20 @@ def openai_schema(cls):
"""
schema = cls.model_json_schema()
docstring = parse(cls.__doc__ or "")
parameters = {
k: v for k, v in schema.items() if k not in ("title", "description")
}
parameters = {k: v for k, v in schema.items() if k not in ("title", "description")}
for param in docstring.params:
if (name := param.arg_name) in parameters["properties"] and (
description := param.description
):
if (name := param.arg_name) in parameters["properties"] and (description := param.description):
if "description" not in parameters["properties"][name]:
parameters["properties"][name]["description"] = description

parameters["required"] = sorted(
k for k, v in parameters["properties"].items() if "default" not in v
)
parameters["required"] = sorted(k for k, v in parameters["properties"].items() if "default" not in v)

if "description" not in schema:
if docstring.short_description:
schema["description"] = docstring.short_description
else:
schema["description"] = (
f"Correctly extracted `{cls.__name__}` with all "
f"the required parameters with correct types"
f"Correctly extracted `{cls.__name__}` with all " f"the required parameters with correct types"
)

return {
Expand Down Expand Up @@ -221,9 +205,7 @@ def from_response(

if throw_error:
assert "function_call" in message, "No function call detected"
assert (
message["function_call"]["name"] == cls.openai_schema["name"]
), "Function name does not match"
assert message["function_call"]["name"] == cls.openai_schema["name"], "Function name does not match"

return cls.model_validate_json(
message["function_call"]["arguments"],
Expand Down
Loading

0 comments on commit dd7f0c9

Please sign in to comment.