Skip to content

Commit

Permalink
move sanitize to model class
Browse files Browse the repository at this point in the history
  • Loading branch information
wnma3mz committed Feb 1, 2025
1 parent 057a2e5 commit 3cc0a98
Show file tree
Hide file tree
Showing 14 changed files with 187 additions and 120 deletions.
57 changes: 10 additions & 47 deletions tllm/commons/manager.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import glob
import os
from typing import List

from transformers import AutoConfig

from tllm import BACKEND, BackendEnum
from tllm.commons.tp_communicator import BaseCommunicator
from tllm.models.file_helper import find_weight_file, get_model_path
from tllm.models.file_helper import get_model_path
from tllm.models.register import MODEL_REGISTER
from tllm.models.weight_helper import load_gguf_weight, read_from_safetensors

Expand Down Expand Up @@ -39,8 +38,8 @@ def __init__(self, model_path: str):
self.read_master_weight = self._gguf_read_master_weight
self.read_client_weight = self._gguf_read_client_weight
else:
self.read_master_weight = self._hf_read_master_weight
self.read_client_weight = self._hf_read_client_weight
self.read_master_weight = self._hf_read_weight
self.read_client_weight = self._hf_read_weight
self.tok, self.arch, self.config = self._post_init()

def _read_flux_master_weight(self):
Expand Down Expand Up @@ -96,43 +95,14 @@ def _gguf_read_master_weight(self):
state_dict["norm.eps"] = self.config.rms_norm_eps
return state_dict

def _hf_read_weight(self, prefix_key_list: List[str]):
file_key_dict = find_weight_file(self.model_path, prefix_key_list)
def _hf_read_weight(self):
sf_file_list = glob.glob(os.path.join(self.model_path, "*.safetensors"))
state_dict = {}
for file, key_list in file_key_dict.items():
weight_path = os.path.join(self.model_path, file)
state_dict.update(read_from_safetensors(weight_path, key_list if len(key_list) > 0 else prefix_key_list))
for sf_file in sf_file_list:
state_dict.update(read_from_safetensors(sf_file))
return state_dict

def _hf_read_master_weight(self):
prefix_key_list = ["model.embed_tokens.", "model.norm.", "lm_head.", "visual.", "vision_tower."]
# For Janus-Pro
prefix_key_list += [
"vision_model.",
"language_model.model.embed_tokens.",
"language_model.model.norm.",
"language_model.lm_head.",
"aligner.",
"gen_",
]
state_dict = self._hf_read_weight(prefix_key_list)

new_state_dict = {}
for k, v in state_dict.items():
# for qwen-vl
if BACKEND == BackendEnum.MLX and k == "visual.patch_embed.proj.weight":
# [out_ch, in_ch, n, h, w] -> [out_ch, n, h, w, in_ch]
v = v.transpose(0, 2, 3, 4, 1)

# model.layers for multi modal encoder
if k.startswith("model.") and not k.startswith("model.layers."):
new_state_dict[k.split("model.")[-1]] = v
else:
new_state_dict[k] = v

return new_state_dict

def _gguf_read_client_weight(self, start_idx: int, end_idx: int):
def _gguf_read_client_weight(self):
if self.state_dict is None:
raise ValueError("state_dict is None")
new_state_dict = {}
Expand All @@ -149,21 +119,14 @@ def _gguf_read_client_weight(self, start_idx: int, end_idx: int):
new_state_dict[k] = v
return new_state_dict

def _hf_read_client_weight(self, start_idx: int, end_idx: int):
prefix_key_list = [f"model.layers.{layer_idx}." for layer_idx in range(start_idx, end_idx)]
# for Janus-Pro
prefix_key_list += [f"language_model.model.layers.{layer_idx}." for layer_idx in range(start_idx, end_idx)]
state_dict = self._hf_read_weight(prefix_key_list)
return state_dict


def load_client_model(start_idx: int, end_idx: int, comm: BaseCommunicator, model_path: str):
weight_manager = WeightManager(model_path)
config = weight_manager.config

end_idx = min(end_idx, config.num_hidden_layers)

state_dict = weight_manager.read_client_weight(start_idx, end_idx)
state_dict = weight_manager.read_client_weight()
if weight_manager.arch not in MODEL_REGISTER:
raise ValueError(f"Model {weight_manager.arch} not supported")

Expand Down
18 changes: 0 additions & 18 deletions tllm/models/file_helper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from collections import defaultdict
import json
import os
from pathlib import Path
from typing import List, Optional, Tuple
Expand All @@ -8,22 +6,6 @@
from huggingface_hub.file_download import repo_folder_name


def find_weight_file(model_path: str, prefix_key_list: List[str]) -> dict:
index_path = os.path.join(model_path, "model.safetensors.index.json")
file_key_dict = defaultdict(list)
if os.path.isfile(index_path):
with open(index_path, "r") as f:
index = json.load(f)
for key, file_ in index["weight_map"].items():
for prefix_key in prefix_key_list:
if key.startswith(prefix_key):
# file_set.add(file_)
file_key_dict[file_].append(key)
else:
file_key_dict["model.safetensors"] = []
return file_key_dict


def get_hf_cache_model_path(repo_id: str, revision: Optional[str] = None) -> Path:
cache_dir = constants.HF_HUB_CACHE
if revision is None:
Expand Down
10 changes: 2 additions & 8 deletions tllm/models/mlx/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,8 @@ def class_predicate(p, m):
return model


def read_from_safetensors(file_path: str, key_list: List[str]) -> Dict[str, mx.array]:
tensors = {}
weights = mx.load(file_path)
for key in weights.keys():
for prefix_key in key_list:
if key.startswith(prefix_key):
tensors[key] = weights[key]
return tensors
def read_from_safetensors(file_path: str) -> Dict[str, mx.array]:
return mx.load(file_path)


def get_last_hidden_states(
Expand Down
45 changes: 29 additions & 16 deletions tllm/models/mlx/janus_pro.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
from tllm.models.mlx.helper import dict_to_dataclass, quantization_func
from tllm.models.mlx.vq_model import ModelArgs, VQModel, vision_head
from tllm.models.processor import VLMImageProcessor
from tllm.models.weight_helper import common_sanitize


def replace_vision_model_func(k: str, prefix_key: str) -> Optional[str]:
k = k.split("vision_model.", 1)[-1]
if f"{prefix_key}blocks." in k:
k = k.replace(f"{prefix_key}blocks.", f"{prefix_key}encoder.layers.")
if f"{prefix_key}patch_embed.proj." in k:
k = k.replace(f"{prefix_key}patch_embed.proj.", f"{prefix_key}embeddings.patch_embedding.")
def replace_vision_model_func(k: str) -> Optional[str]:
# k = k.split("vision_model.", 1)[-1]
if "vision_tower.blocks." in k:
k = k.replace("vision_tower.blocks.", "vision_tower.encoder.layers.")
if "vision_tower.patch_embed.proj." in k:
k = k.replace("vision_tower.patch_embed.proj.", "vision_tower.embeddings.patch_embedding.")

# do not load attn_pool
if "attn_pool." in k:
Expand Down Expand Up @@ -97,17 +98,29 @@ def __call__(self, x: mx.array) -> mx.array:
return x


class MLXJanusProLM(nn.Module):
def __init__(self, config):
super().__init__()
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)


class MLXJanusProVLM(nn.Module):
def __init__(self, config):
super().__init__()
self.vision_tower = SiglipVisionModel(config)


class MLXJanusProConditionalGeneration(nn.Module):
def __init__(self, config):
super().__init__()
self.vision_tower = SiglipVisionModel(config.vision_config)
self.vision_model = MLXJanusProVLM(config.vision_config)
self.aligner = MlpProjector(config.aligner_config["params"])

language_config = dict_to_dataclass(config.language_config, "LanguageConfig")
self.vocab_size = language_config.vocab_size
self.embed_tokens = nn.Embedding(language_config.vocab_size, language_config.hidden_size)
self.model = MLXJanusProLM(language_config)
self.lm_head = nn.Linear(language_config.hidden_size, language_config.vocab_size, bias=False)
self.norm = nn.RMSNorm(language_config.hidden_size, eps=language_config.rms_norm_eps)

self.gen_vision_model = VQModel(ModelArgs(encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4]))
self.gen_aligner = MlpProjector(config.gen_aligner_config["params"])
Expand Down Expand Up @@ -152,11 +165,11 @@ def from_pretrained(cls, config, state_dict: Dict[str, mx.array], **kwargs):
def sanitize(weights):
sanitized_weights = {}
# Ugly compatibility janus
for k, v in weights.items():
for k, v in common_sanitize(weights).items():
if k.startswith("vision_model."):
k = replace_vision_model_func(k, prefix_key="vision_tower.")
k = replace_vision_model_func(k)
# Skip attn_pool
if k.startswith("vision_tower.head."):
if k.startswith("vision_model.vision_tower.head."):
continue
if k.startswith("language_model."):
k = k.replace("language_model.model.", "")
Expand Down Expand Up @@ -192,15 +205,15 @@ def get_input_embeddings(
pixel_values: Optional[np.ndarray] = None,
) -> mx.array:
# TODO: Multi-Request Maybe Has Problem
inputs_embeds = self.embed_tokens(mx.array(input_ids))
inputs_embeds = self.model.embed_tokens(mx.array(input_ids))

if pixel_values is not None:
# for mlx framework need to transpose
# bs, c, h, w -> bs, h, w, c
pixel_values = pixel_values.transpose(0, 2, 3, 1)

pixel_values = mx.array(pixel_values).astype(DTYPE)
image_embeds = self.aligner(self.vision_tower(pixel_values))
image_embeds = self.aligner(self.vision_model.vision_tower(pixel_values))
# image_embeds: token_nums x hidden_size
image_embeds = image_embeds[0] # TODO: fix this

Expand All @@ -218,11 +231,11 @@ def get_input_embeddings(
return inputs_embeds

def get_logits(self, hidden_states: mx.array) -> mx.array:
logits = self.lm_head(self.norm(hidden_states))
logits = self.lm_head(self.model.norm(hidden_states))
return logits

def get_gen_head(self, hidden_states: mx.array, temperature: float = 1.0, cfg_weight: float = 5.0) -> mx.array:
logits = self.gen_head(self.norm(hidden_states))
logits = self.gen_head(self.model.norm(hidden_states))
logit_cond = logits[0::2, :]
logit_uncond = logits[1::2, :]

Expand Down
24 changes: 20 additions & 4 deletions tllm/models/mlx/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def from_pretrained(cls, config: AutoConfig, state_dict: Dict[str, mx.array], **
is_merge = True

model = cls(config, is_merge)
state_dict = model.sanitize(state_dict)
state_dict = model.sanitize(state_dict, config.decoder_start_layer_idx, config.decoder_end_layer_idx)
state_dict = model.merge_weights(state_dict, is_merge)

model = quantization_func(config, model, state_dict)
Expand All @@ -119,10 +119,16 @@ def from_pretrained(cls, config: AutoConfig, state_dict: Dict[str, mx.array], **
return model

@staticmethod
def sanitize(weights):
def sanitize(weights, start_idx: int, end_idx: int):
sanitized_weights = {}
for key, value in weights.items():
sanitized_weights[key.split("language_model.", 1)[-1]] = value
for k, v in weights.items():
if not k.startswith("model"):
continue
if "embed_tokens" in k or "model.norm" in k:
continue
if int(k.split("model.layers.", 1)[-1].split(".")[0]) not in range(start_idx, end_idx):
continue
sanitized_weights[k.split("language_model.", 1)[-1]] = v

return sanitized_weights

Expand Down Expand Up @@ -163,13 +169,23 @@ def from_pretrained(cls, config, state_dict, **kwargs):
cls.num_layers = config.num_hidden_layers
state_dict = tie_word_embeddings_func(config, state_dict)

state_dict = model.sanitize(state_dict)
model = quantization_func(config, model, state_dict)
model.load_weights(list(state_dict.items()))

mx.eval(model.parameters())
model.eval()
return model

@staticmethod
def sanitize(weights):
sanitized_weights = {}
for k, v in weights.items():
if k.startswith("model.layers."):
continue
sanitized_weights[k.split("model.")[-1]] = v
return sanitized_weights

def get_input_embeddings(self, x: np.ndarray) -> mx.array:
return self.embed_tokens(mx.array(x))

Expand Down
21 changes: 19 additions & 2 deletions tllm/models/mlx/qwen.py → tllm/models/mlx/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def from_pretrained(cls, config: AutoConfig, state_dict: Dict[str, mx.array], **
model = cls(config, is_merge)
state_dict = model.merge_weights(state_dict, is_merge)

state_dict = model.sanitize(state_dict)
state_dict = model.sanitize(state_dict, config.decoder_start_layer_idx, config.decoder_end_layer_idx)
model = quantization_func(config, model, state_dict)
model.load_weights(list(state_dict.items())) # strict=False

Expand All @@ -82,11 +82,17 @@ def from_pretrained(cls, config: AutoConfig, state_dict: Dict[str, mx.array], **
return model

@staticmethod
def sanitize(weights):
def sanitize(weights, start_idx: int, end_idx: int):
sanitized_weights = {}
for k, v in weights.items():
if k.startswith("language_model."):
k = k.replace("language_model.", "")
if not k.startswith("model."):
continue
if "embed_tokens" in k or "model.norm" in k:
continue
if int(k.split("model.layers.", 1)[-1].split(".")[0]) not in range(start_idx, end_idx):
continue

sanitized_weights[k] = v
return sanitized_weights
Expand Down Expand Up @@ -125,13 +131,24 @@ def from_pretrained(cls, config, state_dict: Optional[Dict], **kwargs):
cls.num_layers = config.num_hidden_layers

state_dict = tie_word_embeddings_func(config, state_dict)
state_dict = model.sanitize(state_dict)

model = quantization_func(config, model, state_dict)
model.load_weights(list(state_dict.items())) # , strict=False

mx.eval(model.parameters())
model.eval()
return model

@staticmethod
def sanitize(weights):
sanitized_weights = {}
for k, v in weights.items():
if k.startswith("model.layers."):
continue
sanitized_weights[k.split("model.")[-1]] = v
return sanitized_weights

def get_input_embeddings(self, x: np.ndarray) -> mx.array:
return self.embed_tokens(mx.array(x))

Expand Down
17 changes: 14 additions & 3 deletions tllm/models/mlx/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def from_pretrained(cls, config, state_dict: Dict[str, mx.array], **kwargs):
state_dict = tie_word_embeddings_func(config, state_dict)
state_dict = model.sanitize(state_dict)
model = quantization_func(config, model, state_dict)
model.load_weights(list(state_dict.items())) # , strict=False
model.load_weights(list(state_dict.items()))

mx.eval(model.parameters())
model.eval()
Expand All @@ -88,12 +88,23 @@ def from_pretrained(cls, config, state_dict: Dict[str, mx.array], **kwargs):
def sanitize(weights):
sanitized_weights = {}
for k, v in weights.items():
if k.startswith("language_model."):
if k.startswith("language_model.model.layers"):
continue
if k.startswith("model.layers"):
continue
if k.startswith("language_model.model."):
k = k.replace("language_model.model.", "")
k = k.replace("language_model.", "")

if k.startswith("model."):
k = k.replace("model.", "")
if k.startswith("vision_tower."):
k = k.replace("vision_tower.", "visual.")

if k == "visual.patch_embed.proj.weight":
# [out_ch, in_ch, n, h, w] -> [out_ch, n, h, w, in_ch]
if v.shape[3] == v.shape[4]:
v = v.transpose(0, 2, 3, 4, 1)

sanitized_weights[k] = v
return sanitized_weights

Expand Down
Loading

0 comments on commit 3cc0a98

Please sign in to comment.