Skip to content

[perf] Batch rpyc calls in multimodal path -2 #973

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions lightllm/models/whisper/whisper_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(self, kvargs):
self.sampling_rate = 16000
self.max_length = self.max_seconds * self.sampling_rate
self.cache_port = kvargs["cache_port"]
self.cache_client = rpyc.connect("localhost", self.cache_port)
self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True})
data_type = kvargs["data_type"]
if data_type in ["bf16", "bfloat16"]:
self.data_type = torch.bfloat16
Expand Down Expand Up @@ -190,8 +190,13 @@ def encode(self, audio_items: List[AudioItem]):
audio_lens_after_cnn = np.array(audio_lens_after_cnn, dtype=np.int32)
audio_token_num = (audio_lens_after_cnn - 2) // 2 + 1

for i in range(len(uuids)):
if not self.cache_client.root.get_item_embed(uuids[i]):
ready_audio = self.cache_client.root.get_items_embed(uuids)
ids_to_set = []
for i, ready in enumerate(ready_audio):
if not ready:
uid = uuids[i]
cur_embed_bytes = tensor2bytes(audios[i][: audio_token_num[i]])
create_shm(get_shm_name_embed(uuids[i]), cur_embed_bytes)
self.cache_client.root.set_item_embed(uuids[i])
create_shm(get_shm_name_embed(uid), cur_embed_bytes)
ids_to_set.append(uid)
if ids_to_set:
self.cache_client.root.set_items_data(ids=ids_to_set)
9 changes: 6 additions & 3 deletions lightllm/server/audioserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(

self.recv_from_visualserver = context.socket(zmq.PULL)
self.recv_from_visualserver.bind(f"{args.zmq_mode}127.0.0.1:{audio_port}")
self.cache_client = rpyc.connect("localhost", cache_port)
self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True})
self.cache_port = cache_port
self.waiting_reqs: List[GroupReqIndexes] = []
self.model_weightdir = args.model_dir
Expand Down Expand Up @@ -94,8 +94,11 @@ async def loop_for_fwd(self):

multimodal_params = group_req_indexes.multimodal_params

for audio in multimodal_params.audios:
if not self.cache_client.root.get_item_embed(audio.uuid):
audio_uuids = [audio.uuid for audio in multimodal_params.audios]
ready_audio = self.cache_client.root.get_items_embed(audio_uuids)

for audio, ready in zip(multimodal_params.audios, ready_audio):
if not ready:
audios_need_infer.append(audio)

if len(audios_need_infer) == self.infer_batch_size:
Expand Down
104 changes: 52 additions & 52 deletions lightllm/server/embed_cache/impl/naive_memory_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import threading
import dataclasses
import requests
from ..interface import CacheManager, CacheManagerFactory
from typing import Union
from typing import Union, Optional
import torch
import time
from collections import deque
Expand All @@ -27,8 +26,7 @@ class Record(object):
token_num: int


@CacheManagerFactory.register("naive")
class InMemoryCache(CacheManager):
class InMemoryCache:
def __init__(self, args) -> None:
self.args = args
self._records = dict()
Expand Down Expand Up @@ -89,57 +87,59 @@ def _clear(self):
if deleted >= max_delete:
break

def alloc(self, md5sum: str, token_num: int) -> dict:
def alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[list[dict]]:
now = time.time()
with self.lock:
t = time.time()
# add new record
if md5sum not in self._md5_to_record:

# full, need to clear some unused items
if self.occupied >= self.capacity:
self._clear()
if self.occupied >= self.capacity:
return None

id = uuid.uuid1()
id = id.int
self._check_and_set_new_id_range(token_num)
record = Record(
id=id,
md5sum=md5sum,
ref=1,
data=False,
embed=False,
createtime=t,
visittime=t,
token_id=self.token_id_range_start,
token_num=token_num,
)
self.token_id_range_start += token_num
self._records[id] = record
self._md5_to_record[md5sum] = record
self.occupied += 1

# cache hit
else:
record = self._md5_to_record[md5sum]
record.visittime = t
record.ref += 1

return {"id": record.id, "token_id": record.token_id, "token_num": record.token_num}

def release(self, id: int) -> None:
new_md5s = [m for m in md5sum_list if m not in self._md5_to_record]
new_needed = len(new_md5s)

if self.occupied + new_needed > self.capacity:
self._clear()
if self.occupied + new_needed > self.capacity:
return None

results: list[dict] = []
for md5sum, token_num in zip(md5sum_list, token_num_list):
if md5sum in self._md5_to_record:
rec = self._md5_to_record[md5sum]
rec.visittime = now
rec.ref += 1
else:
uid_int = uuid.uuid1().int
self._check_and_set_new_id_range(token_num)
rec = Record(
id=uid_int,
md5sum=md5sum,
ref=1,
data=False,
embed=False,
createtime=now,
visittime=now,
token_id=self.token_id_range_start,
token_num=token_num,
)
self.token_id_range_start += token_num
self._records[uid_int] = rec
self._md5_to_record[md5sum] = rec
self.occupied += 1
results.append({"id": rec.id, "token_id": rec.token_id, "token_num": rec.token_num})
return results

def release(self, ids: list[int]) -> None:
with self.lock:
self._records[id].ref -= 1
for id in ids:
self._records[id].ref -= 1

def set_item_data(self, id: int) -> None:
self._records[id].data = True
def set_items_data(self, ids: list[int]) -> None:
for id in ids:
self._records[id].data = True

def get_item_data(self, id: int) -> bool:
return self._records[id].data
def get_items_data(self, ids: list[int]) -> list[Optional[bool]]:
return [self._records.get(i).data if i in self._records else False for i in ids]

def set_item_embed(self, id: int) -> None:
self._records[id].embed = True
def set_items_embed(self, ids: list[int]) -> None:
for id in ids:
self._records[id].embed = True

def get_item_embed(self, id: int) -> bool:
return self._records[id].embed
def get_items_embed(self, ids: list[int]) -> list[Optional[bool]]:
return [self._records.get(i).embed if i in self._records else False for i in ids]
48 changes: 0 additions & 48 deletions lightllm/server/embed_cache/interface.py

This file was deleted.

49 changes: 23 additions & 26 deletions lightllm/server/embed_cache/manager.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import rpyc
import uuid
import inspect
from typing import Union
from typing import Union, Optional
from lightllm.utils.graceful_utils import graceful_registry
from .interface import CacheManager
from lightllm.server.embed_cache.impl.naive_memory_cache import InMemoryCache
from rpyc.utils.classic import obtain


class CacheServer(rpyc.Service):
def __init__(self, manager_impl: CacheManager) -> None:
def __init__(self, manager_impl: InMemoryCache) -> None:
super().__init__()
self._impl = manager_impl

Expand All @@ -22,41 +22,38 @@ def on_disconnect(self, conn):
# (to finalize the service, if needed)
pass

def exposed_alloc(self, md5sum: str, token_num: int) -> dict:
md5sum = obtain(md5sum)
token_num = obtain(token_num)
record = self._impl.alloc(md5sum, token_num)
def exposed_alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[list[dict]]:
md5sum_list = obtain(md5sum_list)
token_num_list = obtain(token_num_list)
record = self._impl.alloc(md5sum_list, token_num_list)
return record

def exposed_release(self, id: int) -> None:
id = obtain(id)
return self._impl.release(id)
def exposed_release(self, ids: list[int]) -> None:
ids = obtain(ids)
return self._impl.release(ids)

def exposed_set_item_data(self, id: int) -> None:
id = obtain(id)
return self._impl.set_item_data(id=id)
def exposed_set_items_data(self, ids: list[int]) -> None:
ids = obtain(ids)
return self._impl.set_items_data(ids)

def exposed_get_item_data(self, id: int) -> bool:
id = obtain(id)
return self._impl.get_item_data(id=id)
def exposed_get_items_data(self, ids: list[int]) -> list[bool]:
ids = obtain(ids)
return self._impl.get_items_data(ids)

def exposed_set_item_embed(self, id: int) -> None:
id = obtain(id)
return self._impl.set_item_embed(id=id)
def exposed_set_items_embed(self, ids: list[int]) -> None:
ids = obtain(ids)
return self._impl.set_items_embed(ids)

def exposed_get_item_embed(self, id: int) -> bool:
id = obtain(id)
return self._impl.get_item_embed(id=id)
def exposed_get_items_embed(self, ids: list[int]) -> list[bool]:
ids = obtain(ids)
return self._impl.get_items_embed(ids)


def start_cache_manager(port: int, args, pipe_writer):
# 注册graceful 退出的处理
graceful_registry(inspect.currentframe().f_code.co_name)

from .interface import CacheManagerFactory

manager_cls = CacheManagerFactory.get_impl("naive")
manager = manager_cls(args)
manager = InMemoryCache(args)
service = CacheServer(manager)
from rpyc.utils.server import ThreadedServer

Expand Down
Loading