Skip to content

Commit 649c6a3

Browse files
niushengxiaosangchengmeng
authored andcommitted
[fix]batch rpyc call in multimodal
1 parent c074ecb commit 649c6a3

File tree

4 files changed

+54
-72
lines changed

4 files changed

+54
-72
lines changed

lightllm/models/whisper/whisper_audio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def encode(self, audio_items: List[AudioItem]):
190190
audio_lens_after_cnn = np.array(audio_lens_after_cnn, dtype=np.int32)
191191
audio_token_num = (audio_lens_after_cnn - 2) // 2 + 1
192192

193-
ready_audio = self.cache_client.root.get_items_data(uuids)
193+
ready_audio = self.cache_client.root.get_items_embed(uuids)
194194
ids_to_set = []
195195
for i, ready in enumerate(ready_audio):
196196
if not ready:

lightllm/server/embed_cache/impl/naive_memory_cache.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import threading
33
import dataclasses
44
import requests
5-
from typing import Union
5+
from typing import Union, Optional
66
import torch
77
import time
88
from collections import deque
@@ -87,41 +87,42 @@ def _clear(self):
8787
if deleted >= max_delete:
8888
break
8989

90-
def alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> list[dict]:
91-
results = []
90+
def alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[list[dict]]:
91+
now = time.time()
9292
with self.lock:
93+
new_md5s = [m for m in md5sum_list if m not in self._md5_to_record]
94+
new_needed = len(new_md5s)
95+
96+
if self.occupied + new_needed > self.capacity:
97+
self._clear()
98+
if self.occupied + new_needed > self.capacity:
99+
return None
100+
101+
results: list[dict] = []
93102
for md5sum, token_num in zip(md5sum_list, token_num_list):
94-
t = time.time()
95-
if md5sum not in self._md5_to_record:
96-
if self.occupied >= self.capacity:
97-
self._clear()
98-
if self.occupied >= self.capacity:
99-
results.append(None)
100-
continue
101-
id = uuid.uuid1()
102-
id = id.int
103+
if md5sum in self._md5_to_record:
104+
rec = self._md5_to_record[md5sum]
105+
rec.visittime = now
106+
rec.ref += 1
107+
else:
108+
uid_int = uuid.uuid1().int
103109
self._check_and_set_new_id_range(token_num)
104-
record = Record(
105-
id=id,
110+
rec = Record(
111+
id=uid_int,
106112
md5sum=md5sum,
107113
ref=1,
108114
data=False,
109115
embed=False,
110-
createtime=t,
111-
visittime=t,
116+
createtime=now,
117+
visittime=now,
112118
token_id=self.token_id_range_start,
113119
token_num=token_num,
114120
)
115121
self.token_id_range_start += token_num
116-
self._records[id] = record
117-
self._md5_to_record[md5sum] = record
122+
self._records[uid_int] = rec
123+
self._md5_to_record[md5sum] = rec
118124
self.occupied += 1
119-
# cache hit
120-
else:
121-
record = self._md5_to_record[md5sum]
122-
record.visittime = t
123-
record.ref += 1
124-
results.append({"id": record.id, "token_id": record.token_id, "token_num": record.token_num})
125+
results.append({"id": rec.id, "token_id": rec.token_id, "token_num": rec.token_num})
125126
return results
126127

127128
def release(self, ids: list[int]) -> None:
@@ -133,12 +134,12 @@ def set_items_data(self, ids: list[int]) -> None:
133134
for id in ids:
134135
self._records[id].data = True
135136

136-
def get_items_data(self, ids: list[int]) -> list[bool]:
137+
def get_items_data(self, ids: list[int]) -> list[Optional[bool]]:
137138
return [self._records.get(i).data if i in self._records else False for i in ids]
138139

139140
def set_items_embed(self, ids: list[int]) -> None:
140141
for id in ids:
141142
self._records[id].embed = True
142143

143-
def get_items_embed(self, ids: list[int]) -> list[bool]:
144+
def get_items_embed(self, ids: list[int]) -> list[Optional[bool]]:
144145
return [self._records.get(i).embed if i in self._records else False for i in ids]

lightllm/server/embed_cache/manager.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import rpyc
22
import uuid
33
import inspect
4-
from typing import Union
4+
from typing import Union, Optional
55
from lightllm.utils.graceful_utils import graceful_registry
66
from lightllm.server.embed_cache.impl.naive_memory_cache import InMemoryCache
77
from rpyc.utils.classic import obtain
@@ -22,7 +22,7 @@ def on_disconnect(self, conn):
2222
# (to finalize the service, if needed)
2323
pass
2424

25-
def exposed_alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> dict:
25+
def exposed_alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[list[dict]]:
2626
md5sum_list = obtain(md5sum_list)
2727
token_num_list = obtain(token_num_list)
2828
record = self._impl.alloc(md5sum_list, token_num_list)
@@ -34,19 +34,19 @@ def exposed_release(self, ids: list[int]) -> None:
3434

3535
def exposed_set_items_data(self, ids: list[int]) -> None:
3636
ids = obtain(ids)
37-
return self._impl.set_items_data(ids=ids)
37+
return self._impl.set_items_data(ids)
3838

3939
def exposed_get_items_data(self, ids: list[int]) -> list[bool]:
4040
ids = obtain(ids)
41-
return self._impl.get_items_data(ids=ids)
41+
return self._impl.get_items_data(ids)
4242

4343
def exposed_set_items_embed(self, ids: list[int]) -> None:
4444
ids = obtain(ids)
45-
return self._impl.set_items_embed(ids=ids)
45+
return self._impl.set_items_embed(ids)
4646

4747
def exposed_get_items_embed(self, ids: list[int]) -> list[bool]:
4848
ids = obtain(ids)
49-
return self._impl.get_items_embed(ids=ids)
49+
return self._impl.get_items_embed(ids)
5050

5151

5252
def start_cache_manager(port: int, args, pipe_writer):

lightllm/server/httpserver/manager.py

Lines changed: 20 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -114,53 +114,34 @@ def __init__(
114114
self.latest_success_infer_time_mark.set_value(int(time.time()))
115115
return
116116

117-
async def _alloc_resource(self, items, md5sums, tokens_nums, datas):
117+
async def _alloc_resource(self, items, md5sums, token_nums, datas):
118118
wait_time = 1
119-
pending_idx = list(range(len(items)))
120-
while pending_idx:
121-
sub_md5sum = [md5sums[i] for i in pending_idx]
122-
sub_tokens_num = [tokens_nums[i] for i in pending_idx]
123-
124-
records = self.cache_client.root.alloc(sub_md5sum, sub_tokens_num)
119+
while True:
120+
records = self.cache_client.root.alloc(md5sums, token_nums)
125121

126-
if all(record is None for record in records):
122+
if records is None:
127123
await asyncio.sleep(wait_time)
128-
wait_time = min(wait_time + 2, 9)
124+
wait_time = min(wait_time + 0.5, 2)
129125
continue
130126

131-
next_pending = [] # record为None,安排在下一轮
132-
uids_to_check = [] # record存在,本轮处理
133-
uid_to_idx = {} # uid → 原items下标
127+
uid_list = []
128+
for item, rec in zip(items, records):
129+
item.uuid = rec["id"]
130+
item.token_id = rec["token_id"]
131+
item.token_num = rec["token_num"]
132+
uid_list.append(rec["id"])
134133

135-
for local_pos, record in enumerate(records):
136-
global_pos = pending_idx[local_pos]
134+
ready_flags = self.cache_client.root.get_items_data(uid_list)
135+
need_write = []
137136

138-
if record is None:
139-
next_pending.append(global_pos)
140-
continue
137+
for uid, ready, data in zip(uid_list, ready_flags, datas):
138+
if not ready:
139+
create_shm(get_shm_name_data(uid), data)
140+
need_write.append(uid)
141141

142-
uid = record["id"]
143-
uid_to_idx[uid] = global_pos
144-
uids_to_check.append(uid)
145-
146-
item = items[global_pos]
147-
item.uuid = uid
148-
item.token_id = record["token_id"]
149-
item.token_num = record["token_num"]
150-
151-
if uids_to_check:
152-
ready_flags = self.cache_client.root.get_items_data(uids_to_check)
153-
need_write = []
154-
155-
for uid, ready in zip(uids_to_check, ready_flags):
156-
if not ready:
157-
idx = uid_to_idx[uid]
158-
create_shm(get_shm_name_data(uid), datas[idx])
159-
need_write.append(uid)
160-
if need_write:
161-
self.cache_client.root.set_items_data(need_write)
162-
pending_idx = next_pending
163-
return
142+
if need_write:
143+
self.cache_client.root.set_items_data(need_write)
144+
return
164145

165146
async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, sampling_params: SamplingParams):
166147
# 只有 P 和 NORMAL 节点需要真的管理多模态资源

0 commit comments

Comments
 (0)