|
| 1 | +from typing import List, Tuple |
| 2 | +import os |
| 3 | +import torch |
| 4 | + |
| 5 | + |
| 6 | +lib = torch.library.Library("fsdp", "FRAGMENT") |
| 7 | + |
| 8 | + |
| 9 | +@torch.library.impl(lib, "chunk_cat", "PrivateUse1") |
| 10 | +def chunk_cat( |
| 11 | + tensors: List[torch.Tensor], |
| 12 | + dim: int, |
| 13 | + num_chunks: int, |
| 14 | + out: torch.Tensor, |
| 15 | +) -> None: |
| 16 | + tensors = [tensor.contiguous() for tensor in tensors] |
| 17 | + out = out.contiguous() |
| 18 | + torch._chunk_cat(tensors, dim, num_chunks, out=out) |
| 19 | + |
| 20 | + |
| 21 | +@torch.library.impl(lib, "all_gather_copy_in", "PrivateUse1") |
| 22 | +def all_gather_copy_in_npu( |
| 23 | + all_gather_inputs: List[torch.Tensor], |
| 24 | + inp_split_sizes: List[int], |
| 25 | + all_gather_input_numel: int, |
| 26 | + world_size: int, |
| 27 | + rank: int, |
| 28 | + dtype: torch.dtype, |
| 29 | + device: torch.device, |
| 30 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 31 | + all_gather_output = torch.empty( |
| 32 | + (all_gather_input_numel * world_size,), dtype=dtype, device=device |
| 33 | + ) |
| 34 | + all_gather_input = all_gather_output.narrow( |
| 35 | + 0, all_gather_input_numel * rank, all_gather_input_numel |
| 36 | + ) |
| 37 | + foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) |
| 38 | + with torch.no_grad(): |
| 39 | + if foreach_copy_dsts[0].device == all_gather_inputs[0].device: |
| 40 | + torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs, non_blocking=True) |
| 41 | + else: |
| 42 | + torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs) |
| 43 | + return all_gather_input, all_gather_output |
| 44 | + |
| 45 | +@torch.library.impl(lib, "split_with_sizes_copy", "PrivateUse1") |
| 46 | +def split_with_sizes_copy( |
| 47 | + all_gather_output: torch.Tensor, |
| 48 | + all_gather_input_split_sizes: List[int], |
| 49 | + dim: int, |
| 50 | + out: List[torch.Tensor], |
| 51 | + num_expert: int = 128, |
| 52 | + hidden_size: int = 4096, |
| 53 | + moe_intermediate_size: int = 1536, |
| 54 | +) -> None: |
| 55 | + # 当且仅当满足如下条件,才启用gmm_nz优化 |
| 56 | + # 1. 打开GROUPMM_NZ_TRANSPOSE开关 |
| 57 | + # 2. all_gather_input_split_sizes长度大于1 |
| 58 | + # 3. 切分后的最后一个权重用于GMM down_proj |
| 59 | + enable_gmm_nz = int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")) \ |
| 60 | + and len(all_gather_input_split_sizes) > 1 \ |
| 61 | + and out[-1].shape[0] * out[-1].shape[1] == num_expert * hidden_size * moe_intermediate_size |
| 62 | + |
| 63 | + if enable_gmm_nz: |
| 64 | + from special_op import npu_special_slice |
| 65 | + num_rank = out[0].shape[0] |
| 66 | + total_size = sum(all_gather_input_split_sizes) |
| 67 | + |
| 68 | + # 切分后最后两个权重用于GMM up_proj和down_proj |
| 69 | + up_size = out[-1].shape[1] |
| 70 | + down_size = out[-2].shape[1] |
| 71 | + |
| 72 | + up_start = total_size - up_size |
| 73 | + down_start = up_start - down_size |
| 74 | + |
| 75 | + out[-1].resize_(num_expert,moe_intermediate_size,hidden_size) |
| 76 | + out[-2].resize_(num_expert,hidden_size,moe_intermediate_size*2) |
| 77 | + |
| 78 | + # GMM权重切分和转NZ使用融合算子 |
| 79 | + npu_special_slice(all_gather_output, dim, up_start, total_size, out[-1]) |
| 80 | + npu_special_slice(all_gather_output, dim, down_start, up_start, out[-2]) |
| 81 | + |
| 82 | + other_tensors = all_gather_output[:, :down_start].view(num_rank, -1) |
| 83 | + torch.split_with_sizes_copy( |
| 84 | + other_tensors, all_gather_input_split_sizes[:-2], dim=dim, out=out[:-2] |
| 85 | + ) |
| 86 | + |
| 87 | + return |
| 88 | + |
| 89 | + torch.split_with_sizes_copy( |
| 90 | + all_gather_output, all_gather_input_split_sizes, dim=dim, out=out |
| 91 | + ) |
0 commit comments