Skip to content

Commit 016a308

Browse files
Dynamic transformer (#275)
Co-authored-by: RaymondLi0 <[email protected]>
1 parent 39b018a commit 016a308

File tree

18 files changed

+708
-509
lines changed

18 files changed

+708
-509
lines changed

fast_llm/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -988,7 +988,7 @@ def __init_subclass__(cls):
988988
)
989989

990990

991-
class Configurable[ConfigType: Config]:
991+
class Configurable[ConfigType: Config](abc.ABC):
992992
config_class: typing.ClassVar[type[Config]] = Config
993993

994994
def __init__(self, config: ConfigType, *args, **kwargs):

fast_llm/functional/rotary.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

fast_llm/layers/common/config.py

Lines changed: 85 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import abc
12
import enum
23
import typing
34

@@ -6,6 +7,8 @@
67
from fast_llm.utils import Assert
78

89
if typing.TYPE_CHECKING:
10+
import torch
11+
912
from fast_llm.engine.config_utils.tensor_space import TensorDim
1013
from fast_llm.layers.common.linear import LinearBase, LinearLike
1114
from fast_llm.layers.common.normalization import LayerNorm, RMSNorm
@@ -35,26 +38,42 @@ class NormalizationImplementation(str, enum.Enum):
3538
triton = "triton"
3639

3740

38-
class NormalizationType(str, enum.Enum):
39-
"""
40-
An enum for the available normalization layers.
41-
TODO: Add no_norm type?
42-
"""
41+
@config_class(registry=True)
42+
class NormalizationConfig(BaseModelConfig):
43+
pass
4344

44-
layer_norm = "layer_norm"
45-
rms_norm = "rms_norm"
45+
@abc.abstractmethod
46+
def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module":
47+
pass
4648

49+
@classmethod
50+
def _from_dict(
51+
cls,
52+
default: dict[str, typing.Any],
53+
strict: bool = True,
54+
flat: bool = False,
55+
) -> typing.Self:
56+
if cls is NormalizationConfig and cls.get_subclass(default.get("type")) is None:
57+
# Default subclass.
58+
return LayerNormalizationConfig._from_dict(default, strict, flat)
59+
return super()._from_dict(default, strict=strict, flat=flat)
4760

48-
@config_class(registry=True)
49-
class NormalizationConfig(BaseModelConfig):
61+
62+
@config_class(dynamic_type={NormalizationConfig: "none"})
63+
class NoNormalizationConfig(NormalizationConfig):
5064
_abstract = False
5165

52-
# Normalization type
53-
type: NormalizationType = Field(
54-
default=NormalizationType.layer_norm,
55-
desc="The type of normalization to use, for example Layer Norm or RMS Norm.",
56-
hint=FieldHint.architecture,
57-
)
66+
@abc.abstractmethod
67+
def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module":
68+
return torch.nn.Identity()
69+
70+
71+
@config_class()
72+
class LayerNormalizationBaseConfig(NormalizationConfig):
73+
"""
74+
Common configuration for layer norm and rms norm
75+
"""
76+
5877
# TODO: Rename to normalization_epsilon
5978
epsilon: float = Field(
6079
default=1e-5,
@@ -81,7 +100,6 @@ class NormalizationConfig(BaseModelConfig):
81100
)
82101

83102
def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm":
84-
from fast_llm.layers.common.normalization import LayerNorm, RMSNorm
85103
from fast_llm.tensor import init_uniform_
86104

87105
kwargs = {
@@ -96,14 +114,12 @@ def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "
96114
kwargs["weight_init_method"] = init_uniform_(
97115
mean - self.initialization_range, mean + self.initialization_range
98116
)
99-
if self.type == NormalizationType.layer_norm:
100-
if self.initialization_range:
101-
kwargs["bias_init_method"] = init_uniform_(-self.initialization_range, self.initialization_range)
102-
return LayerNorm(**kwargs)
103-
elif self.type == NormalizationType.rms_norm:
104-
return RMSNorm(**kwargs)
105-
else:
106-
raise ValueError(self.type)
117+
return self.module_class(**kwargs)
118+
119+
@property
120+
@abc.abstractmethod
121+
def module_class(self):
122+
pass
107123

108124
@classmethod
109125
def _from_dict(
@@ -120,27 +136,47 @@ def _from_dict(
120136
return super()._from_dict(default, strict, flat)
121137

122138

123-
for name in NormalizationType:
124-
# We need this because we are using the reserved field name `type`.
125-
# TODO: Implement proper dynamic typing.
126-
NormalizationConfig.register_subclass(name.value, NormalizationConfig)
139+
@config_class(dynamic_type={NormalizationConfig: "layer_norm"})
140+
class LayerNormalizationConfig(LayerNormalizationBaseConfig):
141+
_abstract = False
142+
143+
@property
144+
def module_class(self):
145+
from fast_llm.layers.common.normalization import LayerNorm
127146

147+
return LayerNorm
128148

129-
class PeftType(str, enum.Enum):
130-
# TODO : Use a dynamic config type instead.
131-
none = "none"
132-
lora = "lora"
149+
150+
@config_class(dynamic_type={NormalizationConfig: "rms_norm"})
151+
class RMSNormalizationConfig(LayerNormalizationBaseConfig):
152+
_abstract = False
153+
154+
@property
155+
def module_class(self):
156+
from fast_llm.layers.common.normalization import RMSNorm
157+
158+
return RMSNorm
133159

134160

135161
@config_class()
136162
class PeftConfig(BaseModelConfig):
163+
@abc.abstractmethod
164+
def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike":
165+
pass
166+
167+
168+
@config_class()
169+
class NoPeftConfig(PeftConfig):
170+
_abstract = False
171+
172+
def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike":
173+
return linear
174+
175+
176+
@config_class()
177+
class LoRAConfig(PeftConfig):
137178
_abstract = False
138179

139-
type: PeftType = Field(
140-
default=PeftType.none,
141-
desc="The type of parameter-efficient fine tuning to use Only LoRA is supported at the moment.",
142-
hint=FieldHint.core,
143-
)
144180
rank: int = Field(
145181
default=8,
146182
desc="The LoRA rank, i.e. the size of the intermediate dimension.",
@@ -158,20 +194,15 @@ class PeftConfig(BaseModelConfig):
158194
)
159195

160196
def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike":
161-
if self.type == PeftType.none:
162-
return linear
163-
elif self.type == PeftType.lora:
164-
from fast_llm.layers.common.peft import lora_linear
165-
166-
# TODO: Init method?
167-
return lora_linear(
168-
linear,
169-
linear.weight.param_init_method,
170-
linear.weight.param_init_method,
171-
self.rank,
172-
self.alpha,
173-
self.dropout,
174-
**kwargs,
175-
)
176-
else:
177-
raise NotImplementedError(self.type)
197+
from fast_llm.layers.common.peft import lora_linear
198+
199+
# TODO: Init method?
200+
return lora_linear(
201+
linear,
202+
linear.weight.param_init_method,
203+
linear.weight.param_init_method,
204+
self.rank,
205+
self.alpha,
206+
self.dropout,
207+
**kwargs,
208+
)

fast_llm/layers/language_model/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from fast_llm.engine.distributed.config import DistributedDimNames
77
from fast_llm.functional.config import CrossEntropyImpl
88
from fast_llm.layers.transformer.config import TransformerConfig
9+
from fast_llm.layers.transformer.rotary.config import NoRotaryConfig
910
from fast_llm.utils import Assert
1011

1112

@@ -179,7 +180,7 @@ def _validate(self) -> None:
179180
self.transformer.validate()
180181
with self._set_implicit_default():
181182
if self.use_position_embeddings is None:
182-
self.use_position_embeddings = not self.transformer.rotary.enabled
183+
self.use_position_embeddings = isinstance(self.transformer.rotary, NoRotaryConfig)
183184
if self.init_method_std_embed is None:
184185
self.init_method_std_embed = self.transformer.init_method_std
185186
if self.init_method_max_embed is None:

fast_llm/layers/transformer/attention.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim
77
from fast_llm.engine.config_utils.tensor_space import TensorSpace
88
from fast_llm.functional.autograd import wrap_forward_backward
9-
from fast_llm.functional.rotary import apply_rotary_embeddings
10-
from fast_llm.functional.triton.rotary import triton_rotary_autograd_
119
from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear
1210
from fast_llm.layers.transformer.config import (
1311
TransformerConfig,
@@ -134,6 +132,9 @@ def __init__(
134132
)
135133
self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward)
136134

135+
# Rotary embeddings.
136+
self._rotary = self._config.rotary.build()
137+
137138
# Output.
138139
self.dense = InputParallelLinear(
139140
self._tensor_space.get_tensor_dim(TransformerDimNames.composite_dense),
@@ -340,18 +341,15 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[
340341
key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels)
341342
value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels)
342343

343-
if self._config.rotary.enabled:
344-
if self._debug_transformer:
345-
self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs)
346-
self._debug_log(
347-
key,
348-
"key_rotary_input",
349-
self._KV_DIMS,
350-
kwargs,
351-
)
352-
rotary_fn = triton_rotary_autograd_ if self._config.rotary.triton else apply_rotary_embeddings
353-
query = rotary_fn(query, kwargs[TransformerKwargs.rotary_freq_q])
354-
key = rotary_fn(key, kwargs[TransformerKwargs.rotary_freq_k])
344+
if self._debug_transformer:
345+
self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs)
346+
self._debug_log(
347+
key,
348+
"key_rotary_input",
349+
self._KV_DIMS,
350+
kwargs,
351+
)
352+
query, key = self._rotary(query, key, kwargs)
355353

356354
window_size = self._decide_window_size()
357355

0 commit comments

Comments
 (0)