Skip to content

Commit

Permalink
Yi-VL Model (sgl-project#112)
Browse files Browse the repository at this point in the history
  • Loading branch information
BabyChouSr authored Feb 1, 2024
1 parent 79cb018 commit 8644253
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 2 deletions.
68 changes: 68 additions & 0 deletions examples/quick_start/srt_example_yi_vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""
Usage: python3 srt_example_yi_vl.py
"""
import sglang as sgl


@sgl.function
def image_qa(s, image_path, question):
s += sgl.user(sgl.image(image_path) + question)
s += sgl.assistant(sgl.gen("answer"))


def single():
state = image_qa.run(
image_path="images/cat.jpeg",
question="What is this?",
max_new_tokens=64,
stop="###")
print(state["answer"], "\n")


def stream():
state = image_qa.run(
image_path="images/cat.jpeg",
question="What is this?",
max_new_tokens=64,
stream=True,
stop="###")

for out in state.text_iter("answer"):
print(out, end="", flush=True)
print()


def batch():
states = image_qa.run_batch(
[
{"image_path": "images/cat.jpeg", "question":"What is this?"},
{"image_path": "images/dog.jpeg", "question":"What is this?"},
],
max_new_tokens=64,
stop="###"
)
for s in states:
print(s["answer"], "\n")


if __name__ == "__main__":
runtime = sgl.Runtime(model_path="BabyChou/Yi-VL-6B",
tokenizer_path="BabyChou/Yi-VL-6B")
sgl.set_default_backend(runtime)
# Or you can use API models
# sgl.set_default_backend(sgl.OpenAI("gpt-4-vision-preview"))
# sgl.set_default_backend(sgl.VertexAI("gemini-pro-vision"))

# Run a single request
print("\n========== single ==========\n")
single()

# Stream output
print("\n========== stream ==========\n")
stream()

# Run a batch of requests
print("\n========== batch ==========\n")
batch()

runtime.shutdown()
23 changes: 23 additions & 0 deletions python/sglang/lang/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,23 @@ def get_chat_template_by_model_path(model_path):
)
)

# Reference: https://github.com/01-ai/Yi/tree/main/VL#major-difference-with-llava
register_chat_template(
ChatTemplate(
name="yi",
default_system_prompt=(
"This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. Read all the images carefully, and respond to the human's questions with informative, helpful, detailed and polite answers."
"这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。"
),
role_prefix_and_suffix={
"system": ("", "\n\n"),
"user": ("### Human:", "\n"),
"assistant": ("### Assistant:", "\n"),
},
image_token=" <image_placeholder>\n",
)
)


@register_chat_template_matching_function
def match_vicuna(model_path: str):
Expand Down Expand Up @@ -176,6 +193,12 @@ def match_chat_ml(model_path: str):
if "qwen" in model_path and "chat" in model_path:
return get_chat_template("chatml")

@register_chat_template_matching_function
def match_chat_yi(model_path: str):
model_path = model_path.lower()
if "yi" in model_path:
return get_chat_template("yi")


if __name__ == "__main__":
messages = [
Expand Down
101 changes: 101 additions & 0 deletions python/sglang/srt/models/yivl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Inference-only Yi-VL model."""
import os
from typing import List, Optional

import torch
import torch.nn as nn
from transformers import CLIPVisionModel, LlavaConfig
from vllm.model_executor.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)

from sglang.srt.models.llava import LlavaLlamaForCausalLM, clip_vision_embed_forward, monkey_path_clip_vision_embed_forward


class YiVLForCausalLM(LlavaLlamaForCausalLM):
def __init__(self, *args, **kwargs):
self.config = kwargs["config"]
super().__init__(self.config)

self.multi_modal_projector = YiVLMultiModalProjector(self.config)
self.vision_tower_subfolder = self.config.mm_vision_tower.replace("./", "") # Everything after "./"

def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
# We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B)
self.vision_tower = CLIPVisionModel.from_pretrained(
model_name_or_path, torch_dtype=torch.float16, subfolder=self.vision_tower_subfolder
).cuda()

self.vision_tower.eval()

self.vision_feature_layer = self.config.mm_vision_select_layer
self.vision_feature_select_strategy = self.config.mm_vision_select_feature
self.image_size = self.vision_tower.config.image_size
self.patch_size = self.vision_tower.config.patch_size

self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)

self.image_feature_len = int((self.image_size / self.patch_size) ** 2)
if self.vision_feature_select_strategy == "patch":
pass
elif self.vision_feature_select_strategy == "cls_patch":
self.image_feature_len += 1
else:
raise ValueError(f"Unexpected select feature: {self.select_feature}")

# load mm_projector
# TODO: support TP?
projector_weights = {
"model.mm_projector.0": "multi_modal_projector.linear_1",
"model.mm_projector.1": "multi_modal_projector.ln_1",
"model.mm_projector.3": "multi_modal_projector.linear_2",
"model.mm_projector.4": "multi_modal_projector.ln_2",
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
}
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision
):
if "projector" in name or "vision_tower" in name:
for weight_name, param_name in projector_weights.items():
if weight_name in name:
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)

# load language model
self.language_model.load_weights(
model_name_or_path, cache_dir, load_format, revision
)

monkey_path_clip_vision_embed_forward()

class YiVLMultiModalProjector(nn.Module):
def __init__(self, config: LlavaConfig):
super().__init__()

self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size)
self.ln_1 = nn.LayerNorm(config.text_config.hidden_size)
self.act = nn.GELU()
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size)
self.ln_2 = nn.LayerNorm(config.text_config.hidden_size)

def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_state = self.ln_1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
hidden_states = self.ln_2(hidden_states)
return hidden_states

EntryClass = YiVLForCausalLM
5 changes: 3 additions & 2 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,12 @@ def ret_func(grid, num_warps, *args):

def is_multimodal_model(model):
if isinstance(model, str):
return "llava" in model
return "llava" or "yi-vl" in model
from sglang.srt.model_config import ModelConfig

if isinstance(model, ModelConfig):
return "llava" in model.path.lower()
model_path = model.path.lower()
return "llava" in model_path or "yi-vl" in model_path
raise Exception("unrecognized type")


Expand Down
38 changes: 38 additions & 0 deletions scripts/convert_yi_vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Convert Yi-VL config into a format useable with SGLang
Usage: python3 scripts/convert_yi_vl.py --model-path <path-to-model>
"""

import argparse
import json
import os

from transformers import AutoConfig, AutoTokenizer

def add_image_token(model_path: str):
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.add_tokens(
["<image_placeholder>"],
special_tokens=True
)

print(tokenizer)
tokenizer.save_pretrained(model_path)

def edit_model_config(model_path):
config = AutoConfig.from_pretrained(model_path)

setattr(config, "architectures", ["YiVLForCausalLM"])
setattr(config, "image_token_index", 64002)

print(config)
config.save_pretrained(model_path)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str)
args = parser.parse_args()

add_image_token(args.model_path)
edit_model_config(args.model_path)
13 changes: 13 additions & 0 deletions scripts/convert_yi_vl.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# For 34B Model
mkdir ~/model_weights
cd ~/model_weights
git clone https://huggingface.co/01-ai/Yi-VL-34B
cp ~/model_weights/Yi-VL-34B/vit/clip-vit-H-14-laion2B-s32B-b79K-yi-vl-34B-448/preprocessor_config.json ~/model_weights/Yi-VL-34B
python3 convert_yi_vl.py --model-path ~/model_weights/Yi-VL-34B

# For 6B Model
mkdir ~/model_weights
cd ~/model_weights
git clone https://huggingface.co/01-ai/Yi-VL-6B
cp ~/model_weights/Yi-VL-6B/vit/clip-vit-H-14-laion2B-s32B-b79K-yi-vl-6B-448/preprocessor_config.json ~/model_weights/Yi-VL-6B
python3 convert_yi_vl.py --model-path ~/model_weights/Yi-VL-6B

0 comments on commit 8644253

Please sign in to comment.