Skip to content

Commit

Permalink
[FEAT] Support GGUF format (#2215)
Browse files Browse the repository at this point in the history
Co-authored-by: Yang Zheng(SW)(Alex) <[email protected]>
  • Loading branch information
zhengy001 and Yang Zheng(SW)(Alex) authored Nov 30, 2024
1 parent 0d6a49b commit 883c955
Show file tree
Hide file tree
Showing 39 changed files with 180 additions and 89 deletions.
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ sphinx-copybutton
sphinx-tabs
sphinxcontrib-mermaid
urllib3<2.0.0
gguf>=0.10.0
36 changes: 35 additions & 1 deletion python/sglang/srt/hf_transformers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import contextlib
import os
import warnings
from pathlib import Path
from typing import Dict, Optional, Type, Union

from huggingface_hub import snapshot_download
Expand All @@ -27,6 +28,7 @@
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES

try:
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
Expand Down Expand Up @@ -60,15 +62,29 @@ def get_config(
trust_remote_code: bool,
revision: Optional[str] = None,
model_override_args: Optional[dict] = None,
**kwargs,
):
is_gguf = check_gguf_file(model)
if is_gguf:
kwargs["gguf_file"] = model
model = Path(model).parent

config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model, revision=revision)
if model_override_args:
config.update(model_override_args)

# Special architecture mapping check for GGUF models
if is_gguf:
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
raise RuntimeError(f"Can't get gguf config for {config.model_type}.")
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]})

return config


Expand Down Expand Up @@ -123,6 +139,11 @@ def get_tokenizer(
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False

is_gguf = check_gguf_file(tokenizer_name)
if is_gguf:
kwargs["gguf_file"] = tokenizer_name
tokenizer_name = Path(tokenizer_name).parent

try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
Expand Down Expand Up @@ -195,3 +216,16 @@ def attach_additional_stop_token_ids(tokenizer):
)
else:
tokenizer.additional_stop_token_ids = None


def check_gguf_file(model: Union[str, os.PathLike]) -> bool:
"""Check if the file is a GGUF model."""
model = Path(model)
if not model.is_file():
return False
elif model.suffix == ".gguf":
return True

with open(model, "rb") as f:
header = f.read(4)
return header == b"GGUF"
20 changes: 17 additions & 3 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
tensor_model_parallel_all_gather,
)

from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode


Expand Down Expand Up @@ -163,7 +164,7 @@ def forward(
self,
input_ids,
hidden_states,
weight,
lm_head: VocabParallelEmbedding,
logits_metadata: Union[LogitsMetadata, ForwardBatch],
):
if isinstance(logits_metadata, ForwardBatch):
Expand All @@ -178,7 +179,7 @@ def forward(
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
last_hidden = hidden_states[last_index]

last_logits = torch.matmul(last_hidden, weight.T)
last_logits = self._get_logits(last_hidden, lm_head)
if self.do_tensor_parallel_all_gather:
last_logits = tensor_model_parallel_all_gather(last_logits)
last_logits = last_logits[:, : self.config.vocab_size].float()
Expand Down Expand Up @@ -229,7 +230,7 @@ def forward(

# Compute the logits and logprobs for all required tokens
states = torch.cat(states, dim=0)
all_logits = torch.matmul(states, weight.T)
all_logits = self._get_logits(states, lm_head)
if self.do_tensor_parallel_all_gather:
all_logits = tensor_model_parallel_all_gather(all_logits)
all_logits = all_logits[:, : self.config.vocab_size].float()
Expand Down Expand Up @@ -276,6 +277,19 @@ def forward(
output_top_logprobs=output_top_logprobs,
)

def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if hasattr(lm_head, "weight"):
logits = torch.matmul(hidden_states, lm_head.weight.T)
else:
# GGUF models
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
return logits


def test():
all_logprobs = torch.tensor(
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def __init__(
enable_tp: bool = True,
):
super().__init__()
self.quant_config = quant_config

self.enable_tp = enable_tp
if self.enable_tp:
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
enable_show_time_cost,
get_available_gpu_memory,
is_hip,
monkey_patch_vllm_gguf_config,
monkey_patch_vllm_model_config,
monkey_patch_vllm_p2p_access_check,
set_cpu_offload_max_bytes,
Expand Down Expand Up @@ -297,6 +298,8 @@ def load_model(self):
download_dir=self.server_args.download_dir,
)
monkey_patch_vllm_model_config()
if self.server_args.load_format == "gguf":
monkey_patch_vllm_gguf_config()
self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
if self.model_config.model_override_args is not None:
self.vllm_model_config.hf_config.update(
Expand Down
11 changes: 6 additions & 5 deletions python/sglang/srt/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,11 +338,12 @@ def __init__(

self.quant_config = quant_config
self.model = BaiChuanModel(config, position_embedding, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config)

def forward(
Expand All @@ -353,7 +354,7 @@ def forward(
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def forward(
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def forward(
forward_batch,
)
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
input_ids, hidden_states, self.model.embed_tokens, forward_batch
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def forward(
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def forward(
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ def forward(
hidden_states = self.model(input_ids, positions, forward_batch)
if not forward_batch.forward_mode.is_idle():
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/exaone.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def forward(
input_ids, positions, forward_batch, input_embeds
)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def forward(
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
input_ids, hidden_states, self.model.embed_tokens, forward_batch
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def forward(
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
input_ids, hidden_states, self.model.embed_tokens, forward_batch
)

def get_attention_sliding_window_size(self):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def forward(
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def forward(
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def forward(
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def forward(
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.output.weight, forward_batch
input_ids, hidden_states, self.output, forward_batch
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Expand Down
27 changes: 8 additions & 19 deletions python/sglang/srt/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def __init__(
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
self.layers = make_layers(
config.num_hidden_layers,
Expand Down Expand Up @@ -305,7 +306,12 @@ def __init__(
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.model = LlamaModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.stacked_params_mapping = [
Expand All @@ -329,7 +335,7 @@ def forward(
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
else:
return self.pooler(hidden_states, forward_batch)
Expand Down Expand Up @@ -373,7 +379,6 @@ def get_num_params(self):
return len(params_dict)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
embed_tokens_weight = None
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
Expand All @@ -385,12 +390,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

params_dict = dict(self.named_parameters())

load_tie_word_embeddings = (
hasattr(self.config, "tie_word_embeddings")
and self.config.tie_word_embeddings
and "lm_head.weight" in params_dict
)

for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name:
continue
Expand Down Expand Up @@ -423,16 +422,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)

if load_tie_word_embeddings and name == "model.embed_tokens.weight":
embed_tokens_weight = loaded_weight

if load_tie_word_embeddings:
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param = self.lm_head.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
if embed_tokens_weight is not None:
weight_loader(param, embed_tokens_weight)

apply_torchao_config_(self, params_dict, set(["proj.weight"]))

def get_weights_by_name(
Expand Down
8 changes: 3 additions & 5 deletions python/sglang/srt/models/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,12 +308,10 @@ def forward(
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens.weight
lm_head = self.model.embed_tokens
else:
lm_head_weight = self.lm_head.weight
return self.logits_processor(
input_ids, hidden_states, lm_head_weight, forward_batch
)
lm_head = self.lm_head
return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
Expand Down
8 changes: 3 additions & 5 deletions python/sglang/srt/models/minicpm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,12 +585,10 @@ def forward(
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens.weight
lm_head = self.model.embed_tokens
else:
lm_head_weight = self.lm_head.weight
return self.logits_processor(
input_ids, hidden_states, lm_head_weight, forward_batch
)
lm_head = self.lm_head
return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
Expand Down
Loading

0 comments on commit 883c955

Please sign in to comment.