-
Notifications
You must be signed in to change notification settings - Fork 269
Multimodal improve #951
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Multimodal improve #951
Changes from all commits
d498aaf
691b89c
8624e8f
3e04c7f
06c38a0
d21ccaa
2ab9dfc
b0768e2
cbf93a0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from collections import OrderedDict | ||
from lightllm.utils.dist_utils import get_current_device_id | ||
|
||
|
||
class ImageCacheManager: | ||
def __init__(self): | ||
""" | ||
Initialize the image cache manager with a simple GPU cache and an LRU CPU cache. | ||
""" | ||
self._gpu_cache = dict() | ||
self._cpu_cache = OrderedDict() | ||
|
||
def set_max_size(self, max_size: int): | ||
""" | ||
Set the maximum number of items to keep in the CPU cache. | ||
:param max_size: Maximum number of items to keep in the CPU cache. | ||
""" | ||
if max_size <= 0: | ||
raise ValueError("max_size must be greater than 0") | ||
self._max_size = max_size | ||
|
||
def set_embed(self, uuid, embed): | ||
""" | ||
Store the embedding for the given uuid in the GPU cache. | ||
:param uuid: Unique identifier for the image | ||
:param embed: Embedding vector for the image (on GPU) | ||
""" | ||
self._gpu_cache[uuid] = embed | ||
|
||
def get_embed(self, uuid): | ||
""" | ||
Retrieve the embedding for the given uuid. Prefer GPU cache, | ||
otherwise return CPU cache and move to GPU (simulate .cuda()). | ||
:param uuid: Unique identifier for the image | ||
:return: Embedding vector (on GPU if possible, else move from CPU to GPU) | ||
""" | ||
if uuid in self._gpu_cache: | ||
return self._gpu_cache[uuid] | ||
elif uuid in self._cpu_cache: | ||
self._cpu_cache.move_to_end(uuid) | ||
embed = self._cpu_cache[uuid].cuda(get_current_device_id()) | ||
return embed | ||
return None | ||
|
||
def query_embed(self, uuid): | ||
""" | ||
Query if the embedding for the given uuid is in the cache. | ||
:param uuid: Unique identifier for the image | ||
:return: True if the embedding is in the cache, False otherwise | ||
""" | ||
return uuid in self._gpu_cache or uuid in self._cpu_cache | ||
|
||
def filter(self, uuid_list): | ||
""" | ||
Given a list of uuids, move their embeddings from GPU cache to CPU cache if present, | ||
and return a dict of those found in the cache and their embeddings (on CPU). | ||
:param uuid_list: List of uuids | ||
""" | ||
for uuid in uuid_list: | ||
if uuid in self._gpu_cache: | ||
embed_cpu = self._gpu_cache[uuid].cpu() | ||
# Move to CPU cache and remove from GPU cache | ||
self._gpu_cache.pop(uuid) | ||
if uuid in self._cpu_cache: | ||
self._cpu_cache.move_to_end(uuid) | ||
self._cpu_cache[uuid] = embed_cpu | ||
if len(self._cpu_cache) > self._max_size: | ||
self._cpu_cache.popitem(last=False) | ||
elif uuid in self._cpu_cache: | ||
self._cpu_cache.move_to_end(uuid) | ||
print(self._gpu_cache.keys()) | ||
print(self._cpu_cache.keys()) | ||
Comment on lines
+71
to
+72
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
return | ||
|
||
|
||
image_cache_manager = ImageCacheManager() |
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -6,7 +6,9 @@ | |||||||||||||
|
||||||||||||||
from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer | ||||||||||||||
from lightllm.utils.infer_utils import mark_cost_time | ||||||||||||||
from lightllm.utils.envs_utils import get_env_start_args | ||||||||||||||
from lightllm.server.embed_cache.utils import bytes2tensor, read_shm, get_shm_name_embed | ||||||||||||||
from lightllm.common.image_cache_manager import image_cache_manager | ||||||||||||||
from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb | ||||||||||||||
from lightllm.distributed.communication_op import all_reduce | ||||||||||||||
|
||||||||||||||
|
@@ -29,8 +31,24 @@ | |||||||||||||
class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer): | ||||||||||||||
def __init__(self, network_config, mode): | ||||||||||||||
super().__init__(network_config, mode) | ||||||||||||||
self.disable_extra_process_for_multimodal = get_env_start_args().disable_extra_process_for_multimodal | ||||||||||||||
return | ||||||||||||||
|
||||||||||||||
def _infer_image_embeds(self, infer_state, layer_weight): | ||||||||||||||
if not self.disable_extra_process_for_multimodal: | ||||||||||||||
return | ||||||||||||||
infer_images = [] | ||||||||||||||
for _, p in enumerate(infer_state.multimodal_params): | ||||||||||||||
for img in p["images"] + p["audios"]: | ||||||||||||||
if (img["_prefill_"] is True) and (not image_cache_manager.query_embed(img["uuid"])): | ||||||||||||||
infer_images.append(img) | ||||||||||||||
if len(infer_images) > 0: | ||||||||||||||
infer_batch_size = get_env_start_args().visual_infer_batch_size | ||||||||||||||
for i in range(0, len(infer_images), infer_batch_size): | ||||||||||||||
img_embeds, uuids, valid_ids = layer_weight.visual_model.encode(infer_images[i : i + infer_batch_size]) | ||||||||||||||
for uuid, valid_id in zip(uuids, valid_ids): | ||||||||||||||
image_cache_manager.set_embed(uuid, img_embeds[valid_id[0] : valid_id[1]]) | ||||||||||||||
|
||||||||||||||
def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): | ||||||||||||||
|
||||||||||||||
img_weight = [] | ||||||||||||||
|
@@ -45,14 +63,19 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei | |||||||||||||
|
||||||||||||||
infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids) | ||||||||||||||
|
||||||||||||||
self._infer_image_embeds(infer_state, layer_weight) | ||||||||||||||
for batch_id, p in enumerate(infer_state.multimodal_params): | ||||||||||||||
for img in p["images"] + p["audios"]: | ||||||||||||||
# skip the same image | ||||||||||||||
if img["token_id"] in img_start_token_ids or img["_prefill_"] is False: | ||||||||||||||
continue | ||||||||||||||
# pull the img_embeds by uid from shm | ||||||||||||||
data = read_shm(get_shm_name_embed(img["uuid"])) | ||||||||||||||
img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1)) | ||||||||||||||
if self.disable_extra_process_for_multimodal: | ||||||||||||||
img_embed = image_cache_manager.get_embed(img["uuid"]) | ||||||||||||||
img_weight.append(img_embed.reshape(img["token_num"], -1)) | ||||||||||||||
Comment on lines
+74
to
+75
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
else: | ||||||||||||||
data = read_shm(get_shm_name_embed(img["uuid"])) | ||||||||||||||
img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1)) | ||||||||||||||
img_start_token_ids.append(img["token_id"]) | ||||||||||||||
img_token_lens.append(img["token_num"]) | ||||||||||||||
img_start_locs.append(img_start_loc) | ||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,8 @@ | |
from io import BytesIO | ||
from rpyc.utils.classic import obtain | ||
from lightllm.common.quantization import Quantcfg | ||
from lightllm.utils.dist_utils import get_dp_world_size | ||
from lightllm.utils.dist_utils import get_dp_world_size, get_global_world_size | ||
from lightllm.utils.envs_utils import get_env_start_args | ||
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager | ||
|
||
|
||
|
@@ -47,6 +48,7 @@ def __init__(self, kvargs): | |
self.quant_cfg_path = kvargs.get("quant_cfg", None) | ||
self.load_image_func = get_load_image_func(self.weight_dir_) | ||
self.max_batch_size = kvargs.get("max_batch_size", 1) | ||
self.enable_tensor_cache = not get_env_start_args().disable_extra_process_for_multimodal | ||
|
||
self._init_datatype() | ||
self._init_config() | ||
|
@@ -63,13 +65,15 @@ def _check_max_len_infer(self): | |
disable_check_max_len_infer = os.getenv("DISABLE_CHECK_MAX_LEN_INFER", None) is not None | ||
if disable_check_max_len_infer: | ||
return | ||
self.enable_tensor_cache = True | ||
|
||
try: | ||
dummy_images = torch.randn( | ||
(self.MAX_PATH_NUM * self.max_batch_size, 3, self.IMAGE_H, self.IMAGE_W), dtype=self.data_type | ||
).cuda() | ||
all_img_embeds = self.forward(dummy_images) | ||
del all_img_embeds | ||
del dummy_images | ||
logger.info(f"vit check max_len {self.max_batch_size} infer ok") | ||
except (RuntimeError, torch.OutOfMemoryError) as e: | ||
logger.exception(str(e)) | ||
|
@@ -78,6 +82,7 @@ def _check_max_len_infer(self): | |
) | ||
logger.error(exception_str) | ||
raise Exception(exception_str) | ||
self.enable_tensor_cache = not get_env_start_args().disable_extra_process_for_multimodal | ||
return | ||
|
||
def _init_config(self): | ||
|
@@ -150,6 +155,8 @@ def _init_infer_layer(self): | |
return | ||
|
||
def _init_datatype(self): | ||
if isinstance(self.data_type, torch.dtype): | ||
return | ||
if self.data_type in ["fp16", "float16"]: | ||
self.data_type = torch.float16 | ||
elif self.data_type in ["bf16", "bfloat16"]: | ||
|
@@ -161,12 +168,14 @@ def _init_datatype(self): | |
|
||
@torch.no_grad() | ||
def forward(self, pixel_values): | ||
g_cache_manager.cache_env_in() | ||
if self.enable_tensor_cache: | ||
g_cache_manager.cache_env_in() | ||
input_embs = self.pre_infer.forward(pixel_values, self.pre_post_weight) | ||
for i in range(self.layers_num + self.select_layer + 1): | ||
input_embs = self.layers_infer[i].forward(input_embs, self.trans_layers_weight[i]) | ||
input_embs = self.post_infer.forward(input_embs[:, 1:, :], self.pre_post_weight) | ||
g_cache_manager.cache_env_out() | ||
if self.enable_tensor_cache: | ||
g_cache_manager.cache_env_out() | ||
return input_embs | ||
|
||
@torch.no_grad() | ||
|
@@ -182,6 +191,12 @@ def encode(self, images: List[ImageItem]): | |
image_data = Image.open(BytesIO(image_data)) | ||
t = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"]) | ||
img_tensors.append(t) | ||
elif isinstance(img, dict): | ||
uuids.append(img["uuid"]) | ||
image_data = read_shm(get_shm_name_data(img["uuid"])) | ||
image_data = Image.open(BytesIO(image_data)) | ||
t = self.load_image_func(image_data, max_num=img["extra_params"]["image_patch_max_num"]) | ||
img_tensors.append(t) | ||
Comment on lines
+194
to
+199
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
else: | ||
raise Exception("Unsupport input types: {} for {}".format(type(img), img)) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
_max_size
attribute is used in thefilter
method but is not initialized in the__init__
method, which can lead to anAttributeError
iffilter()
is called beforeset_max_size()
.