Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,7 @@ def __init_subclass__(cls):
)


class Configurable[ConfigType: Config]:
class Configurable[ConfigType: Config](abc.ABC):
config_class: typing.ClassVar[type[Config]] = Config

def __init__(self, config: ConfigType, *args, **kwargs):
Expand Down
24 changes: 0 additions & 24 deletions fast_llm/functional/rotary.py

This file was deleted.

139 changes: 85 additions & 54 deletions fast_llm/layers/common/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import enum
import typing

Expand All @@ -6,6 +7,8 @@
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
import torch

from fast_llm.engine.config_utils.tensor_space import TensorDim
from fast_llm.layers.common.linear import LinearBase, LinearLike
from fast_llm.layers.common.normalization import LayerNorm, RMSNorm
Expand Down Expand Up @@ -35,26 +38,42 @@ class NormalizationImplementation(str, enum.Enum):
triton = "triton"


class NormalizationType(str, enum.Enum):
"""
An enum for the available normalization layers.
TODO: Add no_norm type?
"""
@config_class(registry=True)
class NormalizationConfig(BaseModelConfig):
pass

layer_norm = "layer_norm"
rms_norm = "rms_norm"
@abc.abstractmethod
def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module":
pass

@classmethod
def _from_dict(
cls,
default: dict[str, typing.Any],
strict: bool = True,
flat: bool = False,
) -> typing.Self:
if cls is NormalizationConfig and cls.get_subclass(default.get("type")) is None:
# Default subclass.
return LayerNormalizationConfig._from_dict(default, strict, flat)
return super()._from_dict(default, strict=strict, flat=flat)

@config_class(registry=True)
class NormalizationConfig(BaseModelConfig):

@config_class(dynamic_type={NormalizationConfig: "none"})
class NoNormalizationConfig(NormalizationConfig):
_abstract = False

# Normalization type
type: NormalizationType = Field(
default=NormalizationType.layer_norm,
desc="The type of normalization to use, for example Layer Norm or RMS Norm.",
hint=FieldHint.architecture,
)
@abc.abstractmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@abc.abstractmethod is not needed here?

def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module":
return torch.nn.Identity()


@config_class()
class LayerNormalizationBaseConfig(NormalizationConfig):
"""
Common configuration for layer norm and rms norm
"""

# TODO: Rename to normalization_epsilon
epsilon: float = Field(
default=1e-5,
Expand All @@ -81,7 +100,6 @@ class NormalizationConfig(BaseModelConfig):
)

def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm":
from fast_llm.layers.common.normalization import LayerNorm, RMSNorm
from fast_llm.tensor import init_uniform_

kwargs = {
Expand All @@ -96,14 +114,12 @@ def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "
kwargs["weight_init_method"] = init_uniform_(
mean - self.initialization_range, mean + self.initialization_range
)
if self.type == NormalizationType.layer_norm:
if self.initialization_range:
kwargs["bias_init_method"] = init_uniform_(-self.initialization_range, self.initialization_range)
return LayerNorm(**kwargs)
elif self.type == NormalizationType.rms_norm:
return RMSNorm(**kwargs)
else:
raise ValueError(self.type)
return self.module_class(**kwargs)

@property
@abc.abstractmethod
def module_class(self):
pass

@classmethod
def _from_dict(
Expand All @@ -120,27 +136,47 @@ def _from_dict(
return super()._from_dict(default, strict, flat)


for name in NormalizationType:
# We need this because we are using the reserved field name `type`.
# TODO: Implement proper dynamic typing.
NormalizationConfig.register_subclass(name.value, NormalizationConfig)
@config_class(dynamic_type={NormalizationConfig: "layer_norm"})
class LayerNormalizationConfig(LayerNormalizationBaseConfig):
_abstract = False

@property
def module_class(self):
from fast_llm.layers.common.normalization import LayerNorm

return LayerNorm

class PeftType(str, enum.Enum):
# TODO : Use a dynamic config type instead.
none = "none"
lora = "lora"

@config_class(dynamic_type={NormalizationConfig: "rms_norm"})
class RMSNormalizationConfig(LayerNormalizationBaseConfig):
_abstract = False

@property
def module_class(self):
from fast_llm.layers.common.normalization import RMSNorm

return RMSNorm


@config_class()
class PeftConfig(BaseModelConfig):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no default handling (from_dict)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really use that one and it doesn't have a registry, these are in TransformerPeftConfig instead.

@abc.abstractmethod
def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike":
pass


@config_class()
class NoPeftConfig(PeftConfig):
_abstract = False

def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike":
return linear


@config_class()
class LoRAConfig(PeftConfig):
_abstract = False

type: PeftType = Field(
default=PeftType.none,
desc="The type of parameter-efficient fine tuning to use Only LoRA is supported at the moment.",
hint=FieldHint.core,
)
rank: int = Field(
default=8,
desc="The LoRA rank, i.e. the size of the intermediate dimension.",
Expand All @@ -158,20 +194,15 @@ class PeftConfig(BaseModelConfig):
)

def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike":
if self.type == PeftType.none:
return linear
elif self.type == PeftType.lora:
from fast_llm.layers.common.peft import lora_linear

# TODO: Init method?
return lora_linear(
linear,
linear.weight.param_init_method,
linear.weight.param_init_method,
self.rank,
self.alpha,
self.dropout,
**kwargs,
)
else:
raise NotImplementedError(self.type)
from fast_llm.layers.common.peft import lora_linear

# TODO: Init method?
return lora_linear(
linear,
linear.weight.param_init_method,
linear.weight.param_init_method,
self.rank,
self.alpha,
self.dropout,
**kwargs,
)
3 changes: 2 additions & 1 deletion fast_llm/layers/language_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fast_llm.engine.distributed.config import DistributedDimNames
from fast_llm.functional.config import CrossEntropyImpl
from fast_llm.layers.transformer.config import TransformerConfig
from fast_llm.layers.transformer.rotary.config import NoRotaryConfig
from fast_llm.utils import Assert


Expand Down Expand Up @@ -179,7 +180,7 @@ def _validate(self) -> None:
self.transformer.validate()
with self._set_implicit_default():
if self.use_position_embeddings is None:
self.use_position_embeddings = not self.transformer.rotary.enabled
self.use_position_embeddings = isinstance(self.transformer.rotary, NoRotaryConfig)
if self.init_method_std_embed is None:
self.init_method_std_embed = self.transformer.init_method_std
if self.init_method_max_embed is None:
Expand Down
26 changes: 12 additions & 14 deletions fast_llm/layers/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim
from fast_llm.engine.config_utils.tensor_space import TensorSpace
from fast_llm.functional.autograd import wrap_forward_backward
from fast_llm.functional.rotary import apply_rotary_embeddings
from fast_llm.functional.triton.rotary import triton_rotary_autograd_
from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear
from fast_llm.layers.transformer.config import (
TransformerConfig,
Expand Down Expand Up @@ -134,6 +132,9 @@ def __init__(
)
self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward)

# Rotary embeddings.
self._rotary = self._config.rotary.build()

# Output.
self.dense = InputParallelLinear(
self._tensor_space.get_tensor_dim(TransformerDimNames.composite_dense),
Expand Down Expand Up @@ -340,18 +341,15 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[
key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels)
value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels)

if self._config.rotary.enabled:
if self._debug_transformer:
self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs)
self._debug_log(
key,
"key_rotary_input",
self._KV_DIMS,
kwargs,
)
rotary_fn = triton_rotary_autograd_ if self._config.rotary.triton else apply_rotary_embeddings
query = rotary_fn(query, kwargs[TransformerKwargs.rotary_freq_q])
key = rotary_fn(key, kwargs[TransformerKwargs.rotary_freq_k])
if self._debug_transformer:
self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs)
self._debug_log(
key,
"key_rotary_input",
self._KV_DIMS,
kwargs,
)
query, key = self._rotary(query, key, kwargs)

window_size = self._decide_window_size()

Expand Down
Loading
Loading