-
Notifications
You must be signed in to change notification settings - Fork 14
compat mtp megatron_core main branch #92
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,9 +1,11 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Copyright (c) ModelScope Contributors. All rights reserved. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import transformer_engine | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import warnings | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from contextlib import nullcontext | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from functools import partial | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from megatron.core import InferenceParams | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from megatron.core import InferenceParams, parallel_state, tensor_parallel | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from megatron.core.enums import Fp8Recipe | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from megatron.core.fp8_utils import get_fp8_context | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from megatron.core.packed_seq_params import PackedSeqParams | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from megatron.core.tensor_parallel.mappings import (gather_from_sequence_parallel_region, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| gather_from_tensor_model_parallel_region, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -94,11 +96,6 @@ def forward( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-layer_number, dims=0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.config.recompute_granularity == 'full' and self.training: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states = self._checkpointed_forward( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| partial( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._proj_and_transformer_layer, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| packed_seq_params=packed_seq_params, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sequence_len_offset=sequence_len_offset, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states=hidden_states, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| decoder_input=decoder_input, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| attention_mask=attention_mask, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -109,6 +106,8 @@ def forward( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rotary_pos_sin=rotary_pos_sin, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| attention_bias=attention_bias, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inference_params=inference_params, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| packed_seq_params=packed_seq_params, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sequence_len_offset=sequence_len_offset, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states = self._proj_and_transformer_layer( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -127,6 +126,161 @@ def forward( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return hidden_states, input_ids, position_ids, decoder_input | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Code borrowed from NVIDIA/Megatron-LM | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _checkpointed_forward( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| decoder_input: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| attention_mask: Optional[torch.Tensor] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| context: Optional[torch.Tensor] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| context_mask: Optional[torch.Tensor] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rotary_pos_emb: Optional[torch.Tensor] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rotary_pos_cos: Optional[torch.Tensor] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rotary_pos_sin: Optional[torch.Tensor] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| attention_bias: Optional[torch.Tensor] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inference_params: Optional[InferenceParams] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| packed_seq_params: Optional[PackedSeqParams] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sequence_len_offset: Optional[torch.Tensor] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Forward through ``_proj_and_transformer_layer`` with activation | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| recomputation. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Mirrors ``transformer_block._checkpointed_forward``: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * Non-tensor objects (``attention_bias``, ``inference_params``, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ``packed_seq_params``) are captured by the ``custom_forward`` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| closure; only tensor / ``None`` arguments flow positionally | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| through the underlying checkpoint primitive. This is required | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| by both backends: ``tensor_parallel.checkpoint`` because its | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ``save_for_backward`` only accepts tensors and ``None``, and | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ``te_checkpoint`` because its reentrant implementation only | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tracks positional tensor inputs as checkpoint inputs (kwarg | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tensors are not represented in the recompute backward path). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+165
to
+173
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docstring incorrectly identifies
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * Quantized recipes (fp8, fp4) route through ``te_checkpoint``; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| everything else uses ``tensor_parallel.checkpoint``. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * Only ``fp8 + delayed scaling`` needs an outer quantization | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| context entered before ``te_checkpoint``; see the | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ``outer_quantization_context`` block below. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def custom_forward( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| decoder_input, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| attention_mask, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| context, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| context_mask, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rotary_pos_emb, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rotary_pos_cos, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rotary_pos_sin, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sequence_len_offset, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return self._proj_and_transformer_layer( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states=hidden_states, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| decoder_input=decoder_input, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| attention_mask=attention_mask, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| context=context, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| context_mask=context_mask, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rotary_pos_emb=rotary_pos_emb, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rotary_pos_cos=rotary_pos_cos, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rotary_pos_sin=rotary_pos_sin, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| attention_bias=attention_bias, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inference_params=inference_params, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| packed_seq_params=packed_seq_params, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sequence_len_offset=sequence_len_offset, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The def custom_forward(hidden_states, decoder_input):
# Get appropriate inner quantization context
if self.config.fp8:
inner_quantization_context = get_fp8_context(self.config, self.layer_number - 1)
elif self.config.fp4:
inner_quantization_context = get_fp4_context(self.config, self.layer_number - 1)
else:
inner_quantization_context = nullcontext()
with inner_quantization_context:
return self._proj_and_transformer_layer(
hidden_states=hidden_states,
decoder_input=decoder_input,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
) |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Decide the outer quantization context, matching | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # ``transformer_block._checkpointed_forward``. Only ``fp8 + delayed | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # scaling`` needs an active context at the ``te_checkpoint`` entry | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # point: TE's ``_CheckpointFunction.forward`` samples | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # ``FP8GlobalStateManager.is_fp8_enabled()`` there to gate the | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # phase-1 amax-buffer stash that phase-2 backward looks up via | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # ``global_fp8_buffer_pos_fwd_recompute``. With fp8 only entered | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # *inside* ``_proj_and_transformer_layer``, TE samples fp8 as off, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # phase-1 skips the stash, and phase-2 raises ``KeyError``. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Non-delayed fp8 recipes (MXFP8BlockScaling, Float8CurrentScaling) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # and fp4 (NVFP4BlockScaling) treat the stash/lookup as a noop, so | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # the inner context entered inside ``_proj_and_transformer_layer`` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # is sufficient. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| outer_quantization_context = get_fp8_context(self.config) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| outer_quantization_context = nullcontext() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def checkpoint_handler(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # fp4 quantization is internally implemented via TE's | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # ``fp8_autocast`` (see ``fp4_utils.get_fp4_context``), so | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # quantized recompute on either fp8 or fp4 must go through | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # ``te_checkpoint``. Matches ``transformer_block``'s policy. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.config.fp8 or self.config.fp4: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from megatron.core.extensions.transformer_engine import te_checkpoint | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return te_checkpoint( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| custom_forward, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.config.distribute_saved_activations, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tensor_parallel.random.get_cuda_rng_tracker, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| parallel_state.get_tensor_model_parallel_group(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| decoder_input, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| attention_mask, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| context, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| context_mask, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rotary_pos_emb, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rotary_pos_cos, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rotary_pos_sin, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sequence_len_offset, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # tensor_parallel.checkpoint stashes args via autograd's | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # ``save_for_backward``, which only accepts tensors and ``None``. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Pass tensor / ``None`` args positionally and capture the | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # non-tensor objects (``attention_bias``, ``inference_params``, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # ``packed_seq_params``) via the ``custom_forward`` closure. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return tensor_parallel.checkpoint( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| custom_forward, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.config.distribute_saved_activations, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| decoder_input, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| attention_mask, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| context, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| context_mask, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rotary_pos_emb, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rotary_pos_cos, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rotary_pos_sin, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sequence_len_offset, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+234
to
+265
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update the checkpoint calls to match the simplified return te_checkpoint(
custom_forward,
self.config.distribute_saved_activations,
tensor_parallel.random.get_cuda_rng_tracker,
parallel_state.get_tensor_model_parallel_group(),
hidden_states,
decoder_input,
)
else:
return tensor_parallel.checkpoint(
custom_forward,
self.config.distribute_saved_activations,
hidden_states,
decoder_input,
) |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.config.recompute_method == 'uniform': | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Uniformly divide the total number of Transformer layers and checkpoint | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # the input activation of each divided chunk. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # A method to further reduce memory usage reducing checkpoints. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert (self.config.recompute_num_layers == 1), 'recompute_num_layers must be 1 for MTP recompute' | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This assertion might be too restrictive if
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with outer_quantization_context: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| outputs = checkpoint_handler() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif self.config.recompute_method == 'block': | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # TODO: implement block-based recompute for MTP | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| warnings.warn("recompute_method == 'block' is not supported for MTP yet." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ' Skipping recompute.') | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| outputs = self._proj_and_transformer_layer( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states=hidden_states, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| decoder_input=decoder_input, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| attention_mask=attention_mask, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| context=context, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| context_mask=context_mask, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rotary_pos_emb=rotary_pos_emb, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rotary_pos_cos=rotary_pos_cos, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rotary_pos_sin=rotary_pos_sin, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| attention_bias=attention_bias, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inference_params=inference_params, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| packed_seq_params=packed_seq_params, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sequence_len_offset=sequence_len_offset, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+274
to
+292
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError('Invalid activation recompute method.') | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return outputs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _concat_embeddings(self, hidden_states: torch.Tensor, decoder_input: torch.Tensor): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Concatenate the tokens before sending to transformer layer. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To fully support FP4 quantization recomputation as referenced in the
_checkpointed_forwardlogic, you should also importget_fp4_context.