From 020add12b89e674433903c3fcba79ff8b24a590d Mon Sep 17 00:00:00 2001 From: lujianghu Date: Sun, 2 Feb 2025 10:04:41 +0800 Subject: [PATCH] mlx-vlm and mflux are not mandatory --- README.md | 7 +- requirements/mlx.txt | 3 +- run_engine.py | 1 - run_janus_pro.py | 44 +-- tllm/commons/manager.py | 8 +- tllm/models/mlx/janus_pro/__init__.py | 0 tllm/models/mlx/{ => janus_pro}/janus_pro.py | 12 +- tllm/models/mlx/janus_pro/siglip_model.py | 341 +++++++++++++++++++ tllm/models/mlx/{ => janus_pro}/vq_model.py | 0 tllm/models/register.py | 19 +- 10 files changed, 380 insertions(+), 55 deletions(-) create mode 100644 tllm/models/mlx/janus_pro/__init__.py rename tllm/models/mlx/{ => janus_pro}/janus_pro.py (96%) create mode 100644 tllm/models/mlx/janus_pro/siglip_model.py rename tllm/models/mlx/{ => janus_pro}/vq_model.py (100%) diff --git a/README.md b/README.md index 678d0fe..9fb5774 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ 1. install dependencies -- for mlx (macos arm): `pip install -U -e ".[mlx]" && pip install -e git+https://github.com/wnma3mz/mlx_clip.git#egg=mlx_clip` +- for mlx (macos arm): `pip install -U -e ".[mlx]"` - for nvidia: `pip install -e ".[torch]"` 2. run server @@ -33,6 +33,11 @@ python3 benchmarks/run_async_requests.py ``` +### More Model + +- qwen-vl: `pip install mlx-vlm==0.1.12` +- flux: `pip install mflux=0.4.1` + ### More Details In `examples/config.json` diff --git a/requirements/mlx.txt b/requirements/mlx.txt index 36f2aa3..f9bf761 100644 --- a/requirements/mlx.txt +++ b/requirements/mlx.txt @@ -1,3 +1,2 @@ mlx==0.22.0 -mlx-lm==0.21.1 -mlx-vlm==0.1.12 \ No newline at end of file +mlx-lm==0.21.1 \ No newline at end of file diff --git a/run_engine.py b/run_engine.py index 786c771..2beb57d 100644 --- a/run_engine.py +++ b/run_engine.py @@ -120,7 +120,6 @@ def load_message(): if __name__ == "__main__": - args = parse_args() messages_dict = load_message() if args.message_type == "llm": asyncio.run(llm_generate(args, messages_dict["llm"][0])) diff --git a/run_janus_pro.py b/run_janus_pro.py index 36b2c26..48ffde4 100644 --- a/run_janus_pro.py +++ b/run_janus_pro.py @@ -3,42 +3,18 @@ import base64 import json import math -import os import time from typing import List, Tuple from PIL import Image import numpy as np - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--backend", type=str, default="MLX", choices=["MLX", "TORCH"]) - parser.add_argument( - "--attn_backend", - type=str, - default="AUTO", - choices=["AUTO", "TORCH", "VLLM", "XFormers"], - help="Attention backend if backend is TORCH", - ) - parser.add_argument("--model_path", type=str, default="Qwen/Qwen2-VL-2B-Instruct") - parser.add_argument("--message_type", type=str, default="llm", choices=["llm", "mllm", "image"]) - return parser.parse_args() - - -args = parse_args() -os.environ["TLLM_BACKEND"] = args.backend -os.environ["TLLM_ATTN_BACKEND"] = args.attn_backend - from tllm.commons.manager import load_client_model, load_master_model from tllm.commons.tp_communicator import Communicator from tllm.engine import AsyncEngine -from tllm.entrypoints.image_server.image_protocol import Text2ImageRequest -from tllm.entrypoints.image_server.server_image import ImageServing from tllm.entrypoints.protocol import ChatCompletionRequest, ChatCompletionResponse from tllm.entrypoints.server_chat import OpenAIServing from tllm.generate import LLMGenerator -from tllm.img_helper import base64_to_pil_image from tllm.schemas import MIX_TENSOR, SeqInput from tllm.singleton_logger import SingletonLogger @@ -56,19 +32,6 @@ async def forward(self, hidden_states: MIX_TENSOR, seq_input: SeqInput) -> Tuple output_hidden_states = self.model(hidden_states, seq_input) return output_hidden_states, [time.perf_counter() - s1] - async def image_forward( - self, - hidden_states: MIX_TENSOR, - text_embeddings: MIX_TENSOR, - seq_len: int, - height: int, - width: int, - request_id: str, - ) -> Tuple[MIX_TENSOR, List[float]]: - s1 = time.perf_counter() - output_hidden_states = self.model(hidden_states, text_embeddings, seq_len, height, width, [request_id]) - return output_hidden_states, [time.perf_counter() - s1] - def init_engine(model_path: str) -> AsyncEngine: model = load_master_model(model_path) @@ -123,6 +86,13 @@ def gen_img_message(): ] +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, default="Qwen/Qwen2-VL-2B-Instruct") + parser.add_argument("--message_type", type=str, default="llm", choices=["llm", "mllm", "image"]) + return parser.parse_args() + + if __name__ == "__main__": args = parse_args() messages_dict = load_message() diff --git a/tllm/commons/manager.py b/tllm/commons/manager.py index 32a3006..a0b1b99 100644 --- a/tllm/commons/manager.py +++ b/tllm/commons/manager.py @@ -5,7 +5,7 @@ from tllm.commons.tp_communicator import BaseCommunicator from tllm.models.file_helper import get_model_path -from tllm.models.register import MODEL_REGISTER +from tllm.models.register import DEP_MODEL_REGISTER, MODEL_REGISTER from tllm.models.weight_helper import load_gguf_weight, read_from_safetensors @@ -148,7 +148,11 @@ def load_master_model(model_path: str): weight_manager = WeightManager(model_path) state_dict = weight_manager.read_master_weight() if weight_manager.arch not in MODEL_REGISTER: - raise ValueError(f"Model {weight_manager.arch} not supported") + arch = weight_manager.arch + if weight_manager.arch in DEP_MODEL_REGISTER: + raise ValueError(f"Model {arch} now is support, please execute `pip install {DEP_MODEL_REGISTER[arch]}`") + else: + raise ValueError(f"Model {arch} not supported") MY_CausalLM_CLASS, _ = MODEL_REGISTER[weight_manager.arch] diff --git a/tllm/models/mlx/janus_pro/__init__.py b/tllm/models/mlx/janus_pro/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tllm/models/mlx/janus_pro.py b/tllm/models/mlx/janus_pro/janus_pro.py similarity index 96% rename from tllm/models/mlx/janus_pro.py rename to tllm/models/mlx/janus_pro/janus_pro.py index 11329f7..b57b040 100644 --- a/tllm/models/mlx/janus_pro.py +++ b/tllm/models/mlx/janus_pro/janus_pro.py @@ -4,27 +4,22 @@ import mlx.core as mx import mlx.nn as nn -from mlx_clip.models.siglip.siglip_model import SiglipVisionModel import numpy as np from tllm import DTYPE from tllm.models.mlx.helper import dict_to_dataclass, quantization_func -from tllm.models.mlx.vq_model import ModelArgs, VQModel, vision_head +from tllm.models.mlx.janus_pro.siglip_model import SiglipVisionModel +from tllm.models.mlx.janus_pro.vq_model import ModelArgs, VQModel, vision_head from tllm.models.processor import VLMImageProcessor from tllm.models.weight_helper import common_sanitize def replace_vision_model_func(k: str) -> Optional[str]: - # k = k.split("vision_model.", 1)[-1] if "vision_tower.blocks." in k: k = k.replace("vision_tower.blocks.", "vision_tower.encoder.layers.") if "vision_tower.patch_embed.proj." in k: k = k.replace("vision_tower.patch_embed.proj.", "vision_tower.embeddings.patch_embedding.") - # do not load attn_pool - if "attn_pool." in k: - k = k.replace("attn_pool.", "head.") - if ".norm2." in k: k = k.replace(".norm2.", ".layer_norm2.") if ".norm1." in k: @@ -169,7 +164,7 @@ def sanitize(weights): if k.startswith("vision_model."): k = replace_vision_model_func(k) # Skip attn_pool - if k.startswith("vision_model.vision_tower.head."): + if k.startswith("vision_model.vision_tower.attn_pool."): continue if k.startswith("language_model."): k = k.replace("language_model.model.", "") @@ -179,6 +174,7 @@ def sanitize(weights): if "weight" in k and len(v.shape) == 4: # [out_ch, in_ch, h, w] -> [out_ch, h, w, in_ch] v = v.transpose(0, 2, 3, 1) + # Skip encoder if "encoder" in k: continue if ".quant_conv" in k: diff --git a/tllm/models/mlx/janus_pro/siglip_model.py b/tllm/models/mlx/janus_pro/siglip_model.py new file mode 100644 index 0000000..ec496a1 --- /dev/null +++ b/tllm/models/mlx/janus_pro/siglip_model.py @@ -0,0 +1,341 @@ +# Copyright © 2023-2024 Apple Inc. + +from dataclasses import dataclass +import glob +import json +import logging +import math +from pathlib import Path +from typing import Dict, Optional + +import mlx.core as mx +import mlx.nn as nn + +# from mlx.core import linalg as LA +from transformers import SiglipVisionConfig + + +@dataclass +class SiglipVisionConfig: + num_hidden_layers: int + hidden_size: int + intermediate_size: int + num_attention_heads: int + num_channels: int + image_size: int + patch_size: int + layer_norm_eps: float + use_head: bool = True + mlp_ratio: int = 4 + ignore_head: bool = True + + @classmethod + def from_dict(cls, data_dict: Dict): + return cls( + num_hidden_layers=data_dict.get("num_hidden_layers", 12), + hidden_size=data_dict.get("hidden_size", 768), + intermediate_size=data_dict.get("hidden_size", 768) * data_dict.get("mlp_ratio", 4), + num_attention_heads=data_dict.get("num_attention_heads", 12), + num_channels=data_dict.get("num_channels", 3), + image_size=data_dict.get("image_size", 224), + patch_size=data_dict.get("patch_size", 16), + layer_norm_eps=data_dict.get("layer_norm_eps", 1e-6), + ) + + +# Modify from CLIP +class Attention(nn.Module): + def __init__( + self, + dims: int, + num_heads: int, + query_input_dims: Optional[int] = None, + key_input_dims: Optional[int] = None, + value_input_dims: Optional[int] = None, + value_dims: Optional[int] = None, + value_output_dims: Optional[int] = None, + bias: bool = False, + ): + super().__init__() + + if (dims % num_heads) != 0: + raise ValueError( + "The input feature dimensions should be divisible by the " + f"number of heads ({dims} % {num_heads}) != 0" + ) + + query_input_dims = query_input_dims or dims + key_input_dims = key_input_dims or dims + value_input_dims = value_input_dims or key_input_dims + value_dims = value_dims or dims + value_output_dims = value_output_dims or dims + + self.num_heads = num_heads + self.qkv = nn.Linear(dims, dims * 3, bias=bias) + self.proj = nn.Linear(value_dims, value_output_dims, bias=bias) + + def __call__(self, hidden_states, mask=None): + B, L, D = hidden_states.shape + + queries, keys, values = self.qkv(hidden_states).reshape(B, L, 3, -1).transpose(2, 0, 1, 3) + + _, S, _ = keys.shape + queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, S, self.num_heads, -1).transpose(0, 2, 3, 1) + values = values.reshape(B, S, self.num_heads, -1).transpose(0, 2, 1, 3) + + scale = math.sqrt(1 / queries.shape[-1]) + scores = (queries * scale) @ keys + if mask is not None: + scores = scores + mask.astype(scores.dtype) + + scores = mx.softmax(scores, axis=-1) + values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.proj(values_hat) + + +# Copied from CLIP +class MLP(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.activation_fn = nn.GELU() + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def __call__(self, x: mx.array) -> mx.array: + x = self.activation_fn(self.fc1(x)) + x = self.fc2(x) + return x + + +# Copied from CLIP +class EncoderLayer(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + # Add biases to the attention projections + self.self_attn = Attention(config.hidden_size, config.num_attention_heads, bias=True) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = MLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: + x = x + self.self_attn(self.layer_norm1(x), mask) + x = x + self.mlp(self.layer_norm2(x)) + return x + + +# Copied from CLIP +class Encoder(nn.Module): + def __init__(self, config: SiglipVisionConfig): + self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)] + + def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: + for layer in self.layers: + x = layer(x, mask) + return x + + +class SiglipVisionEmbeddings(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=True, + ) + + def __call__(self, pixel_values: mx.array) -> mx.array: + target_dtype = self.patch_embedding.weight.dtype + # shape = [*, width, grid, grid] + patch_embeds = self.patch_embedding(pixel_values.astype(target_dtype)) + # patch_embeds (1, 14, 14, 1024) + # [batch_size, h, w, embed_dim] + embeddings = mx.flatten(patch_embeds, start_axis=1, end_axis=2) + # embeddings = patch_embeds.flatten(2).transpose(1, 2) # BCHW -> BNC + return embeddings + + +class SiglipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + + self.q = nn.Linear(config.hidden_size, config.hidden_size) + self.kv = nn.Linear(config.hidden_size, 2 * config.hidden_size) + self.proj = nn.Linear(config.hidden_size, config.hidden_size) + + self.latent_len = 1 + self.num_heads = config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.scale = self.head_dim**-0.5 + + self.latent = mx.zeros((1, self.latent_len, config.hidden_size)) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = MLP(config) + self.pos_embed = None + + self.pool = "token" + + def __call__(self, x: mx.array) -> mx.array: + B, N, C = x.shape + + # if self.pos_embed is not None: + # # FIXME interpolate + # x = x + self.pos_embed.unsqueeze(0).to(x.dtype) + + # TODO: maybe slow + q_latent = mx.repeat(self.latent, B, axis=0) + q = self.q(q_latent) + q = q.reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) + + kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).transpose(2, 0, 3, 1, 4) + k, v = kv + + # q, k = self.q_norm(q), self.k_norm(k) + + # self attn + x = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale) + + x = x.transpose(0, 2, 1, 3).reshape(B, self.latent_len, C) + x = self.proj(x) + # x = self.proj_drop(x) + + x = x + self.mlp(self.norm(x)) + + # optional pool if latent seq_len > 1 and pooled output is desired + if self.pool == "token": + x = x[:, 0] + elif self.pool == "avg": + x = x.mean(1) + return x + + +class SiglipVisionModel(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.embeddings = SiglipVisionEmbeddings(config) # patch_embed + # self.pre_layrnorm = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.encoder = Encoder(config) + self.norm = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.use_head = getattr(config, "use_head", True) + + self.no_embed_class = False + self.num_prefix_tokens = 0 + + grid_size = config.image_size // config.patch_size + self.pos_embed = mx.zeros((1, grid_size * grid_size, config.hidden_size)) + + self.head = None + self.ignore_head = config.ignore_head + # attn_pool + if not self.ignore_head: + self.head = SiglipMultiheadAttentionPoolingHead(config) + + def _pos_embed(self, x: mx.array) -> mx.array: + return x + self.pos_embed + + def __call__(self, x: mx.array): + x = self.embeddings(x) + x = self._pos_embed(x) + + mask = None + x = self.encoder(x, mask) + x = self.norm(x) + + return x if self.ignore_head else self.head(x) + + @staticmethod + def _load_default_config(config_data): + config_data.update( + { + "image_size": 384, + "patch_size": 16, + "hidden_size": 1024, + "num_hidden_layers": 24, + "num_attention_heads": 16, + "mlp_ratio": 4, + "global_pool": "map", + "use_checkpoint": False, + } + ) + return config_data + + @classmethod + def from_pretrained(cls, path: str): + import os + + with open(os.path.join(path, "config.json"), "r") as f: + config_data = json.load(f)["vision_config"]["params"] + config_data = cls._load_default_config(config_data) + config = SiglipVisionConfig.from_dict(config_data) + + model = cls(config) + weight_files = glob.glob(str(Path(path) / "*.safetensors")) + if not weight_files: + logging.error(f"No safetensors found in {path}") + raise FileNotFoundError(f"No safetensors found in {path}") + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf)) + + weights = model.sanitize(weights) + model.load_weights(list(weights.items())) + return model + + @staticmethod + def sanitize(weights): + sanitized_weights = {} + # Ugly compatibility janus + for k, v in weights.items(): + k = k.split("vision_tower.", 1)[-1] + if "blocks." in k: + k = k.replace("blocks.", "encoder.layers.") + if "patch_embed.proj." in k: + k = k.replace("patch_embed.proj.", "embeddings.patch_embedding.") + # if "norm." in k: + # k = k.replace("norm.", "post_layernorm.") + if "attn_pool." in k: + k = k.replace("attn_pool.", "head.") + + if ".norm2." in k: + k = k.replace(".norm2.", ".layer_norm2.") + if ".norm1." in k: + k = k.replace(".norm1.", ".layer_norm1.") + if ".attn." in k: + k = k.replace(".attn.", ".self_attn.") + + if "position_ids" in k: + # Remove unused position_ids + continue + elif "patch_embedding.weight" in k: + # pytorch conv2d expects the weight tensor to be of shape [out_channels, in_channels, kH, KW] + # mlx conv2d expects the weight tensor to be of shape [out_channels, kH, KW, in_channels] + sanitized_weights[k] = v.transpose(0, 2, 3, 1) + else: + sanitized_weights[k] = v + + return sanitized_weights + + +if __name__ == "__main__": + model = SiglipVisionModel.from_pretrained("siglip_model") + + shape = (1, 384, 384, 3) + pixel_value = mx.random.normal(shape=shape) + # import torch + # pixel_value_torch = torch.load("x.pth") + # pixel_value = mx.array(pixel_value_torch).transpose(0, 2, 3, 1) + + output = model(pixel_value) + # print("output", output, output.shape) diff --git a/tllm/models/mlx/vq_model.py b/tllm/models/mlx/janus_pro/vq_model.py similarity index 100% rename from tllm/models/mlx/vq_model.py rename to tllm/models/mlx/janus_pro/vq_model.py diff --git a/tllm/models/register.py b/tllm/models/register.py index dfe2127..e54b01b 100644 --- a/tllm/models/register.py +++ b/tllm/models/register.py @@ -4,6 +4,7 @@ from tllm.models.torch.helper import greedy_decode MODEL_REGISTER = {} +DEP_MODEL_REGISTER = {} try: # in testing from tllm.models.tinygrad.helper import greedy_decode @@ -16,15 +17,12 @@ pass if BackendEnum.MLX == BACKEND: - from tllm.models.mlx.janus_pro import MLXJanusProConditionalGeneration + from tllm.models.mlx.janus_pro.janus_pro import MLXJanusProConditionalGeneration from tllm.models.mlx.llama import MLXLlamaForCausalLM, MLXLlamaModel from tllm.models.mlx.qwen2 import MLXQwen2ForCausalLM, MLXQwen2Model - from tllm.models.mlx.qwen2_vl import MLXQwen2VLForConditionalGeneration MODEL_REGISTER.update({"LlamaForCausalLM": (MLXLlamaForCausalLM, MLXLlamaModel)}) MODEL_REGISTER.update({"Qwen2ForCausalLM": (MLXQwen2ForCausalLM, MLXQwen2Model)}) - MODEL_REGISTER.update({"Qwen2VLForConditionalGeneration": (MLXQwen2VLForConditionalGeneration, MLXQwen2Model)}) - MODEL_REGISTER.update({"Qwen2_5_VLForConditionalGeneration": (MLXQwen2VLForConditionalGeneration, MLXQwen2Model)}) MODEL_REGISTER.update({"JanusProConditionalGeneration": (MLXJanusProConditionalGeneration, MLXLlamaModel)}) if importlib.util.find_spec("mflux"): @@ -32,6 +30,19 @@ from tllm.models.mlx.flux.transformer import FLUXModel MODEL_REGISTER.update({"FLUX": (Flux1, FLUXModel)}) + else: + DEP_MODEL_REGISTER.update({"FLUX": "mflux"}) + + if importlib.util.find_spec("mlx_vlm"): + from tllm.models.mlx.qwen2_vl import MLXQwen2VLForConditionalGeneration + + MODEL_REGISTER.update({"Qwen2VLForConditionalGeneration": (MLXQwen2VLForConditionalGeneration, MLXQwen2Model)}) + MODEL_REGISTER.update( + {"Qwen2_5_VLForConditionalGeneration": (MLXQwen2VLForConditionalGeneration, MLXQwen2Model)} + ) + else: + DEP_MODEL_REGISTER.update({"Qwen2VLForConditionalGeneration": "mlx_vlm"}) + DEP_MODEL_REGISTER.update({"Qwen2_5_VLForConditionalGeneration": "mlx_vlm"}) from tllm.models.mlx.helper import greedy_decode