Skip to content

Commit

Permalink
Merge pull request #11 from jepler/backend-option-help
Browse files Browse the repository at this point in the history
Add background option help; add chatgpt max-request-tokens
  • Loading branch information
jepler authored Sep 24, 2023
2 parents a35a4d6 + c2d801d commit 02de0b3
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ repos:
rev: v2.17.0
hooks:
- id: pylint
additional_dependencies: [click,dataclasses_json,httpx,lorem-text,'textual>=0.18.0',websockets]
additional_dependencies: [click,dataclasses_json,httpx,lorem-text,simple-parsing,'textual>=0.18.0',tiktoken,websockets]
args: ['--source-roots', 'src']
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ dependencies = [
"httpx",
"lorem-text",
"platformdirs",
"simple_parsing",
"textual>=0.18.0",
"tiktoken",
"websockets",
]
classifiers = [
Expand Down
4 changes: 4 additions & 0 deletions src/chap/backends/lorem.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@ class Lorem:
@dataclass
class Parameters:
delay_mu: float = 0.035
"""Average delay between tokens"""
delay_sigma: float = 0.02
"""Standard deviation of token delay"""
paragraph_lo: int = 1
"""Minimum response paragraph count"""
paragraph_hi: int = 5
"""Maximum response paragraph count (inclusive)"""

def __init__(self):
self.parameters = self.Parameters()
Expand Down
89 changes: 82 additions & 7 deletions src/chap/backends/openai_chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,105 @@
#
# SPDX-License-Identifier: MIT

import functools
import json
from dataclasses import dataclass

import httpx
import tiktoken

from ..key import get_key
from ..session import Assistant, Session, User


@dataclass(frozen=True)
class EncodingMeta:
encoding: tiktoken.Encoding
tokens_per_message: int
tokens_per_name: int

@functools.lru_cache()
def encode(self, s):
return self.encoding.encode(s)

def num_tokens_for_message(self, message):
# n.b. chap doesn't use message.name yet
return len(self.encode(message.role)) + len(self.encode(message.content))

def num_tokens_for_messages(self, messages):
return sum(self.num_tokens_for_message(message) for message in messages) + 3

@classmethod
@functools.cache
def from_model(cls, model):
if model == "gpt-3.5-turbo":
# print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
model = "gpt-3.5-turbo-0613"
if model == "gpt-4":
# print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
model = "gpt-4-0613"

try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")

if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = (
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
)
tokens_per_name = -1 # if there's a name, the role is omitted
else:
raise NotImplementedError(
f"""EncodingMeta is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
)
return cls(encoding, tokens_per_message, tokens_per_name)


class ChatGPT:
@dataclass
class Parameters:
model: str = "gpt-3.5-turbo"
"""The model to use. The most common alternative value is 'gpt-4'."""

max_request_tokens: int = 1024
"""The approximate greatest number of tokens to send in a request. When the session is long, the system prompt and 1 or more of the most recent interaction steps are sent."""

def __init__(self):
self.parameters = self.Parameters()

system_message: str = "You're named Chap (they/them). Chap strives to answer truthfully and frankly. You have opinions, preferences, and emotions, you can say when you don't know, and you prefer to speak in plain language."

def ask(self, session, query, *, max_query_size=5, timeout=60):
full_prompt = Session(session.session + [User(query)])
del full_prompt.session[1:-max_query_size]
def make_full_prompt(self, all_history):
encoding = EncodingMeta.from_model(self.parameters.model)
result = [all_history[0]] # Assumed to be system prompt
left = self.parameters.max_request_tokens - encoding.num_tokens_for_messages(
result
)
parts = []
for message in reversed(all_history[1:]):
msglen = encoding.num_tokens_for_message(message)
if left >= msglen:
left -= msglen
parts.append(message)
else:
break
result.extend(reversed(parts))
return Session(result)

def ask(self, session, query, *, timeout=60):
full_prompt = self.make_full_prompt(session.session + [User(query)])
response = httpx.post(
"https://api.openai.com/v1/chat/completions",
json={
Expand Down Expand Up @@ -51,10 +128,8 @@ def ask(self, session, query, *, max_query_size=5, timeout=60):
session.session.extend([User(query), Assistant(result)])
return result

async def aask(self, session, query, *, max_query_size=5, timeout=60):
full_prompt = Session(session.session + [User(query)])
del full_prompt.session[1:-max_query_size]

async def aask(self, session, query, *, timeout=60):
full_prompt = self.make_full_prompt(session.session + [User(query)])
new_content = []
try:
async with httpx.AsyncClient(timeout=timeout) as client:
Expand Down
5 changes: 2 additions & 3 deletions src/chap/commands/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import click
import rich

from ..core import uses_new_session
from ..core import command_uses_new_session

if sys.stdout.isatty():
bold = "\033[1m"
Expand Down Expand Up @@ -78,8 +78,7 @@ async def work():
return result


@click.command
@uses_new_session
@command_uses_new_session
@click.argument("prompt", nargs=-1, required=True)
def main(obj, prompt):
"""Ask a question (command-line argument is passed as prompt)"""
Expand Down
6 changes: 2 additions & 4 deletions src/chap/commands/tui.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
import subprocess
import sys

import click
from markdown_it import MarkdownIt
from textual.app import App
from textual.binding import Binding
from textual.containers import Container, VerticalScroll
from textual.widgets import Footer, Input, Markdown

from ..core import get_api, uses_new_session
from ..core import command_uses_new_session, get_api
from ..session import Assistant, Session, User


Expand Down Expand Up @@ -115,8 +114,7 @@ def action_yank(self):
subprocess.run(["xsel", "-ib"], input=content.encode("utf-8"), check=False)


@click.command
@uses_new_session
@command_uses_new_session
def main(obj):
"""Start interactive terminal user interface session"""
api = obj.api
Expand Down
59 changes: 55 additions & 4 deletions src/chap/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import pathlib
import pkgutil
import subprocess
from dataclasses import dataclass, fields
from dataclasses import MISSING, dataclass, fields

import click
import platformdirs
from simple_parsing.docstring import get_attribute_docstring

from . import commands # pylint: disable=no-name-in-module
from .session import Session
Expand Down Expand Up @@ -88,7 +89,40 @@ def set_system_message(ctx, param, value): # pylint: disable=unused-argument


def set_backend(ctx, param, value): # pylint: disable=unused-argument
ctx.obj.api = get_api(value)
try:
ctx.obj.api = get_api(value)
except ModuleNotFoundError as e:
raise click.BadParameter(str(e))


def format_backend_help(api, formatter):
with formatter.section(f"Backend options for {api.__class__.__name__}"):
rows = []
for f in fields(api.parameters):
name = f.name.replace("_", "-")
default = f.default if f.default_factory is MISSING else f.default_factory()
doc = get_attribute_docstring(type(api.parameters), f.name).docstring_below
if doc:
doc += " "
doc += f"(Default: {default})"
rows.append((f"-B {name}:{f.type.__name__.upper()}", doc))
formatter.write_dl(rows)


def backend_help(ctx, param, value): # pylint: disable=unused-argument
if ctx.resilient_parsing or not value:
return

api = ctx.obj.api or get_api()

if not hasattr(api, "parameters"):
click.utils.echo(f"{api.__class__.__name__} does not support parameters")
else:
formatter = ctx.make_formatter()
format_backend_help(api, formatter)
click.utils.echo(formatter.getvalue().rstrip("\n"))

ctx.exit()


def set_backend_option(ctx, param, opts): # pylint: disable=unused-argument
Expand All @@ -97,7 +131,7 @@ def set_backend_option(ctx, param, opts): # pylint: disable=unused-argument
raise click.BadParameter(
f"{api.__class__.__name__} does not support parameters"
)
all_fields = dict((f.name, f) for f in fields(api.parameters))
all_fields = dict((f.name.replace("_", "-"), f) for f in fields(api.parameters))

def set_one_backend_option(kv):
name, value = kv
Expand Down Expand Up @@ -137,7 +171,15 @@ def uses_existing_session(f):
return f


def uses_new_session(f):
class CommandWithBackendHelp(click.Command):
def format_options(self, ctx, formatter):
super().format_options(ctx, formatter)
api = ctx.obj.api or get_api()
if hasattr(api, "parameters"):
format_backend_help(api, formatter)


def command_uses_new_session(f):
f = click.option(
"--system-message",
"-S",
Expand All @@ -155,6 +197,14 @@ def uses_new_session(f):
expose_value=False,
is_eager=True,
)(f)
f = click.option(
"--backend-help",
is_flag=True,
is_eager=True,
callback=backend_help,
expose_value=False,
help="Show information about backend options",
)(f)
f = click.option(
"--backend-option",
"-B",
Expand All @@ -172,6 +222,7 @@ def uses_new_session(f):
callback=do_session_new,
expose_value=False,
)(f)
f = click.command(cls=CommandWithBackendHelp)(f)
return f


Expand Down

0 comments on commit 02de0b3

Please sign in to comment.