diff --git a/tllm/commons/manager.py b/tllm/commons/manager.py index 30317cf..32a3006 100644 --- a/tllm/commons/manager.py +++ b/tllm/commons/manager.py @@ -1,11 +1,10 @@ +import glob import os -from typing import List from transformers import AutoConfig -from tllm import BACKEND, BackendEnum from tllm.commons.tp_communicator import BaseCommunicator -from tllm.models.file_helper import find_weight_file, get_model_path +from tllm.models.file_helper import get_model_path from tllm.models.register import MODEL_REGISTER from tllm.models.weight_helper import load_gguf_weight, read_from_safetensors @@ -39,8 +38,8 @@ def __init__(self, model_path: str): self.read_master_weight = self._gguf_read_master_weight self.read_client_weight = self._gguf_read_client_weight else: - self.read_master_weight = self._hf_read_master_weight - self.read_client_weight = self._hf_read_client_weight + self.read_master_weight = self._hf_read_weight + self.read_client_weight = self._hf_read_weight self.tok, self.arch, self.config = self._post_init() def _read_flux_master_weight(self): @@ -96,43 +95,14 @@ def _gguf_read_master_weight(self): state_dict["norm.eps"] = self.config.rms_norm_eps return state_dict - def _hf_read_weight(self, prefix_key_list: List[str]): - file_key_dict = find_weight_file(self.model_path, prefix_key_list) + def _hf_read_weight(self): + sf_file_list = glob.glob(os.path.join(self.model_path, "*.safetensors")) state_dict = {} - for file, key_list in file_key_dict.items(): - weight_path = os.path.join(self.model_path, file) - state_dict.update(read_from_safetensors(weight_path, key_list if len(key_list) > 0 else prefix_key_list)) + for sf_file in sf_file_list: + state_dict.update(read_from_safetensors(sf_file)) return state_dict - def _hf_read_master_weight(self): - prefix_key_list = ["model.embed_tokens.", "model.norm.", "lm_head.", "visual.", "vision_tower."] - # For Janus-Pro - prefix_key_list += [ - "vision_model.", - "language_model.model.embed_tokens.", - "language_model.model.norm.", - "language_model.lm_head.", - "aligner.", - "gen_", - ] - state_dict = self._hf_read_weight(prefix_key_list) - - new_state_dict = {} - for k, v in state_dict.items(): - # for qwen-vl - if BACKEND == BackendEnum.MLX and k == "visual.patch_embed.proj.weight": - # [out_ch, in_ch, n, h, w] -> [out_ch, n, h, w, in_ch] - v = v.transpose(0, 2, 3, 4, 1) - - # model.layers for multi modal encoder - if k.startswith("model.") and not k.startswith("model.layers."): - new_state_dict[k.split("model.")[-1]] = v - else: - new_state_dict[k] = v - - return new_state_dict - - def _gguf_read_client_weight(self, start_idx: int, end_idx: int): + def _gguf_read_client_weight(self): if self.state_dict is None: raise ValueError("state_dict is None") new_state_dict = {} @@ -149,13 +119,6 @@ def _gguf_read_client_weight(self, start_idx: int, end_idx: int): new_state_dict[k] = v return new_state_dict - def _hf_read_client_weight(self, start_idx: int, end_idx: int): - prefix_key_list = [f"model.layers.{layer_idx}." for layer_idx in range(start_idx, end_idx)] - # for Janus-Pro - prefix_key_list += [f"language_model.model.layers.{layer_idx}." for layer_idx in range(start_idx, end_idx)] - state_dict = self._hf_read_weight(prefix_key_list) - return state_dict - def load_client_model(start_idx: int, end_idx: int, comm: BaseCommunicator, model_path: str): weight_manager = WeightManager(model_path) @@ -163,7 +126,7 @@ def load_client_model(start_idx: int, end_idx: int, comm: BaseCommunicator, mode end_idx = min(end_idx, config.num_hidden_layers) - state_dict = weight_manager.read_client_weight(start_idx, end_idx) + state_dict = weight_manager.read_client_weight() if weight_manager.arch not in MODEL_REGISTER: raise ValueError(f"Model {weight_manager.arch} not supported") diff --git a/tllm/models/file_helper.py b/tllm/models/file_helper.py index bfb5927..283a623 100644 --- a/tllm/models/file_helper.py +++ b/tllm/models/file_helper.py @@ -1,5 +1,3 @@ -from collections import defaultdict -import json import os from pathlib import Path from typing import List, Optional, Tuple @@ -8,22 +6,6 @@ from huggingface_hub.file_download import repo_folder_name -def find_weight_file(model_path: str, prefix_key_list: List[str]) -> dict: - index_path = os.path.join(model_path, "model.safetensors.index.json") - file_key_dict = defaultdict(list) - if os.path.isfile(index_path): - with open(index_path, "r") as f: - index = json.load(f) - for key, file_ in index["weight_map"].items(): - for prefix_key in prefix_key_list: - if key.startswith(prefix_key): - # file_set.add(file_) - file_key_dict[file_].append(key) - else: - file_key_dict["model.safetensors"] = [] - return file_key_dict - - def get_hf_cache_model_path(repo_id: str, revision: Optional[str] = None) -> Path: cache_dir = constants.HF_HUB_CACHE if revision is None: diff --git a/tllm/models/mlx/helper.py b/tllm/models/mlx/helper.py index c4a0ef1..67509e3 100644 --- a/tllm/models/mlx/helper.py +++ b/tllm/models/mlx/helper.py @@ -78,14 +78,8 @@ def class_predicate(p, m): return model -def read_from_safetensors(file_path: str, key_list: List[str]) -> Dict[str, mx.array]: - tensors = {} - weights = mx.load(file_path) - for key in weights.keys(): - for prefix_key in key_list: - if key.startswith(prefix_key): - tensors[key] = weights[key] - return tensors +def read_from_safetensors(file_path: str) -> Dict[str, mx.array]: + return mx.load(file_path) def get_last_hidden_states( diff --git a/tllm/models/mlx/janus_pro.py b/tllm/models/mlx/janus_pro.py index 4a0c578..11329f7 100644 --- a/tllm/models/mlx/janus_pro.py +++ b/tllm/models/mlx/janus_pro.py @@ -11,14 +11,15 @@ from tllm.models.mlx.helper import dict_to_dataclass, quantization_func from tllm.models.mlx.vq_model import ModelArgs, VQModel, vision_head from tllm.models.processor import VLMImageProcessor +from tllm.models.weight_helper import common_sanitize -def replace_vision_model_func(k: str, prefix_key: str) -> Optional[str]: - k = k.split("vision_model.", 1)[-1] - if f"{prefix_key}blocks." in k: - k = k.replace(f"{prefix_key}blocks.", f"{prefix_key}encoder.layers.") - if f"{prefix_key}patch_embed.proj." in k: - k = k.replace(f"{prefix_key}patch_embed.proj.", f"{prefix_key}embeddings.patch_embedding.") +def replace_vision_model_func(k: str) -> Optional[str]: + # k = k.split("vision_model.", 1)[-1] + if "vision_tower.blocks." in k: + k = k.replace("vision_tower.blocks.", "vision_tower.encoder.layers.") + if "vision_tower.patch_embed.proj." in k: + k = k.replace("vision_tower.patch_embed.proj.", "vision_tower.embeddings.patch_embedding.") # do not load attn_pool if "attn_pool." in k: @@ -97,17 +98,29 @@ def __call__(self, x: mx.array) -> mx.array: return x +class MLXJanusProLM(nn.Module): + def __init__(self, config): + super().__init__() + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + +class MLXJanusProVLM(nn.Module): + def __init__(self, config): + super().__init__() + self.vision_tower = SiglipVisionModel(config) + + class MLXJanusProConditionalGeneration(nn.Module): def __init__(self, config): super().__init__() - self.vision_tower = SiglipVisionModel(config.vision_config) + self.vision_model = MLXJanusProVLM(config.vision_config) self.aligner = MlpProjector(config.aligner_config["params"]) language_config = dict_to_dataclass(config.language_config, "LanguageConfig") self.vocab_size = language_config.vocab_size - self.embed_tokens = nn.Embedding(language_config.vocab_size, language_config.hidden_size) + self.model = MLXJanusProLM(language_config) self.lm_head = nn.Linear(language_config.hidden_size, language_config.vocab_size, bias=False) - self.norm = nn.RMSNorm(language_config.hidden_size, eps=language_config.rms_norm_eps) self.gen_vision_model = VQModel(ModelArgs(encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4])) self.gen_aligner = MlpProjector(config.gen_aligner_config["params"]) @@ -152,11 +165,11 @@ def from_pretrained(cls, config, state_dict: Dict[str, mx.array], **kwargs): def sanitize(weights): sanitized_weights = {} # Ugly compatibility janus - for k, v in weights.items(): + for k, v in common_sanitize(weights).items(): if k.startswith("vision_model."): - k = replace_vision_model_func(k, prefix_key="vision_tower.") + k = replace_vision_model_func(k) # Skip attn_pool - if k.startswith("vision_tower.head."): + if k.startswith("vision_model.vision_tower.head."): continue if k.startswith("language_model."): k = k.replace("language_model.model.", "") @@ -192,7 +205,7 @@ def get_input_embeddings( pixel_values: Optional[np.ndarray] = None, ) -> mx.array: # TODO: Multi-Request Maybe Has Problem - inputs_embeds = self.embed_tokens(mx.array(input_ids)) + inputs_embeds = self.model.embed_tokens(mx.array(input_ids)) if pixel_values is not None: # for mlx framework need to transpose @@ -200,7 +213,7 @@ def get_input_embeddings( pixel_values = pixel_values.transpose(0, 2, 3, 1) pixel_values = mx.array(pixel_values).astype(DTYPE) - image_embeds = self.aligner(self.vision_tower(pixel_values)) + image_embeds = self.aligner(self.vision_model.vision_tower(pixel_values)) # image_embeds: token_nums x hidden_size image_embeds = image_embeds[0] # TODO: fix this @@ -218,11 +231,11 @@ def get_input_embeddings( return inputs_embeds def get_logits(self, hidden_states: mx.array) -> mx.array: - logits = self.lm_head(self.norm(hidden_states)) + logits = self.lm_head(self.model.norm(hidden_states)) return logits def get_gen_head(self, hidden_states: mx.array, temperature: float = 1.0, cfg_weight: float = 5.0) -> mx.array: - logits = self.gen_head(self.norm(hidden_states)) + logits = self.gen_head(self.model.norm(hidden_states)) logit_cond = logits[0::2, :] logit_uncond = logits[1::2, :] diff --git a/tllm/models/mlx/llama.py b/tllm/models/mlx/llama.py index 92719f7..2c9df45 100644 --- a/tllm/models/mlx/llama.py +++ b/tllm/models/mlx/llama.py @@ -108,7 +108,7 @@ def from_pretrained(cls, config: AutoConfig, state_dict: Dict[str, mx.array], ** is_merge = True model = cls(config, is_merge) - state_dict = model.sanitize(state_dict) + state_dict = model.sanitize(state_dict, config.decoder_start_layer_idx, config.decoder_end_layer_idx) state_dict = model.merge_weights(state_dict, is_merge) model = quantization_func(config, model, state_dict) @@ -119,10 +119,16 @@ def from_pretrained(cls, config: AutoConfig, state_dict: Dict[str, mx.array], ** return model @staticmethod - def sanitize(weights): + def sanitize(weights, start_idx: int, end_idx: int): sanitized_weights = {} - for key, value in weights.items(): - sanitized_weights[key.split("language_model.", 1)[-1]] = value + for k, v in weights.items(): + if not k.startswith("model"): + continue + if "embed_tokens" in k or "model.norm" in k: + continue + if int(k.split("model.layers.", 1)[-1].split(".")[0]) not in range(start_idx, end_idx): + continue + sanitized_weights[k.split("language_model.", 1)[-1]] = v return sanitized_weights @@ -163,6 +169,7 @@ def from_pretrained(cls, config, state_dict, **kwargs): cls.num_layers = config.num_hidden_layers state_dict = tie_word_embeddings_func(config, state_dict) + state_dict = model.sanitize(state_dict) model = quantization_func(config, model, state_dict) model.load_weights(list(state_dict.items())) @@ -170,6 +177,15 @@ def from_pretrained(cls, config, state_dict, **kwargs): model.eval() return model + @staticmethod + def sanitize(weights): + sanitized_weights = {} + for k, v in weights.items(): + if k.startswith("model.layers."): + continue + sanitized_weights[k.split("model.")[-1]] = v + return sanitized_weights + def get_input_embeddings(self, x: np.ndarray) -> mx.array: return self.embed_tokens(mx.array(x)) diff --git a/tllm/models/mlx/qwen.py b/tllm/models/mlx/qwen2.py similarity index 88% rename from tllm/models/mlx/qwen.py rename to tllm/models/mlx/qwen2.py index 8651fad..02b224d 100644 --- a/tllm/models/mlx/qwen.py +++ b/tllm/models/mlx/qwen2.py @@ -73,7 +73,7 @@ def from_pretrained(cls, config: AutoConfig, state_dict: Dict[str, mx.array], ** model = cls(config, is_merge) state_dict = model.merge_weights(state_dict, is_merge) - state_dict = model.sanitize(state_dict) + state_dict = model.sanitize(state_dict, config.decoder_start_layer_idx, config.decoder_end_layer_idx) model = quantization_func(config, model, state_dict) model.load_weights(list(state_dict.items())) # strict=False @@ -82,11 +82,17 @@ def from_pretrained(cls, config: AutoConfig, state_dict: Dict[str, mx.array], ** return model @staticmethod - def sanitize(weights): + def sanitize(weights, start_idx: int, end_idx: int): sanitized_weights = {} for k, v in weights.items(): if k.startswith("language_model."): k = k.replace("language_model.", "") + if not k.startswith("model."): + continue + if "embed_tokens" in k or "model.norm" in k: + continue + if int(k.split("model.layers.", 1)[-1].split(".")[0]) not in range(start_idx, end_idx): + continue sanitized_weights[k] = v return sanitized_weights @@ -125,6 +131,8 @@ def from_pretrained(cls, config, state_dict: Optional[Dict], **kwargs): cls.num_layers = config.num_hidden_layers state_dict = tie_word_embeddings_func(config, state_dict) + state_dict = model.sanitize(state_dict) + model = quantization_func(config, model, state_dict) model.load_weights(list(state_dict.items())) # , strict=False @@ -132,6 +140,15 @@ def from_pretrained(cls, config, state_dict: Optional[Dict], **kwargs): model.eval() return model + @staticmethod + def sanitize(weights): + sanitized_weights = {} + for k, v in weights.items(): + if k.startswith("model.layers."): + continue + sanitized_weights[k.split("model.")[-1]] = v + return sanitized_weights + def get_input_embeddings(self, x: np.ndarray) -> mx.array: return self.embed_tokens(mx.array(x)) diff --git a/tllm/models/mlx/qwen2_vl.py b/tllm/models/mlx/qwen2_vl.py index ff0490f..489e380 100644 --- a/tllm/models/mlx/qwen2_vl.py +++ b/tllm/models/mlx/qwen2_vl.py @@ -78,7 +78,7 @@ def from_pretrained(cls, config, state_dict: Dict[str, mx.array], **kwargs): state_dict = tie_word_embeddings_func(config, state_dict) state_dict = model.sanitize(state_dict) model = quantization_func(config, model, state_dict) - model.load_weights(list(state_dict.items())) # , strict=False + model.load_weights(list(state_dict.items())) mx.eval(model.parameters()) model.eval() @@ -88,12 +88,23 @@ def from_pretrained(cls, config, state_dict: Dict[str, mx.array], **kwargs): def sanitize(weights): sanitized_weights = {} for k, v in weights.items(): - if k.startswith("language_model."): + if k.startswith("language_model.model.layers"): + continue + if k.startswith("model.layers"): + continue + if k.startswith("language_model.model."): k = k.replace("language_model.model.", "") - k = k.replace("language_model.", "") + + if k.startswith("model."): + k = k.replace("model.", "") if k.startswith("vision_tower."): k = k.replace("vision_tower.", "visual.") + if k == "visual.patch_embed.proj.weight": + # [out_ch, in_ch, n, h, w] -> [out_ch, n, h, w, in_ch] + if v.shape[3] == v.shape[4]: + v = v.transpose(0, 2, 3, 4, 1) + sanitized_weights[k] = v return sanitized_weights diff --git a/tllm/models/register.py b/tllm/models/register.py index bf94543..dfe2127 100644 --- a/tllm/models/register.py +++ b/tllm/models/register.py @@ -18,8 +18,8 @@ if BackendEnum.MLX == BACKEND: from tllm.models.mlx.janus_pro import MLXJanusProConditionalGeneration from tllm.models.mlx.llama import MLXLlamaForCausalLM, MLXLlamaModel + from tllm.models.mlx.qwen2 import MLXQwen2ForCausalLM, MLXQwen2Model from tllm.models.mlx.qwen2_vl import MLXQwen2VLForConditionalGeneration - from tllm.models.mlx.qwen import MLXQwen2ForCausalLM, MLXQwen2Model MODEL_REGISTER.update({"LlamaForCausalLM": (MLXLlamaForCausalLM, MLXLlamaModel)}) MODEL_REGISTER.update({"Qwen2ForCausalLM": (MLXQwen2ForCausalLM, MLXQwen2Model)}) @@ -38,8 +38,8 @@ sampling_func = greedy_decode elif BackendEnum.TORCH == BACKEND: from tllm.models.torch.llama import HFLlamaForCausalLM, HFLlamaModel - from tllm.models.torch.qwen import HFQwen2ForCausalLM, HFQwen2Model - from tllm.models.torch.qwen_vl import HFQwen2VLForConditionalGeneration + from tllm.models.torch.qwen2 import HFQwen2ForCausalLM, HFQwen2Model + from tllm.models.torch.qwen2_vl import HFQwen2VLForConditionalGeneration MODEL_REGISTER.update({"LlamaForCausalLM": (HFLlamaForCausalLM, HFLlamaModel)}) MODEL_REGISTER.update({"Qwen2ForCausalLM": (HFQwen2ForCausalLM, HFQwen2Model)}) diff --git a/tllm/models/torch/helper.py b/tllm/models/torch/helper.py index cc7f37b..a52f468 100644 --- a/tllm/models/torch/helper.py +++ b/tllm/models/torch/helper.py @@ -1,7 +1,7 @@ import itertools from typing import Dict, List -from safetensors import safe_open +from safetensors.torch import load_file as safe_load import torch from tllm.commons.attn import ATTN_TYPE @@ -45,19 +45,8 @@ def build_mask(q_len_list: List[int], k_len_list: List[int]) -> "torch.Tensor": return combined_mask -def read_from_safetensors(file_path: str, key_list: List[str] = None) -> Dict[str, "torch.Tensor"]: - tensors = {} - if key_list: - with safe_open(file_path, framework="pt", device="cpu") as f: - for key in f.keys(): - for prefix_key in key_list: - if key.startswith(prefix_key): - tensors[key] = f.get_tensor(key) - else: - with safe_open(file_path, framework="pt", device="cpu") as f: - for key in f.keys(): - tensors[key] = f.get_tensor(key) - return tensors +def read_from_safetensors(file_path: str) -> Dict[str, torch.Tensor]: + return safe_load(file_path, device="cpu") if ATTN_TYPE == "xformers": @@ -91,4 +80,5 @@ def build_forward_cache( attn_mask=attn_mask, uuid_list=seq_input.uuid_list, position_ids=torch.cat(position_ids_list, dim=-1), + q_len_list=q_len_list, ) diff --git a/tllm/models/torch/layers.py b/tllm/models/torch/layers.py index 0677cd1..5faec1f 100644 --- a/tllm/models/torch/layers.py +++ b/tllm/models/torch/layers.py @@ -5,10 +5,10 @@ from transformers import AutoConfig from transformers.activations import ACT2FN from transformers.models.llama.modeling_llama import ( + LlamaAttention, LlamaConfig, LlamaMLP, LlamaRMSNorm, - LlamaSdpaAttention, apply_rotary_pos_emb, ) @@ -197,7 +197,7 @@ def forward( return attn_output, None -class PlainLlamaSdpaAttention(LlamaSdpaAttention): +class PlainLlamaSdpaAttention(LlamaAttention): def forward( self, hidden_states: torch.Tensor, diff --git a/tllm/models/torch/llama.py b/tllm/models/torch/llama.py index 54190a5..f868789 100644 --- a/tllm/models/torch/llama.py +++ b/tllm/models/torch/llama.py @@ -57,12 +57,27 @@ def __init__(self, config, is_merge: bool = True): def from_pretrained(cls, config, state_dict: Dict[str, torch.Tensor], is_merge: bool = True, **kwargs): model = cls(config, is_merge) state_dict = model.merge_weights(state_dict, is_merge) + state_dict = model.sanitize(state_dict, config.decoder_start_layer_idx, config.decoder_end_layer_idx) model.load_state_dict(state_dict) model.to(DTYPE).to(DEVICE) model.eval() return model + @staticmethod + def sanitize(weights, start_idx: int, end_idx: int): + sanitized_weights = {} + for k, v in weights.items(): + if not k.startswith("model"): + continue + if "embed_tokens" in k or "model.norm" in k: + continue + if int(k.split("model.layers.", 1)[-1].split(".")[0]) not in range(start_idx, end_idx): + continue + sanitized_weights[k.split("language_model.", 1)[-1]] = v + + return sanitized_weights + def merge_weights(self, state_dict: Dict[str, torch.Tensor], is_merge: bool) -> Dict[str, torch.Tensor]: if not is_merge: return state_dict @@ -138,12 +153,22 @@ def from_pretrained(cls, config, state_dict: Optional[Dict] = None, **kwargs): cls.config = config cls.num_layers = config.num_hidden_layers - state_dict = tie_word_embeddings_func(state_dict) + state_dict = tie_word_embeddings_func(config, state_dict) + state_dict = model.sanitize(state_dict) model.load_state_dict(state_dict) model.to(DTYPE).to(DEVICE) model.eval() return model + @staticmethod + def sanitize(weights): + sanitized_weights = {} + for k, v in weights.items(): + if k.startswith("model.layers."): + continue + sanitized_weights[k.split("model.")[-1]] = v + return sanitized_weights + @torch.inference_mode() def get_input_embeddings(self, x: np.ndarray) -> torch.Tensor: return self.embed_tokens(torch.tensor(x, device=DEVICE)) diff --git a/tllm/models/torch/qwen.py b/tllm/models/torch/qwen2.py similarity index 85% rename from tllm/models/torch/qwen.py rename to tllm/models/torch/qwen2.py index c309267..7f43b9c 100644 --- a/tllm/models/torch/qwen.py +++ b/tllm/models/torch/qwen2.py @@ -62,12 +62,29 @@ def __init__(self, config, is_merge: bool = True): def from_pretrained(cls, config, state_dict: Dict[str, torch.Tensor], is_merge: bool = True, **kwargs): model = cls(config, is_merge) state_dict = model.merge_weights(state_dict, is_merge) + state_dict = model.sanitize(state_dict, config.decoder_start_layer_idx, config.decoder_end_layer_idx) model.load_state_dict(state_dict) model.to(DTYPE).to(DEVICE) model.eval() return model + @staticmethod + def sanitize(weights, start_idx: int, end_idx: int): + sanitized_weights = {} + for k, v in weights.items(): + if k.startswith("language_model."): + k = k.replace("language_model.", "") + if not k.startswith("model."): + continue + if "embed_tokens" in k or "model.norm" in k: + continue + if int(k.split("model.layers.", 1)[-1].split(".")[0]) not in range(start_idx, end_idx): + continue + + sanitized_weights[k] = v + return sanitized_weights + def merge_weights(self, state_dict: Dict[str, torch.Tensor], is_merge: bool) -> Dict[str, torch.Tensor]: if not is_merge: return state_dict @@ -144,12 +161,22 @@ def from_pretrained(cls, config, state_dict: Optional[Dict] = None, **kwargs): cls.config = config cls.num_layers = config.num_hidden_layers - state_dict = tie_word_embeddings_func(state_dict) + state_dict = tie_word_embeddings_func(config, state_dict) + state_dict = model.sanitize(state_dict) model.load_state_dict(state_dict) model.to(DTYPE).to(DEVICE) model.eval() return model + @staticmethod + def sanitize(weights): + sanitized_weights = {} + for k, v in weights.items(): + if k.startswith("model.layers."): + continue + sanitized_weights[k.split("model.")[-1]] = v + return sanitized_weights + @torch.inference_mode() def get_input_embeddings(self, x: np.ndarray) -> torch.Tensor: return self.embed_tokens(torch.tensor(x, device=DEVICE)) diff --git a/tllm/models/torch/qwen_vl.py b/tllm/models/torch/qwen2_vl.py similarity index 86% rename from tllm/models/torch/qwen_vl.py rename to tllm/models/torch/qwen2_vl.py index da99f82..13426a6 100644 --- a/tllm/models/torch/qwen_vl.py +++ b/tllm/models/torch/qwen2_vl.py @@ -52,12 +52,32 @@ def from_pretrained(cls, config, state_dict: Dict[str, torch.Tensor], **kwargs): cls.config = config cls.num_layers = config.num_hidden_layers - state_dict = tie_word_embeddings_func(state_dict) + state_dict = tie_word_embeddings_func(config, state_dict) + state_dict = model.sanitize(state_dict) model.load_state_dict(state_dict) model.to(DTYPE).to(DEVICE) model.eval() return model + @staticmethod + def sanitize(weights): + sanitized_weights = {} + for k, v in weights.items(): + if k.startswith("language_model.model.layers"): + continue + if k.startswith("model.layers"): + continue + if k.startswith("language_model.model."): + k = k.replace("language_model.model.", "") + + if k.startswith("model."): + k = k.replace("model.", "") + if k.startswith("vision_tower."): + k = k.replace("vision_tower.", "visual.") + + sanitized_weights[k] = v + return sanitized_weights + @torch.inference_mode() def get_input_embeddings( self, diff --git a/tllm/models/weight_helper.py b/tllm/models/weight_helper.py index 3ee8daa..cc959a9 100644 --- a/tllm/models/weight_helper.py +++ b/tllm/models/weight_helper.py @@ -20,6 +20,15 @@ load_gguf_weight = lambda x: None, None, None +def common_sanitize(weights): + sanitized_weights = {} + for k, v in weights.items(): + if k.startswith("model.layers."): + continue + sanitized_weights[k] = v + return sanitized_weights + + def pop_weight_func( prefix_key_list: List[str], weights: Dict[str, MIX_TENSOR], num_layers: int, start_idx: int, end_idx: int ) -> Dict[str, MIX_TENSOR]: