Conversation
2.Patch Fp8 as torchtitan has a bug which may lead to error 3. Fix CP init, now we can get correct position_ids to transformer to do CP.
2.Patch Fp8 as torchtitan has a bug which may lead to error 3. Fix CP init, now we can get correct position_ids to transformer to do CP.
WalkthroughThe pull request updates several files to improve code clarity through formatting changes. In Changes
Poem
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
train.py (2)
339-353: Track temporary float8 workaround.A temporary workaround has been added to handle a bug in torchtitan that affects float8 processing when float8 is not enabled and torchao is not installed.
Would you like me to create an issue to track this temporary workaround and ensure it's removed once the bug is fixed in torchtitan?
619-621: Track float8 precomputation workaround.The float8 dynamic scale precomputation for FSDP is part of the temporary workaround for the torchtitan bug.
Would you like me to create an issue to track this temporary workaround and ensure it's updated once the bug is fixed in torchtitan?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
flame/checkpoint.py(2 hunks)flame/parallelisms/pipeline_fla.py(2 hunks)train.py(16 hunks)
✅ Files skipped from review due to trivial changes (2)
- flame/checkpoint.py
- flame/parallelisms/pipeline_fla.py
🧰 Additional context used
🪛 Ruff (0.8.2)
train.py
30-30: torchtitan.model_converter.build_model_converters imported but unused
Remove unused import: torchtitan.model_converter.build_model_converters
(F401)
🪛 GitHub Actions: pr
train.py
[error] 19-19: flake8: 'torchtitan.model_converter.build_model_converters' imported but unused (F401)
🔇 Additional comments (1)
train.py (1)
471-494: LGTM! Improved logging clarity.The logging changes enhance readability and provide more comprehensive training progress information, including:
- Sequence length
- Gradient accumulation steps
- Batch sizes
- Total optimization steps
- Warmup steps
- Model parameters
| from flame.utils import device_module, device_type | ||
| from torchtitan.float8 import Float8Converter | ||
| from torchtitan.logging import init_logger, logger | ||
| from torchtitan.model_converter import build_model_converters |
There was a problem hiding this comment.
Remove unused import.
The import build_model_converters from torchtitan.model_converter is not used in the code.
Apply this diff to fix the issue:
-from torchtitan.model_converter import build_model_converters📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from torchtitan.model_converter import build_model_converters |
🧰 Tools
🪛 Ruff (0.8.2)
30-30: torchtitan.model_converter.build_model_converters imported but unused
Remove unused import: torchtitan.model_converter.build_model_converters
(F401)
| from torchtitan.config_manager import TORCH_DTYPE_MAP, JobConfig | ||
| from torchtitan.logging import init_logger, logger | ||
| from torchtitan.optimizer import OptimizersContainer, SchedulersContainer | ||
| from torchtitan.optimizer import OptimizersContainer, LRSchedulersContainer |
| model_parts: List[nn.Module], | ||
| optimizers: OptimizersContainer, | ||
| lr_schedulers: SchedulersContainer, | ||
| lr_schedulers: LRSchedulersContainer, |
| from torchtitan.parallelisms.pipelining_utils import (build_pipeline_schedule, | ||
| generate_split_points, | ||
| stage_ids_this_rank) | ||
| from torchtitan.parallelisms.pipeline import ( |
| from flame.parallelisms.pipeline_fla import pipeline_fla | ||
| from flame.utils import device_module, device_type | ||
| from torchtitan.float8 import Float8Handler | ||
| from torchtitan.float8 import Float8Converter |
| # swap to Float8Linear based on float8 configs | ||
| float8_handler.convert_to_float8_training(model) | ||
| """ | ||
| # !TODO[flame]: torchtitan@57387af0e0e6173e7c0f3a38ac5db1134bb376d5 introduces a bug that cannot handel the case: |
There was a problem hiding this comment.
Actual Changes.
If FP8 is not enabled, skipping the conversion to FP8 is better. Because if one did not install torchao, the FP8 Converter here might throw an error even if FP8 is not enabled.
| cp_buffers=[input_ids, labels] + [m.freqs_cis for m in model_parts], | ||
| cp_seq_dims=[1, 1] + [0 for _ in model_parts], | ||
| cp_no_restore_buffers={input_ids, labels}, | ||
| cp_buffers=[input_ids, labels, position_ids], |
There was a problem hiding this comment.
Actual Changes.
Now, position_ids can be distributed via CP.
A model with HF-Llama style can take position_ids and apply the correct rope.
(We still need to fix on FLA-Attention to correctly handle the position_ids)
Summary by CodeRabbit
New Features
Refactor / Style