You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The tensor has shape is like (n, 1, seq, seq). n is global_batch_size divided by data_parallel_size (aka, micro_batch_size * gradient_accumulation_steps, if my understanding is correct), seq is max context length in the batch. Therefore the tensor will be very huge when context length is long and when we have large gradient_accumulation_steps.
There're lots of wasted memory of the tensor: all subtensors of shape (1, seq, seq) are the same. So it's possible to reduce this tensor size by 1/n. It will help a lot with long context training.
Describe the solution you'd like
I'm not very familiar with the codebase so my thought could be wrong here. In GPTSFTChatDataset::collate_fn, we can just use a tensor of shape (1, seq, seq) to save the attention_mask. Then we need to modify or extend the get_iterator_k_split method used in GPTSFTModel (https://github.com/NVIDIA/NeMo-Aligner/blob/main/nemo_aligner/models/nlp/gpt/gpt_sft_model.py#L86). In the modified or extended version, instead of splitting attention_mask, we just create the actual (micro_batch, 1, seq, seq) of the attention_mask.
Describe alternatives you've considered
No
Additional context
No
The text was updated successfully, but these errors were encountered:
Is your feature request related to a problem? Please describe.
In GPTSFTChatDataset::collate_fn, it seems return a huge tensor of
attention_mask
(code:https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py#L380-L381)
The tensor has shape is like (n, 1, seq, seq).
n
is global_batch_size divided by data_parallel_size (aka, micro_batch_size * gradient_accumulation_steps, if my understanding is correct),seq
is max context length in the batch. Therefore the tensor will be very huge when context length is long and when we have large gradient_accumulation_steps.There're lots of wasted memory of the tensor: all subtensors of shape (1, seq, seq) are the same. So it's possible to reduce this tensor size by 1/n. It will help a lot with long context training.
Describe the solution you'd like
I'm not very familiar with the codebase so my thought could be wrong here. In GPTSFTChatDataset::collate_fn, we can just use a tensor of shape (1, seq, seq) to save the
attention_mask
. Then we need to modify or extend theget_iterator_k_split
method used in GPTSFTModel (https://github.com/NVIDIA/NeMo-Aligner/blob/main/nemo_aligner/models/nlp/gpt/gpt_sft_model.py#L86). In the modified or extended version, instead of splittingattention_mask
, we just create the actual (micro_batch, 1, seq, seq) of the attention_mask.Describe alternatives you've considered
No
Additional context
No
The text was updated successfully, but these errors were encountered: