forked from sgl-project/sglang
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
79cb018
commit 8644253
Showing
6 changed files
with
246 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
""" | ||
Usage: python3 srt_example_yi_vl.py | ||
""" | ||
import sglang as sgl | ||
|
||
|
||
@sgl.function | ||
def image_qa(s, image_path, question): | ||
s += sgl.user(sgl.image(image_path) + question) | ||
s += sgl.assistant(sgl.gen("answer")) | ||
|
||
|
||
def single(): | ||
state = image_qa.run( | ||
image_path="images/cat.jpeg", | ||
question="What is this?", | ||
max_new_tokens=64, | ||
stop="###") | ||
print(state["answer"], "\n") | ||
|
||
|
||
def stream(): | ||
state = image_qa.run( | ||
image_path="images/cat.jpeg", | ||
question="What is this?", | ||
max_new_tokens=64, | ||
stream=True, | ||
stop="###") | ||
|
||
for out in state.text_iter("answer"): | ||
print(out, end="", flush=True) | ||
print() | ||
|
||
|
||
def batch(): | ||
states = image_qa.run_batch( | ||
[ | ||
{"image_path": "images/cat.jpeg", "question":"What is this?"}, | ||
{"image_path": "images/dog.jpeg", "question":"What is this?"}, | ||
], | ||
max_new_tokens=64, | ||
stop="###" | ||
) | ||
for s in states: | ||
print(s["answer"], "\n") | ||
|
||
|
||
if __name__ == "__main__": | ||
runtime = sgl.Runtime(model_path="BabyChou/Yi-VL-6B", | ||
tokenizer_path="BabyChou/Yi-VL-6B") | ||
sgl.set_default_backend(runtime) | ||
# Or you can use API models | ||
# sgl.set_default_backend(sgl.OpenAI("gpt-4-vision-preview")) | ||
# sgl.set_default_backend(sgl.VertexAI("gemini-pro-vision")) | ||
|
||
# Run a single request | ||
print("\n========== single ==========\n") | ||
single() | ||
|
||
# Stream output | ||
print("\n========== stream ==========\n") | ||
stream() | ||
|
||
# Run a batch of requests | ||
print("\n========== batch ==========\n") | ||
batch() | ||
|
||
runtime.shutdown() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
"""Inference-only Yi-VL model.""" | ||
import os | ||
from typing import List, Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
from transformers import CLIPVisionModel, LlavaConfig | ||
from vllm.model_executor.weight_utils import ( | ||
default_weight_loader, | ||
hf_model_weights_iterator, | ||
) | ||
|
||
from sglang.srt.models.llava import LlavaLlamaForCausalLM, clip_vision_embed_forward, monkey_path_clip_vision_embed_forward | ||
|
||
|
||
class YiVLForCausalLM(LlavaLlamaForCausalLM): | ||
def __init__(self, *args, **kwargs): | ||
self.config = kwargs["config"] | ||
super().__init__(self.config) | ||
|
||
self.multi_modal_projector = YiVLMultiModalProjector(self.config) | ||
self.vision_tower_subfolder = self.config.mm_vision_tower.replace("./", "") # Everything after "./" | ||
|
||
def load_weights( | ||
self, | ||
model_name_or_path: str, | ||
cache_dir: Optional[str] = None, | ||
load_format: str = "auto", | ||
revision: Optional[str] = None, | ||
): | ||
# We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B) | ||
self.vision_tower = CLIPVisionModel.from_pretrained( | ||
model_name_or_path, torch_dtype=torch.float16, subfolder=self.vision_tower_subfolder | ||
).cuda() | ||
|
||
self.vision_tower.eval() | ||
|
||
self.vision_feature_layer = self.config.mm_vision_select_layer | ||
self.vision_feature_select_strategy = self.config.mm_vision_select_feature | ||
self.image_size = self.vision_tower.config.image_size | ||
self.patch_size = self.vision_tower.config.patch_size | ||
|
||
self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") | ||
self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") | ||
self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None) | ||
|
||
self.image_feature_len = int((self.image_size / self.patch_size) ** 2) | ||
if self.vision_feature_select_strategy == "patch": | ||
pass | ||
elif self.vision_feature_select_strategy == "cls_patch": | ||
self.image_feature_len += 1 | ||
else: | ||
raise ValueError(f"Unexpected select feature: {self.select_feature}") | ||
|
||
# load mm_projector | ||
# TODO: support TP? | ||
projector_weights = { | ||
"model.mm_projector.0": "multi_modal_projector.linear_1", | ||
"model.mm_projector.1": "multi_modal_projector.ln_1", | ||
"model.mm_projector.3": "multi_modal_projector.linear_2", | ||
"model.mm_projector.4": "multi_modal_projector.ln_2", | ||
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). | ||
} | ||
params_dict = dict(self.named_parameters()) | ||
for name, loaded_weight in hf_model_weights_iterator( | ||
model_name_or_path, cache_dir, load_format, revision | ||
): | ||
if "projector" in name or "vision_tower" in name: | ||
for weight_name, param_name in projector_weights.items(): | ||
if weight_name in name: | ||
name = name.replace(weight_name, param_name) | ||
param = params_dict[name] | ||
weight_loader = getattr(param, "weight_loader", default_weight_loader) | ||
weight_loader(param, loaded_weight) | ||
|
||
# load language model | ||
self.language_model.load_weights( | ||
model_name_or_path, cache_dir, load_format, revision | ||
) | ||
|
||
monkey_path_clip_vision_embed_forward() | ||
|
||
class YiVLMultiModalProjector(nn.Module): | ||
def __init__(self, config: LlavaConfig): | ||
super().__init__() | ||
|
||
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size) | ||
self.ln_1 = nn.LayerNorm(config.text_config.hidden_size) | ||
self.act = nn.GELU() | ||
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size) | ||
self.ln_2 = nn.LayerNorm(config.text_config.hidden_size) | ||
|
||
def forward(self, image_features): | ||
hidden_states = self.linear_1(image_features) | ||
hidden_state = self.ln_1(hidden_states) | ||
hidden_states = self.act(hidden_states) | ||
hidden_states = self.linear_2(hidden_states) | ||
hidden_states = self.ln_2(hidden_states) | ||
return hidden_states | ||
|
||
EntryClass = YiVLForCausalLM |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
""" | ||
Convert Yi-VL config into a format useable with SGLang | ||
Usage: python3 scripts/convert_yi_vl.py --model-path <path-to-model> | ||
""" | ||
|
||
import argparse | ||
import json | ||
import os | ||
|
||
from transformers import AutoConfig, AutoTokenizer | ||
|
||
def add_image_token(model_path: str): | ||
tokenizer = AutoTokenizer.from_pretrained(model_path) | ||
tokenizer.add_tokens( | ||
["<image_placeholder>"], | ||
special_tokens=True | ||
) | ||
|
||
print(tokenizer) | ||
tokenizer.save_pretrained(model_path) | ||
|
||
def edit_model_config(model_path): | ||
config = AutoConfig.from_pretrained(model_path) | ||
|
||
setattr(config, "architectures", ["YiVLForCausalLM"]) | ||
setattr(config, "image_token_index", 64002) | ||
|
||
print(config) | ||
config.save_pretrained(model_path) | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model-path", type=str) | ||
args = parser.parse_args() | ||
|
||
add_image_token(args.model_path) | ||
edit_model_config(args.model_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# For 34B Model | ||
mkdir ~/model_weights | ||
cd ~/model_weights | ||
git clone https://huggingface.co/01-ai/Yi-VL-34B | ||
cp ~/model_weights/Yi-VL-34B/vit/clip-vit-H-14-laion2B-s32B-b79K-yi-vl-34B-448/preprocessor_config.json ~/model_weights/Yi-VL-34B | ||
python3 convert_yi_vl.py --model-path ~/model_weights/Yi-VL-34B | ||
|
||
# For 6B Model | ||
mkdir ~/model_weights | ||
cd ~/model_weights | ||
git clone https://huggingface.co/01-ai/Yi-VL-6B | ||
cp ~/model_weights/Yi-VL-6B/vit/clip-vit-H-14-laion2B-s32B-b79K-yi-vl-6B-448/preprocessor_config.json ~/model_weights/Yi-VL-6B | ||
python3 convert_yi_vl.py --model-path ~/model_weights/Yi-VL-6B |