Skip to content

Commit

Permalink
Chatbot interface for models using Gradio (#10)
Browse files Browse the repository at this point in the history
* Add chat interface for the original Gemma-2-9B-it model and tuned model

* Fix type errors raised by mypy
  • Loading branch information
chiffonng authored Dec 12, 2024
1 parent bee206b commit bc00af1
Show file tree
Hide file tree
Showing 15 changed files with 425 additions and 37 deletions.
1 change: 1 addition & 0 deletions .env.template
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
OPENAI_API_KEY=sk-proj-
HF_TOKEN=hf_B-
WANDB_API_KEY=
PYTHONPATH=.
38 changes: 38 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -1,2 +1,40 @@
# Ignore Jupyter Notebooks from Github Linguist Stats
*.ipynb linguist-vendored

# Ignore Large File Storage objects
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
*.csv filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.mlmodel filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
*.npz filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tar filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,6 @@ cython_debug/
# Data
/data
/temp
*.parquet
*.csv

# Write up
*pdf/
Expand Down
21 changes: 20 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,26 @@ repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-added-large-files
- id: check-executables-have-shebangs
- id: check-json
- id: check-merge-conflict
- id: check-toml
- id: check-yaml
- id: end-of-file-fixer
- id: mixed-line-ending
args: ["--fix=lf"]
- id: requirements-txt-fixer
- id: trailing-whitespace
- id: check-added-large-files
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.13.0
hooks:
- id: mypy
args: ["--ignore-missing-imports"]
additional_dependencies:
[
"types-python-slugify",
"types-requests",
"types-PyYAML",
"types-pytz",
]
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Otherwise, you can try the setup script:
bash setup.sh
```

It attempts to install with [uv](https://docs.astral.sh/uv/) (a fast, Rust-based Python package and project manager) using `.python-version` file and `pyproject.toml` file. This is the recommended way to manage the project, since its resolver is faster and more reliable than `pip`.
It attempts to install with [uv](https://docs.astral.sh/uv/) (a fast, Rust-based Python package and project manager) using `pyproject.toml` file. This is the recommended way to manage the project, since its dependency resolver is faster and more reliable than `pip`.

Otherwise, it falls back to `pip` installation.

Expand Down
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@ description = "Generate mnemonic sentences for English words"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"accelerate>=1.0.1",
"datasets", # Hugging Face datasets
"evaluate", # HF evaluation
"gradio>=4.26.0", # Web app
"gradio>=4.26.0",
"hf-transfer>=0.1.8", # Web app
"numpy<2.0.0", # Wait for other packages to update
"openai>=1.57.0",
"peft", # HF parameter-efficient training
"pre-commit>=4.0.1", # Pre-commit hooks
"python-dotenv>=1.0.1", # Load environment variables
"pyyaml>=6.0.2", # YAML config
"ruff>=0.7.1",
"spaces>=0.31.0",
"tenacity>=9.0.0", # Retry (e.g. API calls)
"torch>=2.5.1", # PyTorch
"torch>=2.4.0", # PyTorch
"tqdm>=4.67.1", # Progress bar
"transformers", # HF transformers
"trl", # HF transformer reinforcement learning
Expand Down
10 changes: 10 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ gitdb==4.0.11
gitpython==3.1.43
# via wandb
gradio==4.26.0
# via spaces
gradio-client==0.15.1
# via gradio
h11==0.14.0
Expand All @@ -111,6 +112,8 @@ httpx==0.27.2
# gradio
# gradio-client
# openai
# safehttpx
# spaces
huggingface-hub==0.26.1
# via
# accelerate
Expand Down Expand Up @@ -196,6 +199,7 @@ packaging==24.1
# huggingface-hub
# matplotlib
# peft
# spaces
# transformers
pandas==2.2.3
# via
Expand All @@ -220,6 +224,7 @@ psutil==5.9.8
# via
# accelerate
# peft
# spaces
# wandb
pyarrow==18.1.0
# via datasets
Expand All @@ -230,6 +235,7 @@ pydantic==2.9.2
# fastapi
# gradio
# openai
# spaces
# wandb
pydantic-core==2.23.4
# via pydantic
Expand Down Expand Up @@ -269,6 +275,7 @@ requests==2.31.0
# datasets
# evaluate
# huggingface-hub
# spaces
# transformers
# wandb
rich==13.9.3
Expand All @@ -281,6 +288,7 @@ rpds-py==0.22.3
# referencing
ruff==0.8.2
# via gradio
safehttpx==0.1.6
safetensors==0.4.5
# via
# accelerate
Expand Down Expand Up @@ -309,6 +317,7 @@ sniffio==1.3.1
# anyio
# httpx
# openai
spaces==0.31.0
starlette==0.41.2
# via fastapi
sympy==1.13.1
Expand Down Expand Up @@ -353,6 +362,7 @@ typing-extensions==4.12.2
# pydantic
# pydantic-core
# rich
# spaces
# torch
# typeguard
# typer
Expand Down
14 changes: 14 additions & 0 deletions src/app/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
---
title: Gemma 2 9B IT
emoji: 😻
colorFrom: indigo
colorTo: pink
sdk: gradio
sdk_version: 5.8.0
python_version: 3.10
app_file: app.py
pinned: false
short_description: Chatbot
---

Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
162 changes: 162 additions & 0 deletions src/app/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
"""Chat interface demo for Google Gemma 2 9B IT model.
Cloned and adapted from the demo: https://huggingface.co/spaces/huggingface-projects/gemma-2-9b-it/tree/main/app.py
"""

import os
from threading import Thread
from typing import Iterator

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

DESCRIPTION = """
This is a demo for the Google Gemma 2 9B IT model. Use it to generate mnemonics for English words you want to learn and remember.
Input your instructions or start with one of the examples provided. The input supports a subset of markdown formatting such as bold, italics, code, tables. You can also use the following special tokens to customize the mnemonic:
"""

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_id = "google/gemma-2-9b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model.config.sliding_window = 4096
model.eval()


@spaces.GPU(duration=90)
def generate(
message: str,
chat_history: list[dict],
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
"""Generate a response to a message using the model.
Args:
message: The message to respond to.
chat_history: The conversation history.
max_new_tokens: The maximum number of tokens to generate.
temperature: The temperature for sampling.
top_p: The top-p value for nucleus sampling.
top_k: The top-k value for sampling.
repetition_penalty: The repetition penalty.
Yields:
Iterator[str]: The generated response.
"""
conversation = chat_history.copy()
conversation.append({"role": "user", "content": message})

input_ids = tokenizer.apply_chat_template(
conversation, add_generation_prompt=True, return_tensors="pt"
)
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(
f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens."
)
input_ids = input_ids.to(model.device)

streamer = TextIteratorStreamer(
tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()

outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)


chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
stop_btn=True,
examples=[
[
"Produce a cue to help me learn and retrieve the meaning of this word whenever I look at it (and nothing else): preposterous"
],
[
"Create a cue that elicits vivid mental image for the word 'observient' so I could remember its meaning."
],
[
"I need a mnemonic for 'dilapidated' to learn its meaning and contextual usage."
],
[
"Help me remember the meaning of 'encapsulate' by connecting it to its etymology or related words."
],
],
cache_examples=False,
type="messages",
)

with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
gr.Markdown(DESCRIPTION)
(chat_interface.render(),)
gr.ClearButton(elem_id="clear-button")


if __name__ == "__main__":
demo.queue(max_size=20).launch()
Loading

0 comments on commit bc00af1

Please sign in to comment.