Skip to content

Multimodal improve #951

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions lightllm/common/image_cache_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from collections import OrderedDict
from lightllm.utils.dist_utils import get_current_device_id


class ImageCacheManager:
def __init__(self):
"""
Initialize the image cache manager with a simple GPU cache and an LRU CPU cache.
"""
self._gpu_cache = dict()
self._cpu_cache = OrderedDict()
Comment on lines +10 to +11

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The _max_size attribute is used in the filter method but is not initialized in the __init__ method, which can lead to an AttributeError if filter() is called before set_max_size().

Suggested change
self._gpu_cache = dict()
self._cpu_cache = OrderedDict()
self._gpu_cache = dict()
self._cpu_cache = OrderedDict()
self._max_size = 0


def set_max_size(self, max_size: int):
"""
Set the maximum number of items to keep in the CPU cache.
:param max_size: Maximum number of items to keep in the CPU cache.
"""
if max_size <= 0:
raise ValueError("max_size must be greater than 0")
self._max_size = max_size

def set_embed(self, uuid, embed):
"""
Store the embedding for the given uuid in the GPU cache.
:param uuid: Unique identifier for the image
:param embed: Embedding vector for the image (on GPU)
"""
self._gpu_cache[uuid] = embed

def get_embed(self, uuid):
"""
Retrieve the embedding for the given uuid. Prefer GPU cache,
otherwise return CPU cache and move to GPU (simulate .cuda()).
:param uuid: Unique identifier for the image
:return: Embedding vector (on GPU if possible, else move from CPU to GPU)
"""
if uuid in self._gpu_cache:
return self._gpu_cache[uuid]
elif uuid in self._cpu_cache:
self._cpu_cache.move_to_end(uuid)
embed = self._cpu_cache[uuid].cuda(get_current_device_id())
return embed
return None

def query_embed(self, uuid):
"""
Query if the embedding for the given uuid is in the cache.
:param uuid: Unique identifier for the image
:return: True if the embedding is in the cache, False otherwise
"""
return uuid in self._gpu_cache or uuid in self._cpu_cache

def filter(self, uuid_list):
"""
Given a list of uuids, move their embeddings from GPU cache to CPU cache if present,
and return a dict of those found in the cache and their embeddings (on CPU).
:param uuid_list: List of uuids
"""
for uuid in uuid_list:
if uuid in self._gpu_cache:
embed_cpu = self._gpu_cache[uuid].cpu()
# Move to CPU cache and remove from GPU cache
self._gpu_cache.pop(uuid)
if uuid in self._cpu_cache:
self._cpu_cache.move_to_end(uuid)
self._cpu_cache[uuid] = embed_cpu
if len(self._cpu_cache) > self._max_size:
self._cpu_cache.popitem(last=False)
elif uuid in self._cpu_cache:
self._cpu_cache.move_to_end(uuid)
print(self._gpu_cache.keys())
print(self._cpu_cache.keys())
Comment on lines +71 to +72

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These print statements are likely for debugging and should be removed or replaced with proper logging.

return


image_cache_manager = ImageCacheManager()
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight

from lightllm.models.internlm2.layer_weights.pre_and_post_layer_weight import Internlm2PreAndPostLayerWeight
from lightllm.models.vit.model import VisionTransformer
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.common.image_cache_manager import image_cache_manager


# add key: language_model.xxx -> xxx
Expand All @@ -15,9 +18,45 @@ def rename_weight_keys(weights):
weights[k[len(prefix) :]] = weights[k]


class InternVLPreAndPostLayerWeight(LlamaPreAndPostLayerWeight):
def __init__(self, data_type, network_config, mode):
super().__init__(data_type, network_config, mode)
# if we don't assign an extra process for visual model, we need initialize the image cache manager here
if get_env_start_args().disable_extra_process_for_multimodal:
kvargs = {
"weight_dir": get_env_start_args().model_dir,
"data_type": self.data_type_,
"quant_type": get_env_start_args().vit_quant_type,
"quant_cfg": get_env_start_args().vit_quant_cfg,
"max_batch_size": get_env_start_args().visual_infer_batch_size,
}
self.visual_model = VisionTransformer(
kvargs=kvargs,
)
image_cache_manager.set_max_size(get_env_start_args().cache_capacity * 2)
return

def load_hf_weights(self, weights):
rename_weight_keys(weights)
super().load_hf_weights(weights)


class InternVLPhi3PreAndPostLayerWeight(LlamaPreAndPostLayerWeight):
def __init__(self, data_type, network_config, mode):
super().__init__(data_type, network_config, mode)
# if we don't assign an extra process for visual model, we need initialize the image cache manager here
if get_env_start_args().disable_extra_process_for_multimodal:
kvargs = {
"weight_dir": get_env_start_args().model_dir,
"data_type": self.data_type_,
"quant_type": get_env_start_args().vit_quant_type,
"quant_cfg": get_env_start_args().vit_quant_cfg,
"max_batch_size": get_env_start_args().visual_infer_batch_size,
}
self.visual_model = VisionTransformer(
kvargs=kvargs,
)
image_cache_manager.set_max_size(get_env_start_args().cache_capacity * 2)
return

def load_hf_weights(self, weights):
Expand All @@ -29,6 +68,19 @@ def load_hf_weights(self, weights):
class InternVLInternlm2PreAndPostLayerWeight(Internlm2PreAndPostLayerWeight):
def __init__(self, data_type, network_config, mode):
super().__init__(data_type, network_config, mode)
# if we don't assign an extra process for visual model, we need initialize the image cache manager here
if get_env_start_args().disable_extra_process_for_multimodal:
kvargs = {
"weight_dir": get_env_start_args().model_dir,
"data_type": self.data_type_,
"quant_type": get_env_start_args().vit_quant_type,
"quant_cfg": get_env_start_args().vit_quant_cfg,
"max_batch_size": get_env_start_args().visual_infer_batch_size,
}
self.visual_model = VisionTransformer(
kvargs=kvargs,
)
image_cache_manager.set_max_size(get_env_start_args().cache_capacity * 2)
return

def load_hf_weights(self, weights):
Expand All @@ -40,6 +92,19 @@ def load_hf_weights(self, weights):
class InternVLLlamaPreAndPostLayerWeight(LlamaPreAndPostLayerWeight):
def __init__(self, data_type, network_config, mode):
super().__init__(data_type, network_config, mode)
# if we don't assign an extra process for visual model, we need initialize the image cache manager here
if get_env_start_args().disable_extra_process_for_multimodal:
kvargs = {
"weight_dir": get_env_start_args().model_dir,
"data_type": self.data_type_,
"quant_type": get_env_start_args().vit_quant_type,
"quant_cfg": get_env_start_args().vit_quant_cfg,
"max_batch_size": get_env_start_args().visual_infer_batch_size,
}
self.visual_model = VisionTransformer(
kvargs=kvargs,
)
image_cache_manager.set_max_size(get_env_start_args().cache_capacity * 2)
return

def load_hf_weights(self, weights):
Expand Down
27 changes: 25 additions & 2 deletions lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer
from lightllm.utils.infer_utils import mark_cost_time
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.server.embed_cache.utils import bytes2tensor, read_shm, get_shm_name_embed
from lightllm.common.image_cache_manager import image_cache_manager
from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb
from lightllm.distributed.communication_op import all_reduce

Expand All @@ -29,8 +31,24 @@
class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer):
def __init__(self, network_config, mode):
super().__init__(network_config, mode)
self.disable_extra_process_for_multimodal = get_env_start_args().disable_extra_process_for_multimodal
return

def _infer_image_embeds(self, infer_state, layer_weight):
if not self.disable_extra_process_for_multimodal:
return
infer_images = []
for _, p in enumerate(infer_state.multimodal_params):
for img in p["images"] + p["audios"]:
if (img["_prefill_"] is True) and (not image_cache_manager.query_embed(img["uuid"])):
infer_images.append(img)
if len(infer_images) > 0:
infer_batch_size = get_env_start_args().visual_infer_batch_size
for i in range(0, len(infer_images), infer_batch_size):
img_embeds, uuids, valid_ids = layer_weight.visual_model.encode(infer_images[i : i + infer_batch_size])
for uuid, valid_id in zip(uuids, valid_ids):
image_cache_manager.set_embed(uuid, img_embeds[valid_id[0] : valid_id[1]])

def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight):

img_weight = []
Expand All @@ -45,14 +63,19 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei

infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids)

self._infer_image_embeds(infer_state, layer_weight)
for batch_id, p in enumerate(infer_state.multimodal_params):
for img in p["images"] + p["audios"]:
# skip the same image
if img["token_id"] in img_start_token_ids or img["_prefill_"] is False:
continue
# pull the img_embeds by uid from shm
data = read_shm(get_shm_name_embed(img["uuid"]))
img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1))
if self.disable_extra_process_for_multimodal:
img_embed = image_cache_manager.get_embed(img["uuid"])
img_weight.append(img_embed.reshape(img["token_num"], -1))
Comment on lines +74 to +75

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

image_cache_manager.get_embed() can return None. Handle this case to prevent a potential AttributeError on img_embed.reshape().

Suggested change
img_embed = image_cache_manager.get_embed(img["uuid"])
img_weight.append(img_embed.reshape(img["token_num"], -1))
img_embed = image_cache_manager.get_embed(img["uuid"])
if img_embed is None:
raise ValueError(f"Image embedding for uuid {img['uuid']} not found in cache.")
img_weight.append(img_embed.reshape(img["token_num"], -1))

else:
data = read_shm(get_shm_name_embed(img["uuid"]))
img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1))
img_start_token_ids.append(img["token_id"])
img_token_lens.append(img["token_num"])
img_start_locs.append(img_start_loc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
MultiROWMMWeight,
TpNormWeight,
)
from lightllm.utils.dist_utils import get_current_device_id
from lightllm.utils.dist_utils import (
get_current_device_id,
)


class ViTTransformerLayerWeight(TransformerLayerWeight):
Expand Down
21 changes: 18 additions & 3 deletions lightllm/models/vit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from io import BytesIO
from rpyc.utils.classic import obtain
from lightllm.common.quantization import Quantcfg
from lightllm.utils.dist_utils import get_dp_world_size
from lightllm.utils.dist_utils import get_dp_world_size, get_global_world_size
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager


Expand Down Expand Up @@ -47,6 +48,7 @@ def __init__(self, kvargs):
self.quant_cfg_path = kvargs.get("quant_cfg", None)
self.load_image_func = get_load_image_func(self.weight_dir_)
self.max_batch_size = kvargs.get("max_batch_size", 1)
self.enable_tensor_cache = not get_env_start_args().disable_extra_process_for_multimodal

self._init_datatype()
self._init_config()
Expand All @@ -63,13 +65,15 @@ def _check_max_len_infer(self):
disable_check_max_len_infer = os.getenv("DISABLE_CHECK_MAX_LEN_INFER", None) is not None
if disable_check_max_len_infer:
return
self.enable_tensor_cache = True

try:
dummy_images = torch.randn(
(self.MAX_PATH_NUM * self.max_batch_size, 3, self.IMAGE_H, self.IMAGE_W), dtype=self.data_type
).cuda()
all_img_embeds = self.forward(dummy_images)
del all_img_embeds
del dummy_images
logger.info(f"vit check max_len {self.max_batch_size} infer ok")
except (RuntimeError, torch.OutOfMemoryError) as e:
logger.exception(str(e))
Expand All @@ -78,6 +82,7 @@ def _check_max_len_infer(self):
)
logger.error(exception_str)
raise Exception(exception_str)
self.enable_tensor_cache = not get_env_start_args().disable_extra_process_for_multimodal
return

def _init_config(self):
Expand Down Expand Up @@ -150,6 +155,8 @@ def _init_infer_layer(self):
return

def _init_datatype(self):
if isinstance(self.data_type, torch.dtype):
return
if self.data_type in ["fp16", "float16"]:
self.data_type = torch.float16
elif self.data_type in ["bf16", "bfloat16"]:
Expand All @@ -161,12 +168,14 @@ def _init_datatype(self):

@torch.no_grad()
def forward(self, pixel_values):
g_cache_manager.cache_env_in()
if self.enable_tensor_cache:
g_cache_manager.cache_env_in()
input_embs = self.pre_infer.forward(pixel_values, self.pre_post_weight)
for i in range(self.layers_num + self.select_layer + 1):
input_embs = self.layers_infer[i].forward(input_embs, self.trans_layers_weight[i])
input_embs = self.post_infer.forward(input_embs[:, 1:, :], self.pre_post_weight)
g_cache_manager.cache_env_out()
if self.enable_tensor_cache:
g_cache_manager.cache_env_out()
return input_embs

@torch.no_grad()
Expand All @@ -182,6 +191,12 @@ def encode(self, images: List[ImageItem]):
image_data = Image.open(BytesIO(image_data))
t = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"])
img_tensors.append(t)
elif isinstance(img, dict):
uuids.append(img["uuid"])
image_data = read_shm(get_shm_name_data(img["uuid"]))
image_data = Image.open(BytesIO(image_data))
t = self.load_image_func(image_data, max_num=img["extra_params"]["image_patch_max_num"])
img_tensors.append(t)
Comment on lines +194 to +199

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The elif block is nearly identical to the preceding if block. Extract the common logic into a helper function to reduce code duplication.

else:
raise Exception("Unsupport input types: {} for {}".format(type(img), img))

Expand Down
5 changes: 5 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--enable_multimodal", action="store_true", help="Whether or not to allow to load additional visual models."
)
parser.add_argument(
"--disable_extra_process_for_multimodal",
action="store_true",
help="Whether or not to disable extra process for multimodal.",
)
parser.add_argument(
"--enable_multimodal_audio",
action="store_true",
Expand Down
4 changes: 2 additions & 2 deletions lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def normal_or_p_d_start(args):
],
start_args=[(cache_port, args)],
)
if args.enable_multimodal_audio:
if args.enable_multimodal_audio and not args.disable_extra_process_for_multimodal:
from .audioserver.manager import start_audio_process

process_manager.start_submodule_processes(
Expand All @@ -278,7 +278,7 @@ def normal_or_p_d_start(args):
],
)

else:
elif not args.disable_extra_process_for_multimodal:
process_manager.start_submodule_processes(
start_funcs=[
start_visual_process,
Expand Down
10 changes: 6 additions & 4 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,12 @@ def __init__(
)

self.enable_multimodal = enable_multimodal
self.disable_extra_process_for_multimodal = args.disable_extra_process_for_multimodal
if self.enable_multimodal:
self.cache_client = rpyc.connect("localhost", cache_port)
self.send_to_visual = context.socket(zmq.PUSH)
self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}")
if not self.disable_extra_process_for_multimodal:
self.send_to_visual = context.socket(zmq.PUSH)
self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}")

self.shm_req_manager = ShmReqManager()

Expand Down Expand Up @@ -449,7 +451,7 @@ async def transfer_to_next_module(
):

if self.pd_mode == NodeRole.P:
if self.enable_multimodal:
if self.enable_multimodal and not self.disable_extra_process_for_multimodal:
self.send_to_visual.send_pyobj(
group_req_objs.to_group_req_index(),
protocol=pickle.HIGHEST_PROTOCOL,
Expand All @@ -470,7 +472,7 @@ async def transfer_to_next_module(
return

if self.pd_mode == NodeRole.NORMAL:
if self.enable_multimodal:
if self.enable_multimodal and not self.disable_extra_process_for_multimodal:
self.send_to_visual.send_pyobj(
group_req_objs.to_group_req_index(),
protocol=pickle.HIGHEST_PROTOCOL,
Expand Down
1 change: 1 addition & 0 deletions lightllm/server/multimodal_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def to_dict(self):
ret["uuid"] = self.uuid
ret["token_id"] = self.token_id
ret["token_num"] = self.token_num
ret["extra_params"] = self.extra_params
return ret

def to_origin_dict(self):
Expand Down
Loading