Skip to content

Enable DoMINO parallelization via ShardTensor #838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 66 commits into
base: domino
Choose a base branch
from

Conversation

coreyjadams
Copy link
Collaborator

Description

This PR is the initial version of domain parallelism for DoMINO. Currently, the forward pass (inference) is supported though there are some numerical instabilities to track down with Conv3d and certain shapes.

Included in this PR are a number of other pending parallelization operations, many of which are required for DoMINO but not all. These include:

  • Sequence parallel attention via ring attention
  • normalization layers (group Norm, layer norm via DTensor)
  • ConvTranspose
  • Upsampling via torch interpolate
  • MaxPooling and AvgPooling

The halo passing (and new ring algorithm) are also reorganized. This should help make the halo algorithm more readable and maintainable.

There are some remaining optimizations to be done:

  • The ring message passing should be made into an async op
  • The convolution and ball query wrappers should avoid "infer" for shard tensor creation to avoid blocking

To include testing of all the distributed algorithms would add several hundred distributed tests. With the fork-every-test set up, this is just not feasible in the CI. The tests currently exist in another repository for validating numerical correctness of external operations. PhysicsNeMo-implemented operations (BallQuery) have tests implemented here.

(After the release of 25.03 and renaming, I had to do some manual merging on many files. so there are changed files unrelated to this PR)

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

Dependencies

External testing repo, to be opened or perhaps merged.

coreyjadams and others added 27 commits February 20, 2025 07:49
* Stashing profiling work

* Torch profile works but is very slow.  line profiler not functional at this time

* Enablement of profiling tool with pytorch profiler, as a context manager.  Still several TBD Objects but this implementation will capture a torch profile.

* Moving profiling tools into a directory to make separate tools more clearly separated as well as enable easier extensions.

* Profiling tools work with torch profiler and line_profiler.  nsys has a crash that I haven't resolved yet.

* Fix line profiling construction

* Begin instrumenting figconvnet and adding tutorials on modulus profiling tools

* Remove annotations and force all annotations to conform to nvtx.  Simpler, for now, and the most (only?) useful annotation tool

* Updating profiling tutorial

* Minor updates to profiling interfaces

* only adding some profiling hooks to figconvnet

* Add profiling hooks to mesh graph net.

* Set TELayerNorm to default layer norm in MeshGraphNet

* Nearly finished profiling tutorial and tooling example.  Just need to add images.

* Final (first) draft of the profiling tutorial and clean up profiler code slightly.  Ready for draft PR

* Add tests to the profiler tools to check functionality.  Thanks Cursor!

Some minor updtes to the tools themselves to accomodate instance clearing and refreshing.

* Update changelog for profiling tools

* Update profiler files to (hopefully) pass CI checks

* Remove profiling parts from capture.py for later integration

* Update __init__.py

Remove nvtx wrapper

* Add extra line to make linting happy...

* When cuda is not available (mostly CI), emit a warning and switch to native layer norm.

* Make the default as LayerNorm so tests will pass.  Needs more care in the test, I think, about TELayerNorm

* Very minor fixes per review

* Resolve most comments from PR review.  One to go (profiler state to become a literal)

* Change profiler state tracker to a single state with an enum type.

* Two changes made here:
- the exit stack moves from a class variable to an instance variable
- The double-check locking mechanism in the registry becomes a single lock and check.

* Make sure the exit stack init is actually in __init__ and not initialize()
* Enable mesh-based parallelism as the configuration backend, even for simple DDP sharding

* Fix small typo in docstring

* Remove  unnecessary  functions with new interface

* Adding first implementation of ShardTensor prototype.  Still several pieces are WIP but this has basic functionality supported for creation and forward usage.

* Working implementation of ShardTensor, though still somewhate incomplete.

* Adding work-in-progress examples.  Be careful of sharp edges!

* A few more example pieces before natten will work out of the box.  Most of the ops have been validated, all that remains is to  wrap the na2d function call to ensure it will dispatch properly.

* Fix naming scheme

* Minor name change

* Add monkey patching for na2d operation with shard tensors

* Fix bug in shard tensor inference of globla size.  CHeck agains sharding in unbind op rules.

* Enable backwards gradients for halo sharding and natten patch

* Convolution 2d backwards works, though would be  better to catch torch.ops.aten.convolution.default.

* Fix missing import and ensure tensors are contiguous before allgather_v

* Clean up and remove unnecessary noise and printouts for debugging

* Unify (and correct!) the sharded convolution implementation.  There was also a minor bug in the backward
pass that got more pronounced with smaller data: grad inputs were failing to properly collect
haloed gradients and add them on the edges.  Now fixed.

* Remove noise from sharding utils.

* For smaller tensors, the alltoall step of halo reductions might be significant overhead.
I'm implementing here an option to switch to peer to peer message passing, since it might
benefit from stream utilization in layers like natten.na2d.

It's a developer choice currently, not a user choice.

* Remove shard_utils file, it is a subfolder.

* Add modulus ShardTensor api documentation

* Clean up doc strings, type annotations and mesh implementation.  No significant functionality changes in this commit.

* Add significant docstring / type annotation cleanup to ShardTensor.

Add `scatter_tensor` function to enable more easy transition to shard tensor.
This function allows users to maintain data pipelines (on one rank) and easily
scatter that data to a domain mesh.

* Remove neighborhood attention prototypes

* Remove the rest of these examples since they are outdated and unnecessary

* Mostly, this commit is adding type annotations and doc strings.

But also, this adjusts the shard tensor mechanism for tracking shard info to use
a dict instead of a list of tuples.

* Clean up and document conv patches.
No real code changes applied here.

* clean up and improve documentation and type hints for shard utils worker functions

* Adding basic tests for shard tensor initialization and redistribution.

There appears to be one corner case in redistribute to fix.  TBD.

Tests for grad propogation are coming.

* Add full working example of multilevel parallelism with pytorch
FSDP and modulus ShardTensor

* Add missing type annotations

* Ensure scatter_tensor is available to import from modulus.distributed

* Update changelog and ensure wrapt is a optional dependency

* Update fsdp_and_shard_tensor.rst

Update tutorial based on feedback from @pzharrington

* Update __init__.py

Remove wildcard import.

* Update shard_tensor.py

fix spacing

* This is an essential bug fix for a missing import

* Update branch to pass CI tests.

* This commit provides several pieces:

- First, the ability to transpose the sharding dimensions is supported.  For square submeshs, 2x2 for example,
the output sharding will match the input sharding if it's uneven.  This can only be supported if the number of
devices in the output mesh dimension is equal to the input dimension, hence the restriction on square submeshes.
Other scenarios will apply dtensor-like chunk syntax, but return a shard tensor tracking that split.  Comprehensive
tests on 1D and 2D meshes are included here.  No testing is done at this time on 3D sharding / meshes.

- Second, the issues with torch.mean are intercepted and fixed.  This uses a new dispatch intercept (below)
and applies a weight to the mean, and converts the Partial placement to a Partial(sum) with the weight applied.
This has a bug that appears to be present in DTensor too: reductions over non-sharded dimensions appear to falter.
To be fixed in a future release.

- Third, ShardTensor has a new class attribute to accomodate operator interceptions.  The only applied function
at this time are variants of aten.mean, however, it is expected to convert all monkey patching to this syntax.

* Update monkey patching to ensure patches get applied by modulus, and don't require
them to trigger elsewhere.  If ShardTensor is used, the patches get applied.

Also, minor updates to docs.

* Codify ShardTensor and FSDP in tutorials.

* Apparently, codify'ing in rst requires double ticks.

* This commit fixes gradient propagation for unevenly sharded tensors.  Tests are coming in the next commit immediately after.

* Add tests for shard tensor: initialization, resharding, and gradient sharding.

Further, fixed an annoying bug in other distributed tests where OS environs weren't cleared after testing, and tsome tests would fail but only if others ran first.

Now, all distributed tests use a context manager to change OS environment variables locally only.

* Two things done here:
- Enable dynamic (off by default) wrapping of layers by shard tensor.  they get turned on automatically when a shard tensor is created.
- Rename the utils to manage env variables.

Tests are failing with unusual CPU errors on ORD.  Moving to github runners ...

* Disable patched operations by default.
- First, there is a bug upstream in pytorch.  The profile will currently fall over with stack=True, so its off by default here.
- Second, refactoring the state led to some subtle logic errors, where previously it was
  possible to have enabled and initialized both be true.  That broke.  This commit fixes
  by assuming an numerical state progression, so it's now a __ge__ comparision and if the
  state is ENABLED for example, @Property(initialized()) evaluates to true.
A pytorch DeviceMesh provides syntax to to access the ProcessGroup across a mesh dimension.
Sometimes, it is most easy to access a group that includes all devices of a mesh.  This isn't
included in the upstream syntax, so it's built out here.  Because the cost of creating a group
is not small, the groups are cached using the devicemesh itself (hashed) as a key.

In 1D meshes, the underlying group for that mesh is returned rather than creating a new group.
… better.

The main performance issue with ShardTensor appears to be blocking DtoH and HtoD
copies for infering the sharding shapes and sizes.

Construction of a shard tensor now takes an optional argument `sharding_shapes`
to dictate how sharding shapes are determined.

"infer" will use group-wide communication to allgather the shapes on each sharded
mesh dimension.

"chunk" will assume DTensor-like chunking.  A sanity check will be performed,
that the local shape matches the computed shape.  In the event that the local
shape matches but only on one rank, this could lead to a hang - it is because
the input DTensor would have been incorrectly formatted.  No communication
is done with "chunk" method unless the sanity check fails.

Sharding shapes can be passed directly, too.  Global shape will be inferred
in this case.

Additionally, `scatter_tensor` has been made a little more versatile
at the cost of slightly worse performance.  There are possible optimizations
but unlikely to provide serious performance benefits yet.
…ity of the halo communication.

It is incorporated into both conv* as well as natten.  Other operations that require a halo
on image-like data should be supportable with this update more easily.
…anch,

after it diverged from the renamed release branch.

It also fixes a small typo in a warning message in BallQuery.
Now, a = b[:,:,3] will work on sharded tensors as long as you aren't
selecting on the sharded dimension.
…nd maintain.

This reorganization isolates the layers of halo passing by conceptual
operation (building the halos, communicate halos, apply halos, slice
off residuals, etc).  The code is a little longer but the upshot
is the availability of high level functions `halo_padding` and `unhalo_padding`
which are both differentiable and easily applied.

Also introduces a ring message passing function.  Note that this function
is synchronous (so is the halo) and while the halo needs to be synchronous,
ring message passing often does not.  It's included nevertheless as a
simple, easy to use version to enable debugging of the overlapped version.

Both Halos and Rings and now configured with light dataclass objects
to make the number of arguments passed around simpler, and easier to maintain
state between forward and backward passes.
…bor upsampling,

attention layers via sequence parallelism, and a semi-parallel version of BallQuery
from physicsnemo.

With this commit, the DoMINO model can be used in a domain-parallel way with ShardTensor.
Some optimizations and one further level of parallelism remain.
…ded tensors).

It also includes some edge case fixes for the ball query ring ops, which change minor
details in a the attention patches.
@coreyjadams coreyjadams added the ! - Release PRs or Issues releating to a release label Apr 4, 2025
@coreyjadams coreyjadams marked this pull request as draft April 4, 2025 18:11
@coreyjadams
Copy link
Collaborator Author

/blossom-ci

coreyjadams and others added 30 commits April 16, 2025 06:48
…e can switch to

threaded reading when the preprocessing is ready and available.
Ensure compute_scale_factors works even with GPU preprocessing.
in a context on device != 0 would allocate memory on device 0.
… each iteration. Otherwise, it leads to a memory leak.
…chunk_aligned_

and read directly into a numpy buffer.  This enables better multithreading since each
thread only interfaces with one zarr chunk.
Limited by the threadpool and IO speed.  It'd be nice to
stream right into pinned memory but it seems to be too
large data reads for that pool.  TBD.
…ed from the standard

pipeline, with several extra pieces of information:
- the domain mesh over which the data pipeline is sharded
- Whether to shard point-like outputs (volume fields, surface fields, etc)
- Whether to shard grid-like outputs

This commit also includes some minor refinements to the standard pipeline
to make bootstraping a sharded version functional.
shardtensor's tools.

Finish removing length variables
… that are

significantly simpler and shorter, while producing numerically consistent
results.

Original functions are maintained in this commit and the training script
compares individual loss components as well as total loss.
torch.chunk do not share the same splitting of tensors.  When
redistribute is called, without a "plan" for chunking, it needs
to use torch.chunk to ensure the shapes are what DTensor and the
size validation expect.

This also changes the behavior of the checking: a simple check that
the local shape matches the spec's stored shape along the first mesh
dimension.  If this fails, the code now crashes.  Previously, it was
possible to fail on only some ranks and the collectives became
disordered across ranks.
…e and destination

however this didn't account properly for non-global meshes.  This commit uses local ranks
for determing the ID of source/destination, though they are converted to global
indexing to send the messages.
- torch select and index select now intercept at the torch level, instead
of at aten.  This ensures proper backwards pass gradient sharding.
- Mean and Sum reductions now completely ignore DTensor implementation
on ShardTensors.  The motivation for this is that the backwards pass
in DTensor will not shard gradients properly. It's not an issue in
DTensor but is problematic with domain parallelism.
… sharding in DoMINO.

- shard_tensor will now shard gradients in the backward pass when converting from torch.Tensor.
- unpooling patches was updated to calculate output shapes with no communication overhead.
- point cloud ops raises an exception if the backwards call in RingBallQuery is called.
  its not implemented correctly, yet, but also not used yet.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
! - Release PRs or Issues releating to a release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants