diff --git a/examples/auto_parallel/README.md b/examples/auto_parallel/README.md index b04f68b4f..521e85f2c 100644 --- a/examples/auto_parallel/README.md +++ b/examples/auto_parallel/README.md @@ -28,9 +28,3 @@ should be replaced according to the real environment. The toolkit provides an auto-parallel solution for ERNIE-4.5 pre-training, including the hybrid parallelism training strategy. More advanced optimizations are on the way. - - -Currently, the auto-parallel intermediate API has some limitations under ongoing development: - -- Limited support for MOE -- Limited support for VPP in pipeline parallelism (default USE_VPP=0 in scripts; when USE_VPP=1, basic API are used for modeling) diff --git a/examples/auto_parallel/README_zh.md b/examples/auto_parallel/README_zh.md index 477bbd4ab..c6aee8107 100644 --- a/examples/auto_parallel/README_zh.md +++ b/examples/auto_parallel/README_zh.md @@ -26,7 +26,3 @@ - 注意,您需要将 `train_4p5_300B_A47B.sh` 中的 `master_ip` 与 `port` 根据您的环境进行替换。 该工具包提供了使用自动并行完成 ERNIE-4.5 预训练的方法,包括多维混合并行训练策略,更多的优化点和功能会基于此版本持续更新。 - -现在自动并行中层API存在一些局限性,正在进一步支持: -- 对 MOE 的支持不完备 -- 对流水线并行中的 VPP 优化支持不完备(脚本中默认 USE_VPP=0;当设置 USE_VPP=1 时,采用基础API完成组网) diff --git a/examples/auto_parallel/models/__init__.py b/examples/auto_parallel/models/__init__.py deleted file mode 100644 index 04418b69f..000000000 --- a/examples/auto_parallel/models/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from .top2_gate import * # noqa diff --git a/examples/auto_parallel/models/modeling.py b/examples/auto_parallel/models/modeling.py index 5f491f00b..a3c9314f7 100644 --- a/examples/auto_parallel/models/modeling.py +++ b/examples/auto_parallel/models/modeling.py @@ -16,7 +16,6 @@ import math import logging from typing import Optional, Tuple -import contextlib from copy import deepcopy @@ -29,8 +28,6 @@ from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd from paddle.distributed.fleet.layers.mpu.random import get_rng_state_tracker -from models.top2_gate import TopKGateFused - from paddle.distributed.auto_parallel.intermediate.tensor_parallel import ( PrepareLayerInput, ) @@ -42,16 +39,18 @@ from paddleformers.transformers.model_utils import PretrainedModel from models.moe_layer import ( + get_gate, MOELayer, - MoEStatics, + ErnieMLP, + ErnieMoeMLP, + ErnieMoeMLPFused, + TopKGateFused, ) from models.configuration import ErnieMoEConfig -from utils.training_utils import get_mesh from paddle.nn.functional.flash_attention import flash_attention from paddle.incubate.nn.functional import fused_rotary_position_embedding as fused_rope -from paddle.incubate.nn.functional import swiglu @dataclass @@ -178,15 +177,7 @@ def scaled_dot_product_attention( attn_weights = F.softmax_(attn_weights, axis=-1).astype(query_states.dtype) if config.attention_probs_dropout_prob > 0.0: - if config.tensor_parallel_degree > 1: - with get_rng_state_tracker().rng_state("local_seed"): - attn_weights = F.dropout( - attn_weights, - config.attention_probs_dropout_prob, - training=training, - mode="upscale_in_train", - ) - else: + with get_rng_state_tracker().rng_state("local_seed"): attn_weights = F.dropout( attn_weights, config.attention_probs_dropout_prob, @@ -241,74 +232,6 @@ def _expand_mask(mask, dtype, tgt_length): ) -def get_gate( - config: ErnieMoEConfig, - expert: Tuple[Tuple[int, nn.Layer]], - layer_idx: int, - ipp: int = 0, -) -> Tuple[nn.Layer, nn.LayerList]: - moe_num_experts = config.moe_num_experts - assert ( - moe_num_experts >= config.moe_world_size - ), f"expert moe_num_experts={moe_num_experts} >= moe_world_size={config.moe_world_size}" - assert ( - moe_num_experts % config.moe_world_size == 0 - ), f"expert moe_num_experts={moe_num_experts} % moe_world_size={config.moe_world_size} == 0" - moe_num_experts_per_device = moe_num_experts // config.moe_world_size - experts = nn.LayerList([]) - for expert_id, (experts_num, fc) in enumerate(expert): - assert experts_num % config.moe_world_size == 0 - experts_to_append = [] - if not hasattr(fc, "__len__"): - experts_to_append.append(fc) - if expert_id == 1: - with paddle.utils.unique_name.guard("_mm_deepcopy"): - for _ in range(experts_num - 1): - experts_to_append.append(deepcopy(fc)) - else: - for _ in range(experts_num - 1): - experts_to_append.append(deepcopy(fc)) - else: - experts_to_append = fc - for ex in experts_to_append: - for p in ex.parameters(): - p.expert_type = f"expert_type_{expert_id}" - experts.extend(experts_to_append) - - logger.info( - f"using moe-world-size: {config.moe_world_size} " - f"expert-per-device: {moe_num_experts_per_device} " - ) - if config.moe_use_hard_gate and moe_num_experts <= 2: - gate = None - logger.info("MOE-GATE:-hard-gate") - else: - logger.info(f"MOE-GATE:-{config.moe_gate}") - gate = TopKGateFused( - config, layer_idx=layer_idx, group=config.moe_group, ipp=ipp - ) - - lm_gate, lm_experts = None, None - logger.info(f"LM-experts-{lm_experts} -- experts-{experts}") - - index = 0 if config.moe_group == "dp" else 1 - ep_sub_meshes = dist.auto_parallel.api.split_mesh(get_mesh(ipp), index) - - for i, expert in enumerate(experts): - ep_group_id = i // moe_num_experts_per_device - if isinstance(expert, (ErnieMoeMLPFused, ErnieMoeMLP)): - experts[i].redistribute_expert( - ep_sub_meshes[ep_group_id], [dist.Replicate(), dist.Replicate()] - ) - experts[i].ep_group_id = ep_group_id - - if config.moe_use_aux_free: - moe_statics = MoEStatics(config, layer_idx) - else: - moe_statics = None - return gate, experts, lm_gate, lm_experts, moe_statics - - class RMSNorm(nn.Layer): def __init__(self, config, ipp=0): super().__init__() @@ -476,36 +399,6 @@ def apply_rotary_single(x, rope_emb): return x * rope_emb[0] + rotate_half_x * rope_emb[1] -class ErnieMLP(nn.Layer): - def __init__(self, config, ipp=None, do_shard_tensor=True): - super().__init__() - self.config = config - self.ipp = ipp - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - - self.gate_proj = nn.Linear( - self.hidden_size, self.intermediate_size, bias_attr=config.use_bias - ) - self.up_proj = nn.Linear( - self.hidden_size, self.intermediate_size, bias_attr=config.use_bias - ) - self.down_proj = nn.Linear( - self.intermediate_size, self.hidden_size, bias_attr=config.use_bias - ) - - self.fuse_swiglu = config.fuse_swiglu - - def forward(self, x): - if self.fuse_swiglu: - x = swiglu(self.gate_proj(x), self.up_proj(x)) - else: - x = F.silu(self.gate_proj(x)) * self.up_proj(x) - - out = self.down_proj(x) - return out - - class ErnieAttention(nn.Layer): def __init__(self, config, ipp: Optional[int] = None): super().__init__() @@ -696,119 +589,6 @@ def rope_attn( return attn_output, attn_weights, past_key_value -class ErnieMoeMLP(ErnieMLP): - """_summary_ - - Args: - ErnieMoeMLP (_type_): _description_ - """ - - def __init__(self, config, ipp=0): - """ - doc - """ - disable_ffn_model_parallel = getattr( - config, "disable_ffn_model_parallel", False - ) - if disable_ffn_model_parallel: - config = deepcopy(config) - config.tensor_parallel_degree = 1 - config.sequence_parallel = False - - super().__init__(config, ipp, do_shard_tensor=not disable_ffn_model_parallel) - self.moe_dropout_prob = config.moe_dropout_prob - self.fuse_swiglu = config.fuse_swiglu - - def redistribute_expert(self, mesh, placements): - """ - Place the experts on different devices. - """ - self.gate_proj.weight = dist.shard_tensor( - self.gate_proj.weight, mesh, placements - ) - self.up_proj.weight = dist.shard_tensor(self.up_proj.weight, mesh, placements) - self.down_proj.weight = dist.shard_tensor( - self.down_proj.weight, mesh, placements - ) - if self.config.use_bias: - self.gate_proj.bias = dist.shard_tensor( - self.gate_proj.bias, mesh, placements - ) - self.up_proj.bias = dist.shard_tensor(self.up_proj.bias, mesh, placements) - self.down_proj.bias = dist.shard_tensor( - self.down_proj.bias, mesh, placements - ) - - def forward(self, x): - if self.fuse_swiglu: - x = swiglu(self.gate_proj(x), self.up_proj(x)) - else: - x = F.silu(self.gate_proj(x)) * self.up_proj(x) - if self.moe_dropout_prob > 0: - with get_rng_state_tracker().rng_state("local_seed"): - x = F.dropout(x=x, p=self.moe_dropout_prob) - ret = self.down_proj(x) - return ret - - -class BMMLinear(nn.Layer): - def __init__(self, experts, d_in, d_out, use_bias=False): - super().__init__() - self.weight = self.create_parameter( - [experts, d_in, d_out], dtype=paddle.get_default_dtype() - ) - if use_bias: - self.bias = self.create_parameter( - [experts, d_out], dtype=paddle.get_default_dtype(), is_bias=True - ) - else: - self.bias = None - - def forward(self, x): - """x: [num_experts, Seq, dim]""" - if self.bias is not None: - return paddle.bmm(x, self.weight) + self.bias - return paddle.bmm(x, self.weight) - - -class ErnieMoeMLPFused(nn.Layer): - def __init__(self, config): - assert ( - hasattr(config, "disable_ffn_model_parallel") - or config.tensor_parallel_degree == 1 - ), f"fused mlp only suport mp-moe, mp={config.tensor_parallel_degree}" - assert config.fuse_attn_ffn, "fused mlp only support fuse_attn_ffn" - super().__init__() - self.moe_dropout_prob = config.moe_dropout_prob - self.num_local_experts = config.moe_num_experts // config.moe_world_size - logger.info( - f"fused-expert-weight-shape: {[self.num_local_experts, config.hidden_size, config.intermediate_size]}" - ) - - self.up_gate_proj = BMMLinear( - self.num_local_experts, config.hidden_size, config.intermediate_size * 2 - ) - self.down_proj = BMMLinear( - self.num_local_experts, config.intermediate_size, config.hidden_size - ) - self.fuse_swiglu = config.fuse_swiglu - - def __len__(self): - return self.num_local_experts - - def __iter__(self): - return (self for _ in range(1)) - - def forward(self, x): - if self.fuse_swiglu: - x = swiglu(self.up_gate_proj(x)) - else: - gate, x = self.up_gate_proj(x).chunk(2, axis=-1) - x = F.silu(gate) * x - x = self.down_proj(x) - return x - - class ErnieDecoderLayer(nn.Layer): """ ErnieDecoderLayer is a decoder layer in Ernie model. @@ -990,16 +770,7 @@ def forward( ) ) - if ( - self.config.tensor_parallel_degree > 1 - and self.config.hidden_dropout_prob > 0.0 - ): - current_seed = ( - "local_seed" if self.config.sequence_parallel else "global_seed" - ) - with get_rng_state_tracker().rng_state(current_seed): - hidden_states = self.residual_add1(hidden_states, residual) - else: + with get_rng_state_tracker().rng_state("local_seed"): hidden_states = self.residual_add1(hidden_states, residual) residual = hidden_states @@ -1017,16 +788,7 @@ def forward( hidden_states = self.mlp(hidden_states) gate_logits = None - if ( - self.config.tensor_parallel_degree > 1 - and self.config.hidden_dropout_prob > 0.0 - ): - current_seed = ( - "local_seed" if self.config.sequence_parallel else "global_seed" - ) - with get_rng_state_tracker().rng_state(current_seed): - hidden_states = self.residual_add2(hidden_states, residual) - else: + with get_rng_state_tracker().rng_state("local_seed"): hidden_states = self.residual_add2(hidden_states, residual) outputs = (hidden_states,) @@ -1069,10 +831,7 @@ class ErniePretrainedModel(PretrainedModel): def init_weights(self, layer): """Initialization hook""" - if self.config.tensor_parallel_degree > 1: - rng_tracker = get_rng_state_tracker().rng_state - else: - rng_tracker = contextlib.nullcontext + rng_tracker = get_rng_state_tracker().rng_state if isinstance( layer, @@ -1893,14 +1652,8 @@ def forward( def auto_dist_config(self, prefix=""): if prefix != "": assert prefix.endswith(".") - # if self.config.pipeline_parallel_degree <= 1: - # print(f"ernie use_intermediate_api:{self.config.use_intermediate_api}") - # print(f"ernie pp mode:{self.config.pipeline_schedule_mode}") ernie_prefix = prefix + "ernie." layers_prefix = "" - # else: - # ernie_prefix = prefix - # layers_prefix="layers.*." config = { "sp_config": { "parallelize_plan": { diff --git a/examples/auto_parallel/models/modeling_vpp.py b/examples/auto_parallel/models/modeling_vpp.py index 77cf2c128..12141892a 100644 --- a/examples/auto_parallel/models/modeling_vpp.py +++ b/examples/auto_parallel/models/modeling_vpp.py @@ -16,513 +16,54 @@ import math import logging from typing import Optional, Tuple -import contextlib -from copy import deepcopy -from dataclasses import dataclass import paddle import paddle.distributed as dist -import paddle.nn.functional as F from paddle.distributed import fleet from paddle import nn -from paddle.distributed.fleet.utils import recompute -from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd from paddle.distributed.fleet.layers.mpu.random import get_rng_state_tracker -from models.top2_gate import TopKGateFused - - -from paddleformers.transformers.model_outputs import ( - BaseModelOutputWithPastAndCrossAttentions as _BaseModelOutput, +from .modeling import ( + RMSNorm, + ErniePretrainedModel, + ErnieModel, + ErniePretrainingCriterion, + ReshardLayer, + ErnieLMHead, + ErnieDecoderLayer, + ErnieAttention, + ErnieForCausalLM, ) -from paddleformers.transformers.model_outputs import CausalLMOutputWithCrossAttentions -from paddleformers.transformers.model_utils import PretrainedModel -from models.moe_layer import ( - MOELayer, - MoEStatics, -) +from models.moe_layer import MOELayer, ErnieMLP from models.configuration import ErnieMoEConfig from utils.training_utils import get_mesh - -from paddle.nn.functional.flash_attention import flash_attention -from paddle.incubate.nn.functional import fused_rotary_position_embedding as fused_rope -from paddle.incubate.nn.functional import swiglu - - -@dataclass -class BaseModelOutputWithPastAndCrossAttentions(_BaseModelOutput): - router_loss: Optional[paddle.Tensor] = None - gate_logits: Optional[Tuple[paddle.Tensor]] = None - mtp_outputs: Optional[paddle.Tensor] = None - - -@dataclass -class CausalLMOutputWithCrossAttentionsErnie(CausalLMOutputWithCrossAttentions): - router_loss: Optional[paddle.Tensor] = None - - logger = logging.getLogger(__name__) __all__ = [ - "ErnieForCausalLM", + "ErnieForCausalLMVPP", ] -def calc_lm_head_logits( - config, - hidden_states, - weight, - bias, - sparse_label_idx=None, -): - """the core function to calc lm head""" - if config.sequence_parallel: - hcg = paddle.distributed.fleet.get_hybrid_communicate_group() - dp_rank = hcg.get_data_parallel_rank() - sharding_rank = hcg.get_sharding_parallel_rank() - if dp_rank <= 1 and sharding_rank <= 1: - hidden_states = dist.reshard( - hidden_states, - get_mesh(-1), - [dist.Replicate(), dist.Replicate()], - ) - else: - hidden_states = dist.reshard( - hidden_states, - get_mesh(-1), - [dist.Shard(1), dist.Replicate()], - ) - # [S, B, H] to [B, S, H] - hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) - hidden_states = hidden_states.reshape( - [-1, config.seqlen, hidden_states.shape[-1]] - ) - - logits = paddle.matmul( - hidden_states, weight, transpose_y=config.tie_word_embeddings - ) - if bias is not None: - logits += bias - - return logits - - -def masked_fill(x, mask, value): - y = paddle.full(x.shape, value, x.dtype) - return paddle.where(mask, y, x) - - -def scaled_dot_product_attention( - query_states, - key_states, - value_states, - attention_mask, - output_attentions, - config, - is_causal=True, - inbatch_pack_offset=None, - training=True, -): - bsz, q_len, num_heads, head_dim = query_states.shape - _, kv_seq_len, num_key_value_heads, _ = value_states.shape - - can_use_fa = config.use_flash_attn - - if can_use_fa: - attn_output, attn_weights = flash_attention( - query_states, - key_states, - value_states, - dropout=config.attention_probs_dropout_prob, - causal=is_causal and query_states.shape[1] != 1, - return_softmax=output_attentions, - ) - - attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) - return attn_output, attn_weights - else: - if query_states.shape[-2] != key_states.shape[-2]: - key_states = key_states.repeat_interleave( - num_heads // num_key_value_heads, axis=-2 - ) - if query_states.shape[-2] != value_states.shape[-2]: - value_states = value_states.repeat_interleave( - num_heads // num_key_value_heads, axis=-2 - ) - query_states = paddle.transpose(query_states, [0, 2, 1, 3]) / math.sqrt( - head_dim - ) - key_states = paddle.transpose(key_states, [0, 2, 1, 3]) - value_states = paddle.transpose(value_states, [0, 2, 1, 3]) - - attn_weights = paddle.matmul(query_states, key_states.transpose([0, 1, 3, 2])) - - if attn_weights.shape != [bsz, num_heads, q_len, kv_seq_len]: - raise ValueError( - f"Attention weights should be of shape {(bsz, num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.shape}" - ) - - if attention_mask is None: - attention_mask = F.get_triangle_upper_mask(attn_weights) - - attention_mask = attention_mask.reshape([bsz, 1, q_len, kv_seq_len]) - if attention_mask.shape != [bsz, 1, q_len, kv_seq_len]: - raise ValueError( - f"Attention mask should be of shape {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.shape}" - ) - if training: - attn_weights = attention_mask + attn_weights - attn_weights = paddle.maximum( - attn_weights, - paddle.to_tensor( - float(paddle.finfo(query_states.dtype).min), - dtype=query_states.dtype, - ), - ) - - with paddle.amp.auto_cast(False): - attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype( - query_states.dtype - ) - - else: - attn_weights = attn_weights.cast(paddle.float32) - attention_mask = attention_mask.cast(paddle.float32) - attn_weights = attn_weights.add_(attention_mask) - attn_weights = F.softmax_(attn_weights, axis=-1).astype(query_states.dtype) - - if config.attention_probs_dropout_prob > 0.0: - if config.tensor_parallel_degree > 1: - with get_rng_state_tracker().rng_state("local_seed"): - attn_weights = F.dropout( - attn_weights, - config.attention_probs_dropout_prob, - training=training, - mode="upscale_in_train", - ) - else: - attn_weights = F.dropout( - attn_weights, - config.attention_probs_dropout_prob, - training=training, - mode="upscale_in_train", - ) - - attn_output = paddle.matmul(attn_weights, value_states) - attn_output = attn_output.transpose([0, 2, 1, 3]) - attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) - if output_attentions: - return attn_output, attn_weights - return attn_output, None - - -def _make_causal_mask(input_ids_shape, past_key_values_length, dtype): - batch_size, target_length = input_ids_shape - - mask = paddle.full((target_length, target_length), float(paddle.finfo(dtype).min)) - - mask_cond = paddle.arange(mask.shape[-1]) - mask = masked_fill( - mask, mask_cond < (mask_cond + 1).reshape([mask.shape[-1], 1]), 0 - ) - - if past_key_values_length > 0: - mask = paddle.concat( - [paddle.zeros([target_length, past_key_values_length]), mask], axis=-1 - ) - - return mask[None, None, :, :].expand( - [batch_size, 1, target_length, target_length + past_key_values_length] - ) - - -def _expand_mask(mask, dtype, tgt_length): - if mask.ndim == 4: - expanded_mask = mask - elif mask.ndim == 3: - expanded_mask = mask[:, None, :, :] - else: - batch_size, src_length = mask.shape[0], mask.shape[-1] - tgt_length = tgt_length if tgt_length is not None else src_length - - expanded_mask = mask[:, None, None, :].expand( - [batch_size, 1, tgt_length, src_length] - ) - - inverted_mask = 1.0 - expanded_mask - return masked_fill( - inverted_mask, inverted_mask.cast("bool"), float(paddle.finfo(dtype).min) - ) - - -def get_gate( - config: ErnieMoEConfig, - expert: Tuple[Tuple[int, nn.Layer]], - layer_idx: int, - ipp: int = 0, -) -> Tuple[nn.Layer, nn.LayerList]: - moe_num_experts = config.moe_num_experts - assert ( - moe_num_experts >= config.moe_world_size - ), f"expert moe_num_experts={moe_num_experts} >= moe_world_size={config.moe_world_size}" - assert ( - moe_num_experts % config.moe_world_size == 0 - ), f"expert moe_num_experts={moe_num_experts} % moe_world_size={config.moe_world_size} == 0" - moe_num_experts_per_device = moe_num_experts // config.moe_world_size - experts = nn.LayerList([]) - for expert_id, (experts_num, fc) in enumerate(expert): - assert experts_num % config.moe_world_size == 0 - experts_to_append = [] - if not hasattr(fc, "__len__"): - experts_to_append.append(fc) - if expert_id == 1: - with paddle.utils.unique_name.guard("_mm_deepcopy"): - for _ in range(experts_num - 1): - experts_to_append.append(deepcopy(fc)) - else: - for _ in range(experts_num - 1): - experts_to_append.append(deepcopy(fc)) - else: - experts_to_append = fc - for ex in experts_to_append: - for p in ex.parameters(): - p.expert_type = f"expert_type_{expert_id}" - experts.extend(experts_to_append) - - logger.info( - f"using moe-world-size: {config.moe_world_size} " - f"expert-per-device: {moe_num_experts_per_device} " - ) - if config.moe_use_hard_gate and moe_num_experts <= 2: - gate = None - logger.info("MOE-GATE:-hard-gate") - else: - logger.info(f"MOE-GATE:-{config.moe_gate}") - gate = TopKGateFused( - config, layer_idx=layer_idx, group=config.moe_group, ipp=ipp - ) - - lm_gate, lm_experts = None, None - logger.info(f"LM-experts-{lm_experts} -- experts-{experts}") - - index = 0 if config.moe_group == "dp" else 1 - ep_sub_meshes = dist.auto_parallel.api.split_mesh(get_mesh(ipp), index) - - for i, expert in enumerate(experts): - ep_group_id = i // moe_num_experts_per_device - if isinstance(expert, (ErnieMoeMLPFused, ErnieMoeMLP)): - experts[i].redistribute_expert( - ep_sub_meshes[ep_group_id], [dist.Replicate(), dist.Replicate()] - ) - experts[i].ep_group_id = ep_group_id - - if config.moe_use_aux_free: - moe_statics = MoEStatics(config, layer_idx) - else: - moe_statics = None - return gate, experts, lm_gate, lm_experts, moe_statics - - -class RMSNorm(nn.Layer): - def __init__(self, config, ipp=0): - super().__init__() - self.hidden_size = config.hidden_size - self.weight = paddle.create_parameter( - shape=[self.hidden_size], - dtype=paddle.get_default_dtype(), - default_initializer=nn.initializer.Constant(1.0), - ) - self.variance_epsilon = config.rms_norm_eps - self.config = config - - def forward(self, hidden_states): - if self.config.fuse_rms_norm: - return paddle.incubate.nn.functional.fused_rms_norm_ext( - hidden_states, self.weight, self.variance_epsilon - )[0] - with paddle.amp.auto_cast(False): - variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) - hidden_states = ( - paddle.rsqrt(variance + self.variance_epsilon) * hidden_states - ) - - if self.weight.dtype in [paddle.float16, paddle.bfloat16]: - hidden_states = paddle.cast(hidden_states, self.weight.dtype) - return hidden_states * self.weight - - class LayerNorm(nn.LayerNorm): - def __init__(self, config, ipp=0): super().__init__(config.hidden_size, epsilon=config.rms_norm_eps) - self.use_fast_ln = config.use_fast_ln - self.ipp = ipp if config.pipeline_parallel_degree > 1: self.weight = dist.shard_tensor( - self.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Replicate()] + self.weight, get_mesh(ipp), [dist.Replicate(), dist.Replicate()] ) self.bias = dist.shard_tensor( - self.bias, get_mesh(self.ipp), [dist.Replicate(), dist.Replicate()] + self.bias, get_mesh(ipp), [dist.Replicate(), dist.Replicate()] ) -class RotaryEmbedding(nn.Layer): - - def __init__(self, dim, max_position_embeddings=4096, base=10000): - super().__init__() - self.base = base - self.max_position_embeddings = max_position_embeddings - inv_freq = 1.0 / ( - base ** (paddle.cast(paddle.arange(0, dim, 2), dtype="float32") / dim) - ) - - t = paddle.arange(max_position_embeddings, dtype="float32") - freqs = paddle.einsum("i,j->ij", t, inv_freq.cast("float32")) - emb = paddle.concat([freqs, freqs], axis=-1) - - self.cos_cached = emb.cos() - self.sin_cached = emb.sin() - - self._cast_to_low_precision = False - self._cast_to_low_precison = False - - def forward(self, x, seq_len=None): - return ( - self.cos_cached[:seq_len, :], - self.sin_cached[:seq_len, :], - ) - - @classmethod - def rotate_half(cls, x): - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return paddle.concat([-x2, x1], axis=-1) - - @classmethod - def apply_rotary_pos_emb(cls, q, k, cos, sin, offset: int = 0, position_ids=None): - if position_ids is not None: - assert offset == 0, offset - cos = F.embedding(position_ids, cos) - sin = F.embedding(position_ids, sin) - else: - cos = cos.unsqueeze(0) - sin = sin.unsqueeze(0) - cos = cos[:, offset : q.shape[1] + offset, None, :] - sin = sin[:, offset : q.shape[1] + offset, None, :] - - q_embed = paddle.add( - paddle.multiply(q, cos), paddle.multiply(cls.rotate_half(q), sin) - ) - k_embed = paddle.add( - paddle.multiply(k, cos), paddle.multiply(cls.rotate_half(k), sin) - ) - q_embed = q_embed.astype(q.dtype) - k_embed = k_embed.astype(k.dtype) - return q_embed, k_embed - - -class RopeEmbedding(nn.Layer): - def __init__(self, head_dim, compression_ratio=1.0, base=10000): - super().__init__() - self.head_dim = head_dim - self.compression_ratio = compression_ratio - self.base = base - - def forward(self, seq_length, position_ids=None): - indices = paddle.arange(0, self.head_dim, 2, dtype="float32") - indices = 1 / self.base ** (indices / self.head_dim) - if position_ids is None: - position_ids = paddle.arange(0, seq_length, 1, dtype="float32").unsqueeze(1) - position_ids = position_ids / self.compression_ratio - sinusoid_inp = position_ids * indices.unsqueeze(0) - else: - position_ids = position_ids / self.compression_ratio - seq_length = position_ids.shape[-1] - sinusoid_inp = position_ids.unsqueeze(-1).astype( - "float32" - ) * indices.unsqueeze(0) - pos_emb = paddle.concat( - [paddle.sin(sinusoid_inp), paddle.cos(sinusoid_inp)], axis=-1 - ) - pos_emb = paddle.reshape(pos_emb, (-1, 1, seq_length, self.head_dim)) - pos_emb.stop_gradient = True - return pos_emb - - def apply_rotary(self, rp, q, k): - sin, cos = paddle.chunk(rp, 2, axis=-1) - sin_pos = paddle.reshape(paddle.stack([sin, sin], axis=-1), rp.shape) - cos_pos = paddle.reshape(paddle.stack([cos, cos], axis=-1), rp.shape) - rotate_half_q = paddle.reshape( - paddle.stack([-q[:, :, :, 1::2], q[:, :, :, 0::2]], axis=-1), - paddle.shape(q), - ) - query = paddle.add( - paddle.multiply(q.astype("float32"), cos_pos), - paddle.multiply(rotate_half_q.astype("float32"), sin_pos), - ) - rotate_half_k = paddle.reshape( - paddle.stack([-k[:, :, :, 1::2], k[:, :, :, 0::2]], axis=-1), - paddle.shape(k), - ) - key = paddle.add( - paddle.multiply(k.astype("float32"), cos_pos), - paddle.multiply(rotate_half_k.astype("float32"), sin_pos), - ) - return query, key - - def forward_single(self, position_ids): - batch_size, seq_length = position_ids.shape[:2] - rope_emb = paddle.zeros( - (2, batch_size, seq_length, 1, self.head_dim), dtype="float32" - ) - inv_freq = self.base ** ( - -paddle.arange(0, self.head_dim, 2, dtype="float32") / self.head_dim - ) - position_ids = position_ids.cast("float32") - position_ids = position_ids / self.compression_ratio - freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) - emb = paddle.stack([freqs, freqs], axis=-1).reshape( - (batch_size, seq_length, self.head_dim) - ) - emb = paddle.unsqueeze(emb, 2) - - rope_emb[0] = paddle.cos(emb) - rope_emb[1] = paddle.sin(emb) - return rope_emb - - @staticmethod - def apply_rotary_single(x, rope_emb): - rotate_half_x = paddle.reshape( - paddle.stack([-x[:, :, :, 1::2], x[:, :, :, 0::2]], axis=-1), - paddle.shape(x), - ) - return x * rope_emb[0] + rotate_half_x * rope_emb[1] - - -class ErnieMLP(nn.Layer): +class ErnieMLPVPP(ErnieMLP): def __init__(self, config, ipp=None, do_shard_tensor=True): - super().__init__() - self.config = config - self.ipp = ipp - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - - self.gate_proj = nn.Linear( - self.hidden_size, self.intermediate_size, bias_attr=config.use_bias - ) - self.up_proj = nn.Linear( - self.hidden_size, self.intermediate_size, bias_attr=config.use_bias - ) - self.down_proj = nn.Linear( - self.intermediate_size, self.hidden_size, bias_attr=config.use_bias - ) - + super().__init__(config, ipp, do_shard_tensor) if do_shard_tensor and ( self.config.tensor_parallel_degree > 1 or self.config.pipeline_parallel_degree > 1 @@ -560,80 +101,16 @@ def __init__(self, config, ipp=None, do_shard_tensor=True): [dist.Replicate(), dist.Replicate()], ) - self.fuse_swiglu = config.fuse_swiglu - def forward(self, x): - if self.fuse_swiglu: - x = swiglu(self.gate_proj(x), self.up_proj(x)) - else: - x = F.silu(self.gate_proj(x)) * self.up_proj(x) - - out = self.down_proj(x) + out = super().forward(x) if self.config.sequence_parallel: out = dist.reshard(out, get_mesh(self.ipp), [dist.Shard(1), dist.Shard(0)]) return out -class ErnieAttention(nn.Layer): +class ErnieAttentionVPP(ErnieAttention): def __init__(self, config, ipp: Optional[int] = None): - super().__init__() - self.ipp = ipp - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.num_key_value_heads = config.num_key_value_heads - self.head_dim = self.hidden_size // self.num_heads - self.is_gqa = ( - config.num_key_value_heads is not None - and config.num_key_value_heads != self.num_heads - ) - self.fuse_rope = config.fuse_rope - - if self.is_gqa: - logger.info( - f"use GQA - num_heads: {self.num_heads}- num_key_value_heads: {self.num_key_value_heads}" - ) - assert ( - self.num_heads % self.num_key_value_heads == 0 - ), f"num_heads: {self.num_heads}, num_key_value_heads: {self.num_key_value_heads}" - kv_hidden_size = ( - self.hidden_size // self.num_heads * self.num_key_value_heads - ) - - self.q_proj = nn.Linear( - self.hidden_size, - self.hidden_size, - bias_attr=config.use_bias, - ) - self.k_proj = nn.Linear( - self.hidden_size, - self.hidden_size if not self.is_gqa else kv_hidden_size, - bias_attr=config.use_bias, - ) - self.v_proj = nn.Linear( - self.hidden_size, - self.hidden_size if not self.is_gqa else kv_hidden_size, - bias_attr=config.use_bias, - ) - self.o_proj = nn.Linear( - self.hidden_size, - self.hidden_size, - bias_attr=config.use_bias, - ) - if config.rope_reorder: - self.rotary_emb = RotaryEmbedding( - self.head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ) - else: - self.rotary_emb = RopeEmbedding( - self.head_dim, - compression_ratio=config.compression_ratio, - base=config.rope_theta, - ) - - self.config = config - + super().__init__(config, ipp) self.q_proj.weight = dist.shard_tensor( self.q_proj.weight, get_mesh(self.ipp), @@ -737,333 +214,17 @@ def forward( return attn_output, attn_weights, past_key_value - def rope_attn( - self, - mix_layer, - query_states, - key_states, - value_states, - attention_mask, - position_ids, - output_attentions=False, - past_key_value=None, - use_cache=False, - inbatch_pack_offset=None, - ): - if mix_layer is not None: - query_states, key_states, value_states = paddle.split(mix_layer, 3, axis=-1) - query_states_dtype = query_states.dtype - - kv_seq_len = key_states.shape[-3] - offset = 0 - if past_key_value is not None: - offset = past_key_value[0].shape[-3] - kv_seq_len += offset - - if self.config.rope_reorder: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - query_states, key_states = self.rotary_emb.apply_rotary_pos_emb( - query_states, - key_states, - cos, - sin, - position_ids=position_ids, - offset=offset if position_ids is None else 0, - ) - else: - if offset > 0 or position_ids is not None or not self.fuse_rope: - cos_sin = self.rotary_emb(kv_seq_len, position_ids).transpose( - [0, 2, 1, 3] - ) - if offset > 0 and position_ids is None: - cos_sin = cos_sin[:, offset:] - query_states, key_states = self.rotary_emb.apply_rotary( - cos_sin, query_states, key_states - ) - else: - bsz, q_len, num_heads, head_dim = query_states.shape - _, kv_seq_len, num_key_value_heads, _ = key_states.shape - if num_heads != num_key_value_heads: - query_states, _, _ = fused_rope(query_states, None, None) - key_states, _, _ = fused_rope(key_states, None, None) - else: - query_states, key_states, _ = fused_rope( - query_states, key_states, None - ) - - if use_cache: - query_states = query_states.astype(query_states_dtype) - key_states = key_states.astype(query_states_dtype) - if past_key_value is not None: - key_states = paddle.concat([past_key_value[0], key_states], axis=1) - value_states = paddle.concat([past_key_value[1], value_states], axis=1) - - past_key_value = [key_states, value_states] if use_cache else None - - attn_output, attn_weights = scaled_dot_product_attention( - query_states=query_states, - key_states=key_states, - value_states=value_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - config=self.config, - inbatch_pack_offset=inbatch_pack_offset, - training=self.training, - ) - return attn_output, attn_weights, past_key_value - - -class ErnieMoeMLP(ErnieMLP): - """_summary_ - - Args: - ErnieMoeMLP (_type_): _description_ - """ - - def __init__(self, config, ipp=0): - """ - doc - """ - disable_ffn_model_parallel = getattr( - config, "disable_ffn_model_parallel", False - ) - if disable_ffn_model_parallel: - config = deepcopy(config) - config.tensor_parallel_degree = 1 - config.sequence_parallel = False - - super().__init__(config, ipp, do_shard_tensor=not disable_ffn_model_parallel) - self.moe_dropout_prob = config.moe_dropout_prob - self.fuse_swiglu = config.fuse_swiglu - - def redistribute_expert(self, mesh, placements): - """ - Place the experts on different devices. - """ - self.gate_proj.weight = dist.shard_tensor( - self.gate_proj.weight, mesh, placements - ) - self.up_proj.weight = dist.shard_tensor(self.up_proj.weight, mesh, placements) - self.down_proj.weight = dist.shard_tensor( - self.down_proj.weight, mesh, placements - ) - if self.config.use_bias: - self.gate_proj.bias = dist.shard_tensor( - self.gate_proj.bias, mesh, placements - ) - self.up_proj.bias = dist.shard_tensor(self.up_proj.bias, mesh, placements) - self.down_proj.bias = dist.shard_tensor( - self.down_proj.bias, mesh, placements - ) - - def forward(self, x): - if self.fuse_swiglu: - x = swiglu(self.gate_proj(x), self.up_proj(x)) - else: - x = F.silu(self.gate_proj(x)) * self.up_proj(x) - if self.moe_dropout_prob > 0: - with get_rng_state_tracker().rng_state("local_seed"): - x = F.dropout(x=x, p=self.moe_dropout_prob) - ret = self.down_proj(x) - return ret - - -class BMMLinear(nn.Layer): - def __init__(self, experts, d_in, d_out, use_bias=False): - super().__init__() - self.weight = self.create_parameter( - [experts, d_in, d_out], dtype=paddle.get_default_dtype() - ) - if use_bias: - self.bias = self.create_parameter( - [experts, d_out], dtype=paddle.get_default_dtype(), is_bias=True - ) - else: - self.bias = None - - def forward(self, x): - """x: [num_experts, Seq, dim]""" - if self.bias is not None: - return paddle.bmm(x, self.weight) + self.bias - return paddle.bmm(x, self.weight) - - -class ErnieMoeMLPFused(nn.Layer): - def __init__(self, config): - assert ( - hasattr(config, "disable_ffn_model_parallel") - or config.tensor_parallel_degree == 1 - ), f"fused mlp only suport mp-moe, mp={config.tensor_parallel_degree}" - assert config.fuse_attn_ffn, "fused mlp only support fuse_attn_ffn" - super().__init__() - self.moe_dropout_prob = config.moe_dropout_prob - self.num_local_experts = config.moe_num_experts // config.moe_world_size - logger.info( - f"fused-expert-weight-shape: {[self.num_local_experts, config.hidden_size, config.intermediate_size]}" - ) - - self.up_gate_proj = BMMLinear( - self.num_local_experts, config.hidden_size, config.intermediate_size * 2 - ) - self.down_proj = BMMLinear( - self.num_local_experts, config.intermediate_size, config.hidden_size - ) - self.fuse_swiglu = config.fuse_swiglu - - def __len__(self): - return self.num_local_experts - - def __iter__(self): - return (self for _ in range(1)) - - def forward(self, x): - if self.fuse_swiglu: - x = swiglu(self.up_gate_proj(x)) - else: - gate, x = self.up_gate_proj(x).chunk(2, axis=-1) - x = F.silu(gate) * x - x = self.down_proj(x) - return x - -class ErnieDecoderLayer(nn.Layer): +class ErnieDecoderLayerVPP(ErnieDecoderLayer): """ - ErnieDecoderLayer is a decoder layer in Ernie model. - It is composed of self-attention, cross-attention and feedforward layers. + ErnieDecoderLayerVPP is ErnieDecoderLayer with sequence_parallel and tensor_parallel. """ def __init__(self, config, layer_idx=0, ipp=0): - """ - Initializes the ErnieBlock module. - - Args: - config (ErnieConfig): The model configuration. - layer_idx (int, optional): The index of this block in the model. Defaults to 0. - ipp (int, optional): The index of this block in the pipeline parallelism. Defaults to 0. - """ - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.ipp = ipp - self.hidden_size = config.hidden_size - self.self_attn = ErnieAttention(config, ipp) - self.use_moe = config.use_moe if hasattr(config, "use_moe") else False - if self.use_moe: - moe_layer_start_index = ( - min(config.moe_layer_start_index) - if isinstance(config.moe_layer_start_index, (tuple, list)) - else config.moe_layer_start_index - ) - moe_layer_end_index = ( - max(config.moe_layer_end_index) - if isinstance(config.moe_layer_end_index, (tuple, list)) - else config.moe_layer_end_index - ) - - if ( - self.use_moe - and ((layer_idx + 1) % config.moe_layer_interval == 0) - and layer_idx >= moe_layer_start_index - and layer_idx <= moe_layer_end_index - ): - self.create_moe_mlp_layer(layer_idx, ipp) - else: - self.mlp = ErnieMLP(config, ipp) - Norm = RMSNorm if config.use_rmsnorm else LayerNorm - self.input_layernorm = Norm(config, ipp) - self.post_attention_layernorm = Norm(config, ipp) - self.residual_add1 = FusedDropoutAdd( - config.hidden_dropout_prob, mode="upscale_in_train" - ) - self.residual_add2 = FusedDropoutAdd( - config.hidden_dropout_prob, mode="upscale_in_train" - ) - - def create_moe_mlp_layer(self, layer_idx, ipp): - _ex_cfg = deepcopy(self.config) - fc_cls = ErnieMoeMLPFused if _ex_cfg.moe_fuse_experts else ErnieMoeMLP - if _ex_cfg.moe_intermediate_size: - if isinstance(_ex_cfg.moe_intermediate_size, (tuple, list)): - assert isinstance(_ex_cfg.moe_num_experts, (tuple, list)) and len( - _ex_cfg.moe_num_experts - ) == len(_ex_cfg.moe_intermediate_size) - fc = [] - for _i, (num_experts, intermediate_size) in enumerate( - zip(_ex_cfg.moe_num_experts, _ex_cfg.moe_intermediate_size) - ): - _ex_cfg_real = deepcopy(_ex_cfg) - _ex_cfg_real.intermediate_size = intermediate_size - cur_modality_start_layer_idx = ( - self.config.moe_layer_start_index[_i] - if isinstance(self.config.moe_layer_start_index, (tuple, list)) - else self.config.moe_layer_start_index - ) - cur_modality_end_layer_idx = ( - self.config.moe_layer_end_index[_i] - if isinstance(self.config.moe_layer_end_index, (tuple, list)) - else self.config.moe_layer_end_index - ) - if ( - layer_idx >= cur_modality_start_layer_idx - and layer_idx <= cur_modality_end_layer_idx - ): - if _i == 1: - with paddle.utils.unique_name.guard( - f"mm_expert_{layer_idx}_" - ): - fc.append((num_experts, fc_cls(_ex_cfg_real))) - else: - fc.append((num_experts, fc_cls(_ex_cfg_real))) - else: - logger.info( - f"moe multimodal experts use Identity layer_idx: {layer_idx}" - ) - fc.append((num_experts, nn.Identity())) - else: - _ex_cfg.intermediate_size = _ex_cfg.moe_intermediate_size - fc = [(_ex_cfg.moe_num_experts, fc_cls(_ex_cfg))] - else: - fc = [(_ex_cfg.moe_num_experts, fc_cls(_ex_cfg))] - gate, experts, lm_gate, lm_experts, moe_statics = get_gate( - self.config, fc, layer_idx, self.ipp - ) - _sh_cfg = deepcopy(self.config) - - if _sh_cfg.moe_num_shared_experts > 0: - if _sh_cfg.moe_intermediate_size: - _sh_inter_size = ( - _sh_cfg.moe_intermediate_size[0] - if isinstance(_sh_cfg.moe_intermediate_size, (tuple, list)) - else _sh_cfg.moe_intermediate_size - ) - _sh_cfg.intermediate_size = ( - _sh_inter_size * _sh_cfg.moe_num_shared_experts - ) - else: - _sh_cfg.intermediate_size = ( - _sh_cfg.intermediate_size * _sh_cfg.moe_num_shared_experts - ) - _sh_cfg.disable_ffn_model_parallel = False - shared_experts = ErnieMoeMLP(_sh_cfg, ipp) - else: - shared_experts = None - - logger.info(f"moe-logging:{self.config.moe_logging}") - self.mlp = MOELayer( - gate, - experts, - layer_idx=layer_idx, - shared_experts=shared_experts, - group=self.config.moe_group, - recompute=self.config.use_recompute_moe, - k=self.config.moe_k, - all_to_all_dropout=self.config.moe_all_to_all_dropout, - group_experts=self.config.moe_group_experts, - moe_statics=moe_statics, - config=self.config, - ipp=self.ipp, - ) + super().__init__(config, layer_idx, ipp) + self.self_attn = ErnieAttentionVPP(config, ipp) + if isinstance(self.mlp, ErnieMLP): + self.mlp = ErnieMLPVPP(config, ipp) def forward( self, @@ -1077,20 +238,6 @@ def forward( token_type_ids: Optional[paddle.Tensor] = None, output_gate_logits=True, ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: - """ - Args: - hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`paddle.Tensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `cache` key value states are returned and can be used to speed up decoding - (see `cache`). - cache (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states - """ - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -1176,106 +323,18 @@ def forward( return outputs -class ErniePretrainedModel(PretrainedModel): +class ErnieModelVPP(ErnieModel): """ - ErniePretrainedModel is a pretrained model class for Ernie model. - It is composed of a encoder and a decoder. + ErnieModelVPP is a variant of ErnieModel that support vpp schedule. """ - config_class = ErnieMoEConfig - base_model_prefix = "ernie" - - def init_weights(self, layer): - """Initialization hook""" - if self.config.tensor_parallel_degree > 1: - rng_tracker = get_rng_state_tracker().rng_state - else: - rng_tracker = contextlib.nullcontext - - if isinstance( - layer, - ( - ErnieLMHead, - nn.Embedding, - nn.Linear, - paddle.incubate.nn.FusedLinear, - ), - ): - - with rng_tracker(): - dtype = paddle.get_default_dtype() - paddle.set_default_dtype("float32") - if layer.weight._is_initialized(): - if layer.weight.is_dist(): - layer.weight._local_value().set_value( - paddle.randn( - layer.weight._local_shape, dtype=layer.weight.dtype - ).scale(self.config.initializer_range) - ) - else: - layer.weight.set_value( - paddle.randn( - layer.weight.shape, dtype=layer.weight.dtype - ).scale(self.config.initializer_range) - ) - paddle.set_default_dtype(dtype) - logger.info( - f"dist-init-fc: shape={layer.weight.shape}, " - f" range={self.config.initializer_range}," - f' type={type(layer)},norm={layer.weight.astype("float32").norm()}' - ) - - elif isinstance(layer, TopKGateFused): - if not hasattr(layer, "weight"): - return - with rng_tracker("model_parallel_rng"): - dtype = paddle.get_default_dtype() - paddle.set_default_dtype("float32") - if self.config.moe_group_experts: - if layer.weight._is_initialized(): - layer.weight.set_value( - paddle.randn( - layer.weight.shape, dtype=layer.weight.dtype - ).scale(self.config.initializer_range) - ) - else: - if layer.weight._is_initialized(): - granularity = ( - 1 - if self.config.moe_intermediate_size == 0 - else self.config.intermediate_size - // self.config.moe_intermediate_size - ) - layer.weight.set_value( - paddle.randn( - [ - self.config.hidden_size, - self.config.moe_num_experts // granularity, - ], - dtype="float32", - ) - .scale(self.config.initializer_range) - .repeat_interleave(granularity, axis=-1) - ) - logger.info( - f"dist-init-moe_gate: shape={layer.weight.shape}, dtype={layer.weight.dtype} " - f"range={self.config.initializer_range},type={type(layer)}, " - f'norm={layer.weight.astype("float32").norm()}' - ) - - -class ErnieModel(ErniePretrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ErnieDecoderLayer`] - Args: - config: ErnieMoEConfig - """ - - def __init__(self, config: ErnieMoEConfig, pp_layer_idx=None): - super().__init__(config) + def __init__(self, config: ErnieMoEConfig, pp_layer_idx=None, ipp=0): + super(ErniePretrainedModel, self).__init__(config) self.padding_idx = config.pad_token_id self.config = config - if config.pipeline_parallel_degree <= 1 or pp_layer_idx == 0: + self.layer = ErnieDecoderLayerVPP(config, pp_layer_idx, ipp) + + if pp_layer_idx == 0: self.vocab_size = config.vocab_size self.hidden_size = config.hidden_size self.embed_tokens = nn.Embedding( @@ -1287,21 +346,10 @@ def __init__(self, config: ErnieMoEConfig, pp_layer_idx=None): get_mesh(pp_idx=0), [dist.Replicate(), dist.Shard(1)], ) - if config.pipeline_parallel_degree <= 1: - self.layers = nn.LayerList() - for idx in range( - config.num_hidden_layers - 1 - if config.remove_tail_layer - else config.num_hidden_layers - ): - self.layers.append(ErnieDecoderLayer(config, idx)) - if ( - config.pipeline_parallel_degree <= 1 - or pp_layer_idx == self.config.num_hidden_layers - 1 - ): + if pp_layer_idx == self.config.num_hidden_layers - 1: Norm = RMSNorm if config.use_rmsnorm else LayerNorm - self.norm = Norm(config, -1) - self.lm_head = ErnieLMHead(config) + self.norm = Norm(config) + self.lm_head = ErnieLMHeadVPP(config) self.gradient_checkpointing = False @@ -1315,7 +363,7 @@ def __init__(self, config: ErnieMoEConfig, pp_layer_idx=None): Norm = RMSNorm if config.use_rmsnorm else LayerNorm self.mtp_block = nn.LayerList( [ - ErnieDecoderLayer(config, layer_idx, -1) + ErnieDecoderLayerVPP(config, layer_idx, -1) for layer_idx in range(self.config.multi_token_pred_depth) ] ) @@ -1352,228 +400,7 @@ def __init__(self, config: ErnieMoEConfig, pp_layer_idx=None): self.all_self_attns = None self.next_decoder_cache = None self.inputs_embeds_cur_depth_list = None - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @classmethod - def _prepare_decoder_attention_mask( - cls, attention_mask, input_shape, past_key_values_length, dtype - ): - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, past_key_values_length=past_key_values_length, dtype=dtype - ) - - if attention_mask is not None: - expanded_attn_mask = _expand_mask( - attention_mask, dtype, tgt_length=input_shape[-1] - ) - combined_attention_mask = ( - expanded_attn_mask - if combined_attention_mask is None - else expanded_attn_mask + combined_attention_mask - ) - combined_attention_mask = paddle.maximum( - combined_attention_mask.astype(dtype), - paddle.to_tensor(float(paddle.finfo(dtype).min), dtype=dtype), - ) - return combined_attention_mask - - def recompute_training( - self, - layer_module, - hidden_states, - attention_mask, - position_ids, - output_attentions, - past_key_value, - use_cache, - inbatch_pack_offset, - token_type_ids, - ): - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_gate_logits=False) - - return custom_forward - - hidden_states = recompute( - create_custom_forward(layer_module), - hidden_states, - attention_mask, - position_ids, - output_attentions, - past_key_value, - use_cache, - inbatch_pack_offset, - token_type_ids, - use_reentrant=True, - ) - return hidden_states - - def embed_inputs(self, input_ids, attention_mask, position_ids): - inputs_embeds = self.inputs_embeds - - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - if self.past_key_values is None: - past_key_values = tuple([None] * self.config.num_hidden_layers) - - seq_length -= self.config.multi_token_pred_depth - seq_length_with_past = seq_length - cache_length = 0 - - if past_key_values[0] is not None: - cache_length = paddle.shape(past_key_values[0][0])[1] - seq_length_with_past += cache_length - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids).astype( - self.embed_tokens.weight.dtype - ) - - if self.config.multi_token_pred_depth > 0: - inputs_embeds_extra = inputs_embeds[ - :, -self.config.multi_token_pred_depth :, : - ] # [B, S, D] - inputs_embeds = inputs_embeds[:, : -self.config.multi_token_pred_depth, :] - inputs_embeds_ori = inputs_embeds - inputs_embeds_cur_depth_list = [] - for depth in range(self.config.multi_token_pred_depth): - inputs_embeds_cur_depth = paddle.concat( - [ - inputs_embeds_ori[:, (depth + 1) :, :], - inputs_embeds_extra[:, : (depth + 1), :], - ], - axis=1, - ) - inputs_embeds_cur_depth_list.append(inputs_embeds_cur_depth) - self.inputs_embeds_cur_depth_list = paddle.concat( - inputs_embeds_cur_depth_list - ) - - global_mesh = get_mesh(pp_idx=None) - if self.config.sequence_parallel: - inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) - - if position_ids is not None: - position_ids = dist.shard_tensor( - position_ids, - global_mesh, - [dist.Replicate() for _ in range(len(global_mesh._shape))], - ) - can_use_fa = self.config.use_flash_attn and flash_attention is not None - - if can_use_fa: - if attention_mask is not None: - attention_mask = None - - elif attention_mask is None: - attention_mask = paddle.ones( - (batch_size, seq_length_with_past), dtype=paddle.bool - ) - - if attention_mask is not None: - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, - (batch_size, seq_length), - cache_length, - inputs_embeds.dtype, - ) - attention_mask = dist.shard_tensor( - attention_mask, - global_mesh, - [dist.Replicate() for _ in range(len(global_mesh._shape))], - ) - - hidden_states = dist.reshard(inputs_embeds, get_mesh(0), self.placements) - - if self.config.multi_token_pred_depth > 0: - return ( - hidden_states, - attention_mask, - position_ids, - self.inputs_embeds_cur_depth_list, - ) - else: - return hidden_states, attention_mask, position_ids - - def decode_layer( - self, - decoder_layer, - hidden_states, - attention_mask, - position_ids, - all_router_loss=None, - ): - if self.config.output_hidden_states: - self.all_hidden_states += (hidden_states,) - has_gradient = not hidden_states.stop_gradient - position_ids_input = position_ids - attention_mask_input = attention_mask - token_type_ids_input = self.token_type_ids - - if self.config.use_recompute and has_gradient: - layer_outputs = self.recompute_training( - decoder_layer, - hidden_states, - attention_mask_input, - position_ids_input, - self.config.output_attentions, - self.past_key_values, - self.config.use_cache, - self.inbatch_pack_offset, - token_type_ids_input, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask_input, - position_ids_input, - self.config.output_attentions, - self.past_key_values, - self.config.use_cache, - self.inbatch_pack_offset, - token_type_ids_input, - ) - - if isinstance(layer_outputs, (tuple, list)): - hidden_states = layer_outputs[0] - else: - hidden_states = layer_outputs - - if self.config.use_cache: - self.next_decoder_cache += ( - layer_outputs[2 if self.config.output_attentions else 1], - ) - - if self.config.output_attentions: - self.all_self_attns += (layer_outputs[1],) - if hasattr(self.config, "use_moe") and self.config.use_moe: - if not (self.config.use_recompute and has_gradient): - layer_outputs, gate_logits = layer_outputs[:-1], layer_outputs[-1] - self.all_gate_logits = self.all_gate_logits + (gate_logits,) - router_loss = layer_outputs[-1] - if all_router_loss is not None: - all_router_loss += router_loss - return hidden_states, all_router_loss + self.reshard_replicate = ReshardLayer() def mtp_layer( self, hidden_states, inputs_embeds_cur_depth_list, attention_mask, position_ids @@ -1642,291 +469,6 @@ def mtp_layer( mtp_outputs = [self.norm(hidden_states) for hidden_states in mtp_outputs] return mtp_outputs - def forward( - self, - input_ids=None, - position_ids=None, - attention_mask=None, - inputs_embeds=None, - use_cache=None, - past_key_values=None, - output_attentions=False, - output_hidden_states=None, - return_dict=False, - inbatch_pack_offset=None, - token_type_ids=None, - **kwargs, - ): - self.inputs_embeds = inputs_embeds - self.past_key_values = past_key_values - self.inbatch_pack_offset = inbatch_pack_offset - self.token_type_ids = token_type_ids - self.inbatch_pack_offset = inbatch_pack_offset - if use_cache is not None: - self.config.use_cache = use_cache - if return_dict is not None: - self.config.return_dict = return_dict - if output_hidden_states is not None: - self.config.output_hidden_states = output_hidden_states - if output_attentions is not None: - self.config.output_attentions = output_attentions - - if self.config.multi_token_pred_depth > 0: - ( - hidden_states, - attention_mask, - position_ids, - inputs_embeds_cur_depth_list, - ) = self.embed_inputs(input_ids, attention_mask, position_ids) - else: - hidden_states, attention_mask, position_ids = self.embed_inputs( - input_ids, attention_mask, position_ids - ) - - self.all_hidden_states = () if output_hidden_states else None - self.all_self_attns = () if output_attentions else None - self.next_decoder_cache = () if use_cache else None - - all_router_loss = None - if hasattr(self.config, "use_moe") and self.config.use_moe: - all_router_loss = paddle.to_tensor(0.0) - - for idx, (decoder_layer) in enumerate(self.layers): - hidden_states, all_router_loss = self.decode_layer( - decoder_layer, - hidden_states, - attention_mask, - position_ids, - all_router_loss, - ) - - if use_cache and not (hasattr(self.config, "use_moe") and self.config.use_moe): - hidden_states = paddle.unsqueeze(hidden_states[:, -1, :], 1) - - # Multi Token Prediction - mtp_outputs = [] - if self.config.multi_token_pred_depth > 0: - inputs_embeds_cur_depth_list = paddle.split( - inputs_embeds_cur_depth_list, self.config.multi_token_pred_depth - ) - mtp_outputs = self.mtp_layer( - hidden_states, - inputs_embeds_cur_depth_list, - attention_mask, - position_ids, - ) - hidden_states, mtp_outputs = mtp_outputs[0], mtp_outputs[1:] - else: - hidden_states = self.norm(hidden_states) - - if output_hidden_states: - self.all_hidden_states += (hidden_states,) - - next_cache = self.next_decoder_cache if use_cache else None - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_cache, - self.all_hidden_states, - self.all_self_attns, - all_router_loss, - self.all_gate_logits, - mtp_outputs, - ] - if v is not None - ) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=self.all_hidden_states, - attentions=self.all_self_attns, - cross_attentions=None, - router_loss=all_router_loss, - gate_logits=self.all_gate_logits, - mtp_outputs=mtp_outputs, - ) - - -class ErniePretrainingCriterion(paddle.nn.Layer): - """ - Criterion for Ernie. - It calculates the final loss. - """ - - def __init__(self, config, return_tuple=True): - super(ErniePretrainingCriterion, self).__init__() - self.ignored_index = getattr(config, "ignored_index", -100) - self.config = config - self.return_tuple = return_tuple - - self.loss_func = paddle.nn.CrossEntropyLoss( - reduction="none", - ) - - def forward(self, prediction_scores, masked_lm_labels, router_loss=None): - """ - calculates the final loss - """ - if self.config.multi_token_pred_depth > 0: - # prediction_scores :[logits, mtp_logits] - logits = paddle.split( - prediction_scores, self.config.multi_token_pred_depth + 1 - ) - prediction_scores = logits[0] - mtp_logits = logits[1:] - masked_lm_labels_ori = masked_lm_labels - masked_lm_labels = masked_lm_labels[ - :, : -self.config.multi_token_pred_depth - ] - seq_length = masked_lm_labels.shape[1] - res = self.forward_impl(prediction_scores, masked_lm_labels) - if self.config.multi_token_pred_depth > 0: - mtp_loss_res = [] - for depth in range(self.config.multi_token_pred_depth): - prediction_scores_cur_depth = mtp_logits[depth] - masked_lm_labels_cur_depth = masked_lm_labels_ori[ - :, (depth + 1) : (depth + 1 + seq_length) - ] - res_cur_depth = self.forward_impl( - prediction_scores_cur_depth, - masked_lm_labels_cur_depth, - ) - mtp_loss_res.append(res_cur_depth) - - def add_loss(main_loss, loss): - return main_loss + loss - loss.detach() - - if self.return_tuple: - loss, loss_sum = res - if self.config.multi_token_pred_depth > 0: - loss = add_loss( - loss, - self.config.multi_token_pred_lambda - * sum([x[0] for x in mtp_loss_res]) - / len(mtp_loss_res), - ) - loss_sum = loss_sum + self.config.multi_token_pred_lambda * sum( - [x[1].detach() for x in mtp_loss_res] - ) / len(mtp_loss_res) - else: - loss, loss_sum = res, None - if self.config.multi_token_pred_depth > 0: - loss = add_loss( - loss, - self.config.multi_token_pred_lambda - * sum(mtp_loss_res) - / len(mtp_loss_res), - ) - if router_loss is not None: - loss = loss + router_loss - router_loss.detach() - if not self.return_tuple: - return loss - return loss, loss_sum - - def forward_impl(self, prediction_scores, masked_lm_labels): - with paddle.amp.auto_cast(False): - masked_lm_loss = self.loss_func( - prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(-1) - ) - lossmask = masked_lm_labels != self.ignored_index - - if (~lossmask).all(): - logger.warning( - f"encounter empty span when calculate loss, ignored_index={self.ignored_index}" - ) - loss = paddle.mean(masked_lm_loss) * 0.0 - loss_sum = masked_lm_loss.sum().detach() - else: - lossmask_ = lossmask.reshape([-1]).cast(paddle.float32) - masked_lm_loss_ = paddle.sum( - masked_lm_loss.cast(paddle.float32).reshape([-1]) * lossmask_ - ) - loss = masked_lm_loss_ / lossmask_.sum() - loss_sum = masked_lm_loss_.sum().detach() - - if not self.return_tuple: - if self.training: - return loss - return loss_sum - return loss, loss_sum - - -class ErnieLMHead(nn.Layer): - """ - ErnieLMHead is the linear layer used to project hidden state of decoder into word embeddings. - """ - - def __init__(self, config): - super(ErnieLMHead, self).__init__() - self.config = config - self.weight = self.create_parameter( - shape=( - [config.vocab_size, config.hidden_size] - if config.tie_word_embeddings - else [config.hidden_size, config.vocab_size] - ), - dtype=paddle.get_default_dtype(), - ) - - if ( - self.config.tensor_parallel_degree > 1 - or self.config.pipeline_parallel_degree > 1 - ): - self.weight = dist.shard_tensor( - self.weight, - get_mesh(-1), - [dist.Replicate(), dist.Shard(1)], - ) - self.weight.is_distributed = False - - logger.info( - f"output-weight:{self.weight.shape} config.tie_word_embeddings={config.tie_word_embeddings}" - ) - if config.weight_share_add_bias and config.use_bias: - self.bias = self.create_parameter( - shape=[config.vocab_size], - dtype=paddle.get_default_dtype(), - attr=paddle.ParamAttr( - initializer=paddle.nn.initializer.constant.Constant(0.0) - ), - ) - if ( - self.config.tensor_parallel_degree > 1 - or self.config.pipeline_parallel_degree > 1 - ): - self.bias = dist.shard_tensor( - self.bias, - get_mesh(-1), - [dist.Replicate(), dist.Shard(0)], - ) - self.bias.is_distributed = False - else: - self.bias = None - - if self.config.use_recompute_loss_fn: - logger.info( - "Using recompute_loss_fn, the calculation of logits will be moved into " - "loss_fn for memory optimization" - ) - - def forward(self, hidden_states): - return calc_lm_head_logits( - self.config, - hidden_states, - self.weight, - self.bias, - None, - ) - - -class ErnieModelPP(ErnieModel): - def __init__(self, config, layer_idx=0, ipp=0): - super().__init__(config, layer_idx) - self.layer = ErnieDecoderLayer(config, layer_idx, ipp) - def forward(self, args): attention_mask, position_ids = None, None if isinstance(args, tuple): @@ -1951,6 +493,24 @@ def forward(self, args): hidden_states, attention_mask, position_ids = self.embed_inputs( hidden_states, attention_mask, position_ids ) + global_mesh = get_mesh(pp_idx=None) + if self.config.sequence_parallel: + hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) + + if position_ids is not None: + position_ids = dist.shard_tensor( + position_ids, + global_mesh, + [dist.Replicate() for _ in range(len(global_mesh._shape))], + ) + if attention_mask is not None: + attention_mask = dist.shard_tensor( + attention_mask, + global_mesh, + [dist.Replicate() for _ in range(len(global_mesh._shape))], + ) + hidden_states = dist.reshard(hidden_states, get_mesh(0), self.placements) + hidden_states, _ = self.decode_layer( self.layer, hidden_states, attention_mask, position_ids ) @@ -1985,20 +545,69 @@ def forward(self, args): return hidden_states -class ErnieForCausalLM(ErniePretrainedModel): +class ErnieLMHeadVPP(ErnieLMHead): """ - ErnieForCausalLM is the model class for causal language modeling. + ErnieLMHeadVPP is ErnieLMHead for vpp schedule with shard_tensor """ - _keys_to_ignore_on_load_missing = [r"lm_head.weight"] - def __init__(self, config): super().__init__(config) + if ( + self.config.tensor_parallel_degree > 1 + or self.config.pipeline_parallel_degree > 1 + ): + self.weight = dist.shard_tensor( + self.weight, + get_mesh(-1), + [dist.Replicate(), dist.Shard(1)], + ) + if self.bias: + if ( + self.config.tensor_parallel_degree > 1 + or self.config.pipeline_parallel_degree > 1 + ): + self.bias = dist.shard_tensor( + self.bias, + get_mesh(-1), + [dist.Replicate(), dist.Shard(0)], + ) + + def forward(self, hidden_states): + if self.config.sequence_parallel: + hcg = paddle.distributed.fleet.get_hybrid_communicate_group() + dp_rank = hcg.get_data_parallel_rank() + sharding_rank = hcg.get_sharding_parallel_rank() + if dp_rank <= 1 and sharding_rank <= 1: + hidden_states = dist.reshard( + hidden_states, + get_mesh(-1), + [dist.Replicate(), dist.Replicate()], + ) + else: + hidden_states = dist.reshard( + hidden_states, + get_mesh(-1), + [dist.Shard(1), dist.Replicate()], + ) + # [S, B, H] to [B, S, H] + hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) + hidden_states = hidden_states.reshape( + [-1, self.config.seqlen, hidden_states.shape[-1]] + ) + return super().forward(hidden_states) + + +class ErnieForCausalLMVPP(ErnieForCausalLM): + """ + ErnieForCausalLMVPP is the model class for causal language modeling for vpp pipeline schedule mode. + """ + + def __init__(self, config): + super(ErniePretrainedModel, self).__init__(config) config.initializer_range = math.sqrt(0.3333 / config.hidden_size) logger.info(f"Initializer-range is {config.initializer_range}") self.config = config self.criterion = ErniePretrainingCriterion(config, False) - self.tie_weights() if config.pipeline_parallel_degree > 1: @@ -2015,147 +624,15 @@ def __init__(self, config): target_stage = (idx // chunk_size) % pp_degree if target_stage == current_rank: stage_id = (idx // chunk_size) % pp_degree - self.layers.append(ErnieModelPP(config, idx, stage_id)) + self.layers.append(ErnieModelVPP(config, idx, stage_id)) else: self.layers.append(nn.Identity()) - else: - self.ernie = ErnieModel(config) - self.lm_head = ErnieLMHead(config) def _post_init(self, original_init, *args, **kwargs): - """ - Initialize weights and apply final processing - """ + decoder_layers = [] + for layer in self.layers: + if isinstance(layer, ErnieModelVPP): + decoder_layers.append(layer.layer) + layers = decoder_layers + self.ernie = type("ernie", (), {"layers": layers})() super()._post_init(self, original_init, *args, **kwargs) - factor = 1 / math.sqrt(2 * self.config.num_hidden_layers) - logger.info(f"using post init div: factor:{factor}") - - def scale_by_factor_if_valid(w): - if w.is_dist() and w._is_initialized(): - w.scale_(factor) - - if self.config.pipeline_parallel_degree > 1: - decoder_layers = [] - for layer in self.layers: - if isinstance(layer, ErnieModelPP): - decoder_layers.append(layer.layer) - layers = decoder_layers - else: - layers = self.ernie.layers - if hasattr(self.config, "use_moe") and self.config.use_moe: - with paddle.no_grad(): - for left in layers: - if isinstance( - left.self_attn.o_proj, - (MOELayer), - ): - for e in left.self_attn.o_proj.experts: - if isinstance(e, ErnieMoeMLP): - scale_by_factor_if_valid(e.weight) - else: - scale_by_factor_if_valid(left.self_attn.o_proj.weight) - - if isinstance( - left.mlp, - (MOELayer), - ): - for e in left.mlp.experts: - if isinstance(e, ErnieMoeMLP): - scale_by_factor_if_valid(e.down_proj.weight) - else: - scale_by_factor_if_valid(left.mlp.down_proj.weight) - else: - with paddle.no_grad(): - for left in layers: - scale_by_factor_if_valid(left.self_attn.o_proj.weight) - scale_by_factor_if_valid(left.mlp.down_proj.weight) - - def get_input_embeddings(self): - return self.ernie.embed_tokens - - def set_input_embeddings(self, value): - self.ernie.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def forward( - self, - input_ids, - labels=None, - position_ids=None, - attention_mask=None, - inputs_embeds=None, - use_cache=False, - past_key_values=None, - output_attentions=None, - output_hidden_states=None, - return_dict=False, - ignored_index=0, - inbatch_pack_offset=None, - token_type_ids=None, - ): - if isinstance(input_ids, list): - input_ids, labels = input_ids[:2] - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - outputs = self.ernie( - input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - past_key_values=past_key_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=True, - inbatch_pack_offset=inbatch_pack_offset, - token_type_ids=token_type_ids, - ) - - hidden_states = outputs.last_hidden_state - mtp_outputs = outputs.mtp_outputs - logits = self.lm_head(hidden_states) - - mtp_logits = [logits] - if len(mtp_outputs) > 0: - for _hidden_states in mtp_outputs: - mtp_logits.append(self.lm_head(_hidden_states)) - logits = paddle.concat(mtp_logits) - - if return_dict: - if labels is not None: - loss, _ = self.criterion(logits, labels) - else: - loss = None - return CausalLMOutputWithCrossAttentionsErnie( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - router_loss=outputs.router_loss if self.config.use_moe else None, - ) - - assert labels is not None - router_loss = ( - outputs.router_loss - if hasattr(self.config, "use_moe") and self.config.use_moe - else None - ) - return self.criterion(logits, labels, router_loss) diff --git a/examples/auto_parallel/models/moe_layer.py b/examples/auto_parallel/models/moe_layer.py index 917e9a677..90ad09f3a 100644 --- a/examples/auto_parallel/models/moe_layer.py +++ b/examples/auto_parallel/models/moe_layer.py @@ -15,24 +15,29 @@ # limitations under the License. import inspect -from typing import Tuple, List, Optional import logging +import numpy as np from contextlib import contextmanager +from typing import Tuple, List, Optional +from functools import partial +from copy import deepcopy import paddle -from paddle import nn import paddle.nn.functional as F - +import paddle.distributed as dist +from paddle import nn +from paddle.incubate.nn.functional import swiglu from paddle.distributed.communication.group import Group from paddle.distributed import fleet -import paddle.distributed as dist from paddle import Tensor from paddle.incubate.nn.functional import moe_combine, moe_gate_dispatch - +from paddle.utils import unique_name +from paddle.distributed.fleet.layers.mpu.random import get_rng_state_tracker from paddleformers.trainer.plugins.timer import get_timers from paddleformers.transformers.moe_layer import dispatching, combining -from models.top2_gate import TopKGateFused + from utils.training_utils import get_flatten_mesh, get_mesh, _reshard +from models.configuration import ErnieMoEConfig logger = logging.getLogger(__name__) @@ -705,3 +710,582 @@ def forward( [dist.Replicate(), dist.Replicate()], ) return combined_output, combine_weights, router_loss2, gate_logits + + +def get_gate( + config: ErnieMoEConfig, + expert: Tuple[Tuple[int, nn.Layer]], + layer_idx: int, + ipp: int = 0, +) -> Tuple[nn.Layer, nn.LayerList]: + moe_num_experts = config.moe_num_experts + assert ( + moe_num_experts >= config.moe_world_size + ), f"expert moe_num_experts={moe_num_experts} >= moe_world_size={config.moe_world_size}" + assert ( + moe_num_experts % config.moe_world_size == 0 + ), f"expert moe_num_experts={moe_num_experts} % moe_world_size={config.moe_world_size} == 0" + moe_num_experts_per_device = moe_num_experts // config.moe_world_size + experts = nn.LayerList([]) + for expert_id, (experts_num, fc) in enumerate(expert): + assert experts_num % config.moe_world_size == 0 + experts_to_append = [] + if not hasattr(fc, "__len__"): + experts_to_append.append(fc) + if expert_id == 1: + with paddle.utils.unique_name.guard("_mm_deepcopy"): + for _ in range(experts_num - 1): + experts_to_append.append(deepcopy(fc)) + else: + for _ in range(experts_num - 1): + experts_to_append.append(deepcopy(fc)) + else: + experts_to_append = fc + for ex in experts_to_append: + for p in ex.parameters(): + p.expert_type = f"expert_type_{expert_id}" + experts.extend(experts_to_append) + + logger.info( + f"using moe-world-size: {config.moe_world_size} " + f"expert-per-device: {moe_num_experts_per_device} " + ) + if config.moe_use_hard_gate and moe_num_experts <= 2: + gate = None + logger.info("MOE-GATE:-hard-gate") + else: + logger.info(f"MOE-GATE:-{config.moe_gate}") + gate = TopKGateFused( + config, layer_idx=layer_idx, group=config.moe_group, ipp=ipp + ) + + lm_gate, lm_experts = None, None + logger.info(f"LM-experts-{lm_experts} -- experts-{experts}") + + index = 0 if config.moe_group == "dp" else 1 + ep_sub_meshes = dist.auto_parallel.api.split_mesh(get_mesh(ipp), index) + + for i, expert in enumerate(experts): + ep_group_id = i // moe_num_experts_per_device + if isinstance(expert, (ErnieMoeMLPFused, ErnieMoeMLP)): + experts[i].redistribute_expert( + ep_sub_meshes[ep_group_id], [dist.Replicate(), dist.Replicate()] + ) + experts[i].ep_group_id = ep_group_id + + if config.moe_use_aux_free: + moe_statics = MoEStatics(config, layer_idx) + else: + moe_statics = None + return gate, experts, lm_gate, lm_experts, moe_statics + + +class ErnieMLP(nn.Layer): + def __init__(self, config, ipp=None, do_shard_tensor=True): + super().__init__() + self.config = config + self.ipp = ipp + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + self.gate_proj = nn.Linear( + self.hidden_size, self.intermediate_size, bias_attr=config.use_bias + ) + self.up_proj = nn.Linear( + self.hidden_size, self.intermediate_size, bias_attr=config.use_bias + ) + self.down_proj = nn.Linear( + self.intermediate_size, self.hidden_size, bias_attr=config.use_bias + ) + + self.fuse_swiglu = config.fuse_swiglu + + def forward(self, x): + if self.fuse_swiglu: + x = swiglu(self.gate_proj(x), self.up_proj(x)) + else: + x = F.silu(self.gate_proj(x)) * self.up_proj(x) + + out = self.down_proj(x) + return out + + +class ErnieMoeMLP(ErnieMLP): + """_summary_ + + Args: + ErnieMoeMLP (_type_): _description_ + """ + + def __init__(self, config, ipp=0): + """ + doc + """ + disable_ffn_model_parallel = getattr( + config, "disable_ffn_model_parallel", False + ) + if disable_ffn_model_parallel: + config = deepcopy(config) + config.tensor_parallel_degree = 1 + config.sequence_parallel = False + + super().__init__(config, ipp, do_shard_tensor=not disable_ffn_model_parallel) + self.moe_dropout_prob = config.moe_dropout_prob + self.fuse_swiglu = config.fuse_swiglu + + def redistribute_expert(self, mesh, placements): + """ + Place the experts on different devices. + """ + self.gate_proj.weight = dist.shard_tensor( + self.gate_proj.weight, mesh, placements + ) + self.up_proj.weight = dist.shard_tensor(self.up_proj.weight, mesh, placements) + self.down_proj.weight = dist.shard_tensor( + self.down_proj.weight, mesh, placements + ) + if self.config.use_bias: + self.gate_proj.bias = dist.shard_tensor( + self.gate_proj.bias, mesh, placements + ) + self.up_proj.bias = dist.shard_tensor(self.up_proj.bias, mesh, placements) + self.down_proj.bias = dist.shard_tensor( + self.down_proj.bias, mesh, placements + ) + + def forward(self, x): + if self.fuse_swiglu: + x = swiglu(self.gate_proj(x), self.up_proj(x)) + else: + x = F.silu(self.gate_proj(x)) * self.up_proj(x) + if self.moe_dropout_prob > 0: + with get_rng_state_tracker().rng_state("local_seed"): + x = F.dropout(x=x, p=self.moe_dropout_prob) + ret = self.down_proj(x) + return ret + + +class BMMLinear(nn.Layer): + def __init__(self, experts, d_in, d_out, use_bias=False): + super().__init__() + self.weight = self.create_parameter( + [experts, d_in, d_out], dtype=paddle.get_default_dtype() + ) + if use_bias: + self.bias = self.create_parameter( + [experts, d_out], dtype=paddle.get_default_dtype(), is_bias=True + ) + else: + self.bias = None + + def forward(self, x): + """x: [num_experts, Seq, dim]""" + if self.bias is not None: + return paddle.bmm(x, self.weight) + self.bias + return paddle.bmm(x, self.weight) + + +class ErnieMoeMLPFused(nn.Layer): + def __init__(self, config): + assert config.fuse_attn_ffn, "fused mlp only support fuse_attn_ffn" + super().__init__() + self.moe_dropout_prob = config.moe_dropout_prob + self.num_local_experts = config.moe_num_experts // config.moe_world_size + logger.info( + f"fused-expert-weight-shape: {[self.num_local_experts, config.hidden_size, config.intermediate_size]}" + ) + + self.up_gate_proj = BMMLinear( + self.num_local_experts, config.hidden_size, config.intermediate_size * 2 + ) + self.down_proj = BMMLinear( + self.num_local_experts, config.intermediate_size, config.hidden_size + ) + self.fuse_swiglu = config.fuse_swiglu + + def __len__(self): + return self.num_local_experts + + def __iter__(self): + return (self for _ in range(1)) + + def forward(self, x): + if self.fuse_swiglu: + x = swiglu(self.up_gate_proj(x)) + else: + gate, x = self.up_gate_proj(x).chunk(2, axis=-1) + x = F.silu(gate) * x + x = self.down_proj(x) + return x + + +def cal_aux_loss_func( + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + moe_k, + global_aux_loss=False, + rank=None, + group=None, +): + if tokens_mask is not None and tokens_mask.dtype != gate_prob.dtype: + tokens_mask = tokens_mask.astype(gate_prob.dtype) + + scale = None + if dispatch_tokens_mask is not None: + seqlen_float = dispatch_tokens_mask.astype(gate_prob.dtype).sum() + if ( + tokens_mask is not None + and gate_prob.shape[0] != dispatch_tokens_mask.shape[0] + ): + scale = seqlen_float / paddle.clip(tokens_mask.sum(), min=1e-6) + elif tokens_mask is not None: + seqlen_float = tokens_mask.sum() + else: + seqlen_float = gate_prob.numel().astype(gate_prob.dtype) / num_experts + seqlen_float = paddle.clip(seqlen_float, min=1e-6) + + if len(dispatch_mask.shape) == 2: + dispatch_mask = dispatch_mask.sum(0) + ce = dispatch_mask.astype(gate_prob.dtype).detach() / seqlen_float + me = paddle.sum(gate_prob, axis=0) / seqlen_float + if global_aux_loss: + me_list, ce_list = [], [] + dist.all_gather(me_list, me, group=group) + dist.all_gather(ce_list, ce, group=group) + + me_list[rank] = me + ce_list[rank] = ce + me = paddle.stack(me_list).mean(0) + ce = paddle.stack(ce_list).mean(0) + + l_aux = paddle.sum(me * ce) * num_experts + if use_group: + l_aux = l_aux / moe_k + + if scale is not None: + l_aux = l_aux + (scale - 1) * l_aux.detach() + + return l_aux + + +def gate_detach_matmul(x, weight, use_fake_gate=False): + x = x.cast(paddle.float32) if x.dtype != paddle.float32 else x + score = F.linear(x, weight) + + if use_fake_gate: + score = paddle.randn(score.shape).astype(score.dtype) + score - score + return score + + +class TopKGateFused(nn.Layer): + + def __init__(self, config, layer_idx: int, group, ipp=0) -> None: + super().__init__() + self.config = config + assert not config.fuse_gate_detach_matmul, "matmul_bwd is not supported" + + self.use_fake_gate = config.use_fake_gate + if self.use_fake_gate: + logging.warning( + "You are use fake_gate, which is just for test, not for real training." + ) + + self.model_dim = config.hidden_size + self.num_experts = config.moe_num_experts + self.num_experts_tensor = ( + sum(config.moe_num_experts) + if config.multimodel_experts + else config.moe_num_experts + ) + + self.cap = config.moe_capacity + self.group = group + + self.layer_idx = layer_idx + self.global_aux_loss = config.global_aux_loss + if self.global_aux_loss: + self.rank = dist.get_rank(self.group) + + self.use_token_type_bias = config.moe_use_token_type_bias + self.use_correction_bias = config.moe_use_aux_free + + self.ipp = ipp + + if config.moe_gate_act == "softmax": + self.act = partial(F.softmax, axis=-1) + elif config.moe_gate_act == "sigmoid": + self.act = F.sigmoid + else: + raise ValueError(f"{config.moe_gate_act} is not supported.") + + self.moe_aux_loss_lambda = paddle.to_tensor( + config.moe_aux_loss_lambda, dtype="float32" + ) + + if self.moe_aux_loss_lambda.ndim == 0: + self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.unsqueeze(0) + + self.moe_orthogonal_loss_lambda = paddle.to_tensor( + config.moe_orthogonal_loss_lambda, dtype="float32" + ) + + if self.moe_orthogonal_loss_lambda.ndim == 0: + self.moe_orthogonal_loss_lambda = self.moe_orthogonal_loss_lambda.unsqueeze( + 0 + ) + + self.experts_type_ids = None + if config.moe_orthogonal_loss_lambda: + if hasattr(fleet.fleet, "_user_defined_strategy"): + strategy = fleet.fleet._user_defined_strategy + sharding_configs = strategy.hybrid_configs["sharding_configs"] + pp_config = strategy.hybrid_configs["pp_configs"] + assert ( + not sharding_configs.comm_overlap + and not pp_config.sharding_comm_overlap + ), "orthogonal loss will cause twice gradient accumulate, will break pp/sharding overlap" + + self.eps = paddle.to_tensor([1e-12], dtype="float32") + if config.multimodel_experts: + if config.moe_use_hard_gate: + self.num_experts_list = [] + self.experts_type_mask = [] + experts_ids = paddle.zeros( + [sum(self.num_experts)], dtype="int64" + ).reshape([config.moe_world_size, -1]) + offset = 0 + for i, expert_num in enumerate(self.num_experts): + experts_ids[ + :, offset : offset + expert_num // config.moe_world_size + ] = i + offset += expert_num // config.moe_world_size + self.experts_type_ids = experts_ids.reshape([-1]) + logger.info( + f"use moe_use_hard_gate, experts_ids: {self.experts_type_ids}" + ) + for i, expert_num in enumerate(self.num_experts): + self.experts_type_mask.append( + self.experts_type_ids == i, + ) + self.num_experts_list.append(expert_num) + else: + assert ( + not config.moe_group_experts + ), "group_experts must use hard_gate when multimodel_experts is True" + else: + self.num_experts_list = [self.num_experts] + self._create_gate_parameter() + logger.info( + f"{config.moe_gate}: w/ capacity: {self.cap} experts:{self.num_experts} " + f"use_token_type_bias:{self.use_token_type_bias} gate_act:{config.moe_gate_act} " + ) + + def _create_gate_parameter(self): + + if self.config.multimodel_experts: + + self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.expand( + len(self.num_experts) + ) + self.moe_orthogonal_loss_lambda = self.moe_orthogonal_loss_lambda.expand( + len(self.num_experts) + ) + + for i, num_experts in enumerate(self.num_experts): + if i == 1: + with paddle.utils.unique_name.guard(f"mm_gate_{self.layer_idx}_"): + p = self.create_parameter( + shape=[self.model_dim, num_experts], + dtype="float32", + attr=paddle.ParamAttr( + name=unique_name.generate("moe_gate") + ), + ) + else: + p = self.create_parameter( + shape=[self.model_dim, num_experts], + dtype="float32", + attr=paddle.ParamAttr(name=unique_name.generate("moe_gate")), + ) + p.expert_type = f"expert_type_{i}" + self.add_parameter( + ("weight" if i == 0 else f"weight_{i}"), + p, + ) + else: + self.weight = self.create_parameter( + shape=[self.model_dim, self.num_experts], + dtype="float32", + attr=paddle.ParamAttr(name=unique_name.generate("moe_gate")), + ) + logger.info(f"moe-Gate, {self.weight}") + + if self.use_token_type_bias: + if self.config.multimodel_experts: + assert ( + not self.config.moe_use_hard_gate + ), "multimodel_experts with hard_gate is not support token_type_bias." + num_experts = ( + sum(self.num_experts) + if self.config.multimodel_experts + else self.num_experts + ) + bias_type_num = ( + len(self.num_experts) if self.config.multimodel_experts else 1 + ) + self.bias = self.create_parameter( + shape=[bias_type_num, num_experts], + dtype="float32", + attr=paddle.ParamAttr( + name=unique_name.generate("moe_gate_bias"), + initializer=paddle.nn.initializer.Assign( + np.zeros([bias_type_num, num_experts]) + ), + ), + ) + logger.info(f"using token type bias, bias: {self.bias},") + self._cast_to_low_precision = False + self._cast_to_low_precison = False + + def get_gate_weight(self, transform_weight): + if not self.config.multimodel_experts: + return self.weight + if not transform_weight: + return paddle.concat( + [ + getattr(self, "weight" if i == 0 else f"weight_{i}") + for i in range(len(self.num_experts)) + ], + -1, + ) + weight = paddle.zeros( + [ + self.model_dim, + self.config.moe_world_size, + sum(self.num_experts) // self.config.moe_world_size, + ], + dtype="float32", + ) + offset = 0 + for i, num_experts in enumerate(self.num_experts): + weight[ + :, :, offset : offset + num_experts // self.config.moe_world_size + ] = getattr(self, "weight" if i == 0 else f"weight_{i}").reshape( + [self.model_dim, self.config.moe_world_size, -1] + ) + offset += num_experts // self.config.moe_world_size + weight = weight.reshape([self.model_dim, -1]) + + return weight + + def _cal_aux_loss( + self, + gate_prob, + dispatch_mask, + num_experts=None, + use_group=None, + tokens_mask=None, + dispatch_tokens_mask=None, + ): + + if self.act is F.sigmoid: + gate_prob = gate_prob / gate_prob.sum(-1, keepdim=True) + + if self.use_correction_bias: + if tokens_mask is not None: + gate_prob_this_modality = gate_prob[tokens_mask.astype("bool")] + if gate_prob_this_modality.shape[0]: + _, top_idx = gate_prob_this_modality.topk( + k=self.config.moe_k, axis=-1 + ) + mask = paddle.zeros_like(gate_prob_this_modality).put_along_axis( + top_idx, paddle.to_tensor(1.0), axis=1 + ) + dispatch_mask = paddle.sum(mask.cast(paddle.int64), axis=0) + else: + dispatch_mask = paddle.zeros(gate_prob.shape[-1], dtype="int64") + dist.stream.all_reduce( + dispatch_mask, + group=self.group, + use_calc_stream=True, + ) + else: + _, top_idx = gate_prob.topk(k=self.config.moe_k, axis=-1) + + mask = paddle.zeros_like(gate_prob).put_along_axis( + top_idx, paddle.to_tensor(1.0), axis=1 + ) + dispatch_mask = paddle.sum(mask.cast(paddle.int64), axis=0) + + if num_experts is None: + num_experts = self.num_experts_tensor + if use_group is None: + use_group = self.config.moe_group_experts + + return cal_aux_loss_func( + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + self.config.moe_k, + self.global_aux_loss, + self.rank if self.global_aux_loss else None, + self.group if self.global_aux_loss else None, + ) + + def forward( + self, + input: Tensor, + token_type_ids=None, + transform_weight=True, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """ + Args: + input: paddle.Tensor, hidden-states of layer + Retruns: + paddle.Tensor [Seq, Expert, Capacity]: float32, combine weights + paddle.Tensor [Seq, Expert, Capacity]: bool, dispatch mask + Tuple[paddle.Tensor]: `GateOutput` + """ + num_experts = ( + sum(self.num_experts) + if self.config.multimodel_experts + else self.num_experts + ) + if self.training: + cap = self.cap[0] + elif input.shape[0] < num_experts: + cap = self.cap[2] + else: + cap = self.cap[1] + num_tokens = input.shape[0] + global_capacity = int(cap * num_tokens // num_experts) + local_num_tokens = input._local_shape[0] + local_capacity = int(cap * local_num_tokens // num_experts) + + weight = self.get_gate_weight(transform_weight) + with paddle.amp.auto_cast(False): + input = _reshard( + input, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(0)] + ) + logits = gate_detach_matmul(input, weight, self.use_fake_gate) + logits = _reshard( + logits, get_flatten_mesh(get_mesh(self.ipp)), [dist.Shard(0)] + ) + if self.use_token_type_bias: + assert token_type_ids is not None + assert ( + token_type_ids.max() < self.bias.shape[0] + ), f"token_type_ids {token_type_ids.max()} >= bias shape {self.bias.shape[0]}" + bias = self.bias[token_type_ids] + logits = logits + bias + router_loss = paddle.zeros([1], dtype="float32") + router_loss.stop_gradient = False + + return logits, global_capacity, router_loss, local_capacity diff --git a/examples/auto_parallel/models/top2_gate.py b/examples/auto_parallel/models/top2_gate.py deleted file mode 100644 index 48f3a70ef..000000000 --- a/examples/auto_parallel/models/top2_gate.py +++ /dev/null @@ -1,399 +0,0 @@ -# !/usr/bin/env python3 - -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Tuple -from functools import partial -import logging -import numpy as np -import paddle -from paddle import Tensor -import paddle.distributed as dist -import paddle.nn.functional as F -from paddle import nn -from paddle.utils import unique_name -from paddle.distributed import fleet -from utils.training_utils import get_mesh, get_flatten_mesh - -logger = logging.getLogger(__name__) - - -def cal_aux_loss_func( - gate_prob, - dispatch_mask, - tokens_mask, - dispatch_tokens_mask, - num_experts, - use_group, - moe_k, - global_aux_loss=False, - rank=None, - group=None, -): - if tokens_mask is not None and tokens_mask.dtype != gate_prob.dtype: - tokens_mask = tokens_mask.astype(gate_prob.dtype) - - scale = None - if dispatch_tokens_mask is not None: - seqlen_float = dispatch_tokens_mask.astype(gate_prob.dtype).sum() - if ( - tokens_mask is not None - and gate_prob.shape[0] != dispatch_tokens_mask.shape[0] - ): - scale = seqlen_float / paddle.clip(tokens_mask.sum(), min=1e-6) - elif tokens_mask is not None: - seqlen_float = tokens_mask.sum() - else: - seqlen_float = gate_prob.numel().astype(gate_prob.dtype) / num_experts - seqlen_float = paddle.clip(seqlen_float, min=1e-6) - - if len(dispatch_mask.shape) == 2: - dispatch_mask = dispatch_mask.sum(0) - ce = dispatch_mask.astype(gate_prob.dtype).detach() / seqlen_float - me = paddle.sum(gate_prob, axis=0) / seqlen_float - if global_aux_loss: - me_list, ce_list = [], [] - dist.all_gather(me_list, me, group=group) - dist.all_gather(ce_list, ce, group=group) - - me_list[rank] = me - ce_list[rank] = ce - me = paddle.stack(me_list).mean(0) - ce = paddle.stack(ce_list).mean(0) - - l_aux = paddle.sum(me * ce) * num_experts - if use_group: - l_aux = l_aux / moe_k - - if scale is not None: - l_aux = l_aux + (scale - 1) * l_aux.detach() - - return l_aux - - -def gate_detach_matmul(x, weight, use_fake_gate=False): - x = x.cast(paddle.float32) if x.dtype != paddle.float32 else x - score = F.linear(x, weight) - - if use_fake_gate: - score = paddle.randn(score.shape).astype(score.dtype) + score - score - return score - - -class TopKGateFused(nn.Layer): - - def __init__(self, config, layer_idx: int, group, ipp=0) -> None: - super().__init__() - self.config = config - assert not config.fuse_gate_detach_matmul, "matmul_bwd is not supported" - - self.use_fake_gate = config.use_fake_gate - if self.use_fake_gate: - logging.warning( - "You are use fake_gate, which is just for test, not for real training." - ) - - self.model_dim = config.hidden_size - self.num_experts = config.moe_num_experts - self.num_experts_tensor = ( - sum(config.moe_num_experts) - if config.multimodel_experts - else config.moe_num_experts - ) - - self.cap = config.moe_capacity - self.group = group - - self.layer_idx = layer_idx - self.global_aux_loss = config.global_aux_loss - if self.global_aux_loss: - self.rank = dist.get_rank(self.group) - - self.use_token_type_bias = config.moe_use_token_type_bias - self.use_correction_bias = config.moe_use_aux_free - - if config.moe_gate_act == "softmax": - self.act = partial(F.softmax, axis=-1) - elif config.moe_gate_act == "sigmoid": - self.act = F.sigmoid - else: - raise ValueError(f"{config.moe_gate_act} is not supported.") - - self.moe_aux_loss_lambda = paddle.to_tensor( - config.moe_aux_loss_lambda, dtype="float32" - ) - - if self.moe_aux_loss_lambda.ndim == 0: - self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.unsqueeze(0) - - self.moe_orthogonal_loss_lambda = paddle.to_tensor( - config.moe_orthogonal_loss_lambda, dtype="float32" - ) - - if self.moe_orthogonal_loss_lambda.ndim == 0: - self.moe_orthogonal_loss_lambda = self.moe_orthogonal_loss_lambda.unsqueeze( - 0 - ) - - self.experts_type_ids = None - if config.moe_orthogonal_loss_lambda: - if hasattr(fleet.fleet, "_user_defined_strategy"): - strategy = fleet.fleet._user_defined_strategy - sharding_configs = strategy.hybrid_configs["sharding_configs"] - pp_config = strategy.hybrid_configs["pp_configs"] - assert ( - not sharding_configs.comm_overlap - and not pp_config.sharding_comm_overlap - ), "orthogonal loss will cause twice gradient accumulate, will break pp/sharding overlap" - - self.eps = paddle.to_tensor([1e-12], dtype="float32") - if config.multimodel_experts: - if config.moe_use_hard_gate: - self.num_experts_list = [] - self.experts_type_mask = [] - experts_ids = paddle.zeros( - [sum(self.num_experts)], dtype="int64" - ).reshape([config.moe_world_size, -1]) - offset = 0 - for i, expert_num in enumerate(self.num_experts): - experts_ids[ - :, offset : offset + expert_num // config.moe_world_size - ] = i - offset += expert_num // config.moe_world_size - self.experts_type_ids = experts_ids.reshape([-1]) - logger.info( - f"use moe_use_hard_gate, experts_ids: {self.experts_type_ids}" - ) - for i, expert_num in enumerate(self.num_experts): - self.experts_type_mask.append( - self.experts_type_ids == i, - ) - self.num_experts_list.append(expert_num) - else: - assert ( - not config.moe_group_experts - ), "group_experts must use hard_gate when multimodel_experts is True" - else: - self.num_experts_list = [self.num_experts] - self._create_gate_parameter() - logger.info( - f"{config.moe_gate}: w/ capacity: {self.cap} experts:{self.num_experts} " - f"use_token_type_bias:{self.use_token_type_bias} gate_act:{config.moe_gate_act} " - ) - - self.ipp = ipp - self.weight = dist.shard_tensor( - self.weight, get_flatten_mesh(get_mesh(self.ipp)), [dist.Replicate()] - ) - - def _create_gate_parameter(self): - - if self.config.multimodel_experts: - - self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.expand( - len(self.num_experts) - ) - self.moe_orthogonal_loss_lambda = self.moe_orthogonal_loss_lambda.expand( - len(self.num_experts) - ) - - for i, num_experts in enumerate(self.num_experts): - if i == 1: - with paddle.utils.unique_name.guard(f"mm_gate_{self.layer_idx}_"): - p = self.create_parameter( - shape=[self.model_dim, num_experts], - dtype="float32", - attr=paddle.ParamAttr( - name=unique_name.generate("moe_gate") - ), - ) - else: - p = self.create_parameter( - shape=[self.model_dim, num_experts], - dtype="float32", - attr=paddle.ParamAttr(name=unique_name.generate("moe_gate")), - ) - p.expert_type = f"expert_type_{i}" - self.add_parameter( - ("weight" if i == 0 else f"weight_{i}"), - p, - ) - else: - self.weight = self.create_parameter( - shape=[self.model_dim, self.num_experts], - dtype="float32", - attr=paddle.ParamAttr(name=unique_name.generate("moe_gate")), - ) - logger.info(f"moe-Gate, {self.weight}") - - if self.use_token_type_bias: - if self.config.multimodel_experts: - assert ( - not self.config.moe_use_hard_gate - ), "multimodel_experts with hard_gate is not support token_type_bias." - num_experts = ( - sum(self.num_experts) - if self.config.multimodel_experts - else self.num_experts - ) - bias_type_num = ( - len(self.num_experts) if self.config.multimodel_experts else 1 - ) - self.bias = self.create_parameter( - shape=[bias_type_num, num_experts], - dtype="float32", - attr=paddle.ParamAttr( - name=unique_name.generate("moe_gate_bias"), - initializer=paddle.nn.initializer.Assign( - np.zeros([bias_type_num, num_experts]) - ), - ), - ) - logger.info(f"using token type bias, bias: {self.bias},") - self._cast_to_low_precision = False - self._cast_to_low_precison = False - - def get_gate_weight(self, transform_weight): - if not self.config.multimodel_experts: - return self.weight - if not transform_weight: - return paddle.concat( - [ - getattr(self, "weight" if i == 0 else f"weight_{i}") - for i in range(len(self.num_experts)) - ], - -1, - ) - weight = paddle.zeros( - [ - self.model_dim, - self.config.moe_world_size, - sum(self.num_experts) // self.config.moe_world_size, - ], - dtype="float32", - ) - offset = 0 - for i, num_experts in enumerate(self.num_experts): - weight[ - :, :, offset : offset + num_experts // self.config.moe_world_size - ] = getattr(self, "weight" if i == 0 else f"weight_{i}").reshape( - [self.model_dim, self.config.moe_world_size, -1] - ) - offset += num_experts // self.config.moe_world_size - weight = weight.reshape([self.model_dim, -1]) - - return weight - - def _cal_aux_loss( - self, - gate_prob, - dispatch_mask, - num_experts=None, - use_group=None, - tokens_mask=None, - dispatch_tokens_mask=None, - ): - - if self.act is F.sigmoid: - gate_prob = gate_prob / gate_prob.sum(-1, keepdim=True) - - if self.use_correction_bias: - if tokens_mask is not None: - gate_prob_this_modality = gate_prob[tokens_mask.astype("bool")] - if gate_prob_this_modality.shape[0]: - _, top_idx = gate_prob_this_modality.topk( - k=self.config.moe_k, axis=-1 - ) - mask = paddle.zeros_like(gate_prob_this_modality).put_along_axis( - top_idx, paddle.to_tensor(1.0), axis=1 - ) - dispatch_mask = paddle.sum(mask.cast(paddle.int64), axis=0) - else: - dispatch_mask = paddle.zeros(gate_prob.shape[-1], dtype="int64") - dist.stream.all_reduce( - dispatch_mask, - group=self.group, - use_calc_stream=True, - ) - else: - _, top_idx = gate_prob.topk(k=self.config.moe_k, axis=-1) - - mask = paddle.zeros_like(gate_prob).put_along_axis( - top_idx, paddle.to_tensor(1.0), axis=1 - ) - dispatch_mask = paddle.sum(mask.cast(paddle.int64), axis=0) - - if num_experts is None: - num_experts = self.num_experts_tensor - if use_group is None: - use_group = self.config.moe_group_experts - - return cal_aux_loss_func( - gate_prob, - dispatch_mask, - tokens_mask, - dispatch_tokens_mask, - num_experts, - use_group, - self.config.moe_k, - self.global_aux_loss, - self.rank if self.global_aux_loss else None, - self.group if self.global_aux_loss else None, - ) - - def forward( - self, - input: Tensor, - token_type_ids=None, - transform_weight=True, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - """ - Args: - input: paddle.Tensor, hidden-states of layer - Retruns: - paddle.Tensor [Seq, Expert, Capacity]: float32, combine weights - paddle.Tensor [Seq, Expert, Capacity]: bool, dispatch mask - Tuple[paddle.Tensor]: `GateOutput` - """ - num_experts = ( - sum(self.num_experts) - if self.config.multimodel_experts - else self.num_experts - ) - if self.training: - cap = self.cap[0] - elif input.shape[0] < num_experts: - cap = self.cap[2] - else: - cap = self.cap[1] - num_tokens = input.shape[0] - global_capacity = int(cap * num_tokens // num_experts) - local_num_tokens = input._local_shape[0] - local_capacity = int(cap * local_num_tokens // num_experts) - - weight = self.get_gate_weight(transform_weight) - with paddle.amp.auto_cast(False): - logits = gate_detach_matmul(input, weight, self.use_fake_gate) - if self.use_token_type_bias: - assert token_type_ids is not None - assert ( - token_type_ids.max() < self.bias.shape[0] - ), f"token_type_ids {token_type_ids.max()} >= bias shape {self.bias.shape[0]}" - bias = self.bias[token_type_ids] - logits = logits + bias - router_loss = paddle.zeros([1], dtype="float32") - router_loss.stop_gradient = False - - return logits, global_capacity, router_loss, local_capacity diff --git a/examples/auto_parallel/pretrain.py b/examples/auto_parallel/pretrain.py index 662a8c40e..e16ee23f5 100644 --- a/examples/auto_parallel/pretrain.py +++ b/examples/auto_parallel/pretrain.py @@ -56,20 +56,6 @@ from paddle.tensor.manipulation import reshape from typing import Literal, TypeAlias -# USE_VPP=0: Implement parallelism using the intermediate API. -# USE_VPP=1: Implement parallelism using the basic API; the intermediate API does not support VPP for the time being. -use_vpp = os.environ.get("USE_VPP", "0") -if use_vpp == "0": - from models.modeling import ErnieForCausalLM - - logger.info("Training with the intermediate API. Do not support VPP.") -elif use_vpp == "1": - from models.modeling_vpp import ErnieForCausalLM - - logger.info("Training VPP parallelism with the basic API") -else: - raise ValueError(f"Invalid environment args USE_VPP={use_vpp}") - _ReduceMode: TypeAlias = Literal["mean", "sum", "none"] @@ -550,8 +536,25 @@ def main(): tokenizer = setup_tokenizer(args, cfg) + vpp_degree = cfg.virtual_pp_degree + # vpp_degree==1: Implement parallelism using the intermediate API. + # vpp_degree>1: Implement parallelism using the basic API; the intermediate API does not support VPP for the time being. + assert vpp_degree >= 1, "vpp_degree must be greater than or equal to 1." + if vpp_degree == 1: + from models.modeling import ErnieForCausalLM, ErnieDecoderLayer + + logger.info("Training with the intermediate API. Do not support VPP.") + modle_class = ErnieForCausalLM + aux_free_class = ErnieDecoderLayer + elif vpp_degree > 1: + from models.modeling_vpp import ErnieForCausalLMVPP, ErnieDecoderLayerVPP + + logger.info("Training VPP parallelism with the basic API") + modle_class = ErnieForCausalLMVPP + aux_free_class = ErnieDecoderLayerVPP + with paddle.LazyGuard(): - model = ErnieForCausalLM(cfg) + model = modle_class(cfg) logger.info(f"Using model: {type(model)}, config: {model.config}") paddle.set_default_dtype("float32") @@ -568,7 +571,9 @@ def main(): logger.info("Adding aux free callback") callbacks += [ MoECorrectionBiasAdjustCallback( - args.moe_use_aux_free_update_coef, args.sequence_parallel + args.moe_use_aux_free_update_coef, + args.sequence_parallel, + aux_free_class, ) ] init_parameters(model) diff --git a/examples/auto_parallel/train_4p5_300B_A47B.sh b/examples/auto_parallel/train_4p5_300B_A47B.sh index ac06a15ec..79c31cafd 100644 --- a/examples/auto_parallel/train_4p5_300B_A47B.sh +++ b/examples/auto_parallel/train_4p5_300B_A47B.sh @@ -40,10 +40,6 @@ export PYTHONPATH=../../:$PYTHONPATH log_dir=output/paddle_distributed_logs -# USE_VPP=0: Implement parallelism using the intermediate API. -# USE_VPP=1: Implement parallelism using the basic API; the intermediate API does not support VPP for the time being. -export USE_VPP=1 - python -m paddle.distributed.launch \ --log_dir ${log_dir} \ --master : \ diff --git a/examples/auto_parallel/trainers/callbacks/moe_correction_bias_adjust_callback.py b/examples/auto_parallel/trainers/callbacks/moe_correction_bias_adjust_callback.py index 9f74fe0ba..e256934af 100644 --- a/examples/auto_parallel/trainers/callbacks/moe_correction_bias_adjust_callback.py +++ b/examples/auto_parallel/trainers/callbacks/moe_correction_bias_adjust_callback.py @@ -15,17 +15,17 @@ import paddle import paddle.distributed as dist -from models.modeling import ErnieDecoderLayer from models.moe_layer import MOELayer from paddle.distributed.fleet import fleet from paddleformers.trainer.trainer_callback import TrainerCallback class MoECorrectionBiasAdjustCallback(TrainerCallback): - def __init__(self, lr, use_sp): + def __init__(self, lr, use_sp, model_class): super().__init__() self.update_lr = float(lr) self.use_sp = use_sp + self.model_class = model_class def on_optimizer_end(self, args, state, control, **kwargs): model = kwargs["model"] @@ -35,7 +35,7 @@ def on_optimizer_end(self, args, state, control, **kwargs): def get_stat(layer): nonlocal usages, biases - if isinstance(layer, ErnieDecoderLayer): + if isinstance(layer, self.model_class): if not isinstance(layer.mlp, (MOELayer)): return assert hasattr( @@ -69,7 +69,7 @@ def get_stat(layer): def update_bias(layer): nonlocal usages, biases - if isinstance(layer, ErnieDecoderLayer): + if isinstance(layer, self.model_class): if not isinstance(layer.mlp, MOELayer): return with paddle.no_grad():