11from typing import Optional
22import torch
33import bmtrain as bmt
4- from bmtrain .nn import Linear
4+ from bmtrain .nn import (
5+ Linear ,
6+ ColumnParallelLinear ,
7+ RowParallelLinear ,
8+ )
59import math
10+ from bmtrain .global_var import config
11+ from bmtrain .distributed import all_gather
612
713class Attention (bmt .DistributedModule ):
814 def __init__ (self ,
@@ -12,11 +18,17 @@ def __init__(self,
1218 ) -> None :
1319 super ().__init__ ()
1420
15- self .project_q = Linear (dim_model , dim_head * num_heads , bias = bias , dtype = dtype )
16- self .project_k = Linear (dim_model , dim_head * num_heads , bias = bias , dtype = dtype )
17- self .project_v = Linear (dim_model , dim_head * num_heads , bias = bias , dtype = dtype )
21+ if config ['tp_size' ] > 1 :
22+ self .project_q = ColumnParallelLinear (dim_model , dim_head * num_heads , bias = bias , dtype = dtype , gather_input = False )
23+ self .project_k = ColumnParallelLinear (dim_model , dim_head * num_heads , bias = bias , dtype = dtype , gather_input = False )
24+ self .project_v = ColumnParallelLinear (dim_model , dim_head * num_heads , bias = bias , dtype = dtype , gather_input = False )
25+ self .project_out = RowParallelLinear (dim_head * num_heads , dim_model , bias = bias , dtype = dtype )
26+ else :
27+ self .project_q = Linear (dim_model , dim_head * num_heads , bias = bias , dtype = dtype )
28+ self .project_k = Linear (dim_model , dim_head * num_heads , bias = bias , dtype = dtype )
29+ self .project_v = Linear (dim_model , dim_head * num_heads , bias = bias , dtype = dtype )
30+ self .project_out = Linear (dim_head * num_heads , dim_model , bias = bias , dtype = dtype )
1831
19- self .project_out = Linear (dim_head * num_heads , dim_model , bias = bias , dtype = dtype )
2032
2133 self .softmax = torch .nn .Softmax (dim = - 1 )
2234 self .num_heads = num_heads
@@ -32,32 +44,48 @@ def forward(self,
3244 batch_size , seq_q , dim_model = hidden_q .size ()
3345 seq_kv = hidden_kv .size (1 )
3446
35- h_q : torch .Tensor = self .project_q (hidden_q )
36- h_k : torch .Tensor = self .project_k (hidden_kv )
37- h_v : torch .Tensor = self .project_v (hidden_kv )
47+ assert hidden_q .data_ptr () == hidden_kv .data_ptr ()
3848
39- h_q = h_q .view (batch_size , seq_q , self .num_heads , self .dim_head )
40- h_k = h_k .view (batch_size , seq_kv , self .num_heads , self .dim_head )
41- h_v = h_v .view (batch_size , seq_kv , self .num_heads , self .dim_head )
49+ hidden_q = bmt .nn .OpParallelLinear .apply (
50+ hidden_q ,
51+ torch .cat ([self .project_q .weight , self .project_k .weight , self .project_v .weight ], dim = 0 ),
52+ torch .cat ([self .project_q .bias , self .project_k .bias , self .project_v .bias ], dim = 0 ),
53+ True , False ,
54+ False , None
55+ )
56+
57+ h_q , h_k , h_v = hidden_q .chunk (3 , dim = - 1 )
58+
59+ if config ['tp_size' ] > 1 :
60+ #batch_size will changed in TensorParallel
61+ batch_size = h_v .shape [0 ]
62+
63+ h_q = h_q .view (batch_size , seq_q , - 1 , self .dim_head )
64+ h_k = h_k .view (batch_size , seq_kv , - 1 , self .dim_head )
65+ h_v = h_v .view (batch_size , seq_kv , - 1 , self .dim_head )
4266
4367 h_q = h_q .permute (0 , 2 , 1 , 3 ).contiguous ()
4468 h_k = h_k .permute (0 , 2 , 1 , 3 ).contiguous ()
4569 h_v = h_v .permute (0 , 2 , 1 , 3 ).contiguous ()
4670
47- h_q = h_q .view (batch_size * self . num_heads , seq_q , self .dim_head )
48- h_k = h_k .view (batch_size * self . num_heads , seq_kv , self .dim_head )
49- h_v = h_v .view (batch_size * self . num_heads , seq_kv , self .dim_head )
71+ h_q = h_q .view (- 1 , seq_q , self .dim_head )
72+ h_k = h_k .view (- 1 , seq_kv , self .dim_head )
73+ h_v = h_v .view (- 1 , seq_kv , self .dim_head )
5074
5175 score = torch .bmm (
5276 h_q , h_k .transpose (1 , 2 )
5377 )
5478 score = score / math .sqrt (self .dim_head )
5579
56- score = score .view (batch_size , self . num_heads , seq_q , seq_kv )
80+ score = score .view (batch_size , - 1 , seq_q , seq_kv )
5781
5882 if position_bias is not None :
59- score = score + position_bias .view (batch_size , self .num_heads , seq_q , seq_kv )
60-
83+ score = score + position_bias .view (batch_size , - 1 , seq_q , seq_kv )
84+
85+ if config ['tp_size' ] > 1 :
86+ with torch .no_grad ():
87+ mask = all_gather (mask , config ['tp_comm' ]).flatten (0 ,1 )
88+
6189 score = torch .where (
6290 mask .view (batch_size , 1 , seq_q , seq_kv ),
6391 score ,
@@ -70,14 +98,14 @@ def forward(self,
7098 torch .scalar_tensor (0 , device = score .device , dtype = score .dtype )
7199 )
72100
73- score = score .view (batch_size * self . num_heads , seq_q , seq_kv )
101+ score = score .view (- 1 , seq_q , seq_kv )
74102
75103 h_out = torch .bmm (
76104 score , h_v
77105 )
78- h_out = h_out .view (batch_size , self . num_heads , seq_q , self .dim_head )
106+ h_out = h_out .view (batch_size , - 1 , seq_q , self .dim_head )
79107 h_out = h_out .permute (0 , 2 , 1 , 3 ).contiguous ()
80- h_out = h_out .view (batch_size , seq_q , self . num_heads * self . dim_head )
108+ h_out = h_out .view (batch_size , seq_q , - 1 )
81109
82110 attn_out = self .project_out (h_out )
83111 return attn_out
0 commit comments