Skip to content

Commit

Permalink
fix setup logger bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
wnma3mz committed Jan 28, 2025
1 parent 9162adc commit d009edf
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 20 deletions.
4 changes: 2 additions & 2 deletions benchmarks/run_async_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,5 @@ async def main(messages_list: List[List[Dict[str, Any]]]):


if __name__ == "__main__":
# asyncio.run(main(llm_message()))
asyncio.run(main(mllm_message()))
asyncio.run(main(llm_message()))
# asyncio.run(main(mllm_message()))
4 changes: 2 additions & 2 deletions examples/run_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,6 @@ async def image_generate(args):

if __name__ == "__main__":
args = parse_args()
# asyncio.run(llm_generate(args, llm_message()))
asyncio.run(llm_generate(args, mllm_message()))
asyncio.run(llm_generate(args, llm_message()))
# asyncio.run(llm_generate(args, mllm_message()))
# asyncio.run(image_generate(args))
3 changes: 1 addition & 2 deletions tllm/generate/message_processor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Tuple

from PIL import Image
from PIL.ImageFile import ImageFile
Expand Down
3 changes: 1 addition & 2 deletions tllm/models/mlx/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def sdap(q, k, v, scale, mask):
out = mx.matmul(scores, v[:, None])
# 展平结果为[L, H, D]
out = mx.flatten(out, 0, 1)

else: # 标准注意力计算
scores = mx.matmul(q, mx.swapaxes(k, -1, -2))
scores = scores + mask
Expand Down Expand Up @@ -204,7 +203,7 @@ def _rope(self, xs: mx.array, request_cache: RequestsCache, uuid_list: List[str]
x_list = []
start = 0
for uuid, offset in zip(uuid_list, offset_list):
end = start + request_cache.get_seq_len(uuid)
end = start + request_cache.get_q_len(uuid)
x_list.append(self.rope(xs[start:end].transpose(1, 0, 2), offset).transpose(1, 0, 2))
start = end
return cat_func(x_list)
Expand Down
6 changes: 3 additions & 3 deletions tllm/models/mlx/llama.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Dict

import mlx.core as mx
import mlx.nn as nn
Expand Down Expand Up @@ -101,7 +101,7 @@ def from_pretrained(cls, config: AutoConfig, state_dict: Dict[str, mx.array], **
state_dict = model.merge_weights(state_dict, is_merge)

model = quantization_func(config, model, state_dict)
model.load_weights(list(state_dict.items()), strict=False)
model.load_weights(list(state_dict.items()))

mx.eval(model.parameters())
model.eval()
Expand Down Expand Up @@ -137,7 +137,7 @@ def __init__(self, config):
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

@classmethod
def from_pretrained(cls, config, state_dict: Optional[Any], **kwargs):
def from_pretrained(cls, config, state_dict, **kwargs):
model = cls(config)

cls.config = config
Expand Down
18 changes: 9 additions & 9 deletions tllm/singleton_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,19 @@ def _add_file_handler(cls, logger: logging.Logger):
logger.addHandler(file_handler)

@classmethod
def _setup_logger(cls, name=Literal["master", "handler"]) -> logging.Logger:
if cls.logger is None:
def _setup_logger(cls, name: Literal["master", "handler"]) -> logging.Logger:
if cls.logger is None: # 仅在第一次调用时创建logger
cls.logger = logging.getLogger(name)
cls.logger.setLevel(cls._level)

ch = logging.StreamHandler()
ch.setLevel(cls._level)
if not cls.logger.hasHandlers(): # 检查是否已经存在 handlers
ch = logging.StreamHandler()
ch.setLevel(cls._level)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
ch.setFormatter(formatter)
cls.logger.addHandler(ch)

formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
ch.setFormatter(formatter)

cls.logger.addHandler(ch)
cls._add_file_handler(cls.logger)
cls._add_file_handler(cls.logger) # 始终添加文件handler
return cls.logger

@classmethod
Expand Down

0 comments on commit d009edf

Please sign in to comment.