From 5ebd7ec289504df518989e47136eb0eac4c68dc6 Mon Sep 17 00:00:00 2001 From: xulei1 Date: Mon, 16 Dec 2024 19:21:17 +0800 Subject: [PATCH 1/3] edp with cc+acc is ok --- .../transformer_layer_infer_template.py | 15 +- .../layer_weights/meta_weights/__init__.py | 4 + .../meta_weights/fused_moe_weight.py | 4 +- .../layer_weights/meta_weights/mm_weight.py | 63 ++++++++ lightllm/common/deepseek2_mem_manager.py | 18 ++- .../layer_infer/transformer_layer_infer.py | 68 +++++---- .../layer_weights/transformer_layer_weight.py | 141 ++++++++++++++++-- .../llama/layer_infer/post_layer_infer.py | 6 +- .../llama/layer_infer/pre_layer_infer.py | 15 +- .../layer_infer/transformer_layer_infer.py | 14 +- .../pre_and_post_layer_weight.py | 14 +- 11 files changed, 297 insertions(+), 65 deletions(-) diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py index bf3b210e8..b3d103534 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py @@ -7,6 +7,8 @@ from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv from typing import Tuple +import os + class TransformerLayerInferTpl(TransformerLayerInfer): """ """ @@ -21,6 +23,7 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode): self.tp_o_head_num_ = -1 self.head_dim_ = -1 self.embed_dim_ = -1 + self.tp_split_ = not os.environ.get("EDP_MODE_ENABLED") == "true" return def _att_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: @@ -79,7 +82,7 @@ def _context_attention(self, input_embding, infer_state: InferStateInfo, layer_w o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False) input_embding.add_(o.view(-1, self.embed_dim_)) return @@ -88,7 +91,7 @@ def _context_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return @@ -102,7 +105,7 @@ def _token_attention(self, input_embding, infer_state: InferStateInfo, layer_wei o = self._token_attention_kernel(q, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False) input_embding.add_(o.view(-1, self.embed_dim_)) return @@ -111,7 +114,7 @@ def _token_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight): input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return @@ -125,7 +128,7 @@ def _splitfuse_attention(self, input_embding, infer_state: SplitFuseInferStateIn o = self._splitfuse_attention_kernel(q, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False) input_embding.add_(o.view(-1, self.embed_dim_)) return @@ -134,7 +137,7 @@ def _splitfuse_ffn(self, input_embdings, infer_state: SplitFuseInferStateInfo, l input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index c6b1ab500..973de48da 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -11,6 +11,10 @@ MultiCOLMMWeight, ROWBMMWeight, COLBMMWeight, + MultiCOLMMWeightNoTp, + ROWBMMWeightNoTp, + COLBMMWeightNoTp, + COLMMWeightNoTp, ) from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight from .fused_moe_weight import FusedMoeWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py index ba6d7d028..5526eb383 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py @@ -153,7 +153,7 @@ def _load_hf_weights_etp(self, weights): self.expert_down_proj_etp[i_experts_ep, :] = self.experts_up_projs[i_experts_ep] def load_hf_weights(self, weights): - if os.environ.get("ETP_MODE_ENABLED") == "true": + if os.environ.get("ETP_MODE_ENABLED") == "true" or os.environ.get("EDP_MODE_ENABLED") == "true": self._load_hf_weights_etp(weights) else: for i_experts in range(self.n_routed_experts): @@ -184,7 +184,7 @@ def _cuda(self, cpu_tensor): return cpu_tensor.contiguous().to(self.data_type_).cuda(self.tp_rank_) def verify_load(self): - if os.environ.get("ETP_MODE_ENABLED") == "true": + if os.environ.get("ETP_MODE_ENABLED") == "true" or os.environ.get("EDP_MODE_ENABLED") == "true": return True else: return self.w1 is not None and self.w2 is not None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py index c91dd8acb..8c39102b1 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py @@ -319,3 +319,66 @@ def __init__( def _post_load_weights(self): self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_) + + +class COLMMWeightNoTp(MMWeight): + def __init__(self, weight_name, data_type, split_n_embed, bias_name=None): + super().__init__(weight_name, data_type, split_n_embed, bias_name) + self.start = 0 + self.end = split_n_embed + def load_hf_weights(self, weights): + weight = None + if self.weight_name in weights: + weight = weights[self.weight_name].to(self.data_type_) + self.weight = weight[:, self.start : self.end] + if self.bias_name in weights: + bias = weights[self.bias_name] + self.bias = bias.to(self.data_type_).cuda(self.tp_rank_) + if weight is None: + return + self._post_load_weights() + return + +class MultiCOLMMWeightNoTp(MultiROWMMWeightNoTP): + def __init__(self, weight_names, data_type, split_n_embed, bias_names=[]): + super().__init__(weight_names, data_type, split_n_embed, bias_names) + def load_hf_weights(self, weights): + weight = None + for i in range(len(self.weight_names)): + if self.weight_names[i] in weights: + weight = weights[self.weight_names[i]].to(self.data_type_) + self.weights[i] = weight[:, self.starts[i] : self.ends[i]] + if self.has_bias and self.bias_names[i] in weights: + bias = weights[self.bias_names[i]].to(self.data_type_) + self.biases[i] = bias[:, self.starts[i] : self.ends[i]] + self._fuse() + return + +class ROWBMMWeightNoTp(BMMWeight): + load_hf_weights = ROWMMWeight.load_hf_weights + def __init__( + self, + weight_name, + data_type, + split_n_embed, + bias_name=None, + ): + super().__init__(weight_name, data_type, split_n_embed, bias_name) + self.start = 0 + self.end = split_n_embed + +class COLBMMWeightNoTp(BMMWeight): + load_hf_weights = COLMMWeightNoTp.load_hf_weights + def __init__( + self, + weight_name, + data_type, + split_n_embed, + bias_name=None, + ): + super().__init__(weight_name, data_type, split_n_embed, bias_name) + self.start = 0 + self.end = split_n_embed + def _post_load_weights(self): + self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_) + \ No newline at end of file diff --git a/lightllm/common/deepseek2_mem_manager.py b/lightllm/common/deepseek2_mem_manager.py index 8df5a61e2..736f05f40 100644 --- a/lightllm/common/deepseek2_mem_manager.py +++ b/lightllm/common/deepseek2_mem_manager.py @@ -14,9 +14,25 @@ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): self.kv_buffer = torch.empty((layer_num, size, head_num, head_dim), dtype=dtype, device="cuda") # todo, etp or edp use the same work buffer here # also it can be used for any kernels for work buffer witout save info only - if os.environ.get("ETP_MODE_ENABLED") == "true": + if os.environ.get("ETP_MODE_ENABLED") == "true" or os.environ.get("EDP_MODE_ENABLED") == "true": self.work_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.bfloat16, device="cuda") self.work_buffer.share_memory_() + import lightllm_moe_etp_kernel + import torch.distributed as dist + + rank_id = dist.get_rank() + world_size = dist.get_world_size() + + #lightllm_moe_etp_kernel.enableP2P(world_size, rank_id) + + handle = lightllm_moe_etp_kernel.get_handle(self.work_buffer.contiguous(), rank_id) + handles = [None] * world_size + dist.all_gather_object(handles, handle) + self.handles_work_buffer = handles + + lightllm_moe_etp_kernel.init_system(world_size, rank_id, + self.work_buffer.contiguous(), + handles ) def alloc_kv_move_buffer(self, max_req_total_len): self.kv_move_buffer = torch.empty( diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 4b686c602..5d789c110 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -69,6 +69,9 @@ def __init__( self.enable_opt_decoding_mha = os.getenv("ENABLE_OPT_DECODE_MHA", "False").upper() in ["ON", "TRUE", "1"] self.mla_type = "ACCM" + + self.tp_split_ = not os.environ.get("EDP_MODE_ENABLED") == "true" + return def _bind_attention(self): @@ -78,8 +81,8 @@ def _bind_attention(self): ) self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) if self.is_moe: - if os.environ.get("ETP_MODE_ENABLED") == "true": - self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn_etp, self) + if os.environ.get("ETP_MODE_ENABLED") == "true" or os.environ.get("EDP_MODE_ENABLED") == "true": + self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn_etp_edp, self) else: self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn, self) else: @@ -120,6 +123,7 @@ def _get_qkv( self.mla_type = layer_weight.mla_type if self.mla_type == "ACCM": q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + layer_weight.kv_a_proj_with_mqa_.mm(input, out=cache_kv.view(-1, self.kv_lora_rank + self.qk_rope_head_dim)) @@ -155,7 +159,7 @@ def _CC_method( ): num_local_heads = self.num_heads num_local_kv_heads = self.num_kv_heads - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: num_local_heads //= self.world_size_ num_local_kv_heads //= self.world_size_ if infer_state.use_dynamic_prompt_cache: @@ -187,7 +191,7 @@ def _ACC_method( q_nope, q_rope = q num_local_heads = self.num_heads num_local_kv_heads = self.num_kv_heads - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: num_local_heads //= self.world_size_ num_local_kv_heads //= self.world_size_ # ACC @@ -275,6 +279,10 @@ def _context_attention_kernel_origin( self, q: Tuple[torch.Tensor, torch.Tensor], kv, infer_state: Deepseek2InferStateInfo, layer_weight, out=None ) -> torch.Tensor: q_nope, q_rope = q + + #not support edp yet + assert self.tp_split_ == True + o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out if infer_state.use_dynamic_prompt_cache: @@ -440,7 +448,7 @@ def _splitfuse_attention_kernel_with_CC( torch.cuda.default_stream().wait_event(infer_state.end_event) return o_tensor - def _moe_ffn_etp( + def _moe_ffn_etp_edp( self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight ) -> torch.Tensor: world_size_ = self.world_size_ @@ -461,36 +469,38 @@ def _moe_ffn_etp( num_tokens, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype ) - # router_logits_len = hidden_states.shape[0]*layer_weight.moe_gate.shape[1] router_logits = layer_weight.moe_gate.mm(hidden_states) # now some parameter is not supported yet # assert gating_normalize_prob is False # assert num_expert_groups<=1 + if os.environ.get("ETP_MODE_ENABLED") == "true" : + from lightllm_moe_etp_kernel import moe_fused_all as moe_fused_all + elif os.environ.get("EDP_MODE_ENABLED") == "true": + from lightllm_moe_etp_kernel import moe_fused_all_edp as moe_fused_all + + moe_fused_all( + router_logits.contiguous(), + hidden_states.contiguous(), + layer_weight.gate_up_proj.weight.contiguous(), # transpose + layer_weight.down_proj.weight.contiguous(), # transpose + layer_weight.experts.expert_gate_up_proj_etp.contiguous(), + layer_weight.experts.expert_down_proj_etp.contiguous(), + infer_state.mem_manager.work_buffer.contiguous(), + infer_state.mem_manager.work_buffer.nelement(), + final_hidden_states.contiguous(), + rank_self, + gating_scaling_factor, + num_experts, + num_experts_per_token, + num_tokens, + world_size_, + hidden_dim, + layer_weight.gate_up_proj.weight.size(1) // 2, + layer_weight.experts.expert_gate_up_proj_etp.size(1) // 2, + self.n_shared_experts is not None, + ) - import lightllm_moe_etp_kernel - - lightllm_moe_etp_kernel.moe_fused_all( - router_logits.contiguous(), - hidden_states.contiguous(), - layer_weight.gate_up_proj.weight.contiguous(), # transpose - layer_weight.down_proj.weight.contiguous(), # transpose - layer_weight.experts.expert_gate_up_proj_etp.contiguous(), - layer_weight.experts.expert_down_proj_etp.contiguous(), - infer_state.mem_manager.work_buffer.contiguous(), - infer_state.mem_manager.work_buffer.nelement(), - final_hidden_states.contiguous(), - rank_self, - gating_scaling_factor, - num_experts, - num_experts_per_token, - num_tokens, - world_size_, - hidden_dim, - layer_weight.gate_up_proj.weight.size(1) // 2, - layer_weight.experts.expert_gate_up_proj_etp.size(1) // 2, - self.n_shared_experts is not None, - ) router_logits = None diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 995ad1f11..847aec9cb 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -6,15 +6,21 @@ ROWMMWeight, ROWMMWeightNoTP, MultiROWMMWeight, + MultiROWMMWeightNoTP, COLMMWeight, + COLMMWeightNoTp, MultiCOLMMWeight, + MultiCOLMMWeightNoTp, NormWeight, FusedMoeWeight, ROWBMMWeight, + ROWBMMWeightNoTp, COLBMMWeight, + COLBMMWeightNoTp, ) from functools import partial +import os def fuse_q_kb(self, layer_weight): if not (self.weight is None and all(w is not None for w in self.weights)): @@ -74,6 +80,8 @@ def __init__( ): self.disable_qk_absorb = disable_qk_absorb self.disable_vo_absorb = disable_vo_absorb + self.tp_split_ = not os.environ.get("EDP_MODE_ENABLED") == "true" + super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg) # mla_type = "ACCM", "MIX" # MIX是prefilled CC,decoding ACC @@ -89,7 +97,9 @@ def _parse_config(self): and self.layer_num_ >= self.network_config_["first_k_dense_replace"] and self.layer_num_ % self.network_config_["moe_layer_freq"] == 0 ) - self.tp_q_head_num_ = self.network_config_["num_attention_heads"] // self.world_size_ + self.tp_q_head_num_ = self.network_config_["num_attention_heads"] + if self.tp_split_ : + self.tp_q_head_num_ //= self.world_size_ self.n_routed_experts = self.network_config_["n_routed_experts"] self.q_lora_rank = self.network_config_["q_lora_rank"] self.qk_nope_head_dim = self.network_config_["qk_nope_head_dim"] @@ -104,7 +114,10 @@ def _init_weight_names(self): self.rope_weight_name = f"model.layers.{self.layer_num_}.self_attn.q_b_proj.weight" def _init_weight(self): - self._init_qkvo() + if self.tp_split_ : + self._init_qkvo() + else: + self._init_qkvo_dp() if self.is_moe: self._init_moe() else: @@ -112,12 +125,13 @@ def _init_weight(self): self._init_norm() def _load_q_rope(self, q_weight_): - q_split_n_embed_with_rope = ( - (self.qk_nope_head_dim + self.qk_rope_head_dim) * self.num_attention_heads // self.world_size_ - ) - q_weight_ = q_weight_[ - q_split_n_embed_with_rope * self.tp_rank_ : q_split_n_embed_with_rope * (self.tp_rank_ + 1), : - ] + if self.tp_split_: + q_split_n_embed_with_rope = ( + (self.qk_nope_head_dim + self.qk_rope_head_dim) * self.num_attention_heads // self.world_size_ + ) + q_weight_ = q_weight_[ + q_split_n_embed_with_rope * self.tp_rank_ : q_split_n_embed_with_rope * (self.tp_rank_ + 1), : + ] q_weight_ = q_weight_.transpose(0, 1).contiguous() q_nope_proj_, q_rope_proj_ = torch.split( q_weight_.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim), @@ -239,11 +253,104 @@ def _init_qkvo(self): q_split_n_embed, ) - def _load_mlp(self, mlp_prefix, split_inter_size): - self.gate_up_proj = MultiROWMMWeight( - [f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"], self.data_type_, split_inter_size + def _init_qkvo_dp(self): + q_split_n_embed = self.qk_nope_head_dim * self.tp_q_head_num_ + q_split_n_embed_with_rope = ( + (self.qk_nope_head_dim + self.qk_rope_head_dim) * self.num_attention_heads ) - self.down_proj = COLMMWeight(f"{mlp_prefix}.down_proj.weight", self.data_type_, split_inter_size) + if self.q_lora_rank is None: + if not self.disable_qk_absorb: # acc + self.fuse_qk_weight_ = MultiROWMMWeightNoTP( + [ + f"model.layers.{self.layer_num_}.self_attn.q_proj.weight", + f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight", + ], + self.data_type_, + [q_split_n_embed_with_rope, self.tp_q_head_num_], + ) + self.fuse_qk_weight_._fuse = partial(fuse_q_kb, self.fuse_qk_weight_, self) + else: # cc + self.q_weight_ = ROWMMWeightNoTP( + f"model.layers.{self.layer_num_}.self_attn.q_proj.weight", + self.data_type_, + q_split_n_embed_with_rope, + ) + else: + self.q_a_proj_ = ROWMMWeightNoTP( + f"model.layers.{self.layer_num_}.self_attn.q_a_proj.weight", + self.data_type_, + self.q_lora_rank, + ) + if not self.disable_qk_absorb: + self.fuse_qk_weight_ = MultiROWMMWeightNoTP( + [ + f"model.layers.{self.layer_num_}.self_attn.q_b_proj.weight", + f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight", + ], + self.data_type_, + [q_split_n_embed_with_rope, self.tp_q_head_num_], + ) + self.fuse_qk_weight_._fuse = partial(fuse_q_kb, self.fuse_qk_weight_, self) + else: + self.q_b_proj_ = ROWMMWeightNoTP( + f"model.layers.{self.layer_num_}.self_attn.q_b_proj.weight", + self.data_type_, + q_split_n_embed_with_rope, + ) + + self.q_rope_proj_ = ROWMMWeightNoTP( + f"model.layers.{self.layer_num_}.self_attn.q_rope_proj.weight", + self.data_type_, + self.qk_rope_head_dim * self.tp_q_head_num_, + ) + + self.kv_a_proj_with_mqa_ = ROWMMWeightNoTP( + f"model.layers.{self.layer_num_}.self_attn.kv_a_proj_with_mqa.weight", + self.data_type_, + self.kv_lora_rank + self.qk_rope_head_dim, + ) + if self.disable_qk_absorb: + self.k_b_proj_ = ROWBMMWeightNoTp( + f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight", + self.data_type_, + split_n_embed=self.tp_q_head_num_, + ) + if not self.disable_vo_absorb: + self.fuse_vo_weight_ = MultiCOLMMWeightNoTp( + [ + f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight", + f"model.layers.{self.layer_num_}.self_attn.o_proj.weight", + ], + self.data_type_, + [self.tp_q_head_num_, q_split_n_embed], + ) + self.fuse_vo_weight_._fuse = partial(fuse_vb_o, self.fuse_vo_weight_, self) + else: + self.v_b_proj_ = COLBMMWeightNoTp( + f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight", + self.data_type_, + split_n_embed=self.tp_q_head_num_, + ) + if self.disable_vo_absorb: + self.o_weight_ = COLMMWeightNoTp( + f"model.layers.{self.layer_num_}.self_attn.o_proj.weight", + self.data_type_, + q_split_n_embed, + ) + + + def _load_mlp(self, mlp_prefix, split_inter_size): + if self.tp_split_ : + self.gate_up_proj = MultiROWMMWeight( + [f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"], self.data_type_, split_inter_size + ) + self.down_proj = COLMMWeight(f"{mlp_prefix}.down_proj.weight", self.data_type_, split_inter_size) + else: + self.gate_up_proj = MultiROWMMWeightNoTP( + [f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"], self.data_type_, split_inter_size + ) + self.down_proj = COLMMWeightNoTp(f"{mlp_prefix}.down_proj.weight", self.data_type_, split_inter_size) + def _init_moe(self): moe_intermediate_size = self.network_config_["moe_intermediate_size"] @@ -251,8 +358,9 @@ def _init_moe(self): f"model.layers.{self.layer_num_}.mlp.gate.weight", self.data_type_, moe_intermediate_size ) shared_intermediate_size = moe_intermediate_size * self.network_config_["n_shared_experts"] - shared_split_inter_size = shared_intermediate_size // self.world_size_ - self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", shared_split_inter_size) + + num_shards = self.world_size_ if self.tp_split_ else 1 + self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", shared_intermediate_size // num_shards) self.experts = FusedMoeWeight( gate_proj_name="gate_proj", @@ -262,12 +370,13 @@ def _init_moe(self): n_routed_experts=self.n_routed_experts, split_inter_size=moe_intermediate_size // self.world_size_, data_type=self.data_type_, + ) def _init_ffn(self): inter_size = self.network_config_["intermediate_size"] - split_inter_size = inter_size // self.world_size_ - self._load_mlp(f"model.layers.{self.layer_num_}.mlp", split_inter_size) + num_shards = self.world_size_ if self.tp_split_ else 1 + self._load_mlp(f"model.layers.{self.layer_num_}.mlp", inter_size // num_shards) def _init_norm(self): self.att_norm_weight_ = NormWeight(f"model.layers.{self.layer_num_}.input_layernorm.weight", self.data_type_) diff --git a/lightllm/models/llama/layer_infer/post_layer_infer.py b/lightllm/models/llama/layer_infer/post_layer_infer.py index a642a0fe0..8c7b8f5c0 100644 --- a/lightllm/models/llama/layer_infer/post_layer_infer.py +++ b/lightllm/models/llama/layer_infer/post_layer_infer.py @@ -12,6 +12,8 @@ from lightllm.common.basemodel import PostLayerInferTpl from lightllm.utils.infer_utils import mark_cost_time +import os + class LlamaPostLayerInfer(PostLayerInferTpl): """ """ @@ -21,6 +23,8 @@ def __init__(self, tp_rank, world_size, network_config, mode): self.eps_ = network_config["rms_norm_eps"] self.vocab_size_ = network_config["vocab_size"] self.embed_dim_ = network_config["n_embed"] + self.tp_split_ = not os.environ.get("EDP_MODE_ENABLED") == "true" + return def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor: @@ -89,7 +93,7 @@ def token_forward(self, input_embdings, infer_state: LlamaInferStateInfo, layer_ torch.mm(layer_weight.lm_head_weight_, last_input, out=logic_batch) last_input = None - if self.world_size_ == 1: + if self.world_size_ == 1 or self.tp_split_ == False: gather_data = logic_batch else: gather_data = self.alloc_tensor((self.vocab_size_, token_num), dtype=input_embdings_dtype) diff --git a/lightllm/models/llama/layer_infer/pre_layer_infer.py b/lightllm/models/llama/layer_infer/pre_layer_infer.py index f60fa6127..bb53b68f3 100644 --- a/lightllm/models/llama/layer_infer/pre_layer_infer.py +++ b/lightllm/models/llama/layer_infer/pre_layer_infer.py @@ -9,14 +9,21 @@ from lightllm.utils.infer_utils import mark_cost_time from lightllm.models.llama.triton_kernel.embedding import embedding +import os class LlamaPreLayerInfer(PreLayerInferTpl): """ """ def __init__(self, tp_rank, world_size, network_config, mode): super().__init__(tp_rank, world_size, network_config, mode) - tp_vob_ids = np.linspace(0, network_config["vocab_size"], self.world_size_ + 1, dtype=np.int64) - self.vob_start_id_, self.vob_end_id_ = int(tp_vob_ids[self.tp_rank_]), int(tp_vob_ids[self.tp_rank_ + 1]) + self.tp_split_ = not os.environ.get("EDP_MODE_ENABLED") == "true" + + if self.tp_split_: + tp_vob_ids = np.linspace(0, network_config["vocab_size"], self.world_size_ + 1, dtype=np.int64) + self.vob_start_id_, self.vob_end_id_ = int(tp_vob_ids[self.tp_rank_]), int(tp_vob_ids[self.tp_rank_ + 1]) + else: + self.vob_start_id_, self.vob_end_id_ = 0, network_config["vocab_size"] + return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): @@ -24,7 +31,7 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ ) embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: dist.all_reduce(input_embdings, op=dist.ReduceOp.SUM, async_op=False) return input_embdings @@ -33,7 +40,7 @@ def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weigh (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ ) embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: dist.all_reduce(input_embdings, op=dist.ReduceOp.SUM, async_op=False) return input_embdings diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index bc8ab44fb..4eb4bc80c 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -25,6 +25,7 @@ from lightllm.common.basemodel import TransformerLayerInferTpl from lightllm.models.llama.triton_kernel.ppl_quant_copy_kv import destindex_copy_dequantize_kv +import os class LlamaTransformerLayerInfer(TransformerLayerInferTpl): """ """ @@ -32,9 +33,16 @@ class LlamaTransformerLayerInfer(TransformerLayerInferTpl): def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): super().__init__(layer_num, tp_rank, world_size, network_config, mode) self.eps_ = network_config["rms_norm_eps"] - self.tp_q_head_num_ = network_config["num_attention_heads"] // self.world_size_ - self.tp_k_head_num_ = network_config["num_key_value_heads"] // self.world_size_ - self.tp_v_head_num_ = network_config["num_key_value_heads"] // self.world_size_ + + self.tp_q_head_num_ = network_config["num_attention_heads"] + self.tp_k_head_num_ = network_config["num_key_value_heads"] + self.tp_v_head_num_ = network_config["num_key_value_heads"] + if not os.environ.get("EDP_MODE_ENABLED") == "true": + self.tp_q_head_num_ //= world_size + self.tp_k_head_num_ //= world_size + self.tp_v_head_num_ //= world_size + + self.tp_o_head_num_ = self.tp_q_head_num_ self.head_dim_ = network_config["hidden_size"] // network_config["num_attention_heads"] self.embed_dim_ = network_config["hidden_size"] diff --git a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py index 25e9bd10c..d1e2e70ca 100644 --- a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py @@ -2,17 +2,25 @@ import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight +import os class LlamaPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, tp_rank, world_size, data_type, network_config, mode): super().__init__(tp_rank, world_size, data_type, network_config, mode) + self.tp_split_ = not os.environ.get("EDP_MODE_ENABLED") == "true" return def load_hf_weights(self, weights): vob_size = self.network_config_["vocab_size"] - split_indexes = np.linspace(0, vob_size, self.world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] + + if self.tp_split_: + split_indexes = np.linspace(0, vob_size, self.world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + else: + split_start = 0 + split_end = vob_size + if "model.embed_tokens.weight" in weights: self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) From b95bb347772b920c6c0d92dd5dcd0165c8cdbf75 Mon Sep 17 00:00:00 2001 From: xulei1 Date: Mon, 16 Dec 2024 19:36:16 +0800 Subject: [PATCH 2/3] pre-commit --- .../layer_weights/meta_weights/mm_weight.py | 9 +++- lightllm/common/deepseek2_mem_manager.py | 6 +-- .../layer_infer/transformer_layer_infer.py | 53 +++++++++---------- .../layer_weights/transformer_layer_weight.py | 20 +++---- .../llama/layer_infer/post_layer_infer.py | 4 +- .../llama/layer_infer/pre_layer_infer.py | 1 + .../layer_infer/transformer_layer_infer.py | 2 +- .../pre_and_post_layer_weight.py | 5 +- 8 files changed, 50 insertions(+), 50 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py index 8c39102b1..dd812304d 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py @@ -326,6 +326,7 @@ def __init__(self, weight_name, data_type, split_n_embed, bias_name=None): super().__init__(weight_name, data_type, split_n_embed, bias_name) self.start = 0 self.end = split_n_embed + def load_hf_weights(self, weights): weight = None if self.weight_name in weights: @@ -339,9 +340,11 @@ def load_hf_weights(self, weights): self._post_load_weights() return + class MultiCOLMMWeightNoTp(MultiROWMMWeightNoTP): def __init__(self, weight_names, data_type, split_n_embed, bias_names=[]): super().__init__(weight_names, data_type, split_n_embed, bias_names) + def load_hf_weights(self, weights): weight = None for i in range(len(self.weight_names)): @@ -354,8 +357,10 @@ def load_hf_weights(self, weights): self._fuse() return + class ROWBMMWeightNoTp(BMMWeight): load_hf_weights = ROWMMWeight.load_hf_weights + def __init__( self, weight_name, @@ -367,8 +372,10 @@ def __init__( self.start = 0 self.end = split_n_embed + class COLBMMWeightNoTp(BMMWeight): load_hf_weights = COLMMWeightNoTp.load_hf_weights + def __init__( self, weight_name, @@ -379,6 +386,6 @@ def __init__( super().__init__(weight_name, data_type, split_n_embed, bias_name) self.start = 0 self.end = split_n_embed + def _post_load_weights(self): self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_) - \ No newline at end of file diff --git a/lightllm/common/deepseek2_mem_manager.py b/lightllm/common/deepseek2_mem_manager.py index 736f05f40..3dab29851 100644 --- a/lightllm/common/deepseek2_mem_manager.py +++ b/lightllm/common/deepseek2_mem_manager.py @@ -23,16 +23,14 @@ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): rank_id = dist.get_rank() world_size = dist.get_world_size() - #lightllm_moe_etp_kernel.enableP2P(world_size, rank_id) + # lightllm_moe_etp_kernel.enableP2P(world_size, rank_id) handle = lightllm_moe_etp_kernel.get_handle(self.work_buffer.contiguous(), rank_id) handles = [None] * world_size dist.all_gather_object(handles, handle) self.handles_work_buffer = handles - lightllm_moe_etp_kernel.init_system(world_size, rank_id, - self.work_buffer.contiguous(), - handles ) + lightllm_moe_etp_kernel.init_system(world_size, rank_id, self.work_buffer.contiguous(), handles) def alloc_kv_move_buffer(self, max_req_total_len): self.kv_move_buffer = torch.empty( diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 5d789c110..73d0ebea4 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -69,7 +69,6 @@ def __init__( self.enable_opt_decoding_mha = os.getenv("ENABLE_OPT_DECODE_MHA", "False").upper() in ["ON", "TRUE", "1"] self.mla_type = "ACCM" - self.tp_split_ = not os.environ.get("EDP_MODE_ENABLED") == "true" return @@ -123,7 +122,6 @@ def _get_qkv( self.mla_type = layer_weight.mla_type if self.mla_type == "ACCM": q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - layer_weight.kv_a_proj_with_mqa_.mm(input, out=cache_kv.view(-1, self.kv_lora_rank + self.qk_rope_head_dim)) @@ -280,9 +278,9 @@ def _context_attention_kernel_origin( ) -> torch.Tensor: q_nope, q_rope = q - #not support edp yet - assert self.tp_split_ == True - + # not support edp yet + # assert self.tp_split_ == True + o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out if infer_state.use_dynamic_prompt_cache: @@ -474,33 +472,32 @@ def _moe_ffn_etp_edp( # now some parameter is not supported yet # assert gating_normalize_prob is False # assert num_expert_groups<=1 - if os.environ.get("ETP_MODE_ENABLED") == "true" : + if os.environ.get("ETP_MODE_ENABLED") == "true": from lightllm_moe_etp_kernel import moe_fused_all as moe_fused_all - elif os.environ.get("EDP_MODE_ENABLED") == "true": + elif os.environ.get("EDP_MODE_ENABLED") == "true": from lightllm_moe_etp_kernel import moe_fused_all_edp as moe_fused_all moe_fused_all( - router_logits.contiguous(), - hidden_states.contiguous(), - layer_weight.gate_up_proj.weight.contiguous(), # transpose - layer_weight.down_proj.weight.contiguous(), # transpose - layer_weight.experts.expert_gate_up_proj_etp.contiguous(), - layer_weight.experts.expert_down_proj_etp.contiguous(), - infer_state.mem_manager.work_buffer.contiguous(), - infer_state.mem_manager.work_buffer.nelement(), - final_hidden_states.contiguous(), - rank_self, - gating_scaling_factor, - num_experts, - num_experts_per_token, - num_tokens, - world_size_, - hidden_dim, - layer_weight.gate_up_proj.weight.size(1) // 2, - layer_weight.experts.expert_gate_up_proj_etp.size(1) // 2, - self.n_shared_experts is not None, - ) - + router_logits.contiguous(), + hidden_states.contiguous(), + layer_weight.gate_up_proj.weight.contiguous(), # transpose + layer_weight.down_proj.weight.contiguous(), # transpose + layer_weight.experts.expert_gate_up_proj_etp.contiguous(), + layer_weight.experts.expert_down_proj_etp.contiguous(), + infer_state.mem_manager.work_buffer.contiguous(), + infer_state.mem_manager.work_buffer.nelement(), + final_hidden_states.contiguous(), + rank_self, + gating_scaling_factor, + num_experts, + num_experts_per_token, + num_tokens, + world_size_, + hidden_dim, + layer_weight.gate_up_proj.weight.size(1) // 2, + layer_weight.experts.expert_gate_up_proj_etp.size(1) // 2, + self.n_shared_experts is not None, + ) router_logits = None diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 847aec9cb..b7024ec1b 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -22,6 +22,7 @@ import os + def fuse_q_kb(self, layer_weight): if not (self.weight is None and all(w is not None for w in self.weights)): return @@ -98,7 +99,7 @@ def _parse_config(self): and self.layer_num_ % self.network_config_["moe_layer_freq"] == 0 ) self.tp_q_head_num_ = self.network_config_["num_attention_heads"] - if self.tp_split_ : + if self.tp_split_: self.tp_q_head_num_ //= self.world_size_ self.n_routed_experts = self.network_config_["n_routed_experts"] self.q_lora_rank = self.network_config_["q_lora_rank"] @@ -114,7 +115,7 @@ def _init_weight_names(self): self.rope_weight_name = f"model.layers.{self.layer_num_}.self_attn.q_b_proj.weight" def _init_weight(self): - if self.tp_split_ : + if self.tp_split_: self._init_qkvo() else: self._init_qkvo_dp() @@ -255,9 +256,7 @@ def _init_qkvo(self): def _init_qkvo_dp(self): q_split_n_embed = self.qk_nope_head_dim * self.tp_q_head_num_ - q_split_n_embed_with_rope = ( - (self.qk_nope_head_dim + self.qk_rope_head_dim) * self.num_attention_heads - ) + q_split_n_embed_with_rope = (self.qk_nope_head_dim + self.qk_rope_head_dim) * self.num_attention_heads if self.q_lora_rank is None: if not self.disable_qk_absorb: # acc self.fuse_qk_weight_ = MultiROWMMWeightNoTP( @@ -269,7 +268,7 @@ def _init_qkvo_dp(self): [q_split_n_embed_with_rope, self.tp_q_head_num_], ) self.fuse_qk_weight_._fuse = partial(fuse_q_kb, self.fuse_qk_weight_, self) - else: # cc + else: # cc self.q_weight_ = ROWMMWeightNoTP( f"model.layers.{self.layer_num_}.self_attn.q_proj.weight", self.data_type_, @@ -338,9 +337,8 @@ def _init_qkvo_dp(self): q_split_n_embed, ) - def _load_mlp(self, mlp_prefix, split_inter_size): - if self.tp_split_ : + if self.tp_split_: self.gate_up_proj = MultiROWMMWeight( [f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"], self.data_type_, split_inter_size ) @@ -351,7 +349,6 @@ def _load_mlp(self, mlp_prefix, split_inter_size): ) self.down_proj = COLMMWeightNoTp(f"{mlp_prefix}.down_proj.weight", self.data_type_, split_inter_size) - def _init_moe(self): moe_intermediate_size = self.network_config_["moe_intermediate_size"] self.moe_gate = ROWMMWeightNoTP( @@ -359,7 +356,7 @@ def _init_moe(self): ) shared_intermediate_size = moe_intermediate_size * self.network_config_["n_shared_experts"] - num_shards = self.world_size_ if self.tp_split_ else 1 + num_shards = self.world_size_ if self.tp_split_ else 1 self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", shared_intermediate_size // num_shards) self.experts = FusedMoeWeight( @@ -370,12 +367,11 @@ def _init_moe(self): n_routed_experts=self.n_routed_experts, split_inter_size=moe_intermediate_size // self.world_size_, data_type=self.data_type_, - ) def _init_ffn(self): inter_size = self.network_config_["intermediate_size"] - num_shards = self.world_size_ if self.tp_split_ else 1 + num_shards = self.world_size_ if self.tp_split_ else 1 self._load_mlp(f"model.layers.{self.layer_num_}.mlp", inter_size // num_shards) def _init_norm(self): diff --git a/lightllm/models/llama/layer_infer/post_layer_infer.py b/lightllm/models/llama/layer_infer/post_layer_infer.py index 8c7b8f5c0..f31761788 100644 --- a/lightllm/models/llama/layer_infer/post_layer_infer.py +++ b/lightllm/models/llama/layer_infer/post_layer_infer.py @@ -24,7 +24,7 @@ def __init__(self, tp_rank, world_size, network_config, mode): self.vocab_size_ = network_config["vocab_size"] self.embed_dim_ = network_config["n_embed"] self.tp_split_ = not os.environ.get("EDP_MODE_ENABLED") == "true" - + return def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor: @@ -93,7 +93,7 @@ def token_forward(self, input_embdings, infer_state: LlamaInferStateInfo, layer_ torch.mm(layer_weight.lm_head_weight_, last_input, out=logic_batch) last_input = None - if self.world_size_ == 1 or self.tp_split_ == False: + if self.world_size_ == 1 or not self.tp_split_: gather_data = logic_batch else: gather_data = self.alloc_tensor((self.vocab_size_, token_num), dtype=input_embdings_dtype) diff --git a/lightllm/models/llama/layer_infer/pre_layer_infer.py b/lightllm/models/llama/layer_infer/pre_layer_infer.py index bb53b68f3..cee3c9b1a 100644 --- a/lightllm/models/llama/layer_infer/pre_layer_infer.py +++ b/lightllm/models/llama/layer_infer/pre_layer_infer.py @@ -11,6 +11,7 @@ import os + class LlamaPreLayerInfer(PreLayerInferTpl): """ """ diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 4eb4bc80c..216a8767e 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -27,6 +27,7 @@ import os + class LlamaTransformerLayerInfer(TransformerLayerInferTpl): """ """ @@ -42,7 +43,6 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): self.tp_k_head_num_ //= world_size self.tp_v_head_num_ //= world_size - self.tp_o_head_num_ = self.tp_q_head_num_ self.head_dim_ = network_config["hidden_size"] // network_config["num_attention_heads"] self.embed_dim_ = network_config["hidden_size"] diff --git a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py index d1e2e70ca..06b2277a1 100644 --- a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py @@ -2,7 +2,8 @@ import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight -import os +import os + class LlamaPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, tp_rank, world_size, data_type, network_config, mode): @@ -12,7 +13,7 @@ def __init__(self, tp_rank, world_size, data_type, network_config, mode): def load_hf_weights(self, weights): vob_size = self.network_config_["vocab_size"] - + if self.tp_split_: split_indexes = np.linspace(0, vob_size, self.world_size_ + 1, dtype=np.int64) split_start = split_indexes[self.tp_rank_] From 25d7a72c47d6d4d2881a3f24660cc57882728824 Mon Sep 17 00:00:00 2001 From: xulei1 Date: Wed, 18 Dec 2024 18:46:23 +0800 Subject: [PATCH 3/3] edp etp fused --- .../layer_infer/transformer_layer_infer.py | 19 ++++++++++++++----- .../layer_infer/transformer_layer_infer.py | 6 +++++- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 73d0ebea4..67da6576b 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -466,16 +466,23 @@ def _moe_ffn_etp_edp( final_hidden_states = torch.empty( num_tokens, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype ) - - router_logits = layer_weight.moe_gate.mm(hidden_states) + # now some parameter is not supported yet # assert gating_normalize_prob is False # assert num_expert_groups<=1 + is_etp = True if os.environ.get("ETP_MODE_ENABLED") == "true": - from lightllm_moe_etp_kernel import moe_fused_all as moe_fused_all + router_logits = layer_weight.moe_gate.mm(hidden_states) elif os.environ.get("EDP_MODE_ENABLED") == "true": - from lightllm_moe_etp_kernel import moe_fused_all_edp as moe_fused_all + router_logits = infer_state.mem_manager.work_buffer[ -(num_tokens*num_experts_per_token+hidden_states.nelement()):-hidden_states.nelement()].view( num_tokens ,num_experts_per_token) + router_logits = layer_weight.moe_gate.mm(hidden_states,out=router_logits) + is_etp = False + + #print(" hid state addr ", infer_state.mem_manager.work_buffer.data_ptr(), + # hidden_states.data_ptr(), + # hidden_states.shape() + # ) moe_fused_all( router_logits.contiguous(), @@ -497,8 +504,10 @@ def _moe_ffn_etp_edp( layer_weight.gate_up_proj.weight.size(1) // 2, layer_weight.experts.expert_gate_up_proj_etp.size(1) // 2, self.n_shared_experts is not None, + is_etp ) - router_logits = None + if os.environ.get("ETP_MODE_ENABLED") == "true": + router_logits = None return final_hidden_states.view(num_tokens, hidden_dim) diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 216a8767e..4ec2fd00f 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -127,7 +127,11 @@ def _att_norm( def _ffn_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: - out = self.alloc_tensor(input.shape, input.dtype) + if not os.environ.get("EDP_MODE_ENABLED") == "true": + out = self.alloc_tensor(input.shape, input.dtype) + else: + num_ele = input.nelement() + out = self.infer_state.mem_manager.work_buffer[ -num_ele: ].view(input.shape) rmsnorm_forward(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_, out=out) return out