Skip to content

Commit

Permalink
Merge pull request #14 from jepler/huggingface
Browse files Browse the repository at this point in the history
Add huggingface back-end
  • Loading branch information
jepler authored Sep 29, 2023
2 parents 9919c9a + b6fa44f commit f3bf17c
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 11 deletions.
117 changes: 117 additions & 0 deletions src/chap/backends/huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# SPDX-FileCopyrightText: 2023 Jeff Epler <[email protected]>
#
# SPDX-License-Identifier: MIT

import asyncio
import json
from dataclasses import dataclass

import httpx

from ..key import get_key
from ..session import Assistant, User


class HuggingFace:
@dataclass
class Parameters:
url: str = "https://api-inference.huggingface.co"
model: str = "mistralai/Mistral-7B-Instruct-v0.1"
max_new_tokens: int = 250
start_prompt: str = """<s>[INST] <<SYS>>\n"""
after_system: str = "\n<</SYS>>\n\n"
after_user: str = """ [/INST] """
after_assistant: str = """ </s><s>[INST] """
stop_token_id = 2

def __init__(self):
self.parameters = self.Parameters()

system_message = """\
A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, and knows its own limits.
"""

def make_full_query(self, messages, max_query_size):
del messages[1:-max_query_size]
result = [self.parameters.start_prompt]
for m in messages:
content = (m.content or "").strip()
if not content:
continue
result.append(content)
if m.role == "system":
result.append(self.parameters.after_system)
elif m.role == "assistant":
result.append(self.parameters.after_assistant)
elif m.role == "user":
result.append(self.parameters.after_user)
full_query = "".join(result)
return full_query

async def chained_query(self, inputs, timeout):
async with httpx.AsyncClient(timeout=timeout) as client:
while inputs:
params = {
"inputs": inputs,
"stream": True,
}
inputs = None
async with client.stream(
"POST",
f"{self.parameters.url}/models/{self.parameters.model}",
json=params,
headers={
"Authorization": f"Bearer {self.get_key()}",
},
) as response:
if response.status_code == 200:
async for line in response.aiter_lines():
if line.startswith("data:"):
data = line.removeprefix("data:").strip()
j = json.loads(data)
token = j.get("token", {})
inputs = j.get("generated_text", inputs)
if token.get("id") == self.parameters.stop_token_id:
return
yield token.get("text", "")
else:
yield f"\nFailed with {response=!r}"
return

async def aask(
self, session, query, *, max_query_size=5, timeout=180
): # pylint: disable=unused-argument,too-many-locals,too-many-branches
new_content = []
inputs = self.make_full_query(session.session + [User(query)], max_query_size)
try:
async for content in self.chained_query(inputs, timeout=timeout):
if not new_content:
content = content.lstrip()
if content:
if not new_content:
content = content.lstrip()
if content:
new_content.append(content)
yield content

except httpx.HTTPError as e:
content = f"\nException: {e!r}"
new_content.append(content)
yield content

session.session.extend([User(query), Assistant("".join(new_content))])

def ask(self, session, query, *, max_query_size=5, timeout=60):
asyncio.run(
self.aask(session, query, max_query_size=max_query_size, timeout=timeout)
)
return session.session[-1].message

@classmethod
def get_key(cls):
return get_key("huggingface_api_token")


def factory():
"""Uses the huggingface text-generation-interface web API"""
return HuggingFace()
20 changes: 13 additions & 7 deletions src/chap/backends/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ class Parameters:
url: str = "http://localhost:8080/completion"
"""The URL of a llama.cpp server's completion endpoint."""

start_prompt: str = """<s>[INST] <<SYS>>\n"""
after_system: str = "\n<</SYS>>\n\n"
after_user: str = """ [/INST] """
after_assistant: str = """ </s><s>[INST] """

def __init__(self):
self.parameters = self.Parameters()

Expand All @@ -26,29 +31,30 @@ def __init__(self):

def make_full_query(self, messages, max_query_size):
del messages[1:-max_query_size]
rows = []
result = [self.parameters.start_prompt]
for m in messages:
content = (m.content or "").strip()
if not content:
continue
result.append(content)
if m.role == "system":
rows.append(f"ASSISTANT'S RULE: {content}\n")
result.append(self.parameters.after_system)
elif m.role == "assistant":
rows.append(f"ASSISTANT: {content}\n")
result.append(self.parameters.after_assistant)
elif m.role == "user":
rows.append(f"USER: {content}")
rows.append("ASSISTANT: ")
full_query = ("\n".join(rows)).rstrip()
result.append(self.parameters.after_user)
full_query = "".join(result)
return full_query

async def aask(
self, session, query, *, max_query_size=5, timeout=60
self, session, query, *, max_query_size=5, timeout=180
): # pylint: disable=unused-argument,too-many-locals,too-many-branches
params = {
"prompt": self.make_full_query(
session.session + [User(query)], max_query_size
),
"stream": True,
"stop": ["</s>", "<s>", "[INST]"],
}
new_content = []
try:
Expand Down
9 changes: 9 additions & 0 deletions src/chap/commands/tui.css
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@
* SPDX-License-Identifier: MIT
*/

.role_user.history_exclude, .role_assistant.history_exclude {
color: $text-disabled;
border-left: dashed $primary;
}
.role_assistant.history_exclude:focus-within {
color: $text-disabled;
border-left: dashed $secondary;
}

.role_system {
text-style: italic;
color: $text-muted;
Expand Down
27 changes: 24 additions & 3 deletions src/chap/commands/tui.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class Markdown(
Binding("ctrl+y", "yank", "Yank text", show=True),
Binding("ctrl+r", "resubmit", "resubmit", show=True),
Binding("ctrl+x", "delete", "delete to end", show=True),
Binding("ctrl+q", "toggle_history", "history toggle", show=True),
]


Expand All @@ -43,7 +44,7 @@ def markdown_for_step(step):
class Tui(App):
CSS_PATH = "tui.css"
BINDINGS = [
Binding("ctrl+q", "app.quit", "Quit", show=True, priority=True),
Binding("ctrl+c", "app.quit", "Quit", show=True, priority=True),
]

def __init__(self, api=None, session=None):
Expand Down Expand Up @@ -82,6 +83,12 @@ async def on_input_submitted(self, event) -> None:
tokens = []
update = asyncio.Queue(1)

# Construct a fake session with only select items
session = Session()
for si, wi in zip(self.session.session, self.container.children):
if not wi.has_class("history_exclude"):
session.session.append(si)

async def render_fun():
while await update.get():
if tokens:
Expand All @@ -90,7 +97,7 @@ async def render_fun():
await asyncio.sleep(0.1)

async def get_token_fun():
async for token in self.api.aask(self.session, event.value):
async for token in self.api.aask(session, event.value):
tokens.append(token)
try:
update.put_nowait(True)
Expand All @@ -102,6 +109,7 @@ async def get_token_fun():
await asyncio.gather(render_fun(), get_token_fun())
self.input.value = ""
finally:
self.session.session.extend(session.session[-2:])
all_output = self.session.session[-1].content
output.update(all_output)
output._markdown = all_output # pylint: disable=protected-access
Expand All @@ -118,6 +126,19 @@ def action_yank(self):
content = widget._markdown # pylint: disable=protected-access
subprocess.run(["xsel", "-ib"], input=content.encode("utf-8"), check=False)

def action_toggle_history(self):
widget = self.focused
if not isinstance(widget, Markdown):
return
children = self.container.children
idx = children.index(widget)
while idx > 1 and not "role_user" in children[idx].classes:
idx -= 1
widget = children[idx]

children[idx].toggle_class("history_exclude")
children[idx + 1].toggle_class("history_exclude")

async def action_resubmit(self):
await self.delete_or_resubmit(True)

Expand All @@ -130,7 +151,7 @@ async def delete_or_resubmit(self, resubmit):
return
children = self.container.children
idx = children.index(widget)
while idx > 1 and not "role_user" in children[idx].classes:
while idx > 1 and not children[idx].has_class("role_user"):
idx -= 1
widget = children[idx]

Expand Down
2 changes: 1 addition & 1 deletion src/chap/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def format_backend_help(api, formatter):
doc = get_attribute_docstring(type(api.parameters), f.name).docstring_below
if doc:
doc += " "
doc += f"(Default: {default})"
doc += f"(Default: {default!r})"
rows.append((f"-B {name}:{f.type.__name__.upper()}", doc))
formatter.write_dl(rows)

Expand Down

0 comments on commit f3bf17c

Please sign in to comment.