Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ feat: add deploy #177

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ python finetune.py \

</details>

### 部署

文件夹下`web_demo.py`,`web_demo2.py`与`api.py`的使用方法与官方的一致。

使用前需要修改`peft_path`为你自己训练的模型路径,修改`peft_config`中的`r`与你训练时的`--lora_rank`一致。


## S2. Reward Model

Expand Down
85 changes: 85 additions & 0 deletions api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel
import uvicorn
import json
import datetime
import torch
from peft import get_peft_model, LoraConfig, TaskType
import argparse

DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE


def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(CUDA_DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()


app = FastAPI()


@app.post("/")
async def chat(request: Request):
global model, tokenizer
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
history = json_post_list.get('history')
max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
response, history = model.chat(tokenizer,
prompt,
history=history,
max_length=max_length if max_length else 2048,
top_p=top_p if top_p else 0.7,
temperature=temperature if temperature else 0.95)
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = {
"response": response,
"history": history,
"status": 200,
"time": time
}
log = f'[{time}] ", prompt:"{prompt}", response:"{repr(response)}"'
print(log)
torch_gc()
return answer


parser = argparse.ArgumentParser()
parser.add_argument('--peft_path', type=str,
default='output/adapter_model.bin')
parser.add_argument('--r', type=int, default=8)
parser.add_argument('--host', type=str, default='localhost')
parser.add_argument('--port', type=int, default=8000)
parser.add_argument('--workers', type=int, default=1)

if __name__ == '__main__':
args = parser.parse_args()
torch.set_default_tensor_type(torch.cuda.HalfTensor)
tokenizer = AutoTokenizer.from_pretrained(
"THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained(
"THUDM/chatglm-6b", trust_remote_code=True).half().cuda()

peft_path = args.peft_path

peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, inference_mode=True,
r=args.r,
lora_alpha=32, lora_dropout=0.1
)

model = get_peft_model(model, peft_config)
model.load_state_dict(torch.load(peft_path), strict=False)

model.eval()

uvicorn.run(app, host=args.host, port=args.port, workers=args.workers)
124 changes: 124 additions & 0 deletions web_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from transformers import AutoModel, AutoTokenizer
import gradio as gr
import mdtex2html
from peft import get_peft_model, LoraConfig, TaskType
import torch
import argparse

tokenizer = AutoTokenizer.from_pretrained(
"THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained(
"THUDM/chatglm-6b", trust_remote_code=True).half().cuda()

parser = argparse.ArgumentParser()
parser.add_argument('--peft_path', type=str,
default='output/adapter_model.bin')
parser.add_argument('--r', type=int, default=8)
args = parser.parse_args()

peft_path = args.peft_path
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, inference_mode=True,
r=args.r,
lora_alpha=32, lora_dropout=0.1
)
model = get_peft_model(model, peft_config)
model.load_state_dict(torch.load(peft_path), strict=False)
model.eval()

"""Override Chatbot.postprocess"""


def postprocess(self, y):
if y is None:
return []
for i, (message, response) in enumerate(y):
y[i] = (
None if message is None else mdtex2html.convert((message)),
None if response is None else mdtex2html.convert(response),
)
return y


gr.Chatbot.postprocess = postprocess


def parse_text(text):
"""copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split('`')
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f'<br></code></pre>'
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", "\`")
line = line.replace("<", "&lt;")
line = line.replace(">", "&gt;")
line = line.replace(" ", "&nbsp;")
line = line.replace("*", "&ast;")
line = line.replace("_", "&lowbar;")
line = line.replace("-", "&#45;")
line = line.replace(".", "&#46;")
line = line.replace("!", "&#33;")
line = line.replace("(", "&#40;")
line = line.replace(")", "&#41;")
line = line.replace("$", "&#36;")
lines[i] = "<br>"+line
text = "".join(lines)
return text


def predict(input, chatbot, max_length, top_p, temperature, history):
chatbot.append((parse_text(input), ""))
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
temperature=temperature):
chatbot[-1] = (parse_text(input), parse_text(response))

yield chatbot, history


def reset_user_input():
return gr.update(value='')


def reset_state():
return [], []


with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">ChatGLM</h1>""")

chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(
0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.7, step=0.01,
label="Top P", interactive=True)
temperature = gr.Slider(
0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)

history = gr.State([])

submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
show_progress=True)
submitBtn.click(reset_user_input, [], [user_input])

emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)

demo.queue().launch(share=False, inbrowser=True)
89 changes: 89 additions & 0 deletions web_demo2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from transformers import AutoModel, AutoTokenizer
import streamlit as st
from streamlit_chat import message
from peft import get_peft_model, LoraConfig, TaskType
import torch
import argparse

st.set_page_config(
page_title="ChatGLM-6b 演示",
page_icon=":robot:"
)

parser = argparse.ArgumentParser()
parser.add_argument('--peft_path', type=str,
default='output/adapter_model.bin')
parser.add_argument('--r', type=int, default=8)
args = parser.parse_args()


@st.cache_resource
def get_model():
tokenizer = AutoTokenizer.from_pretrained(
"THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained(
"THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
peft_path = args.peft_path
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, inference_mode=True,
r=args.r,
lora_alpha=32, lora_dropout=0.1
)
model = get_peft_model(model, peft_config)
model.load_state_dict(torch.load(peft_path), strict=False)
model.eval()
return tokenizer, model


MAX_TURNS = 20
MAX_BOXES = MAX_TURNS * 2


def predict(input, max_length, top_p, temperature, history=None):
tokenizer, model = get_model()
if history is None:
history = []

with container:
if len(history) > 0:
for i, (query, response) in enumerate(history):
message(query, avatar_style="big-smile", key=str(i) + "_user")
message(response, avatar_style="bottts", key=str(i))

message(input, avatar_style="big-smile",
key=str(len(history)) + "_user")
st.write("AI正在回复:")
with st.empty():
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
temperature=temperature):
query, response = history[-1]
st.write(response)

return history


container = st.container()

# create a prompt text for the text generation
prompt_text = st.text_area(label="用户命令输入",
height=100,
placeholder="请在这儿输入您的命令")

max_length = st.sidebar.slider(
'max_length', 0, 4096, 2048, step=1
)
top_p = st.sidebar.slider(
'top_p', 0.0, 1.0, 0.6, step=0.01
)
temperature = st.sidebar.slider(
'temperature', 0.0, 1.0, 0.95, step=0.01
)

if 'state' not in st.session_state:
st.session_state['state'] = []

if st.button("发送", key="predict"):
with st.spinner("AI正在思考,请稍等........"):
# text generation
st.session_state["state"] = predict(
prompt_text, max_length, top_p, temperature, st.session_state["state"])