Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory inefficiency when loading attention_mask, causing dataloader OOM with long context #488

Open
shensimeteor opened this issue Jan 22, 2025 · 0 comments

Comments

@shensimeteor
Copy link

Is your feature request related to a problem? Please describe.

In GPTSFTChatDataset::collate_fn, it seems return a huge tensor of attention_mask (code:
https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py#L380-L381)

The tensor has shape is like (n, 1, seq, seq). n is global_batch_size divided by data_parallel_size (aka, micro_batch_size * gradient_accumulation_steps, if my understanding is correct), seq is max context length in the batch. Therefore the tensor will be very huge when context length is long and when we have large gradient_accumulation_steps.

There're lots of wasted memory of the tensor: all subtensors of shape (1, seq, seq) are the same. So it's possible to reduce this tensor size by 1/n. It will help a lot with long context training.

Describe the solution you'd like

I'm not very familiar with the codebase so my thought could be wrong here. In GPTSFTChatDataset::collate_fn, we can just use a tensor of shape (1, seq, seq) to save the attention_mask. Then we need to modify or extend the get_iterator_k_split method used in GPTSFTModel (https://github.com/NVIDIA/NeMo-Aligner/blob/main/nemo_aligner/models/nlp/gpt/gpt_sft_model.py#L86). In the modified or extended version, instead of splitting attention_mask, we just create the actual (micro_batch, 1, seq, seq) of the attention_mask.

Describe alternatives you've considered

No

Additional context

No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant