From 2ac8c980e5e837ccdd1cde31e9b833808dfd3a44 Mon Sep 17 00:00:00 2001 From: lujianghu Date: Wed, 29 Jan 2025 11:41:28 +0800 Subject: [PATCH] clean code --- README.md | 6 +- examples/run_engine.py => run_engine.py | 12 +++- tllm/generate/token_utils.py | 9 --- tllm/shared_memory.py | 89 ------------------------- 4 files changed, 13 insertions(+), 103 deletions(-) rename examples/run_engine.py => run_engine.py (90%) delete mode 100644 tllm/shared_memory.py diff --git a/README.md b/README.md index 6d180c4..7d1246e 100644 --- a/README.md +++ b/README.md @@ -11,13 +11,13 @@ 2. run server - 2.1 (no communication) + 2.1 (localhost) ```bash tllm.server --model_path mlx-community/Llama-3.2-1B-Instruct-4bit --hostname localhost --is_local --client_size 1 ``` - 2.2 (with communication) + 2.2 (multi clients) ```bash # first in one terminal @@ -26,6 +26,8 @@ # in another terminal tllm.client --hostname http://$YOUR_IP:8022 ``` + + 3. testing ```bash diff --git a/examples/run_engine.py b/run_engine.py similarity index 90% rename from examples/run_engine.py rename to run_engine.py index 2f0b5c9..945fa20 100644 --- a/examples/run_engine.py +++ b/run_engine.py @@ -19,6 +19,7 @@ def parse_args(): help="Attention backend if backend is TORCH", ) parser.add_argument("--model_path", type=str, default="Qwen/Qwen2-VL-2B-Instruct") + parser.add_argument("--message_type", type=str, default="llm", choices=["llm", "mllm", "image"]) return parser.parse_args() @@ -131,6 +132,11 @@ async def image_generate(args): if __name__ == "__main__": args = parse_args() - asyncio.run(llm_generate(args, llm_message())) - # asyncio.run(llm_generate(args, mllm_message())) - # asyncio.run(image_generate(args)) + if args.message_type == "llm": + asyncio.run(llm_generate(args, llm_message())) + elif args.message_type == "mllm": + asyncio.run(llm_generate(args, mllm_message())) + elif args.message_type == "image": + asyncio.run(image_generate(args)) + else: + raise ValueError(f"Unknown message type: {args.message_type}") \ No newline at end of file diff --git a/tllm/generate/token_utils.py b/tllm/generate/token_utils.py index 376245a..1c0beda 100644 --- a/tllm/generate/token_utils.py +++ b/tllm/generate/token_utils.py @@ -29,15 +29,6 @@ def preprocess( input_ids = self.tokenizer.encode(text, add_special_tokens=False) return TokenizerResult(input_ids=input_ids, input_str=text) - def preprocess_old(self, text: str = None, messages: List[List[Dict[str, str]]] = None) -> TokenizerResult: - formatted_prompt = "### Human: {}### Assistant:" - - if messages: - text = formatted_prompt.format(messages[0]["content"]) - input_ids = self.tokenizer.encode(text, add_special_tokens=True) - while input_ids[0] == input_ids[1] == self.tokenizer.bos_token_id: - input_ids.pop(0) - return TokenizerResult(input_ids=input_ids, input_str=text) def decode( self, token_ids: List[int], cache_token_ids: List[Optional[List[int]]] diff --git a/tllm/shared_memory.py b/tllm/shared_memory.py deleted file mode 100644 index 7dc1e58..0000000 --- a/tllm/shared_memory.py +++ /dev/null @@ -1,89 +0,0 @@ -# shared_memory.py -import ctypes -from multiprocessing import shared_memory - - -class RingBuffer: - # 缓冲区头部结构 - class Header(ctypes.Structure): - _fields_ = [ - ("write_idx", ctypes.c_uint64), # 写指针 - ("read_idx", ctypes.c_uint64), # 读指针 - ("buf_size", ctypes.c_uint64), # 缓冲区大小 - ] - - def __init__(self, name, size=1024 * 1024): # 默认1MB - self.buf_size = size - self.name = name - - try: - # 尝试连接到已存在的共享内存 - self.shm = shared_memory.SharedMemory(name=name) - self.is_creator = False - except FileNotFoundError: - # 创建新的共享内存 - self.shm = shared_memory.SharedMemory(name=name, create=True, size=ctypes.sizeof(self.Header) + size) - self.is_creator = True - - # 初始化头部 - header = self.Header.from_buffer(self.shm.buf) - header.write_idx = 0 - header.read_idx = 0 - header.buf_size = size - - self.header = self.Header.from_buffer(self.shm.buf) - self.buffer = memoryview(self.shm.buf)[ctypes.sizeof(self.Header) :] - - def write(self, data: bytes) -> bool: - data_size = len(data) - if data_size > self.buf_size: - return False - - # 计算可用空间 - write_idx = self.header.write_idx % self.buf_size - read_idx = self.header.read_idx % self.buf_size - - if write_idx >= read_idx: - available = self.buf_size - (write_idx - read_idx) - else: - available = read_idx - write_idx - - if data_size + 4 > available: # 4字节用于存储长度 - return False - - # 写入数据长度 - length_bytes = data_size.to_bytes(4, "little") - for i, b in enumerate(length_bytes): - self.buffer[(write_idx + i) % self.buf_size] = b - - # 写入数据 - for i, b in enumerate(data): - self.buffer[(write_idx + 4 + i) % self.buf_size] = b - - # 更新写指针 - self.header.write_idx = (write_idx + 4 + data_size) % self.buf_size - return True - - def read(self) -> bytes: - if self.header.read_idx == self.header.write_idx: - return None - - read_idx = self.header.read_idx % self.buf_size - - # 读取数据长度 - length_bytes = bytes(self.buffer[read_idx : read_idx + 4]) - data_size = int.from_bytes(length_bytes, "little") - - # 读取数据 - data = bytearray() - for i in range(data_size): - data.append(self.buffer[(read_idx + 4 + i) % self.buf_size]) - - # 更新读指针 - self.header.read_idx = (read_idx + 4 + data_size) % self.buf_size - return bytes(data) - - def close(self): - self.shm.close() - if self.is_creator: - self.shm.unlink()