-
Notifications
You must be signed in to change notification settings - Fork 33
Dynamic transformer #275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dynamic transformer #275
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some small comments here and there, but LGTM otherwise, thank you!
torch.cat([torch.full((x,), i) for i, x in enumerate(sample_lens)]) | ||
for sample_lens in sequence_lengths | ||
] | ||
[torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Won't this break the document_mask below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Broken merge, good catch!
return ramp_func | ||
|
||
def _get_correction(self, beta: float, dim: int) -> float: | ||
return math.floor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
original implementation uses floor
for low
, but ceil
for high
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks!
desc="The type of normalization to use, for example Layer Norm or RMS Norm.", | ||
hint=FieldHint.architecture, | ||
) | ||
@abc.abstractmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@abc.abstractmethod is not needed here?
@@ -194,66 +160,50 @@ class TransformerPeftConfig(PeftConfig): | |||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure this is the right place to raise this concern, but I think its better to not have layer freezing as part of PEFT, and keep only a single layer freezing mechanism (e.g. currently using lr sclaing for different components and layers).
Hence, if user wants to use PEFT (e.g. LoRA) in conjunction with layer freezing, they might do it by explicitly freezing the layers using lr_scale
parameter.
This way we can keep the LoRA/PEFT logic simple and disentangle parameter freezing from PEFT.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a fair point, but LoRA is almost always used together with freezing, so it's really more convenient to do it together. Also I don't think the arbitrary lr scaling parameters is a good long term solution, and anyway there is still only one freezing mechanism in the background, so it's not too bad.
def module_class(self): | ||
from fast_llm.layers.common.normalization import RMSNorm | ||
|
||
return RMSNorm | ||
|
||
|
||
@config_class() | ||
class PeftConfig(BaseModelConfig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no default handling (from_dict)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't really use that one and it doesn't have a registry, these are in TransformerPeftConfig
instead.
✨ Description
Adjust the rotary embeddings, peft and normalization layers to use the new dynamic classes. Do some cleanup and refactoring for rotary embeddings. Add an option to disable normalization layers because why not.
🔍 Type of change
Select all that apply: