Skip to content

support async in nccl pg #211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 12, 2025
Merged

support async in nccl pg #211

merged 1 commit into from
Jun 12, 2025

Conversation

tushar00jain
Copy link
Contributor

@tushar00jain tushar00jain commented Jun 10, 2025

Summary:

  • set the same stream as the one used for work in future continuations so that random streams don't depend on pg stream (this can make these streams dependent on the allreduce stream)
  • wait on the work sent to pg's immediately on the fragment streams (used for allreduce) to make them depend on the pg stream and so that they don't depend on any future work that's submitted to those streams
  • copy grads before allreduce so that the inner optimization can use those and it doesn't create a dependency between the default stream and the pg stream
  • add back support for quantized allreduce in manager
  • change return types to be consistent with pg allreduce
  • the returned future from quantization collectives hangs (likely because set_result is not called?) so changed it to return the future directly from the pg

Test Plan:

  • tested the changes with nccl pg
  • synchronize on recovery stream sometimes makes the cpu block on collective (probably because some callback gets scheduled on the recovery stream? we need to remove synchronizing on recovery stream when there is no need to)
  • calling work.wait returned by baby nccl pg makes the cpu block on the collective (because 2 contexts can't overlap?)
  • pg gloo needs us to call future.wait in the sync phase instead of the prepare phase, so we probably need a different wrapper
  • same for baby gloo pg

Without Quantization

image

With Quantization

image

Stack created with Sapling. Best reviewed with ReviewStack.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 10, 2025
@tushar00jain tushar00jain force-pushed the pr211 branch 15 times, most recently from 29833dc to adb94c2 Compare June 12, 2025 03:29
@tushar00jain tushar00jain changed the title integrate quantization in manager support async in nccl pg Jun 12, 2025
@tushar00jain tushar00jain marked this pull request as ready for review June 12, 2025 03:53
Copy link
Member

@d4l3k d4l3k left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes generally seem reasonable though as always with streams it's a bit hard to follow the logic. It would be nice to add some tests but I'm not sure how to unit test stream logic.

Maybe we could write a stream mock and then check for certain calls/ops but don't think that would play well with the normal operations unfortunately. I think we can land this as is unless you have some ideas

Comment on lines +381 to +383
work.wait()
fut = work.get_future()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we call work.wait() there's no need for get_future, just run the callback inline

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're getting the fut here so that we can return it

@tushar00jain
Copy link
Contributor Author

These changes generally seem reasonable though as always with streams it's a bit hard to follow the logic. It would be nice to add some tests but I'm not sure how to unit test stream logic.

Maybe we could write a stream mock and then check for certain calls/ops but don't think that would play well with the normal operations unfortunately. I think we can land this as is unless you have some ideas

@d4l3k yeah that'll be something good to have but likely to be a lot of code so would prefer doing it separately. mocking will also need us to assume a lot of things about streams, which may or may not be correct. this is why looking at profiles has been useful. i have a bunch of feature asks from the profiler :)

Copy link
Member

@H-Huang H-Huang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, is what was happening before:

  1. computation stream
  2. (torch.distributed) collective stream
  3. diloco fragment stream (many)

1 was waiting on 2 which waited for 3

Now after the PR:

torchft does not queue any collectives on 2, 1 does not wait for 3

Summary:
- set the same stream as the one used for work in future continuations so that random streams don't depend on pg stream (this can make these streams dependent on the allreduce stream)
- wait on the work sent to pg's immediately on the fragment streams (used for allreduce) to make them depend on the pg stream and so that they don't depend on any future work that's submitted to those streams
- copy grads before allreduce so that the inner optimization can use those and it doesn't create a dependency between the default stream and the pg stream
- add back support for quantized allreduce in manager
- change return types to be consistent with pg allreduce
- the returned future from quantization collectives hangs (likely because set_result is not called?) so changed it to return the future directly from the pg

Test Plan:
- tested the changes with nccl pg
- synchronize on recovery stream sometimes makes the cpu block on collective (probably because some callback gets scheduled on the recovery stream? we need to remove synchronizing on recovery stream when there is no need to)
- calling `work.wait` returned by baby nccl pg makes the cpu block on the collective (because 2 contexts can't overlap?)
- pg gloo needs us to call `future.wait` in the sync phase instead of the prepare phase, so we probably need a different wrapper
- same for baby gloo pg

> Without Quantization

<img width="1188" alt="image" src="https://github.com/user-attachments/assets/8f8dd694-a972-4bc6-96a0-8a79627a4d5d" />

> With Quantization

<img width="1123" alt="image" src="https://github.com/user-attachments/assets/b54288a3-9727-4956-89e7-c8b8775a98aa" />
@tushar00jain
Copy link
Contributor Author

For my understanding, is what was happening before:

  1. computation stream
  2. (torch.distributed) collective stream
  3. diloco fragment stream (many)

1 was waiting on 2 which waited for 3

Now after the PR:

torchft does not queue any collectives on 2, 1 does not wait for 3

@H-Huang you mean because of the copied out gradients?

@H-Huang
Copy link
Member

H-Huang commented Jun 12, 2025

i mean after the stream changes, what was the changed behavior for the streams and what stream is the torchft allreduce running on?

@tushar00jain
Copy link
Contributor Author

tushar00jain commented Jun 12, 2025

@H-Huang there's also a 4th stream involved that runs the callbacks

  1. computation stream
  2. (torch.distributed) collective stream
  3. diloco fragment stream (many)
  4. callback stream
  5. recovery stream (sorry 1 more)

i think this is how it works

Before

  • 3 waits for 2
  • 4 waits for 2
  • 3 waits for 4
  • we sync 3' (for another fragment)
  • we sync 1

3 will block cpu if 4 == 3' or 4 == 1

After

we enforce 4 == 3, so

  • 3 waits for 2
  • we sync 3'
  • we sync 1

but i'm not sure now how the 5 is getting involved since 4 != 5

@tushar00jain tushar00jain merged commit 5fe8f8b into pytorch:main Jun 12, 2025
17 of 18 checks passed
@tushar00jain tushar00jain deleted the pr211 branch June 12, 2025 21:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants