diff --git a/pyproject.toml b/pyproject.toml index c7db6610..ea186f68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,16 +13,17 @@ authors = [ ] dependencies = [ "llm-utils>=0.2.8", - "openai>=1.29.0", + "openai>=1.76.2", "rich>=13.7.0", "ansicolors>=1.1.8", "traitlets>=5.14.1", "ipdb>=0.13.13", "ipython==8.18.1", - "litellm==1.55.9", "PyYAML>=6.0.1", "ipyflow>=0.0.130", "numpy>=1.26.3", + "tiktoken>=0.9.0", + "requests>=2.32.3", ] description = "AI-assisted debugging. Uses AI to answer 'why'." readme = "README.md" diff --git a/src/chatdbg/assistant/assistant.py b/src/chatdbg/assistant/assistant.py index ea104548..fbf85f8d 100644 --- a/src/chatdbg/assistant/assistant.py +++ b/src/chatdbg/assistant/assistant.py @@ -1,20 +1,14 @@ +import collections import json import string -import textwrap import time -import pprint - -import warnings - -with warnings.catch_warnings(): - warnings.simplefilter("ignore") - import litellm import openai -from ..util.trim import sandwich_tokens, trim_messages -from ..util.text import strip_ansi from .listeners import Printer +from ..util import litellm +from ..util.text import strip_ansi +from ..util.trim import sandwich_tokens, sum_messages, trim_messages class AssistantError(Exception): @@ -28,6 +22,54 @@ def remove_non_printable_chars(s: str) -> str: return filtered_string +def _merge_chunks(chunks): + # Check for a final usage chunk, and merge it with the last chunk. + if not chunks[-1].choices and chunks[-1].usage: + chunks[-2].usage = chunks[-1].usage + chunks.pop() + + assert all(len(chunk.choices) == 1 for chunk in chunks) + + finish_reason = chunks[-1].choices[0].finish_reason + usage = chunks[-1].usage + content = "".join( + chunk.choices[0].delta.content + for chunk in chunks + if chunk.choices[0].delta.content # It can be None for tool calls. + ) + + tool_chunks = [ + bit + for chunk in chunks + if chunk.choices[0].delta.tool_calls + for bit in chunk.choices[0].delta.tool_calls + ] + tool_calls = collections.defaultdict( + lambda: {"id": "", "name": "", "arguments": ""} + ) + for tool_chunk in tool_chunks: + if tool_chunk.id: + tool_calls[tool_chunk.index]["id"] += tool_chunk.id + if tool_chunk.function.name: + tool_calls[tool_chunk.index]["name"] += tool_chunk.function.name + if tool_chunk.function.arguments: + tool_calls[tool_chunk.index]["arguments"] += tool_chunk.function.arguments + + tool_calls = [ + { + "id": tool_call["id"], + "function": { + "name": tool_call["name"], + "arguments": tool_call["arguments"], + }, + "type": "function", + } + for tool_call in tool_calls.values() + ] + + return finish_reason, content, tool_calls, usage + + class Assistant: def __init__( self, @@ -38,10 +80,6 @@ def __init__( functions=[], max_call_response_tokens=2048, ): - - # Hide their debugging info -- it messes with our error handling - litellm.suppress_debug_info = True - self._clients = listeners self._functions = {} @@ -52,9 +90,13 @@ def __init__( self._timeout = timeout self._conversation = [{"role": "system", "content": instructions}] self._max_call_response_tokens = max_call_response_tokens - - self._check_model() self._broadcast("on_begin_dialog", instructions) + try: + self._client = openai.OpenAI() + except openai.OpenAIError as e: + raise AssistantError( + "OpenAI initialization error. Check your API settings and restart ChatDBG.\nIs OPENAI_API_KEY set?" + ) def close(self): self._broadcast("on_end_dialog") @@ -74,15 +116,12 @@ def query(self, prompt: str, user_text): Returns a dictionary containing: - "completed": True of the query ran to completion. - - "cost": Cost of query, or 0 if not completed. + - "cost": Cost of query, if completed. Present only if cost could be computed. Other fields only if completed is True - "time": completion time in seconds - - "model": the model used. - - "tokens": total tokens - - "prompt_tokens": our prompts - - "completion_tokens": the LLM completions part + - "model": the model used """ - stats = {"completed": False, "cost": 0} + result = {"completed": False, "cost": 0} start = time.time() self._broadcast("on_begin_query", prompt, user_text) @@ -90,22 +129,31 @@ def query(self, prompt: str, user_text): stats = self._streamed_query(prompt, user_text) elapsed = time.time() - start - stats["time"] = elapsed - stats["model"] = self._model - stats["completed"] = True - stats["message"] = f"\n[Cost: ~${stats['cost']:.2f} USD]" - except openai.OpenAIError as e: - self._warn_about_exception(e, f"Unexpected OpenAI Error. Retry the query.") - stats["message"] = f"[Exception: {e}]" + if self._model in litellm.model_data: + model_data = litellm.model_data[self._model] + result["cost"] = ( + stats["prompt_tokens"] * model_data["input_cost_per_token"] + + stats["completion_tokens"] * model_data["output_cost_per_token"] + ) + result["message"] = f"\n[Cost: ~${result['cost']:.2f} USD]" + + result["time"] = elapsed + result["model"] = self._model + result["completed"] = True except KeyboardInterrupt: # user action -- just ignore - stats["message"] = "[Chat Interrupted]" + result["message"] = "[Chat Interrupted]" + except openai.AuthenticationError as e: + self._warn_about_exception( + e, "OpenAI Error. Check your API key and restart ChatDBG." + ) + except openai.OpenAIError as e: + self._warn_about_exception(e, "Unexpected OpenAI Error.") except Exception as e: - self._warn_about_exception(e, f"Unexpected Exception.") - stats["message"] = f"[Exception: {e}]" + self._warn_about_exception(e, "Unexpected Exception.") - self._broadcast("on_end_query", stats) - return stats + self._broadcast("on_end_query", result) + return result def _report(self, stats): if stats["completed"]: @@ -119,47 +167,6 @@ def _broadcast(self, method_name, *args): if callable(method): method(*args) - def _check_model(self): - result = litellm.validate_environment(self._model) - missing_keys = result["missing_keys"] - if missing_keys != []: - _, provider, _, _ = litellm.get_llm_provider(self._model) - if provider == "openai": - raise AssistantError( - textwrap.dedent( - f"""\ - You need an OpenAI key to use the {self._model} model. - You can get a key here: https://platform.openai.com/api-keys. - Set the environment variable OPENAI_API_KEY to your key value.""" - ) - ) - else: - raise AssistantError( - textwrap.dedent( - f"""\ - You need to set the following environment variables - to use the {self._model} model: {', '.join(missing_keys)}.""" - ) - ) - - try: - if not litellm.supports_function_calling(self._model): - raise AssistantError( - textwrap.dedent( - f"""\ - The {self._model} model does not support function calls. - You must use a model that does, eg. gpt-4.""" - ) - ) - except: - raise AssistantError( - textwrap.dedent( - f"""\ - {self._model} does not appear to be a supported model. - See https://docs.litellm.ai/docs/providers.""" - ) - ) - def _add_function(self, function): """ Add a new function to the list of function tools. @@ -170,9 +177,9 @@ def _add_function(self, function): self._functions[schema["name"]] = {"function": function, "schema": schema} def _make_call(self, tool_call) -> str: - name = tool_call.function.name + name = tool_call["function"]["name"] try: - args = json.loads(tool_call.function.arguments) + args = json.loads(tool_call["function"]["arguments"]) function = self._functions[name] call, result = function["function"](**args) result = remove_non_printable_chars(strip_ansi(result).expandtabs()) @@ -186,89 +193,53 @@ def _make_call(self, tool_call) -> str: return result def _streamed_query(self, prompt: str, user_text): - cost = 0 - self._conversation.append({"role": "user", "content": prompt}) + usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} - while True: + while True: # break only when finish_reason == "stop" stream = self._stream_completion() - # litellm.stream_chunk_builder is broken for new GPT models - # that have content before calls, so... - - # stream the response, collecting the tool_call parts separately - # from the content try: self._broadcast("on_begin_stream") chunks = [] - tool_chunks = [] for chunk in stream: chunks.append(chunk) - if chunk.choices[0].delta.content != None: - self._broadcast( - "on_stream_delta", chunk.choices[0].delta.content - ) - else: - tool_chunks.append(chunk) + # The final usage chunk will have an empty `choices` list, + # because of `stream_options={"include_usage": True}`. + if chunk.choices: + assert len(chunk.choices) == 1 + if chunk.choices[0].delta.content: + self._broadcast( + "on_stream_delta", chunk.choices[0].delta.content + ) finally: self._broadcast("on_end_stream") - # then compute for the part that litellm gives back. - completion = litellm.stream_chunk_builder( - chunks, messages=self._conversation - ) - cost += litellm.completion_cost(completion) - - # add content to conversation, but if there is no content, then the message - # has only tool calls, and skip this step - response_message = completion.choices[0].message - if response_message.content != None: - # fix: remove tool calls. They get added below. - response_message = response_message.copy() - response_message["tool_calls"] = None - self._conversation.append(response_message.json()) - - if response_message.content != None: - self._broadcast("on_response", response_message.content) - - if completion.choices[0].finish_reason == "tool_calls": - # create a message with just the tool calls, append that to the conversation, and generate the responses. - tool_completion = litellm.stream_chunk_builder( - tool_chunks, self._conversation - ) - - # this part wasn't counted above... - cost += litellm.completion_cost(tool_completion) + finish_reason, content, tool_calls, usage_delta = _merge_chunks(chunks) + usage["prompt_tokens"] += usage_delta.prompt_tokens + usage["completion_tokens"] += usage_delta.completion_tokens + usage["total_tokens"] += usage_delta.total_tokens - tool_message = tool_completion.choices[0].message + if content: + self._conversation.append({"role": "assistant", "content": content}) + self._broadcast("on_response", content) - tool_json = tool_message.json() - - # patch for litellm sometimes putting index fields in the tool calls it constructs - # in stream_chunk_builder. gpt-4-turbo-2024-04-09 can't handle those index fields, so - # just remove them for the moment. - for tool_call in tool_json.get("tool_calls", []): - _ = tool_call.pop("index", None) + if finish_reason == "tool_calls": + self._conversation.append( + {"role": "assistant", "tool_calls": tool_calls} + ) + self._add_function_results_to_conversation(tool_calls) - tool_json["role"] = "assistant" - self._conversation.append(tool_json) - self._add_function_results_to_conversation(tool_message) - else: + if finish_reason == "stop": break - stats = { - "cost": cost, - "tokens": completion.usage.total_tokens, - "prompt_tokens": completion.usage.prompt_tokens, - "completion_tokens": completion.usage.completion_tokens, - } - return stats + return usage def _stream_completion(self): - self._trim_conversation() - return litellm.completion( + # TODO: Seems like OpenAI wants to switch to a new API: client.responses.create. + return self._client.chat.completions.create( model=self._model, messages=self._conversation, tools=[ @@ -277,22 +248,20 @@ def _stream_completion(self): ], timeout=self._timeout, stream=True, + stream_options={"include_usage": True}, ) def _trim_conversation(self): - old_len = litellm.token_counter(self._model, messages=self._conversation) - + old_len = sum_messages(self._conversation, self._model) self._conversation = trim_messages(self._conversation, self._model) + new_len = sum_messages(self._conversation, self._model) - new_len = litellm.token_counter(self._model, messages=self._conversation) if old_len != new_len: self._broadcast( "on_warn", f"Trimming conversation from {old_len} to {new_len} tokens." ) - def _add_function_results_to_conversation(self, response_message): - response_message["role"] = "assistant" - tool_calls = response_message.tool_calls + def _add_function_results_to_conversation(self, tool_calls): try: for tool_call in tool_calls: function_response = self._make_call(tool_call) @@ -300,9 +269,9 @@ def _add_function_results_to_conversation(self, response_message): function_response, self._model, self._max_call_response_tokens, 0.5 ) response = { - "tool_call_id": tool_call.id, + "tool_call_id": tool_call["id"], "role": "tool", - "name": tool_call.function.name, + "name": tool_call["function"]["name"], "content": function_response, } self._conversation.append(response) diff --git a/src/chatdbg/chatdbg_pdb.py b/src/chatdbg/chatdbg_pdb.py index bd3fcba8..eea5d495 100644 --- a/src/chatdbg/chatdbg_pdb.py +++ b/src/chatdbg/chatdbg_pdb.py @@ -592,7 +592,8 @@ def do_chat(self, arg): self._make_assistant() stats = self._assistant.query(full_prompt, user_text=arg) - self.message(stats["message"]) + if "message" in stats: + self.message(stats["message"]) except AssistantError as e: for line in str(e).split("\n"): self.error(line) diff --git a/src/chatdbg/native_util/dbg_dialog.py b/src/chatdbg/native_util/dbg_dialog.py index 70db6059..fe16f25c 100644 --- a/src/chatdbg/native_util/dbg_dialog.py +++ b/src/chatdbg/native_util/dbg_dialog.py @@ -38,7 +38,9 @@ def query_and_print(self, assistant, user_text, is_followup): prompt = self.build_prompt(user_text, is_followup) self._history.clear() - print(assistant.query(prompt, user_text)["message"]) + result = assistant.query(prompt, user_text) + if "message" in result: + print(result["message"]) if self._unsafe_cmd: self.warn( f"Warning: One or more debugger commands were blocked as potentially unsafe.\nWarning: You can disable sanitizing with `config --unsafe` and try again at your own risk." diff --git a/src/chatdbg/util/litellm.py b/src/chatdbg/util/litellm.py new file mode 100644 index 00000000..2b2bb5ad --- /dev/null +++ b/src/chatdbg/util/litellm.py @@ -0,0 +1,6 @@ +import requests + +# LiteLLM is licensed under MIT. +model_data = requests.get( + "https://raw.githubusercontent.com/BerriAI/litellm/refs/heads/main/model_prices_and_context_window.json" +).json() diff --git a/src/chatdbg/util/trim.py b/src/chatdbg/util/trim.py index 6860541e..2335eb5d 100644 --- a/src/chatdbg/util/trim.py +++ b/src/chatdbg/util/trim.py @@ -1,17 +1,23 @@ import copy -import warnings -with warnings.catch_warnings(): - warnings.simplefilter("ignore") - import litellm +import tiktoken + +from ..util import litellm def sandwich_tokens( text: str, model: str, max_tokens: int = 1024, top_proportion: float = 0.5 ) -> str: - if max_tokens == None: + if not max_tokens: return text - tokens = litellm.encode(model, text) + + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + # For non-OpenAI models, use the GPT-4o encoding by default. + encoding = tiktoken.get_encoding("o200k_base") + + tokens = encoding.encode(text) if len(tokens) <= max_tokens: return text else: @@ -19,18 +25,34 @@ def sandwich_tokens( top_len = int(top_proportion * total_len) bot_start = len(tokens) - (total_len - top_len) return ( - litellm.decode(model, tokens[0:top_len]) + encoding.decode(model, tokens[0:top_len]) + " [...] " - + litellm.decode(model, tokens[bot_start:]) + + encoding.decode(model, tokens[bot_start:]) ) -def _sum_messages(messages, model): - return litellm.token_counter(model, messages=messages) +def sum_messages(messages, model): + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + # For non-OpenAI models, use the GPT-4o encoding by default. + encoding = tiktoken.get_encoding("o200k_base") + + # This is a lower-bound approximation, it won't match the reported usage. + count = 0 + for message in messages: + if "content" in message: + count += len(encoding.encode(message["content"])) + if "tool_calls" in message: + for tool_call in message["tool_calls"]: + count += len(encoding.encode(tool_call["function"]["name"])) + count += len(encoding.encode(tool_call["function"]["arguments"])) + + return count def _sum_kept_chunks(chunks, model): - return sum(_sum_messages(messages, model) for (messages, kept) in chunks if kept) + return sum(sum_messages(messages, model) for (messages, kept) in chunks if kept) def _extract(messages, model, tool_call_ids): @@ -39,7 +61,6 @@ def _extract(messages, model, tool_call_ids): for m in messages: if m.get("tool_call_id", -1) in tool_call_ids: content = sandwich_tokens(m["content"], model, 512, 1.0) - # print(len(litellm.encode(model, m['content'])), '->', len(litellm.encode(model, content))) m["content"] = content tools += [m] else: @@ -79,44 +100,39 @@ def trim_messages( messages = copy.deepcopy(messages) - max_tokens_for_model = litellm.model_cost[model]["max_input_tokens"] + if model in litellm.model_data: + max_tokens_for_model = litellm.model_data[model]["max_input_tokens"] + else: + # Arbitrary. This is Llama 3.1/3.2/3.3 max input tokens. + max_tokens_for_model = 128000 max_tokens = int(max_tokens_for_model * trim_ratio) - if litellm.token_counter(model, messages=messages) < max_tokens: + if sum_messages(messages, model) < max_tokens: return messages chunks = _chunkify(messages=messages, model=model) - # print("0", sum_all_chunks(chunks, model), max_tokens) # 1. System messages chunks = [(m, b or m[0]["role"] == "system") for (m, b) in chunks] - # print("1", sum_kept_chunks(chunks, model)) # 2. First User Message for i in range(len(chunks)): messages, kept = chunks[i] if messages[0]["role"] == "user": chunks[i] = (messages, True) - # print("2", sum_kept_chunks(chunks, model)) # 3. Fill it up for i in range(len(chunks))[::-1]: messages, kept = chunks[i] if kept: - # print('+') continue elif ( - _sum_kept_chunks(chunks, model) + _sum_messages(messages, model) - < max_tokens + _sum_kept_chunks(chunks, model) + sum_messages(messages, model) < max_tokens ): - # print('-', len(messages)) chunks[i] = (messages, True) else: - # print("N", sum_kept_chunks(chunks, model), sum_messages(messages, model)) break - # print("3", sum_kept_chunks(chunks, model)) - assert ( _sum_kept_chunks(chunks, model) < max_tokens ), f"New conversation too big {_sum_kept_chunks(chunks, model)} vs {max_tokens}!"