From d009edfea772fbe51fc95f0f277228f5acfef1e3 Mon Sep 17 00:00:00 2001 From: lujianghu Date: Tue, 28 Jan 2025 18:47:15 +0800 Subject: [PATCH] fix setup logger bugs --- benchmarks/run_async_requests.py | 4 ++-- examples/run_engine.py | 4 ++-- tllm/generate/message_processor.py | 3 +-- tllm/models/mlx/layers.py | 3 +-- tllm/models/mlx/llama.py | 6 +++--- tllm/singleton_logger.py | 18 +++++++++--------- 6 files changed, 18 insertions(+), 20 deletions(-) diff --git a/benchmarks/run_async_requests.py b/benchmarks/run_async_requests.py index 6553cc9..672aed3 100644 --- a/benchmarks/run_async_requests.py +++ b/benchmarks/run_async_requests.py @@ -74,5 +74,5 @@ async def main(messages_list: List[List[Dict[str, Any]]]): if __name__ == "__main__": - # asyncio.run(main(llm_message())) - asyncio.run(main(mllm_message())) + asyncio.run(main(llm_message())) + # asyncio.run(main(mllm_message())) diff --git a/examples/run_engine.py b/examples/run_engine.py index 42775ff..2f0b5c9 100644 --- a/examples/run_engine.py +++ b/examples/run_engine.py @@ -131,6 +131,6 @@ async def image_generate(args): if __name__ == "__main__": args = parse_args() - # asyncio.run(llm_generate(args, llm_message())) - asyncio.run(llm_generate(args, mllm_message())) + asyncio.run(llm_generate(args, llm_message())) + # asyncio.run(llm_generate(args, mllm_message())) # asyncio.run(image_generate(args)) diff --git a/tllm/generate/message_processor.py b/tllm/generate/message_processor.py index cf3e443..97907b2 100644 --- a/tllm/generate/message_processor.py +++ b/tllm/generate/message_processor.py @@ -1,5 +1,4 @@ -import copy -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Tuple from PIL import Image from PIL.ImageFile import ImageFile diff --git a/tllm/models/mlx/layers.py b/tllm/models/mlx/layers.py index bd2e872..2f5cb59 100644 --- a/tllm/models/mlx/layers.py +++ b/tllm/models/mlx/layers.py @@ -104,7 +104,6 @@ def sdap(q, k, v, scale, mask): out = mx.matmul(scores, v[:, None]) # 展平结果为[L, H, D] out = mx.flatten(out, 0, 1) - else: # 标准注意力计算 scores = mx.matmul(q, mx.swapaxes(k, -1, -2)) scores = scores + mask @@ -204,7 +203,7 @@ def _rope(self, xs: mx.array, request_cache: RequestsCache, uuid_list: List[str] x_list = [] start = 0 for uuid, offset in zip(uuid_list, offset_list): - end = start + request_cache.get_seq_len(uuid) + end = start + request_cache.get_q_len(uuid) x_list.append(self.rope(xs[start:end].transpose(1, 0, 2), offset).transpose(1, 0, 2)) start = end return cat_func(x_list) diff --git a/tllm/models/mlx/llama.py b/tllm/models/mlx/llama.py index 582ec96..d8648fe 100644 --- a/tllm/models/mlx/llama.py +++ b/tllm/models/mlx/llama.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Dict import mlx.core as mx import mlx.nn as nn @@ -101,7 +101,7 @@ def from_pretrained(cls, config: AutoConfig, state_dict: Dict[str, mx.array], ** state_dict = model.merge_weights(state_dict, is_merge) model = quantization_func(config, model, state_dict) - model.load_weights(list(state_dict.items()), strict=False) + model.load_weights(list(state_dict.items())) mx.eval(model.parameters()) model.eval() @@ -137,7 +137,7 @@ def __init__(self, config): self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @classmethod - def from_pretrained(cls, config, state_dict: Optional[Any], **kwargs): + def from_pretrained(cls, config, state_dict, **kwargs): model = cls(config) cls.config = config diff --git a/tllm/singleton_logger.py b/tllm/singleton_logger.py index 2a6f9fb..ad1a9dc 100644 --- a/tllm/singleton_logger.py +++ b/tllm/singleton_logger.py @@ -65,19 +65,19 @@ def _add_file_handler(cls, logger: logging.Logger): logger.addHandler(file_handler) @classmethod - def _setup_logger(cls, name=Literal["master", "handler"]) -> logging.Logger: - if cls.logger is None: + def _setup_logger(cls, name: Literal["master", "handler"]) -> logging.Logger: + if cls.logger is None: # 仅在第一次调用时创建logger cls.logger = logging.getLogger(name) cls.logger.setLevel(cls._level) - ch = logging.StreamHandler() - ch.setLevel(cls._level) + if not cls.logger.hasHandlers(): # 检查是否已经存在 handlers + ch = logging.StreamHandler() + ch.setLevel(cls._level) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + ch.setFormatter(formatter) + cls.logger.addHandler(ch) - formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") - ch.setFormatter(formatter) - - cls.logger.addHandler(ch) - cls._add_file_handler(cls.logger) + cls._add_file_handler(cls.logger) # 始终添加文件handler return cls.logger @classmethod