diff --git a/lightllm/common/image_cache_manager.py b/lightllm/common/image_cache_manager.py new file mode 100644 index 000000000..fb04e4b59 --- /dev/null +++ b/lightllm/common/image_cache_manager.py @@ -0,0 +1,76 @@ +from collections import OrderedDict +from lightllm.utils.dist_utils import get_current_device_id + + +class ImageCacheManager: + def __init__(self): + """ + Initialize the image cache manager with a simple GPU cache and an LRU CPU cache. + """ + self._gpu_cache = dict() + self._cpu_cache = OrderedDict() + + def set_max_size(self, max_size: int): + """ + Set the maximum number of items to keep in the CPU cache. + :param max_size: Maximum number of items to keep in the CPU cache. + """ + if max_size <= 0: + raise ValueError("max_size must be greater than 0") + self._max_size = max_size + + def set_embed(self, uuid, embed): + """ + Store the embedding for the given uuid in the GPU cache. + :param uuid: Unique identifier for the image + :param embed: Embedding vector for the image (on GPU) + """ + self._gpu_cache[uuid] = embed + + def get_embed(self, uuid): + """ + Retrieve the embedding for the given uuid. Prefer GPU cache, + otherwise return CPU cache and move to GPU (simulate .cuda()). + :param uuid: Unique identifier for the image + :return: Embedding vector (on GPU if possible, else move from CPU to GPU) + """ + if uuid in self._gpu_cache: + return self._gpu_cache[uuid] + elif uuid in self._cpu_cache: + self._cpu_cache.move_to_end(uuid) + embed = self._cpu_cache[uuid].cuda(get_current_device_id()) + return embed + return None + + def query_embed(self, uuid): + """ + Query if the embedding for the given uuid is in the cache. + :param uuid: Unique identifier for the image + :return: True if the embedding is in the cache, False otherwise + """ + return uuid in self._gpu_cache or uuid in self._cpu_cache + + def filter(self, uuid_list): + """ + Given a list of uuids, move their embeddings from GPU cache to CPU cache if present, + and return a dict of those found in the cache and their embeddings (on CPU). + :param uuid_list: List of uuids + """ + for uuid in uuid_list: + if uuid in self._gpu_cache: + embed_cpu = self._gpu_cache[uuid].cpu() + # Move to CPU cache and remove from GPU cache + self._gpu_cache.pop(uuid) + if uuid in self._cpu_cache: + self._cpu_cache.move_to_end(uuid) + self._cpu_cache[uuid] = embed_cpu + if len(self._cpu_cache) > self._max_size: + self._cpu_cache.popitem(last=False) + elif uuid in self._cpu_cache: + self._cpu_cache.move_to_end(uuid) + print(self._gpu_cache.keys()) + print(self._cpu_cache.keys()) + return + + +image_cache_manager = ImageCacheManager() diff --git a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py index f19563932..486319495 100644 --- a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py @@ -3,6 +3,9 @@ from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight from lightllm.models.internlm2.layer_weights.pre_and_post_layer_weight import Internlm2PreAndPostLayerWeight +from lightllm.models.vit.model import VisionTransformer +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.common.image_cache_manager import image_cache_manager # add key: language_model.xxx -> xxx @@ -15,9 +18,45 @@ def rename_weight_keys(weights): weights[k[len(prefix) :]] = weights[k] +class InternVLPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + # if we don't assign an extra process for visual model, we need initialize the image cache manager here + if get_env_start_args().disable_extra_process_for_multimodal: + kvargs = { + "weight_dir": get_env_start_args().model_dir, + "data_type": self.data_type_, + "quant_type": get_env_start_args().vit_quant_type, + "quant_cfg": get_env_start_args().vit_quant_cfg, + "max_batch_size": get_env_start_args().visual_infer_batch_size, + } + self.visual_model = VisionTransformer( + kvargs=kvargs, + ) + image_cache_manager.set_max_size(get_env_start_args().cache_capacity * 2) + return + + def load_hf_weights(self, weights): + rename_weight_keys(weights) + super().load_hf_weights(weights) + + class InternVLPhi3PreAndPostLayerWeight(LlamaPreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + # if we don't assign an extra process for visual model, we need initialize the image cache manager here + if get_env_start_args().disable_extra_process_for_multimodal: + kvargs = { + "weight_dir": get_env_start_args().model_dir, + "data_type": self.data_type_, + "quant_type": get_env_start_args().vit_quant_type, + "quant_cfg": get_env_start_args().vit_quant_cfg, + "max_batch_size": get_env_start_args().visual_infer_batch_size, + } + self.visual_model = VisionTransformer( + kvargs=kvargs, + ) + image_cache_manager.set_max_size(get_env_start_args().cache_capacity * 2) return def load_hf_weights(self, weights): @@ -29,6 +68,19 @@ def load_hf_weights(self, weights): class InternVLInternlm2PreAndPostLayerWeight(Internlm2PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + # if we don't assign an extra process for visual model, we need initialize the image cache manager here + if get_env_start_args().disable_extra_process_for_multimodal: + kvargs = { + "weight_dir": get_env_start_args().model_dir, + "data_type": self.data_type_, + "quant_type": get_env_start_args().vit_quant_type, + "quant_cfg": get_env_start_args().vit_quant_cfg, + "max_batch_size": get_env_start_args().visual_infer_batch_size, + } + self.visual_model = VisionTransformer( + kvargs=kvargs, + ) + image_cache_manager.set_max_size(get_env_start_args().cache_capacity * 2) return def load_hf_weights(self, weights): @@ -40,6 +92,19 @@ def load_hf_weights(self, weights): class InternVLLlamaPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + # if we don't assign an extra process for visual model, we need initialize the image cache manager here + if get_env_start_args().disable_extra_process_for_multimodal: + kvargs = { + "weight_dir": get_env_start_args().model_dir, + "data_type": self.data_type_, + "quant_type": get_env_start_args().vit_quant_type, + "quant_cfg": get_env_start_args().vit_quant_cfg, + "max_batch_size": get_env_start_args().visual_infer_batch_size, + } + self.visual_model = VisionTransformer( + kvargs=kvargs, + ) + image_cache_manager.set_max_size(get_env_start_args().cache_capacity * 2) return def load_hf_weights(self, weights): diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index b5b31a413..f4ac4c326 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -6,7 +6,9 @@ from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer from lightllm.utils.infer_utils import mark_cost_time +from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.embed_cache.utils import bytes2tensor, read_shm, get_shm_name_embed +from lightllm.common.image_cache_manager import image_cache_manager from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb from lightllm.distributed.communication_op import all_reduce @@ -29,8 +31,24 @@ class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer): def __init__(self, network_config, mode): super().__init__(network_config, mode) + self.disable_extra_process_for_multimodal = get_env_start_args().disable_extra_process_for_multimodal return + def _infer_image_embeds(self, infer_state, layer_weight): + if not self.disable_extra_process_for_multimodal: + return + infer_images = [] + for _, p in enumerate(infer_state.multimodal_params): + for img in p["images"] + p["audios"]: + if (img["_prefill_"] is True) and (not image_cache_manager.query_embed(img["uuid"])): + infer_images.append(img) + if len(infer_images) > 0: + infer_batch_size = get_env_start_args().visual_infer_batch_size + for i in range(0, len(infer_images), infer_batch_size): + img_embeds, uuids, valid_ids = layer_weight.visual_model.encode(infer_images[i : i + infer_batch_size]) + for uuid, valid_id in zip(uuids, valid_ids): + image_cache_manager.set_embed(uuid, img_embeds[valid_id[0] : valid_id[1]]) + def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): img_weight = [] @@ -45,14 +63,19 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids) + self._infer_image_embeds(infer_state, layer_weight) for batch_id, p in enumerate(infer_state.multimodal_params): for img in p["images"] + p["audios"]: # skip the same image if img["token_id"] in img_start_token_ids or img["_prefill_"] is False: continue # pull the img_embeds by uid from shm - data = read_shm(get_shm_name_embed(img["uuid"])) - img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1)) + if self.disable_extra_process_for_multimodal: + img_embed = image_cache_manager.get_embed(img["uuid"]) + img_weight.append(img_embed.reshape(img["token_num"], -1)) + else: + data = read_shm(get_shm_name_embed(img["uuid"])) + img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1)) img_start_token_ids.append(img["token_id"]) img_token_lens.append(img["token_num"]) img_start_locs.append(img_start_loc) diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py index 55d73fa73..3c42f712e 100644 --- a/lightllm/models/vit/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -11,7 +11,9 @@ MultiROWMMWeight, TpNormWeight, ) -from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.dist_utils import ( + get_current_device_id, +) class ViTTransformerLayerWeight(TransformerLayerWeight): diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index 01bb69bdf..a8b475889 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -18,7 +18,8 @@ from io import BytesIO from rpyc.utils.classic import obtain from lightllm.common.quantization import Quantcfg -from lightllm.utils.dist_utils import get_dp_world_size +from lightllm.utils.dist_utils import get_dp_world_size, get_global_world_size +from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager @@ -47,6 +48,7 @@ def __init__(self, kvargs): self.quant_cfg_path = kvargs.get("quant_cfg", None) self.load_image_func = get_load_image_func(self.weight_dir_) self.max_batch_size = kvargs.get("max_batch_size", 1) + self.enable_tensor_cache = not get_env_start_args().disable_extra_process_for_multimodal self._init_datatype() self._init_config() @@ -63,6 +65,7 @@ def _check_max_len_infer(self): disable_check_max_len_infer = os.getenv("DISABLE_CHECK_MAX_LEN_INFER", None) is not None if disable_check_max_len_infer: return + self.enable_tensor_cache = True try: dummy_images = torch.randn( @@ -70,6 +73,7 @@ def _check_max_len_infer(self): ).cuda() all_img_embeds = self.forward(dummy_images) del all_img_embeds + del dummy_images logger.info(f"vit check max_len {self.max_batch_size} infer ok") except (RuntimeError, torch.OutOfMemoryError) as e: logger.exception(str(e)) @@ -78,6 +82,7 @@ def _check_max_len_infer(self): ) logger.error(exception_str) raise Exception(exception_str) + self.enable_tensor_cache = not get_env_start_args().disable_extra_process_for_multimodal return def _init_config(self): @@ -150,6 +155,8 @@ def _init_infer_layer(self): return def _init_datatype(self): + if isinstance(self.data_type, torch.dtype): + return if self.data_type in ["fp16", "float16"]: self.data_type = torch.float16 elif self.data_type in ["bf16", "bfloat16"]: @@ -161,12 +168,14 @@ def _init_datatype(self): @torch.no_grad() def forward(self, pixel_values): - g_cache_manager.cache_env_in() + if self.enable_tensor_cache: + g_cache_manager.cache_env_in() input_embs = self.pre_infer.forward(pixel_values, self.pre_post_weight) for i in range(self.layers_num + self.select_layer + 1): input_embs = self.layers_infer[i].forward(input_embs, self.trans_layers_weight[i]) input_embs = self.post_infer.forward(input_embs[:, 1:, :], self.pre_post_weight) - g_cache_manager.cache_env_out() + if self.enable_tensor_cache: + g_cache_manager.cache_env_out() return input_embs @torch.no_grad() @@ -182,6 +191,12 @@ def encode(self, images: List[ImageItem]): image_data = Image.open(BytesIO(image_data)) t = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"]) img_tensors.append(t) + elif isinstance(img, dict): + uuids.append(img["uuid"]) + image_data = read_shm(get_shm_name_data(img["uuid"])) + image_data = Image.open(BytesIO(image_data)) + t = self.load_image_func(image_data, max_num=img["extra_params"]["image_patch_max_num"]) + img_tensors.append(t) else: raise Exception("Unsupport input types: {} for {}".format(type(img), img)) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index d904c727f..7e1d3dc29 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -240,6 +240,11 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--enable_multimodal", action="store_true", help="Whether or not to allow to load additional visual models." ) + parser.add_argument( + "--disable_extra_process_for_multimodal", + action="store_true", + help="Whether or not to disable extra process for multimodal.", + ) parser.add_argument( "--enable_multimodal_audio", action="store_true", diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 6e6c27b5e..79787cfb7 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -258,7 +258,7 @@ def normal_or_p_d_start(args): ], start_args=[(cache_port, args)], ) - if args.enable_multimodal_audio: + if args.enable_multimodal_audio and not args.disable_extra_process_for_multimodal: from .audioserver.manager import start_audio_process process_manager.start_submodule_processes( @@ -278,7 +278,7 @@ def normal_or_p_d_start(args): ], ) - else: + elif not args.disable_extra_process_for_multimodal: process_manager.start_submodule_processes( start_funcs=[ start_visual_process, diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index fa455c225..967c716dd 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -81,10 +81,12 @@ def __init__( ) self.enable_multimodal = enable_multimodal + self.disable_extra_process_for_multimodal = args.disable_extra_process_for_multimodal if self.enable_multimodal: self.cache_client = rpyc.connect("localhost", cache_port) - self.send_to_visual = context.socket(zmq.PUSH) - self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}") + if not self.disable_extra_process_for_multimodal: + self.send_to_visual = context.socket(zmq.PUSH) + self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}") self.shm_req_manager = ShmReqManager() @@ -449,7 +451,7 @@ async def transfer_to_next_module( ): if self.pd_mode == NodeRole.P: - if self.enable_multimodal: + if self.enable_multimodal and not self.disable_extra_process_for_multimodal: self.send_to_visual.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, @@ -470,7 +472,7 @@ async def transfer_to_next_module( return if self.pd_mode == NodeRole.NORMAL: - if self.enable_multimodal: + if self.enable_multimodal and not self.disable_extra_process_for_multimodal: self.send_to_visual.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index e3c1d19d2..5af5d5d95 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -120,6 +120,7 @@ def to_dict(self): ret["uuid"] = self.uuid ret["token_id"] = self.token_id ret["token_num"] = self.token_num + ret["extra_params"] = self.extra_params return ret def to_origin_dict(self): diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 10b68245c..551a70a22 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -9,6 +9,7 @@ from dataclasses import dataclass, field from typing import List, Dict, Tuple, Optional, Union, Any from lightllm.common.req_manager import ReqManager +from lightllm.common.image_cache_manager import image_cache_manager from lightllm.utils.infer_utils import mark_start, mark_end from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode @@ -131,6 +132,7 @@ def filter(self, finished_request_ids: List[int]): free_req_index = [] free_token_index = [] + image_uuid_list = [] for request_id in finished_request_ids: req: InferReq = self.requests_mapping.pop(request_id) group_req_id = convert_sub_id_to_group_id(req.shm_req.request_id) @@ -145,6 +147,10 @@ def filter(self, finished_request_ids: List[int]): # logger.info(f"infer release req id {req.shm_req.request_id}") req.shm_req.shm_infer_released = True self.shm_req_manager.put_back_req_obj(req.shm_req) + if req.multimodal_params is not None and get_env_start_args().disable_extra_process_for_multimodal: + for img in req.multimodal_params["images"]: + image_uuid_list.append(img["uuid"]) + image_cache_manager.filter(image_uuid_list) free_token_index = custom_cat(free_token_index) self.req_manager.free(free_req_index, free_token_index) diff --git a/unit_tests/common/basemodel/triton_kernel/test_multimodal_emb.py b/unit_tests/common/basemodel/triton_kernel/test_multimodal_emb.py index f5f7e903d..438eaa157 100755 --- a/unit_tests/common/basemodel/triton_kernel/test_multimodal_emb.py +++ b/unit_tests/common/basemodel/triton_kernel/test_multimodal_emb.py @@ -6,7 +6,7 @@ logger = init_logger(__name__) -def test_mark_multimodal_obj(): +def test_mark_mubltimodal_obj(): obj_start_ids = torch.tensor([1, 4, 100], device="cuda", dtype=torch.int64) obj_token_lens = torch.tensor([1, 3, 2], device="cuda", dtype=torch.int64) input_ids = torch.tensor([1, 7, 9, 333], device="cuda", dtype=torch.int64) diff --git a/unit_tests/models/llama/llama_gqa_decode_vsm.py b/unit_tests/models/llama/llama_gqa_decode_vsm.py new file mode 100644 index 000000000..f124a28eb --- /dev/null +++ b/unit_tests/models/llama/llama_gqa_decode_vsm.py @@ -0,0 +1,104 @@ +import unittest +import random +import torch +from tqdm import tqdm +from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.common.req_manager import ReqManager +from lightllm.models.llama.triton_kernel.gqa_flash_decoding_vsm import ( + gqa_token_decode_attention_flash_decoding_vsm, +) +from lightllm.models.llama.triton_kernel.gqa_flash_decoding import ( + gqa_token_decode_attention_flash_decoding, +) + + +class TestVSMGQADecoding(unittest.TestCase): + def test_vsm_gqa_decoding_align(self): + random.seed(0) + torch.manual_seed(0) + torch.cuda.manual_seed(0) + torch.cuda.manual_seed_all(0) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + bs_list = [1, 8, 16, 32, 64, 128, 256] + group_size_list = [16, 32, 64] + seq_len_list = [128, 512, 1024, 2048, 4096, 8192] + q_head_dim_list = [64, 128] + q_head_num_list = [8, 16, 32] + + def get_test_configs(): + for bs in bs_list: + for group_size in group_size_list: + for seq_len_m in seq_len_list: + for q_head_dim in q_head_dim_list: + for q_head_num in q_head_num_list: + if q_head_num < group_size: + continue + yield bs, group_size, seq_len_m, q_head_dim, q_head_num + + for bs, group_size, seq_len_m, q_head_dim, q_head_num in tqdm(list(get_test_configs())): + kv_head_num = q_head_num // group_size + q_head_dim = q_head_dim + kv_head_dim = q_head_dim + seq_len = (torch.zeros(bs, dtype=torch.int32) + seq_len_m).to(torch.int32) + total_token_in_the_batch = seq_len.sum().item() + rounded_total_token_in_the_batch = (total_token_in_the_batch + 128 - 1) // 128 * 128 + + q_shape = [bs, q_head_num, q_head_dim] + kv_shape = [ + rounded_total_token_in_the_batch, + kv_head_num, + kv_head_dim, + ] + qkv_dtype = torch.float16 + + q, k, v = ( + torch.randn(q_shape, dtype=qkv_dtype, device="cuda"), + torch.randn(kv_shape, dtype=qkv_dtype, device="cuda"), + torch.randn(kv_shape, dtype=qkv_dtype, device="cuda"), + ) + q, k, v = q / 10, k / 10, v / 10 + + req_to_token_index = torch.zeros((bs, seq_len_m)) - 1 + token_index = torch.arange(rounded_total_token_in_the_batch) + + total_count = 0 + for i in range(bs): + req_to_token_index[i, : seq_len[i]] = token_index[total_count : total_count + seq_len[i]] + total_count += seq_len[i] + + req_to_token_index = req_to_token_index.long().cuda() + + b_req_idx = torch.arange(bs, device="cuda") + infer_state = InferStateInfo() + infer_state.req_manager = ReqManager(bs, 2048, None) + infer_state.req_manager.req_to_token_indexs = req_to_token_index + infer_state.b_req_idx = b_req_idx.cuda() + infer_state.b_seq_len = seq_len.cuda() + infer_state.max_len_in_batch = seq_len_m + infer_state.batch_size = bs + infer_state.q_head_num = q_head_num + infer_state.q_head_dim = q_head_dim + infer_state.kv_head_num = kv_head_num + infer_state.softmax_scale = 1 / (q_head_dim ** 0.5) + infer_state.total_token_num = torch.tensor([total_token_in_the_batch], dtype=torch.int32).cuda() + new_out = gqa_token_decode_attention_flash_decoding_vsm(q, k, v, infer_state) + old_out = gqa_token_decode_attention_flash_decoding( + q, + infer_state, + infer_state.q_head_num, + infer_state.q_head_dim, + k, + v, + ) + cos_sim = torch.nn.functional.cosine_similarity(new_out, old_out, dim=-1).mean().cpu().item() + self.assertGreaterEqual( + cos_sim, + 0.9, + f"bs={bs},group_size={group_size},seq_len={seq_len_m},q_head_dim={q_head_dim},q_head_num={q_head_num}", + ) + + +if __name__ == "__main__": + unittest.main()