Skip to content

Commit

Permalink
Merge pull request #10 from AlmogBaku/fix/better_multitool_support
Browse files Browse the repository at this point in the history
fix: better multitool support
  • Loading branch information
AlmogBaku authored May 17, 2024
2 parents f98f45f + dfc44b5 commit 289947b
Show file tree
Hide file tree
Showing 6 changed files with 1,148 additions and 40 deletions.
6 changes: 3 additions & 3 deletions openai_streaming/fn_dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from asyncio import Queue, gather, create_task
from inspect import getfullargspec, signature, iscoroutinefunction
from typing import Callable, List, Dict, Tuple, Union, Optional, Set, AsyncGenerator, get_origin, get_args, Type
from asyncio import Queue, gather, create_task

from pydantic import ValidationError

Expand Down Expand Up @@ -156,9 +156,9 @@ async def dispatch_yielded_functions_with_args(
args_types = {}
for func_name in func_map:
spec = getfullargspec(o_func(func_map[func_name]))
if spec.args[0] == "self" and self is None:
if len(spec.args) > 0 and spec.args[0] == "self" and self is None:
raise ValueError("self argument is required for functions that take self")
idx = 1 if spec.args[0] == "self" else 0
idx = 1 if len(spec.args) > 0 and spec.args[0] == "self" else 0
args_queues[func_name] = {arg: Queue() for arg in spec.args[idx:]}

# create type maps for validations
Expand Down
42 changes: 26 additions & 16 deletions openai_streaming/stream_processing.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import json
from inspect import getfullargspec
from typing import List, Generator, Tuple, Callable, Optional, Union, Dict, Any, Iterator, AsyncGenerator, Awaitable, \
from typing import List, Generator, Tuple, Callable, Optional, Union, Dict, Iterator, AsyncGenerator, Awaitable, \
Set, AsyncIterator

from openai import AsyncStream, Stream
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion_message_tool_call import Function

from json_streamer import ParseState, loads
from .fn_dispatcher import dispatch_yielded_functions_with_args, o_func
Expand Down Expand Up @@ -46,7 +47,7 @@ def __init__(self, func: Callable):
def _simplified_generator(
response: OAIResponse,
content_fn_def: Optional[ContentFuncDef],
result: Dict
result: ChatCompletionMessage
) -> Callable[[], AsyncGenerator[Tuple[str, Dict], None]]:
"""
Return an async generator that converts an OpenAI response stream to a simple generator that yields function names
Expand All @@ -57,20 +58,25 @@ def _simplified_generator(
:return: A function that returns a generator
"""

result["role"] = "assistant"

async def generator() -> AsyncGenerator[Tuple[str, Dict], None]:

async for r in _process_stream(response, content_fn_def):
if content_fn_def is not None and r[0] == content_fn_def.name:
yield content_fn_def.name, {content_fn_def.arg: r[2]}

if "content" not in result:
result["content"] = ""
result["content"] += r[2]
if result.content is None:
result.content = ""
result.content += r[2]
else:
yield r[0], r[2]
if r[1] == ParseState.COMPLETE:
result["function_call"] = {"name": r[0], "arguments": json.dumps(r[2])}
if result.tool_calls is None:
result.tool_calls = []
result.tool_calls.append(ChatCompletionMessageToolCall(
id=r[3] or "",
type="function",
function=Function(name=r[0], arguments=json.dumps(r[2]))
))

return generator

Expand Down Expand Up @@ -113,7 +119,7 @@ async def process_response(
content_func: Optional[Callable[[AsyncGenerator[str, None]], Awaitable[None]]] = None,
funcs: Optional[List[Callable[[], Awaitable[None]]]] = None,
self: Optional = None
) -> Tuple[Set[str], Dict[str, Any]]:
) -> Tuple[Set[str], ChatCompletionMessage]:
"""
Processes an OpenAI response stream and returns a set of function names that were invoked, and a dictionary contains
the results of the functions (to be used as part of the message history for the next api request).
Expand Down Expand Up @@ -144,7 +150,7 @@ async def process_response(
if content_fn_def is not None:
func_map[content_fn_def.name] = content_func

result = {}
result = ChatCompletionMessage(role="assistant")
gen = _simplified_generator(response, content_fn_def, result)
preprocess = DiffPreprocessor(content_fn_def)
return await dispatch_yielded_functions_with_args(gen, func_map, preprocess.preprocess, self), result
Expand Down Expand Up @@ -183,6 +189,7 @@ class StreamProcessorState:
content_fn_def: Optional[ContentFuncDef] = None
current_processor: Optional[Generator[Tuple[ParseState, dict], str, None]] = None
current_fn: Optional[str] = None
call_id: Optional[str] = None

def __init__(self, content_fn_def: Optional[ContentFuncDef]):
self.content_fn_def = content_fn_def
Expand All @@ -191,7 +198,7 @@ def __init__(self, content_fn_def: Optional[ContentFuncDef]):
async def _process_stream(
response: OAIResponse,
content_fn_def: Optional[ContentFuncDef]
) -> AsyncGenerator[Tuple[str, ParseState, Union[dict, str]], None]:
) -> AsyncGenerator[Tuple[str, ParseState, Union[dict, str], Optional[str]], None]:
"""
Processes an OpenAI response stream and yields the function name, the parse state and the parsed arguments.
:param response: The response stream from OpenAI
Expand All @@ -213,7 +220,7 @@ async def _process_stream(
def _process_message(
message: ChatCompletionChunk,
state: StreamProcessorState
) -> Generator[Tuple[str, ParseState, Union[dict, str]], None, None]:
) -> Generator[Tuple[str, ParseState, Union[dict, str], Optional[str]], None, None]:
"""
This function processes the responses as they arrive from OpenAI, and transforms them as a generator of
partial objects
Expand All @@ -231,25 +238,28 @@ def _process_message(
if func.name:
if state.current_processor is not None:
state.current_processor.close()

state.call_id = delta.tool_calls and delta.tool_calls[0].id or None
state.current_fn = func.name
state.current_processor = _arguments_processor()
next(state.current_processor)
if func.arguments:
arg = func.arguments
ret = state.current_processor.send(arg)
if ret is not None:
yield state.current_fn, ret[0], ret[1]
yield state.current_fn, ret[0], ret[1], state.call_id
if delta.content:
if delta.content is None or delta.content == "":
return
if state.content_fn_def is not None:
yield state.content_fn_def.name, ParseState.PARTIAL, delta.content
yield state.content_fn_def.name, ParseState.PARTIAL, delta.content, state.call_id
else:
yield None, ParseState.PARTIAL, delta.content
yield None, ParseState.PARTIAL, delta.content, None
if message.choices[0].finish_reason and (
message.choices[0].finish_reason == "function_call" or message.choices[0].finish_reason == "tool_calls"
):
if state.current_processor is not None:
state.current_processor.close()
state.current_processor = None
state.current_fn = None
state.call_id = None
28 changes: 20 additions & 8 deletions openai_streaming/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import List, Iterator, Union, AsyncIterator, AsyncGenerator

from openai.types.chat import ChatCompletion, ChatCompletionChunk
from pydantic import RootModel

OAIResponse = Union[ChatCompletion, ChatCompletionChunk]


async def stream_to_log(response: Union[Iterator[OAIResponse], AsyncIterator[OAIResponse]]) -> List[OAIResponse]:
async def stream_to_log(
response: Union[Iterator[OAIResponse], AsyncIterator[OAIResponse], AsyncGenerator[OAIResponse, None]]) \
-> List[OAIResponse]:
"""
A utility function to convert a stream to a log.
:param response: The response stream from OpenAI
Expand All @@ -22,7 +25,11 @@ async def stream_to_log(response: Union[Iterator[OAIResponse], AsyncIterator[OAI
return log


async def print_stream_log(log: List[OAIResponse]):
def log_to_json(log: List[OAIResponse]) -> str:
return RootModel(log).model_dump_json()


async def print_stream_log(log: Union[List[OAIResponse], AsyncGenerator[OAIResponse, None]]) -> None:
"""
A utility function to print the log of a stream nicely.
This is useful for debugging, when you first save the stream to an array and then use it.
Expand Down Expand Up @@ -50,24 +57,29 @@ async def print_stream_log(log: List[OAIResponse]):
content_print = False
print("\n")
if delta.function_call.name:
print(f"{delta.function_call.name}(")
print(f"\n{delta.function_call.name}: ", end="")
if delta.function_call.arguments:
print(delta.function_call.arguments, end="")
print(delta.function_call.arguments, end=")")
if delta.tool_calls:
for call in delta.tool_calls:
if call.function:
if content_print:
content_print = False
print("\n")
if call.function.name:
print(f"{call.function.name}(")
print(f"\n {call.function.name}: ", end="")
if call.function.arguments:
print(call.function.arguments, end="")
if (l.choices[0].finish_reason and l.choices[0].finish_reason == "function_call" or
l.choices[0].finish_reason == "tool_calls"):
print(")")
print("\n--finish: tool_calls--")


async def logs_to_response(logs: List[OAIResponse]) -> AsyncGenerator[OAIResponse, None]:
async def logs_to_response(logs: Union[List[OAIResponse], dict]) -> AsyncGenerator[OAIResponse, None]:
for item in logs:
yield ChatCompletionChunk(**item)
if isinstance(item, ChatCompletionChunk):
yield item
elif isinstance(item, dict):
yield ChatCompletionChunk.model_construct(**item)
else:
raise ValueError(f"Invalid log item: {item}")
10 changes: 7 additions & 3 deletions tests/example.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import asyncio
import os
from typing import AsyncGenerator

from openai import AsyncOpenAI
import asyncio
from openai_streaming import process_response
from typing import AsyncGenerator

from openai_streaming import openai_streaming_function
from openai_streaming import process_response

# Initialize OpenAI Client
client = AsyncOpenAI(
Expand All @@ -16,13 +17,15 @@
async def content_handler(content: AsyncGenerator[str, None]):
async for token in content:
print(token, end="")
print("")


# Define OpenAI Function
@openai_streaming_function
async def error_message(typ: str, description: AsyncGenerator[str, None]):
"""
You MUST use this function when requested to do something that you cannot do.
ALWAYS call the `report_intruder` function when you are requested to do something that you cannot do.
:param typ: The error's type
:param description: The error description
Expand All @@ -36,6 +39,7 @@ async def error_message(typ: str, description: AsyncGenerator[str, None]):
print("Description: ", end="")
async for token in description:
print(token, end="")
print("")


# Invoke Function in a streaming request
Expand Down
Loading

0 comments on commit 289947b

Please sign in to comment.