Skip to content

Commit

Permalink
Gemini Backend (sgl-project#9)
Browse files Browse the repository at this point in the history
Co-authored-by: Ying Sheng <[email protected]>
  • Loading branch information
caoshiyi and Ying1123 authored Jan 17, 2024
1 parent c4707f1 commit fd7c479
Show file tree
Hide file tree
Showing 13 changed files with 311 additions and 2 deletions.
26 changes: 26 additions & 0 deletions examples/quick_start/gemini_example_complete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from sglang import function, gen, set_default_backend, Gemini


@function
def few_shot_qa(s, question):
s += (
"""The following are questions with answers.
Q: What is the capital of France?
A: Paris
Q: What is the capital of Germany?
A: Berlin
Q: What is the capital of Italy?
A: Rome
""")
s += "Q: " + question + "\n"
s += "A:" + gen("answer", stop="\n", temperature=0)


set_default_backend(Gemini("gemini-pro"))

state = few_shot_qa.run(question="What is the capital of the United States?")
answer = state["answer"].strip().lower()

assert "washington" in answer, f"answer: {state['answer']}"

print(state.text())
19 changes: 19 additions & 0 deletions examples/quick_start/gemini_example_multimodal_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from sglang import function, user, assistant, gen, image, set_default_backend, Gemini


@function
def image_qa(s, image_file1, image_file2, question):
s += user(image(image_file1) + image(image_file2) + question)
s += assistant(gen("answer_1", max_tokens=256))

set_default_backend(Gemini("gemini-pro-vision"))

state = image_qa.run(
image_file1="./images/cat.jpeg",
image_file2="./images/dog.jpeg",
question="Describe difference of the 2 images in one sentence.",
stream=True
)

for out in state.text_iter():
print(out, end="", flush=True)
20 changes: 20 additions & 0 deletions examples/quick_start/gemini_example_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from sglang import function, user, assistant, gen, set_default_backend, Gemini


@function
def multi_turn_question(s, question_1, question_2):
s += user(question_1)
s += assistant(gen("answer_1", max_tokens=256))
s += user(question_2)
s += assistant(gen("answer_2", max_tokens=256))

set_default_backend(Gemini("gemini-pro"))

state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions.",
stream=True
)

for out in state.text_iter():
print(out, end="", flush=True)
Binary file added examples/quick_start/images/cat.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/quick_start/images/dog.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions python/sglang/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from sglang.backend.anthropic import Anthropic
from sglang.backend.base_backend import BaseBackend
from sglang.backend.gemini import Gemini
from sglang.backend.openai import OpenAI
from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.global_config import global_config
Expand Down
152 changes: 152 additions & 0 deletions python/sglang/backend/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import os
import warnings
from typing import List, Optional, Union

import numpy as np
from sglang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams

try:
import vertexai
from vertexai.preview.generative_models import (
GenerationConfig,
GenerativeModel,
Image,
)
except ImportError as e:
GenerativeModel = e

GEMINI_MODEL_NAMES = [
"gemini-pro",
"gemini-pro-vision",
]


class Gemini(BaseBackend):
def __init__(self, model_name):
super().__init__()

if isinstance(GenerativeModel, Exception):
raise GenerativeModel

project_id = os.environ["GCP_PROJECT_ID"]
location = os.environ["GCP_LOCATION"]
vertexai.init(project=project_id, location=location)

self.model_name = model_name
self.chat_template = get_chat_template("default")

def get_chat_template(self):
return self.chat_template

def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if s.messages_:
prompt = self.messages_to_gemini_input(s.messages_)
else:
# single-turn
prompt = (
self.text_to_gemini_input(s.text_, s.cur_images)
if s.cur_images
else s.text_
)
ret = GenerativeModel(self.model_name).generate_content(
prompt,
generation_config=GenerationConfig(**sampling_params.to_gemini_kwargs()),
)

comp = ret.text

return comp, {}

def generate_stream(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if s.messages_:
prompt = self.messages_to_gemini_input(s.messages_)
else:
# single-turn
prompt = (
self.text_to_gemini_input(s.text_, s.cur_images)
if s.cur_images
else s.text_
)
generator = GenerativeModel(self.model_name).generate_content(
prompt,
stream=True,
generation_config=GenerationConfig(**sampling_params.to_gemini_kwargs()),
)
for ret in generator:
yield ret.text, {}

def text_to_gemini_input(self, text, images):
input = []
# split with image token
text_segs = text.split(self.chat_template.image_token)
for image_path, image_base64_data in images:
text_seg = text_segs.pop(0)
if text_seg != "":
input.append(text_seg)
input.append(Image.from_bytes(image_base64_data))
text_seg = text_segs.pop(0)
if text_seg != "":
input.append(text_seg)
return input

def messages_to_gemini_input(self, messages):
gemini_message = []
# from openai message format to gemini message format
for msg in messages:
if isinstance(msg["content"], str):
text = msg["content"]
else:
text = msg["content"][0]["text"]

if msg["role"] == "system":
warnings.warn("Warning: system prompt is not supported in Gemini.")
gemini_message.append(
{
"role": "user",
"parts": [{"text": "System prompt: " + text}],
}
)
gemini_message.append(
{
"role": "model",
"parts": [{"text": "Understood."}],
}
)
continue
if msg["role"] == "user":
gemini_msg = {
"role": "user",
"parts": [{"text": text}],
}
elif msg["role"] == "assistant":
gemini_msg = {
"role": "model",
"parts": [{"text": text}],
}

# images
if isinstance(msg["content"], list) and len(msg["content"]) > 1:
for image in msg["content"][1:]:
assert image["type"] == "image_url"
gemini_msg["parts"].append(
{
"inline_data": {
"data": image["image_url"]["url"].split(",")[1],
"mime_type": "image/jpeg",
}
}
)

gemini_message.append(gemini_msg)
return gemini_message
1 change: 1 addition & 0 deletions python/sglang/lang/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ def _execute_role_end(self, expr: SglRoleEnd):
self.messages_.append(last_msg)
self.cur_images = []
else:
# OpenAI chat API format
self.messages_.append({"role": expr.role, "content": new_text})

self.cur_role = None
Expand Down
10 changes: 10 additions & 0 deletions python/sglang/lang/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ def to_openai_kwargs(self):
"presence_penalty": self.presence_penalty,
}

def to_gemini_kwargs(self):
return {
"candidate_count": 1,
"max_output_tokens": self.max_new_tokens,
"stop_sequences": self.stop,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k if self.top_k > 0 else None,
}

def to_anthropic_kwargs(self):
# Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
return {
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def load_weights(
):
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
Expand Down
5 changes: 4 additions & 1 deletion python/sglang/test/test_programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,10 @@ def image_qa(s, question):
temperature=0,
max_new_tokens=64,
)
assert "taxi" in state.messages()[-1]["content"]
assert (
"taxi" in state.messages()[-1]["content"]
or "car" in state.messages()[-1]["content"]
)


def test_stream():
Expand Down
66 changes: 66 additions & 0 deletions test/lang/test_gemini_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import unittest

from sglang.test.test_programs import (
test_expert_answer,
test_few_shot_qa,
test_image_qa,
test_mt_bench,
test_parallel_decoding,
test_parallel_encoding,
test_stream,
)

from sglang import Gemini, set_default_backend


class TestGeminiBackend(unittest.TestCase):
backend = None
chat_backend = None
chat_vision_backend = None

def setUp(self):
cls = type(self)

if cls.backend is None:
cls.backend = Gemini("gemini-pro")
cls.chat_backend = Gemini("gemini-pro")
cls.chat_vision_backend = Gemini("gemini-pro-vision")

def test_few_shot_qa(self):
set_default_backend(self.backend)
test_few_shot_qa()

def test_mt_bench(self):
set_default_backend(self.chat_backend)
test_mt_bench()

def test_expert_answer(self):
set_default_backend(self.backend)
test_expert_answer()

def test_parallel_decoding(self):
set_default_backend(self.backend)
test_parallel_decoding()

def test_parallel_encoding(self):
set_default_backend(self.backend)
test_parallel_encoding()

def test_image_qa(self):
set_default_backend(self.chat_vision_backend)
test_image_qa()

def test_stream(self):
set_default_backend(self.backend)
test_stream()


if __name__ == "__main__":
unittest.main(warnings="ignore")

# from sglang.global_config import global_config

# global_config.verbosity = 2
# t = TestGeminiBackend()
# t.setUp()
# t.test_stream()
11 changes: 11 additions & 0 deletions test/lang/test_openai_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,15 @@ def test_stream(self):
# global_config.verbosity = 2
# t = TestOpenAIBackend()
# t.setUp()
# t.test_few_shot_qa()
# t.test_mt_bench()
# t.test_select()
# t.test_decode_int()
# t.test_decode_json()
# t.test_expert_answer()
# t.test_tool_use()
# t.test_react()
# t.test_parallel_decoding()
# t.test_parallel_encoding()
# t.test_image_qa()
# t.test_stream()

0 comments on commit fd7c479

Please sign in to comment.