11from typing import Optional
22import torch
33import bmtrain as bmt
4- from bmtrain .nn import (
5- Linear ,
6- ColumnParallelLinear ,
7- RowParallelLinear ,
8- )
4+ from bmtrain .nn import Linear
95import math
10- from bmtrain .global_var import config
11- from bmtrain .distributed import all_gather
126
137class Attention (bmt .DistributedModule ):
148 def __init__ (self ,
@@ -18,21 +12,14 @@ def __init__(self,
1812 ) -> None :
1913 super ().__init__ ()
2014
21- if config ['tp_size' ] > 1 :
22- self .project_q = ColumnParallelLinear (dim_model , dim_head * num_heads , bias = bias , dtype = dtype , gather_input = False )
23- self .project_k = ColumnParallelLinear (dim_model , dim_head * num_heads , bias = bias , dtype = dtype , gather_input = False )
24- self .project_v = ColumnParallelLinear (dim_model , dim_head * num_heads , bias = bias , dtype = dtype , gather_input = False )
25- self .project_out = RowParallelLinear (dim_head * num_heads , dim_model , bias = bias , dtype = dtype )
26- else :
27- self .project_q = Linear (dim_model , dim_head * num_heads , bias = bias , dtype = dtype )
28- self .project_k = Linear (dim_model , dim_head * num_heads , bias = bias , dtype = dtype )
29- self .project_v = Linear (dim_model , dim_head * num_heads , bias = bias , dtype = dtype )
30- self .project_out = Linear (dim_head * num_heads , dim_model , bias = bias , dtype = dtype )
15+ self .project_q = Linear (dim_model , dim_head * num_heads , bias = bias , dtype = dtype )
16+ self .project_k = Linear (dim_model , dim_head * num_heads , bias = bias , dtype = dtype )
17+ self .project_v = Linear (dim_model , dim_head * num_heads , bias = bias , dtype = dtype )
3118
19+ self .project_out = Linear (dim_head * num_heads , dim_model , bias = bias , dtype = dtype )
3220
3321 self .softmax = torch .nn .Softmax (dim = - 1 )
3422 self .num_heads = num_heads
35- self .num_kv_heads = num_heads
3623 self .dim_head = dim_head
3724 self .dim_model = dim_model
3825
@@ -45,50 +32,32 @@ def forward(self,
4532 batch_size , seq_q , dim_model = hidden_q .size ()
4633 seq_kv = hidden_kv .size (1 )
4734
48- if isinstance (self .project_q , ColumnParallelLinear ):
49- assert hidden_q .data_ptr () == hidden_kv .data_ptr ()
50- hidden_q = bmt .nn .OpParallelLinear .apply (
51- hidden_q ,
52- torch .cat ([self .project_q .weight , self .project_k .weight , self .project_v .weight ], dim = 0 ),
53- torch .cat ([self .project_q .bias , self .project_k .bias , self .project_v .bias ], dim = 0 ) if self .project_q .bias is not None else None ,
54- True , False ,
55- False , None
56- )
57- h_q , h_k , h_v = hidden_q .chunk (3 , dim = - 1 )
58- else :
59- h_q : torch .Tensor = self .project_q (hidden_q )
60- h_k : torch .Tensor = self .project_k (hidden_q )
61- h_v : torch .Tensor = self .project_v (hidden_q )
62- if config ['tp_size' ] > 1 :
63- #batch_size will changed in TensorParallel
64- batch_size = h_v .shape [0 ]
65-
66- h_q = h_q .view (batch_size , seq_q , - 1 , self .dim_head )
67- h_k = h_k .view (batch_size , seq_kv , - 1 , self .dim_head )
68- h_v = h_v .view (batch_size , seq_kv , - 1 , self .dim_head )
35+ h_q : torch .Tensor = self .project_q (hidden_q )
36+ h_k : torch .Tensor = self .project_k (hidden_kv )
37+ h_v : torch .Tensor = self .project_v (hidden_kv )
38+
39+ h_q = h_q .view (batch_size , seq_q , self .num_heads , self .dim_head )
40+ h_k = h_k .view (batch_size , seq_kv , self .num_heads , self .dim_head )
41+ h_v = h_v .view (batch_size , seq_kv , self .num_heads , self .dim_head )
6942
7043 h_q = h_q .permute (0 , 2 , 1 , 3 ).contiguous ()
7144 h_k = h_k .permute (0 , 2 , 1 , 3 ).contiguous ()
7245 h_v = h_v .permute (0 , 2 , 1 , 3 ).contiguous ()
7346
74- h_q = h_q .view (- 1 , seq_q , self .dim_head )
75- h_k = h_k .view (- 1 , seq_kv , self .dim_head )
76- h_v = h_v .view (- 1 , seq_kv , self .dim_head )
47+ h_q = h_q .view (batch_size * self . num_heads , seq_q , self .dim_head )
48+ h_k = h_k .view (batch_size * self . num_heads , seq_kv , self .dim_head )
49+ h_v = h_v .view (batch_size * self . num_heads , seq_kv , self .dim_head )
7750
7851 score = torch .bmm (
7952 h_q , h_k .transpose (1 , 2 )
8053 )
8154 score = score / math .sqrt (self .dim_head )
8255
83- score = score .view (batch_size , - 1 , seq_q , seq_kv )
56+ score = score .view (batch_size , self . num_heads , seq_q , seq_kv )
8457
8558 if position_bias is not None :
86- score = score + position_bias .view (batch_size , - 1 , seq_q , seq_kv )
87-
88- if config ['tp_size' ] > 1 :
89- with torch .no_grad ():
90- mask = all_gather (mask , config ['tp_comm' ]).flatten (0 ,1 )
91-
59+ score = score + position_bias .view (batch_size , self .num_heads , seq_q , seq_kv )
60+
9261 score = torch .where (
9362 mask .view (batch_size , 1 , seq_q , seq_kv ),
9463 score ,
@@ -101,14 +70,14 @@ def forward(self,
10170 torch .scalar_tensor (0 , device = score .device , dtype = score .dtype )
10271 )
10372
104- score = score .view (- 1 , seq_q , seq_kv )
73+ score = score .view (batch_size * self . num_heads , seq_q , seq_kv )
10574
10675 h_out = torch .bmm (
10776 score , h_v
10877 )
109- h_out = h_out .view (batch_size , - 1 , seq_q , self .dim_head )
78+ h_out = h_out .view (batch_size , self . num_heads , seq_q , self .dim_head )
11079 h_out = h_out .permute (0 , 2 , 1 , 3 ).contiguous ()
111- h_out = h_out .view (batch_size , seq_q , - 1 )
80+ h_out = h_out .view (batch_size , seq_q , self . num_heads * self . dim_head )
11281
11382 attn_out = self .project_out (h_out )
11483 return attn_out
0 commit comments