Skip to content

Commit

Permalink
Update quick start examples (sgl-project#120)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Jan 30, 2024
1 parent 4ea92f8 commit 0617528
Show file tree
Hide file tree
Showing 20 changed files with 555 additions and 225 deletions.
29 changes: 15 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,20 @@ pip install -e "python[all]"
- For NVIDIA V100, please install the [nightly](https://triton-lang.org/main/getting-started/installation.html) version.
- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`


## Quick Start
The example below shows how to use sglang to answer a mulit-turn question.

### Using OpenAI Models
Set the OpenAI API Key
### Using Local Models
First, launch a server with
```
export OPENAI_API_KEY=sk-******
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```

Then, answer a multi-turn question.
Then, connect to the server and answer a multi-turn question.

```python
from sglang import function, system, user, assistant, gen, set_default_backend, OpenAI
from sglang import function, system, user, assistant, gen, set_default_backend, RuntimeEndpoint

@function
def multi_turn_question(s, question_1, question_2):
Expand All @@ -60,7 +62,7 @@ def multi_turn_question(s, question_1, question_2):
s += user(question_2)
s += assistant(gen("answer_2", max_tokens=256))

set_default_backend(OpenAI("gpt-3.5-turbo"))
set_default_backend(RuntimeEndpoint("http://localhost:30000"))

state = multi_turn_question.run(
question_1="What is the capital of the United States?",
Expand All @@ -73,16 +75,15 @@ for m in state.messages():
print(state["answer_1"])
```

### Using Local Models
First, launch a server with
### Using OpenAI Models
Set the OpenAI API Key
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
export OPENAI_API_KEY=sk-******
```

Then, connect to the server and answer a multi-turn question.

Then, answer a multi-turn question.
```python
from sglang import function, system, user, assistant, gen, set_default_backend, RuntimeEndpoint
from sglang import function, system, user, assistant, gen, set_default_backend, OpenAI

@function
def multi_turn_question(s, question_1, question_2):
Expand All @@ -92,7 +93,7 @@ def multi_turn_question(s, question_1, question_2):
s += user(question_2)
s += assistant(gen("answer_2", max_tokens=256))

set_default_backend(RuntimeEndpoint("http://localhost:30000"))
set_default_backend(OpenAI("gpt-3.5-turbo"))

state = multi_turn_question.run(
question_1="What is the capital of the United States?",
Expand Down Expand Up @@ -120,7 +121,7 @@ import sglang as sgl
`sglang` provides some simple primitives such as `gen`, `select`, `fork`, `image`.
You can implement your prompt flow in a function decorated by `sgl.function`.
You can then invoke the function with `run` or `run_batch`.
The system will manage the state, chat template, and parallelism for you.
The system will manage the state, chat template, parallelism and batching for you.

### Control Flow
You can use any Python code within the function body, including control flow, nested function calls, and external libraries.
Expand Down
74 changes: 61 additions & 13 deletions examples/quick_start/anthropic_example_chat.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,67 @@
from sglang import function, system, user, assistant, gen, set_default_backend, Anthropic
"""
Usage:
export ANTHROPIC_API_KEY=sk-******
python3 anthropic_example_chat.py
"""
import sglang as sgl


@function
@sgl.function
def multi_turn_question(s, question_1, question_2):
s += user(question_1)
s += assistant(gen("answer_1", max_tokens=256))
s += user(question_2)
s += assistant(gen("answer_2", max_tokens=256))
s += sgl.user(question_1)
s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
s += sgl.user(question_2)
s += sgl.assistant(sgl.gen("answer_2", max_tokens=256))

set_default_backend(Anthropic("claude-2"))

state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions.",
)
def single():
state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions.",
)

for m in state.messages():
print(m["role"], ":", m["content"])
for m in state.messages():
print(m["role"], ":", m["content"])

print("answer_1", state["answer_1"])


def stream():
state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions.",
stream=True
)

for out in state.text_iter():
print(out, end="", flush=True)
print()


def batch():
states = multi_turn_question.run_batch([
{"question_1": "What is the capital of the United States?",
"question_2": "List two local attractions."},

{"question_1": "What is the capital of France?",
"question_2": "What is the population of this city?"},
])

for s in states:
print(s.messages())


if __name__ == "__main__":
sgl.set_default_backend(sgl.Anthropic("claude-2"))

# Run a single request
print("\n========== single ==========\n")
single()

# Stream output
print("\n========== stream ==========\n")
stream()

# Run a batch of requests
print("\n========== batch ==========\n")
batch()
57 changes: 49 additions & 8 deletions examples/quick_start/anthropic_example_complete.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from sglang import function, gen, set_default_backend, Anthropic
"""
Usage:
export ANTHROPIC_API_KEY=sk-******
python3 anthropic_example_complete.py
"""

import sglang as sgl

@function

@sgl.function
def few_shot_qa(s, question):
s += (
"""
Expand All @@ -13,14 +19,49 @@ def few_shot_qa(s, question):
\n\nAssistant: Rome
""")
s += "\n\nHuman: " + question + "\n"
s += "\n\nAssistant:" + gen("answer", stop="\n", temperature=0)
s += "\n\nAssistant:" + sgl.gen("answer", stop="\n", temperature=0)


def single():
state = few_shot_qa.run(question="What is the capital of the United States?")
answer = state["answer"].strip().lower()

assert "washington" in answer, f"answer: {state['answer']}"

print(state.text())


def stream():
state = few_shot_qa.run(
question="What is the capital of the United States?",
stream=True)

for out in state.text_iter("answer"):
print(out, end="", flush=True)
print()


def batch():
states = few_shot_qa.run_batch([
{"question": "What is the capital of the United States?"},
{"question": "What is the capital of China?"},
])

for s in states:
print(s["answer"])


set_default_backend(Anthropic("claude-2"))
if __name__ == "__main__":
sgl.set_default_backend(sgl.Anthropic("claude-2"))

state = few_shot_qa.run(question="What is the capital of the United States?")
answer = state["answer"].strip().lower()
# Run a single request
print("\n========== single ==========\n")
single()

assert "washington" in answer, f"answer: {state['answer']}"
# Stream output
print("\n========== stream ==========\n")
stream()

print(state.text())
# Run a batch of requests
print("\n========== batch ==========\n")
batch()
20 changes: 0 additions & 20 deletions examples/quick_start/anthropic_example_stream.py

This file was deleted.

67 changes: 67 additions & 0 deletions examples/quick_start/gemini_example_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
Usage:
export GCP_PROJECT_ID=******
python3 gemini_example_chat.py
"""
import sglang as sgl


@sgl.function
def multi_turn_question(s, question_1, question_2):
s += sgl.user(question_1)
s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
s += sgl.user(question_2)
s += sgl.assistant(sgl.gen("answer_2", max_tokens=256))


def single():
state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions.",
)

for m in state.messages():
print(m["role"], ":", m["content"])

print("answer_1", state["answer_1"])


def stream():
state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions.",
stream=True
)

for out in state.text_iter():
print(out, end="", flush=True)
print()


def batch():
states = multi_turn_question.run_batch([
{"question_1": "What is the capital of the United States?",
"question_2": "List two local attractions."},

{"question_1": "What is the capital of France?",
"question_2": "What is the population of this city?"},
])

for s in states:
print(s.messages())


if __name__ == "__main__":
sgl.set_default_backend(sgl.VertexAI("gemini-pro"))

# Run a single request
print("\n========== single ==========\n")
single()

# Stream output
print("\n========== stream ==========\n")
stream()

# Run a batch of requests
print("\n========== batch ==========\n")
batch()
57 changes: 49 additions & 8 deletions examples/quick_start/gemini_example_complete.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from sglang import function, gen, set_default_backend, VertexAI
"""
Usage:
export GCP_PROJECT_ID=******
python3 gemini_example_complete.py
"""

import sglang as sgl

@function

@sgl.function
def few_shot_qa(s, question):
s += (
"""The following are questions with answers.
Expand All @@ -13,14 +19,49 @@ def few_shot_qa(s, question):
A: Rome
""")
s += "Q: " + question + "\n"
s += "A:" + gen("answer", stop="\n", temperature=0)
s += "A:" + sgl.gen("answer", stop="\n", temperature=0)


def single():
state = few_shot_qa.run(question="What is the capital of the United States?")
answer = state["answer"].strip().lower()

assert "washington" in answer, f"answer: {state['answer']}"

print(state.text())


def stream():
state = few_shot_qa.run(
question="What is the capital of the United States?",
stream=True)

for out in state.text_iter("answer"):
print(out, end="", flush=True)
print()


def batch():
states = few_shot_qa.run_batch([
{"question": "What is the capital of the United States?"},
{"question": "What is the capital of China?"},
])

for s in states:
print(s["answer"])


set_default_backend(VertexAI("gemini-pro"))
if __name__ == "__main__":
sgl.set_default_backend(sgl.VertexAI("gemini-pro"))

state = few_shot_qa.run(question="What is the capital of the United States?")
answer = state["answer"].strip().lower()
# Run a single request
print("\n========== single ==========\n")
single()

assert "washington" in answer, f"answer: {state['answer']}"
# Stream output
print("\n========== stream ==========\n")
stream()

print(state.text())
# Run a batch of requests
print("\n========== batch ==========\n")
batch()
Loading

0 comments on commit 0617528

Please sign in to comment.