Skip to content

Commit

Permalink
mlx-vlm and mflux are not mandatory
Browse files Browse the repository at this point in the history
  • Loading branch information
wnma3mz committed Feb 2, 2025
1 parent 3cc0a98 commit 020add1
Show file tree
Hide file tree
Showing 10 changed files with 380 additions and 55 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

1. install dependencies

- for mlx (macos arm): `pip install -U -e ".[mlx]" && pip install -e git+https://github.com/wnma3mz/mlx_clip.git#egg=mlx_clip`
- for mlx (macos arm): `pip install -U -e ".[mlx]"`
- for nvidia: `pip install -e ".[torch]"`

2. run server
Expand All @@ -33,6 +33,11 @@
python3 benchmarks/run_async_requests.py
```

### More Model

- qwen-vl: `pip install mlx-vlm==0.1.12`
- flux: `pip install mflux=0.4.1`

### More Details

In `examples/config.json`
Expand Down
3 changes: 1 addition & 2 deletions requirements/mlx.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
mlx==0.22.0
mlx-lm==0.21.1
mlx-vlm==0.1.12
mlx-lm==0.21.1
1 change: 0 additions & 1 deletion run_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def load_message():


if __name__ == "__main__":
args = parse_args()
messages_dict = load_message()
if args.message_type == "llm":
asyncio.run(llm_generate(args, messages_dict["llm"][0]))
Expand Down
44 changes: 7 additions & 37 deletions run_janus_pro.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,18 @@
import base64
import json
import math
import os
import time
from typing import List, Tuple

from PIL import Image
import numpy as np


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--backend", type=str, default="MLX", choices=["MLX", "TORCH"])
parser.add_argument(
"--attn_backend",
type=str,
default="AUTO",
choices=["AUTO", "TORCH", "VLLM", "XFormers"],
help="Attention backend if backend is TORCH",
)
parser.add_argument("--model_path", type=str, default="Qwen/Qwen2-VL-2B-Instruct")
parser.add_argument("--message_type", type=str, default="llm", choices=["llm", "mllm", "image"])
return parser.parse_args()


args = parse_args()
os.environ["TLLM_BACKEND"] = args.backend
os.environ["TLLM_ATTN_BACKEND"] = args.attn_backend

from tllm.commons.manager import load_client_model, load_master_model
from tllm.commons.tp_communicator import Communicator
from tllm.engine import AsyncEngine
from tllm.entrypoints.image_server.image_protocol import Text2ImageRequest
from tllm.entrypoints.image_server.server_image import ImageServing
from tllm.entrypoints.protocol import ChatCompletionRequest, ChatCompletionResponse
from tllm.entrypoints.server_chat import OpenAIServing
from tllm.generate import LLMGenerator
from tllm.img_helper import base64_to_pil_image
from tllm.schemas import MIX_TENSOR, SeqInput
from tllm.singleton_logger import SingletonLogger

Expand All @@ -56,19 +32,6 @@ async def forward(self, hidden_states: MIX_TENSOR, seq_input: SeqInput) -> Tuple
output_hidden_states = self.model(hidden_states, seq_input)
return output_hidden_states, [time.perf_counter() - s1]

async def image_forward(
self,
hidden_states: MIX_TENSOR,
text_embeddings: MIX_TENSOR,
seq_len: int,
height: int,
width: int,
request_id: str,
) -> Tuple[MIX_TENSOR, List[float]]:
s1 = time.perf_counter()
output_hidden_states = self.model(hidden_states, text_embeddings, seq_len, height, width, [request_id])
return output_hidden_states, [time.perf_counter() - s1]


def init_engine(model_path: str) -> AsyncEngine:
model = load_master_model(model_path)
Expand Down Expand Up @@ -123,6 +86,13 @@ def gen_img_message():
]


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="Qwen/Qwen2-VL-2B-Instruct")
parser.add_argument("--message_type", type=str, default="llm", choices=["llm", "mllm", "image"])
return parser.parse_args()


if __name__ == "__main__":
args = parse_args()
messages_dict = load_message()
Expand Down
8 changes: 6 additions & 2 deletions tllm/commons/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from tllm.commons.tp_communicator import BaseCommunicator
from tllm.models.file_helper import get_model_path
from tllm.models.register import MODEL_REGISTER
from tllm.models.register import DEP_MODEL_REGISTER, MODEL_REGISTER
from tllm.models.weight_helper import load_gguf_weight, read_from_safetensors


Expand Down Expand Up @@ -148,7 +148,11 @@ def load_master_model(model_path: str):
weight_manager = WeightManager(model_path)
state_dict = weight_manager.read_master_weight()
if weight_manager.arch not in MODEL_REGISTER:
raise ValueError(f"Model {weight_manager.arch} not supported")
arch = weight_manager.arch
if weight_manager.arch in DEP_MODEL_REGISTER:
raise ValueError(f"Model {arch} now is support, please execute `pip install {DEP_MODEL_REGISTER[arch]}`")
else:
raise ValueError(f"Model {arch} not supported")

MY_CausalLM_CLASS, _ = MODEL_REGISTER[weight_manager.arch]

Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,22 @@

import mlx.core as mx
import mlx.nn as nn
from mlx_clip.models.siglip.siglip_model import SiglipVisionModel
import numpy as np

from tllm import DTYPE
from tllm.models.mlx.helper import dict_to_dataclass, quantization_func
from tllm.models.mlx.vq_model import ModelArgs, VQModel, vision_head
from tllm.models.mlx.janus_pro.siglip_model import SiglipVisionModel
from tllm.models.mlx.janus_pro.vq_model import ModelArgs, VQModel, vision_head
from tllm.models.processor import VLMImageProcessor
from tllm.models.weight_helper import common_sanitize


def replace_vision_model_func(k: str) -> Optional[str]:
# k = k.split("vision_model.", 1)[-1]
if "vision_tower.blocks." in k:
k = k.replace("vision_tower.blocks.", "vision_tower.encoder.layers.")
if "vision_tower.patch_embed.proj." in k:
k = k.replace("vision_tower.patch_embed.proj.", "vision_tower.embeddings.patch_embedding.")

# do not load attn_pool
if "attn_pool." in k:
k = k.replace("attn_pool.", "head.")

if ".norm2." in k:
k = k.replace(".norm2.", ".layer_norm2.")
if ".norm1." in k:
Expand Down Expand Up @@ -169,7 +164,7 @@ def sanitize(weights):
if k.startswith("vision_model."):
k = replace_vision_model_func(k)
# Skip attn_pool
if k.startswith("vision_model.vision_tower.head."):
if k.startswith("vision_model.vision_tower.attn_pool."):
continue
if k.startswith("language_model."):
k = k.replace("language_model.model.", "")
Expand All @@ -179,6 +174,7 @@ def sanitize(weights):
if "weight" in k and len(v.shape) == 4:
# [out_ch, in_ch, h, w] -> [out_ch, h, w, in_ch]
v = v.transpose(0, 2, 3, 1)
# Skip encoder
if "encoder" in k:
continue
if ".quant_conv" in k:
Expand Down
Loading

0 comments on commit 020add1

Please sign in to comment.