1
+ import abc
1
2
import enum
2
3
import typing
3
4
6
7
from fast_llm .utils import Assert
7
8
8
9
if typing .TYPE_CHECKING :
10
+ import torch
11
+
9
12
from fast_llm .engine .config_utils .tensor_space import TensorDim
10
13
from fast_llm .layers .common .linear import LinearBase , LinearLike
11
14
from fast_llm .layers .common .normalization import LayerNorm , RMSNorm
@@ -35,26 +38,42 @@ class NormalizationImplementation(str, enum.Enum):
35
38
triton = "triton"
36
39
37
40
38
- class NormalizationType (str , enum .Enum ):
39
- """
40
- An enum for the available normalization layers.
41
- TODO: Add no_norm type?
42
- """
41
+ @config_class (registry = True )
42
+ class NormalizationConfig (BaseModelConfig ):
43
+ pass
43
44
44
- layer_norm = "layer_norm"
45
- rms_norm = "rms_norm"
45
+ @abc .abstractmethod
46
+ def get_layer (self , hidden_dim : "TensorDim" ) -> "torch.nn.Module" :
47
+ pass
46
48
49
+ @classmethod
50
+ def _from_dict (
51
+ cls ,
52
+ default : dict [str , typing .Any ],
53
+ strict : bool = True ,
54
+ flat : bool = False ,
55
+ ) -> typing .Self :
56
+ if cls is NormalizationConfig and cls .get_subclass (default .get ("type" )) is None :
57
+ # Default subclass.
58
+ return LayerNormalizationConfig ._from_dict (default , strict , flat )
59
+ return super ()._from_dict (default , strict = strict , flat = flat )
47
60
48
- @config_class (registry = True )
49
- class NormalizationConfig (BaseModelConfig ):
61
+
62
+ @config_class (dynamic_type = {NormalizationConfig : "none" })
63
+ class NoNormalizationConfig (NormalizationConfig ):
50
64
_abstract = False
51
65
52
- # Normalization type
53
- type : NormalizationType = Field (
54
- default = NormalizationType .layer_norm ,
55
- desc = "The type of normalization to use, for example Layer Norm or RMS Norm." ,
56
- hint = FieldHint .architecture ,
57
- )
66
+ @abc .abstractmethod
67
+ def get_layer (self , hidden_dim : "TensorDim" ) -> "torch.nn.Module" :
68
+ return torch .nn .Identity ()
69
+
70
+
71
+ @config_class ()
72
+ class LayerNormalizationBaseConfig (NormalizationConfig ):
73
+ """
74
+ Common configuration for layer norm and rms norm
75
+ """
76
+
58
77
# TODO: Rename to normalization_epsilon
59
78
epsilon : float = Field (
60
79
default = 1e-5 ,
@@ -81,7 +100,6 @@ class NormalizationConfig(BaseModelConfig):
81
100
)
82
101
83
102
def get_layer (self , hidden_dim : "TensorDim" , lr_scale : float | None = None ) -> "LayerNorm | RMSNorm" :
84
- from fast_llm .layers .common .normalization import LayerNorm , RMSNorm
85
103
from fast_llm .tensor import init_uniform_
86
104
87
105
kwargs = {
@@ -96,14 +114,12 @@ def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "
96
114
kwargs ["weight_init_method" ] = init_uniform_ (
97
115
mean - self .initialization_range , mean + self .initialization_range
98
116
)
99
- if self .type == NormalizationType .layer_norm :
100
- if self .initialization_range :
101
- kwargs ["bias_init_method" ] = init_uniform_ (- self .initialization_range , self .initialization_range )
102
- return LayerNorm (** kwargs )
103
- elif self .type == NormalizationType .rms_norm :
104
- return RMSNorm (** kwargs )
105
- else :
106
- raise ValueError (self .type )
117
+ return self .module_class (** kwargs )
118
+
119
+ @property
120
+ @abc .abstractmethod
121
+ def module_class (self ):
122
+ pass
107
123
108
124
@classmethod
109
125
def _from_dict (
@@ -120,27 +136,47 @@ def _from_dict(
120
136
return super ()._from_dict (default , strict , flat )
121
137
122
138
123
- for name in NormalizationType :
124
- # We need this because we are using the reserved field name `type`.
125
- # TODO: Implement proper dynamic typing.
126
- NormalizationConfig .register_subclass (name .value , NormalizationConfig )
139
+ @config_class (dynamic_type = {NormalizationConfig : "layer_norm" })
140
+ class LayerNormalizationConfig (LayerNormalizationBaseConfig ):
141
+ _abstract = False
142
+
143
+ @property
144
+ def module_class (self ):
145
+ from fast_llm .layers .common .normalization import LayerNorm
127
146
147
+ return LayerNorm
128
148
129
- class PeftType (str , enum .Enum ):
130
- # TODO : Use a dynamic config type instead.
131
- none = "none"
132
- lora = "lora"
149
+
150
+ @config_class (dynamic_type = {NormalizationConfig : "rms_norm" })
151
+ class RMSNormalizationConfig (LayerNormalizationBaseConfig ):
152
+ _abstract = False
153
+
154
+ @property
155
+ def module_class (self ):
156
+ from fast_llm .layers .common .normalization import RMSNorm
157
+
158
+ return RMSNorm
133
159
134
160
135
161
@config_class ()
136
162
class PeftConfig (BaseModelConfig ):
163
+ @abc .abstractmethod
164
+ def apply_linear (self , linear : "LinearBase" , ** kwargs ) -> "LinearLike" :
165
+ pass
166
+
167
+
168
+ @config_class ()
169
+ class NoPeftConfig (PeftConfig ):
170
+ _abstract = False
171
+
172
+ def apply_linear (self , linear : "LinearBase" , ** kwargs ) -> "LinearLike" :
173
+ return linear
174
+
175
+
176
+ @config_class ()
177
+ class LoRAConfig (PeftConfig ):
137
178
_abstract = False
138
179
139
- type : PeftType = Field (
140
- default = PeftType .none ,
141
- desc = "The type of parameter-efficient fine tuning to use Only LoRA is supported at the moment." ,
142
- hint = FieldHint .core ,
143
- )
144
180
rank : int = Field (
145
181
default = 8 ,
146
182
desc = "The LoRA rank, i.e. the size of the intermediate dimension." ,
@@ -158,20 +194,15 @@ class PeftConfig(BaseModelConfig):
158
194
)
159
195
160
196
def apply_linear (self , linear : "LinearBase" , ** kwargs ) -> "LinearLike" :
161
- if self .type == PeftType .none :
162
- return linear
163
- elif self .type == PeftType .lora :
164
- from fast_llm .layers .common .peft import lora_linear
165
-
166
- # TODO: Init method?
167
- return lora_linear (
168
- linear ,
169
- linear .weight .param_init_method ,
170
- linear .weight .param_init_method ,
171
- self .rank ,
172
- self .alpha ,
173
- self .dropout ,
174
- ** kwargs ,
175
- )
176
- else :
177
- raise NotImplementedError (self .type )
197
+ from fast_llm .layers .common .peft import lora_linear
198
+
199
+ # TODO: Init method?
200
+ return lora_linear (
201
+ linear ,
202
+ linear .weight .param_init_method ,
203
+ linear .weight .param_init_method ,
204
+ self .rank ,
205
+ self .alpha ,
206
+ self .dropout ,
207
+ ** kwargs ,
208
+ )
0 commit comments