forked from sgl-project/sglang
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Ying Sheng <[email protected]>
- Loading branch information
Showing
13 changed files
with
311 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from sglang import function, gen, set_default_backend, Gemini | ||
|
||
|
||
@function | ||
def few_shot_qa(s, question): | ||
s += ( | ||
"""The following are questions with answers. | ||
Q: What is the capital of France? | ||
A: Paris | ||
Q: What is the capital of Germany? | ||
A: Berlin | ||
Q: What is the capital of Italy? | ||
A: Rome | ||
""") | ||
s += "Q: " + question + "\n" | ||
s += "A:" + gen("answer", stop="\n", temperature=0) | ||
|
||
|
||
set_default_backend(Gemini("gemini-pro")) | ||
|
||
state = few_shot_qa.run(question="What is the capital of the United States?") | ||
answer = state["answer"].strip().lower() | ||
|
||
assert "washington" in answer, f"answer: {state['answer']}" | ||
|
||
print(state.text()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from sglang import function, user, assistant, gen, image, set_default_backend, Gemini | ||
|
||
|
||
@function | ||
def image_qa(s, image_file1, image_file2, question): | ||
s += user(image(image_file1) + image(image_file2) + question) | ||
s += assistant(gen("answer_1", max_tokens=256)) | ||
|
||
set_default_backend(Gemini("gemini-pro-vision")) | ||
|
||
state = image_qa.run( | ||
image_file1="./images/cat.jpeg", | ||
image_file2="./images/dog.jpeg", | ||
question="Describe difference of the 2 images in one sentence.", | ||
stream=True | ||
) | ||
|
||
for out in state.text_iter(): | ||
print(out, end="", flush=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from sglang import function, user, assistant, gen, set_default_backend, Gemini | ||
|
||
|
||
@function | ||
def multi_turn_question(s, question_1, question_2): | ||
s += user(question_1) | ||
s += assistant(gen("answer_1", max_tokens=256)) | ||
s += user(question_2) | ||
s += assistant(gen("answer_2", max_tokens=256)) | ||
|
||
set_default_backend(Gemini("gemini-pro")) | ||
|
||
state = multi_turn_question.run( | ||
question_1="What is the capital of the United States?", | ||
question_2="List two local attractions.", | ||
stream=True | ||
) | ||
|
||
for out in state.text_iter(): | ||
print(out, end="", flush=True) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
import os | ||
import warnings | ||
from typing import List, Optional, Union | ||
|
||
import numpy as np | ||
from sglang.backend.base_backend import BaseBackend | ||
from sglang.lang.chat_template import get_chat_template | ||
from sglang.lang.interpreter import StreamExecutor | ||
from sglang.lang.ir import SglSamplingParams | ||
|
||
try: | ||
import vertexai | ||
from vertexai.preview.generative_models import ( | ||
GenerationConfig, | ||
GenerativeModel, | ||
Image, | ||
) | ||
except ImportError as e: | ||
GenerativeModel = e | ||
|
||
GEMINI_MODEL_NAMES = [ | ||
"gemini-pro", | ||
"gemini-pro-vision", | ||
] | ||
|
||
|
||
class Gemini(BaseBackend): | ||
def __init__(self, model_name): | ||
super().__init__() | ||
|
||
if isinstance(GenerativeModel, Exception): | ||
raise GenerativeModel | ||
|
||
project_id = os.environ["GCP_PROJECT_ID"] | ||
location = os.environ["GCP_LOCATION"] | ||
vertexai.init(project=project_id, location=location) | ||
|
||
self.model_name = model_name | ||
self.chat_template = get_chat_template("default") | ||
|
||
def get_chat_template(self): | ||
return self.chat_template | ||
|
||
def generate( | ||
self, | ||
s: StreamExecutor, | ||
sampling_params: SglSamplingParams, | ||
): | ||
if s.messages_: | ||
prompt = self.messages_to_gemini_input(s.messages_) | ||
else: | ||
# single-turn | ||
prompt = ( | ||
self.text_to_gemini_input(s.text_, s.cur_images) | ||
if s.cur_images | ||
else s.text_ | ||
) | ||
ret = GenerativeModel(self.model_name).generate_content( | ||
prompt, | ||
generation_config=GenerationConfig(**sampling_params.to_gemini_kwargs()), | ||
) | ||
|
||
comp = ret.text | ||
|
||
return comp, {} | ||
|
||
def generate_stream( | ||
self, | ||
s: StreamExecutor, | ||
sampling_params: SglSamplingParams, | ||
): | ||
if s.messages_: | ||
prompt = self.messages_to_gemini_input(s.messages_) | ||
else: | ||
# single-turn | ||
prompt = ( | ||
self.text_to_gemini_input(s.text_, s.cur_images) | ||
if s.cur_images | ||
else s.text_ | ||
) | ||
generator = GenerativeModel(self.model_name).generate_content( | ||
prompt, | ||
stream=True, | ||
generation_config=GenerationConfig(**sampling_params.to_gemini_kwargs()), | ||
) | ||
for ret in generator: | ||
yield ret.text, {} | ||
|
||
def text_to_gemini_input(self, text, images): | ||
input = [] | ||
# split with image token | ||
text_segs = text.split(self.chat_template.image_token) | ||
for image_path, image_base64_data in images: | ||
text_seg = text_segs.pop(0) | ||
if text_seg != "": | ||
input.append(text_seg) | ||
input.append(Image.from_bytes(image_base64_data)) | ||
text_seg = text_segs.pop(0) | ||
if text_seg != "": | ||
input.append(text_seg) | ||
return input | ||
|
||
def messages_to_gemini_input(self, messages): | ||
gemini_message = [] | ||
# from openai message format to gemini message format | ||
for msg in messages: | ||
if isinstance(msg["content"], str): | ||
text = msg["content"] | ||
else: | ||
text = msg["content"][0]["text"] | ||
|
||
if msg["role"] == "system": | ||
warnings.warn("Warning: system prompt is not supported in Gemini.") | ||
gemini_message.append( | ||
{ | ||
"role": "user", | ||
"parts": [{"text": "System prompt: " + text}], | ||
} | ||
) | ||
gemini_message.append( | ||
{ | ||
"role": "model", | ||
"parts": [{"text": "Understood."}], | ||
} | ||
) | ||
continue | ||
if msg["role"] == "user": | ||
gemini_msg = { | ||
"role": "user", | ||
"parts": [{"text": text}], | ||
} | ||
elif msg["role"] == "assistant": | ||
gemini_msg = { | ||
"role": "model", | ||
"parts": [{"text": text}], | ||
} | ||
|
||
# images | ||
if isinstance(msg["content"], list) and len(msg["content"]) > 1: | ||
for image in msg["content"][1:]: | ||
assert image["type"] == "image_url" | ||
gemini_msg["parts"].append( | ||
{ | ||
"inline_data": { | ||
"data": image["image_url"]["url"].split(",")[1], | ||
"mime_type": "image/jpeg", | ||
} | ||
} | ||
) | ||
|
||
gemini_message.append(gemini_msg) | ||
return gemini_message |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import unittest | ||
|
||
from sglang.test.test_programs import ( | ||
test_expert_answer, | ||
test_few_shot_qa, | ||
test_image_qa, | ||
test_mt_bench, | ||
test_parallel_decoding, | ||
test_parallel_encoding, | ||
test_stream, | ||
) | ||
|
||
from sglang import Gemini, set_default_backend | ||
|
||
|
||
class TestGeminiBackend(unittest.TestCase): | ||
backend = None | ||
chat_backend = None | ||
chat_vision_backend = None | ||
|
||
def setUp(self): | ||
cls = type(self) | ||
|
||
if cls.backend is None: | ||
cls.backend = Gemini("gemini-pro") | ||
cls.chat_backend = Gemini("gemini-pro") | ||
cls.chat_vision_backend = Gemini("gemini-pro-vision") | ||
|
||
def test_few_shot_qa(self): | ||
set_default_backend(self.backend) | ||
test_few_shot_qa() | ||
|
||
def test_mt_bench(self): | ||
set_default_backend(self.chat_backend) | ||
test_mt_bench() | ||
|
||
def test_expert_answer(self): | ||
set_default_backend(self.backend) | ||
test_expert_answer() | ||
|
||
def test_parallel_decoding(self): | ||
set_default_backend(self.backend) | ||
test_parallel_decoding() | ||
|
||
def test_parallel_encoding(self): | ||
set_default_backend(self.backend) | ||
test_parallel_encoding() | ||
|
||
def test_image_qa(self): | ||
set_default_backend(self.chat_vision_backend) | ||
test_image_qa() | ||
|
||
def test_stream(self): | ||
set_default_backend(self.backend) | ||
test_stream() | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main(warnings="ignore") | ||
|
||
# from sglang.global_config import global_config | ||
|
||
# global_config.verbosity = 2 | ||
# t = TestGeminiBackend() | ||
# t.setUp() | ||
# t.test_stream() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters