From 765148145d9cda39234d84367fc8d605730c3229 Mon Sep 17 00:00:00 2001 From: Ychoukewen <2369589402@qq.com> Date: Thu, 8 May 2025 03:27:29 +0000 Subject: [PATCH 01/12] test.py --- test.py | 67 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 test.py diff --git a/test.py b/test.py new file mode 100644 index 0000000000..ea92f5df99 --- /dev/null +++ b/test.py @@ -0,0 +1,67 @@ +import datetime +import os +from lmdeploy import pipeline, PytorchEngineConfig +from lmdeploy.messages import GenerationConfig +from lmdeploy.serve.openai.api_server import handle_torchrun +import torch.distributed as dist + +def main(rank: int): + # model_path ='/nvme2/huggingface_hub_137_llm/hub/models--Qwen--Qwen2.5-7B-Instruct/snapshots/a09a35458c702b33eeacc393d103063234e8bc28' + # model_path ='/nvme1/zhaochaoxing/hub/models--deepseek-ai--DeepSeek-V3/snapshots/86518964eaef84e3fdd98e9861759a1384f9c29d' + model_path = '/mnt/huggingface_hub_137_llm/hub/models--deepseek-ai--DeepSeek-V3/snapshots/86518964eaef84e3fdd98e9861759a1384f9c29d' + # model_path = '/nvme2/huggingface_hub_137_llm/hub/models--deepseek-ai--DeepSeek-V2-Lite/snapshots/604d5664dddd88a0433dbae533b7fe9472482de0' + log_level = 'WARNING' + prompts = [ + 'fast fox jump over the lazy dog.', + 'fast fox jump over the lazy dog.', + 'fast fox jump over the lazy dog.', + 'fast fox jump over the lazy dog.', + 'fast fox jump over the lazy dog.', + 'fast fox jump over the lazy dog.', + 'fast fox jump over the lazy dog.', + 'fast fox jump over the lazy dog.', + ] + prompts1 = [ + 'The stars twinkle in the night sky.', + 'The stars twinkle in the night sky.', + 'The stars twinkle in the night sky.', + 'The stars twinkle in the night sky.', + 'The stars twinkle in the night sky.', + 'The stars twinkle in the night sky.', + 'The stars twinkle in the night sky.', + 'The stars twinkle in the night sky.', + ] + prompts = prompts1[rank:rank+1] + + backend_config = PytorchEngineConfig( + tp=1, + dp=4, + ep=4, + dp_rank=rank, + eager_mode=True, + ) + gen_config = GenerationConfig( + temperature=1.0, + top_k=1, + do_sample=True, + max_new_tokens=64, + ) + + os.environ['LMDEPLOY_DP_MASTER_ADDR'] = '10.130.8.148' + os.environ['LMDEPLOY_DP_MASTER_PORT'] = str(29551) + with pipeline(model_path, backend_config=backend_config, log_level=log_level) as pipe: + outputs = pipe(prompts, gen_config=gen_config) + print(outputs) + + dist.barrier() + +if __name__ == '__main__': + handle_torchrun() + rank = int(os.environ['LOCAL_RANK']) + print(f"local rank : {rank}") + os.environ['CUDA_VISIBLE_DEVICES'] = str(rank) + dist.init_process_group(backend='nccl', timeout=datetime.timedelta(90)) + try: + main(rank) + finally: + dist.destroy_process_group() \ No newline at end of file From 37d234c28353a248fe9a7ebabd0f8c931743d946 Mon Sep 17 00:00:00 2001 From: Ychoukewen <2369589402@qq.com> Date: Thu, 8 May 2025 04:31:43 +0000 Subject: [PATCH 02/12] adjust: change model layer count --- lmdeploy/pytorch/models/deepseek_v2.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py index 22b1382aed..9e6d8da43f 100644 --- a/lmdeploy/pytorch/models/deepseek_v2.py +++ b/lmdeploy/pytorch/models/deepseek_v2.py @@ -1113,6 +1113,7 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None): super().__init__() + config.num_hidden_layers = 4 self.config = config self.quantization_config = getattr(config, 'quantization_config', None) self.dtype = dtype @@ -1365,6 +1366,13 @@ def __skip_nextn(name, nextn_keys): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: + # zcx begin + strs = name.split(".") + if len(strs) >= 3 and str.isdigit(strs[2]): + layer_number = int(strs[2]) + if layer_number >= 4: + continue + # zcx end if 'rotary_emb.inv_freq' in name: continue if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name): From afc7c5ed0a17d546271d1965913ecab82e6efa7b Mon Sep 17 00:00:00 2001 From: Ychoukewen <2369589402@qq.com> Date: Thu, 8 May 2025 04:43:09 +0000 Subject: [PATCH 03/12] add _load_weight_experts_with_eplb in deepseek_v2.py --- lmdeploy/pytorch/models/deepseek_v2.py | 32 ++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py index 9e6d8da43f..12f8cad5e5 100644 --- a/lmdeploy/pytorch/models/deepseek_v2.py +++ b/lmdeploy/pytorch/models/deepseek_v2.py @@ -22,10 +22,13 @@ from lmdeploy.utils import get_logger from .utils.cudagraph import CudaGraphMixin +import os +enable_eplb = os.environ.get('EPLB_ENABLED', '0') == '1' logger = get_logger('lmdeploy') + # microbatch class ExecType(Enum): """batch ecex type.""" @@ -668,7 +671,10 @@ def __init__(self, config: Any, layer_idx, dtype: torch.dtype = None, device: to quantization_config = getattr(config, 'quantization_config', None) self.hidden_dim = config.hidden_size self.ffn_dim = config.moe_intermediate_size - self.num_experts = config.n_routed_experts + if enable_eplb: + self.num_experts = config.n_routed_experts + 32 + else: + self.num_experts = config.n_routed_experts self.top_k = config.num_experts_per_tok self.norm_topk_prob = config.norm_topk_prob self.routed_scaling_factor = config.routed_scaling_factor @@ -1364,6 +1370,23 @@ def __skip_nextn(name, nextn_keys): num_nextn_predict_layers = getattr(self.config, 'num_nextn_predict_layers', 1) nextn_keys = [f'.layers.{num_hidden_layers+i}' for i in range(num_nextn_predict_layers)] + # 分层行为:每层独立专家映射表 + layer_expert_params_mapping = {} + if enable_eplb: + moe_layers = [] + for layer_idx in range(num_hidden_layers): + if config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0: + moe_layers.append(layer_idx) + + for layer_idx in moe_layers: + expert_params = [] + for exp_id in range(num_experts): + gate_param = (f'.layers.{layer_idx}.mlp.experts.gate_up', f'.layers.{layer_idx}.mlp.experts.{exp_id}.gate_proj', exp_id, 'gate') + up_param = (f'.layers.{layer_idx}.mlp.experts.gate_up', f'.layers.{layer_idx}.mlp.experts.{exp_id}.up_proj', exp_id, 'up') + down_param = (f'.layers.{layer_idx}.mlp.experts.down', f'.layers.{layer_idx}.mlp.experts.{exp_id}.down_proj', exp_id, 'down') + expert_params += [gate_param, up_param, down_param] + layer_expert_params_mapping[layer_idx] = expert_params + params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: # zcx begin @@ -1386,7 +1409,12 @@ def __skip_nextn(name, nextn_keys): if name.endswith(scale_suffix): name = name[:-len(scale_suffix)] + '.scale' if '.experts' in name: - self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping) + if enable_eplb: + self._load_weight_experts_with_eplb(name, loaded_weight, params_dict, + layer_expert_params_mapping=layer_expert_params_mapping) + else: + self._load_weight_experts(name, loaded_weight, params_dict, + expert_params_mapping=expert_params_mapping) elif '.self_attn' in name and getattr(config, 'use_mla', True): # attention self._load_weight_attention(name, loaded_weight, params_dict, update_pe_mapping) From da496cf272a5260c59c46a2e411b6041b696621e Mon Sep 17 00:00:00 2001 From: Ychoukewen <2369589402@qq.com> Date: Thu, 8 May 2025 04:53:43 +0000 Subject: [PATCH 04/12] add enable_eplb in nn/moe.py --- lmdeploy/pytorch/nn/moe.py | 80 ++++++++++++++++--- .../weight_loader/model_weight_loader.py | 13 +++ 2 files changed, 80 insertions(+), 13 deletions(-) diff --git a/lmdeploy/pytorch/nn/moe.py b/lmdeploy/pytorch/nn/moe.py index 0bbc7575b1..f017912c1e 100644 --- a/lmdeploy/pytorch/nn/moe.py +++ b/lmdeploy/pytorch/nn/moe.py @@ -12,6 +12,10 @@ from ..backends import OpType, get_backend from .utils import div_up +import os +enable_eplb = os.environ.get('EPLB_ENABLED', '0') == '1' +from collections import defaultdict + class MoeType(Enum): """batch ecex type.""" @@ -73,7 +77,12 @@ def __init__(self, self.half_out = out_features // 2 if self.ep: - self.expert_map = dict((eid, idx) for idx, eid in enumerate(expert_list)) + if enable_eplb: + self.expert_map = defaultdict(list) + for idx, eid in enumerate(expert_list): + self.expert_map[eid].append(idx) + else: + self.expert_map = dict((eid, idx) for idx, eid in enumerate(expert_list)) self.weight.weight_loader = self.weight_loader_ep else: self.weight.weight_loader = self.weight_loader_tp @@ -110,16 +119,51 @@ def weight_loader_ep(self, param: torch.nn.Parameter, loaded_weight: torch.Tenso return expert_map = self.expert_map - param_id = expert_map[expert_id] - if shard_id == 'gate': - param_data = param.data[param_id, :self.half_out] - elif shard_id == 'up': - param_data = param.data[param_id, self.half_out:] - elif shard_id == 'down': - param_data = param.data[param_id] + + if not enable_eplb: + param_id = expert_map[expert_id] + if shard_id == 'gate': + param_data = param.data[param_id, :self.half_out] + elif shard_id == 'up': + param_data = param.data[param_id, self.half_out:] + elif shard_id == 'down': + param_data = param.data[param_id] + else: + raise RuntimeError(f'Unknown shard_id: {shard_id}') + param_data.copy_(loaded_weight) else: - raise RuntimeError(f'Unknown shard_id: {shard_id}') - param_data.copy_(loaded_weight) + param_ids = expert_map[expert_id] + for param_id in param_ids: + if param.data.dtype == torch.float8_e4m3fn: + # 临时转为 float16 做索引 + temp_param = param.data.to(torch.float16) + + if shard_id == 'gate': + param_data = temp_param[param_id, :self.half_out] + elif shard_id == 'up': + param_data = temp_param[param_id, self.half_out:] + elif shard_id == 'down': + param_data = temp_param[param_id] + else: + raise RuntimeError(f'Unknown shard_id: {shard_id}') + + # 将 loaded_weight 也转成 float16 + weight_to_copy = loaded_weight.to(torch.float16) + param_data.copy_(weight_to_copy) + + # 再写回原始 param.data(转换回 float8) + param.data.copy_(temp_param.to(torch.float8_e4m3fn)) + else: + if shard_id == 'gate': + param_data = param.data[param_id, :self.half_out] + elif shard_id == 'up': + param_data = param.data[param_id, self.half_out:] + elif shard_id == 'down': + param_data = param.data[param_id] + else: + raise RuntimeError(f'Unknown shard_id: {shard_id}') + param_data.copy_(loaded_weight.to(param_data.dtype)) + print(f"[Rank {rank}] ✅ Loaded Expert {expert_id} for Layer {layer_idx} ({shard_id}) shape={param_data.shape}") def _gather_input(x: torch.Tensor, tp_sizes: List[int]): @@ -428,7 +472,12 @@ def __init__(self, self.register_parameter('scale', scale) if self.ep: - self.expert_map = dict((eid, idx) for idx, eid in enumerate(expert_list)) + if enable_eplb: + self.expert_map = defaultdict(list) + for idx, eid in enumerate(expert_list): + self.expert_map[eid].append(idx) + else: + self.expert_map = dict((eid, idx) for idx, eid in enumerate(expert_list)) self.scale.weight_loader = self.weight_loader_scale_ep else: self.scale.weight_loader = self.weight_loader_scale_tp @@ -446,8 +495,13 @@ def weight_loader_scale_ep(self, param: torch.nn.Parameter, loaded_weight: torch expert_list = self.expert_list if expert_id not in expert_list: return - expert_id = self.expert_map[expert_id] - self.weight_loader_scale_tp(param, loaded_weight, expert_id, shard_id) + if not enable_eplb: + expert_id = self.expert_map[expert_id] + self.weight_loader_scale_tp(param, loaded_weight, expert_id, shard_id) + else: + expert_ids = self.expert_map[expert_id] + for expert_id in expert_ids: + self.weight_loader_scale_tp(param, loaded_weight, expert_id, shard_id) def weight_loader_scale_tp(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, shard_id: str): diff --git a/lmdeploy/pytorch/weight_loader/model_weight_loader.py b/lmdeploy/pytorch/weight_loader/model_weight_loader.py index b3a22988ae..e242cdbf58 100644 --- a/lmdeploy/pytorch/weight_loader/model_weight_loader.py +++ b/lmdeploy/pytorch/weight_loader/model_weight_loader.py @@ -16,6 +16,19 @@ def load_weight(param: torch.nn.Parameter, loaded_weight: torch.Tensor, **kwargs): """load weight.""" + # expert_id = kwargs.get('expert_id', None) + # # for debug + # shard_id = kwargs.get('shard_id', '?') + # rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + + # if expert_id is not None and hasattr(param, 'expert_list'): + # if expert_id not in param.expert_list: + # print(f"[Rank {rank}] 🔁 Skip Expert {expert_id} for param {param.shape}") + # return + # else: + # layer_idx = getattr(param, 'layer_idx', '?') + # print(f"[Rank {rank}] ✅ Load Expert {expert_id} for Layer {layer_idx} ({shard_id})") + if hasattr(param, 'weight_loader'): param.weight_loader(param, loaded_weight, **kwargs) else: From 1cebe01423c3aedb6c0ea1ec816f96bc2d459e89 Mon Sep 17 00:00:00 2001 From: Ychoukewen <2369589402@qq.com> Date: Thu, 8 May 2025 04:57:58 +0000 Subject: [PATCH 05/12] add layer_idx in nn/moe.py --- lmdeploy/pytorch/nn/moe.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/lmdeploy/pytorch/nn/moe.py b/lmdeploy/pytorch/nn/moe.py index f017912c1e..19d7250e27 100644 --- a/lmdeploy/pytorch/nn/moe.py +++ b/lmdeploy/pytorch/nn/moe.py @@ -66,10 +66,14 @@ def __init__(self, dtype: torch.dtype, device: torch.device, expert_list: List[int] = None, - ep: bool = False): + ep: bool = False, + layer_idx: int = 0): super().__init__() weight = torch.empty((num_experts, out_features, in_features), dtype=dtype, device=device) weight = torch.nn.Parameter(weight, requires_grad=False) + if ep and enable_eplb and expert_list is not None: + weight.expert_list = expert_list # ✅ 添加这一行,仅在 EPLB 时才加 + weight.layer_idx = layer_idx self.register_parameter('weight', weight) self.ep = ep self.expert_list = expert_list @@ -453,7 +457,8 @@ def __init__(self, dtype: torch.dtype, device: torch.device, expert_list: List[int] = None, - ep: bool = False): + ep: bool = False, + layer_idx: int = 0): super().__init__( num_experts=num_experts, in_features=in_features, @@ -463,6 +468,7 @@ def __init__(self, device=device, expert_list=expert_list, ep=ep, + layer_idx=layer_idx, ) self.block_size = block_size scale = torch.empty((num_experts, div_up(out_features, block_size), div_up(in_features, block_size)), @@ -573,7 +579,8 @@ def __init__(self, dtype=fp8_dtype, device=device, expert_list=expert_list, - ep=self.ep_size > 1) + ep=self.ep_size > 1, + layer_idx=layer_idx) self.down = LinearWeightsBlockedF8( num_experts, ffn_dim, @@ -584,7 +591,7 @@ def __init__(self, device=device, expert_list=expert_list, ep=self.ep_size > 1, - ) + layer_idx=layer_idx) self.hidden_dim = hidden_dim self.ffn_dim = ffn_dim From d5dd6df1a2fde12702c9a82c0134638866b019da Mon Sep 17 00:00:00 2001 From: Ychoukewen <2369589402@qq.com> Date: Thu, 8 May 2025 05:14:14 +0000 Subject: [PATCH 06/12] add rank_expert_list print in nn/moe.py --- lmdeploy/pytorch/nn/moe.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lmdeploy/pytorch/nn/moe.py b/lmdeploy/pytorch/nn/moe.py index 19d7250e27..eb9a3f3a0e 100644 --- a/lmdeploy/pytorch/nn/moe.py +++ b/lmdeploy/pytorch/nn/moe.py @@ -74,6 +74,8 @@ def __init__(self, if ep and enable_eplb and expert_list is not None: weight.expert_list = expert_list # ✅ 添加这一行,仅在 EPLB 时才加 weight.layer_idx = layer_idx + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + print(f"[Rank {rank}] ✅ Load Expert {expert_list} for Layer {layer_idx} ({weight_type})") self.register_parameter('weight', weight) self.ep = ep self.expert_list = expert_list From a3e5531422ba42bf782d3a2677dd300a95c218e2 Mon Sep 17 00:00:00 2001 From: Ychoukewen <2369589402@qq.com> Date: Thu, 8 May 2025 06:36:55 +0000 Subject: [PATCH 07/12] add expert_list to FusedMoENormal class and FusedMoELowLatency class in backends/cuda/moe.py --- ep_mapping_json_gen.py | 18 ++ ep_mapping_json_path.json | 262 ++++++++++++++++++++++ lmdeploy/pytorch/backends/cuda/moe.py | 275 +++++++++++++++++++++++- lmdeploy/pytorch/kernels/cuda/ep_moe.py | 57 +++++ 4 files changed, 610 insertions(+), 2 deletions(-) create mode 100644 ep_mapping_json_gen.py create mode 100644 ep_mapping_json_path.json diff --git a/ep_mapping_json_gen.py b/ep_mapping_json_gen.py new file mode 100644 index 0000000000..411524b131 --- /dev/null +++ b/ep_mapping_json_gen.py @@ -0,0 +1,18 @@ +import numpy as np +import json + +num_experts = 12 +np.random.seed(42) # 保持可复现 + +# 随机整数负载:每个专家处理 80~160 个 token +weight = np.random.randint(low=80, high=161, size=num_experts).tolist() + +data = { + "num_groups": 4, + "num_nodes": 1, + "weight": weight +} + +with open("/nvme1/liudongyan/workspace/lmdeploy/ep_mapping_json_path_logicexp12.json", "w") as f: + json.dump(data, f, indent=2) +print("JSON 写入完成, weight 总和 =", sum(weight)) diff --git a/ep_mapping_json_path.json b/ep_mapping_json_path.json new file mode 100644 index 0000000000..2058c3fb20 --- /dev/null +++ b/ep_mapping_json_path.json @@ -0,0 +1,262 @@ +{ + "num_groups": 4, + "num_nodes": 1, + "weight": [[ + 131, + 94, + 151, + 140, + 100, + 154, + 154, + 103, + 82, + 101, + 132, + 81, + 109, + 117, + 81, + 143, + 139, + 100, + 112, + 155, + 137, + 101, + 128, + 138, + 121, + 139, + 159, + 94, + 141, + 141, + 126, + 141, + 130, + 134, + 143, + 82, + 130, + 86, + 100, + 152, + 118, + 97, + 83, + 139, + 93, + 88, + 132, + 81, + 139, + 150, + 123, + 87, + 126, + 114, + 157, + 160, + 115, + 129, + 83, + 81, + 85, + 133, + 83, + 133, + 142, + 97, + 123, + 113, + 153, + 141, + 93, + 127, + 94, + 151, + 157, + 141, + 119, + 159, + 132, + 103, + 105, + 139, + 120, + 108, + 94, + 124, + 144, + 150, + 88, + 80, + 87, + 142, + 90, + 160, + 87, + 114, + 114, + 112, + 84, + 120, + 107, + 86, + 152, + 151, + 91, + 113, + 112, + 127, + 102, + 141, + 116, + 123, + 114, + 144, + 126, + 157, + 82, + 80, + 84, + 93, + 106, + 88, + 158, + 94, + 121, + 156, + 130, + 142, + 131, + 83, + 102, + 94, + 122, + 108, + 115, + 92, + 111, + 150, + 138, + 107, + 145, + 121, + 124, + 141, + 136, + 85, + 107, + 107, + 123, + 109, + 141, + 154, + 141, + 80, + 106, + 141, + 156, + 82, + 149, + 151, + 106, + 88, + 141, + 116, + 130, + 123, + 103, + 158, + 138, + 111, + 131, + 141, + 137, + 131, + 91, + 118, + 81, + 82, + 135, + 160, + 138, + 81, + 81, + 133, + 80, + 98, + 81, + 132, + 123, + 111, + 149, + 111, + 147, + 134, + 154, + 135, + 96, + 117, + 103, + 148, + 149, + 90, + 95, + 152, + 138, + 149, + 159, + 82, + 99, + 138, + 115, + 98, + 146, + 98, + 99, + 150, + 131, + 112, + 119, + 118, + 80, + 90, + 136, + 129, + 102, + 110, + 121, + 86, + 95, + 139, + 81, + 80, + 127, + 91, + 148, + 116, + 111, + 88, + 98, + 127, + 159, + 82, + 99, + 103, + 133, + 112, + 103, + 154, + 151, + 115, + 117, + 104, + 97, + 145, + 133, + 114 + ]] +} \ No newline at end of file diff --git a/lmdeploy/pytorch/backends/cuda/moe.py b/lmdeploy/pytorch/backends/cuda/moe.py index d4c4a2a83e..dedea5431e 100644 --- a/lmdeploy/pytorch/backends/cuda/moe.py +++ b/lmdeploy/pytorch/backends/cuda/moe.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Optional +from typing import List, Optional, Tuple import torch import torch.distributed as dist @@ -10,7 +10,7 @@ from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import fused_moe_blocked_fp8 from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8 from lmdeploy.pytorch.kernels.cuda.ep_moe import (grouped_gemm_triton, silu_and_mul_masked_post_quant_fwd, - silu_and_mul_triton_kernel) + silu_and_mul_triton_kernel, map_logic_to_physical_idx_hash_random) from lmdeploy.pytorch.kernels.cuda.fused_moe import _renormalize from lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import per_token_quant_int8 from lmdeploy.pytorch.model_inputs import get_step_ctx_manager @@ -417,6 +417,139 @@ def __init__(self, params_dtype=out_dtype, ) + def balanced_packing(self, weight: torch.Tensor, num_packs: int) -> Tuple[torch.Tensor, torch.Tensor]: + + num_layers, num_groups = weight.shape + assert num_groups % num_packs == 0 + groups_per_pack = num_groups // num_packs + + if groups_per_pack == 1: + pack_index = torch.arange(weight.size(-1), dtype=torch.int64, device=weight.device).expand(weight.shape) + rank_in_pack = torch.zeros_like(weight, dtype=torch.int64) + return pack_index, rank_in_pack + + indices = weight.float().sort(-1, descending=True).indices.cpu() + pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device='cpu') + rank_in_pack = torch.full_like(pack_index, fill_value=-1) + for i in range(num_layers): + pack_weights = [0] * num_packs + pack_items = [0] * num_packs + for group in indices[i]: + pack = min((i for i in range(num_packs) if pack_items[i] < groups_per_pack), + key=pack_weights.__getitem__) + assert pack_items[pack] < groups_per_pack + pack_index[i, group] = pack + rank_in_pack[i, group] = pack_items[pack] + pack_weights[pack] += weight[i, group] + pack_items[pack] += 1 + return pack_index, rank_in_pack + + def replicate_experts(self, weight: torch.Tensor, num_phy: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + n, num_log = weight.shape + num_redundant = num_phy - num_log + assert num_redundant >= 0 + device = weight.device + phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1) + rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) + logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) + arangen = torch.arange(n, dtype=torch.int64, device=device) + for i in range(num_log, num_phy): + redundant_indices = (weight / logcnt).max(dim=-1).indices + phy2log[:, i] = redundant_indices + rank[:, i] = logcnt[arangen, redundant_indices] + logcnt[arangen, redundant_indices] += 1 + return phy2log, rank, logcnt + + def rebalance_experts_hierarchical(self, weight: torch.Tensor, num_physical_experts: int, num_groups: int, + num_nodes: int, num_gpus: int): + + num_layers, num_logical_experts = weight.shape + assert num_logical_experts % num_groups == 0 + group_size = num_logical_experts // num_groups + assert num_groups % num_nodes == 0 + groups_per_node = num_groups // num_nodes + assert num_gpus % num_nodes == 0 + assert num_physical_experts % num_gpus == 0 + phy_experts_per_gpu = num_physical_experts // num_gpus + + def inverse(perm: torch.Tensor) -> torch.Tensor: + inv = torch.empty_like(perm) + inv.scatter_(1, perm, torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand(perm.shape)) + return inv + + tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) + group_pack_index, group_rank_in_pack = self.balanced_packing(tokens_per_group, num_nodes) + log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) * group_size).unsqueeze(-1) + + torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device)).flatten(-2) + mlog2log = inverse(log2mlog) + tokens_per_mlog = weight.gather(-1, mlog2log).view(-1, num_logical_experts // num_nodes) + phy2mlog, phyrank, mlogcnt = self.replicate_experts(tokens_per_mlog, num_physical_experts // num_nodes) + tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) + pack_index, rank_in_pack = self.balanced_packing(tokens_per_phy, num_gpus // num_nodes) + phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack + pphy2phy = inverse(phy2pphy) + pphy2mlog = phy2mlog.gather(-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes] + pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + + torch.arange(0, num_logical_experts, num_logical_experts // num_nodes).view(1, -1, 1)).flatten(-2) + pphy2log = mlog2log.gather(-1, pphy2mlog) + pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) + logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) + return pphy2log, pphyrank, logcnt + + def rebalance_experts(self, weight: torch.Tensor, num_replicas: int, num_groups: int, num_nodes: int, + num_gpus: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + num_layers, num_logical_experts = weight.shape + weight = weight.float().cpu() + if num_groups % num_nodes == 0: + # use hierarchical load-balance policy + phy2log, phyrank, logcnt = self.rebalance_experts_hierarchical(weight, num_replicas, num_groups, num_nodes, + num_gpus) + else: + # use global load-balance policy + phy2log, phyrank, logcnt = self.replicate_experts(weight, num_replicas) + maxlogcnt = logcnt.max().item() + log2phy: torch.Tensor = torch.full((num_layers, num_logical_experts, maxlogcnt), + -1, + dtype=torch.int64, + device=logcnt.device) + log2phy.view(num_layers, -1).scatter_( + -1, phy2log * maxlogcnt + phyrank, + torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(num_layers, -1)) + return phy2log, log2phy, logcnt + + def _load_ep_mapping(self, json_path: str): + """Load expert partition metadata from JSON (class-internal).""" + import json + with open(json_path, 'r') as f: + data = json.load(f) + num_groups = data['num_groups'] + num_nodes = data['num_nodes'] + weight = torch.tensor(data['weight'], dtype=torch.float32, device='cuda') + return num_groups, num_nodes, weight + + def ep_expert_list(self, world_size: int, rank: int, num_groups: int=None, num_nodes: int=None, weight: torch.Tensor=None): + """experts list of current rank.""" + if enable_eplb: + num_groups, num_nodes, weight = self._load_ep_mapping("ep_mapping_json_prefill.json") + phy2log, log2phy, logcnt = self.rebalance_experts(weight, num_experts, num_groups, num_nodes, self.world_size) + self.phy2log = phy2log[0].to('cuda') + self.log2phy = log2phy[0].to('cuda') + self.logcnt = logcnt[0].to('cuda') + expert_per_rank = (self.num_experts + world_size - 1) // world_size + first_expert = rank * expert_per_rank + last_expert = min(first_expert + expert_per_rank, self.num_experts) + sliced_phy2log = self.phy2log[first_expert:last_expert].tolist() + + return sliced_phy2log + else: + num_experts = self.num_experts + expert_per_rank = (num_experts + world_size - 1) // world_size + first_expert = rank * expert_per_rank + last_expert = min(first_expert + expert_per_rank, num_experts) + return list(range(first_expert, last_expert)) + def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, @@ -427,6 +560,10 @@ def forward(self, down_scale: torch.Tensor, expert_list: List[int] = None): """forward.""" + if enable_eplb: + topk_ids = map_logic_to_physical_idx_hash_random(topk_ids, self.log2phy, self.logcnt) + else: + topk_ids = topk_ids recv_hidden_states, recv_topk_ids, recv_topk_weights, tokens_per_expert = self.token_dispatcher.dispatch( hidden_states, topk_ids, @@ -506,6 +643,136 @@ def __init__(self, params_dtype=out_dtype, ) + def balanced_packing(self, weight: torch.Tensor, num_packs: int) -> Tuple[torch.Tensor, torch.Tensor]: + + num_layers, num_groups = weight.shape + assert num_groups % num_packs == 0 + groups_per_pack = num_groups // num_packs + + if groups_per_pack == 1: + pack_index = torch.arange(weight.size(-1), dtype=torch.int64, device=weight.device).expand(weight.shape) + rank_in_pack = torch.zeros_like(weight, dtype=torch.int64) + return pack_index, rank_in_pack + + indices = weight.float().sort(-1, descending=True).indices.cpu() + pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device='cpu') + rank_in_pack = torch.full_like(pack_index, fill_value=-1) + for i in range(num_layers): + pack_weights = [0] * num_packs + pack_items = [0] * num_packs + for group in indices[i]: + pack = min((i for i in range(num_packs) if pack_items[i] < groups_per_pack), + key=pack_weights.__getitem__) + assert pack_items[pack] < groups_per_pack + pack_index[i, group] = pack + rank_in_pack[i, group] = pack_items[pack] + pack_weights[pack] += weight[i, group] + pack_items[pack] += 1 + return pack_index, rank_in_pack + + + def replicate_experts(self, weight: torch.Tensor, num_phy: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + n, num_log = weight.shape + num_redundant = num_phy - num_log + assert num_redundant >= 0 + device = weight.device + phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1) + rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) + logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) + arangen = torch.arange(n, dtype=torch.int64, device=device) + for i in range(num_log, num_phy): + redundant_indices = (weight / logcnt).max(dim=-1).indices + phy2log[:, i] = redundant_indices + rank[:, i] = logcnt[arangen, redundant_indices] + logcnt[arangen, redundant_indices] += 1 + return phy2log, rank, logcnt + + + def rebalance_experts_hierarchical(self, weight: torch.Tensor, num_physical_experts: int, num_groups: int, num_nodes: int, num_gpus: int): + + num_layers, num_logical_experts = weight.shape + assert num_logical_experts % num_groups == 0 + group_size = num_logical_experts // num_groups + assert num_groups % num_nodes == 0 + groups_per_node = num_groups // num_nodes + assert num_gpus % num_nodes == 0 + assert num_physical_experts % num_gpus == 0 + phy_experts_per_gpu = num_physical_experts // num_gpus + + def inverse(perm: torch.Tensor) -> torch.Tensor: + inv = torch.empty_like(perm) + inv.scatter_(1, perm, torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand(perm.shape)) + return inv + + tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) + group_pack_index, group_rank_in_pack = self.balanced_packing(tokens_per_group, num_nodes) + log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) * group_size).unsqueeze(-1) + + torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device)).flatten(-2) + mlog2log = inverse(log2mlog) + tokens_per_mlog = weight.gather(-1, mlog2log).view(-1, num_logical_experts // num_nodes) + phy2mlog, phyrank, mlogcnt = self.replicate_experts(tokens_per_mlog, num_physical_experts // num_nodes) + tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) + pack_index, rank_in_pack = self.balanced_packing(tokens_per_phy, num_gpus // num_nodes) + phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack + pphy2phy = inverse(phy2pphy) + pphy2mlog = phy2mlog.gather(-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes] + pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + + torch.arange(0, num_logical_experts, num_logical_experts // num_nodes).view(1, -1, 1)).flatten(-2) + pphy2log = mlog2log.gather(-1, pphy2mlog) + pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) + logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) + return pphy2log, pphyrank, logcnt + + def rebalance_experts(self, weight: torch.Tensor, num_replicas: int, num_groups: int, + num_nodes: int, num_gpus: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + num_layers, num_logical_experts = weight.shape + weight = weight.float().cpu() + if num_groups % num_nodes == 0: + # use hierarchical load-balance policy + phy2log, phyrank, logcnt = self.rebalance_experts_hierarchical(weight, num_replicas, + num_groups, num_nodes, num_gpus) + else: + # use global load-balance policy + phy2log, phyrank, logcnt = self.replicate_experts(weight, num_replicas) + maxlogcnt = logcnt.max().item() + log2phy: torch.Tensor = torch.full((num_layers, num_logical_experts, maxlogcnt), + -1, dtype=torch.int64, device=logcnt.device) + log2phy.view(num_layers, -1).scatter_(-1, phy2log * maxlogcnt + phyrank, + torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(num_layers, -1)) + return phy2log, log2phy, logcnt + + def _load_ep_mapping(self, json_path: str): + """Load expert partition metadata from JSON (class-internal).""" + import json + with open(json_path, 'r') as f: + data = json.load(f) + num_groups = data['num_groups'] + num_nodes = data['num_nodes'] + weight = torch.tensor(data['weight'], dtype=torch.float32, device='cuda') + return num_groups, num_nodes, weight + + def ep_expert_list(self, world_size: int, rank: int): + """experts list of current rank.""" + if enable_eplb: + num_groups, num_nodes, weight = self._load_ep_mapping("ep_mapping_json_decode.json") + phy2log, log2phy, logcnt = self.rebalance_experts(weight, num_experts, num_groups, num_nodes, world_size) + self.phy2log = phy2log[0].to('cuda') + self.log2phy = log2phy[0].to('cuda') + self.logcnt = logcnt[0].to('cuda') + expert_per_rank = (self.num_experts + world_size - 1) // world_size + first_expert = rank * expert_per_rank + last_expert = min(first_expert + expert_per_rank, self.num_experts) + sliced_phy2log = self.phy2log[first_expert:last_expert].tolist() + return sliced_phy2log + else: + num_experts = self.num_experts + expert_per_rank = (num_experts + world_size - 1) // world_size + first_expert = rank * expert_per_rank + last_expert = min(first_expert + expert_per_rank, num_experts) + return list(range(first_expert, last_expert)) + def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, @@ -516,6 +783,10 @@ def forward(self, down_scale: torch.Tensor, expert_list: List[int] = None): """forward.""" + if enable_eplb: + topk_ids = map_logic_to_physical_idx_hash_random(topk_ids, self.log2phy, self.logcnt) + else: + topk_ids = topk_ids recv_hidden_states, topk_idx, topk_weights, masked_m, expected_m = self.token_dispatcher.dispatch( hidden_states, topk_ids, diff --git a/lmdeploy/pytorch/kernels/cuda/ep_moe.py b/lmdeploy/pytorch/kernels/cuda/ep_moe.py index ad620ead1c..ef1acfe144 100644 --- a/lmdeploy/pytorch/kernels/cuda/ep_moe.py +++ b/lmdeploy/pytorch/kernels/cuda/ep_moe.py @@ -361,3 +361,60 @@ def silu_and_mul_masked_post_quant_fwd( NUM_STAGE=NUM_STAGES, num_warps=num_warps, ) + +@triton.jit +def kernel_map_logic_to_physical_hash(topk_idx_ptr, physical_idx_ptr, log2phy_ptr, logcnt_ptr, seed, num_tokens, + num_topk, num_logical_experts, max_replica, BLOCK: tl.constexpr): + pid = tl.program_id(0) + start_t = pid * BLOCK + end_t = tl.minimum(start_t + BLOCK, num_tokens) + + for t in range(start_t, end_t): + row_off = t * num_topk + for k in range(num_topk): + logic_exp = tl.load(topk_idx_ptr + row_off + k) + + # 条件判断不能用 continue,只能嵌套写 + if logic_exp >= 0: + cnt = tl.load(logcnt_ptr + logic_exp) + if cnt > 0: + # 构造 hash 值 + combined = ((t << 16) ^ (k << 8) ^ seed) & 0xFFFFFFFF + x = combined + x = ((x >> 16) ^ x) * 0x45d9f3b + x = ((x >> 16) ^ x) * 0x45d9f3b + x = (x >> 16) ^ x + rand_val = x & 0x7fffffff + + replica_id = rand_val % cnt.to(tl.uint32) + phy_id_addr = log2phy_ptr + logic_exp * max_replica + replica_id + phy_id = tl.load(phy_id_addr) + tl.store(physical_idx_ptr + row_off + k, phy_id) + else: + tl.store(physical_idx_ptr + row_off + k, -1) + else: + tl.store(physical_idx_ptr + row_off + k, -1) + + +def map_logic_to_physical_idx_hash_random(topk_idx: torch.Tensor, + log2phy: torch.Tensor, + logcnt: torch.Tensor, + seed: int = 12345, + block_size: int = 128) -> torch.Tensor: + num_tokens, num_topk = topk_idx.shape + physical_idx = torch.empty_like(topk_idx, dtype=torch.int32, device=topk_idx.device) + + grid = ((num_tokens + block_size - 1) // block_size, ) + + kernel_map_logic_to_physical_hash[grid](topk_idx, + physical_idx, + log2phy, + logcnt, + seed, + num_tokens, + num_topk, + log2phy.shape[0], + log2phy.shape[1], + BLOCK=block_size) + + return physical_idx From 112e0aa795eba6bf5f2e454a7e0648b01351e667 Mon Sep 17 00:00:00 2001 From: Ychoukewen <2369589402@qq.com> Date: Thu, 8 May 2025 07:34:06 +0000 Subject: [PATCH 08/12] add layer_idx to FusedMoENormal class and ep_expert_list() --- ep_mapping_json_gen.py | 19 +++++++++++++------ lmdeploy/pytorch/backends/cuda/moe.py | 26 ++++++++++++++++---------- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/ep_mapping_json_gen.py b/ep_mapping_json_gen.py index 411524b131..7982cdf807 100644 --- a/ep_mapping_json_gen.py +++ b/ep_mapping_json_gen.py @@ -1,18 +1,25 @@ import numpy as np import json -num_experts = 12 +# 配置参数 +num_experts = 256 +layers = 12 # 层数 np.random.seed(42) # 保持可复现 -# 随机整数负载:每个专家处理 80~160 个 token -weight = np.random.randint(low=80, high=161, size=num_experts).tolist() +# 随机生成 weight:形状为 [layers, num_experts],每个专家处理 80~160 个 token +weight = np.random.randint(low=80, high=161, size=(layers, num_experts)).tolist() +# 构造数据 data = { "num_groups": 4, "num_nodes": 1, - "weight": weight + "weight": weight # weight 的形状为 [layers, num_experts] } -with open("/nvme1/liudongyan/workspace/lmdeploy/ep_mapping_json_path_logicexp12.json", "w") as f: +# 写入 JSON 文件 +output_path = "/nvme1/liudongyan/workspace/lmdeploy/ep_mapping_json_prefill.json" +with open(output_path, "w") as f: json.dump(data, f, indent=2) -print("JSON 写入完成, weight 总和 =", sum(weight)) + +# 打印信息 +print("JSON 写入完成, weight 总和 =", np.sum(weight)) \ No newline at end of file diff --git a/lmdeploy/pytorch/backends/cuda/moe.py b/lmdeploy/pytorch/backends/cuda/moe.py index dedea5431e..4c7e37479b 100644 --- a/lmdeploy/pytorch/backends/cuda/moe.py +++ b/lmdeploy/pytorch/backends/cuda/moe.py @@ -21,6 +21,8 @@ FusedMoEW8A8Impl) logger = get_logger('lmdeploy') +import os +enable_eplb = os.environ.get('EPLB_ENABLED', '0') == '1' class TritonFusedMoEImpl(FusedMoEImpl): @@ -407,7 +409,9 @@ def __init__(self, num_experts: int, hidden_dim: int, block_size: int = 128, - out_dtype: torch.dtype = torch.bfloat16): + out_dtype: torch.dtype = torch.bfloat16, + layer_idx: int = 0): + self.layer_idx = layer_idx self.experts = DeepEPExpertsGroupedGEMM(num_experts, ep_size, [block_size, block_size]) self.token_dispatcher = TokenDispatcherBuilder.build( group=ep_group, @@ -534,9 +538,9 @@ def ep_expert_list(self, world_size: int, rank: int, num_groups: int=None, num_n if enable_eplb: num_groups, num_nodes, weight = self._load_ep_mapping("ep_mapping_json_prefill.json") phy2log, log2phy, logcnt = self.rebalance_experts(weight, num_experts, num_groups, num_nodes, self.world_size) - self.phy2log = phy2log[0].to('cuda') - self.log2phy = log2phy[0].to('cuda') - self.logcnt = logcnt[0].to('cuda') + self.phy2log = phy2log[self.layer_idx].to('cuda') + self.log2phy = log2phy[self.layer_idx].to('cuda') + self.logcnt = logcnt[self.layer_idx].to('cuda') expert_per_rank = (self.num_experts + world_size - 1) // world_size first_expert = rank * expert_per_rank last_expert = min(first_expert + expert_per_rank, self.num_experts) @@ -632,7 +636,9 @@ def __init__(self, num_experts: int, hidden_dim: int, block_size: int = 128, - out_dtype: torch.dtype = torch.bfloat16): + out_dtype: torch.dtype = torch.bfloat16, + layer_idx: int = 0): + self.layer_idx = layer_idx self.num_experts = num_experts self.experts = DeepEPExpertsDeepGEMM(num_experts, ep_size, block_size, out_dtype) self.token_dispatcher = DeepEPTokenDispatcherLowLatency( @@ -758,9 +764,9 @@ def ep_expert_list(self, world_size: int, rank: int): if enable_eplb: num_groups, num_nodes, weight = self._load_ep_mapping("ep_mapping_json_decode.json") phy2log, log2phy, logcnt = self.rebalance_experts(weight, num_experts, num_groups, num_nodes, world_size) - self.phy2log = phy2log[0].to('cuda') - self.log2phy = log2phy[0].to('cuda') - self.logcnt = logcnt[0].to('cuda') + self.phy2log = phy2log[self.layer_idx].to('cuda') + self.log2phy = log2phy[self.layer_idx].to('cuda') + self.logcnt = logcnt[self.layer_idx].to('cuda') expert_per_rank = (self.num_experts + world_size - 1) // world_size first_expert = rank * expert_per_rank last_expert = min(first_expert + expert_per_rank, self.num_experts) @@ -904,10 +910,10 @@ def fusedmoe_build(self, low_latency_mode: bool = False): chunk_size=16 * 1024) elif low_latency_mode: return FusedMoELowLatency(self.ep_size, self.ep_group, self.num_experts, self.hidden_dim, self.block_size, - self.out_dtype) + self.out_dtype, layer_idx=self.layer_idx) else: return FusedMoENormal(self.ep_size, self.ep_group, self.num_experts, self.hidden_dim, self.block_size, - self.out_dtype) + self.out_dtype, layer_idx=self.layer_idx) class TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder): From 983bf97c21543dfa0ac9bd58e3e3715021014935 Mon Sep 17 00:00:00 2001 From: Ychoukewen <2369589402@qq.com> Date: Thu, 8 May 2025 08:07:27 +0000 Subject: [PATCH 09/12] generate eplb.json --- ep_mapping_json_decode.json | 3102 +++++++++++++++++ ep_mapping_json_gen.py | 2 +- ep_mapping_json_path.json | 262 -- ep_mapping_json_prefill.json | 3102 +++++++++++++++++ lmdeploy/pytorch/backends/cuda/moe.py | 3 + lmdeploy/pytorch/models/deepseek_v2.py | 1 + .../weight_loader/model_weight_loader.py | 24 +- 7 files changed, 6221 insertions(+), 275 deletions(-) create mode 100644 ep_mapping_json_decode.json delete mode 100644 ep_mapping_json_path.json create mode 100644 ep_mapping_json_prefill.json diff --git a/ep_mapping_json_decode.json b/ep_mapping_json_decode.json new file mode 100644 index 0000000000..6bd9f883de --- /dev/null +++ b/ep_mapping_json_decode.json @@ -0,0 +1,3102 @@ +{ + "num_groups": 4, + "num_nodes": 1, + "weight": [ + [ + 131, + 94, + 151, + 140, + 100, + 154, + 154, + 103, + 82, + 101, + 132, + 81, + 109, + 117, + 81, + 143, + 139, + 100, + 112, + 155, + 137, + 101, + 128, + 138, + 121, + 139, + 159, + 94, + 141, + 141, + 126, + 141, + 130, + 134, + 143, + 82, + 130, + 86, + 100, + 152, + 118, + 97, + 83, + 139, + 93, + 88, + 132, + 81, + 139, + 150, + 123, + 87, + 126, + 114, + 157, + 160, + 115, + 129, + 83, + 81, + 85, + 133, + 83, + 133, + 142, + 97, + 123, + 113, + 153, + 141, + 93, + 127, + 94, + 151, + 157, + 141, + 119, + 159, + 132, + 103, + 105, + 139, + 120, + 108, + 94, + 124, + 144, + 150, + 88, + 80, + 87, + 142, + 90, + 160, + 87, + 114, + 114, + 112, + 84, + 120, + 107, + 86, + 152, + 151, + 91, + 113, + 112, + 127, + 102, + 141, + 116, + 123, + 114, + 144, + 126, + 157, + 82, + 80, + 84, + 93, + 106, + 88, + 158, + 94, + 121, + 156, + 130, + 142, + 131, + 83, + 102, + 94, + 122, + 108, + 115, + 92, + 111, + 150, + 138, + 107, + 145, + 121, + 124, + 141, + 136, + 85, + 107, + 107, + 123, + 109, + 141, + 154, + 141, + 80, + 106, + 141, + 156, + 82, + 149, + 151, + 106, + 88, + 141, + 116, + 130, + 123, + 103, + 158, + 138, + 111, + 131, + 141, + 137, + 131, + 91, + 118, + 81, + 82, + 135, + 160, + 138, + 81, + 81, + 133, + 80, + 98, + 81, + 132, + 123, + 111, + 149, + 111, + 147, + 134, + 154, + 135, + 96, + 117, + 103, + 148, + 149, + 90, + 95, + 152, + 138, + 149, + 159, + 82, + 99, + 138, + 115, + 98, + 146, + 98, + 99, + 150, + 131, + 112, + 119, + 118, + 80, + 90, + 136, + 129, + 102, + 110, + 121, + 86, + 95, + 139, + 81, + 80, + 127, + 91, + 148, + 116, + 111, + 88, + 98, + 127, + 159, + 82, + 99, + 103, + 133, + 112, + 103, + 154, + 151, + 115, + 117, + 104, + 97, + 145, + 133, + 114 + ], + [ + 159, + 140, + 120, + 112, + 147, + 112, + 93, + 100, + 127, + 99, + 87, + 86, + 146, + 96, + 112, + 127, + 155, + 138, + 101, + 109, + 117, + 130, + 133, + 87, + 106, + 106, + 100, + 109, + 107, + 143, + 148, + 140, + 127, + 98, + 83, + 114, + 143, + 128, + 96, + 123, + 109, + 125, + 85, + 116, + 103, + 125, + 132, + 139, + 142, + 111, + 112, + 146, + 97, + 104, + 133, + 137, + 146, + 125, + 103, + 111, + 126, + 102, + 145, + 106, + 81, + 96, + 112, + 88, + 122, + 127, + 118, + 121, + 105, + 129, + 104, + 103, + 92, + 139, + 86, + 136, + 115, + 124, + 99, + 144, + 87, + 95, + 93, + 155, + 94, + 145, + 111, + 142, + 130, + 104, + 137, + 142, + 141, + 101, + 137, + 137, + 128, + 131, + 121, + 149, + 94, + 133, + 139, + 87, + 132, + 139, + 84, + 147, + 85, + 126, + 134, + 119, + 131, + 95, + 92, + 109, + 98, + 96, + 142, + 98, + 137, + 134, + 141, + 102, + 88, + 91, + 80, + 137, + 80, + 113, + 127, + 80, + 95, + 140, + 143, + 142, + 148, + 101, + 146, + 155, + 105, + 95, + 130, + 136, + 108, + 157, + 148, + 126, + 141, + 148, + 155, + 95, + 127, + 118, + 112, + 102, + 89, + 148, + 113, + 131, + 89, + 98, + 137, + 80, + 148, + 83, + 95, + 103, + 159, + 81, + 111, + 103, + 91, + 129, + 114, + 112, + 112, + 140, + 130, + 122, + 91, + 146, + 144, + 112, + 119, + 153, + 122, + 123, + 108, + 92, + 91, + 125, + 81, + 114, + 160, + 87, + 105, + 153, + 113, + 86, + 147, + 137, + 154, + 108, + 115, + 100, + 115, + 89, + 152, + 103, + 143, + 128, + 115, + 103, + 102, + 141, + 116, + 91, + 134, + 92, + 102, + 109, + 96, + 141, + 92, + 138, + 98, + 128, + 91, + 140, + 98, + 155, + 88, + 150, + 107, + 157, + 131, + 95, + 148, + 91, + 104, + 131, + 132, + 102, + 95, + 136, + 118, + 132, + 121, + 137, + 118, + 93 + ], + [ + 84, + 114, + 154, + 97, + 155, + 88, + 153, + 137, + 96, + 86, + 125, + 92, + 119, + 121, + 88, + 129, + 106, + 145, + 84, + 108, + 116, + 117, + 87, + 144, + 96, + 150, + 124, + 83, + 115, + 149, + 110, + 98, + 140, + 133, + 118, + 153, + 98, + 118, + 146, + 124, + 92, + 137, + 99, + 151, + 140, + 118, + 80, + 82, + 156, + 141, + 142, + 104, + 135, + 112, + 117, + 85, + 137, + 123, + 124, + 111, + 124, + 140, + 126, + 100, + 159, + 154, + 115, + 98, + 99, + 136, + 97, + 126, + 128, + 93, + 94, + 110, + 80, + 133, + 82, + 95, + 136, + 154, + 91, + 153, + 95, + 151, + 155, + 103, + 107, + 87, + 115, + 87, + 137, + 139, + 129, + 107, + 120, + 143, + 106, + 142, + 96, + 152, + 112, + 156, + 108, + 92, + 125, + 114, + 85, + 148, + 126, + 104, + 145, + 89, + 135, + 109, + 84, + 112, + 144, + 97, + 128, + 90, + 105, + 142, + 138, + 106, + 128, + 156, + 112, + 80, + 100, + 134, + 85, + 160, + 148, + 84, + 82, + 132, + 102, + 132, + 116, + 153, + 153, + 96, + 157, + 152, + 80, + 130, + 124, + 156, + 83, + 141, + 144, + 111, + 113, + 151, + 118, + 105, + 113, + 133, + 82, + 129, + 91, + 144, + 133, + 84, + 136, + 96, + 126, + 102, + 158, + 93, + 145, + 154, + 130, + 117, + 143, + 117, + 129, + 109, + 158, + 130, + 142, + 131, + 117, + 158, + 109, + 130, + 160, + 84, + 108, + 83, + 89, + 135, + 96, + 153, + 96, + 148, + 113, + 85, + 132, + 145, + 156, + 122, + 154, + 102, + 134, + 159, + 154, + 95, + 87, + 83, + 83, + 135, + 104, + 146, + 146, + 106, + 111, + 129, + 140, + 130, + 98, + 100, + 84, + 121, + 140, + 101, + 100, + 149, + 80, + 84, + 91, + 125, + 113, + 128, + 157, + 124, + 106, + 152, + 105, + 126, + 135, + 142, + 127, + 140, + 160, + 105, + 115, + 80, + 87, + 131, + 158, + 126, + 135, + 93 + ], + [ + 107, + 157, + 81, + 105, + 93, + 138, + 135, + 86, + 82, + 102, + 97, + 117, + 94, + 143, + 107, + 153, + 118, + 136, + 96, + 123, + 104, + 96, + 92, + 104, + 147, + 89, + 146, + 97, + 113, + 87, + 119, + 121, + 120, + 85, + 131, + 105, + 143, + 138, + 135, + 138, + 149, + 112, + 132, + 101, + 100, + 149, + 149, + 83, + 154, + 141, + 141, + 103, + 134, + 88, + 82, + 110, + 119, + 115, + 103, + 85, + 145, + 154, + 83, + 158, + 85, + 130, + 141, + 136, + 145, + 158, + 154, + 87, + 105, + 130, + 124, + 123, + 84, + 149, + 105, + 147, + 98, + 99, + 91, + 126, + 80, + 93, + 143, + 117, + 116, + 90, + 156, + 82, + 112, + 85, + 129, + 89, + 84, + 102, + 89, + 123, + 81, + 92, + 119, + 81, + 144, + 142, + 152, + 96, + 88, + 154, + 94, + 103, + 117, + 114, + 128, + 148, + 141, + 139, + 129, + 157, + 154, + 88, + 113, + 155, + 114, + 80, + 119, + 143, + 101, + 139, + 143, + 151, + 90, + 93, + 139, + 109, + 114, + 116, + 84, + 157, + 105, + 141, + 83, + 121, + 97, + 119, + 151, + 118, + 93, + 111, + 130, + 117, + 102, + 142, + 94, + 104, + 96, + 145, + 157, + 132, + 130, + 118, + 130, + 149, + 85, + 146, + 86, + 130, + 151, + 121, + 143, + 94, + 108, + 112, + 106, + 115, + 108, + 117, + 136, + 106, + 134, + 112, + 147, + 145, + 89, + 84, + 153, + 117, + 92, + 110, + 126, + 131, + 135, + 94, + 108, + 87, + 84, + 108, + 126, + 147, + 155, + 124, + 81, + 106, + 115, + 115, + 105, + 122, + 106, + 148, + 99, + 90, + 153, + 117, + 85, + 151, + 102, + 126, + 125, + 91, + 92, + 141, + 139, + 122, + 155, + 147, + 84, + 116, + 151, + 110, + 88, + 130, + 108, + 157, + 119, + 120, + 90, + 102, + 80, + 125, + 100, + 115, + 133, + 136, + 80, + 142, + 133, + 134, + 119, + 94, + 100, + 126, + 152, + 132, + 88, + 153 + ], + [ + 131, + 136, + 105, + 120, + 114, + 142, + 104, + 154, + 117, + 81, + 86, + 113, + 96, + 122, + 138, + 130, + 133, + 103, + 104, + 150, + 131, + 149, + 112, + 128, + 108, + 142, + 101, + 105, + 107, + 128, + 150, + 160, + 128, + 99, + 142, + 140, + 128, + 150, + 80, + 92, + 130, + 135, + 141, + 111, + 109, + 108, + 128, + 124, + 109, + 95, + 119, + 98, + 97, + 80, + 157, + 126, + 145, + 117, + 130, + 142, + 83, + 80, + 87, + 108, + 134, + 82, + 111, + 89, + 153, + 113, + 134, + 111, + 129, + 86, + 87, + 144, + 136, + 146, + 138, + 151, + 133, + 146, + 130, + 87, + 113, + 114, + 157, + 111, + 125, + 95, + 147, + 116, + 133, + 93, + 134, + 127, + 86, + 153, + 86, + 112, + 102, + 98, + 98, + 115, + 108, + 139, + 81, + 80, + 126, + 148, + 99, + 90, + 81, + 146, + 91, + 99, + 84, + 116, + 117, + 88, + 132, + 123, + 103, + 160, + 153, + 109, + 138, + 160, + 160, + 93, + 88, + 119, + 145, + 104, + 152, + 101, + 83, + 105, + 137, + 108, + 116, + 154, + 149, + 97, + 121, + 120, + 117, + 113, + 96, + 116, + 104, + 155, + 106, + 136, + 110, + 85, + 119, + 91, + 132, + 150, + 89, + 124, + 96, + 105, + 141, + 125, + 143, + 81, + 133, + 144, + 130, + 132, + 115, + 105, + 108, + 100, + 90, + 119, + 90, + 115, + 138, + 118, + 133, + 134, + 103, + 101, + 148, + 135, + 112, + 115, + 99, + 160, + 141, + 152, + 105, + 145, + 152, + 152, + 119, + 96, + 80, + 140, + 122, + 121, + 104, + 118, + 114, + 82, + 123, + 130, + 91, + 98, + 123, + 138, + 128, + 140, + 96, + 153, + 136, + 134, + 126, + 91, + 141, + 159, + 87, + 100, + 160, + 159, + 149, + 151, + 104, + 91, + 94, + 138, + 105, + 105, + 126, + 111, + 89, + 95, + 150, + 96, + 102, + 105, + 86, + 93, + 158, + 86, + 88, + 127, + 151, + 138, + 118, + 97, + 138, + 96 + ], + [ + 93, + 110, + 103, + 139, + 124, + 82, + 116, + 122, + 119, + 134, + 118, + 94, + 83, + 104, + 92, + 112, + 95, + 121, + 145, + 134, + 121, + 113, + 109, + 92, + 92, + 97, + 111, + 118, + 125, + 108, + 141, + 141, + 136, + 95, + 135, + 89, + 109, + 104, + 84, + 144, + 128, + 82, + 124, + 93, + 109, + 147, + 97, + 141, + 116, + 104, + 127, + 144, + 132, + 158, + 152, + 94, + 128, + 147, + 91, + 138, + 116, + 140, + 122, + 149, + 118, + 135, + 142, + 125, + 90, + 141, + 156, + 104, + 150, + 131, + 83, + 138, + 151, + 99, + 142, + 133, + 153, + 136, + 120, + 82, + 85, + 84, + 84, + 133, + 126, + 128, + 88, + 99, + 140, + 114, + 129, + 141, + 96, + 82, + 111, + 92, + 152, + 147, + 93, + 90, + 135, + 146, + 108, + 88, + 146, + 104, + 155, + 121, + 88, + 158, + 119, + 104, + 84, + 90, + 138, + 158, + 117, + 151, + 84, + 127, + 97, + 149, + 116, + 139, + 127, + 156, + 107, + 148, + 156, + 160, + 119, + 124, + 141, + 137, + 146, + 135, + 119, + 119, + 156, + 112, + 85, + 98, + 159, + 119, + 138, + 138, + 127, + 120, + 92, + 124, + 132, + 121, + 142, + 106, + 117, + 139, + 86, + 84, + 124, + 137, + 110, + 102, + 121, + 89, + 111, + 132, + 109, + 118, + 110, + 81, + 157, + 156, + 91, + 118, + 145, + 92, + 125, + 106, + 96, + 139, + 143, + 114, + 149, + 109, + 150, + 153, + 142, + 107, + 139, + 158, + 116, + 101, + 147, + 99, + 147, + 143, + 96, + 91, + 101, + 96, + 89, + 144, + 101, + 129, + 93, + 143, + 126, + 108, + 155, + 115, + 131, + 155, + 86, + 108, + 90, + 112, + 85, + 158, + 111, + 130, + 154, + 137, + 105, + 125, + 108, + 83, + 160, + 92, + 118, + 94, + 108, + 108, + 154, + 111, + 81, + 126, + 157, + 101, + 154, + 106, + 146, + 160, + 135, + 100, + 126, + 159, + 138, + 85, + 98, + 136, + 148, + 82 + ], + [ + 80, + 137, + 117, + 94, + 113, + 140, + 114, + 143, + 104, + 103, + 91, + 97, + 94, + 159, + 106, + 151, + 139, + 125, + 96, + 132, + 87, + 158, + 106, + 125, + 129, + 150, + 82, + 143, + 123, + 132, + 103, + 139, + 82, + 99, + 125, + 94, + 104, + 157, + 150, + 96, + 88, + 117, + 127, + 139, + 142, + 81, + 151, + 90, + 143, + 156, + 160, + 151, + 100, + 97, + 143, + 117, + 111, + 90, + 124, + 112, + 120, + 87, + 90, + 130, + 120, + 96, + 155, + 125, + 111, + 158, + 159, + 133, + 99, + 112, + 153, + 119, + 111, + 81, + 112, + 156, + 109, + 112, + 156, + 105, + 101, + 145, + 117, + 125, + 94, + 123, + 128, + 154, + 140, + 146, + 85, + 88, + 85, + 152, + 111, + 120, + 155, + 87, + 151, + 129, + 141, + 141, + 86, + 154, + 159, + 83, + 85, + 137, + 101, + 105, + 82, + 120, + 139, + 93, + 154, + 91, + 91, + 92, + 104, + 124, + 98, + 134, + 124, + 87, + 132, + 134, + 111, + 130, + 123, + 149, + 97, + 101, + 116, + 135, + 138, + 82, + 107, + 153, + 114, + 140, + 133, + 142, + 158, + 86, + 155, + 128, + 110, + 137, + 140, + 103, + 99, + 116, + 83, + 99, + 139, + 103, + 129, + 128, + 130, + 115, + 145, + 158, + 113, + 115, + 85, + 118, + 121, + 137, + 91, + 82, + 97, + 119, + 100, + 94, + 105, + 87, + 121, + 123, + 93, + 157, + 81, + 116, + 148, + 117, + 120, + 141, + 139, + 93, + 81, + 149, + 96, + 107, + 145, + 114, + 84, + 144, + 119, + 155, + 135, + 115, + 124, + 102, + 126, + 110, + 90, + 89, + 146, + 93, + 159, + 153, + 89, + 155, + 80, + 84, + 99, + 87, + 122, + 155, + 80, + 136, + 143, + 107, + 116, + 126, + 155, + 116, + 128, + 123, + 135, + 107, + 158, + 85, + 140, + 99, + 157, + 81, + 152, + 116, + 99, + 105, + 108, + 80, + 114, + 149, + 138, + 148, + 112, + 130, + 157, + 91, + 99, + 132 + ], + [ + 122, + 129, + 108, + 136, + 118, + 135, + 122, + 129, + 102, + 123, + 117, + 102, + 124, + 131, + 159, + 129, + 87, + 80, + 87, + 108, + 84, + 126, + 113, + 144, + 134, + 122, + 153, + 109, + 96, + 105, + 87, + 126, + 155, + 151, + 125, + 147, + 112, + 80, + 98, + 110, + 159, + 93, + 132, + 127, + 148, + 141, + 145, + 143, + 149, + 142, + 152, + 112, + 120, + 127, + 92, + 152, + 156, + 110, + 154, + 147, + 83, + 98, + 149, + 92, + 102, + 154, + 80, + 133, + 150, + 94, + 144, + 155, + 150, + 94, + 100, + 98, + 139, + 112, + 128, + 114, + 123, + 90, + 151, + 157, + 153, + 145, + 125, + 88, + 82, + 105, + 151, + 127, + 145, + 149, + 107, + 105, + 104, + 91, + 84, + 127, + 116, + 158, + 123, + 94, + 122, + 127, + 149, + 125, + 122, + 100, + 102, + 159, + 126, + 94, + 104, + 115, + 154, + 130, + 94, + 147, + 100, + 92, + 98, + 113, + 157, + 147, + 131, + 106, + 108, + 107, + 114, + 87, + 95, + 115, + 109, + 97, + 81, + 128, + 156, + 83, + 154, + 86, + 153, + 118, + 92, + 110, + 150, + 142, + 135, + 121, + 140, + 121, + 135, + 115, + 147, + 86, + 133, + 80, + 123, + 124, + 126, + 92, + 101, + 100, + 88, + 105, + 88, + 94, + 89, + 121, + 151, + 103, + 159, + 128, + 105, + 116, + 147, + 129, + 103, + 98, + 139, + 118, + 153, + 97, + 98, + 115, + 153, + 134, + 148, + 126, + 128, + 112, + 100, + 160, + 101, + 127, + 116, + 94, + 114, + 150, + 111, + 90, + 89, + 152, + 157, + 156, + 136, + 104, + 81, + 143, + 114, + 86, + 97, + 146, + 121, + 118, + 120, + 154, + 103, + 139, + 125, + 98, + 147, + 121, + 111, + 118, + 115, + 115, + 82, + 152, + 87, + 149, + 120, + 147, + 92, + 128, + 139, + 81, + 139, + 84, + 123, + 106, + 111, + 100, + 87, + 151, + 129, + 146, + 123, + 137, + 82, + 149, + 81, + 123, + 81, + 160 + ], + [ + 114, + 109, + 87, + 101, + 123, + 137, + 115, + 148, + 156, + 110, + 126, + 134, + 159, + 141, + 109, + 134, + 95, + 91, + 96, + 95, + 91, + 87, + 82, + 95, + 109, + 102, + 84, + 144, + 109, + 100, + 149, + 101, + 151, + 111, + 112, + 95, + 155, + 117, + 93, + 115, + 102, + 145, + 81, + 152, + 144, + 154, + 146, + 131, + 141, + 102, + 105, + 147, + 134, + 85, + 114, + 101, + 122, + 94, + 139, + 91, + 128, + 121, + 117, + 124, + 114, + 109, + 109, + 157, + 85, + 146, + 159, + 86, + 131, + 93, + 152, + 94, + 100, + 114, + 145, + 81, + 114, + 157, + 99, + 89, + 104, + 145, + 104, + 133, + 112, + 115, + 93, + 144, + 159, + 133, + 131, + 82, + 120, + 85, + 104, + 92, + 87, + 100, + 147, + 104, + 158, + 147, + 124, + 139, + 91, + 111, + 92, + 110, + 157, + 99, + 97, + 139, + 142, + 85, + 111, + 122, + 81, + 139, + 102, + 110, + 84, + 113, + 89, + 81, + 90, + 129, + 110, + 135, + 158, + 152, + 88, + 100, + 152, + 107, + 141, + 86, + 159, + 95, + 83, + 83, + 119, + 96, + 131, + 131, + 115, + 117, + 129, + 142, + 142, + 141, + 86, + 106, + 93, + 150, + 84, + 134, + 81, + 141, + 100, + 158, + 100, + 89, + 137, + 118, + 84, + 144, + 92, + 132, + 124, + 111, + 115, + 160, + 139, + 90, + 142, + 100, + 131, + 134, + 105, + 147, + 113, + 131, + 131, + 108, + 117, + 121, + 100, + 121, + 92, + 143, + 96, + 88, + 120, + 80, + 81, + 139, + 127, + 108, + 89, + 137, + 139, + 100, + 107, + 154, + 113, + 87, + 138, + 151, + 128, + 158, + 96, + 139, + 124, + 148, + 111, + 123, + 144, + 87, + 156, + 151, + 122, + 80, + 130, + 103, + 103, + 96, + 84, + 99, + 100, + 137, + 110, + 90, + 131, + 94, + 86, + 107, + 130, + 137, + 84, + 84, + 87, + 143, + 89, + 117, + 160, + 87, + 104, + 152, + 145, + 158, + 129, + 146 + ], + [ + 128, + 115, + 147, + 157, + 120, + 141, + 89, + 100, + 115, + 91, + 81, + 92, + 132, + 111, + 95, + 111, + 111, + 90, + 89, + 94, + 96, + 119, + 97, + 134, + 152, + 104, + 138, + 157, + 101, + 109, + 123, + 111, + 130, + 125, + 100, + 111, + 118, + 91, + 126, + 97, + 100, + 92, + 100, + 87, + 125, + 86, + 102, + 154, + 101, + 101, + 108, + 92, + 81, + 108, + 158, + 131, + 106, + 129, + 158, + 102, + 123, + 156, + 101, + 122, + 116, + 140, + 135, + 82, + 142, + 98, + 102, + 108, + 109, + 118, + 104, + 87, + 137, + 95, + 151, + 80, + 84, + 82, + 144, + 83, + 143, + 89, + 100, + 125, + 106, + 145, + 120, + 92, + 88, + 115, + 137, + 133, + 91, + 160, + 157, + 93, + 113, + 132, + 99, + 154, + 95, + 149, + 135, + 137, + 83, + 83, + 99, + 89, + 103, + 105, + 116, + 133, + 100, + 153, + 117, + 125, + 83, + 139, + 136, + 124, + 99, + 96, + 150, + 83, + 93, + 82, + 159, + 137, + 134, + 83, + 103, + 113, + 109, + 98, + 110, + 139, + 139, + 98, + 128, + 142, + 101, + 103, + 155, + 109, + 137, + 110, + 143, + 98, + 127, + 124, + 134, + 134, + 97, + 142, + 126, + 154, + 136, + 124, + 112, + 96, + 149, + 83, + 99, + 126, + 94, + 149, + 126, + 157, + 154, + 88, + 95, + 148, + 140, + 89, + 129, + 82, + 156, + 149, + 160, + 130, + 98, + 121, + 114, + 125, + 105, + 148, + 143, + 116, + 128, + 100, + 153, + 143, + 84, + 149, + 83, + 118, + 114, + 80, + 103, + 136, + 149, + 120, + 133, + 125, + 139, + 139, + 136, + 110, + 103, + 115, + 109, + 103, + 109, + 86, + 114, + 134, + 93, + 89, + 159, + 95, + 133, + 156, + 149, + 123, + 152, + 124, + 135, + 136, + 100, + 147, + 129, + 112, + 86, + 149, + 94, + 111, + 154, + 99, + 94, + 127, + 150, + 100, + 121, + 105, + 93, + 136, + 116, + 137, + 100, + 148, + 82, + 116 + ], + [ + 128, + 98, + 91, + 140, + 120, + 112, + 129, + 147, + 150, + 112, + 160, + 103, + 140, + 140, + 107, + 137, + 122, + 120, + 118, + 128, + 96, + 149, + 149, + 144, + 153, + 84, + 81, + 101, + 136, + 145, + 153, + 95, + 130, + 118, + 155, + 119, + 117, + 139, + 138, + 89, + 143, + 142, + 128, + 143, + 111, + 86, + 100, + 99, + 105, + 147, + 94, + 158, + 93, + 106, + 119, + 137, + 107, + 105, + 142, + 145, + 104, + 82, + 151, + 123, + 121, + 85, + 118, + 131, + 102, + 129, + 122, + 98, + 148, + 130, + 90, + 91, + 153, + 87, + 103, + 120, + 115, + 151, + 134, + 88, + 126, + 94, + 127, + 121, + 139, + 149, + 91, + 81, + 86, + 159, + 82, + 82, + 106, + 119, + 112, + 80, + 87, + 111, + 127, + 83, + 149, + 98, + 111, + 88, + 91, + 90, + 125, + 95, + 80, + 142, + 100, + 109, + 92, + 113, + 83, + 107, + 138, + 106, + 148, + 104, + 82, + 147, + 104, + 99, + 113, + 105, + 84, + 140, + 88, + 129, + 85, + 139, + 123, + 90, + 145, + 126, + 152, + 82, + 155, + 109, + 100, + 107, + 141, + 112, + 114, + 131, + 113, + 120, + 130, + 126, + 137, + 154, + 136, + 102, + 127, + 101, + 96, + 148, + 90, + 94, + 114, + 112, + 95, + 113, + 145, + 119, + 107, + 126, + 118, + 149, + 152, + 99, + 80, + 151, + 142, + 146, + 83, + 137, + 104, + 126, + 117, + 94, + 142, + 100, + 136, + 132, + 125, + 93, + 155, + 96, + 118, + 88, + 156, + 158, + 85, + 102, + 138, + 125, + 117, + 93, + 151, + 141, + 114, + 98, + 109, + 157, + 104, + 119, + 93, + 105, + 101, + 127, + 130, + 82, + 121, + 98, + 152, + 107, + 124, + 150, + 157, + 142, + 152, + 98, + 126, + 93, + 117, + 132, + 140, + 115, + 119, + 160, + 109, + 88, + 143, + 112, + 131, + 93, + 154, + 132, + 138, + 157, + 120, + 147, + 147, + 87, + 81, + 120, + 134, + 95, + 141, + 127 + ], + [ + 109, + 91, + 114, + 90, + 143, + 129, + 100, + 148, + 127, + 99, + 146, + 123, + 112, + 152, + 140, + 88, + 147, + 114, + 160, + 144, + 87, + 124, + 126, + 119, + 155, + 133, + 155, + 125, + 126, + 94, + 98, + 105, + 131, + 101, + 100, + 130, + 117, + 99, + 108, + 130, + 83, + 134, + 150, + 107, + 83, + 101, + 155, + 124, + 146, + 145, + 113, + 134, + 80, + 151, + 117, + 85, + 104, + 96, + 140, + 91, + 85, + 106, + 87, + 144, + 136, + 119, + 93, + 114, + 127, + 135, + 128, + 92, + 140, + 128, + 125, + 141, + 87, + 122, + 99, + 118, + 128, + 143, + 141, + 133, + 88, + 116, + 128, + 120, + 96, + 100, + 130, + 138, + 118, + 134, + 120, + 118, + 92, + 143, + 90, + 98, + 90, + 156, + 125, + 124, + 80, + 119, + 98, + 155, + 151, + 151, + 140, + 91, + 140, + 112, + 124, + 148, + 87, + 104, + 125, + 109, + 126, + 128, + 126, + 97, + 145, + 147, + 112, + 115, + 98, + 107, + 119, + 155, + 94, + 120, + 118, + 81, + 142, + 108, + 153, + 145, + 130, + 154, + 156, + 89, + 93, + 130, + 152, + 118, + 107, + 82, + 149, + 106, + 81, + 117, + 82, + 95, + 119, + 81, + 160, + 120, + 91, + 114, + 144, + 134, + 124, + 158, + 141, + 128, + 104, + 102, + 103, + 117, + 107, + 100, + 133, + 97, + 115, + 124, + 152, + 132, + 109, + 139, + 111, + 146, + 146, + 107, + 123, + 158, + 120, + 142, + 102, + 154, + 127, + 111, + 136, + 112, + 103, + 118, + 113, + 135, + 107, + 116, + 86, + 147, + 116, + 135, + 144, + 144, + 126, + 99, + 98, + 108, + 153, + 139, + 97, + 114, + 122, + 90, + 128, + 117, + 143, + 148, + 87, + 113, + 127, + 135, + 155, + 144, + 110, + 83, + 129, + 148, + 109, + 138, + 146, + 154, + 124, + 83, + 105, + 106, + 125, + 143, + 130, + 129, + 85, + 138, + 94, + 109, + 121, + 116, + 121, + 114, + 85, + 147, + 97, + 152 + ] + ] +} \ No newline at end of file diff --git a/ep_mapping_json_gen.py b/ep_mapping_json_gen.py index 7982cdf807..450ade3d30 100644 --- a/ep_mapping_json_gen.py +++ b/ep_mapping_json_gen.py @@ -17,7 +17,7 @@ } # 写入 JSON 文件 -output_path = "/nvme1/liudongyan/workspace/lmdeploy/ep_mapping_json_prefill.json" +output_path = "/opt/workspace/workspace/lmdeploy_internLM/lmdeploy/ep_mapping_json_decode.json" with open(output_path, "w") as f: json.dump(data, f, indent=2) diff --git a/ep_mapping_json_path.json b/ep_mapping_json_path.json deleted file mode 100644 index 2058c3fb20..0000000000 --- a/ep_mapping_json_path.json +++ /dev/null @@ -1,262 +0,0 @@ -{ - "num_groups": 4, - "num_nodes": 1, - "weight": [[ - 131, - 94, - 151, - 140, - 100, - 154, - 154, - 103, - 82, - 101, - 132, - 81, - 109, - 117, - 81, - 143, - 139, - 100, - 112, - 155, - 137, - 101, - 128, - 138, - 121, - 139, - 159, - 94, - 141, - 141, - 126, - 141, - 130, - 134, - 143, - 82, - 130, - 86, - 100, - 152, - 118, - 97, - 83, - 139, - 93, - 88, - 132, - 81, - 139, - 150, - 123, - 87, - 126, - 114, - 157, - 160, - 115, - 129, - 83, - 81, - 85, - 133, - 83, - 133, - 142, - 97, - 123, - 113, - 153, - 141, - 93, - 127, - 94, - 151, - 157, - 141, - 119, - 159, - 132, - 103, - 105, - 139, - 120, - 108, - 94, - 124, - 144, - 150, - 88, - 80, - 87, - 142, - 90, - 160, - 87, - 114, - 114, - 112, - 84, - 120, - 107, - 86, - 152, - 151, - 91, - 113, - 112, - 127, - 102, - 141, - 116, - 123, - 114, - 144, - 126, - 157, - 82, - 80, - 84, - 93, - 106, - 88, - 158, - 94, - 121, - 156, - 130, - 142, - 131, - 83, - 102, - 94, - 122, - 108, - 115, - 92, - 111, - 150, - 138, - 107, - 145, - 121, - 124, - 141, - 136, - 85, - 107, - 107, - 123, - 109, - 141, - 154, - 141, - 80, - 106, - 141, - 156, - 82, - 149, - 151, - 106, - 88, - 141, - 116, - 130, - 123, - 103, - 158, - 138, - 111, - 131, - 141, - 137, - 131, - 91, - 118, - 81, - 82, - 135, - 160, - 138, - 81, - 81, - 133, - 80, - 98, - 81, - 132, - 123, - 111, - 149, - 111, - 147, - 134, - 154, - 135, - 96, - 117, - 103, - 148, - 149, - 90, - 95, - 152, - 138, - 149, - 159, - 82, - 99, - 138, - 115, - 98, - 146, - 98, - 99, - 150, - 131, - 112, - 119, - 118, - 80, - 90, - 136, - 129, - 102, - 110, - 121, - 86, - 95, - 139, - 81, - 80, - 127, - 91, - 148, - 116, - 111, - 88, - 98, - 127, - 159, - 82, - 99, - 103, - 133, - 112, - 103, - 154, - 151, - 115, - 117, - 104, - 97, - 145, - 133, - 114 - ]] -} \ No newline at end of file diff --git a/ep_mapping_json_prefill.json b/ep_mapping_json_prefill.json new file mode 100644 index 0000000000..6bd9f883de --- /dev/null +++ b/ep_mapping_json_prefill.json @@ -0,0 +1,3102 @@ +{ + "num_groups": 4, + "num_nodes": 1, + "weight": [ + [ + 131, + 94, + 151, + 140, + 100, + 154, + 154, + 103, + 82, + 101, + 132, + 81, + 109, + 117, + 81, + 143, + 139, + 100, + 112, + 155, + 137, + 101, + 128, + 138, + 121, + 139, + 159, + 94, + 141, + 141, + 126, + 141, + 130, + 134, + 143, + 82, + 130, + 86, + 100, + 152, + 118, + 97, + 83, + 139, + 93, + 88, + 132, + 81, + 139, + 150, + 123, + 87, + 126, + 114, + 157, + 160, + 115, + 129, + 83, + 81, + 85, + 133, + 83, + 133, + 142, + 97, + 123, + 113, + 153, + 141, + 93, + 127, + 94, + 151, + 157, + 141, + 119, + 159, + 132, + 103, + 105, + 139, + 120, + 108, + 94, + 124, + 144, + 150, + 88, + 80, + 87, + 142, + 90, + 160, + 87, + 114, + 114, + 112, + 84, + 120, + 107, + 86, + 152, + 151, + 91, + 113, + 112, + 127, + 102, + 141, + 116, + 123, + 114, + 144, + 126, + 157, + 82, + 80, + 84, + 93, + 106, + 88, + 158, + 94, + 121, + 156, + 130, + 142, + 131, + 83, + 102, + 94, + 122, + 108, + 115, + 92, + 111, + 150, + 138, + 107, + 145, + 121, + 124, + 141, + 136, + 85, + 107, + 107, + 123, + 109, + 141, + 154, + 141, + 80, + 106, + 141, + 156, + 82, + 149, + 151, + 106, + 88, + 141, + 116, + 130, + 123, + 103, + 158, + 138, + 111, + 131, + 141, + 137, + 131, + 91, + 118, + 81, + 82, + 135, + 160, + 138, + 81, + 81, + 133, + 80, + 98, + 81, + 132, + 123, + 111, + 149, + 111, + 147, + 134, + 154, + 135, + 96, + 117, + 103, + 148, + 149, + 90, + 95, + 152, + 138, + 149, + 159, + 82, + 99, + 138, + 115, + 98, + 146, + 98, + 99, + 150, + 131, + 112, + 119, + 118, + 80, + 90, + 136, + 129, + 102, + 110, + 121, + 86, + 95, + 139, + 81, + 80, + 127, + 91, + 148, + 116, + 111, + 88, + 98, + 127, + 159, + 82, + 99, + 103, + 133, + 112, + 103, + 154, + 151, + 115, + 117, + 104, + 97, + 145, + 133, + 114 + ], + [ + 159, + 140, + 120, + 112, + 147, + 112, + 93, + 100, + 127, + 99, + 87, + 86, + 146, + 96, + 112, + 127, + 155, + 138, + 101, + 109, + 117, + 130, + 133, + 87, + 106, + 106, + 100, + 109, + 107, + 143, + 148, + 140, + 127, + 98, + 83, + 114, + 143, + 128, + 96, + 123, + 109, + 125, + 85, + 116, + 103, + 125, + 132, + 139, + 142, + 111, + 112, + 146, + 97, + 104, + 133, + 137, + 146, + 125, + 103, + 111, + 126, + 102, + 145, + 106, + 81, + 96, + 112, + 88, + 122, + 127, + 118, + 121, + 105, + 129, + 104, + 103, + 92, + 139, + 86, + 136, + 115, + 124, + 99, + 144, + 87, + 95, + 93, + 155, + 94, + 145, + 111, + 142, + 130, + 104, + 137, + 142, + 141, + 101, + 137, + 137, + 128, + 131, + 121, + 149, + 94, + 133, + 139, + 87, + 132, + 139, + 84, + 147, + 85, + 126, + 134, + 119, + 131, + 95, + 92, + 109, + 98, + 96, + 142, + 98, + 137, + 134, + 141, + 102, + 88, + 91, + 80, + 137, + 80, + 113, + 127, + 80, + 95, + 140, + 143, + 142, + 148, + 101, + 146, + 155, + 105, + 95, + 130, + 136, + 108, + 157, + 148, + 126, + 141, + 148, + 155, + 95, + 127, + 118, + 112, + 102, + 89, + 148, + 113, + 131, + 89, + 98, + 137, + 80, + 148, + 83, + 95, + 103, + 159, + 81, + 111, + 103, + 91, + 129, + 114, + 112, + 112, + 140, + 130, + 122, + 91, + 146, + 144, + 112, + 119, + 153, + 122, + 123, + 108, + 92, + 91, + 125, + 81, + 114, + 160, + 87, + 105, + 153, + 113, + 86, + 147, + 137, + 154, + 108, + 115, + 100, + 115, + 89, + 152, + 103, + 143, + 128, + 115, + 103, + 102, + 141, + 116, + 91, + 134, + 92, + 102, + 109, + 96, + 141, + 92, + 138, + 98, + 128, + 91, + 140, + 98, + 155, + 88, + 150, + 107, + 157, + 131, + 95, + 148, + 91, + 104, + 131, + 132, + 102, + 95, + 136, + 118, + 132, + 121, + 137, + 118, + 93 + ], + [ + 84, + 114, + 154, + 97, + 155, + 88, + 153, + 137, + 96, + 86, + 125, + 92, + 119, + 121, + 88, + 129, + 106, + 145, + 84, + 108, + 116, + 117, + 87, + 144, + 96, + 150, + 124, + 83, + 115, + 149, + 110, + 98, + 140, + 133, + 118, + 153, + 98, + 118, + 146, + 124, + 92, + 137, + 99, + 151, + 140, + 118, + 80, + 82, + 156, + 141, + 142, + 104, + 135, + 112, + 117, + 85, + 137, + 123, + 124, + 111, + 124, + 140, + 126, + 100, + 159, + 154, + 115, + 98, + 99, + 136, + 97, + 126, + 128, + 93, + 94, + 110, + 80, + 133, + 82, + 95, + 136, + 154, + 91, + 153, + 95, + 151, + 155, + 103, + 107, + 87, + 115, + 87, + 137, + 139, + 129, + 107, + 120, + 143, + 106, + 142, + 96, + 152, + 112, + 156, + 108, + 92, + 125, + 114, + 85, + 148, + 126, + 104, + 145, + 89, + 135, + 109, + 84, + 112, + 144, + 97, + 128, + 90, + 105, + 142, + 138, + 106, + 128, + 156, + 112, + 80, + 100, + 134, + 85, + 160, + 148, + 84, + 82, + 132, + 102, + 132, + 116, + 153, + 153, + 96, + 157, + 152, + 80, + 130, + 124, + 156, + 83, + 141, + 144, + 111, + 113, + 151, + 118, + 105, + 113, + 133, + 82, + 129, + 91, + 144, + 133, + 84, + 136, + 96, + 126, + 102, + 158, + 93, + 145, + 154, + 130, + 117, + 143, + 117, + 129, + 109, + 158, + 130, + 142, + 131, + 117, + 158, + 109, + 130, + 160, + 84, + 108, + 83, + 89, + 135, + 96, + 153, + 96, + 148, + 113, + 85, + 132, + 145, + 156, + 122, + 154, + 102, + 134, + 159, + 154, + 95, + 87, + 83, + 83, + 135, + 104, + 146, + 146, + 106, + 111, + 129, + 140, + 130, + 98, + 100, + 84, + 121, + 140, + 101, + 100, + 149, + 80, + 84, + 91, + 125, + 113, + 128, + 157, + 124, + 106, + 152, + 105, + 126, + 135, + 142, + 127, + 140, + 160, + 105, + 115, + 80, + 87, + 131, + 158, + 126, + 135, + 93 + ], + [ + 107, + 157, + 81, + 105, + 93, + 138, + 135, + 86, + 82, + 102, + 97, + 117, + 94, + 143, + 107, + 153, + 118, + 136, + 96, + 123, + 104, + 96, + 92, + 104, + 147, + 89, + 146, + 97, + 113, + 87, + 119, + 121, + 120, + 85, + 131, + 105, + 143, + 138, + 135, + 138, + 149, + 112, + 132, + 101, + 100, + 149, + 149, + 83, + 154, + 141, + 141, + 103, + 134, + 88, + 82, + 110, + 119, + 115, + 103, + 85, + 145, + 154, + 83, + 158, + 85, + 130, + 141, + 136, + 145, + 158, + 154, + 87, + 105, + 130, + 124, + 123, + 84, + 149, + 105, + 147, + 98, + 99, + 91, + 126, + 80, + 93, + 143, + 117, + 116, + 90, + 156, + 82, + 112, + 85, + 129, + 89, + 84, + 102, + 89, + 123, + 81, + 92, + 119, + 81, + 144, + 142, + 152, + 96, + 88, + 154, + 94, + 103, + 117, + 114, + 128, + 148, + 141, + 139, + 129, + 157, + 154, + 88, + 113, + 155, + 114, + 80, + 119, + 143, + 101, + 139, + 143, + 151, + 90, + 93, + 139, + 109, + 114, + 116, + 84, + 157, + 105, + 141, + 83, + 121, + 97, + 119, + 151, + 118, + 93, + 111, + 130, + 117, + 102, + 142, + 94, + 104, + 96, + 145, + 157, + 132, + 130, + 118, + 130, + 149, + 85, + 146, + 86, + 130, + 151, + 121, + 143, + 94, + 108, + 112, + 106, + 115, + 108, + 117, + 136, + 106, + 134, + 112, + 147, + 145, + 89, + 84, + 153, + 117, + 92, + 110, + 126, + 131, + 135, + 94, + 108, + 87, + 84, + 108, + 126, + 147, + 155, + 124, + 81, + 106, + 115, + 115, + 105, + 122, + 106, + 148, + 99, + 90, + 153, + 117, + 85, + 151, + 102, + 126, + 125, + 91, + 92, + 141, + 139, + 122, + 155, + 147, + 84, + 116, + 151, + 110, + 88, + 130, + 108, + 157, + 119, + 120, + 90, + 102, + 80, + 125, + 100, + 115, + 133, + 136, + 80, + 142, + 133, + 134, + 119, + 94, + 100, + 126, + 152, + 132, + 88, + 153 + ], + [ + 131, + 136, + 105, + 120, + 114, + 142, + 104, + 154, + 117, + 81, + 86, + 113, + 96, + 122, + 138, + 130, + 133, + 103, + 104, + 150, + 131, + 149, + 112, + 128, + 108, + 142, + 101, + 105, + 107, + 128, + 150, + 160, + 128, + 99, + 142, + 140, + 128, + 150, + 80, + 92, + 130, + 135, + 141, + 111, + 109, + 108, + 128, + 124, + 109, + 95, + 119, + 98, + 97, + 80, + 157, + 126, + 145, + 117, + 130, + 142, + 83, + 80, + 87, + 108, + 134, + 82, + 111, + 89, + 153, + 113, + 134, + 111, + 129, + 86, + 87, + 144, + 136, + 146, + 138, + 151, + 133, + 146, + 130, + 87, + 113, + 114, + 157, + 111, + 125, + 95, + 147, + 116, + 133, + 93, + 134, + 127, + 86, + 153, + 86, + 112, + 102, + 98, + 98, + 115, + 108, + 139, + 81, + 80, + 126, + 148, + 99, + 90, + 81, + 146, + 91, + 99, + 84, + 116, + 117, + 88, + 132, + 123, + 103, + 160, + 153, + 109, + 138, + 160, + 160, + 93, + 88, + 119, + 145, + 104, + 152, + 101, + 83, + 105, + 137, + 108, + 116, + 154, + 149, + 97, + 121, + 120, + 117, + 113, + 96, + 116, + 104, + 155, + 106, + 136, + 110, + 85, + 119, + 91, + 132, + 150, + 89, + 124, + 96, + 105, + 141, + 125, + 143, + 81, + 133, + 144, + 130, + 132, + 115, + 105, + 108, + 100, + 90, + 119, + 90, + 115, + 138, + 118, + 133, + 134, + 103, + 101, + 148, + 135, + 112, + 115, + 99, + 160, + 141, + 152, + 105, + 145, + 152, + 152, + 119, + 96, + 80, + 140, + 122, + 121, + 104, + 118, + 114, + 82, + 123, + 130, + 91, + 98, + 123, + 138, + 128, + 140, + 96, + 153, + 136, + 134, + 126, + 91, + 141, + 159, + 87, + 100, + 160, + 159, + 149, + 151, + 104, + 91, + 94, + 138, + 105, + 105, + 126, + 111, + 89, + 95, + 150, + 96, + 102, + 105, + 86, + 93, + 158, + 86, + 88, + 127, + 151, + 138, + 118, + 97, + 138, + 96 + ], + [ + 93, + 110, + 103, + 139, + 124, + 82, + 116, + 122, + 119, + 134, + 118, + 94, + 83, + 104, + 92, + 112, + 95, + 121, + 145, + 134, + 121, + 113, + 109, + 92, + 92, + 97, + 111, + 118, + 125, + 108, + 141, + 141, + 136, + 95, + 135, + 89, + 109, + 104, + 84, + 144, + 128, + 82, + 124, + 93, + 109, + 147, + 97, + 141, + 116, + 104, + 127, + 144, + 132, + 158, + 152, + 94, + 128, + 147, + 91, + 138, + 116, + 140, + 122, + 149, + 118, + 135, + 142, + 125, + 90, + 141, + 156, + 104, + 150, + 131, + 83, + 138, + 151, + 99, + 142, + 133, + 153, + 136, + 120, + 82, + 85, + 84, + 84, + 133, + 126, + 128, + 88, + 99, + 140, + 114, + 129, + 141, + 96, + 82, + 111, + 92, + 152, + 147, + 93, + 90, + 135, + 146, + 108, + 88, + 146, + 104, + 155, + 121, + 88, + 158, + 119, + 104, + 84, + 90, + 138, + 158, + 117, + 151, + 84, + 127, + 97, + 149, + 116, + 139, + 127, + 156, + 107, + 148, + 156, + 160, + 119, + 124, + 141, + 137, + 146, + 135, + 119, + 119, + 156, + 112, + 85, + 98, + 159, + 119, + 138, + 138, + 127, + 120, + 92, + 124, + 132, + 121, + 142, + 106, + 117, + 139, + 86, + 84, + 124, + 137, + 110, + 102, + 121, + 89, + 111, + 132, + 109, + 118, + 110, + 81, + 157, + 156, + 91, + 118, + 145, + 92, + 125, + 106, + 96, + 139, + 143, + 114, + 149, + 109, + 150, + 153, + 142, + 107, + 139, + 158, + 116, + 101, + 147, + 99, + 147, + 143, + 96, + 91, + 101, + 96, + 89, + 144, + 101, + 129, + 93, + 143, + 126, + 108, + 155, + 115, + 131, + 155, + 86, + 108, + 90, + 112, + 85, + 158, + 111, + 130, + 154, + 137, + 105, + 125, + 108, + 83, + 160, + 92, + 118, + 94, + 108, + 108, + 154, + 111, + 81, + 126, + 157, + 101, + 154, + 106, + 146, + 160, + 135, + 100, + 126, + 159, + 138, + 85, + 98, + 136, + 148, + 82 + ], + [ + 80, + 137, + 117, + 94, + 113, + 140, + 114, + 143, + 104, + 103, + 91, + 97, + 94, + 159, + 106, + 151, + 139, + 125, + 96, + 132, + 87, + 158, + 106, + 125, + 129, + 150, + 82, + 143, + 123, + 132, + 103, + 139, + 82, + 99, + 125, + 94, + 104, + 157, + 150, + 96, + 88, + 117, + 127, + 139, + 142, + 81, + 151, + 90, + 143, + 156, + 160, + 151, + 100, + 97, + 143, + 117, + 111, + 90, + 124, + 112, + 120, + 87, + 90, + 130, + 120, + 96, + 155, + 125, + 111, + 158, + 159, + 133, + 99, + 112, + 153, + 119, + 111, + 81, + 112, + 156, + 109, + 112, + 156, + 105, + 101, + 145, + 117, + 125, + 94, + 123, + 128, + 154, + 140, + 146, + 85, + 88, + 85, + 152, + 111, + 120, + 155, + 87, + 151, + 129, + 141, + 141, + 86, + 154, + 159, + 83, + 85, + 137, + 101, + 105, + 82, + 120, + 139, + 93, + 154, + 91, + 91, + 92, + 104, + 124, + 98, + 134, + 124, + 87, + 132, + 134, + 111, + 130, + 123, + 149, + 97, + 101, + 116, + 135, + 138, + 82, + 107, + 153, + 114, + 140, + 133, + 142, + 158, + 86, + 155, + 128, + 110, + 137, + 140, + 103, + 99, + 116, + 83, + 99, + 139, + 103, + 129, + 128, + 130, + 115, + 145, + 158, + 113, + 115, + 85, + 118, + 121, + 137, + 91, + 82, + 97, + 119, + 100, + 94, + 105, + 87, + 121, + 123, + 93, + 157, + 81, + 116, + 148, + 117, + 120, + 141, + 139, + 93, + 81, + 149, + 96, + 107, + 145, + 114, + 84, + 144, + 119, + 155, + 135, + 115, + 124, + 102, + 126, + 110, + 90, + 89, + 146, + 93, + 159, + 153, + 89, + 155, + 80, + 84, + 99, + 87, + 122, + 155, + 80, + 136, + 143, + 107, + 116, + 126, + 155, + 116, + 128, + 123, + 135, + 107, + 158, + 85, + 140, + 99, + 157, + 81, + 152, + 116, + 99, + 105, + 108, + 80, + 114, + 149, + 138, + 148, + 112, + 130, + 157, + 91, + 99, + 132 + ], + [ + 122, + 129, + 108, + 136, + 118, + 135, + 122, + 129, + 102, + 123, + 117, + 102, + 124, + 131, + 159, + 129, + 87, + 80, + 87, + 108, + 84, + 126, + 113, + 144, + 134, + 122, + 153, + 109, + 96, + 105, + 87, + 126, + 155, + 151, + 125, + 147, + 112, + 80, + 98, + 110, + 159, + 93, + 132, + 127, + 148, + 141, + 145, + 143, + 149, + 142, + 152, + 112, + 120, + 127, + 92, + 152, + 156, + 110, + 154, + 147, + 83, + 98, + 149, + 92, + 102, + 154, + 80, + 133, + 150, + 94, + 144, + 155, + 150, + 94, + 100, + 98, + 139, + 112, + 128, + 114, + 123, + 90, + 151, + 157, + 153, + 145, + 125, + 88, + 82, + 105, + 151, + 127, + 145, + 149, + 107, + 105, + 104, + 91, + 84, + 127, + 116, + 158, + 123, + 94, + 122, + 127, + 149, + 125, + 122, + 100, + 102, + 159, + 126, + 94, + 104, + 115, + 154, + 130, + 94, + 147, + 100, + 92, + 98, + 113, + 157, + 147, + 131, + 106, + 108, + 107, + 114, + 87, + 95, + 115, + 109, + 97, + 81, + 128, + 156, + 83, + 154, + 86, + 153, + 118, + 92, + 110, + 150, + 142, + 135, + 121, + 140, + 121, + 135, + 115, + 147, + 86, + 133, + 80, + 123, + 124, + 126, + 92, + 101, + 100, + 88, + 105, + 88, + 94, + 89, + 121, + 151, + 103, + 159, + 128, + 105, + 116, + 147, + 129, + 103, + 98, + 139, + 118, + 153, + 97, + 98, + 115, + 153, + 134, + 148, + 126, + 128, + 112, + 100, + 160, + 101, + 127, + 116, + 94, + 114, + 150, + 111, + 90, + 89, + 152, + 157, + 156, + 136, + 104, + 81, + 143, + 114, + 86, + 97, + 146, + 121, + 118, + 120, + 154, + 103, + 139, + 125, + 98, + 147, + 121, + 111, + 118, + 115, + 115, + 82, + 152, + 87, + 149, + 120, + 147, + 92, + 128, + 139, + 81, + 139, + 84, + 123, + 106, + 111, + 100, + 87, + 151, + 129, + 146, + 123, + 137, + 82, + 149, + 81, + 123, + 81, + 160 + ], + [ + 114, + 109, + 87, + 101, + 123, + 137, + 115, + 148, + 156, + 110, + 126, + 134, + 159, + 141, + 109, + 134, + 95, + 91, + 96, + 95, + 91, + 87, + 82, + 95, + 109, + 102, + 84, + 144, + 109, + 100, + 149, + 101, + 151, + 111, + 112, + 95, + 155, + 117, + 93, + 115, + 102, + 145, + 81, + 152, + 144, + 154, + 146, + 131, + 141, + 102, + 105, + 147, + 134, + 85, + 114, + 101, + 122, + 94, + 139, + 91, + 128, + 121, + 117, + 124, + 114, + 109, + 109, + 157, + 85, + 146, + 159, + 86, + 131, + 93, + 152, + 94, + 100, + 114, + 145, + 81, + 114, + 157, + 99, + 89, + 104, + 145, + 104, + 133, + 112, + 115, + 93, + 144, + 159, + 133, + 131, + 82, + 120, + 85, + 104, + 92, + 87, + 100, + 147, + 104, + 158, + 147, + 124, + 139, + 91, + 111, + 92, + 110, + 157, + 99, + 97, + 139, + 142, + 85, + 111, + 122, + 81, + 139, + 102, + 110, + 84, + 113, + 89, + 81, + 90, + 129, + 110, + 135, + 158, + 152, + 88, + 100, + 152, + 107, + 141, + 86, + 159, + 95, + 83, + 83, + 119, + 96, + 131, + 131, + 115, + 117, + 129, + 142, + 142, + 141, + 86, + 106, + 93, + 150, + 84, + 134, + 81, + 141, + 100, + 158, + 100, + 89, + 137, + 118, + 84, + 144, + 92, + 132, + 124, + 111, + 115, + 160, + 139, + 90, + 142, + 100, + 131, + 134, + 105, + 147, + 113, + 131, + 131, + 108, + 117, + 121, + 100, + 121, + 92, + 143, + 96, + 88, + 120, + 80, + 81, + 139, + 127, + 108, + 89, + 137, + 139, + 100, + 107, + 154, + 113, + 87, + 138, + 151, + 128, + 158, + 96, + 139, + 124, + 148, + 111, + 123, + 144, + 87, + 156, + 151, + 122, + 80, + 130, + 103, + 103, + 96, + 84, + 99, + 100, + 137, + 110, + 90, + 131, + 94, + 86, + 107, + 130, + 137, + 84, + 84, + 87, + 143, + 89, + 117, + 160, + 87, + 104, + 152, + 145, + 158, + 129, + 146 + ], + [ + 128, + 115, + 147, + 157, + 120, + 141, + 89, + 100, + 115, + 91, + 81, + 92, + 132, + 111, + 95, + 111, + 111, + 90, + 89, + 94, + 96, + 119, + 97, + 134, + 152, + 104, + 138, + 157, + 101, + 109, + 123, + 111, + 130, + 125, + 100, + 111, + 118, + 91, + 126, + 97, + 100, + 92, + 100, + 87, + 125, + 86, + 102, + 154, + 101, + 101, + 108, + 92, + 81, + 108, + 158, + 131, + 106, + 129, + 158, + 102, + 123, + 156, + 101, + 122, + 116, + 140, + 135, + 82, + 142, + 98, + 102, + 108, + 109, + 118, + 104, + 87, + 137, + 95, + 151, + 80, + 84, + 82, + 144, + 83, + 143, + 89, + 100, + 125, + 106, + 145, + 120, + 92, + 88, + 115, + 137, + 133, + 91, + 160, + 157, + 93, + 113, + 132, + 99, + 154, + 95, + 149, + 135, + 137, + 83, + 83, + 99, + 89, + 103, + 105, + 116, + 133, + 100, + 153, + 117, + 125, + 83, + 139, + 136, + 124, + 99, + 96, + 150, + 83, + 93, + 82, + 159, + 137, + 134, + 83, + 103, + 113, + 109, + 98, + 110, + 139, + 139, + 98, + 128, + 142, + 101, + 103, + 155, + 109, + 137, + 110, + 143, + 98, + 127, + 124, + 134, + 134, + 97, + 142, + 126, + 154, + 136, + 124, + 112, + 96, + 149, + 83, + 99, + 126, + 94, + 149, + 126, + 157, + 154, + 88, + 95, + 148, + 140, + 89, + 129, + 82, + 156, + 149, + 160, + 130, + 98, + 121, + 114, + 125, + 105, + 148, + 143, + 116, + 128, + 100, + 153, + 143, + 84, + 149, + 83, + 118, + 114, + 80, + 103, + 136, + 149, + 120, + 133, + 125, + 139, + 139, + 136, + 110, + 103, + 115, + 109, + 103, + 109, + 86, + 114, + 134, + 93, + 89, + 159, + 95, + 133, + 156, + 149, + 123, + 152, + 124, + 135, + 136, + 100, + 147, + 129, + 112, + 86, + 149, + 94, + 111, + 154, + 99, + 94, + 127, + 150, + 100, + 121, + 105, + 93, + 136, + 116, + 137, + 100, + 148, + 82, + 116 + ], + [ + 128, + 98, + 91, + 140, + 120, + 112, + 129, + 147, + 150, + 112, + 160, + 103, + 140, + 140, + 107, + 137, + 122, + 120, + 118, + 128, + 96, + 149, + 149, + 144, + 153, + 84, + 81, + 101, + 136, + 145, + 153, + 95, + 130, + 118, + 155, + 119, + 117, + 139, + 138, + 89, + 143, + 142, + 128, + 143, + 111, + 86, + 100, + 99, + 105, + 147, + 94, + 158, + 93, + 106, + 119, + 137, + 107, + 105, + 142, + 145, + 104, + 82, + 151, + 123, + 121, + 85, + 118, + 131, + 102, + 129, + 122, + 98, + 148, + 130, + 90, + 91, + 153, + 87, + 103, + 120, + 115, + 151, + 134, + 88, + 126, + 94, + 127, + 121, + 139, + 149, + 91, + 81, + 86, + 159, + 82, + 82, + 106, + 119, + 112, + 80, + 87, + 111, + 127, + 83, + 149, + 98, + 111, + 88, + 91, + 90, + 125, + 95, + 80, + 142, + 100, + 109, + 92, + 113, + 83, + 107, + 138, + 106, + 148, + 104, + 82, + 147, + 104, + 99, + 113, + 105, + 84, + 140, + 88, + 129, + 85, + 139, + 123, + 90, + 145, + 126, + 152, + 82, + 155, + 109, + 100, + 107, + 141, + 112, + 114, + 131, + 113, + 120, + 130, + 126, + 137, + 154, + 136, + 102, + 127, + 101, + 96, + 148, + 90, + 94, + 114, + 112, + 95, + 113, + 145, + 119, + 107, + 126, + 118, + 149, + 152, + 99, + 80, + 151, + 142, + 146, + 83, + 137, + 104, + 126, + 117, + 94, + 142, + 100, + 136, + 132, + 125, + 93, + 155, + 96, + 118, + 88, + 156, + 158, + 85, + 102, + 138, + 125, + 117, + 93, + 151, + 141, + 114, + 98, + 109, + 157, + 104, + 119, + 93, + 105, + 101, + 127, + 130, + 82, + 121, + 98, + 152, + 107, + 124, + 150, + 157, + 142, + 152, + 98, + 126, + 93, + 117, + 132, + 140, + 115, + 119, + 160, + 109, + 88, + 143, + 112, + 131, + 93, + 154, + 132, + 138, + 157, + 120, + 147, + 147, + 87, + 81, + 120, + 134, + 95, + 141, + 127 + ], + [ + 109, + 91, + 114, + 90, + 143, + 129, + 100, + 148, + 127, + 99, + 146, + 123, + 112, + 152, + 140, + 88, + 147, + 114, + 160, + 144, + 87, + 124, + 126, + 119, + 155, + 133, + 155, + 125, + 126, + 94, + 98, + 105, + 131, + 101, + 100, + 130, + 117, + 99, + 108, + 130, + 83, + 134, + 150, + 107, + 83, + 101, + 155, + 124, + 146, + 145, + 113, + 134, + 80, + 151, + 117, + 85, + 104, + 96, + 140, + 91, + 85, + 106, + 87, + 144, + 136, + 119, + 93, + 114, + 127, + 135, + 128, + 92, + 140, + 128, + 125, + 141, + 87, + 122, + 99, + 118, + 128, + 143, + 141, + 133, + 88, + 116, + 128, + 120, + 96, + 100, + 130, + 138, + 118, + 134, + 120, + 118, + 92, + 143, + 90, + 98, + 90, + 156, + 125, + 124, + 80, + 119, + 98, + 155, + 151, + 151, + 140, + 91, + 140, + 112, + 124, + 148, + 87, + 104, + 125, + 109, + 126, + 128, + 126, + 97, + 145, + 147, + 112, + 115, + 98, + 107, + 119, + 155, + 94, + 120, + 118, + 81, + 142, + 108, + 153, + 145, + 130, + 154, + 156, + 89, + 93, + 130, + 152, + 118, + 107, + 82, + 149, + 106, + 81, + 117, + 82, + 95, + 119, + 81, + 160, + 120, + 91, + 114, + 144, + 134, + 124, + 158, + 141, + 128, + 104, + 102, + 103, + 117, + 107, + 100, + 133, + 97, + 115, + 124, + 152, + 132, + 109, + 139, + 111, + 146, + 146, + 107, + 123, + 158, + 120, + 142, + 102, + 154, + 127, + 111, + 136, + 112, + 103, + 118, + 113, + 135, + 107, + 116, + 86, + 147, + 116, + 135, + 144, + 144, + 126, + 99, + 98, + 108, + 153, + 139, + 97, + 114, + 122, + 90, + 128, + 117, + 143, + 148, + 87, + 113, + 127, + 135, + 155, + 144, + 110, + 83, + 129, + 148, + 109, + 138, + 146, + 154, + 124, + 83, + 105, + 106, + 125, + 143, + 130, + 129, + 85, + 138, + 94, + 109, + 121, + 116, + 121, + 114, + 85, + 147, + 97, + 152 + ] + ] +} \ No newline at end of file diff --git a/lmdeploy/pytorch/backends/cuda/moe.py b/lmdeploy/pytorch/backends/cuda/moe.py index 4c7e37479b..04dc658e9c 100644 --- a/lmdeploy/pytorch/backends/cuda/moe.py +++ b/lmdeploy/pytorch/backends/cuda/moe.py @@ -545,6 +545,9 @@ def ep_expert_list(self, world_size: int, rank: int, num_groups: int=None, num_n first_expert = rank * expert_per_rank last_expert = min(first_expert + expert_per_rank, self.num_experts) sliced_phy2log = self.phy2log[first_expert:last_expert].tolist() + if rank == 0: + logger.info(f"ep_expert_list: {sliced_phy2log}") + logger.info(f"len_ep_expert_list: {len(sliced_phy2log)}") return sliced_phy2log else: diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py index 12f8cad5e5..e856fc8a19 100644 --- a/lmdeploy/pytorch/models/deepseek_v2.py +++ b/lmdeploy/pytorch/models/deepseek_v2.py @@ -1227,6 +1227,7 @@ def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_di param = params_dict[name] load_weight(param, loaded_weight) + def _load_weight_attention(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter], update_pe_mapping: List): """load weight attention.""" diff --git a/lmdeploy/pytorch/weight_loader/model_weight_loader.py b/lmdeploy/pytorch/weight_loader/model_weight_loader.py index e242cdbf58..e5d7a66cec 100644 --- a/lmdeploy/pytorch/weight_loader/model_weight_loader.py +++ b/lmdeploy/pytorch/weight_loader/model_weight_loader.py @@ -16,18 +16,18 @@ def load_weight(param: torch.nn.Parameter, loaded_weight: torch.Tensor, **kwargs): """load weight.""" - # expert_id = kwargs.get('expert_id', None) - # # for debug - # shard_id = kwargs.get('shard_id', '?') - # rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 - - # if expert_id is not None and hasattr(param, 'expert_list'): - # if expert_id not in param.expert_list: - # print(f"[Rank {rank}] 🔁 Skip Expert {expert_id} for param {param.shape}") - # return - # else: - # layer_idx = getattr(param, 'layer_idx', '?') - # print(f"[Rank {rank}] ✅ Load Expert {expert_id} for Layer {layer_idx} ({shard_id})") + expert_id = kwargs.get('expert_id', None) + # for debug + shard_id = kwargs.get('shard_id', '?') + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + + if expert_id is not None and hasattr(param, 'expert_list'): + if expert_id not in param.expert_list: + print(f"[Rank {rank}] 🔁 Skip Expert {expert_id} for param {param.shape}") + return + else: + layer_idx = getattr(param, 'layer_idx', '?') + print(f"[Rank {rank}] ✅ Load Expert {expert_id} for Layer {layer_idx} ({shard_id})") if hasattr(param, 'weight_loader'): param.weight_loader(param, loaded_weight, **kwargs) From 571e2e0e26a1e0efe4722f98c3c97a289f7e1dd6 Mon Sep 17 00:00:00 2001 From: Ychoukewen <2369589402@qq.com> Date: Thu, 8 May 2025 10:37:33 +0000 Subject: [PATCH 10/12] add expert_list() to FusedDeepEpMoEBlockedF8Impl class and run well for num_hidden_layer=4 --- lmdeploy/pytorch/backends/cuda/moe.py | 446 +++++++----------- lmdeploy/pytorch/models/deepseek_v2.py | 33 +- lmdeploy/pytorch/nn/moe.py | 6 +- .../weight_loader/model_weight_loader.py | 5 +- 4 files changed, 209 insertions(+), 281 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/moe.py b/lmdeploy/pytorch/backends/cuda/moe.py index 04dc658e9c..42d3bf6a42 100644 --- a/lmdeploy/pytorch/backends/cuda/moe.py +++ b/lmdeploy/pytorch/backends/cuda/moe.py @@ -410,7 +410,9 @@ def __init__(self, hidden_dim: int, block_size: int = 128, out_dtype: torch.dtype = torch.bfloat16, - layer_idx: int = 0): + layer_idx: int = 0, + log2phy: torch.Tensor = None, + logcnt: torch.Tensor = None): self.layer_idx = layer_idx self.experts = DeepEPExpertsGroupedGEMM(num_experts, ep_size, [block_size, block_size]) self.token_dispatcher = TokenDispatcherBuilder.build( @@ -420,142 +422,8 @@ def __init__(self, hidden_size=hidden_dim, params_dtype=out_dtype, ) - - def balanced_packing(self, weight: torch.Tensor, num_packs: int) -> Tuple[torch.Tensor, torch.Tensor]: - - num_layers, num_groups = weight.shape - assert num_groups % num_packs == 0 - groups_per_pack = num_groups // num_packs - - if groups_per_pack == 1: - pack_index = torch.arange(weight.size(-1), dtype=torch.int64, device=weight.device).expand(weight.shape) - rank_in_pack = torch.zeros_like(weight, dtype=torch.int64) - return pack_index, rank_in_pack - - indices = weight.float().sort(-1, descending=True).indices.cpu() - pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device='cpu') - rank_in_pack = torch.full_like(pack_index, fill_value=-1) - for i in range(num_layers): - pack_weights = [0] * num_packs - pack_items = [0] * num_packs - for group in indices[i]: - pack = min((i for i in range(num_packs) if pack_items[i] < groups_per_pack), - key=pack_weights.__getitem__) - assert pack_items[pack] < groups_per_pack - pack_index[i, group] = pack - rank_in_pack[i, group] = pack_items[pack] - pack_weights[pack] += weight[i, group] - pack_items[pack] += 1 - return pack_index, rank_in_pack - - def replicate_experts(self, weight: torch.Tensor, num_phy: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - n, num_log = weight.shape - num_redundant = num_phy - num_log - assert num_redundant >= 0 - device = weight.device - phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1) - rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) - logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) - arangen = torch.arange(n, dtype=torch.int64, device=device) - for i in range(num_log, num_phy): - redundant_indices = (weight / logcnt).max(dim=-1).indices - phy2log[:, i] = redundant_indices - rank[:, i] = logcnt[arangen, redundant_indices] - logcnt[arangen, redundant_indices] += 1 - return phy2log, rank, logcnt - - def rebalance_experts_hierarchical(self, weight: torch.Tensor, num_physical_experts: int, num_groups: int, - num_nodes: int, num_gpus: int): - - num_layers, num_logical_experts = weight.shape - assert num_logical_experts % num_groups == 0 - group_size = num_logical_experts // num_groups - assert num_groups % num_nodes == 0 - groups_per_node = num_groups // num_nodes - assert num_gpus % num_nodes == 0 - assert num_physical_experts % num_gpus == 0 - phy_experts_per_gpu = num_physical_experts // num_gpus - - def inverse(perm: torch.Tensor) -> torch.Tensor: - inv = torch.empty_like(perm) - inv.scatter_(1, perm, torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand(perm.shape)) - return inv - - tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) - group_pack_index, group_rank_in_pack = self.balanced_packing(tokens_per_group, num_nodes) - log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) * group_size).unsqueeze(-1) + - torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device)).flatten(-2) - mlog2log = inverse(log2mlog) - tokens_per_mlog = weight.gather(-1, mlog2log).view(-1, num_logical_experts // num_nodes) - phy2mlog, phyrank, mlogcnt = self.replicate_experts(tokens_per_mlog, num_physical_experts // num_nodes) - tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) - pack_index, rank_in_pack = self.balanced_packing(tokens_per_phy, num_gpus // num_nodes) - phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack - pphy2phy = inverse(phy2pphy) - pphy2mlog = phy2mlog.gather(-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes] - pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + - torch.arange(0, num_logical_experts, num_logical_experts // num_nodes).view(1, -1, 1)).flatten(-2) - pphy2log = mlog2log.gather(-1, pphy2mlog) - pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) - logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) - return pphy2log, pphyrank, logcnt - - def rebalance_experts(self, weight: torch.Tensor, num_replicas: int, num_groups: int, num_nodes: int, - num_gpus: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - num_layers, num_logical_experts = weight.shape - weight = weight.float().cpu() - if num_groups % num_nodes == 0: - # use hierarchical load-balance policy - phy2log, phyrank, logcnt = self.rebalance_experts_hierarchical(weight, num_replicas, num_groups, num_nodes, - num_gpus) - else: - # use global load-balance policy - phy2log, phyrank, logcnt = self.replicate_experts(weight, num_replicas) - maxlogcnt = logcnt.max().item() - log2phy: torch.Tensor = torch.full((num_layers, num_logical_experts, maxlogcnt), - -1, - dtype=torch.int64, - device=logcnt.device) - log2phy.view(num_layers, -1).scatter_( - -1, phy2log * maxlogcnt + phyrank, - torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(num_layers, -1)) - return phy2log, log2phy, logcnt - - def _load_ep_mapping(self, json_path: str): - """Load expert partition metadata from JSON (class-internal).""" - import json - with open(json_path, 'r') as f: - data = json.load(f) - num_groups = data['num_groups'] - num_nodes = data['num_nodes'] - weight = torch.tensor(data['weight'], dtype=torch.float32, device='cuda') - return num_groups, num_nodes, weight - - def ep_expert_list(self, world_size: int, rank: int, num_groups: int=None, num_nodes: int=None, weight: torch.Tensor=None): - """experts list of current rank.""" - if enable_eplb: - num_groups, num_nodes, weight = self._load_ep_mapping("ep_mapping_json_prefill.json") - phy2log, log2phy, logcnt = self.rebalance_experts(weight, num_experts, num_groups, num_nodes, self.world_size) - self.phy2log = phy2log[self.layer_idx].to('cuda') - self.log2phy = log2phy[self.layer_idx].to('cuda') - self.logcnt = logcnt[self.layer_idx].to('cuda') - expert_per_rank = (self.num_experts + world_size - 1) // world_size - first_expert = rank * expert_per_rank - last_expert = min(first_expert + expert_per_rank, self.num_experts) - sliced_phy2log = self.phy2log[first_expert:last_expert].tolist() - if rank == 0: - logger.info(f"ep_expert_list: {sliced_phy2log}") - logger.info(f"len_ep_expert_list: {len(sliced_phy2log)}") - - return sliced_phy2log - else: - num_experts = self.num_experts - expert_per_rank = (num_experts + world_size - 1) // world_size - first_expert = rank * expert_per_rank - last_expert = min(first_expert + expert_per_rank, num_experts) - return list(range(first_expert, last_expert)) + self.log2phy = log2phy + self.logcnt = logcnt def forward(self, hidden_states: torch.Tensor, @@ -640,7 +508,9 @@ def __init__(self, hidden_dim: int, block_size: int = 128, out_dtype: torch.dtype = torch.bfloat16, - layer_idx: int = 0): + layer_idx: int = 0, + log2phy: torch.Tensor = None, + logcnt: torch.Tensor = None): self.layer_idx = layer_idx self.num_experts = num_experts self.experts = DeepEPExpertsDeepGEMM(num_experts, ep_size, block_size, out_dtype) @@ -651,136 +521,8 @@ def __init__(self, hidden_size=hidden_dim, params_dtype=out_dtype, ) - - def balanced_packing(self, weight: torch.Tensor, num_packs: int) -> Tuple[torch.Tensor, torch.Tensor]: - - num_layers, num_groups = weight.shape - assert num_groups % num_packs == 0 - groups_per_pack = num_groups // num_packs - - if groups_per_pack == 1: - pack_index = torch.arange(weight.size(-1), dtype=torch.int64, device=weight.device).expand(weight.shape) - rank_in_pack = torch.zeros_like(weight, dtype=torch.int64) - return pack_index, rank_in_pack - - indices = weight.float().sort(-1, descending=True).indices.cpu() - pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device='cpu') - rank_in_pack = torch.full_like(pack_index, fill_value=-1) - for i in range(num_layers): - pack_weights = [0] * num_packs - pack_items = [0] * num_packs - for group in indices[i]: - pack = min((i for i in range(num_packs) if pack_items[i] < groups_per_pack), - key=pack_weights.__getitem__) - assert pack_items[pack] < groups_per_pack - pack_index[i, group] = pack - rank_in_pack[i, group] = pack_items[pack] - pack_weights[pack] += weight[i, group] - pack_items[pack] += 1 - return pack_index, rank_in_pack - - - def replicate_experts(self, weight: torch.Tensor, num_phy: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - n, num_log = weight.shape - num_redundant = num_phy - num_log - assert num_redundant >= 0 - device = weight.device - phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1) - rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) - logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) - arangen = torch.arange(n, dtype=torch.int64, device=device) - for i in range(num_log, num_phy): - redundant_indices = (weight / logcnt).max(dim=-1).indices - phy2log[:, i] = redundant_indices - rank[:, i] = logcnt[arangen, redundant_indices] - logcnt[arangen, redundant_indices] += 1 - return phy2log, rank, logcnt - - - def rebalance_experts_hierarchical(self, weight: torch.Tensor, num_physical_experts: int, num_groups: int, num_nodes: int, num_gpus: int): - - num_layers, num_logical_experts = weight.shape - assert num_logical_experts % num_groups == 0 - group_size = num_logical_experts // num_groups - assert num_groups % num_nodes == 0 - groups_per_node = num_groups // num_nodes - assert num_gpus % num_nodes == 0 - assert num_physical_experts % num_gpus == 0 - phy_experts_per_gpu = num_physical_experts // num_gpus - - def inverse(perm: torch.Tensor) -> torch.Tensor: - inv = torch.empty_like(perm) - inv.scatter_(1, perm, torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand(perm.shape)) - return inv - - tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) - group_pack_index, group_rank_in_pack = self.balanced_packing(tokens_per_group, num_nodes) - log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) * group_size).unsqueeze(-1) + - torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device)).flatten(-2) - mlog2log = inverse(log2mlog) - tokens_per_mlog = weight.gather(-1, mlog2log).view(-1, num_logical_experts // num_nodes) - phy2mlog, phyrank, mlogcnt = self.replicate_experts(tokens_per_mlog, num_physical_experts // num_nodes) - tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) - pack_index, rank_in_pack = self.balanced_packing(tokens_per_phy, num_gpus // num_nodes) - phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack - pphy2phy = inverse(phy2pphy) - pphy2mlog = phy2mlog.gather(-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes] - pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + - torch.arange(0, num_logical_experts, num_logical_experts // num_nodes).view(1, -1, 1)).flatten(-2) - pphy2log = mlog2log.gather(-1, pphy2mlog) - pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) - logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) - return pphy2log, pphyrank, logcnt - - def rebalance_experts(self, weight: torch.Tensor, num_replicas: int, num_groups: int, - num_nodes: int, num_gpus: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - num_layers, num_logical_experts = weight.shape - weight = weight.float().cpu() - if num_groups % num_nodes == 0: - # use hierarchical load-balance policy - phy2log, phyrank, logcnt = self.rebalance_experts_hierarchical(weight, num_replicas, - num_groups, num_nodes, num_gpus) - else: - # use global load-balance policy - phy2log, phyrank, logcnt = self.replicate_experts(weight, num_replicas) - maxlogcnt = logcnt.max().item() - log2phy: torch.Tensor = torch.full((num_layers, num_logical_experts, maxlogcnt), - -1, dtype=torch.int64, device=logcnt.device) - log2phy.view(num_layers, -1).scatter_(-1, phy2log * maxlogcnt + phyrank, - torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(num_layers, -1)) - return phy2log, log2phy, logcnt - - def _load_ep_mapping(self, json_path: str): - """Load expert partition metadata from JSON (class-internal).""" - import json - with open(json_path, 'r') as f: - data = json.load(f) - num_groups = data['num_groups'] - num_nodes = data['num_nodes'] - weight = torch.tensor(data['weight'], dtype=torch.float32, device='cuda') - return num_groups, num_nodes, weight - - def ep_expert_list(self, world_size: int, rank: int): - """experts list of current rank.""" - if enable_eplb: - num_groups, num_nodes, weight = self._load_ep_mapping("ep_mapping_json_decode.json") - phy2log, log2phy, logcnt = self.rebalance_experts(weight, num_experts, num_groups, num_nodes, world_size) - self.phy2log = phy2log[self.layer_idx].to('cuda') - self.log2phy = log2phy[self.layer_idx].to('cuda') - self.logcnt = logcnt[self.layer_idx].to('cuda') - expert_per_rank = (self.num_experts + world_size - 1) // world_size - first_expert = rank * expert_per_rank - last_expert = min(first_expert + expert_per_rank, self.num_experts) - sliced_phy2log = self.phy2log[first_expert:last_expert].tolist() - return sliced_phy2log - else: - num_experts = self.num_experts - expert_per_rank = (num_experts + world_size - 1) // world_size - first_expert = rank * expert_per_rank - last_expert = min(first_expert + expert_per_rank, num_experts) - return list(range(first_expert, last_expert)) + self.log2phy = log2phy + self.logcnt = logcnt def forward(self, hidden_states: torch.Tensor, @@ -862,6 +604,9 @@ def __init__(self, self.block_size = block_size self.out_dtype = out_dtype self.layer_idx = layer_idx + self.log2phy = None + self.logcnt = None + self.phy2log = None try: import deep_gemm DeepEPExpertsDeepGEMM.deep_gemm = deep_gemm @@ -870,14 +615,17 @@ def __init__(self, self.use_deep_gemm = False logger.warning('For higher performance, please install DeepGEMM https://github.com/deepseek-ai/DeepGEMM') - try: - from dlblas.layers.moe.ep_moe import build_deepep_moe - self.use_dlblas = True - self.build_deepep_moe = build_deepep_moe - except ImportError: - self.use_dlblas = False - logger.warning('For higher performance, please install dlBLAS https://github.com/DeepLink-org/dlBLAS') + # try: + # from dlblas.layers.moe.ep_moe import build_deepep_moe + # self.use_dlblas = True + # self.build_deepep_moe = build_deepep_moe + # except ImportError: + # self.use_dlblas = False + # logger.warning('For higher performance, please install dlBLAS https://github.com/DeepLink-org/dlBLAS') + self.use_dlblas = False + logger.warning('For higher performance, please install dlBLAS https://github.com/DeepLink-org/dlBLAS') + def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, @@ -899,6 +647,150 @@ def forward(self, def do_renormalize(self, topk_weights): return _renormalize(topk_weights, self.renormalize) + def balanced_packing(self, weight: torch.Tensor, num_packs: int) -> Tuple[torch.Tensor, torch.Tensor]: + + num_layers, num_groups = weight.shape + assert num_groups % num_packs == 0 + groups_per_pack = num_groups // num_packs + + if groups_per_pack == 1: + pack_index = torch.arange(weight.size(-1), dtype=torch.int64, device=weight.device).expand(weight.shape) + rank_in_pack = torch.zeros_like(weight, dtype=torch.int64) + return pack_index, rank_in_pack + + indices = weight.float().sort(-1, descending=True).indices.cpu() + pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device='cpu') + rank_in_pack = torch.full_like(pack_index, fill_value=-1) + for i in range(num_layers): + pack_weights = [0] * num_packs + pack_items = [0] * num_packs + for group in indices[i]: + pack = min((i for i in range(num_packs) if pack_items[i] < groups_per_pack), + key=pack_weights.__getitem__) + assert pack_items[pack] < groups_per_pack + pack_index[i, group] = pack + rank_in_pack[i, group] = pack_items[pack] + pack_weights[pack] += weight[i, group] + pack_items[pack] += 1 + return pack_index, rank_in_pack + + def replicate_experts(self, weight: torch.Tensor, num_phy: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + n, num_log = weight.shape + num_redundant = num_phy - num_log + assert num_redundant >= 0 + device = weight.device + phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1) + rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) + logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) + arangen = torch.arange(n, dtype=torch.int64, device=device) + for i in range(num_log, num_phy): + redundant_indices = (weight / logcnt).max(dim=-1).indices + phy2log[:, i] = redundant_indices + rank[:, i] = logcnt[arangen, redundant_indices] + logcnt[arangen, redundant_indices] += 1 + return phy2log, rank, logcnt + + def rebalance_experts_hierarchical(self, weight: torch.Tensor, num_physical_experts: int, num_groups: int, + num_nodes: int, num_gpus: int): + + num_layers, num_logical_experts = weight.shape + assert num_logical_experts % num_groups == 0 + group_size = num_logical_experts // num_groups + assert num_groups % num_nodes == 0 + groups_per_node = num_groups // num_nodes + assert num_gpus % num_nodes == 0 + assert num_physical_experts % num_gpus == 0 + phy_experts_per_gpu = num_physical_experts // num_gpus + + def inverse(perm: torch.Tensor) -> torch.Tensor: + inv = torch.empty_like(perm) + inv.scatter_(1, perm, torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand(perm.shape)) + return inv + + tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) + group_pack_index, group_rank_in_pack = self.balanced_packing(tokens_per_group, num_nodes) + log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) * group_size).unsqueeze(-1) + + torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device)).flatten(-2) + mlog2log = inverse(log2mlog) + tokens_per_mlog = weight.gather(-1, mlog2log).view(-1, num_logical_experts // num_nodes) + phy2mlog, phyrank, mlogcnt = self.replicate_experts(tokens_per_mlog, num_physical_experts // num_nodes) + tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) + pack_index, rank_in_pack = self.balanced_packing(tokens_per_phy, num_gpus // num_nodes) + phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack + pphy2phy = inverse(phy2pphy) + pphy2mlog = phy2mlog.gather(-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes] + pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + + torch.arange(0, num_logical_experts, num_logical_experts // num_nodes).view(1, -1, 1)).flatten(-2) + pphy2log = mlog2log.gather(-1, pphy2mlog) + pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) + logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) + return pphy2log, pphyrank, logcnt + + def rebalance_experts(self, weight: torch.Tensor, num_replicas: int, num_groups: int, num_nodes: int, + num_gpus: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + num_layers, num_logical_experts = weight.shape + weight = weight.float().cpu() + if num_groups % num_nodes == 0: + # use hierarchical load-balance policy + phy2log, phyrank, logcnt = self.rebalance_experts_hierarchical(weight, num_replicas, num_groups, num_nodes, + num_gpus) + else: + # use global load-balance policy + phy2log, phyrank, logcnt = self.replicate_experts(weight, num_replicas) + maxlogcnt = logcnt.max().item() + log2phy: torch.Tensor = torch.full((num_layers, num_logical_experts, maxlogcnt), + -1, + dtype=torch.int64, + device=logcnt.device) + log2phy.view(num_layers, -1).scatter_( + -1, phy2log * maxlogcnt + phyrank, + torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(num_layers, -1)) + return phy2log, log2phy, logcnt + + def _load_ep_mapping(self, json_path: str): + """Load expert partition metadata from JSON (class-internal).""" + import json + with open(json_path, 'r') as f: + data = json.load(f) + num_groups = data['num_groups'] + num_nodes = data['num_nodes'] + weight = torch.tensor(data['weight'], dtype=torch.float32, device='cuda') + return num_groups, num_nodes, weight + + def ep_expert_list(self, world_size: int, rank: int, num_groups: int=None, num_nodes: int=None, weight: torch.Tensor=None): + """experts list of current rank.""" + print("=======================ep_expert_list in FusedDeepEpMoEBlockedF8Impl is called======================") + if enable_eplb: + # step_ctx = get_step_ctx_manager().current_context() + # low_latency_mode = step_ctx.is_decoding and self.use_deep_gemm + # TO DO : 在初始化阶段区分是prefill还是decode,需要加载不同的json文件 + # if not low_latency_mode: + # num_groups, num_nodes, weight = self._load_ep_mapping("ep_mapping_json_prefill.json") + # else: + # num_groups, num_nodes, weight = self._load_ep_mapping("ep_mapping_json_decode.json") + num_groups, num_nodes, weight = self._load_ep_mapping("ep_mapping_json_prefill.json") + phy2log, log2phy, logcnt = self.rebalance_experts(weight, self.num_experts, num_groups, num_nodes, world_size) + self.phy2log = phy2log[self.layer_idx].to('cuda') + self.log2phy = log2phy[self.layer_idx].to('cuda') + self.logcnt = logcnt[self.layer_idx].to('cuda') + expert_per_rank = (self.num_experts + world_size - 1) // world_size + first_expert = rank * expert_per_rank + last_expert = min(first_expert + expert_per_rank, self.num_experts) + sliced_phy2log = self.phy2log[first_expert:last_expert].tolist() + if rank == 0: + logger.info(f"ep_expert_list: {sliced_phy2log}") + logger.info(f"len_ep_expert_list: {len(sliced_phy2log)}") + + return sliced_phy2log + else: + num_experts = self.num_experts + expert_per_rank = (num_experts + world_size - 1) // world_size + first_expert = rank * expert_per_rank + last_expert = min(first_expert + expert_per_rank, num_experts) + return list(range(first_expert, last_expert)) + def fusedmoe_build(self, low_latency_mode: bool = False): if self.use_dlblas: return self.build_deepep_moe(low_latency_mode, @@ -913,10 +805,10 @@ def fusedmoe_build(self, low_latency_mode: bool = False): chunk_size=16 * 1024) elif low_latency_mode: return FusedMoELowLatency(self.ep_size, self.ep_group, self.num_experts, self.hidden_dim, self.block_size, - self.out_dtype, layer_idx=self.layer_idx) + self.out_dtype, self.layer_idx, self.log2phy, self.logcnt) else: return FusedMoENormal(self.ep_size, self.ep_group, self.num_experts, self.hidden_dim, self.block_size, - self.out_dtype, layer_idx=self.layer_idx) + self.out_dtype, self.layer_idx, self.log2phy, self.logcnt) class TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder): diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py index e856fc8a19..5f6c11bfa7 100644 --- a/lmdeploy/pytorch/models/deepseek_v2.py +++ b/lmdeploy/pytorch/models/deepseek_v2.py @@ -1227,7 +1227,38 @@ def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_di param = params_dict[name] load_weight(param, loaded_weight) - + def _load_weight_experts_with_eplb(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter], + layer_expert_params_mapping: Dict[int, List]): + """Load weight for experts with layer-wise expert mapping.""" + # 从name解析layer_idx + strs = name.split(".") + if len(strs) < 3 or not strs[2].isdigit(): + raise ValueError(f"Cannot parse layer index from weight name {name}") + layer_idx = int(strs[2]) + + # 如果不是MoE层(例如普通Dense层),直接按原来逻辑加载 + if layer_idx not in layer_expert_params_mapping: + param = params_dict[name] + load_weight(param, loaded_weight) + return + + # 找到该层的expert参数映射 + expert_params_mapping = layer_expert_params_mapping[layer_idx] + + # 在该层的专家映射中匹配 + for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping: + if weight_name not in name: + continue + # 匹配成功后,把weight_name替换为param_name,找到目标参数 + new_name = name.replace(weight_name, param_name) + param = params_dict[new_name] + load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id) + break + else: + # 如果这一层里也没匹配上,就直接按普通参数处理 + param = params_dict[name] + load_weight(param, loaded_weight) + def _load_weight_attention(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter], update_pe_mapping: List): """load weight attention.""" diff --git a/lmdeploy/pytorch/nn/moe.py b/lmdeploy/pytorch/nn/moe.py index eb9a3f3a0e..e60f85b383 100644 --- a/lmdeploy/pytorch/nn/moe.py +++ b/lmdeploy/pytorch/nn/moe.py @@ -69,6 +69,7 @@ def __init__(self, ep: bool = False, layer_idx: int = 0): super().__init__() + self.layer_idx = layer_idx weight = torch.empty((num_experts, out_features, in_features), dtype=dtype, device=device) weight = torch.nn.Parameter(weight, requires_grad=False) if ep and enable_eplb and expert_list is not None: @@ -120,6 +121,7 @@ def weight_loader_tp(self, param: torch.nn.Parameter, loaded_weight: torch.Tenso def weight_loader_ep(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, shard_id: str): """weight loader.""" + world_size, rank = get_tp_world_rank() expert_list = self.expert_list if expert_id not in expert_list: return @@ -169,7 +171,7 @@ def weight_loader_ep(self, param: torch.nn.Parameter, loaded_weight: torch.Tenso else: raise RuntimeError(f'Unknown shard_id: {shard_id}') param_data.copy_(loaded_weight.to(param_data.dtype)) - print(f"[Rank {rank}] ✅ Loaded Expert {expert_id} for Layer {layer_idx} ({shard_id}) shape={param_data.shape}") + # print(f"[Rank {rank}] ✅ Loaded Expert {expert_id} for Layer {self.layer_idx} ({shard_id}) shape={param_data.shape}") def _gather_input(x: torch.Tensor, tp_sizes: List[int]): @@ -566,6 +568,8 @@ def __init__(self, layer_idx=layer_idx) if self.ep_size > 1: + if rank == 0: + print("================================geting expert list==========================") expert_list = self.impl.ep_expert_list(self.ep_size, rank) num_experts = len(expert_list) else: diff --git a/lmdeploy/pytorch/weight_loader/model_weight_loader.py b/lmdeploy/pytorch/weight_loader/model_weight_loader.py index e5d7a66cec..167086d364 100644 --- a/lmdeploy/pytorch/weight_loader/model_weight_loader.py +++ b/lmdeploy/pytorch/weight_loader/model_weight_loader.py @@ -23,11 +23,12 @@ def load_weight(param: torch.nn.Parameter, loaded_weight: torch.Tensor, **kwargs if expert_id is not None and hasattr(param, 'expert_list'): if expert_id not in param.expert_list: - print(f"[Rank {rank}] 🔁 Skip Expert {expert_id} for param {param.shape}") + # print(f"[Rank {rank}] 🔁 Skip Expert {expert_id} for param {param.shape}") return else: layer_idx = getattr(param, 'layer_idx', '?') - print(f"[Rank {rank}] ✅ Load Expert {expert_id} for Layer {layer_idx} ({shard_id})") + if rank == 0: + print(f"[Rank {rank}] ✅ Load Expert {expert_id} for Layer {layer_idx} ({shard_id})") if hasattr(param, 'weight_loader'): param.weight_loader(param, loaded_weight, **kwargs) From 1ae2409fe6ffa47fcf4308c150b9f1fdcbd23a47 Mon Sep 17 00:00:00 2001 From: Ychoukewen <2369589402@qq.com> Date: Fri, 9 May 2025 02:56:08 +0000 Subject: [PATCH 11/12] add print and set num_hidden_layers=8 --- ep_mapping_json_decode.json | 2 +- lmdeploy/pytorch/models/deepseek_v2.py | 4 ++-- lmdeploy/pytorch/nn/moe.py | 4 +++- lmdeploy/pytorch/weight_loader/model_weight_loader.py | 5 +++-- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/ep_mapping_json_decode.json b/ep_mapping_json_decode.json index 6bd9f883de..1fb35d46b0 100644 --- a/ep_mapping_json_decode.json +++ b/ep_mapping_json_decode.json @@ -1,5 +1,5 @@ { - "num_groups": 4, + "num_groups": 8, "num_nodes": 1, "weight": [ [ diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py index 5f6c11bfa7..a195738f28 100644 --- a/lmdeploy/pytorch/models/deepseek_v2.py +++ b/lmdeploy/pytorch/models/deepseek_v2.py @@ -1119,7 +1119,7 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None): super().__init__() - config.num_hidden_layers = 4 + config.num_hidden_layers = 8 self.config = config self.quantization_config = getattr(config, 'quantization_config', None) self.dtype = dtype @@ -1425,7 +1425,7 @@ def __skip_nextn(name, nextn_keys): strs = name.split(".") if len(strs) >= 3 and str.isdigit(strs[2]): layer_number = int(strs[2]) - if layer_number >= 4: + if layer_number >= 8: continue # zcx end if 'rotary_emb.inv_freq' in name: diff --git a/lmdeploy/pytorch/nn/moe.py b/lmdeploy/pytorch/nn/moe.py index e60f85b383..45346c5969 100644 --- a/lmdeploy/pytorch/nn/moe.py +++ b/lmdeploy/pytorch/nn/moe.py @@ -171,7 +171,9 @@ def weight_loader_ep(self, param: torch.nn.Parameter, loaded_weight: torch.Tenso else: raise RuntimeError(f'Unknown shard_id: {shard_id}') param_data.copy_(loaded_weight.to(param_data.dtype)) - # print(f"[Rank {rank}] ✅ Loaded Expert {expert_id} for Layer {self.layer_idx} ({shard_id}) shape={param_data.shape}") + # 打印日志:记录每个 rank 每层加载的专家 ID、参数名称和权重形状 + param_name = f"Layer_{self.layer_idx}_{shard_id}_Expert_{expert_id}" + print(f"[Rank {rank}] ✅ Loaded Expert {expert_id} for {param_name} shape={param_data.shape}") def _gather_input(x: torch.Tensor, tp_sizes: List[int]): diff --git a/lmdeploy/pytorch/weight_loader/model_weight_loader.py b/lmdeploy/pytorch/weight_loader/model_weight_loader.py index 167086d364..e2ca7e57a3 100644 --- a/lmdeploy/pytorch/weight_loader/model_weight_loader.py +++ b/lmdeploy/pytorch/weight_loader/model_weight_loader.py @@ -22,13 +22,14 @@ def load_weight(param: torch.nn.Parameter, loaded_weight: torch.Tensor, **kwargs rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 if expert_id is not None and hasattr(param, 'expert_list'): + print("===================load_weight for eplb==================") if expert_id not in param.expert_list: # print(f"[Rank {rank}] 🔁 Skip Expert {expert_id} for param {param.shape}") return else: layer_idx = getattr(param, 'layer_idx', '?') - if rank == 0: - print(f"[Rank {rank}] ✅ Load Expert {expert_id} for Layer {layer_idx} ({shard_id})") + # if rank == 0: + # print(f"[Rank {rank}] ✅ Load Expert {expert_id} for Layer {layer_idx} ({shard_id})") if hasattr(param, 'weight_loader'): param.weight_loader(param, loaded_weight, **kwargs) From 4491c89a144da8a3ae46b63bea5b94c8e12548df Mon Sep 17 00:00:00 2001 From: Ychoukewen <2369589402@qq.com> Date: Mon, 12 May 2025 09:58:12 +0000 Subject: [PATCH 12/12] Commented out the code for modifying model layers; kept the original layer structure --- ep_mapping_json_decode.json | 2 +- lmdeploy/pytorch/models/deepseek_v2.py | 16 ++++++++-------- test.py | 4 ++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/ep_mapping_json_decode.json b/ep_mapping_json_decode.json index 1fb35d46b0..6bd9f883de 100644 --- a/ep_mapping_json_decode.json +++ b/ep_mapping_json_decode.json @@ -1,5 +1,5 @@ { - "num_groups": 8, + "num_groups": 4, "num_nodes": 1, "weight": [ [ diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py index a195738f28..c3cad59abe 100644 --- a/lmdeploy/pytorch/models/deepseek_v2.py +++ b/lmdeploy/pytorch/models/deepseek_v2.py @@ -1119,7 +1119,7 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None): super().__init__() - config.num_hidden_layers = 8 + # config.num_hidden_layers = 8 self.config = config self.quantization_config = getattr(config, 'quantization_config', None) self.dtype = dtype @@ -1421,13 +1421,13 @@ def __skip_nextn(name, nextn_keys): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - # zcx begin - strs = name.split(".") - if len(strs) >= 3 and str.isdigit(strs[2]): - layer_number = int(strs[2]) - if layer_number >= 8: - continue - # zcx end + # # zcx begin + # strs = name.split(".") + # if len(strs) >= 3 and str.isdigit(strs[2]): + # layer_number = int(strs[2]) + # if layer_number >= 8: + # continue + # # zcx end if 'rotary_emb.inv_freq' in name: continue if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name): diff --git a/test.py b/test.py index ea92f5df99..886298ed04 100644 --- a/test.py +++ b/test.py @@ -35,8 +35,8 @@ def main(rank: int): backend_config = PytorchEngineConfig( tp=1, - dp=4, - ep=4, + dp=8, + ep=8, dp_rank=rank, eager_mode=True, )