@@ -756,6 +756,7 @@ def __init__(
756756 softcap_value = 50. ,
757757 use_flex_attn = False ,
758758 gate_values = True ,
759+ laser = False ,
759760 learned_value_residual_mix = False
760761 ):
761762 super ().__init__ ()
@@ -783,6 +784,8 @@ def __init__(
783784
784785 self .softcap_value = softcap_value
785786
787+ self .laser = laser
788+
786789 self .dropout = nn .Dropout (dropout )
787790
788791 self .to_out = nn .Sequential (
@@ -844,6 +847,12 @@ def forward(
844847 if exists (rotary_emb ):
845848 q , k = tuple (apply_rotary_emb (rotary_emb , t , freqs_seq_dim = - 2 ) for t in (q , k ))
846849
850+ # laser attention
851+
852+ if self .laser :
853+ v_max = v .amax (dim = - 2 , keepdim = True ).detach ()
854+ v = (v - v_max ).exp ()
855+
847856 # whether to use flex attention or not
848857
849858 if should_use_flex_attn :
@@ -878,6 +887,11 @@ def forward(
878887
879888 out = einsum (attn , v , 'b h i j, b h j d -> b h i d' )
880889
890+ # laser attention
891+
892+ if self .laser :
893+ out = log (out ) + v_max
894+
881895 # maybe gate values
882896
883897 if exists (self .to_gates ):
@@ -908,6 +922,7 @@ def __init__(
908922 ff_expansion_factor = 4 ,
909923 attn_kwargs : dict = dict (),
910924 ff_kwargs : dict = dict (),
925+ attn_laser = False ,
911926 unet_skips = True ,
912927 use_flex_attn = False
913928 ):
@@ -932,7 +947,7 @@ def __init__(
932947
933948 skip_proj = Linear (dim * 2 , dim , bias = False ) if is_latter_half and unet_skips else None
934949
935- attn = Attention (dim = dim , dim_head = dim_head , heads = heads , dropout = dropout , use_flex_attn = use_flex_attn , learned_value_residual_mix = not is_first , ** attn_kwargs )
950+ attn = Attention (dim = dim , dim_head = dim_head , heads = heads , dropout = dropout , use_flex_attn = use_flex_attn , learned_value_residual_mix = not is_first , laser = attn_laser , ** attn_kwargs )
936951
937952 ff = FeedForward (dim = dim , expansion_factor = ff_expansion_factor , ** ff_kwargs )
938953
0 commit comments