diff --git a/docs/CN/source/tutorial/api_server_args_zh.rst b/docs/CN/source/tutorial/api_server_args_zh.rst index d7c055ef4..f40528bba 100755 --- a/docs/CN/source/tutorial/api_server_args_zh.rst +++ b/docs/CN/source/tutorial/api_server_args_zh.rst @@ -274,10 +274,6 @@ attention类型选择参数 多模态资源的缓存服务器容量,默认为 ``200`` -.. option:: --cache_reserved_ratio - - 缓存服务器清理后的保留容量比例,默认为 ``0.5`` - .. option:: --visual_infer_batch_size 每次推理批次中处理的图像数量,默认为 ``1`` diff --git a/docs/EN/source/tutorial/api_server_args_zh.rst b/docs/EN/source/tutorial/api_server_args_zh.rst index 3b25ae85c..629d34bf8 100755 --- a/docs/EN/source/tutorial/api_server_args_zh.rst +++ b/docs/EN/source/tutorial/api_server_args_zh.rst @@ -273,10 +273,6 @@ Multimodal Parameters Cache server capacity for multimodal resources, default is ``200`` -.. option:: --cache_reserved_ratio - - Reserved capacity ratio after cache server cleanup, default is ``0.5`` - .. option:: --visual_infer_batch_size Number of images processed in each inference batch, default is ``1`` diff --git a/lightllm/models/whisper/whisper_audio.py b/lightllm/models/whisper/whisper_audio.py index 4a96efbf1..c5959ea1e 100644 --- a/lightllm/models/whisper/whisper_audio.py +++ b/lightllm/models/whisper/whisper_audio.py @@ -11,7 +11,7 @@ from transformers.processing_utils import ProcessorMixin from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed from lightllm.server.multimodal_params import AudioItem - +from rpyc.utils.classic import obtain # tokenizer_class removed class WhisperProcessor(ProcessorMixin): @@ -89,7 +89,7 @@ def __init__(self, kvargs): self.sampling_rate = 16000 self.max_length = self.max_seconds * self.sampling_rate self.cache_port = kvargs["cache_port"] - self.cache_client = rpyc.connect("localhost", self.cache_port) + self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) data_type = kvargs["data_type"] if data_type in ["bf16", "bfloat16"]: self.data_type = torch.bfloat16 @@ -190,8 +190,13 @@ def encode(self, audio_items: List[AudioItem]): audio_lens_after_cnn = np.array(audio_lens_after_cnn, dtype=np.int32) audio_token_num = (audio_lens_after_cnn - 2) // 2 + 1 - for i in range(len(uuids)): - if not self.cache_client.root.get_item_embed(uuids[i]): + ready_audio = obtain(self.cache_client.root.get_items_embed(uuids)) + ids_to_set = [] + for i, ready in enumerate(ready_audio): + if not ready: + uid = uuids[i] cur_embed_bytes = tensor2bytes(audios[i][: audio_token_num[i]]) - create_shm(get_shm_name_embed(uuids[i]), cur_embed_bytes) - self.cache_client.root.set_item_embed(uuids[i]) + create_shm(get_shm_name_embed(uid), cur_embed_bytes) + ids_to_set.append(uid) + if ids_to_set: + self.cache_client.root.set_items_embed(ids=ids_to_set) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index c0ccd7a1c..ee5518ea8 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -288,9 +288,6 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--cache_capacity", type=int, default=200, help="cache server capacity for multimodal resources" ) - parser.add_argument( - "--cache_reserved_ratio", type=float, default=0.5, help="cache server reserved capacity ratio after clear" - ) parser.add_argument( "--data_type", type=str, diff --git a/lightllm/server/audioserver/manager.py b/lightllm/server/audioserver/manager.py index abd1416d1..707fd11d0 100644 --- a/lightllm/server/audioserver/manager.py +++ b/lightllm/server/audioserver/manager.py @@ -14,6 +14,7 @@ from lightllm.server.multimodal_params import AudioItem from .model_infer.model_rpc import start_model_process, AudioModelRpcClient from lightllm.utils.graceful_utils import graceful_registry +from rpyc.utils.classic import obtain logger = init_logger(__name__) @@ -33,7 +34,7 @@ def __init__( self.recv_from_visualserver = context.socket(zmq.PULL) self.recv_from_visualserver.bind(f"{args.zmq_mode}127.0.0.1:{audio_port}") - self.cache_client = rpyc.connect("localhost", cache_port) + self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) self.cache_port = cache_port self.waiting_reqs: List[GroupReqIndexes] = [] self.model_weightdir = args.model_dir @@ -94,8 +95,11 @@ async def loop_for_fwd(self): multimodal_params = group_req_indexes.multimodal_params - for audio in multimodal_params.audios: - if not self.cache_client.root.get_item_embed(audio.uuid): + audio_uuids = [audio.uuid for audio in multimodal_params.audios] + ready_audio = obtain(self.cache_client.root.get_items_embed(audio_uuids)) + + for audio, ready in zip(multimodal_params.audios, ready_audio): + if not ready: audios_need_infer.append(audio) if len(audios_need_infer) == self.infer_batch_size: diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 15e344871..d4a205a15 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -57,7 +57,6 @@ class StartArgs: enable_decode_microbatch_overlap: bool = field(default=False) enable_prefill_microbatch_overlap: bool = field(default=False) cache_capacity: int = field(default=200) - cache_reserved_ratio: float = field(default=0.5) data_type: Optional[str] = field( default=None, metadata={"choices": ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"]} ) diff --git a/lightllm/server/embed_cache/impl/naive_memory_cache.py b/lightllm/server/embed_cache/impl/naive_memory_cache.py index c03b084c4..5477be22b 100644 --- a/lightllm/server/embed_cache/impl/naive_memory_cache.py +++ b/lightllm/server/embed_cache/impl/naive_memory_cache.py @@ -2,8 +2,7 @@ import threading import dataclasses import requests -from ..interface import CacheManager, CacheManagerFactory -from typing import Union +from typing import Union, Optional import torch import time from collections import deque @@ -27,15 +26,12 @@ class Record(object): token_num: int -@CacheManagerFactory.register("naive") -class InMemoryCache(CacheManager): +class InMemoryCache: def __init__(self, args) -> None: self.args = args self._records = dict() self._md5_to_record = dict() self.capacity = max(1, args.cache_capacity) - self.reserved = max(0, int(self.capacity * args.cache_reserved_ratio)) - self.reserved = min(self.reserved, self.capacity - 1) self.occupied = 0 self.expired_secs = 60 * 60 self.lock = threading.Lock() @@ -71,9 +67,9 @@ def _check_and_set_new_id_range(self, alloced_token_num): time.sleep(3) return - def _clear(self): + def _clear(self, free_max_count: int): deleted = 0 - max_delete = max(1, self.occupied - self.reserved) + max_delete = free_max_count items = sorted(self._records.items(), key=lambda x: x[1].visittime) t = time.time() for id, record in items: @@ -89,57 +85,59 @@ def _clear(self): if deleted >= max_delete: break - def alloc(self, md5sum: str, token_num: int) -> dict: + def alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[list[dict]]: + now = time.time() with self.lock: - t = time.time() - # add new record - if md5sum not in self._md5_to_record: - - # full, need to clear some unused items - if self.occupied >= self.capacity: - self._clear() - if self.occupied >= self.capacity: - return None - - id = uuid.uuid1() - id = id.int - self._check_and_set_new_id_range(token_num) - record = Record( - id=id, - md5sum=md5sum, - ref=1, - data=False, - embed=False, - createtime=t, - visittime=t, - token_id=self.token_id_range_start, - token_num=token_num, - ) - self.token_id_range_start += token_num - self._records[id] = record - self._md5_to_record[md5sum] = record - self.occupied += 1 - - # cache hit - else: - record = self._md5_to_record[md5sum] - record.visittime = t - record.ref += 1 - - return {"id": record.id, "token_id": record.token_id, "token_num": record.token_num} - - def release(self, id: int) -> None: + new_md5s = [m for m in md5sum_list if m not in self._md5_to_record] + new_needed = len(set(new_md5s)) + + if self.occupied + new_needed > self.capacity: + self._clear(free_max_count=new_needed - (self.capacity - self.occupied)) + if self.occupied + new_needed > self.capacity: + return None + + results: list[dict] = [] + for md5sum, token_num in zip(md5sum_list, token_num_list): + if md5sum in self._md5_to_record: + rec = self._md5_to_record[md5sum] + rec.visittime = now + rec.ref += 1 + else: + uid_int = uuid.uuid1().int + self._check_and_set_new_id_range(token_num) + rec = Record( + id=uid_int, + md5sum=md5sum, + ref=1, + data=False, + embed=False, + createtime=now, + visittime=now, + token_id=self.token_id_range_start, + token_num=token_num, + ) + self.token_id_range_start += token_num + self._records[uid_int] = rec + self._md5_to_record[md5sum] = rec + self.occupied += 1 + results.append({"id": rec.id, "token_id": rec.token_id, "token_num": rec.token_num}) + return results + + def release(self, ids: list[int]) -> None: with self.lock: - self._records[id].ref -= 1 + for id_ in ids: + self._records[id_].ref -= 1 - def set_item_data(self, id: int) -> None: - self._records[id].data = True + def set_items_data(self, ids: list[int]) -> None: + for id_ in ids: + self._records[id_].data = True - def get_item_data(self, id: int) -> bool: - return self._records[id].data + def get_items_data(self, ids: list[int]) -> list[Optional[bool]]: + return [self._records.get(id_).data if id_ in self._records else False for id_ in ids] - def set_item_embed(self, id: int) -> None: - self._records[id].embed = True + def set_items_embed(self, ids: list[int]) -> None: + for id_ in ids: + self._records[id_].embed = True - def get_item_embed(self, id: int) -> bool: - return self._records[id].embed + def get_items_embed(self, ids: list[int]) -> list[Optional[bool]]: + return [self._records.get(id_).embed if id_ in self._records else False for id_ in ids] diff --git a/lightllm/server/embed_cache/interface.py b/lightllm/server/embed_cache/interface.py deleted file mode 100644 index 030b59986..000000000 --- a/lightllm/server/embed_cache/interface.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import Union - -class CacheManager(object): - ''' Defines the interface of embedding cache manager. - ''' - def __init__(self) -> None: - pass - - def alloc(self, md5sum: str, token_num: int) -> dict: - pass - - def release(self, id: int) -> None: - pass - - def set_item_data(self, id: int) -> None: - pass - - def get_item_data(self, id: int) -> bool: - pass - - def set_item_embed(self, id: int) -> None: - pass - - def get_item_embed(self, id: int) -> bool: - pass - - -class CacheManagerFactory(object): - _impls = dict() - - @classmethod - def register(cls, target): - def add_register_item(key, value): - if not callable(value): - raise Exception(f"register object must be callable! But receice:{value} is not callable!") - if key in cls._impls: - print(f"warning: \033[33m{value.__name__} has been registered before, so we will overriden it\033[0m") - cls._impls[key] = value - return value - - if callable(target): # 如果传入的目标可调用,说明之前没有给出注册名字,我们就以传入的函数或者类的名字作为注册名 - return add_register_item(target.__name__, target) - else: # 如果不可调用,说明额外说明了注册的可调用对象的名字 - return lambda x : add_register_item(target, x) - - @classmethod - def get_impl(cls, name: str): - return cls._impls[name] diff --git a/lightllm/server/embed_cache/manager.py b/lightllm/server/embed_cache/manager.py index 85ed32505..566124142 100644 --- a/lightllm/server/embed_cache/manager.py +++ b/lightllm/server/embed_cache/manager.py @@ -1,14 +1,14 @@ import rpyc import uuid import inspect -from typing import Union +from typing import Union, Optional from lightllm.utils.graceful_utils import graceful_registry -from .interface import CacheManager +from lightllm.server.embed_cache.impl.naive_memory_cache import InMemoryCache from rpyc.utils.classic import obtain class CacheServer(rpyc.Service): - def __init__(self, manager_impl: CacheManager) -> None: + def __init__(self, manager_impl: InMemoryCache) -> None: super().__init__() self._impl = manager_impl @@ -22,41 +22,38 @@ def on_disconnect(self, conn): # (to finalize the service, if needed) pass - def exposed_alloc(self, md5sum: str, token_num: int) -> dict: - md5sum = obtain(md5sum) - token_num = obtain(token_num) - record = self._impl.alloc(md5sum, token_num) + def exposed_alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[list[dict]]: + md5sum_list = obtain(md5sum_list) + token_num_list = obtain(token_num_list) + record = self._impl.alloc(md5sum_list, token_num_list) return record - def exposed_release(self, id: int) -> None: - id = obtain(id) - return self._impl.release(id) + def exposed_release(self, ids: list[int]) -> None: + ids = obtain(ids) + return self._impl.release(ids) - def exposed_set_item_data(self, id: int) -> None: - id = obtain(id) - return self._impl.set_item_data(id=id) + def exposed_set_items_data(self, ids: list[int]) -> None: + ids = obtain(ids) + return self._impl.set_items_data(ids) - def exposed_get_item_data(self, id: int) -> bool: - id = obtain(id) - return self._impl.get_item_data(id=id) + def exposed_get_items_data(self, ids: list[int]) -> list[bool]: + ids = obtain(ids) + return self._impl.get_items_data(ids) - def exposed_set_item_embed(self, id: int) -> None: - id = obtain(id) - return self._impl.set_item_embed(id=id) + def exposed_set_items_embed(self, ids: list[int]) -> None: + ids = obtain(ids) + return self._impl.set_items_embed(ids) - def exposed_get_item_embed(self, id: int) -> bool: - id = obtain(id) - return self._impl.get_item_embed(id=id) + def exposed_get_items_embed(self, ids: list[int]) -> list[bool]: + ids = obtain(ids) + return self._impl.get_items_embed(ids) def start_cache_manager(port: int, args, pipe_writer): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) - from .interface import CacheManagerFactory - - manager_cls = CacheManagerFactory.get_impl("naive") - manager = manager_cls(args) + manager = InMemoryCache(args) service = CacheServer(manager) from rpyc.utils.server import ThreadedServer diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 72a33f5e0..96e48fb13 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -32,6 +32,7 @@ from lightllm.utils.statics_utils import MovingAverage from lightllm.utils.config_utils import get_vocab_size from lightllm.utils.envs_utils import get_unique_server_name +from rpyc.utils.classic import obtain logger = init_logger(__name__) @@ -81,7 +82,7 @@ def __init__( self.enable_multimodal = enable_multimodal if self.enable_multimodal: - self.cache_client = rpyc.connect("localhost", cache_port) + self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) self.send_to_visual = context.socket(zmq.PUSH) self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}") @@ -113,33 +114,33 @@ def __init__( self.latest_success_infer_time_mark.set_value(int(time.time())) return - # connect cache server, calculate md5, alloc resource, return uuid - async def _alloc_resource(self, item: Union[ImageItem, AudioItem]): - if isinstance(item, ImageItem): - data = item.read() - # must after init_imageitem_extral_params - num_tokens = self.tokenizer.get_image_token_length(item) - elif isinstance(item, AudioItem): - data = item.read() - num_tokens = self.tokenizer.get_audio_token_length(item) - else: - raise ValueError(f"unexpected item type {type(item)}") + async def _alloc_resource(self, items, md5sums, token_nums, datas): - md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(item.extra_params))) - wait_time = 1 while True: - record = self.cache_client.root.alloc(md5sum, num_tokens) - # hit or new - if record: - uid = record["id"] - if not self.cache_client.root.get_item_data(uid): + records = obtain(self.cache_client.root.alloc(md5sums, token_nums)) + + if records is None: + await asyncio.sleep(0.1) + continue + + uid_list = [] + for item, rec in zip(items, records): + item.uuid = rec["id"] + item.token_id = rec["token_id"] + item.token_num = rec["token_num"] + uid_list.append(rec["id"]) + + ready_flags = obtain(self.cache_client.root.get_items_data(uid_list)) + update_data_ids = [] + + for uid, ready, data in zip(uid_list, ready_flags, datas): + if not ready: create_shm(get_shm_name_data(uid), data) - self.cache_client.root.set_item_data(uid) - return record - # cache full - else: - await asyncio.sleep(wait_time) - wait_time = min(wait_time + 2, 9) + update_data_ids.append(uid) + + if update_data_ids: + self.cache_client.root.set_items_data(update_data_ids) + return async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, sampling_params: SamplingParams): # 只有 P 和 NORMAL 节点需要真的管理多模态资源 @@ -148,38 +149,51 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, # 如果不加任何锁,假如请求1和请求2都有6张图片,而cache_capacity为10, # 那么如果某一时刻shm中存在请求1的5张图和请求2的5张图,将会资源竞争产生死锁。 async with self._resource_lock: + items, md5sums, tokens_nums, datas = [], [], [], [] for img in multimodal_params.images: self.tokenizer.init_imageitem_extral_params(img, multimodal_params, sampling_params) - record = await self._alloc_resource(img) - img.uuid = record["id"] - img.token_id = record["token_id"] - img.token_num = record["token_num"] + data = img.read() + # must after init_imageitem_extral_params + token_num = self.tokenizer.get_image_token_length(img) + md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(img.extra_params))) + md5sums.append(md5sum) + tokens_nums.append(token_num) + datas.append(data) + items.append(img) for audio in multimodal_params.audios: self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, sampling_params) - record = await self._alloc_resource(audio) - audio.uuid = record["id"] - audio.token_id = record["token_id"] - audio.token_num = record["token_num"] - return + data = audio.read() + token_num = self.tokenizer.get_audio_token_length(audio) + md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(audio.extra_params))) + md5sums.append(md5sum) + tokens_nums.append(token_num) + datas.append(data) + items.append(audio) + + await self._alloc_resource(items, md5sums, tokens_nums, datas) + return async def _release_multimodal_resources(self, multimodal_params: MultimodalParams): # 只有 P 和 NORMAL 节点需要真的管理多模态资源 if self.pd_mode.is_P_or_NORMAL(): if multimodal_params is not None: + ids_to_release = [] for img in multimodal_params.images: if img.uuid is not None: - self.cache_client.root.release(img.uuid) + ids_to_release.append(img.uuid) # 将 uuid 等 赋值为 None, 防止因为abort等异常情况造成重复释放异常 img.uuid = None img.token_id = None img.token_num = None for audio in multimodal_params.audios: if audio.uuid is not None: - self.cache_client.root.release(audio.uuid) + ids_to_release.append(audio.uuid) # 将 uuid 等 赋值为 None, 防止因为abort等异常情况造成重复释放异常 audio.uuid = None audio.token_id = None audio.token_num = None + if ids_to_release: + self.cache_client.root.release(ids_to_release) return def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwargs=None): diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 6fabe2465..7a4557c47 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -15,6 +15,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread +from rpyc.utils.classic import obtain logger = init_logger(__name__) @@ -35,7 +36,7 @@ def __init__( self.recv_from_httpserver = context.socket(zmq.PULL) self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{visual_port}") - self.cache_client = rpyc.connect("localhost", cache_port) + self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) self.cache_port = cache_port self.waiting_reqs: List[GroupReqIndexes] = [] self.model_weightdir = args.model_dir @@ -120,8 +121,11 @@ async def loop_for_fwd(self): multimodal_params = group_req_indexes.multimodal_params - for img in multimodal_params.images: - if not self.cache_client.root.get_item_embed(img.uuid): + img_uuids = [img.uuid for img in multimodal_params.images] + ready_image = obtain(self.cache_client.root.get_items_embed(img_uuids)) + + for img, ready in zip(multimodal_params.images, ready_image): + if not ready: images_need_infer.append(img) if len(images_need_infer) == self.infer_batch_size: diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 2d0d99a50..d2d45f2fd 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -38,7 +38,7 @@ def exposed_init_model(self, kvargs): self.cache_port = kvargs["cache_port"] weight_dir = kvargs["weight_dir"] self.vit_rank_id = kvargs["vit_rank_id"] - self.cache_client = rpyc.connect("localhost", self.cache_port) + self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) self.data_type = kvargs["data_type"] init_vision_distributed_env(kvargs) @@ -94,14 +94,20 @@ def exposed_encode(self, images: List[ImageItem]): images = obtain(images) all_img_embeds, uuids, valid_ids = self.forward(images) all_img_embeds = all_img_embeds.to(torch.device("cpu")) + if self.tp_rank_id == 0: - for i in range(len(uuids)): + ready_flags = obtain(self.cache_client.root.get_items_embed(uuids)) + ids_to_set = [] + for i, ready in enumerate(ready_flags): + if ready: + continue uid = uuids[i] - if not self.cache_client.root.get_item_embed(uid): - start, end = valid_ids[i] - cur_embed_bytes = tensor2bytes(all_img_embeds[start:end]) - create_shm(get_shm_name_embed(uuids[i]), cur_embed_bytes) - self.cache_client.root.set_item_embed(uuids[i]) + start, end = valid_ids[i] + cur_embed_bytes = tensor2bytes(all_img_embeds[start:end]) + create_shm(get_shm_name_embed(uid), cur_embed_bytes) + ids_to_set.append(uid) + if ids_to_set: + self.cache_client.root.set_items_embed(ids_to_set) return