Skip to content

Dynamic transformer #275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 12, 2025
Merged

Dynamic transformer #275

merged 8 commits into from
Jun 12, 2025

Conversation

jlamypoirier
Copy link
Collaborator

@jlamypoirier jlamypoirier commented May 27, 2025

✨ Description

Adjust the rotary embeddings, peft and normalization layers to use the new dynamic classes. Do some cleanup and refactoring for rotary embeddings. Add an option to disable normalization layers because why not.

🔍 Type of change

Select all that apply:

  • 🐛 Bug fix (non-breaking change that addresses a specific issue)
  • 🚀 New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • 📈 Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • 🛠️ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • 📦 Dependency bump (updates dependencies, including Dockerfile or package changes)
  • 📝 Documentation change (updates documentation, including new content or typo fixes)
  • 🔧 Infrastructure/Build change (affects build process, CI/CD, or dependencies)

@jlamypoirier jlamypoirier marked this pull request as ready for review May 28, 2025 00:06
Copy link
Contributor

@RaymondLi0 RaymondLi0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some small comments here and there, but LGTM otherwise, thank you!

torch.cat([torch.full((x,), i) for i, x in enumerate(sample_lens)])
for sample_lens in sequence_lengths
]
[torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this break the document_mask below?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broken merge, good catch!

return ramp_func

def _get_correction(self, beta: float, dim: int) -> float:
return math.floor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

original implementation uses floor for low, but ceil for high.

@jlamypoirier jlamypoirier requested a review from RaymondLi0 June 12, 2025 21:02
Copy link
Contributor

@RaymondLi0 RaymondLi0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks!

desc="The type of normalization to use, for example Layer Norm or RMS Norm.",
hint=FieldHint.architecture,
)
@abc.abstractmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@abc.abstractmethod is not needed here?

@@ -194,66 +160,50 @@ class TransformerPeftConfig(PeftConfig):
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this is the right place to raise this concern, but I think its better to not have layer freezing as part of PEFT, and keep only a single layer freezing mechanism (e.g. currently using lr sclaing for different components and layers).

Hence, if user wants to use PEFT (e.g. LoRA) in conjunction with layer freezing, they might do it by explicitly freezing the layers using lr_scale parameter.

This way we can keep the LoRA/PEFT logic simple and disentangle parameter freezing from PEFT.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a fair point, but LoRA is almost always used together with freezing, so it's really more convenient to do it together. Also I don't think the arbitrary lr scaling parameters is a good long term solution, and anyway there is still only one freezing mechanism in the background, so it's not too bad.

def module_class(self):
from fast_llm.layers.common.normalization import RMSNorm

return RMSNorm


@config_class()
class PeftConfig(BaseModelConfig):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no default handling (from_dict)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really use that one and it doesn't have a registry, these are in TransformerPeftConfig instead.

@jlamypoirier jlamypoirier merged commit 016a308 into main Jun 12, 2025
3 of 4 checks passed
@jlamypoirier jlamypoirier deleted the dynamic_transformer branch June 12, 2025 22:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants