forked from sgl-project/sglang
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Together and AzureOpenAI examples (sgl-project#184)
- Loading branch information
1 parent
9312132
commit bb824da
Showing
8 changed files
with
262 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
""" | ||
Usage: | ||
export AZURE_OPENAI_API_KEY=sk-****** | ||
python3 openai_example_chat.py | ||
""" | ||
import sglang as sgl | ||
import os | ||
|
||
|
||
@sgl.function | ||
def multi_turn_question(s, question_1, question_2): | ||
s += sgl.system("You are a helpful assistant.") | ||
s += sgl.user(question_1) | ||
s += sgl.assistant(sgl.gen("answer_1", max_tokens=256)) | ||
s += sgl.user(question_2) | ||
s += sgl.assistant(sgl.gen("answer_2", max_tokens=256)) | ||
|
||
|
||
def single(): | ||
state = multi_turn_question.run( | ||
question_1="What is the capital of the United States?", | ||
question_2="List two local attractions.", | ||
) | ||
|
||
for m in state.messages(): | ||
print(m["role"], ":", m["content"]) | ||
|
||
print("\n-- answer_1 --\n", state["answer_1"]) | ||
|
||
|
||
def stream(): | ||
state = multi_turn_question.run( | ||
question_1="What is the capital of the United States?", | ||
question_2="List two local attractions.", | ||
stream=True | ||
) | ||
|
||
for out in state.text_iter(): | ||
print(out, end="", flush=True) | ||
print() | ||
|
||
|
||
def batch(): | ||
states = multi_turn_question.run_batch([ | ||
{"question_1": "What is the capital of the United States?", | ||
"question_2": "List two local attractions."}, | ||
|
||
{"question_1": "What is the capital of France?", | ||
"question_2": "What is the population of this city?"}, | ||
]) | ||
|
||
for s in states: | ||
print(s.messages()) | ||
|
||
|
||
if __name__ == "__main__": | ||
backend = sgl.OpenAI( | ||
model_name="azure-gpt-4", | ||
api_version="2023-07-01-preview", | ||
azure_endpoint="https://oai-arena-sweden.openai.azure.com/", | ||
api_key=os.environ["AZURE_OPENAI_API_KEY"], | ||
is_azure=True, | ||
) | ||
sgl.set_default_backend(backend) | ||
|
||
# Run a single request | ||
print("\n========== single ==========\n") | ||
single() | ||
|
||
# Stream output | ||
print("\n========== stream ==========\n") | ||
stream() | ||
|
||
# Run a batch of requests | ||
print("\n========== batch ==========\n") | ||
batch() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
""" | ||
Usage: | ||
export TOGETHER_API_KEY=sk-****** | ||
python3 together_example_chat.py | ||
""" | ||
import sglang as sgl | ||
import os | ||
|
||
|
||
@sgl.function | ||
def multi_turn_question(s, question_1, question_2): | ||
s += sgl.system("You are a helpful assistant.") | ||
s += sgl.user(question_1) | ||
s += sgl.assistant(sgl.gen("answer_1", max_tokens=256)) | ||
s += sgl.user(question_2) | ||
s += sgl.assistant(sgl.gen("answer_2", max_tokens=256)) | ||
|
||
|
||
def single(): | ||
state = multi_turn_question.run( | ||
question_1="What is the capital of the United States?", | ||
question_2="List two local attractions.", | ||
) | ||
|
||
for m in state.messages(): | ||
print(m["role"], ":", m["content"]) | ||
|
||
print("\n-- answer_1 --\n", state["answer_1"]) | ||
|
||
|
||
def stream(): | ||
state = multi_turn_question.run( | ||
question_1="What is the capital of the United States?", | ||
question_2="List two local attractions.", | ||
stream=True | ||
) | ||
|
||
for out in state.text_iter(): | ||
print(out, end="", flush=True) | ||
print() | ||
|
||
|
||
def batch(): | ||
states = multi_turn_question.run_batch([ | ||
{"question_1": "What is the capital of the United States?", | ||
"question_2": "List two local attractions."}, | ||
|
||
{"question_1": "What is the capital of France?", | ||
"question_2": "What is the population of this city?"}, | ||
]) | ||
|
||
for s in states: | ||
print(s.messages()) | ||
|
||
|
||
if __name__ == "__main__": | ||
backend = sgl.OpenAI( | ||
model_name="mistralai/Mixtral-8x7B-Instruct-v0.1", | ||
base_url="https://api.together.xyz/v1", | ||
api_key=os.environ.get("TOGETHER_API_KEY"), | ||
) | ||
sgl.set_default_backend(backend) | ||
|
||
# Run a single request | ||
print("\n========== single ==========\n") | ||
single() | ||
|
||
# Stream output | ||
print("\n========== stream ==========\n") | ||
stream() | ||
|
||
# Run a batch of requests | ||
print("\n========== batch ==========\n") | ||
batch() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
""" | ||
Usage: | ||
export TOGETHER_API_KEY=sk-****** | ||
python3 together_example_complete.py | ||
""" | ||
|
||
import sglang as sgl | ||
import os | ||
|
||
|
||
@sgl.function | ||
def few_shot_qa(s, question): | ||
s += ( | ||
"""The following are questions with answers. | ||
Q: What is the capital of France? | ||
A: Paris | ||
Q: What is the capital of Germany? | ||
A: Berlin | ||
Q: What is the capital of Italy? | ||
A: Rome | ||
""") | ||
s += "Q: " + question + "\n" | ||
s += "A:" + sgl.gen("answer", stop="\n", temperature=0) | ||
|
||
|
||
def single(): | ||
state = few_shot_qa.run(question="What is the capital of the United States?") | ||
answer = state["answer"].strip().lower() | ||
|
||
assert "washington" in answer, f"answer: {state['answer']}" | ||
|
||
print(state.text()) | ||
|
||
|
||
def stream(): | ||
state = few_shot_qa.run( | ||
question="What is the capital of the United States?", | ||
stream=True) | ||
|
||
for out in state.text_iter("answer"): | ||
print(out, end="", flush=True) | ||
print() | ||
|
||
|
||
def batch(): | ||
states = few_shot_qa.run_batch([ | ||
{"question": "What is the capital of the United States?"}, | ||
{"question": "What is the capital of China?"}, | ||
]) | ||
|
||
for s in states: | ||
print(s["answer"]) | ||
|
||
|
||
if __name__ == "__main__": | ||
backend = sgl.OpenAI( | ||
model_name="mistralai/Mixtral-8x7B-Instruct-v0.1", | ||
is_chat_model=False, | ||
base_url="https://api.together.xyz/v1", | ||
api_key=os.environ.get("TOGETHER_API_KEY"), | ||
) | ||
sgl.set_default_backend(backend) | ||
|
||
# Run a single request | ||
print("\n========== single ==========\n") | ||
single() | ||
|
||
# Stream output | ||
print("\n========== stream ==========\n") | ||
stream() | ||
|
||
# Run a batch of requests | ||
print("\n========== batch ==========\n") | ||
batch() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters