Skip to content

Enable DoMINO parallelization via ShardTensor #838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 66 commits into
base: domino
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
07f8baa
Profiling (#787)
coreyjadams Feb 20, 2025
e735f11
Enable Domain Parallelism with ShardTensor (#784)
coreyjadams Feb 20, 2025
48f2b21
name change
ktangsali Feb 21, 2025
6d7af37
name change docs
ktangsali Feb 21, 2025
e1a9ce8
This commit addresses two issues:
coreyjadams Feb 21, 2025
44b3bb8
Minor fixes and updates to the profiling utility.
coreyjadams Feb 25, 2025
c9ecd53
Add functionality to distributed manager to provide mesh-wide groups.
coreyjadams Feb 25, 2025
dfac94e
Performance enhancements to shard tensor. Not fully optimized yet bu…
coreyjadams Feb 26, 2025
b2067a2
Hot fix - interface for mesh names was incorrect.
coreyjadams Feb 28, 2025
19238db
Small updates to ShardTensor and redistribution methods.
coreyjadams Feb 28, 2025
42af72a
This commit improves the functionality, readability, and maintainabil…
coreyjadams Mar 11, 2025
bfcc018
Add support for a select group of conv_transpose shapes where kernel …
coreyjadams Mar 11, 2025
db3ff9f
Enable group normalization with shard tensor.
coreyjadams Mar 11, 2025
48dab53
Add attention mechanism (scaled_dot_product_attention) to supported S…
coreyjadams Mar 12, 2025
9f83671
Add average pooling fucntionality for select shapes.
coreyjadams Mar 12, 2025
90a2ec0
Enable pooling, normalization, and attention patches when registering…
coreyjadams Mar 12, 2025
942ec2a
Remove printouts ...
coreyjadams Mar 12, 2025
61ae414
Merge branch 'main' into shardTensorFeature
coreyjadams Mar 31, 2025
9f982a0
Merge branch modulus:main into shardTensorFeature
coreyjadams Mar 31, 2025
67bc29d
Merge branch 'NVIDIA:main' into shardTensorFeature
coreyjadams Apr 1, 2025
e67ec56
Merge branch 'NVIDIA:main' into shardTensorFeature
coreyjadams Apr 2, 2025
17e3ffc
This commit addresses issues that arose in the merge of my feature br…
coreyjadams Apr 2, 2025
018e5d9
Add a sharding propagation for aten.select.int.
coreyjadams Apr 2, 2025
4ad0dae
Reorganize the halo and ring message passing to be easier to follow a…
coreyjadams Apr 2, 2025
be2ef41
Merge branch 'shardTensorFeature' of github.com:coreyjadams/physicsne…
coreyjadams Apr 2, 2025
1268634
This commit adds support for Max Pooling, Unpooling via nearest neigh…
coreyjadams Apr 2, 2025
b0a82af
This commit adds tests for RingBallQuery (which is ball query on shar…
coreyjadams Apr 4, 2025
6573fc4
make sure that convolutions and ball query compute output shapes and …
coreyjadams Apr 4, 2025
6bd2791
Add profiling hooks to convolution wrapper and halo padding.
coreyjadams Apr 7, 2025
aab6a71
Merge branch 'NVIDIA:main' into shardTensorFeature
coreyjadams Apr 8, 2025
b6711fd
Disable the `length` variables in BallQuery. They are unused, but st…
coreyjadams Apr 8, 2025
0d9be0d
This commit applies some reorganizations to the ball query layer to e…
coreyjadams Apr 10, 2025
ddb2a00
Merge branch 'NVIDIA:main' into domino_perf
coreyjadams Apr 10, 2025
daec55c
Optimizations and efficiency improvements in the domino datapipe. Hi…
coreyjadams Apr 14, 2025
9e93054
Merge remote-tracking branch 'upstream/domino' into domino_perf
coreyjadams Apr 15, 2025
13daff1
Remove obsolete and unused dataclasses - it's a flat config heirarchy…
coreyjadams Apr 15, 2025
610bc2a
This commit enables reading the old-style pickled files by default. …
coreyjadams Apr 16, 2025
473f7d2
Provide more robust reading of pickled files.
coreyjadams Apr 16, 2025
33bf72c
Fix several small bugs: the dataloader sometimes implicitly uses cupy…
coreyjadams Apr 16, 2025
3309bc3
Fix issue if using CPU data loading.
coreyjadams Apr 16, 2025
2c94eb1
Ensure all gpu preprocessing is directed to the proper device
coreyjadams Apr 16, 2025
0d1412e
Ensure that the dataloader doesn't waste GPU memory. Previously, loa…
coreyjadams Apr 16, 2025
10c42b9
Enable zarr readers. Use file path to toggle which type of file to r…
coreyjadams Apr 17, 2025
e869c8e
Improve logging and track memory leak. Enable zarr.
coreyjadams Apr 17, 2025
b07a897
Add GPU monitoring to the training script, and recreate the knn class…
coreyjadams Apr 17, 2025
60068d2
Merge branch 'domino' into domino_perf
coreyjadams Apr 17, 2025
4279259
Merge branch 'domino_perf' into shardTensorFeature
coreyjadams Apr 17, 2025
6fcdb07
Enforce the determinism request in the domino pipeline.
coreyjadams Apr 17, 2025
dac4734
This commit makes an improvement to the zarr reading: reads are now _…
coreyjadams Apr 17, 2025
c05f235
Put ALL zarr chunk reads into futures and thread the IO.
coreyjadams Apr 17, 2025
439deac
Introduce a Sharded data pipeline for DoMINO. This class is construc…
coreyjadams Apr 18, 2025
275d7c5
Merge branch 'domino_perf' into shardTensorFeature
coreyjadams Apr 18, 2025
9b91893
Update ball query module to call to the functional interface to leverage
coreyjadams Apr 21, 2025
5e45e4a
Merge remote-tracking branch 'upstream/domino' into shardTensorFeature
coreyjadams Apr 21, 2025
207a578
This commit creates alternative versions of the domino loss functions…
coreyjadams Apr 23, 2025
b9236f7
Remove older loss functions and consolidate script.
coreyjadams Apr 23, 2025
1edf788
Merge branch 'domino_loss_fn' into shardTensorFeature
coreyjadams Apr 23, 2025
3960424
Merge loss function updates.
coreyjadams Apr 23, 2025
138adac
This commit address a bug in shard tensor: torch.tensor_split and
coreyjadams Apr 28, 2025
b2ab2c0
Ensure the backwards gradient computations uses consistent types.
coreyjadams Apr 28, 2025
3c21174
In ring calculations, the global rank was being used to compute sourc…
coreyjadams Apr 28, 2025
3075207
Implement sharded version of torch's index_select.
coreyjadams Apr 28, 2025
5efead5
Merge branch 'domino' into shardTensorFeature
coreyjadams Apr 28, 2025
5c9ece9
This commit enables the following pieces:
coreyjadams Apr 29, 2025
21b97f6
This commit handles some of the final updates required to enable full…
coreyjadams Apr 30, 2025
d0aa534
Add profiling hooks to the domino model.
coreyjadams Apr 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
Timer unit: 1e-09 s

Total time: 0.0635549 s
File: /root/physicsnemo/docs/test_scripts/profiling/annotated_code/attn.py
Function: forward at line 31

Line # Hits Time Per Hit % Time Line Contents
==============================================================
31 @profile
32 def forward(self, x: torch.Tensor) -> torch.Tensor:
33
34 8 47352.0 5919.0 0.1 B, N, C = x.shape
35 8 41506771.0 5e+06 65.3 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
36 8 250098.0 31262.2 0.4 q, k, v = qkv.unbind(0)
37
38
39 # This is not optimal code right here ...
40 8 4760805.0 595100.6 7.5 q = q * self.scale
41 8 7717546.0 964693.2 12.1 attn = q @ k.transpose(-2, -1)
42 8 6836559.0 854569.9 10.8 attn = attn.softmax(dim=-1)
43 8 467249.0 58406.1 0.7 attn = self.attn_drop(attn)
44 8 681692.0 85211.5 1.1 x = attn @ v
45
46 8 273367.0 34170.9 0.4 x = x.transpose(1, 2).reshape(B, N, C)
47 8 822387.0 102798.4 1.3 x = self.proj(x)
48 8 182694.0 22836.8 0.3 x = self.proj_drop(x)
49 8 8381.0 1047.6 0.0 return x

Total time: 0.00505685 s
File: /root/physicsnemo/docs/test_scripts/profiling/annotated_code/attn.py
Function: forward at line 65

Line # Hits Time Per Hit % Time Line Contents
==============================================================
65 @profile
66 def forward(self, x):
67 8 493290.0 61661.2 9.8 x = self.fc1(x)
68 8 3648444.0 456055.5 72.1 x = self.gelu(x)
69 8 155234.0 19404.2 3.1 x = self.drop1(x)
70 8 437770.0 54721.2 8.7 x = self.fc2(x)
71 8 197332.0 24666.5 3.9 x = self.gelu(x)
72 8 120381.0 15047.6 2.4 x = self.drop2(x)
73 8 4401.0 550.1 0.1 return x

Total time: 0.0855953 s
File: /root/physicsnemo/docs/test_scripts/profiling/annotated_code/attn.py
Function: forward at line 104

Line # Hits Time Per Hit % Time Line Contents
==============================================================
104 @profile
105 def forward(self, x: torch.Tensor) -> torch.Tensor:
106 8 79767938.0 1e+07 93.2 x = x + self.attn(self.norm1(x))
107 8 5821609.0 727701.1 6.8 x = x + self.mlp(self.norm2(x))
108 8 5719.0 714.9 0.0 return x

Total time: 4.54253 s
File: /root/physicsnemo/docs/test_scripts/profiling/annotated_code/workload.py
Function: workload at line 30

Line # Hits Time Per Hit % Time Line Contents
==============================================================
30 @profile
31 def workload(cfg):
32
33 1 304686.0 304686.0 0.0 ds = RandomNoiseDataset(cfg["shape"])
34
35 2 202324.0 101162.0 0.0 loader = DataLoader(
36 1 170.0 170.0 0.0 ds,
37 1 42921.0 42921.0 0.0 batch_size=cfg["batch_size"],
38 1 180.0 180.0 0.0 shuffle = True,
39 )
40
41
42 # Initialize the model:
43 3 79343062.0 3e+07 1.7 model = Block(
44 1 79202.0 79202.0 0.0 dim = cfg["shape"][-1],
45 1 73421.0 73421.0 0.0 num_heads = cfg.model["num_heads"],
46 1 52831.0 52831.0 0.0 qkv_bias = cfg.model["qkv_bias"] ,
47 1 50570.0 50570.0 0.0 attn_drop = cfg.model["attn_drop"],
48 1 50352.0 50352.0 0.0 proj_drop = cfg.model["proj_drop"],
49 1 109143037.0 1e+08 2.4 ).to("cuda")
50
51 1 107952.0 107952.0 0.0 if cfg["train"]:
52 1 2059751411.0 2e+09 45.3 opt = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)
53
54 1 270.0 270.0 0.0 times = []
55 2 74021.0 37010.5 0.0 with Profiler() as p:
56 1 1521.0 1521.0 0.0 start = time.perf_counter()
57 9 2022090217.0 2e+08 44.5 for i, batch in enumerate(loader):
58 8 74692.0 9336.5 0.0 image = batch["image"]
59 8 47176297.0 6e+06 1.0 image = image.to("cuda")
60 16 364408.0 22775.5 0.0 with annotate(domain="forward", color="blue"):
61 8 85942255.0 1e+07 1.9 output = model(image)
62 8 871577.0 108947.1 0.0 if cfg["train"]:
63 8 1728264.0 216033.0 0.0 opt.zero_grad()
64 # Compute the loss:
65 8 24881952.0 3e+06 0.5 loss = loss_fn(output)
66 # Do the gradient calculation:
67 16 181082.0 11317.6 0.0 with annotate(domain="backward", color="green"):
68 8 50119067.0 6e+06 1.1 loss.backward()
69 # Apply the gradients
70 8 58985347.0 7e+06 1.3 opt.step()
71 8 35302.0 4412.8 0.0 p.step()
72 8 27261.0 3407.6 0.0 end = time.perf_counter()
73 8 352396.0 44049.5 0.0 print(f"Finished step {i} in {end - start:.4f} seconds")
74 8 4790.0 598.8 0.0 times.append(end - start)
75 8 6301.0 787.6 0.0 start = time.perf_counter()
76
77 1 84802.0 84802.0 0.0 times = torch.tensor(times)
78 # Drop first and last:
79 1 117812.0 117812.0 0.0 avg_time = times[1:-1].mean()
80 # compute throughput too:
81 1 152063.0 152063.0 0.0 throughput = cfg["batch_size"] / avg_time
82 1 60321.0 60321.0 0.0 print(f"Average time per iteration: {avg_time:.3f} ({throughput:.3f} examples / s)")

Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
Timer unit: 1e-09 s

Total time: 0.0712444 s
File: /root/physicsnemo/docs/test_scripts/profiling/fixed_data_loader/attn.py
Function: forward at line 31

Line # Hits Time Per Hit % Time Line Contents
==============================================================
31 @profile
32 def forward(self, x: torch.Tensor) -> torch.Tensor:
33
34 32 54830.0 1713.4 0.1 B, N, C = x.shape
35 32 41511385.0 1e+06 58.3 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
36 32 368675.0 11521.1 0.5 q, k, v = qkv.unbind(0)
37
38
39 # This is not optimal code right here ...
40 32 5547725.0 173366.4 7.8 q = q * self.scale
41 32 9591905.0 299747.0 13.5 attn = q @ k.transpose(-2, -1)
42 32 7572284.0 236633.9 10.6 attn = attn.softmax(dim=-1)
43 32 915209.0 28600.3 1.3 attn = self.attn_drop(attn)
44 32 2071393.0 64731.0 2.9 x = attn @ v
45
46 32 981847.0 30682.7 1.4 x = x.transpose(1, 2).reshape(B, N, C)
47 32 1999401.0 62481.3 2.8 x = self.proj(x)
48 32 598419.0 18700.6 0.8 x = self.proj_drop(x)
49 32 31321.0 978.8 0.0 return x

Total time: 0.00933511 s
File: /root/physicsnemo/docs/test_scripts/profiling/fixed_data_loader/attn.py
Function: forward at line 65

Line # Hits Time Per Hit % Time Line Contents
==============================================================
65 @profile
66 def forward(self, x):
67 32 1681602.0 52550.1 18.0 x = self.fc1(x)
68 32 4308446.0 134638.9 46.2 x = self.gelu(x)
69 32 525102.0 16409.4 5.6 x = self.drop1(x)
70 32 1591600.0 49737.5 17.0 x = self.fc2(x)
71 32 773671.0 24177.2 8.3 x = self.gelu(x)
72 32 441258.0 13789.3 4.7 x = self.drop2(x)
73 32 13428.0 419.6 0.1 return x

Total time: 0.0998171 s
File: /root/physicsnemo/docs/test_scripts/profiling/fixed_data_loader/attn.py
Function: forward at line 104

Line # Hits Time Per Hit % Time Line Contents
==============================================================
104 @profile
105 def forward(self, x: torch.Tensor) -> torch.Tensor:
106 32 88075019.0 3e+06 88.2 x = x + self.attn(self.norm1(x))
107 32 11725880.0 366433.8 11.7 x = x + self.mlp(self.norm2(x))
108 32 16161.0 505.0 0.0 return x

Total time: 4.07266 s
File: /root/physicsnemo/docs/test_scripts/profiling/fixed_data_loader/workload.py
Function: workload at line 30

Line # Hits Time Per Hit % Time Line Contents
==============================================================
30 @profile
31 def workload(cfg):
32
33 1 171664.0 171664.0 0.0 ds = RandomNoiseDataset(cfg["shape"])
34
35 2 204305.0 102152.5 0.0 loader = DataLoader(
36 1 180.0 180.0 0.0 ds,
37 1 34230.0 34230.0 0.0 batch_size=cfg["batch_size"],
38 1 140.0 140.0 0.0 shuffle = True,
39 )
40
41
42 # Initialize the model:
43 3 79993605.0 3e+07 2.0 model = Block(
44 1 78242.0 78242.0 0.0 dim = cfg["shape"][-1],
45 1 80352.0 80352.0 0.0 num_heads = cfg.model["num_heads"],
46 1 53571.0 53571.0 0.0 qkv_bias = cfg.model["qkv_bias"] ,
47 1 51821.0 51821.0 0.0 attn_drop = cfg.model["attn_drop"],
48 1 50711.0 50711.0 0.0 proj_drop = cfg.model["proj_drop"],
49 1 117420554.0 1e+08 2.9 ).to("cuda")
50
51 1 108593.0 108593.0 0.0 if cfg["train"]:
52 1 2054675844.0 2e+09 50.5 opt = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)
53
54 1 260.0 260.0 0.0 times = []
55 2 52532.0 26266.0 0.0 with Profiler() as p:
56 1 1390.0 1390.0 0.0 start = time.perf_counter()
57 33 14661233.0 444279.8 0.4 for i, batch in enumerate(loader):
58 32 65280.0 2040.0 0.0 image = batch["image"]
59 32 148192.0 4631.0 0.0 image = image.to("cuda")
60 64 392538.0 6133.4 0.0 with annotate(domain="forward", color="blue"):
61 32 100282907.0 3e+06 2.5 output = model(image)
62 32 1679425.0 52482.0 0.0 if cfg["train"]:
63 32 3523561.0 110111.3 0.1 opt.zero_grad()
64 # Compute the loss:
65 32 27139497.0 848109.3 0.7 loss = loss_fn(output)
66 # Do the gradient calculation:
67 64 424059.0 6625.9 0.0 with annotate(domain="backward", color="green"):
68 32 46323272.0 1e+06 1.1 loss.backward()
69 # Apply the gradients
70 32 86108878.0 3e+06 2.1 opt.step()
71 32 68970.0 2155.3 0.0 p.step()
72 32 1538055836.0 5e+07 37.8 torch.cuda.synchronize()
73 32 31759.0 992.5 0.0 end = time.perf_counter()
74 32 491312.0 15353.5 0.0 print(f"Finished step {i} in {end - start:.4f} seconds")
75 32 14763.0 461.3 0.0 times.append(end - start)
76 32 17920.0 560.0 0.0 start = time.perf_counter()
77
78 1 56391.0 56391.0 0.0 times = torch.tensor(times)
79 # Drop first and last:
80 1 84261.0 84261.0 0.0 avg_time = times[1:-1].mean()
81 # compute throughput too:
82 1 77271.0 77271.0 0.0 throughput = cfg["batch_size"] / avg_time
83 1 35361.0 35361.0 0.0 print(f"Average time per iteration: {avg_time:.3f} ({throughput:.3f} examples / s)")

1 change: 0 additions & 1 deletion docs/tutorials/profiling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,6 @@ Next, take a look at the first instrumented version of the model code,
compared to the original:

.. code-block:: diff

*** attn_baseline.py 2025-01-27 07:41:37.749753000 -0800
--- attn_instrumented.py 2025-01-27 11:27:09.162202000 -0800
***************
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
"if Path(\"ahmed_body_mgn.zip\").is_file():\n",
" pass\n",
"else:\n",
" !wget 'https://api.ngc.nvidia.com/v2/models/nvidia/modulus/modulus_ahmed_body_meshgraphnet/versions/v0.2/files/ahmed_body_mgn.zip'\n",
" !wget 'https://api.ngc.nvidia.com/v2/models/nvidia/physicsnemo/modulus_ahmed_body_meshgraphnet/versions/v0.2/files/ahmed_body_mgn.zip'\n",
" !unzip ahmed_body_mgn.zip\n",
" !mv ahmed_body_mgn/* .\n",
" !rm utils.py # TODO: hacky, remove the old utils.py"
Expand Down
20 changes: 12 additions & 8 deletions examples/cfd/external_aerodynamics/domino/src/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,19 @@ resume_dir: ${output}/models

data_processor: # Data processor configurable parameters
kind: drivaer_aws # must be either drivesim or drivaer_aws
output_dir: /lustre/rranade/modulus_dev/data/aws_data_all/
output_dir: /user_data/datasets/domino_volume_data_zarr/
input_dir: /lustre/datasets/drivaer_aws/drivaer_data_full/
cached_dir: /lustre/cached/drivaer_aws/drivaer_data_full/
use_cache: false
num_processors: 12

data: # Input directory for training and validation data
input_dir: /lustre/rranade/modulus_dev/data/aws_data_all/
input_dir_val: /lustre/rranade/modulus_dev/data/aws_data_all_val/
input_dir: /user_data/datasets/benchmark_datasets/domino_volume_data/
input_dir_val: /user_data/datasets/domino_volume_data_val/
# input_dir: /user_data/datasets/benchmark_datasets/domino_volume_data/
# input_dir_val: //user_data/datasets/domino_volume_data_cleaned_val/
# input_dir: /user_data/datasets/domino_volume_data_cleaned/
# input_dir_val: //user_data/datasets/domino_volume_data_cleaned_val/
bounding_box: # Bounding box dimensions for computational domain
min: [-3.5, -2.25 , -0.32]
max: [8.5 , 2.25 , 3.00]
Expand All @@ -64,13 +68,13 @@ variables:
model:
model_type: combined # train which model? surface, volume, combined
loss_function:
loss_type: "mse" # mse or rmse
loss_type: "rmse" # mse or rmse
area_weighing_factor: 10000 # Generally inverse of maximum area
interp_res: [128, 128, 128] # resolution of latent space 128, 64, 48
use_sdf_in_basis_func: true # SDF in basis function network
positional_encoding: false # calculate positional encoding?
volume_points_sample: 8192 # Number of points to sample in volume per epoch
surface_points_sample: 8192 # Number of points to sample on surface per epoch
volume_points_sample: 200000 # Number of points to sample in volume per epoch
surface_points_sample: 200000 # Number of points to sample on surface per epoch
surface_sampling_algorithm: area_weighted # random or area_weighted
geom_points_sample: 300_000 # Number of points to sample on STL per epoch
surface_neighbors: true # Pre-compute surface neighborhood from input data
Expand Down Expand Up @@ -118,8 +122,8 @@ model:
num_modes: 5

train: # Training configurable parameters
epochs: 1000
checkpoint_interval: 1
epochs: 1
checkpoint_interval: 100
dataloader:
batch_size: 1
pin_memory: false # if the preprocessing is outputing GPU data, set this to false
Expand Down
Loading