Skip to content

Commit 84d1cc1

Browse files
author
wentiange
committed
[FEATURE] NPU训练GMM NZ
1 parent ba28809 commit 84d1cc1

7 files changed

Lines changed: 368 additions & 10 deletions

File tree

pta_patch/_fsdp_collectives.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from typing import List, Tuple
2+
import os
3+
import torch
4+
5+
6+
lib = torch.library.Library("fsdp", "FRAGMENT")
7+
8+
9+
@torch.library.impl(lib, "chunk_cat", "PrivateUse1")
10+
def chunk_cat(
11+
tensors: List[torch.Tensor],
12+
dim: int,
13+
num_chunks: int,
14+
out: torch.Tensor,
15+
) -> None:
16+
tensors = [tensor.contiguous() for tensor in tensors]
17+
out = out.contiguous()
18+
torch._chunk_cat(tensors, dim, num_chunks, out=out)
19+
20+
21+
@torch.library.impl(lib, "all_gather_copy_in", "PrivateUse1")
22+
def all_gather_copy_in_npu(
23+
all_gather_inputs: List[torch.Tensor],
24+
inp_split_sizes: List[int],
25+
all_gather_input_numel: int,
26+
world_size: int,
27+
rank: int,
28+
dtype: torch.dtype,
29+
device: torch.device,
30+
) -> Tuple[torch.Tensor, torch.Tensor]:
31+
all_gather_output = torch.empty(
32+
(all_gather_input_numel * world_size,), dtype=dtype, device=device
33+
)
34+
all_gather_input = all_gather_output.narrow(
35+
0, all_gather_input_numel * rank, all_gather_input_numel
36+
)
37+
foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
38+
with torch.no_grad():
39+
if foreach_copy_dsts[0].device == all_gather_inputs[0].device:
40+
torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs, non_blocking=True)
41+
else:
42+
torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs)
43+
return all_gather_input, all_gather_output
44+
45+
@torch.library.impl(lib, "split_with_sizes_copy", "PrivateUse1")
46+
def split_with_sizes_copy(
47+
all_gather_output: torch.Tensor,
48+
all_gather_input_split_sizes: List[int],
49+
dim: int,
50+
out: List[torch.Tensor],
51+
num_expert: int = 128,
52+
hidden_size: int = 4096,
53+
moe_intermediate_size: int = 1536,
54+
) -> None:
55+
# 当且仅当满足如下条件,才启用gmm_nz优化
56+
# 1. 打开GROUPMM_NZ_TRANSPOSE开关
57+
# 2. all_gather_input_split_sizes长度大于1
58+
# 3. 切分后的最后一个权重用于GMM down_proj
59+
enable_gmm_nz = int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")) \
60+
and len(all_gather_input_split_sizes) > 1 \
61+
and out[-1].shape[0] * out[-1].shape[1] == num_expert * hidden_size * moe_intermediate_size
62+
63+
if enable_gmm_nz:
64+
from special_op import npu_special_slice
65+
num_rank = out[0].shape[0]
66+
total_size = sum(all_gather_input_split_sizes)
67+
68+
# 切分后最后两个权重用于GMM up_proj和down_proj
69+
up_size = out[-1].shape[1]
70+
down_size = out[-2].shape[1]
71+
72+
up_start = total_size - up_size
73+
down_start = up_start - down_size
74+
75+
out[-1].resize_(num_expert,moe_intermediate_size,hidden_size)
76+
out[-2].resize_(num_expert,hidden_size,moe_intermediate_size*2)
77+
78+
# GMM权重切分和转NZ使用融合算子
79+
npu_special_slice(all_gather_output, dim, up_start, total_size, out[-1])
80+
npu_special_slice(all_gather_output, dim, down_start, up_start, out[-2])
81+
82+
other_tensors = all_gather_output[:, :down_start].view(num_rank, -1)
83+
torch.split_with_sizes_copy(
84+
other_tensors, all_gather_input_split_sizes[:-2], dim=dim, out=out[:-2]
85+
)
86+
87+
return
88+
89+
torch.split_with_sizes_copy(
90+
all_gather_output, all_gather_input_split_sizes, dim=dim, out=out
91+
)

special_op.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from torch_npu import npu_special_slice

test_qwen3_235b_npu.sh

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
set -x
2+
3+
CANN_DIR=/usr/local/Ascend/ascend-toolkit # 默认CANN安装地址
4+
CANN_OPS_DIR=/tmp/cann-ops # GMM_NZ使能补丁的CANN-ops安装地址
5+
PTA_FSDP_DIR=/usr/local/lib/python3.11/site-packages/torch_npu/distributed/fsdp # PTA的FSDP补丁位置,方便后续替换slice算子的实现
6+
7+
mkdir ${CANN_OPS_DIR}
8+
./pta_patch/CANN-custom_ops--linux.aarch64.run --install-path=${CANN_OPS_DIR}
9+
source ${CANN_OPS_DIR}/vendors/customize/bin/set_env.bash
10+
source ${CANN_DIR}/set_env.sh
11+
12+
# 安装PTA 2.6.0版本GMM 切K轴补丁
13+
pip install /path/to/torch_npu-custom.whl --force-reinstall
14+
cp ./pta_patch/_fsdp_collectives.py ${PTA_FSDP_DIR}
15+
16+
# 使能GMM NZ开关
17+
export GROUPMM_NZ_TRANSPOSE=1
18+
19+
export QWEN3_MOE_PATH=/path/to/qwen3_moe_weights
20+
export ALPACA_PATH=/path/to/alpaca_dataset
21+
22+
export XTUNER_USE_FA3="1"
23+
export HCCL_RDMA_TC=132
24+
25+
26+
# 自定义, 1是开启,0是关闭
27+
export LINEAR_ONLY_SHARD=1
28+
29+
mkdir ${LOGS_DIR}
30+
31+
torchrun --nproc-per-node 16 \
32+
--master_addr=$MASTER_ADDR \
33+
--master_port=$MASTER_PORT \
34+
--nnodes=$WORLD_SIZE \
35+
--node_rank=$RANK \
36+
ci/scripts/test_sft_trainer_235B.py \
37+
${PROF_DIR} | tee ${LOGS_DIR}/rank_${RANK}.log

xtuner/v1/model/base.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from xtuner.v1.utils.loader import HFCheckpointLoader
4040

4141
from .utils import ModelForwardExtraLogInfo
42-
42+
import os
4343

4444
logger = get_logger()
4545

@@ -229,19 +229,34 @@ def safetensors_to_params(
229229
start: int | None,
230230
end: int | None,
231231
dim: int | None,
232+
flag: str | None,
233+
num_expert: int = 128,
234+
hidden_size: int = 4096,
235+
moe_intermediate_size: int = 1536,
232236
):
233237
if len(safetensors) > 1:
234238
assert dim is not None, "Internal Error dim must not be None when len(safetensors) > 1"
235239
loaded_tensor = torch.cat(safetensors, dim=dim)
236240
else:
237241
loaded_tensor = safetensors[0]
238242

243+
if int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")) == 1 and flag == 'fused':
244+
out_feature, in_feature = safetensors[0].shape
245+
out_feature = out_feature * 2 if out_feature == moe_intermediate_size else out_feature
246+
loaded_tensor = loaded_tensor.view(-1, out_feature, in_feature).transpose(1,2).contiguous().view(-1, out_feature).contiguous()
247+
239248
if start is not None and end is not None:
240249
assert self.fsdp_config is not None, (
241250
"Internal Error. fsdp_config must not be None when start and end is not None"
242251
)
243252
start = min(start, loaded_tensor.shape[self.FSDP_SHARD_DIM])
244253
end = min(end, loaded_tensor.shape[self.FSDP_SHARD_DIM])
254+
255+
if int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")) == 1 and torch.distributed.get_world_size() >= num_expert and loaded_tensor.shape[self.FSDP_SHARD_DIM] == hidden_size:
256+
if torch.distributed.get_rank() % 4 >= 2:
257+
start += hidden_size // 2
258+
end += hidden_size // 2
259+
245260
loaded_tensor_slice = loaded_tensor.index_select(
246261
dim=self.FSDP_SHARD_DIM, index=torch.arange(start, end, dtype=torch.int64, device=loaded_tensor.device)
247262
)
@@ -956,12 +971,12 @@ def _load_same_hf_param(
956971
end = None
957972

958973
self.safetensors_to_params(
959-
[loaded_tensor], local_tensor, param_name=load_spec.name, start=start, end=end, dim=load_spec.dim
974+
[loaded_tensor], local_tensor, param_name=load_spec.name, start=start, end=end, dim=load_spec.dim, flag='same',
960975
)
961976
return []
962977

963978
def _load_fused_hf_param(
964-
self, param: torch.Tensor, load_spec: LoadSpec, checkpoint_loader: HFCheckpointLoader
979+
self, param: torch.Tensor, load_spec: LoadSpec, checkpoint_loader: HFCheckpointLoader, num_expert: int = 128
965980
) -> list[str]:
966981
# For expert parallel
967982
# NOTE:
@@ -1004,6 +1019,10 @@ def _load_fused_hf_param(
10041019
hf_keys_start = int(fsdp_start / hf_key_size)
10051020
hf_keys_end = math.ceil(fsdp_end / hf_key_size)
10061021

1022+
if int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")) == 1 and len(hf_keys) == num_expert *2 and torch.distributed.get_world_size() >= num_expert: # gate & up的情况,down的情况需要排除
1023+
hf_keys_start = int(torch.distributed.get_rank() // 4) * 2
1024+
hf_keys_end = hf_keys_start + 2
1025+
10071026
# Empty pad by fsdp
10081027
if hf_keys_start == hf_keys_end:
10091028
return []
@@ -1038,7 +1057,7 @@ def _load_fused_hf_param(
10381057
return missing_keys
10391058

10401059
self.safetensors_to_params(
1041-
_loaded_tensor, local_tensor, param_name=load_spec.name, start=start, end=end, dim=load_spec.dim
1060+
_loaded_tensor, local_tensor, param_name=load_spec.name, start=start, end=end, dim=load_spec.dim, flag='fused',
10421061
)
10431062
return missing_keys
10441063

@@ -1084,6 +1103,7 @@ def _load_shard_hf_param(
10841103
start=start,
10851104
end=end,
10861105
dim=load_spec.dim,
1106+
flag='shard',
10871107
)
10881108
return []
10891109

xtuner/v1/model/moe/moe.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,12 @@
4747
)
4848
from xtuner.v1.utils.activation_offload import async_save_on_cpu
4949
from xtuner.v1.utils.compile import maybe_compile
50+
import torch_npu
51+
import os
5052

53+
if int(os.getenv("LINEAR_ONLY_SHARD", "0")) == 1:
54+
from xtuner.v1.patch.fsdp_partial_shard import apply_fsdp_partial_shard_patch
55+
apply_fsdp_partial_shard_patch()
5156

5257
DEVICE = get_device()
5358
logger = get_logger()
@@ -167,7 +172,7 @@ def __init__(self, config: MoEConfig):
167172
else:
168173
self.z_loss = None
169174

170-
self.offload_stream = torch.cuda.Stream()
175+
self.offload_stream = torch_npu.npu.Stream() #torch.cuda.Stream()
171176

172177
def _select_non_pad_router_logits(
173178
self,
@@ -688,6 +693,16 @@ def fully_shard(
688693
)
689694
num_recompute_layers = int(self.config.num_hidden_layers * self.fsdp_config.recompute_ratio)
690695

696+
def layer_to_fully_shard(layer):
697+
return [
698+
layer.self_attn.q_proj,
699+
layer.self_attn.k_proj,
700+
layer.self_attn.v_proj,
701+
layer.self_attn.o_proj,
702+
layer.experts.fused_w1w3,
703+
layer.experts.fused_w2,
704+
]
705+
691706
for layer_idx, layer in tqdm(self.layers.items(), desc="[FSDP Sharding]"):
692707
layer_idx = int(layer_idx)
693708
if layer_idx < num_recompute_layers - 1:
@@ -698,19 +713,26 @@ def fully_shard(
698713
reshard_after_forward = False
699714
else:
700715
reshard_after_forward = self.fsdp_config.reshard_after_forward
716+
is_linear_only_shard = int(os.getenv("LINEAR_ONLY_SHARD", "0")) == 1
701717
fully_shard(
702-
layer,
718+
layer_to_fully_shard(layer) if is_linear_only_shard else layer,
703719
mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh,
704720
mp_policy=mp_policy,
705721
reshard_after_forward=reshard_after_forward,
706722
offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None,
723+
**({'hook_module': layer} if is_linear_only_shard else {}),
707724
)
708725

709726
for layer_cur, layer_next in zip(
710727
list(self.layers.values())[:-1],
711728
list(self.layers.values())[1:],
712729
):
713-
layer_cur.set_modules_to_forward_prefetch([layer_next]) # type: ignore
730+
if not is_linear_only_shard:
731+
layer_cur.set_modules_to_forward_prefetch([layer_next]) # type: ignore
732+
else:
733+
layer_cur_fully_shard_modules = layer_to_fully_shard(layer_cur)
734+
layer_next_fully_shard_modules = layer_to_fully_shard(layer_next)
735+
layer_cur_fully_shard_modules[0].set_modules_to_forward_prefetch([layer_next_fully_shard_modules[0]])
714736

715737
fully_shard(
716738
self.embed_tokens,

xtuner/v1/module/grouped_linear/moe_group_linear.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from xtuner.v1.float8.float8_gmm_tile_wise import TileWiseFloat8GroupedLinear
88
from xtuner.v1.ops import group_gemm
99

10+
from torch.autograd import Function
11+
import torch_npu
12+
import os
1013

1114
class GroupedLinear(nn.Module):
1215
# TODO:Missng example docs
@@ -22,7 +25,11 @@ def __init__(
2225
self.in_features = in_features
2326
self.out_features = out_features
2427
self.num_routed_experts = num_routed_experts
25-
weight = torch.empty(num_routed_experts * out_features, in_features)
28+
# 添加训练NZ开关,注意该开关只在npu上有性能增益
29+
if int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")):
30+
weight = torch.empty(num_routed_experts * in_features, out_features)
31+
else:
32+
weight = torch.empty(num_routed_experts * out_features, in_features)
2633

2734
self.ep_mesh = ep_mesh
2835
if self.ep_mesh is not None and self.ep_mesh.size() > 1:
@@ -40,15 +47,47 @@ def __init__(
4047

4148
def forward(self, x: torch.Tensor, tokens_per_expert: torch.Tensor, decoding: bool = False):
4249
weight = self.weight.to_local() if isinstance(self.weight, DTensor) else self.weight
43-
weight = weight.view(-1, self.out_features, self.in_features)
44-
out = group_gemm(x, weight, tokens_per_expert)
50+
51+
if int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")):
52+
weight = weight.view(-1, self.in_features, self.out_features)
53+
out = NpuGMMOp.apply(weight, x, tokens_per_expert)
54+
else:
55+
weight = weight.view(-1, self.out_features, self.in_features)
56+
out = group_gemm(x, weight, tokens_per_expert)
4557

4658
if self.moe_bias:
4759
bias = self.bias.to_local() if isinstance(self.bias, DTensor) else self.bias
4860
out = out + bias.repeat_interleave(tokens_per_expert, dim=0) # TODO: 无法 compile
4961
return out
5062

5163

64+
class NpuGMMOp(Function):
65+
@staticmethod
66+
def forward(ctx, weight, x, tokens_per_expert):
67+
if not int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")):
68+
weight = torch.transpose(weight, 1, 2)
69+
ctx.save_for_backward(weight, x, tokens_per_expert)
70+
outs = torch_npu.npu_grouped_matmul([x], [weight], group_list = tokens_per_expert, group_type = 0, group_list_type = 1, split_item = 2)
71+
return outs[0]
72+
73+
74+
@staticmethod
75+
def backward(ctx, grad_output):
76+
tensors = ctx.saved_tensors
77+
weight = tensors[0]
78+
input_tensor = tensors[1]
79+
tokens_per_expert = tensors[2]
80+
weight = torch.transpose(weight, 1, 2)
81+
grad_input = torch_npu.npu_grouped_matmul([grad_output], [weight], group_list = tokens_per_expert,
82+
group_type = 0, group_list_type = 1, split_item=2)[0]
83+
grad_weight = torch_npu.npu_grouped_matmul([input_tensor.T], [grad_output], bias=None, group_list = tokens_per_expert,
84+
split_item=3, group_type=2, group_list_type=1)[0]
85+
if not int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")):
86+
grad_weight = torch.transpose(grad_weight, 1, 2)
87+
return grad_weight, grad_input, None
88+
89+
90+
5291
def build_grouped_linear(
5392
in_features: int,
5493
out_features: int,

0 commit comments

Comments
 (0)