-
Notifications
You must be signed in to change notification settings - Fork 39
Dynamic transformer #275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dynamic transformer #275
Changes from all commits
8ce8674
5513e48
3971464
9eb745c
62abf27
619a0da
912431c
3c8cf0b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| import abc | ||
| import enum | ||
| import typing | ||
|
|
||
|
|
@@ -6,6 +7,8 @@ | |
| from fast_llm.utils import Assert | ||
|
|
||
| if typing.TYPE_CHECKING: | ||
| import torch | ||
|
|
||
| from fast_llm.engine.config_utils.tensor_space import TensorDim | ||
| from fast_llm.layers.common.linear import LinearBase, LinearLike | ||
| from fast_llm.layers.common.normalization import LayerNorm, RMSNorm | ||
|
|
@@ -35,26 +38,42 @@ class NormalizationImplementation(str, enum.Enum): | |
| triton = "triton" | ||
|
|
||
|
|
||
| class NormalizationType(str, enum.Enum): | ||
| """ | ||
| An enum for the available normalization layers. | ||
| TODO: Add no_norm type? | ||
| """ | ||
| @config_class(registry=True) | ||
| class NormalizationConfig(BaseModelConfig): | ||
| pass | ||
|
|
||
| layer_norm = "layer_norm" | ||
| rms_norm = "rms_norm" | ||
| @abc.abstractmethod | ||
| def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": | ||
| pass | ||
|
|
||
| @classmethod | ||
| def _from_dict( | ||
| cls, | ||
| default: dict[str, typing.Any], | ||
| strict: bool = True, | ||
| flat: bool = False, | ||
| ) -> typing.Self: | ||
| if cls is NormalizationConfig and cls.get_subclass(default.get("type")) is None: | ||
| # Default subclass. | ||
| return LayerNormalizationConfig._from_dict(default, strict, flat) | ||
| return super()._from_dict(default, strict=strict, flat=flat) | ||
|
|
||
| @config_class(registry=True) | ||
| class NormalizationConfig(BaseModelConfig): | ||
|
|
||
| @config_class(dynamic_type={NormalizationConfig: "none"}) | ||
| class NoNormalizationConfig(NormalizationConfig): | ||
| _abstract = False | ||
|
|
||
| # Normalization type | ||
| type: NormalizationType = Field( | ||
| default=NormalizationType.layer_norm, | ||
| desc="The type of normalization to use, for example Layer Norm or RMS Norm.", | ||
| hint=FieldHint.architecture, | ||
| ) | ||
| @abc.abstractmethod | ||
| def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": | ||
| return torch.nn.Identity() | ||
|
|
||
|
|
||
| @config_class() | ||
| class LayerNormalizationBaseConfig(NormalizationConfig): | ||
| """ | ||
| Common configuration for layer norm and rms norm | ||
| """ | ||
|
|
||
| # TODO: Rename to normalization_epsilon | ||
| epsilon: float = Field( | ||
| default=1e-5, | ||
|
|
@@ -81,7 +100,6 @@ class NormalizationConfig(BaseModelConfig): | |
| ) | ||
|
|
||
| def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": | ||
| from fast_llm.layers.common.normalization import LayerNorm, RMSNorm | ||
| from fast_llm.tensor import init_uniform_ | ||
|
|
||
| kwargs = { | ||
|
|
@@ -96,14 +114,12 @@ def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> " | |
| kwargs["weight_init_method"] = init_uniform_( | ||
| mean - self.initialization_range, mean + self.initialization_range | ||
| ) | ||
| if self.type == NormalizationType.layer_norm: | ||
| if self.initialization_range: | ||
| kwargs["bias_init_method"] = init_uniform_(-self.initialization_range, self.initialization_range) | ||
| return LayerNorm(**kwargs) | ||
| elif self.type == NormalizationType.rms_norm: | ||
| return RMSNorm(**kwargs) | ||
| else: | ||
| raise ValueError(self.type) | ||
| return self.module_class(**kwargs) | ||
|
|
||
| @property | ||
| @abc.abstractmethod | ||
| def module_class(self): | ||
| pass | ||
|
|
||
| @classmethod | ||
| def _from_dict( | ||
|
|
@@ -120,27 +136,47 @@ def _from_dict( | |
| return super()._from_dict(default, strict, flat) | ||
|
|
||
|
|
||
| for name in NormalizationType: | ||
| # We need this because we are using the reserved field name `type`. | ||
| # TODO: Implement proper dynamic typing. | ||
| NormalizationConfig.register_subclass(name.value, NormalizationConfig) | ||
| @config_class(dynamic_type={NormalizationConfig: "layer_norm"}) | ||
| class LayerNormalizationConfig(LayerNormalizationBaseConfig): | ||
| _abstract = False | ||
|
|
||
| @property | ||
| def module_class(self): | ||
| from fast_llm.layers.common.normalization import LayerNorm | ||
|
|
||
| return LayerNorm | ||
|
|
||
| class PeftType(str, enum.Enum): | ||
| # TODO : Use a dynamic config type instead. | ||
| none = "none" | ||
| lora = "lora" | ||
|
|
||
| @config_class(dynamic_type={NormalizationConfig: "rms_norm"}) | ||
| class RMSNormalizationConfig(LayerNormalizationBaseConfig): | ||
| _abstract = False | ||
|
|
||
| @property | ||
| def module_class(self): | ||
| from fast_llm.layers.common.normalization import RMSNorm | ||
|
|
||
| return RMSNorm | ||
|
|
||
|
|
||
| @config_class() | ||
| class PeftConfig(BaseModelConfig): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no default handling (from_dict)?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't really use that one and it doesn't have a registry, these are in |
||
| @abc.abstractmethod | ||
| def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": | ||
| pass | ||
|
|
||
|
|
||
| @config_class() | ||
| class NoPeftConfig(PeftConfig): | ||
| _abstract = False | ||
|
|
||
| def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": | ||
| return linear | ||
|
|
||
|
|
||
| @config_class() | ||
| class LoRAConfig(PeftConfig): | ||
| _abstract = False | ||
|
|
||
| type: PeftType = Field( | ||
| default=PeftType.none, | ||
| desc="The type of parameter-efficient fine tuning to use Only LoRA is supported at the moment.", | ||
| hint=FieldHint.core, | ||
| ) | ||
| rank: int = Field( | ||
| default=8, | ||
| desc="The LoRA rank, i.e. the size of the intermediate dimension.", | ||
|
|
@@ -158,20 +194,15 @@ class PeftConfig(BaseModelConfig): | |
| ) | ||
|
|
||
| def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": | ||
| if self.type == PeftType.none: | ||
| return linear | ||
| elif self.type == PeftType.lora: | ||
| from fast_llm.layers.common.peft import lora_linear | ||
|
|
||
| # TODO: Init method? | ||
| return lora_linear( | ||
| linear, | ||
| linear.weight.param_init_method, | ||
| linear.weight.param_init_method, | ||
| self.rank, | ||
| self.alpha, | ||
| self.dropout, | ||
| **kwargs, | ||
| ) | ||
| else: | ||
| raise NotImplementedError(self.type) | ||
| from fast_llm.layers.common.peft import lora_linear | ||
|
|
||
| # TODO: Init method? | ||
| return lora_linear( | ||
| linear, | ||
| linear.weight.param_init_method, | ||
| linear.weight.param_init_method, | ||
| self.rank, | ||
| self.alpha, | ||
| self.dropout, | ||
| **kwargs, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@abc.abstractmethod is not needed here?