Skip to content

Commit

Permalink
update libraries, fill box space and formatting (#1551)
Browse files Browse the repository at this point in the history
  • Loading branch information
kenthua authored Dec 5, 2024
1 parent 32872d6 commit bfe2bde
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 120 deletions.
2 changes: 1 addition & 1 deletion ai-ml/llm-serving-gemma/gradio/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

FROM python:3.13.0-alpine3.19
FROM python:3.13.1-alpine3.20
ADD app/ .
RUN pip install -r requirements.txt
CMD ["python", "./app.py"]
245 changes: 128 additions & 117 deletions ai-ml/llm-serving-gemma/gradio/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,139 +17,150 @@
import os

if "MODEL_ID" in os.environ:
model_id = os.environ["MODEL_ID"]
model_id = os.environ["MODEL_ID"]
else:
model_id = "gradio"
model_id = "gradio"

disable_system_message = False
if "DISABLE_SYSTEM_MESSAGE" in os.environ:
disable_system_message = os.environ["DISABLE_SYSTEM_MESSAGE"]
disable_system_message = os.environ["DISABLE_SYSTEM_MESSAGE"]


def inference_interface(message, history, model_temperature, top_p, max_tokens):

json_message = {}

# Need to determine the engine to determine input/output formats
if "LLM_ENGINE" in os.environ:
llm_engine = os.environ["LLM_ENGINE"]
else:
llm_engine = "openai-chat"

match llm_engine:
case "max":
json_message.update({"temperature": model_temperature})
json_message.update({"top_p": top_p})
json_message.update({"max_tokens": max_tokens})
final_message = process_message(message, history)

json_message.update({"prompt": final_message})
json_data = post_request(json_message)

temp_output = json_data["response"]
output = temp_output
case "vllm":
json_message.update({"temperature": model_temperature})
json_message.update({"top_p": top_p})
json_message.update({"max_tokens": max_tokens})
final_message = process_message(message, history)

json_message.update({"prompt": final_message})
json_data = post_request(json_message)

temp_output = json_data["predictions"][0]
output = temp_output.split("Output:\n", 1)[1]
case "tgi":
json_message.update({"parameters": {}})
json_message["parameters"].update({"temperature": model_temperature})
json_message["parameters"].update({"top_p": top_p})
json_message["parameters"].update({"max_new_tokens": max_tokens})
final_message = process_message(message, history)

json_message.update({"inputs": final_message})
json_data = post_request(json_message)

temp_output = json_data["generated_text"]
output = temp_output
case _:
print("* History: " + str(history))
json_message.update({"model": model_id})
json_message.update({"messages": []})
# originally this was defaulted, so user would have to manually set this value to disable the prompt
if not disable_system_message:
system_message = {"role": "system", "content": "You are a helpful assistant."}
json_message["messages"].append(system_message)

json_message['temperature'] = model_temperature

if len(history) > 0:
# we have history
print("** Before adding additional messages: " + str(json_message['messages']))
for item in history:
user_message = {"role": "user", "content": item[0]}
assistant_message = {"role": "assistant", "content": item[1]}
json_message["messages"].append(user_message)
json_message["messages"].append(assistant_message)
json_message = {}

# Need to determine the engine to determine input/output formats
if "LLM_ENGINE" in os.environ:
llm_engine = os.environ["LLM_ENGINE"]
else:
llm_engine = "openai-chat"

match llm_engine:
case "max":
json_message.update({"temperature": model_temperature})
json_message.update({"top_p": top_p})
json_message.update({"max_tokens": max_tokens})
final_message = process_message(message, history)

json_message.update({"prompt": final_message})
json_data = post_request(json_message)

temp_output = json_data["response"]
output = temp_output
case "vllm":
json_message.update({"temperature": model_temperature})
json_message.update({"top_p": top_p})
json_message.update({"max_tokens": max_tokens})
final_message = process_message(message, history)

json_message.update({"prompt": final_message})
json_data = post_request(json_message)

temp_output = json_data["predictions"][0]
output = temp_output.split("Output:\n", 1)[1]
case "tgi":
json_message.update({"parameters": {}})
json_message["parameters"].update({"temperature": model_temperature})
json_message["parameters"].update({"top_p": top_p})
json_message["parameters"].update({"max_new_tokens": max_tokens})
final_message = process_message(message, history)

json_message.update({"inputs": final_message})
json_data = post_request(json_message)

temp_output = json_data["generated_text"]
output = temp_output
case _:
print("* History: " + str(history))
json_message.update({"model": model_id})
json_message.update({"messages": []})
# originally this was defaulted, so user would have to manually set this value to disable the prompt
if not disable_system_message:
system_message = {
"role": "system",
"content": "You are a helpful assistant.",
}
json_message["messages"].append(system_message)

json_message["temperature"] = model_temperature

if len(history) > 0:
# we have history
print(
"** Before adding additional messages: "
+ str(json_message["messages"])
)
for item in history:
user_message = {"role": "user", "content": item[0]}
assistant_message = {"role": "assistant", "content": item[1]}
json_message["messages"].append(user_message)
json_message["messages"].append(assistant_message)

new_user_message = {"role": "user", "content": message}
json_message["messages"].append(new_user_message)

json_data = post_request(json_message)
output = json_data["choices"][0]["message"]["content"]

return output

new_user_message = {"role": "user", "content": message}
json_message["messages"].append(new_user_message)

json_data = post_request(json_message)
output = json_data["choices"][0]["message"]["content"]
def process_message(message, history):
user_prompt_format = ""
system_prompt_format = ""

return output
# if env prompts are set, use those
if "USER_PROMPT" in os.environ:
user_prompt_format = os.environ["USER_PROMPT"]

def process_message(message, history):
user_prompt_format = ""
system_prompt_format = ""
if "SYSTEM_PROMPT" in os.environ:
system_prompt_format = os.environ["SYSTEM_PROMPT"]

# if env prompts are set, use those
if "USER_PROMPT" in os.environ:
user_prompt_format = os.environ["USER_PROMPT"]
print("* History: " + str(history))

if "SYSTEM_PROMPT" in os.environ:
system_prompt_format = os.environ["SYSTEM_PROMPT"]
user_message = ""
system_message = ""
history_message = ""

print("* History: " + str(history))
if len(history) > 0:
# we have history
for item in history:
user_message = user_prompt_format.replace("prompt", item[0])
system_message = system_prompt_format.replace("prompt", item[1])
history_message = history_message + user_message + system_message

user_message = ""
system_message = ""
history_message = ""
new_user_message = user_prompt_format.replace("prompt", message)

if len(history) > 0:
# we have history
for item in history:
user_message = user_prompt_format.replace("prompt", item[0])
system_message = system_prompt_format.replace("prompt", item[1])
history_message = history_message + user_message + system_message
# append the history with the new message and close with the turn
aggregated_message = history_message + new_user_message
return aggregated_message

new_user_message = user_prompt_format.replace("prompt", message)

# append the history with the new message and close with the turn
aggregated_message = history_message + new_user_message
return aggregated_message

def post_request(json_message):
print("*** Request" + str(json_message), flush=True)
response = requests.post(os.environ["HOST"] + os.environ["CONTEXT_PATH"], json=json_message)
json_data = response.json()
print("*** Output: " + str(json_data), flush=True)
return json_data

with gr.Blocks() as app:
html_text = "You are chatting with: " + model_id
gr.HTML(value=html_text)

model_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Temperature", render=False)
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top_p", render=False)
max_tokens = gr.Slider(minimum=1, maximum=4096, value=256, label="Max Tokens", render=False)

gr.ChatInterface(
inference_interface,
additional_inputs=[
model_temperature,
top_p,
max_tokens
]
)

app.launch(server_name="0.0.0.0")
print("*** Request" + str(json_message), flush=True)
response = requests.post(
os.environ["HOST"] + os.environ["CONTEXT_PATH"], json=json_message
)
json_data = response.json()
print("*** Output: " + str(json_data), flush=True)
return json_data


with gr.Blocks(fill_height=True) as app:
html_text = "You are chatting with: " + model_id
gr.HTML(value=html_text)

model_temperature = gr.Slider(
minimum=0.1, maximum=1.0, value=0.9, label="Temperature", render=False
)
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top_p", render=False)
max_tokens = gr.Slider(
minimum=1, maximum=4096, value=256, label="Max Tokens", render=False
)

gr.ChatInterface(
inference_interface, additional_inputs=[model_temperature, top_p, max_tokens]
)

app.launch(server_name="0.0.0.0")
3 changes: 1 addition & 2 deletions ai-ml/llm-serving-gemma/gradio/app/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
gradio==5.5.0

gradio==5.8.0

0 comments on commit bfe2bde

Please sign in to comment.