diff --git a/docs/test_scripts/profiling/annotated_code/modulus_profiling_outputs/line_profiler/profiler_stats.txt b/docs/test_scripts/profiling/annotated_code/modulus_profiling_outputs/line_profiler/profiler_stats.txt new file mode 100644 index 0000000000..78b118d996 --- /dev/null +++ b/docs/test_scripts/profiling/annotated_code/modulus_profiling_outputs/line_profiler/profiler_stats.txt @@ -0,0 +1,116 @@ +Timer unit: 1e-09 s + +Total time: 0.0635549 s +File: /root/physicsnemo/docs/test_scripts/profiling/annotated_code/attn.py +Function: forward at line 31 + +Line # Hits Time Per Hit % Time Line Contents +============================================================== + 31 @profile + 32 def forward(self, x: torch.Tensor) -> torch.Tensor: + 33 + 34 8 47352.0 5919.0 0.1 B, N, C = x.shape + 35 8 41506771.0 5e+06 65.3 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + 36 8 250098.0 31262.2 0.4 q, k, v = qkv.unbind(0) + 37 + 38 + 39 # This is not optimal code right here ... + 40 8 4760805.0 595100.6 7.5 q = q * self.scale + 41 8 7717546.0 964693.2 12.1 attn = q @ k.transpose(-2, -1) + 42 8 6836559.0 854569.9 10.8 attn = attn.softmax(dim=-1) + 43 8 467249.0 58406.1 0.7 attn = self.attn_drop(attn) + 44 8 681692.0 85211.5 1.1 x = attn @ v + 45 + 46 8 273367.0 34170.9 0.4 x = x.transpose(1, 2).reshape(B, N, C) + 47 8 822387.0 102798.4 1.3 x = self.proj(x) + 48 8 182694.0 22836.8 0.3 x = self.proj_drop(x) + 49 8 8381.0 1047.6 0.0 return x + +Total time: 0.00505685 s +File: /root/physicsnemo/docs/test_scripts/profiling/annotated_code/attn.py +Function: forward at line 65 + +Line # Hits Time Per Hit % Time Line Contents +============================================================== + 65 @profile + 66 def forward(self, x): + 67 8 493290.0 61661.2 9.8 x = self.fc1(x) + 68 8 3648444.0 456055.5 72.1 x = self.gelu(x) + 69 8 155234.0 19404.2 3.1 x = self.drop1(x) + 70 8 437770.0 54721.2 8.7 x = self.fc2(x) + 71 8 197332.0 24666.5 3.9 x = self.gelu(x) + 72 8 120381.0 15047.6 2.4 x = self.drop2(x) + 73 8 4401.0 550.1 0.1 return x + +Total time: 0.0855953 s +File: /root/physicsnemo/docs/test_scripts/profiling/annotated_code/attn.py +Function: forward at line 104 + +Line # Hits Time Per Hit % Time Line Contents +============================================================== + 104 @profile + 105 def forward(self, x: torch.Tensor) -> torch.Tensor: + 106 8 79767938.0 1e+07 93.2 x = x + self.attn(self.norm1(x)) + 107 8 5821609.0 727701.1 6.8 x = x + self.mlp(self.norm2(x)) + 108 8 5719.0 714.9 0.0 return x + +Total time: 4.54253 s +File: /root/physicsnemo/docs/test_scripts/profiling/annotated_code/workload.py +Function: workload at line 30 + +Line # Hits Time Per Hit % Time Line Contents +============================================================== + 30 @profile + 31 def workload(cfg): + 32 + 33 1 304686.0 304686.0 0.0 ds = RandomNoiseDataset(cfg["shape"]) + 34 + 35 2 202324.0 101162.0 0.0 loader = DataLoader( + 36 1 170.0 170.0 0.0 ds, + 37 1 42921.0 42921.0 0.0 batch_size=cfg["batch_size"], + 38 1 180.0 180.0 0.0 shuffle = True, + 39 ) + 40 + 41 + 42 # Initialize the model: + 43 3 79343062.0 3e+07 1.7 model = Block( + 44 1 79202.0 79202.0 0.0 dim = cfg["shape"][-1], + 45 1 73421.0 73421.0 0.0 num_heads = cfg.model["num_heads"], + 46 1 52831.0 52831.0 0.0 qkv_bias = cfg.model["qkv_bias"] , + 47 1 50570.0 50570.0 0.0 attn_drop = cfg.model["attn_drop"], + 48 1 50352.0 50352.0 0.0 proj_drop = cfg.model["proj_drop"], + 49 1 109143037.0 1e+08 2.4 ).to("cuda") + 50 + 51 1 107952.0 107952.0 0.0 if cfg["train"]: + 52 1 2059751411.0 2e+09 45.3 opt = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9) + 53 + 54 1 270.0 270.0 0.0 times = [] + 55 2 74021.0 37010.5 0.0 with Profiler() as p: + 56 1 1521.0 1521.0 0.0 start = time.perf_counter() + 57 9 2022090217.0 2e+08 44.5 for i, batch in enumerate(loader): + 58 8 74692.0 9336.5 0.0 image = batch["image"] + 59 8 47176297.0 6e+06 1.0 image = image.to("cuda") + 60 16 364408.0 22775.5 0.0 with annotate(domain="forward", color="blue"): + 61 8 85942255.0 1e+07 1.9 output = model(image) + 62 8 871577.0 108947.1 0.0 if cfg["train"]: + 63 8 1728264.0 216033.0 0.0 opt.zero_grad() + 64 # Compute the loss: + 65 8 24881952.0 3e+06 0.5 loss = loss_fn(output) + 66 # Do the gradient calculation: + 67 16 181082.0 11317.6 0.0 with annotate(domain="backward", color="green"): + 68 8 50119067.0 6e+06 1.1 loss.backward() + 69 # Apply the gradients + 70 8 58985347.0 7e+06 1.3 opt.step() + 71 8 35302.0 4412.8 0.0 p.step() + 72 8 27261.0 3407.6 0.0 end = time.perf_counter() + 73 8 352396.0 44049.5 0.0 print(f"Finished step {i} in {end - start:.4f} seconds") + 74 8 4790.0 598.8 0.0 times.append(end - start) + 75 8 6301.0 787.6 0.0 start = time.perf_counter() + 76 + 77 1 84802.0 84802.0 0.0 times = torch.tensor(times) + 78 # Drop first and last: + 79 1 117812.0 117812.0 0.0 avg_time = times[1:-1].mean() + 80 # compute throughput too: + 81 1 152063.0 152063.0 0.0 throughput = cfg["batch_size"] / avg_time + 82 1 60321.0 60321.0 0.0 print(f"Average time per iteration: {avg_time:.3f} ({throughput:.3f} examples / s)") + diff --git a/docs/test_scripts/profiling/fixed_data_loader/modulus_profiling_outputs/line_profiler/profiler_stats.txt b/docs/test_scripts/profiling/fixed_data_loader/modulus_profiling_outputs/line_profiler/profiler_stats.txt new file mode 100644 index 0000000000..b7e7647f8b --- /dev/null +++ b/docs/test_scripts/profiling/fixed_data_loader/modulus_profiling_outputs/line_profiler/profiler_stats.txt @@ -0,0 +1,117 @@ +Timer unit: 1e-09 s + +Total time: 0.0712444 s +File: /root/physicsnemo/docs/test_scripts/profiling/fixed_data_loader/attn.py +Function: forward at line 31 + +Line # Hits Time Per Hit % Time Line Contents +============================================================== + 31 @profile + 32 def forward(self, x: torch.Tensor) -> torch.Tensor: + 33 + 34 32 54830.0 1713.4 0.1 B, N, C = x.shape + 35 32 41511385.0 1e+06 58.3 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + 36 32 368675.0 11521.1 0.5 q, k, v = qkv.unbind(0) + 37 + 38 + 39 # This is not optimal code right here ... + 40 32 5547725.0 173366.4 7.8 q = q * self.scale + 41 32 9591905.0 299747.0 13.5 attn = q @ k.transpose(-2, -1) + 42 32 7572284.0 236633.9 10.6 attn = attn.softmax(dim=-1) + 43 32 915209.0 28600.3 1.3 attn = self.attn_drop(attn) + 44 32 2071393.0 64731.0 2.9 x = attn @ v + 45 + 46 32 981847.0 30682.7 1.4 x = x.transpose(1, 2).reshape(B, N, C) + 47 32 1999401.0 62481.3 2.8 x = self.proj(x) + 48 32 598419.0 18700.6 0.8 x = self.proj_drop(x) + 49 32 31321.0 978.8 0.0 return x + +Total time: 0.00933511 s +File: /root/physicsnemo/docs/test_scripts/profiling/fixed_data_loader/attn.py +Function: forward at line 65 + +Line # Hits Time Per Hit % Time Line Contents +============================================================== + 65 @profile + 66 def forward(self, x): + 67 32 1681602.0 52550.1 18.0 x = self.fc1(x) + 68 32 4308446.0 134638.9 46.2 x = self.gelu(x) + 69 32 525102.0 16409.4 5.6 x = self.drop1(x) + 70 32 1591600.0 49737.5 17.0 x = self.fc2(x) + 71 32 773671.0 24177.2 8.3 x = self.gelu(x) + 72 32 441258.0 13789.3 4.7 x = self.drop2(x) + 73 32 13428.0 419.6 0.1 return x + +Total time: 0.0998171 s +File: /root/physicsnemo/docs/test_scripts/profiling/fixed_data_loader/attn.py +Function: forward at line 104 + +Line # Hits Time Per Hit % Time Line Contents +============================================================== + 104 @profile + 105 def forward(self, x: torch.Tensor) -> torch.Tensor: + 106 32 88075019.0 3e+06 88.2 x = x + self.attn(self.norm1(x)) + 107 32 11725880.0 366433.8 11.7 x = x + self.mlp(self.norm2(x)) + 108 32 16161.0 505.0 0.0 return x + +Total time: 4.07266 s +File: /root/physicsnemo/docs/test_scripts/profiling/fixed_data_loader/workload.py +Function: workload at line 30 + +Line # Hits Time Per Hit % Time Line Contents +============================================================== + 30 @profile + 31 def workload(cfg): + 32 + 33 1 171664.0 171664.0 0.0 ds = RandomNoiseDataset(cfg["shape"]) + 34 + 35 2 204305.0 102152.5 0.0 loader = DataLoader( + 36 1 180.0 180.0 0.0 ds, + 37 1 34230.0 34230.0 0.0 batch_size=cfg["batch_size"], + 38 1 140.0 140.0 0.0 shuffle = True, + 39 ) + 40 + 41 + 42 # Initialize the model: + 43 3 79993605.0 3e+07 2.0 model = Block( + 44 1 78242.0 78242.0 0.0 dim = cfg["shape"][-1], + 45 1 80352.0 80352.0 0.0 num_heads = cfg.model["num_heads"], + 46 1 53571.0 53571.0 0.0 qkv_bias = cfg.model["qkv_bias"] , + 47 1 51821.0 51821.0 0.0 attn_drop = cfg.model["attn_drop"], + 48 1 50711.0 50711.0 0.0 proj_drop = cfg.model["proj_drop"], + 49 1 117420554.0 1e+08 2.9 ).to("cuda") + 50 + 51 1 108593.0 108593.0 0.0 if cfg["train"]: + 52 1 2054675844.0 2e+09 50.5 opt = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9) + 53 + 54 1 260.0 260.0 0.0 times = [] + 55 2 52532.0 26266.0 0.0 with Profiler() as p: + 56 1 1390.0 1390.0 0.0 start = time.perf_counter() + 57 33 14661233.0 444279.8 0.4 for i, batch in enumerate(loader): + 58 32 65280.0 2040.0 0.0 image = batch["image"] + 59 32 148192.0 4631.0 0.0 image = image.to("cuda") + 60 64 392538.0 6133.4 0.0 with annotate(domain="forward", color="blue"): + 61 32 100282907.0 3e+06 2.5 output = model(image) + 62 32 1679425.0 52482.0 0.0 if cfg["train"]: + 63 32 3523561.0 110111.3 0.1 opt.zero_grad() + 64 # Compute the loss: + 65 32 27139497.0 848109.3 0.7 loss = loss_fn(output) + 66 # Do the gradient calculation: + 67 64 424059.0 6625.9 0.0 with annotate(domain="backward", color="green"): + 68 32 46323272.0 1e+06 1.1 loss.backward() + 69 # Apply the gradients + 70 32 86108878.0 3e+06 2.1 opt.step() + 71 32 68970.0 2155.3 0.0 p.step() + 72 32 1538055836.0 5e+07 37.8 torch.cuda.synchronize() + 73 32 31759.0 992.5 0.0 end = time.perf_counter() + 74 32 491312.0 15353.5 0.0 print(f"Finished step {i} in {end - start:.4f} seconds") + 75 32 14763.0 461.3 0.0 times.append(end - start) + 76 32 17920.0 560.0 0.0 start = time.perf_counter() + 77 + 78 1 56391.0 56391.0 0.0 times = torch.tensor(times) + 79 # Drop first and last: + 80 1 84261.0 84261.0 0.0 avg_time = times[1:-1].mean() + 81 # compute throughput too: + 82 1 77271.0 77271.0 0.0 throughput = cfg["batch_size"] / avg_time + 83 1 35361.0 35361.0 0.0 print(f"Average time per iteration: {avg_time:.3f} ({throughput:.3f} examples / s)") + diff --git a/docs/tutorials/domain_parallelism/domain_parallelism.rst b/docs/tutorials/domain_parallelism/domain_parallelism.rst new file mode 100644 index 0000000000..b9320ee1d2 --- /dev/null +++ b/docs/tutorials/domain_parallelism/domain_parallelism.rst @@ -0,0 +1,236 @@ +# Domain Parallelism and Shard Tensor + +In scientific AI, one of most challenging aspects in training a model is dealing with extremely high resolution data. In this tutorial, we'll explore what makes high resolution data so challenging to handle, for both training and inference, and why that's different from the scaling challenges in other domains (like NLP, image processing, etc.). We'll also take a technical look at how we're working to streamline high-resolution model training in `PhysicsNeMo`, and how you can leverage our tools for your own scientific workloads as well. + +## What makes scientific AI challenging? + +To understand why scientific AI hits unique challenges in training and inference on high resolution data, let's take a look at the computational and memory cost of training models and subsequently running inference. "Cost" here refers to two fundamental, high level concepts: computational cost is how much computing power is needed to complete an operation (and is, in general, a complicated interplay of GPU FLOPs, memory bandwidth, cache sizes, algorithm efficiencies, and more); memory costs refer to the amount of GPU memory required to perform the computations. + +For all AI models, the memory cost of inference is dominated by just two categories of use: + +1. **Model parameters** (weights, biases, encodings, etc.) all are required to be loaded into GPU memory for fast access during inference. For a model with N total parameters, each parameter requires 4 bytes in float32 precision, or 2 in float16/bfloat16. A rough approximation is that a 100M parameter model requires 400MB of memory in float32 precision. For Large Language Models with billions of parameters, even at inference time this is a large amount of memory. +2. **Data and Activations** represent the memory required to actually compute the layers and outputs of the model. For inference, the available memory has to be enough to hold the input data, output data, and model parameters as well as temporariliy accomodate memory of intermediate activations. As one layer's output is consumed by the next layer, the total memory needed to store activations typically never exceeds the requirements of the most memory-intensive layer. + +For scientific AI with high resolution data, the memory cost at inference quickly becomes dominated not by the model parameters but by the data. + +During training, the high resolution of the data is even ore challenging. For each layer during training, pytorch will typically save the some version of the layer's input or output as the "intermediate activation" for that layer. In practice, this is a computational optimization to enable the backwards pass to compute and propagate gradients more efficiently. Each layer, however, requires extra memory storage during training that is proportional to the resolution of the input data. + +As a cumulative effect, as models continue to stack up layers and save intermediate activations, the activation-related memory required training a model grows with both the depth of the model and the resolution of the input data. In contrast to Large Language Models, where the memory usage during training is dominated by the parameters, gradients, and optimizer states, for high resolution scientific AI models with modest parameter counts the memory usage is dominated by data size. + +To address this challenge, in PhysicsNeMo we have developed a domain-parallelism framework specifically designed to parallelize the high compute and memory costs of training and inferencing models on high resolution data. Named `ShardTensor`, and built on top of PyTorch's `DTensor` framework, `ShardTensor` allows models to divide expensive operations across multiple GPUs - parallelizing both the compute required as well as the storage of the intermediate activations. + +The remainder of this tutorial will focus on the high level concepts of `ShardTensor` and domain parallelism, and [implementing new layers with `ShardTensor`](#implementing-new-layers) TODO will be covered in a separate tutorial. + +## Starting with an Example + +As a high level example, let's consider a simple 2D convolution operation. There have been many tutorials on the mathematics and efficient computation of convolutions; let's not focus on that here. Instead, consider if the input data to the convolution is spread across two GPUs, and we want to correctly compute the ouput of the convolution but without ever realizing the input data on a single GPU. + +Just applying the convolution to each half provides incorrect results. We can simulate this, actually, in pytorch on one device: + +```python +import torch + +full_image = torch.randn(1, 8, 1024, 1024) + +left_image = full_image[:,:,:512,:] +right_image = full_image[:,:,512:,:] + +convolution_operator = torch.nn.Conv2d(8, 8, 3, stride=1, padding=1) + +full_output = convolution_operator(full_image) + +left_output = convolution_operator(left_image) +right_output = convolution_operator(right_image) + +recombined_output = torch.cat([left_output, right_output], dim=2) + +# Do the shapes agree? +print(full_output.shape) +print(recombined_output.shape) +# (they should!) + +# Do the values agree? +torch.allclose(full_output, recombined_output) +# (the do not!) +``` + +To understand why they don't agree, we can look at the location of the disagreement: + +```python +diff = full_output - recombined_output +b_locs, c_locs, h_locs, w_locs = torch.where( torch.abs(diff) > 1e-6) +print(torch.unique(b_locs)) +print(torch.unique(c_locs)) +print(torch.unique(h_locs)) +print(torch.unique(w_locs)) +``` + +This will produce the following output: + +``` +tensor([0]) +tensor([0, 1, 2, 3, 4, 5, 6, 7]) +tensor([511, 512]) +tensor([ 0, 1, 2, ..., 1021, 1022, 1023]) +``` + +We see in particular that along the height dimension (dim=2), the output is incorrect only along the pixels 511 and 512 - right where we split the data! The problem is that the convolution operator is a local operation, but splitting the data prevents it from seeing the correct neighboring pixels right at the border. You could fix this directly: + +```python + +# Slice off the data needed on the other image (around the center of the original image) +missing_left_data = right_image[:,:,0:1,:] +missing_right_data = left_image[:,:,-1:,:] + +# Add it to the correct image +padded_left_image = torch.cat([left_image, missing_left_data], 2) +padded_right_image = torch.cat([missing_right_data, right_image], 2) + +# Recompute convolutions +right_output = convolution_operator(padded_right_image)[:,:,1:,:] +left_output = convolution_operator(padded_left_image)[:,:,:-1,:] +# ^ Need to drop the extra pixels in the output here + +recombined_output = torch.cat([left_output, right_output], dim=2) + +# Now, the output works correctly: +torch.allclose(recombined_output, full_output) +# True +``` + +In the example above, for a simple convolution, we saw that just splitting the data and applying the base operation didn't give the results we needed. In general, this is true of many operations we see in AI models: splitting the data across GPUs requires extra operations or communication, depending on the operation, to get everything right. We also haven't even mentioned the gradients yet - to call `backward()` through this split operation across devices also requires extra operations and communication. But, in order to get the memory and potential computational benefits of domain parallelism, it's necessary. + +## How does `ShardTensor` help? + +PyTorch's `DTensor` interface already has an interface for a distributed tensor mechanism, and it's great - great enough, in fact, that `ShardTensor` is built upon it. However, `DTensor` is built with a different paradigm of parallelism in mind, including model parallelisms from DeepSpeed and MegaTron [CITATION NEEDED]. It has several shortcomings: notably, it can not accomodate data that isn't distributed uniformly or according to `torch.chunk` syntax. For scientific data, such as mesh data, point clouds, or anything else irregular, this is a nearly-immediate dead end for deploying domain parallelism. Further, `DTensor`'s mechanism for implementing parallelism is largely restricted to lower level `torch` operations - great for broad support in PyTorch, but not as accesible for most developers. + +With `ShardTensor`, we extend the functionality of `DTensor` in the ways needed to make domain parallelism simpler and easier to apply. In practice, this looks like the following, if we reuse the convolution example from before: + +```python +import torch + +from torch.distributed.tensor import ( + Shard, + distribute_module, +) + + +from physicsnemo.distributed import ( + DistributedManager, + ShardTensor, + scatter_tensor, + register_custom_ops +) + +register_custom_ops() + +DistributedManager.initialize() +dm = DistributedManager() + +mesh = dm.initialize_mesh((-1,), ("domain_parallel",)) + +if dm.rank == 0: + original_tensor = torch.randn(1, 8, 1024, 1024, device=dm.device, requires_grad=True) +else: + original_tensor = None + +# This is now a tensor across all GPUs, spread on the "height" dimension == 2 +sharded_tensor = scatter_tensor(original_tensor, 0, mesh, (Shard(2),), requires_grad=True) + +conv = torch.nn.Conv2d(8, 8, 3, stride=1, padding=1).to(dm.device) + +if dm.rank == 0: + # Get the single-device output: + rank_0_output = conv(original_tensor) + +# We tell pytorch that the convolution will work on distributed tensors: +# And, over the same mesh! +conv = distribute_module(conv, mesh) + +# Now, we can do the distributed convolution: +sharded_output = conv(sharded_tensor) + +# We can now gather the output back to the original tensor: +full_output = sharded_output.full_tensor() # This triggers a collective allgather. + +if dm.rank == 0: + + # Check that the output is the same as the single-device output: + diff = full_output - rank_0_output + assert torch.allclose(full_output, rank_0_output) + print(f"Global operation matches local! ") + +# We can even do gradients: +if dm.rank == 0: + rank_0_output.mean().backward() + original_tensor_grad = original_tensor.grad.data.clone() + +# Distribute gradients: +full_output.mean().backward() + +distributed_grad = sharded_tensor.grad +# distributed grad itsself is a sharded tensor: +full_grad = distributed_grad.full_tensor() +if dm.rank == 0: + # Check that the gradient is correct: + assert torch.allclose(original_tensor_grad, full_grad) + print(f"Gradient check passed!") + +print(f"Distributed grad sharding and local shape: {distributed_grad._spec.placements}, {distributed_grad.to_local().shape}") + +``` + +If you run this (`torchrun --nproc-per-node 4 conv_example.py`), you'll see the checks on output and gradients both pass. Further, the last line will print: + +```text +Distributed grad sharding and local shape: (Shard(dim=2),), torch.Size([1, 8, 256, 1024]) +``` + +Note that when running this, there was no need to perform manual communication or padding, in either the forward or backward pass. And, though we used a convolution, the details of the operation didn't need to be explicitly specified. In this case, it just worked. + +## How does `ShardTensor` work? + +At a high level, `DTensor` from pytorch is a concept of a local chunk of a tensor (stored as a `torch.Tensor`), and a `DTensorSpec` object which combines a `DeviceMesh` object representing the group of GPUs the tensor is on, and a description of how that global tensor is distributed (or replicated). `ShardTensor` extends this API with an addition to the specification to track the shape of each local tensor along sharding axes. This becomes important when the input data is something like a point cloud, rather than an evenly-distributed tensor. + +At run time, when an operation in `torch` has `DTensor` as input, pytorch will use a custom dispatcher in `DTensor` to route perform operations correctly on the inputs. `ShardTensor` extends this by intercepting a little higher than `DTensor`: operations can be intercepted at the functional level, or at the dispatch level, and if `ShardTensor` has no registered implementation it will fall back to DTensor. + +ShardTensor also has custom implementations of common reduction operations `sum` and `mean`, in order to properly intercept and distribute gradients correctly. This is why, in the example above, you can seamlessly call `mean().backward()` on a `ShardTensor` and the gradients will arrive to their proper sharding. + +There is a substantial amount of care needed to implement layers in `ShardTensor` (or `DTensor`!). If you're interested in doing so for your custom model, please check out a full tutorial on this subject: [implementing-new-layers](implementing-new-layers) TODO. + +# When Should You Use `ShardTensor`? + +`ShardTensor` and domain parallelism solve a very specific problem in Scientific AI: input data is such high resolution that models can't train, even at Batch Size of 1, due to memory limitations. And while that challenge can be partially surmounted with reduced precision and input spatial downsampling, not all models can tolerate those techniques without sacrificing accuracy. In this case, you should view `ShardTensor` as a solution to that problem: it will enable you to run training and inference on higher resolution data than a single GPU can accomodate. It is not the only technique for this, and in some cases it isn't the best choice. In this section we'll compare and contrast `ShardTensor` to some other techinques for high resolution data, which can highlight some strengths and weaknesses of `ShardTensor.` + +One other technique for high resolution data is **Pipeline Parallelism**. In pipeline parallelism, the model is divided across 2 or more devices, and each device contains full layers and activations, but to run the entire model the data is "pipelined": input data on GPU 0 is propagated through the local layers, and the outputs of the last layer on GPU 0 become the inputs to the first layer on GPU 1, and so on. Gradients can be computed by running the pipeline in reverse, as well. + +For some use cases, pipeline parallelism can be very powerful. But it also has some weaknesses that `ShardTensor` can avoid. Pipeline parallelism enables scaling of GPU memory resources but does not take much advantage of scaling up GPU compute resources without modifying the training loop. While GPU 0 is active, all other GPUs are waiting on input. And once GPU 0 passes data to GPU 1, GPU 0 sits idly until the backward pass or the next batch of data arrives. For large minibatch data, a good strategy could be to feed each batch of data sequentially: when data passes from GPU 0 to GPU 1, the next example can start processing on GPU 0. For inference on large datasets, this is quite efficient, but may will cause a computational "bubble" or stall everytime gradients are computed and the model is updated. + +With just one or several points in the model where pipeline parallelism divides your model, it is conceptually simple and each GPU has minimal communication overhead. However, not all models are well supported with pipeline parallelism (consider a UNet architecture). On the other hand, `ShardTensor` enables you to slice your model by dividing each and every layer over sharded inputs. In terms of model support, this makes more complicated architectures like UNet simple: the concatenation of features across the down/up sampling paths is unmodified in user space (and in fact it's pretty simple in low-level implementations too: it becomes a concat of the local tensor objects). On the other hand, because each layer introduces additional overhead of communication or coordination, a sharded layer can be less efficient than a purely-local layer. + +As a general rule, `ShardTensor` performs efficiently when the input data is large, and when the ratio of communication time to computation time is small. For some operations, like sequence-parallel attention via a Ring Mechanism [Ring Attention] TODO the benefits become clear, as shown below: the sharded model is faster after a certain input data size. More importantly, the sharded model is still **functional** after a massive input size: something pipeline parallelism could not acheive for a simple one-layer model. + +TODO - add plot of attention efficiency. + +Of course, a one-layer model isn't a good representation of actual user code. Instead, use this as a guiding principle: when the GPU kernels are long because the input data is large, `ShardTensor` will scale very efficiently. When GPU kernels are small, and a model launches many small kernels, `ShardTensor` will be functional but not as efficient. In these cases you may have slightly better scaling with pipeline or other parallelism. Note, however, that `ShardTensor` is still in development and performance optimizations for small kernels are ongoing. + +Another technique for dealing with high resolution input data during training is activation checkpointing. In this technique, during the forward pass, activations are moved from GPU memory to CPU memory to make more space available. They are restored during the backward pass when needed, and the rest of the backward pass continues. Compared to pipeline parallelism, this technique can better leverage parallelization across GPUs with standard Data-Parallel scaling. However, it can be limited by GPU/CPU transfer speeds and possible blocking operations. On NVIDIA GPUs with NCCL enabled, the peer-to-peer bandwidth can be significantly higher than CPU-GPU bandwidth (though not all - GraceHopper systems, for example, can efficiently and effectively take advantage of CPU memory offloading). Unlike `ShardTensor`, the offloading of activations may need hand tuning and optimization based on GPU system architecture. `ShardTensor` is designed to work with your model `as-is` to the greatest possible extent. + +In general, if your model meets all of these conditions, you should consider using `ShardTensor` for domain parallelism during training: + +- Your model has relatively large input size even at batch size of 1 - so large, in fact, that you run out of GPU memory trying to train the model. with batch size 1. + - If your model comfortably fits batch_size=1 training, you will have a simpler and more efficient training using PyTorch's DistributedDataParallel (link, TODO) +- Your model is composed of supported domain-parallel layers (convolutions, normalizations, upsampling/pooling/reductions, attention layers, etc.) + - Not every layer has a domain-parallel implementation in PhysicsNeMo. You can add it to your code yourself if it's simple (consider a P.R. if you do!) or ask for support on github. + - How do you know if a layer is supported? Pass a `ShardTensor` in like above and test it! +- You have multiple GPUs available (ideally connected with high-performance peer to peer path such as NCCL). + +For the best efficiency training with `ShardTensor`, look for: +- Your model is mostly composed of large, compute- or bandwidth-bound kernels rather than very small, low-latency kernels. +- Your model is composed of mostly non-blocking CUDA kernels, allowing the slightly higher overhead of domain parallelism to still fill the GPU queue efficiently. + +For inference, on the other hand, `ShardTensor` can still be useful for lower latency inference on extremely high resolution data. Especially if the model is primarly composed of compute- or bandwidth-bound kernels, and the commmunication overhead is small, `ShardTensor` can provide reductions of inference latency. + +# Summary + +In this tutorial, we saw details about PhysicsNeMo's `ShardTensor` object, and how it can be used to enable domain parallelism. For more behind-the-scenes details of how layers are enabled, see (implementing-new-layers)[implementing-new-layers] #TODO. For an example of combining domain parallelism with other parallelisms through FSDP, see [fsdp_and_shard_tensor](fsdp_and_shard_tensor.rst) TODO-fixlink. \ No newline at end of file diff --git a/docs/tutorials/profiling.rst b/docs/tutorials/profiling.rst index f380181cbd..8dbd1b5cb3 100644 --- a/docs/tutorials/profiling.rst +++ b/docs/tutorials/profiling.rst @@ -247,7 +247,6 @@ Next, take a look at the first instrumented version of the model code, compared to the original: .. code-block:: diff - *** attn_baseline.py 2025-01-27 07:41:37.749753000 -0800 --- attn_instrumented.py 2025-01-27 11:27:09.162202000 -0800 *************** diff --git a/examples/cfd/external_aerodynamics/aero_graph_net/inference_analysis/ahmed_body.ipynb b/examples/cfd/external_aerodynamics/aero_graph_net/inference_analysis/ahmed_body.ipynb index 567afc6d5f..e1e301b25e 100644 --- a/examples/cfd/external_aerodynamics/aero_graph_net/inference_analysis/ahmed_body.ipynb +++ b/examples/cfd/external_aerodynamics/aero_graph_net/inference_analysis/ahmed_body.ipynb @@ -100,7 +100,7 @@ "if Path(\"ahmed_body_mgn.zip\").is_file():\n", " pass\n", "else:\n", - " !wget 'https://api.ngc.nvidia.com/v2/models/nvidia/modulus/modulus_ahmed_body_meshgraphnet/versions/v0.2/files/ahmed_body_mgn.zip'\n", + " !wget 'https://api.ngc.nvidia.com/v2/models/nvidia/physicsnemo/modulus_ahmed_body_meshgraphnet/versions/v0.2/files/ahmed_body_mgn.zip'\n", " !unzip ahmed_body_mgn.zip\n", " !mv ahmed_body_mgn/* .\n", " !rm utils.py # TODO: hacky, remove the old utils.py" diff --git a/examples/cfd/external_aerodynamics/domino/README.md b/examples/cfd/external_aerodynamics/domino/README.md index ca97048f50..92df15836d 100644 --- a/examples/cfd/external_aerodynamics/domino/README.md +++ b/examples/cfd/external_aerodynamics/domino/README.md @@ -59,6 +59,50 @@ To train and test the DoMINO model on AWS dataset, follow these steps: 6. Download the validation results (saved in form of point clouds in `.vtp` / `.vtu` format), and visualize in Paraview. +### Training with Domain Parallelism + +DoMINO has support for training and inference using domain parallelism in physicsnemo, +via the `ShardTensor` mechanisms and pytorch's FSDP tools. `ShardTensor`, built on +PyTorch's `DTensor` object, is a domain-parallel-aware tensor that can live on multiple +GPUs and perform operations in a numerically consistent way. For more information +about the techniques of domain parallelism and `ShardTensor`, refer to physicsnemo +tutorials such as [`ShardTensor`](shard_tensor_tutorial.html). + +In DoMINO specifically, domain parallelism has been abled in two ways, which +can be used concurrently or separately. First, the input sampled volumetric +and surface points can be sharded to accomodate higher resolution point sampling +Second, the latent space of the model - typically a regularlized grid - can be +sharded to reduce computational complexity of the latent processing. When training +with sharded models in DoMINO, the primary objective is to enable higher +resolution inputs and larger latent spaces without sacrificing substantial compute time. + +When configuring DoMINO for sharded training, adjust the following parameters +from `src/conf/config.yaml`: + +```yaml +domain_parallelism: + domain_size: 2 + shard_grid: True + shard_points: True +``` + +The domain_size represents the number of GPUs used for each batch - setting +`domain_size: 1` is not advised since that is the standard training regime, +but with extra overhead. `shard_grid` and `shard_points` will enable domain +parallelism over the latent space and input/output points, respectively. + +Please see `src/train_sharded.py` for more details regarding the changes +from the standard training script required for domain parallel DoMINO training. + +As one last note regarding domain-parallel training: in the phase of the DoMINO +where the output solutions are calculated, the model can used two different +techniques (numerically identical) to calculate the output. Due to the +overhead of potential communication at each operation, it's recommended to +use the `one-loop` mode with `model.solution_calculation_mode` when doing +sharded training. This technique launches vectorized kernels with less +launch overhead at the cost of slightly more memory use. For non-sharded +gtraining, the `two-loop` setting is more optimal. + ## Retraining recipe for DoMINO model To enable retraining the DoMINO model from a pre-trained checkpoint, follow the steps: diff --git a/examples/cfd/external_aerodynamics/domino/src/conf/config.yaml b/examples/cfd/external_aerodynamics/domino/src/conf/config.yaml index 1aba6ed4f4..ecb3ec78e8 100644 --- a/examples/cfd/external_aerodynamics/domino/src/conf/config.yaml +++ b/examples/cfd/external_aerodynamics/domino/src/conf/config.yaml @@ -32,15 +32,19 @@ resume_dir: ${output}/models data_processor: # Data processor configurable parameters kind: drivaer_aws # must be either drivesim or drivaer_aws - output_dir: /lustre/rranade/modulus_dev/data/aws_data_all/ + output_dir: /user_data/datasets/domino_volume_data_zarr/ input_dir: /lustre/datasets/drivaer_aws/drivaer_data_full/ cached_dir: /lustre/cached/drivaer_aws/drivaer_data_full/ use_cache: false num_processors: 12 data: # Input directory for training and validation data - input_dir: /lustre/rranade/modulus_dev/data/aws_data_all/ - input_dir_val: /lustre/rranade/modulus_dev/data/aws_data_all_val/ + input_dir: /user_data/datasets/benchmark_datasets/domino_volume_data/ + input_dir_val: /user_data/datasets/domino_volume_data_val/ + # input_dir: /user_data/datasets/benchmark_datasets/domino_volume_data/ + # input_dir_val: //user_data/datasets/domino_volume_data_cleaned_val/ + # input_dir: /user_data/datasets/domino_volume_data_cleaned/ + # input_dir_val: //user_data/datasets/domino_volume_data_cleaned_val/ bounding_box: # Bounding box dimensions for computational domain min: [-3.5, -2.25 , -0.32] max: [8.5 , 2.25 , 3.00] @@ -48,6 +52,11 @@ data: # Input directory for training and validation data min: [-1.1, -1.2 , -0.32] max: [4.5 , 1.2 , 1.2] +domain_parallelism: + domain_size: 2 + shard_grid: True + shard_points: True + variables: surface: solution: @@ -64,9 +73,9 @@ variables: model: model_type: combined # train which model? surface, volume, combined loss_function: - loss_type: "mse" # mse or rmse + loss_type: "rmse" # mse or rmse area_weighing_factor: 10000 # Generally inverse of maximum area - interp_res: [128, 128, 128] # resolution of latent space 128, 64, 48 + interp_res: [192, 192, 192] # resolution of latent space 128, 64, 48 use_sdf_in_basis_func: true # SDF in basis function network positional_encoding: false # calculate positional encoding? volume_points_sample: 8192 # Number of points to sample in volume per epoch @@ -83,6 +92,7 @@ model: surf_loss_scaling: 5.0 # scale surface loss with this factor in combined mode vol_loss_scaling: 1.0 # scale volume loss with this factor in combined mode geometry_encoding_type: both # geometry encoder type, sdf, stl, both + solution_calculation_mode: two-loop # one-loop is better for sharded, two-loop is lower memory but more overhead resampling_surface_mesh: # resampling of surface mesh before constructing kd tree resample: false #false or true points: 1_000_000 # number of points @@ -118,8 +128,8 @@ model: num_modes: 5 train: # Training configurable parameters - epochs: 1000 - checkpoint_interval: 1 + epochs: 2 + checkpoint_interval: 100 dataloader: batch_size: 1 pin_memory: false # if the preprocessing is outputing GPU data, set this to false diff --git a/examples/cfd/external_aerodynamics/domino/src/train.py b/examples/cfd/external_aerodynamics/domino/src/train.py index 1c57058923..d4c98420d7 100644 --- a/examples/cfd/external_aerodynamics/domino/src/train.py +++ b/examples/cfd/external_aerodynamics/domino/src/train.py @@ -33,9 +33,8 @@ import torch import torchinfo -from typing import Literal +from typing import Literal, Dict -import apex import numpy as np import hydra from hydra.utils import to_absolute_path @@ -46,7 +45,7 @@ from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from torch.utils.tensorboard import SummaryWriter -from nvtx import annotate as nvtx_annotate + import torch.cuda.nvtx as nvtx from physicsnemo.distributed import DistributedManager @@ -57,7 +56,6 @@ DoMINODataPipe, compute_scaling_factors, create_domino_dataset, - # domino_collate_fn, ) from physicsnemo.models.domino.model import DoMINO from physicsnemo.utils.domino.utils import * @@ -70,12 +68,6 @@ nvmlInit() -from physicsnemo.utils.profiling import profile, Profiler - -# Profiler().enable("line_profiler") -# Profiler().initialize() - - def loss_fn( output: torch.Tensor, target: torch.Tensor, @@ -218,8 +210,13 @@ def lift_loss_fn(output, target, area, normals, stream_velocity=None, padded_val output_true = target * mask * area * (vel_inlet) ** 2.0 output_pred = output * mask * area * (vel_inlet) ** 2.0 - pres_true = output_true[:, :, 0] * normals[:, :, 2] - pres_pred = output_pred[:, :, 0] * normals[:, :, 2] + normals = torch.select(normals, 2, 2) + # output_true_0 = output_true[:, :, 0] + output_true_0 = output_true.select(2, 0) + output_pred_0 = output_pred.select(2, 0) + + pres_true = output_true_0 * normals + pres_pred = output_pred_0 * normals wz_true = output_true[:, :, -1] wz_pred = output_pred[:, :, -1] @@ -252,6 +249,77 @@ def drag_loss_fn(output, target, area, normals, stream_velocity=None, padded_val return loss +def compute_loss_dict( + prediction_vol: torch.Tensor, + prediction_surf: torch.Tensor, + batch_inputs: Dict, + loss_fn_type: Dict, + integral_scaling_factor: float, + surf_loss_scaling: float, + vol_loss_scaling: float, +): + nvtx.range_push("Loss Calculation") + total_loss_terms = [] + loss_dict = {} + + if prediction_vol is not None: + target_vol = batch_inputs["volume_fields"] + + loss_vol = loss_fn( + prediction_vol, target_vol, loss_fn_type.loss_type, padded_value=-10 + ) + loss_dict["loss_vol"] = loss_vol + total_loss_terms.append(loss_vol) + + if prediction_surf is not None: + + target_surf = batch_inputs["surface_fields"] + surface_areas = batch_inputs["surface_areas"] + surface_areas = torch.unsqueeze(surface_areas, -1) + surface_normals = batch_inputs["surface_normals"] + stream_velocity = batch_inputs["stream_velocity"] + loss_surf = loss_fn_surface( + prediction_surf, + target_surf, + loss_fn_type.loss_type, + ) + loss_surf_area = loss_fn_area( + prediction_surf, + target_surf, + surface_normals, + surface_areas, + area_scaling_factor=loss_fn_type.area_weighing_factor, + loss_type=loss_fn_type.loss_type, + ) + + if loss_fn_type.loss_type == "mse": + loss_surf = loss_surf * surf_loss_scaling + loss_surf_area = loss_surf_area * surf_loss_scaling + + total_loss_terms.append(0.5 * loss_surf) + loss_dict["loss_surf"] = 0.5 * loss_surf + total_loss_terms.append(0.5 * loss_surf_area) + loss_dict["loss_surf_area"] = 0.5 * loss_surf_area + loss_integral = ( + integral_loss_fn( + prediction_surf, + target_surf, + surface_areas, + surface_normals, + stream_velocity, + padded_value=-10, + ) + ) * integral_scaling_factor + loss_dict["loss_integral"] = loss_integral + total_loss_terms.append(loss_integral) + + total_loss = sum(total_loss_terms) + loss_dict["total_loss"] = total_loss + nvtx.range_pop() + + return total_loss, loss_dict + + def validation_step( dataloader, model, @@ -272,62 +340,17 @@ def validation_step( with autocast(enabled=True): prediction_vol, prediction_surf = model(sampled_batched) - total_loss_terms = [] - if prediction_vol is not None: - target_vol = sampled_batched["volume_fields"] - - alternate_loss_vol = loss_fn( - prediction_vol, - target_vol, - loss_fn_type.loss_type, - padded_value=-10, - ) - total_loss_terms.append(alternate_loss_vol) - - if prediction_surf is not None: - target_surf = sampled_batched["surface_fields"] - surface_normals = sampled_batched["surface_normals"] - surface_areas = sampled_batched["surface_areas"] - stream_velocity = sampled_batched["stream_velocity"] - surface_areas = torch.unsqueeze(surface_areas, -1) - - loss_integral = ( - integral_loss_fn( - prediction_surf, - target_surf, - surface_areas, - surface_normals, - stream_velocity, - padded_value=-10, - ) - ) * integral_scaling_factor # * 0.0 - - alternate_loss_surf = loss_fn_surface( - prediction_surf, - target_surf, - loss_fn_type.loss_type, - ) - alternate_loss_surf_area = loss_fn_area( - prediction_surf, - target_surf, - surface_normals, - surface_areas, - area_scaling_factor=loss_fn_type.area_weighing_factor, - loss_type=loss_fn_type.loss_type, - ) - if loss_fn_type.loss_type == "mse": - alternate_loss_surf = alternate_loss_surf * surf_loss_scaling - alternate_loss_surf_area = ( - alternate_loss_surf_area * surf_loss_scaling - ) - - total_loss_terms.append(0.5 * alternate_loss_surf) - total_loss_terms.append(0.5 * alternate_loss_surf_area) - total_loss_terms.append(loss_integral) - - total_loss = sum(total_loss_terms) - - running_vloss += total_loss.item() + loss, loss_dict = compute_loss_dict( + prediction_vol, + prediction_surf, + sampled_batched, + loss_fn_type, + integral_scaling_factor, + surf_loss_scaling, + vol_loss_scaling, + ) + + running_vloss += loss.item() avg_vloss = running_vloss / (i_batch + 1) @@ -358,6 +381,7 @@ def train_epoch( loss_interval = 1 gpu_start_info = nvmlDeviceGetMemoryInfo(gpu_handle) + start_time = time.perf_counter() for i_batch, sample_batched in enumerate(dataloader): sampled_batched = dict_to_device(sample_batched, device) @@ -365,63 +389,17 @@ def train_epoch( with autocast(enabled=True): with nvtx.range("Model Forward Pass"): prediction_vol, prediction_surf = model(sampled_batched) - total_loss_terms = [] - nvtx.range_push("Loss Calculation") - if prediction_vol is not None: - target_vol = sampled_batched["volume_fields"] - - alternate_loss_vol = loss_fn( - prediction_vol, target_vol, loss_fn_type.loss_type, padded_value=-10 - ) - total_loss_terms.append(alternate_loss_vol) - - if prediction_surf is not None: - target_surf = sampled_batched["surface_fields"] - surface_areas = sampled_batched["surface_areas"] - surface_areas = torch.unsqueeze(surface_areas, -1) - surface_normals = sampled_batched["surface_normals"] - stream_velocity = sampled_batched["stream_velocity"] - alternate_loss_surf = loss_fn_surface( - prediction_surf, - target_surf, - loss_fn_type.loss_type, - ) - alternate_loss_surf_area = loss_fn_area( - prediction_surf, - target_surf, - surface_normals, - surface_areas, - area_scaling_factor=loss_fn_type.area_weighing_factor, - loss_type=loss_fn_type.loss_type, - ) + loss, loss_dict = compute_loss_dict( + prediction_vol, + prediction_surf, + sampled_batched, + loss_fn_type, + integral_scaling_factor, + surf_loss_scaling, + vol_loss_scaling, + ) - if loss_fn_type.loss_type == "mse": - alternate_loss_surf = alternate_loss_surf * surf_loss_scaling - alternate_loss_surf_area = ( - alternate_loss_surf_area * surf_loss_scaling - ) - - total_loss_terms.append(0.5 * alternate_loss_surf) - total_loss_terms.append(0.5 * alternate_loss_surf_area) - loss_integral = ( - integral_loss_fn( - prediction_surf, - target_surf, - surface_areas, - surface_normals, - stream_velocity, - padded_value=-10, - ) - ) * integral_scaling_factor # * 0.0 - total_loss_terms.append(loss_integral) - - total_loss = sum(total_loss_terms) - - nvtx.range_pop() - - # loss = loss_norm - loss = total_loss loss = loss / loss_interval scaler.scale(loss).backward() @@ -429,25 +407,32 @@ def train_epoch( scaler.step(optimizer) scaler.update() optimizer.zero_grad() + # Gather data and report running_loss += loss.item() - + elapsed_time = time.perf_counter() - start_time + start_time = time.perf_counter() gpu_end_info = nvmlDeviceGetMemoryInfo(gpu_handle) gpu_memory_used = gpu_end_info.used / (1024**3) gpu_memory_delta = (gpu_end_info.used - gpu_start_info.used) / (1024**3) logging_string = f"Device {device}, batch processed: {i_batch + 1}\n" - logging_string += f" total loss: {total_loss.item():.5f}\n" - if prediction_vol is not None: - logging_string += f" loss volume: {alternate_loss_vol.item():.5f}\n" - if prediction_surf is not None: - logging_string += f" loss surface: {alternate_loss_surf.item():.5f}\n" - logging_string += ( - f" loss surface area: {alternate_loss_surf_area.item():.5f}\n" - ) - logging_string += f" loss integral: {loss_integral.item():.5f}\n" - logging_string += f" GPU memory used: {gpu_memory_used} Gb\n" - logging_string += f" GPU memory delta: {gpu_memory_delta} Gb\n" + # Format the loss dict into a string: + loss_string = ( + " " + + "\t".join([f"{key.replace('loss_',''):<10}" for key in loss_dict.keys()]) + + "\n" + ) + loss_string += ( + " " + f"\t".join([f"{l.item():<10.2f}" for l in loss_dict.values()]) + "\n" + ) + + logging_string += loss_string + # for key, value in loss_dict.items(): + # logging_string += f" {key}: {value.item():.5f}\n" + logging_string += f" GPU memory used: {gpu_memory_used:.3f} Gb\n" + logging_string += f" GPU memory delta: {gpu_memory_delta:.3f} Gb\n" + logging_string += f" Time taken: {elapsed_time:.2f} seconds\n" logger.info(logging_string) gpu_start_info = nvmlDeviceGetMemoryInfo(gpu_handle) @@ -668,7 +653,7 @@ def main(cfg: DictConfig) -> None: ) epoch_end_time = time.perf_counter() logger.info( - f"Device {dist.device}, Epoch {epoch_number} took {epoch_end_time - epoch_start_time} seconds" + f"Device {dist.device}, Epoch {epoch_number} took {epoch_end_time - epoch_start_time:.3f} seconds" ) epoch_end_time = time.perf_counter() @@ -691,7 +676,7 @@ def main(cfg: DictConfig) -> None: f"Device {dist.device} " f"LOSS train {avg_loss:.5f} " f"valid {avg_vloss:.5f} " - f"Current lr {scheduler.get_last_lr()[0]}" + f"Current lr {scheduler.get_last_lr()[0]} " f"Integral factor {initial_integral_factor}" ) @@ -710,7 +695,8 @@ def main(cfg: DictConfig) -> None: if avg_vloss < best_vloss: # This only considers GPU: 0, is that okay? best_vloss = avg_vloss - print(f"Device { dist.device}, Best val loss {best_vloss}") + if dist.rank == 0: + print(f"Device { dist.device}, Best val loss {best_vloss}") if dist.rank == 0 and (epoch + 1) % cfg.train.checkpoint_interval == 0.0: save_checkpoint( diff --git a/examples/cfd/external_aerodynamics/domino/src/train_sharded.py b/examples/cfd/external_aerodynamics/domino/src/train_sharded.py new file mode 100644 index 0000000000..90423f71e1 --- /dev/null +++ b/examples/cfd/external_aerodynamics/domino/src/train_sharded.py @@ -0,0 +1,590 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code defines a distributed pipeline for training the DoMINO model on +CFD datasets. It includes the computation of scaling factors, instantiating +the DoMINO model and datapipe, automatically loading the most recent checkpoint, +training the model in parallel using DistributedDataParallel across multiple +GPUs, calculating the loss and updating model parameters using mixed precision. +This is a common recipe that enables training of combined models for surface and +volume as well either of them separately. Validation is also conducted every epoch, +where predictions are compared against ground truth values. The code logs training +and validation metrics to TensorBoard. The train tab in config.yaml can be used to +specify batch size, number of epochs and other training parameters. +""" + +import time +import os +import re +import torch +import torchinfo + +import apex +import numpy as np +import hydra +from hydra.utils import to_absolute_path +from omegaconf import DictConfig, OmegaConf + +from physicsnemo.distributed import register_custom_ops +from physicsnemo.distributed import ShardTensor + +register_custom_ops() + +from torch.cuda.amp import GradScaler, autocast + +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, + ShardingStrategy, +) + +from contextlib import nullcontext + +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch.distributed.tensor import distribute_module +from torch.utils.tensorboard import SummaryWriter +from nvtx import annotate as nvtx_annotate +import torch.cuda.nvtx as nvtx + +from physicsnemo.distributed import DistributedManager +from physicsnemo.launch.utils import load_checkpoint, save_checkpoint +from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper + +from physicsnemo.datapipes.cae.domino_datapipe import ( + compute_scaling_factors, + create_domino_dataset, +) +from physicsnemo.datapipes.cae.domino_sharded_datapipe import ( + create_sharded_domino_dataset, +) + +from physicsnemo.models.domino.model import DoMINO +from physicsnemo.utils.domino.utils import * + +# Bring these from the single-gpu script. +from train import ( + compute_loss_dict, +) + +# This is included for GPU memory tracking: +from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo +import time + +# Initialize NVML +nvmlInit() + + +from physicsnemo.utils.profiling import profile, Profiler + + +def validation_step( + dataloader, + model, + device, + use_sdf_basis=False, + use_surface_normals=False, + integral_scaling_factor=1.0, + loss_fn_type=None, + vol_loss_scaling=None, + surf_loss_scaling=None, +): + running_vloss = 0.0 + with torch.no_grad(): + for i_batch, sampled_batched in enumerate(dataloader): + # sampled_batched = dict_to_device(sample_batched, device) + + with autocast(enabled=True): + + prediction_vol, prediction_surf = model(sampled_batched) + loss, loss_dict = compute_loss_dict( + prediction_vol, + prediction_surf, + sampled_batched, + loss_fn_type, + integral_scaling_factor, + surf_loss_scaling, + vol_loss_scaling, + ) + running_vloss += loss.full_tensor() + + avg_vloss = running_vloss / (i_batch + 1) + + return avg_vloss.item() + + +@profile +def train_epoch( + dataloader, + model, + optimizer, + scaler, + tb_writer, + logger, + gpu_handles, + epoch_index, + device, + integral_scaling_factor, + loss_fn_type, + vol_loss_scaling=None, + surf_loss_scaling=None, +): + + dist = DistributedManager() + + running_loss = 0.0 + last_loss = 0.0 + loss_interval = 1 + + gpu_start_info = [nvmlDeviceGetMemoryInfo(gpu_handle) for gpu_handle in gpu_handles] + start_time = time.perf_counter() + for i_batch, sample_batched in enumerate(dataloader): + + sampled_batched = sample_batched + + with autocast(enabled=True): + with nvtx.range("Model Forward Pass"): + prediction_vol, prediction_surf = model(sampled_batched) + + nvtx.range_push("Loss Calculation") + # The loss calculation is the same as singel GPU + loss, loss_dict = compute_loss_dict( + prediction_vol, + prediction_surf, + sampled_batched, + loss_fn_type, + integral_scaling_factor, + surf_loss_scaling, + vol_loss_scaling, + ) + + loss = loss / loss_interval + scaler.scale(loss).backward() + + if ((i_batch + 1) % loss_interval == 0) or (i_batch + 1 == len(dataloader)): + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + # Gather data and report + running_loss += loss.full_tensor().item() + + gpu_end_info = [ + nvmlDeviceGetMemoryInfo(gpu_handle) for gpu_handle in gpu_handles + ] + gpu_memory_used = [ + gpu_end_info.used / (1024**3) for gpu_end_info in gpu_end_info + ] + gpu_memory_delta = [ + (gpu_end_info.used - gpu_start_info.used) / (1024**3) + for gpu_end_info, gpu_start_info in zip(gpu_end_info, gpu_start_info) + ] + elapsed_time = time.perf_counter() - start_time + start_time = time.perf_counter() + logging_string = f"Device {device}, batch processed: {i_batch + 1}\n" + + # Format the loss dict into a string (use full_tensor to reduce across the domain.): + # **** Note **** + # We have to use full_tensor to reduce across the domain. + # You could use `.to_local()` to use just the local gpus version. + # the full_tensor() reduction is only over the mesh domain.` + loss_string = ( + " " + + "\t".join([f"{key.replace('loss_',''):<10}" for key in loss_dict.keys()]) + + "\n" + ) + loss_string += ( + " " + + f"\t".join( + [f"{l.full_tensor().item():<10.2f}" for l in loss_dict.values()] + ) + + "\n" + ) + logging_string += loss_string + + mem_used_str = " ".join( + [f"{gpu_memory_used[i]:.2f}" for i in range(len(gpu_memory_used))] + ) + mem_delta_str = " ".join( + [f"{gpu_memory_delta[i]:.2f}" for i in range(len(gpu_memory_delta))] + ) + logging_string += f" GPU memory used: {mem_used_str} Gb\n" + logging_string += f" GPU memory delta: {mem_delta_str} Gb\n" + logging_string += f" Elapsed time: {elapsed_time:.2f} seconds\n" + logger.info(logging_string) + gpu_start_info = [ + nvmlDeviceGetMemoryInfo(gpu_handle) for gpu_handle in gpu_handles + ] + + last_loss = running_loss / (i_batch + 1) # loss per batch + if dist.rank == 0: + logger.info( + f" Device {device}, batch: {i_batch + 1}, loss norm: {loss.full_tensor().item():.5f}" + ) + tb_x = epoch_index * len(dataloader) + i_batch + 1 + tb_writer.add_scalar("Loss/train", last_loss, tb_x) + + return last_loss + + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: + + # initialize distributed manager + DistributedManager.initialize() + dist = DistributedManager() + + # Use this to monitor GPU memory usage for visible GPUs: + gpu_count = torch.cuda.device_count() + gpu_handles = [nvmlDeviceGetHandleByIndex(i) for i in range(gpu_count)] + + ################################# + # Mesh Creation + # For Sharded training, we utilize pytorch's device mesh. + # The distributed manager can create it for us. We'll use a mesh + # with two devices and the rest of the GPUs are the data-parallel + # dimension. + ################################# + + # The global mesh represents all the GPUs in the process, in a multi-dimensional grid. + # Think of the global mesh as a tensor, with rank = len(mesh_shape) + domain_size = int(cfg.domain_parallelism.domain_size) + # You can use -1 to one axis to indicate that you want to use all the GPUs in that dimension. + mesh = dist.initialize_mesh( + mesh_shape=(-1, domain_size), mesh_dim_names=("ddp", "domain") + ) + # This is a subset of all the GPUs, and will vary depending on the process. + # Think of this as slicing the global mesh along the domain axis. + # It will contain only the GPUs that this process is sharing data with. + domain_mesh = mesh["domain"] + + compute_scaling_factors( + cfg, cfg.data_processor.output_dir, use_cache=cfg.data_processor.use_cache + ) + model_type = cfg.model.model_type + + logger = PythonLogger("Train") + logger = RankZeroLoggingWrapper(logger, dist) + + logger.info(f"Config summary:\n{OmegaConf.to_yaml(cfg, sort_keys=True)}") + + num_vol_vars = 0 + volume_variable_names = [] + if model_type == "volume" or model_type == "combined": + volume_variable_names = list(cfg.variables.volume.solution.keys()) + for j in volume_variable_names: + if cfg.variables.volume.solution[j] == "vector": + num_vol_vars += 3 + else: + num_vol_vars += 1 + else: + num_vol_vars = None + + num_surf_vars = 0 + surface_variable_names = [] + if model_type == "surface" or model_type == "combined": + surface_variable_names = list(cfg.variables.surface.solution.keys()) + num_surf_vars = 0 + for j in surface_variable_names: + if cfg.variables.surface.solution[j] == "vector": + num_surf_vars += 3 + else: + num_surf_vars += 1 + else: + num_surf_vars = None + + vol_save_path = os.path.join( + "outputs", cfg.project.name, "volume_scaling_factors.npy" + ) + surf_save_path = os.path.join( + "outputs", cfg.project.name, "surface_scaling_factors.npy" + ) + if os.path.exists(vol_save_path): + vol_factors = np.load(vol_save_path) + else: + vol_factors = None + + if os.path.exists(surf_save_path): + surf_factors = np.load(surf_save_path) + else: + surf_factors = None + + train_dataset = create_domino_dataset( + cfg, + "train", + volume_variable_names, + surface_variable_names, + vol_factors, + surf_factors, + ) + val_dataset = create_domino_dataset( + cfg, + "val", + volume_variable_names, + surface_variable_names, + vol_factors, + surf_factors, + ) + + ################################# + # Using a Sharded Dataset + ################################# + # Physicsnemo has a built-in wrapper for the DoMino dataset + # that allows for sharding the dataset across multiple GPUs. + # (it's nothing fancy - each rank that shares data loads the entire image, + # and then slices to it's own chunks) + train_dataset = create_sharded_domino_dataset( + train_dataset, + domain_mesh, # The dataloader needs to know the mesh for sharing data. + shard_point_cloud=cfg.domain_parallelism.shard_points, # We can shard the point + shard_grid=cfg.domain_parallelism.shard_grid, # Or the grid (or both) + ) + + val_dataset = create_sharded_domino_dataset( + val_dataset, + domain_mesh, + shard_point_cloud=cfg.domain_parallelism.shard_points, + shard_grid=cfg.domain_parallelism.shard_grid, + ) + + # The distributed sampler needs to know that the dataset is not + # being used in a usual way. We have to tell it how many "real" + # times the dataset is sharded (world size / shard_size). + # It also needs to know its rank in the global "ddp" dimension. + sampler_num_replicas = mesh["ddp"].size() + sampler_rank = mesh["ddp"].get_local_rank() + + train_sampler = DistributedSampler( + train_dataset, + num_replicas=sampler_num_replicas, + rank=sampler_rank, + **cfg.train.sampler, + ) + + val_sampler = DistributedSampler( + val_dataset, + num_replicas=sampler_num_replicas, + rank=sampler_rank, + **cfg.val.sampler, + ) + + train_dataloader = DataLoader( + train_dataset, + sampler=train_sampler, + **cfg.train.dataloader, + ) + val_dataloader = DataLoader( + val_dataset, + sampler=val_sampler, + **cfg.val.dataloader, + ) + + model = DoMINO( + input_features=3, + output_features_vol=num_vol_vars, + output_features_surf=num_surf_vars, + model_parameters=cfg.model, + ).to(dist.device) + model = torch.compile(model, disable=True) # TODO make this configurable + + # Print model summary (structure and parmeter count). + logger.info(f"Model summary:\n{torchinfo.summary(model, verbose=0, depth=2)}\n") + + if dist.world_size > 1: + # Instead of DDP, for sharding we use FSDP. It's possible to use FSDP in the DDP + # mode, but since it's not pure data parallel we have to me more careful. + + # First, distribute the model so that each GPU has the copy with DTensor weights: + model = distribute_module(model, domain_mesh) + + model = FSDP( + model, + device_mesh=mesh["ddp"], + sharding_strategy=ShardingStrategy.NO_SHARD, + ) + + # optimizer = apex.optimizers.FusedAdam(model.parameters(), lr=0.001) + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, milestones=[100, 200, 300, 400, 500, 600, 700, 800], gamma=0.5 + ) + + # Initialize the scaler for mixed precision + scaler = GradScaler() + + writer = SummaryWriter(os.path.join(cfg.output, "tensorboard")) + + epoch_number = 0 + + model_save_path = os.path.join(cfg.output, "models") + param_save_path = os.path.join(cfg.output, "param") + best_model_path = os.path.join(model_save_path, "best_model") + if dist.rank == 0: + create_directory(model_save_path) + create_directory(param_save_path) + create_directory(best_model_path) + + if dist.world_size > 1: + torch.distributed.barrier() + + init_epoch = load_checkpoint( + to_absolute_path(cfg.resume_dir), + models=model, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + device=dist.device, + ) + + if init_epoch != 0: + init_epoch += 1 # Start with the next epoch + epoch_number = init_epoch + + # retrive the smallest validation loss if available + numbers = [] + for filename in os.listdir(best_model_path): + match = re.search(r"\d+\.\d*[1-9]\d*", filename) + if match: + number = float(match.group(0)) + numbers.append(number) + + best_vloss = min(numbers) if numbers else 1_000_000.0 + + initial_integral_factor_orig = cfg.model.integral_loss_scaling_factor + + for epoch in range(init_epoch, cfg.train.epochs): + start_time = time.perf_counter() + logger.info(f"Device {dist.device}, epoch {epoch_number}:") + + train_sampler.set_epoch(epoch) + val_sampler.set_epoch(epoch) + + initial_integral_factor = initial_integral_factor_orig + + if epoch > 250: + surface_scaling_loss = 1.0 * cfg.model.surf_loss_scaling + else: + surface_scaling_loss = cfg.model.surf_loss_scaling + + model.train(True) + epoch_start_time = time.perf_counter() + avg_loss = train_epoch( + dataloader=train_dataloader, + model=model, + optimizer=optimizer, + scaler=scaler, + tb_writer=writer, + logger=logger, + gpu_handles=gpu_handles, + epoch_index=epoch, + device=dist.device, + integral_scaling_factor=initial_integral_factor, + loss_fn_type=cfg.model.loss_function, + vol_loss_scaling=cfg.model.vol_loss_scaling, + surf_loss_scaling=surface_scaling_loss, + ) + epoch_end_time = time.perf_counter() + logger.info( + f"Device {dist.device}, Epoch {epoch_number} took {epoch_end_time - epoch_start_time:.3f} seconds" + ) + + model.eval() + avg_vloss = validation_step( + dataloader=val_dataloader, + model=model, + device=dist.device, + use_sdf_basis=cfg.model.use_sdf_in_basis_func, + use_surface_normals=cfg.model.use_surface_normals, + integral_scaling_factor=initial_integral_factor, + loss_fn_type=cfg.model.loss_function, + vol_loss_scaling=cfg.model.vol_loss_scaling, + surf_loss_scaling=surface_scaling_loss, + ) + + scheduler.step() + logger.info( + f"Device {dist.device} " + f"LOSS train {avg_loss:.5f} " + f"valid {avg_vloss:.5f} " + f"Current lr {scheduler.get_last_lr()[0]}" + f"Integral factor {initial_integral_factor}" + ) + + if dist.rank == 0: + writer.add_scalars( + "Training vs. Validation Loss", + { + "Training": avg_loss, + # "Validation": avg_vloss + }, + epoch_number, + ) + writer.flush() + + # Track best performance, and save the model's state + if dist.world_size > 1: + torch.distributed.barrier() + + if avg_vloss < best_vloss: # This only considers GPU: 0, is that okay? + best_vloss = avg_vloss + # if dist.rank == 0: + save_checkpoint( + to_absolute_path(best_model_path), + models=model, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + epoch=str(best_vloss), # hacky way of using epoch to store metadata + ) + if dist.rank == 0: + print( + f"Device { dist.device}, Best val loss {best_vloss}, Time taken {time.perf_counter() - start_time:.3f}" + ) + + if dist.rank == 0 and (epoch + 1) % cfg.train.checkpoint_interval == 0.0: + save_checkpoint( + to_absolute_path(model_save_path), + models=model, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + epoch=epoch, + ) + + epoch_number += 1 + + if scheduler.get_last_lr()[0] == 1e-6: + print("Training ended") + exit() + + +if __name__ == "__main__": + main() diff --git a/examples/cfd/stokes_mgn/raw_dataset/download_dataset.sh b/examples/cfd/stokes_mgn/raw_dataset/download_dataset.sh index 1f5e3ebb5d..fee49136a9 100644 --- a/examples/cfd/stokes_mgn/raw_dataset/download_dataset.sh +++ b/examples/cfd/stokes_mgn/raw_dataset/download_dataset.sh @@ -17,7 +17,7 @@ Download Stokes flow dataset """ -wget --content-disposition 'https://api.ngc.nvidia.com/v2/resources/org/nvidia/team/modulus/modulus_datasets-stokes-flow/0.0/files?redirect=true&path=results_polygon.zip' -O results_polygon.zip +wget --content-disposition 'https://api.ngc.nvidia.com/v2/resources/org/nvidia/team/physicsnemo/modulus_datasets-stokes-flow/0.0/files?redirect=true&path=results_polygon.zip' -O results_polygon.zip unzip results_polygon.zip mv results ../ rm results_polygon.zip \ No newline at end of file diff --git a/examples/cfd/vortex_shedding_mesh_reduced/README.md b/examples/cfd/vortex_shedding_mesh_reduced/README.md index 4a9839eb97..036f82c785 100644 --- a/examples/cfd/vortex_shedding_mesh_reduced/README.md +++ b/examples/cfd/vortex_shedding_mesh_reduced/README.md @@ -61,8 +61,8 @@ per GPU is set to 10 for the sequence model training. Traing epochs is set as 20 To download the data , run ```bash -wget --content-disposition https://api.ngc.nvidia.com/v2/resources/nvidia/modulus/modulus_datasets_cylinder-flow/versions/v1/zip -O modulus_datasets_cylinder-flow_v1.zip -unzip modulus_datasets_cylinder-flow_v1.zip +wget --content-disposition https://api.ngc.nvidia.com/v2/resources/nvidia/physicsnemo/modulus_datasets_cylinder-flow/versions/v1/zip -O physicsnemo_datasets_cylinder-flow_v1.zip +unzip physicsnemo_datasets_cylinder-flow_v1.zip unzip dataset.zip ``` diff --git a/examples/healthcare/bloodflow_1d_mgn/raw_dataset/download_dataset.sh b/examples/healthcare/bloodflow_1d_mgn/raw_dataset/download_dataset.sh index c524d75e9d..a1a8dbbe67 100644 --- a/examples/healthcare/bloodflow_1d_mgn/raw_dataset/download_dataset.sh +++ b/examples/healthcare/bloodflow_1d_mgn/raw_dataset/download_dataset.sh @@ -17,8 +17,8 @@ Download dataset """ -wget --content-disposition https://api.ngc.nvidia.com/v2/resources/nvidia/modulus/modulus_datasets-cardiovascular-simulation/versions/0.0/zip -O modulus_datasets-cardiovascular-simulation_0.0.zip -unzip modulus_datasets-cardiovascular-simulation_0.0.zip +wget --content-disposition https://api.ngc.nvidia.com/v2/resources/nvidia/physicsnemo/modulus_datasets-cardiovascular-simulation/versions/0.0/zip -O physicsnemo_datasets-cardiovascular-simulation_0.0.zip +unzip physicsnemo_datasets-cardiovascular-simulation_0.0.zip unzip cardiovascular_dataset.zip mv cardiovascular_dataset/* . rm -r cardiovascular_dataset diff --git a/physicsnemo/distributed/__init__.py b/physicsnemo/distributed/__init__.py index bd30a617ba..ca8c1f183e 100644 --- a/physicsnemo/distributed/__init__.py +++ b/physicsnemo/distributed/__init__.py @@ -48,11 +48,13 @@ def register_custom_ops(): # These imports will register the custom ops with the ShardTensor class. # It's done here to avoid an import cycle. from .custom_ops import ( - sharded_mean_wrapper, + register_reduction_functions, unbind_rules, ) from .shard_utils import register_shard_wrappers + register_reduction_functions() + register_shard_wrappers() except ImportError: diff --git a/physicsnemo/distributed/_shard_redistribute.py b/physicsnemo/distributed/_shard_redistribute.py index dcf916321f..2ffced2066 100644 --- a/physicsnemo/distributed/_shard_redistribute.py +++ b/physicsnemo/distributed/_shard_redistribute.py @@ -22,9 +22,6 @@ import torch.distributed._functional_collectives as funcol from torch.distributed.device_mesh import DeviceMesh -from physicsnemo.distributed.autograd import ( - all_gather_v, -) from physicsnemo.utils.version_check import check_module_requirements # This is to make sure the torch minimum version is installed. @@ -78,22 +75,34 @@ def _to_replicate_tensor( Note: This function handles uneven sharding by using all_gather_v instead of regular all_gather """ - # Get the mesh for the group: mesh = current_spec.mesh group = mesh.get_group(mesh_dim) - # Get all sizes: - sizes = current_spec.sharding_sizes() - this_sizes = tuple(s[tensor_dim] for s in sizes[mesh_dim]) - # # Ensure contiguous data for the reduction: + # Ensure contiguous data for the reduction: local_tensor = local_tensor.contiguous() - # We can implement this with a straightforward allgather_v - local_tensor = all_gather_v( - local_tensor, sizes=this_sizes, dim=tensor_dim, group=group - ) - return local_tensor + # # Get all sizes: + # TODO: We don't need to summon all sizes across all mesh dimensions. + # Optimize the spec function to only get the sizes for the relevant mesh dimensions. + sizes = current_spec.sharding_shapes() + + # Consecutive redistributes _don't_ update full sizes. + # So, extract the shape from this tensor, and assume all other tensor + # dims match. + tensor_dim_shapes = tuple(s[tensor_dim] for s in sizes[mesh_dim]) + base_shapes = [list(local_tensor.shape) for _ in tensor_dim_shapes] + for i, t in enumerate(tensor_dim_shapes): + base_shapes[i][tensor_dim] = tensor_dim_shapes[i] + + # Create a spot for the output: + output = [ + torch.empty(s, device=local_tensor.device, dtype=local_tensor.dtype) + for s in base_shapes + ] + dist.all_gather(output, local_tensor, group=group) + + return torch.cat(output, dim=tensor_dim).contiguous() def _select_slice_from_replicate( @@ -121,7 +130,6 @@ def _select_slice_from_replicate( """ # TODO - This needs a rework to enable caching of shapes for a grad pass. - # We really only need the sizes from this dimension: tensor_dim = target_spec.placements[mesh_dim].dim mesh_size = target_spec.mesh.size(mesh_dim=mesh_dim) @@ -132,10 +140,16 @@ def _select_slice_from_replicate( # Split the tensor: if sizes is None: - chunks = torch.tensor_split(local_tensor, mesh_size, dim=tensor_dim) + # Use chunk, not split, when dividing without a plan + chunks = torch.chunk(local_tensor, mesh_size, dim=tensor_dim) else: - chunks = torch.tensor_split(local_tensor, sizes[:-1], dim=tensor_dim) - + # Convert sizes to cumulative sum using basic Python + chunk_starts = [] + running_sum = 0 + for size in sizes[:-1]: + running_sum += size + chunk_starts.append(running_sum) + chunks = torch.tensor_split(local_tensor, chunk_starts, dim=tensor_dim) return chunks[mesh_coord], sizes @@ -166,7 +180,7 @@ def _to_new_shard_dim( # First, we need to split the tensor along the target dimension: if size_hint is None: - chunks = torch.tensor_split(local_tensor, mesh_size, dim=target_dim) + chunks = torch.chunk(local_tensor, mesh_size, dim=target_dim) else: chunk_starts = list(accumulate(size_hint)) chunks = torch.tensor_split(local_tensor, chunk_starts[:-1], dim=target_dim) @@ -214,7 +228,7 @@ def redistribute_local_shard_tensor( *, async_op: bool = False, is_backward: bool = False, - target_sharding_sizes: Optional[dict[int, Tuple[torch.Size, ...]]] = {}, + target_sharding_shapes: Optional[dict[int, Tuple[torch.Size, ...]]] = {}, ) -> torch.Tensor: """ This redistribute the local tensor (torch.Tensor) from the current ShardTensorSpec to @@ -251,12 +265,12 @@ def redistribute_local_shard_tensor( # as the current, but is missing sharding sizes, we can use the current spec's sharding sizes. # if target_spec._sharding_sizes is None: # if target_spec.placements == current_spec.placements and target_spec.mesh == current_spec.mesh: - # target_spec._sharding_sizes = current_spec.sharding_sizes() + # target_spec._sharding_sizes = current_spec.sharding_shapes() # For sharded tensors, we use the same order of transformation as DTensor. # However, often we need to ignore the provided logical shape and substitute # a sharded shape instead. - # This is done by providing a target_sharding_sizes dict above. + # This is done by providing a target_sharding_shapes dict above. transform_infos = _gen_transform_infos(current_spec, target_spec) @@ -264,7 +278,6 @@ def redistribute_local_shard_tensor( return local_tensor for transform_info in transform_infos: - dist.barrier() i = transform_info.mesh_dim current, target = transform_info.src_dst_placements device_mesh.size(mesh_dim=i) @@ -306,8 +319,8 @@ def redistribute_local_shard_tensor( elif current.is_replicate(): # split the tensor and return the corresponding cloned local shard # Are there suggested placements for the shards? - if target_placement.dim in target_sharding_sizes: - size_hint = target_sharding_sizes[target_placement.dim] + if target_placement.dim in target_sharding_shapes: + size_hint = target_sharding_shapes[target_placement.dim] else: size_hint = None new_local_tensor, size_hint = _select_slice_from_replicate( @@ -319,9 +332,9 @@ def redistribute_local_shard_tensor( ) if ( size_hint is not None - and target_placement.dim in target_sharding_sizes + and target_placement.dim in target_sharding_shapes ): - target_sharding_sizes[target_placement.dim] = size_hint + target_sharding_shapes[target_placement.dim] = size_hint else: if not current.is_shard(): @@ -335,8 +348,8 @@ def redistribute_local_shard_tensor( # So, if the target tensor dimension is in there, # That is how we're going to shard the local tensor on the tensor_dim, # and it also defines how we'll receive the tensor . - if target_placement.dim in target_sharding_sizes: - size_hint = target_sharding_sizes[target_placement.dim] + if target_placement.dim in target_sharding_shapes: + size_hint = target_sharding_shapes[target_placement.dim] else: size_hint = None @@ -350,11 +363,11 @@ def redistribute_local_shard_tensor( ) if ( size_hint is None - and target_placement.dim in target_sharding_sizes + and target_placement.dim in target_sharding_shapes ): - target_sharding_sizes.pop(target_placement.dim) - if size_hint is not None and current.dim in target_sharding_sizes: - target_sharding_sizes.pop(current.dim) + target_sharding_shapes.pop(target_placement.dim) + if size_hint is not None and current.dim in target_sharding_shapes: + target_sharding_shapes.pop(current.dim) elif target.is_partial(): if current.is_replicate(): @@ -409,7 +422,7 @@ def get_tensor_sharding_shapes_by_dim( Generate a target spec from the current spec and target_placements. """ - target_sharding_sizes = {} + target_sharding_shapes = {} # Look through the target placements for shardings: for target_mesh_dim, target_placement in enumerate(target_placements): if isinstance(target_placement, Shard): @@ -427,12 +440,12 @@ def get_tensor_sharding_shapes_by_dim( # The tensor dim is the same in both current and target, # But the rest of the tensors dimensions may change. # Therefore only save the dimension on this axis. - current_shardings = current_spec.sharding_sizes()[current_mesh_dim] - target_sharding_sizes[target_tensor_dim] = [ + current_shardings = current_spec.sharding_shapes()[current_mesh_dim] + target_sharding_shapes[target_tensor_dim] = [ c[target_tensor_dim] for c in current_shardings ] - return target_sharding_sizes + return target_sharding_shapes class ShardRedistribute(torch.autograd.Function): @@ -472,8 +485,8 @@ def forward( if current_spec.placements != placements: - # We have to assume, here, that the current spec has correct sharding_sizes. - # Therefore, we can use the target placement + current sharding_sizes + # We have to assume, here, that the current spec has correct sharding_shapes. + # Therefore, we can use the target placement + current sharding_shapes # to get the target sharding sizes correctly. # target_spec = generate_target_spec_from_current_and_placements( @@ -489,17 +502,17 @@ def forward( # The target sharding sizes are potentially incomplete. # They're only provided for shardings that are the same in input/output. - target_sharding_sizes = get_tensor_sharding_shapes_by_dim( + target_sharding_shapes = get_tensor_sharding_shapes_by_dim( current_spec, placements ) - # ctx.target_sharding_sizes = target_sharding_sizes + # ctx.target_sharding_shapes = target_sharding_shapes local_tensor = input._local_tensor output = redistribute_local_shard_tensor( local_tensor, current_spec, target_spec, async_op=async_op, - target_sharding_sizes=target_sharding_sizes, + target_sharding_shapes=target_sharding_shapes, ) # Set the local shape: target_spec._local_shape = output.shape @@ -509,7 +522,7 @@ def forward( target_spec = current_spec return shard_tensor.ShardTensor( - output, + output.contiguous(), target_spec, requires_grad=input.requires_grad, ) @@ -539,17 +552,16 @@ def backward( async_op = ctx.async_op local_tensor = grad_output._local_tensor - target_sharding_sizes = get_tensor_sharding_shapes_by_dim( + target_sharding_shapes = get_tensor_sharding_shapes_by_dim( previous_spec, previous_spec.placements ) - output = redistribute_local_shard_tensor( local_tensor, current_spec, previous_spec, async_op=async_op, is_backward=True, - target_sharding_sizes=target_sharding_sizes, + target_sharding_shapes=target_sharding_shapes, ) # normalize the target placement to replicate if it is partial @@ -576,7 +588,6 @@ def backward( spec, requires_grad=grad_output.requires_grad, ) - return ( output_shard_tensor, None, diff --git a/physicsnemo/distributed/_shard_tensor_spec.py b/physicsnemo/distributed/_shard_tensor_spec.py index 59a120858e..60aefeac5f 100644 --- a/physicsnemo/distributed/_shard_tensor_spec.py +++ b/physicsnemo/distributed/_shard_tensor_spec.py @@ -15,12 +15,13 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist from torch.distributed.device_mesh import DeviceMesh +from physicsnemo.distributed.utils import compute_split_shapes from physicsnemo.utils.version_check import check_module_requirements check_module_requirements("physicsnemo.distributed.shard_tensor") @@ -47,15 +48,15 @@ class ShardTensorSpec(DTensorSpec): ---------- _local_shape : Optional[torch.Size] The shape of the local shard of the tensor - _sharding_sizes : Optional[dict[int, Tuple[torch.Size, ...]]] - Mapping from mesh dimension to shard sizes. Keys are mesh dimensions, + _sharding_shapes : Optional[dict[int, Tuple[torch.Size, ...]]] + Mapping from mesh dimension to shard shapes. Keys are mesh dimensions, values are tuples of torch.Size representing shard shapes along that dimension. - Shard sizes are only tracked along the sharded dimensions, not replicated dimensions. + Shard shapes are only tracked along the sharded dimensions, not replicated dimensions. """ _local_shape: Optional[torch.Size] = field(default_factory=lambda: None) - # This dict is a mapping from the mesh dimension to the shard sizes, _not_ the tensor index - _sharding_sizes: Optional[dict[int, Tuple[torch.Size, ...]]] = field( + # This dict is a mapping from the mesh dimension to the shard shapes, _not_ the tensor index + _sharding_shapes: Optional[dict[int, Tuple[torch.Size, ...]]] = field( default_factory=lambda: None ) @@ -67,7 +68,7 @@ def _hash_impl(self) -> int: Returns ------- int - Hash value incorporating mesh, placements, tensor metadata and sharding sizes + Hash value incorporating mesh, placements, tensor metadata and sharding shapes """ hash_items = [] @@ -78,8 +79,8 @@ def _hash_impl(self) -> int: hash_items.append(self.tensor_meta.shape) hash_items.append(self.tensor_meta.stride) hash_items.append(self.tensor_meta.dtype) - if self._sharding_sizes is not None: - hash_items.append(tuple(sorted(self._sharding_sizes.items()))) + if self._sharding_shapes is not None: + hash_items.append(tuple(sorted(self._sharding_shapes.items()))) hash_tuple = tuple(hash_items) return hash(hash_tuple) @@ -92,31 +93,39 @@ def __hash__(self) -> int: self._hash = self._hash_impl() return self._hash - def sharding_sizes( + def sharding_shapes( self, mesh_dim: Optional[int] = None ) -> dict[int, Tuple[torch.Size, ...]] | Tuple[torch.Size, ...]: - """Get the sizes of shards along specified mesh dimensions. + """Get the shapes of shards along specified mesh dimensions. Parameters ---------- mesh_dim : Optional[int] - If provided, return sizes only for this mesh dimension + If provided, return shapes only for this mesh dimension Returns ------- dict[int, Tuple[torch.Size, ...]] | Tuple[torch.Size, ...] - Dictionary of shard sizes by mesh dim, or tuple of sizes for specific dim + Dictionary of shard shapes by mesh dim, or tuple of shapes for specific dim """ - if self._sharding_sizes is None: - shard_shapes_by_dim, global_shape = _all_gather_shard_shapes( - self._local_shape, self.placements, self.mesh - ) - self._sharding_sizes = shard_shapes_by_dim - self.tensor_meta = self.tensor_meta._replace(shape=global_shape) + if self._sharding_shapes is None: + if mesh_dim is None: + shard_shapes_by_dim, global_shape = _all_gather_shard_shapes( + self._local_shape, self.placements, self.mesh + ) + self._sharding_shapes = shard_shapes_by_dim + self.tensor_meta = self.tensor_meta._replace(shape=global_shape) + else: + return _gather_shard_shapes_for_dim( + self._local_shape, + mesh_dim, + self.mesh.get_group(mesh_dim), + do_checks=False, + ) if mesh_dim is not None: - if mesh_dim in self._sharding_sizes: - return self._sharding_sizes[mesh_dim] - return self._sharding_sizes + if mesh_dim in self._sharding_shapes: + return self._sharding_shapes[mesh_dim] + return self._sharding_shapes def __eq__(self, other: object) -> bool: """Check if two ShardTensorSpecs are equal. @@ -130,7 +139,7 @@ def __eq__(self, other: object) -> bool: return False if not super().__eq__(other): return False - if self._sharding_sizes != other._sharding_sizes: + if self._sharding_shapes != other._sharding_shapes: return False return True @@ -193,7 +202,7 @@ def offsets(self, mesh_dim: Optional[int] = None) -> Tuple[int, ...] | int: placement = self.placements[loop_mesh_dim] # If the placement is not shard, offset is 0: if isinstance(placement, Shard): - shards = self._sharding_sizes[loop_mesh_dim] + shards = self._sharding_shapes[loop_mesh_dim] tensor_dim = placement.dim o = sum([s[tensor_dim] for s in shards[:coord]]) offsets.append(o) @@ -244,50 +253,89 @@ def _stride_from_contiguous_shape_C_style( return stride +def _gather_shard_shapes_for_dim( + local_shape: Union[torch.Size, torch.Tensor], + tensor_dim: int, + local_group: dist.ProcessGroup, + do_checks: bool = False, +) -> Tuple[torch.Tensor, ...]: + """Gather tensor shapes from all ranks in a process group for a given dimension. + + This function collects the shapes of tensor shards from all ranks in a process group + and performs optional validation checks on the gathered shapes. Uses NCCL, which requires + two way transfers between host and device. + + Args: + local_shape: Shape of the local tensor shard, either as torch.Size or torch.Tensor + tensor_dim: The tensor dimension being sharded + local_group: Process group to gather shapes from + do_checks: Whether to validate shape consistency across ranks + + Returns: + Tuple of torch.Sizes containing gathered shapes from all ranks + + Raises: + ValueError: If shape validation fails when do_checks=True + - Ranks have different tensor dimensions + - Non-sharded dimensions don't match across ranks + """ + local_size = dist.get_world_size(group=local_group) + + if not isinstance(local_shape, torch.Tensor): + shape = torch.tensor(local_shape, device="cpu", pin_memory=True) + + local_shape = shape.to(device="cuda", non_blocking=True) + + all_shapes = [ + torch.zeros_like(local_shape, device="cuda") for _ in range(local_size) + ] + + dist.all_gather(all_shapes, local_shape, group=local_group) + + all_shapes = [torch.Size(s.cpu().tolist()) for s in all_shapes] + + if do_checks: + # Check that all shapes are the same rank + if not all(len(local_shape) == len(all_s) for all_s in all_shapes): + raise ValueError( + "Rank mismatch detected when attempting to infer shapes and sizes" + ) + + # Every dimension must be equal for this list, along the sharded axis + for d in range(len(local_shape)): + if d == tensor_dim: + continue # skip the sharded dimension + if not all([local_shape[d] == all_s[d] for all_s in all_shapes]): + raise ValueError( + f"Dimension mismatch detected at non-sharded dimension {d}. " + "All local shapes must match except along sharded dimension." + ) + + return tuple(all_shapes) + + def _all_gather_shard_shapes( - local_shape: Tuple[int], + local_shape: torch.Size, placements: Tuple[Placement], target_mesh: DeviceMesh, + do_checks: bool = False, ): - shard_shapes_by_dim = {} global_shape = [s for s in local_shape] # We start by assuming the global shape is the local shape and fix it on sharded axes for mesh_axis, placement in enumerate(placements): if isinstance(placement, Shard): - tensor_dim = placement.dim + tensor_dim = placement.dim local_group = target_mesh.get_group(mesh_axis) - local_size = dist.get_world_size(group=local_group) - - all_shapes = [torch.Size()] * local_size - - # First, allgather the dimensions of each tensor to each rank: - # Possible collective of CPU-based objects! Could be slow if using separate hosts! - dist.all_gather_object(all_shapes, local_shape, local_group) - - # Check that all shapes are the same rank: - if not all([len(local_shape) == len(all_s) for all_s in all_shapes]): - raise ValueError( - "Rank mismatch detected when attempting to infer shapes and sizes" - ) - - # Every dimension must be equal for this list, along the sharded axis - for d in range(len(local_shape)): - if d == tensor_dim: - continue # skip the sharded dimension - if not all([local_shape[d] == all_s[d] for all_s in all_shapes]): - raise ValueError( - f"Dimension mismatch detected at non-sharded dimension {d}. " - "All local shapes must match except along sharded dimension." - ) - - # Build a list of local torch.Size on this axis for each shard to store: + shard_shapes_for_dim = _gather_shard_shapes_for_dim( + local_shape, tensor_dim, local_group, do_checks + ) local_meta = tuple( # torch.Size(tuple(s)) for s in zip(all_shapes) - all_shapes + shard_shapes_for_dim ) shard_shapes_by_dim[mesh_axis] = local_meta @@ -296,15 +344,75 @@ def _all_gather_shard_shapes( # we have to loop over each axis in the rank list # To check what placement is there. # This assumes full sharding: - global_shape[tensor_dim] = sum([all_s[tensor_dim] for all_s in all_shapes]) + global_shape[tensor_dim] = sum([all_s[tensor_dim] for all_s in local_meta]) return shard_shapes_by_dim, tuple(global_shape) +def compute_sharding_shapes_from_chunking_global_shape( + mesh: DeviceMesh, + placements: Tuple[Placement, ...], + global_shape: Tuple[int, ...], +) -> Dict[int, List[torch.Size]]: + """Compute shard sizes for each mesh dimension based on global shape. + + For each sharded dimension in the mesh, computes the chunk sizes that would result + from evenly dividing the global tensor shape. Returns a mapping from mesh dimensions + to lists of torch.Size objects representing the shape of each shard. + + Args: + mesh: Device mesh defining the process topology + placements: Tuple of placement specifications for each mesh dimension + global_shape: Global shape of the full tensor before sharding + + Returns: + Dict mapping mesh dimensions to lists of torch.Size objects representing + shard shapes for that dimension + + Raises: + ValueError: If placements length doesn't match mesh dimensions + """ + if len(placements) != mesh.ndim: + raise ValueError("Number of placements must match mesh dimensions") + + # First compute raw chunk sizes for each sharded dimension + temp_sharding_shapes: Dict[int, List[int]] = {} + for i in range(mesh.ndim): + if isinstance(placements[i], Shard): + # Compute the chunk size for this dimension: + input_dim = global_shape[placements[i].dim] + chunked_shapes = compute_split_shapes(input_dim, mesh.size(i)) + # This needs to be a tuple of torch.Size + + temp_sharding_shapes[i] = chunked_shapes + + # Initialize shapes for all sharded dimensions, but using the global shape. + # We will update next. + sharding_shapes = { + mesh_dim: [list(global_shape) for _ in chunks] + for mesh_dim, chunks in temp_sharding_shapes.items() + } + + # For every sharded dimension, update the tensor size along it's axis for all mesh dims + for mesh_dim in temp_sharding_shapes.keys(): + placement = placements[mesh_dim] + tensor_dim = placement.dim + for i, shard_size in enumerate(temp_sharding_shapes[mesh_dim]): + sharding_shapes[mesh_dim][i][tensor_dim] = shard_size + + # Convert to immutable torch.Size + return { + mesh_dim: [torch.Size(tuple(size)) for size in sizes] + for mesh_dim, sizes in sharding_shapes.items() + } + + def _infer_shard_tensor_spec_from_local_chunks( local_chunk: torch.Tensor, target_mesh: DeviceMesh, placements: Tuple[Placement, ...], + sharding_shapes: Union[str, Dict[int, List[Tuple[int, ...]]]] = "chunk", + global_shape: Optional[Tuple[int, ...]] = None, ) -> ShardTensorSpec: """ Use local sizes, target mesh, and specified placements to build a @@ -326,25 +434,72 @@ def _infer_shard_tensor_spec_from_local_chunks( of this spec is that each ShardTensor knows the shape and size of other shards, and can compute global offsets and reductions properly """ - - # # Only accept sharding placements (not replications or partial (aka pending)) - # if not all([p.is_shard() for p in placements]): - # raise ValueError( - # "Shard Tensor will only infer shape and strides for sharded tensors," - # "for replication use DTensor" - # ) + # Sharding_shapes, if a string, must be one of "blocking_infer" "chunk" "infer" + if isinstance(sharding_shapes, str) and sharding_shapes not in [ + "chunk", + "infer", + ]: + raise ValueError( + "If sharding_shapes is a string, it must be one of: 'chunk', 'infer'" + ) + + # if sharding_shapes is a chunk, global_shape must be provided + if sharding_shapes == "chunk" and global_shape is None: + raise ValueError("If sharding_shapes is 'chunk', global_shape must be provided") + + # Check if sharding_shapes is an empty dict + if isinstance(sharding_shapes, dict) and not sharding_shapes: + # Raise an error only if the placements contains a shard: + if any(isinstance(placement, Shard) for placement in placements): + raise ValueError("sharding_shapes as a dict cannot be empty") # Need to infer the placements on each dimension of the mesh. if len(placements) != target_mesh.ndim: raise ValueError("Mesh dimension must match placements length") + # If sharding_shapes is chunk, compute the chunk sizes from the global shape + if isinstance(sharding_shapes, str): + if sharding_shapes == "chunk": + # This is communication-free. It's the path from a properly-formated DTensorSpec. + shard_shapes_by_dim = compute_sharding_shapes_from_chunking_global_shape( + target_mesh, + placements, + list(global_shape), + ) + # Basic sanity check, make sure the inferred shape matches the + # local shape on the first sharded mesh dimension + mesh_rank = None + for mesh_dim, p in enumerate(placements): + if isinstance(p, Shard): + mesh_rank = target_mesh.get_coordinate()[mesh_dim] + break + + if mesh_rank is not None: + inferred_local_shape = shard_shapes_by_dim[mesh_dim][mesh_rank] + if inferred_local_shape != local_chunk.shape: + raise ValueError( + f"Rank {dist.get_rank()} expected local shape {inferred_local_shape} does not match tensor's local shape {local_chunk.shape}" + ) - local_shape = local_chunk.shape + if sharding_shapes == "infer": + # When unsure, this is a good option. + shard_shapes_by_dim, global_shape = _all_gather_shard_shapes( + local_chunk.shape, + placements, + target_mesh, + ) + else: + # We have been passed sharding shapes manually (yay! best performance) + # so infer the global shape from them + global_shape = list(local_chunk.shape) + for i in range(target_mesh.ndim): + if isinstance(placements[i], Shard): + # Sum the sides for this axis: + tensor_dim = placements[i].dim + global_shape[tensor_dim] = sum( + [s[tensor_dim] for s in sharding_shapes[i]] + ) - shard_shapes_by_dim, global_shape = _all_gather_shard_shapes( - local_shape, - placements, - target_mesh, - ) + shard_shapes_by_dim = sharding_shapes stride = _stride_from_contiguous_shape_C_style(global_shape) @@ -352,14 +507,12 @@ def _infer_shard_tensor_spec_from_local_chunks( global_meta = TensorMeta( shape=tuple(global_shape), stride=stride, dtype=local_chunk.dtype ) - # all_shard_meta = local_meta - - sharding_sizes = {dim: tuple(s) for dim, s in shard_shapes_by_dim.items()} + sharding_shapes = {dim: tuple(s) for dim, s in shard_shapes_by_dim.items()} return ShardTensorSpec( mesh=target_mesh, placements=placements, tensor_meta=global_meta, - _local_shape=local_shape, - _sharding_sizes=sharding_sizes, + _local_shape=local_chunk.shape, + _sharding_shapes=sharding_shapes, ) diff --git a/physicsnemo/distributed/custom_ops/__init__.py b/physicsnemo/distributed/custom_ops/__init__.py index bde7baa614..0bb3eeaa24 100644 --- a/physicsnemo/distributed/custom_ops/__init__.py +++ b/physicsnemo/distributed/custom_ops/__init__.py @@ -20,7 +20,7 @@ try: check_module_requirements("physicsnemo.distributed.shard_tensor") - from ._reductions import sharded_mean_wrapper + from ._reductions import register_reduction_functions from ._tensor_ops import unbind_rules except ImportError: diff --git a/physicsnemo/distributed/custom_ops/_reductions.py b/physicsnemo/distributed/custom_ops/_reductions.py index 7df891755d..561dfb1d07 100644 --- a/physicsnemo/distributed/custom_ops/_reductions.py +++ b/physicsnemo/distributed/custom_ops/_reductions.py @@ -14,7 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + TypeVar, + Union, +) import torch @@ -27,102 +38,535 @@ Shard, ) +# noqa: E402 from physicsnemo.distributed.shard_tensor import ShardTensor # noqa: E402 aten = torch.ops.aten +# Type variable for dimension parameter +DimT = TypeVar("DimT", None, int, Iterable[int]) + -class ShardedMean(torch.autograd.Function): +def normalize_dim( + dim: DimT, tensor_ndim: int, as_set: bool = False, handle_negatives: bool = True +) -> Union[Optional[Tuple[int, ...]], Set[int]]: """ - This is a custom mean operation that takes into account the fact that the - sharded tensor may be unevenly sharded. + Normalize dimension argument to a consistent form. - The strategy is to do a weighted mean, where the weight is the size of the - shard in the dimension being reduced. + Args: + dim: The dimension(s) to normalize. Can be None, int, or iterable of ints. + tensor_ndim: Number of dimensions in the tensor. + as_set: If True, return a set of dimensions instead of a tuple. + handle_negatives: If True, convert negative dimensions to positive ones. + Returns: + - None if dim is None and as_set is False + - A set of all dimensions if dim is None and as_set is True + - A tuple of dimensions (or set if as_set is True) """ + if dim is None: + if as_set: + return set(range(tensor_ndim)) + return None - @staticmethod - def forward(ctx, *args, **kwargs): - # The difference in the sharded mean is that we need to keep track - # of the portion of the global tensor held by the local shard. + # Convert to tuple if iterable + if isinstance(dim, Iterable) and not isinstance(dim, torch.Tensor): + dims = tuple(dim) + else: + dims = (dim,) - def mean_args( - input, dim=None, keepdim=False, dtype=None, out=None, *args, **kwargs - ): - # Sanitize the arguments: + # Handle negative dimensions + if handle_negatives: + dims = tuple(d % tensor_ndim for d in dims) - return input, dim, keepdim, dtype, out + # Return as set or tuple based on as_set flag + if as_set: + return set(dims) + return dims - input, dim, keepdim, dtype, out = mean_args(*args, **kwargs) - weight = 1.0 +def is_full_reduction(dim: DimT, tensor_ndim: int) -> bool: + """ + Determine if this is a full reduction. + + Args: + dim: The dimension(s) to check. Can be None, int, or iterable of ints. + tensor_ndim: Number of dimensions in the tensor. + + Returns: + bool: True if all dimensions are being reduced, False otherwise. + """ + if dim is None: + return True + if isinstance(dim, Iterable) and len(dim) == tensor_ndim: + return True + return False + - if dim is None: +def compute_result_placements( + tensor: ShardTensor, dim: DimT, reduction_name: str, keepdim: bool = False +) -> List[Union[Partial, Shard]]: + """ + Compute placement info for reduction result. - dim = range(len(input.shape)) + Args: + tensor: The input ShardTensor being reduced. + dim: The dimension(s) to reduce. Can be None, int, or iterable of ints. + reduction_name: Type of reduction operation ("sum", "avg", etc.). + keepdim: Whether to preserve reduced dimensions with size 1. - # Convert dim to a tuple if it's not already iterable - if not isinstance(dim, Iterable): - dim = (dim,) + Returns: + List[Union[Partial, Shard]]: Placement specifications for the result tensor. + """ + if is_full_reduction(dim, tensor.ndim): + return [ + Partial("sum" if reduction_name != "avg" else "avg") + for _ in range(tensor.device_mesh.ndim) + ] + + # Use enhanced normalize_dim to get dimensions as a set + dims = normalize_dim(dim, tensor.ndim, as_set=True) + + placements = [] + for p in tensor._spec.placements: + if isinstance(p, Shard): + shard_dim = p.dim + # Count how many reduction dims are less than this shard dim + num_lower = sum(1 for d in dims if d < shard_dim) + # If this sharded dim is being reduced, it becomes Partial + if shard_dim in dims: + placements.append(Partial(reduction_name)) + else: + # If keepdim is False, dims to the left are removed, so shift left + new_dim = shard_dim - num_lower if not keepdim else shard_dim + placements.append(Shard(new_dim)) + else: + placements.append(p) + return placements - denom = 1.0 - local_shape = input._local_tensor.shape - # For each dimension being reduced, multiply weight by local size - # and track global size for denominator + +def reduction_shape( + S: torch.Size, dim: DimT = None, keepdim: bool = False +) -> torch.Size: + """ + Calculate the resulting shape after a reduction operation. + + Args: + S: Original shape of the tensor. + dim: The dimension(s) to reduce. Can be None, int, or iterable of ints. + keepdim: Whether to preserve reduced dimensions with size 1. + + Returns: + torch.Size: The shape after reduction. + """ + shape = list(S) + if dim is None: + return torch.Size([1] * len(shape)) if keepdim else torch.Size([]) + + # Use enhanced normalize_dim to handle iterable and negative dims + dim = normalize_dim(dim, len(shape), handle_negatives=True) + + if keepdim: for d in dim: - if d < 0: - d += input.ndim - weight *= local_shape[d] - denom *= input.shape[d] + shape[d] = 1 + else: + for d in sorted(dim, reverse=True): + del shape[d] + return torch.Size(shape) + + +def compute_result_sharding_shapes( + tensor: ShardTensor, dim: DimT, keepdim: bool +) -> Dict[int, List[torch.Size]]: + """ + Compute sharding sizes for the result of a reduction operation. + + Args: + tensor: The input ShardTensor being reduced. + dim: The dimension(s) to reduce. Can be None, int, or iterable of ints. + keepdim: Whether to preserve reduced dimensions with size 1. + + Returns: + Dict[int, List[torch.Size]]: Mapping of mesh dimensions to sharding shapes. + """ + if is_full_reduction(dim, tensor.ndim): + return {} + else: + # Create a dictionary to store sharding sizes for dimensions that remain in the output + result_sharding_shapes = {} + + # Get the original sharding sizes + original_sharding_shapes = tensor._spec.sharding_shapes() + # Use normalize_dim directly + normalized_dim = normalize_dim(dim, tensor.ndim) + + for mesh_dim, sharding_shapes in original_sharding_shapes.items(): + result_sharding_shapes[mesh_dim] = [ + reduction_shape(shape, normalized_dim, keepdim) + for shape in sharding_shapes + ] + + return result_sharding_shapes + + +def create_sharded_grad_input( + local_grad_input: torch.Tensor, original_spec: Any +) -> ShardTensor: + """ + Create a ShardTensor from local gradient input. + + Args: + local_grad_input: The local gradient tensor. + original_spec: The original ShardTensor's spec to use for placement. + + Returns: + ShardTensor: A distributed tensor with the same sharding as the original input. + """ + return ShardTensor.from_local( + local_grad_input, + device_mesh=original_spec.mesh, + placements=original_spec.placements, + sharding_shapes=original_spec.sharding_shapes(), + ) + + +# Base class for sharded reductions +class ShardedReductionBase(torch.autograd.Function): + """Base class for implementing custom autograd functions for sharded tensor reductions.""" + + @staticmethod + def setup_ctx( + ctx: Any, tensor: ShardTensor, dim: DimT, keepdim: bool + ) -> Tuple[Optional[Tuple[int, ...]], bool]: + """ + Save common context information for backward pass. + + Args: + ctx: The autograd context object. + tensor: The input ShardTensor being reduced. + dim: The dimension(s) to reduce. + keepdim: Whether to preserve reduced dimensions with size 1. + + Returns: + Tuple[Optional[Tuple[int, ...]], bool]: Normalized dimension and keepdim flag. + """ + ctx.original_spec = tensor._spec + ctx.output_requires_grad = tensor.requires_grad + + # Normalize dim to tuple form + dim = normalize_dim(dim, tensor.ndim) + + # Ensure keepdim is a boolean + keepdim = bool(keepdim) + + ctx.dim = dim + ctx.keepdim = keepdim + ctx.is_full_reduction = is_full_reduction(dim, tensor.ndim) + + # Save the shape of the local tensor + ctx.local_grad_shape = tensor._local_tensor.shape + + return dim, keepdim + + +# Specific reduction implementations +class ShardedSum(ShardedReductionBase): + """ + Custom autograd function for sum reduction of sharded tensors. + Handles both forward and backward passes with proper gradient computation. + """ + + @staticmethod + def forward( + ctx: Any, + tensor: ShardTensor, + dim: DimT = None, + keepdim: bool = False, + dtype: Optional[torch.dtype] = None, + ) -> ShardTensor: + """ + Forward pass for sum reduction on ShardTensor. + + Args: + ctx: The autograd context object. + tensor: The input ShardTensor to be reduced. + dim: The dimension(s) to reduce. + keepdim: Whether to preserve reduced dimensions with size 1. + dtype: Output data type (optional). - # Compute the weight: - weight = weight / denom + Returns: + ShardTensor: The result of sum reduction. + """ + dim, keepdim = ShardedReductionBase.setup_ctx(ctx, tensor, dim, keepdim) - local_input = input._local_tensor + # Get local tensor + local_tensor = tensor._local_tensor + # Perform local sum + local_result = aten.sum(local_tensor, dim=dim, keepdim=keepdim, dtype=dtype) - # Now we can do the mean: - local_mean = aten.mean(local_input, dim=dim, keepdim=keepdim, dtype=dtype) + # Compute placements for the result + placements = compute_result_placements(tensor, dim, "sum") + output_sharding_shapes = compute_result_sharding_shapes(tensor, dim, keepdim) - # If dim is None, placements will be partial across all mesh dims. - if dim is None: - placements = [Partial("sum") for _ in range(input.ndim)] + # Create result ShardTensor + result = ShardTensor.from_local( + local_result, + tensor.device_mesh, + placements, + sharding_shapes=output_sharding_shapes, + ) + + return result + + @staticmethod + def backward( + ctx: Any, grad_output: ShardTensor + ) -> Tuple[ShardTensor, None, None, None]: + """ + Backward pass for sum reduction. + + Args: + ctx: The autograd context object. + grad_output: Gradient of the loss with respect to the output. + + Returns: + Tuple containing gradients for each input in the forward pass. + """ + original_spec = ctx.original_spec + dim = ctx.dim + is_full_reduction = ctx.is_full_reduction + keepdim = ctx.keepdim + local_grad_shape = ctx.local_grad_shape + + # Get local grad output + local_grad_output = grad_output._local_tensor + + if is_full_reduction: + # For full reduction, broadcast to original size + grad_input = local_grad_output.expand(local_grad_shape) else: - # dim is not none, but make sure we only put partial on the dims we're reducing on. - placements = [] - for i_p, p in enumerate(input._spec.placements): - if isinstance(p, Shard) and p.dim in dim: - placements.append(Partial("sum")) - else: - placements.append(p) - - # Create a new ShardTensor with the same mesh and right placements - local_mean = ShardTensor.from_local( - weight * local_mean, # Scale by weight to account for local size - input.device_mesh, + # For dimension-specific reduction + if keepdim: + # Just expand along reduced dimensions + expand_shape = list(local_grad_shape) + grad_input = local_grad_output.expand(expand_shape) + else: + # Need to unsqueeze first + grad_shape = list(local_grad_output.shape) + for d in sorted(dim): + if d < 0: + d += original_spec.tensor_meta.ndim + grad_shape.insert(d, 1) + + grad_expanded = local_grad_output.reshape(grad_shape) + expand_shape = list(local_grad_shape) + grad_input = grad_expanded.expand(expand_shape) + + # Create ShardTensor from local grad + grad_input = create_sharded_grad_input(grad_input, original_spec) + # Return gradients for all inputs + return grad_input, None, None, None + + +class ShardedMean(ShardedReductionBase): + """ + Custom autograd function for mean reduction of sharded tensors. + Handles both forward and backward passes with proper gradient computation and scaling. + """ + + @staticmethod + def forward( + ctx: Any, + tensor: ShardTensor, + dim: DimT = None, + keepdim: bool = False, + dtype: Optional[torch.dtype] = None, + ) -> ShardTensor: + """ + Forward pass for mean reduction on ShardTensor. + + Args: + ctx: The autograd context object. + tensor: The input ShardTensor to be reduced. + dim: The dimension(s) to reduce. + keepdim: Whether to preserve reduced dimensions with size 1. + dtype: Output data type (optional). + + Returns: + ShardTensor: The result of mean reduction. + """ + dim, keepdim = ShardedReductionBase.setup_ctx(ctx, tensor, dim, keepdim) + + # Get local tensor + local_tensor = tensor._local_tensor + + # Compute proper weighting for mean + weight = 1.0 + + # Normalize dimensions for consistent handling + if is_full_reduction(dim, tensor.ndim): + # For full reduction, use all dimensions + reduction_dims = set(range(tensor.ndim)) + else: + # Only use the normalized dimensions for partial reduction + reduction_dims = dim + + # Calculate weight based on local vs global shape ratio for reduction dimensions + local_shape = local_tensor.shape + global_shape = tensor.shape + + for d in reduction_dims: + weight *= local_shape[d] / global_shape[d] + + # Perform local mean + local_result = aten.mean(local_tensor, dim=dim, keepdim=keepdim, dtype=dtype) + # Apply weighting + local_result = local_result * weight + + placements = compute_result_placements(tensor, dim, "sum") + output_sharding_shapes = compute_result_sharding_shapes(tensor, dim, keepdim) + + # Create result ShardTensor + result = ShardTensor.from_local( + local_result, + tensor.device_mesh, placements, + sharding_shapes=output_sharding_shapes, ) - # print(f"Local mean: {local_mean}") - return local_mean + return result @staticmethod - def backward(ctx, grad_output): - return grad_output, None, None, None, None, None, None + def backward( + ctx: Any, grad_output: ShardTensor + ) -> Tuple[ShardTensor, None, None, None]: + """ + Backward pass for mean reduction. + + Args: + ctx: The autograd context object. + grad_output: Gradient of the loss with respect to the output. + + Returns: + Tuple containing gradients for each input in the forward pass. + """ + original_spec = ctx.original_spec + dim = ctx.dim + is_full_reduction = ctx.is_full_reduction + keepdim = ctx.keepdim + local_grad_shape = ctx.local_grad_shape + global_shape = original_spec.tensor_meta.shape + + # Get local grad output + local_grad_output = grad_output._local_tensor + + if is_full_reduction: + # For full reduction, broadcast to original size with scaling + factor = 1.0 / torch.prod(torch.tensor(global_shape)) + grad_input = local_grad_output.expand(local_grad_shape) * factor + else: + # For dimension-specific reduction + if keepdim: + # Just expand along reduced dimensions + expand_shape = list(local_grad_shape) + grad_input = local_grad_output.expand(expand_shape) + else: + # Need to unsqueeze first + grad_shape = list(local_grad_output.shape) + for d in sorted(dim): + if d < 0: + d += original_spec.tensor_meta.ndim + grad_shape.insert(d, 1) + + grad_expanded = local_grad_output.reshape(grad_shape) + expand_shape = list(local_grad_shape) + grad_input = grad_expanded.expand(expand_shape) + + # Apply scaling factor for mean + factor = 1.0 + for d in dim: + if d < 0: + d += original_spec.tensor_meta.ndim + factor /= global_shape[d] + grad_input = grad_input * factor + + # Create ShardTensor from local grad + grad_input = create_sharded_grad_input(grad_input, original_spec) + + # Return gradients for all inputs + return grad_input, None, None, None + + +# Create wrapper functions +def sum_wrapper( + tensor: ShardTensor, + dim: DimT = None, + keepdim: bool = False, + *args: Any, + **kwargs: Any +) -> ShardTensor: + """ + Wrapper function for ShardTensor sum reduction. + + Args: + tensor: Input ShardTensor to reduce. + dim: The dimension(s) to reduce. + keepdim: Whether to preserve reduced dimensions with size 1. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + Returns: + ShardTensor: Result of sum reduction. + """ + return ShardedSum.apply(tensor, dim, keepdim, *args, **kwargs) -def sharded_mean_wrapper(*args, **kwargs): - return ShardedMean.apply(*args, **kwargs) +# TODO - accept func, types, args, kwargs instead +def mean_wrapper( + tensor: ShardTensor, + dim: DimT = None, + keepdim: bool = False, + *args: Any, + **kwargs: Any +) -> ShardTensor: + """ + Wrapper function for ShardTensor mean reduction. + + Args: + tensor: Input ShardTensor to reduce. + dim: The dimension(s) to reduce. + keepdim: Whether to preserve reduced dimensions with size 1. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + ShardTensor: Result of mean reduction. + """ + return ShardedMean.apply(tensor, dim, keepdim, *args, **kwargs) -mean_ops = [ - aten.mean.default, - aten.mean.dim, - aten.mean.dtype_out, - aten.mean.names_dim, - aten.mean.names_out, - aten.mean.op, - aten.mean.out, -] -for op in mean_ops: - ShardTensor.register_function_handler(op, sharded_mean_wrapper) + +# Map the reduction ops to their handlers +reduction_mapping: Dict[str, Callable] = { + "sum": sum_wrapper, + "avg": mean_wrapper, + # "max": max_wrapper, + # "min": min_wrapper +} + + +def register_reduction_functions() -> None: + """ + Register reduction functions with the ShardTensor class. + + This function gatekeeps the registration of these ops + to ensure they don't fire unless wanted. + """ + # Register handlers for standalone functions and methods + ShardTensor.register_function_handler(torch.mean, mean_wrapper) + ShardTensor.register_function_handler(torch.Tensor.mean, mean_wrapper) + ShardTensor.register_function_handler(torch.sum, sum_wrapper) + ShardTensor.register_function_handler(torch.Tensor.sum, sum_wrapper) + # ShardTensor.register_function_handler(torch.max, max_wrapper) + # ShardTensor.register_function_handler(torch.Tensor.max, max_wrapper) + # ShardTensor.register_function_handler(torch.min, min_wrapper) + # ShardTensor.register_function_handler(torch.Tensor.min, min_wrapper) diff --git a/physicsnemo/distributed/custom_ops/_tensor_ops.py b/physicsnemo/distributed/custom_ops/_tensor_ops.py index 1d085eea03..6dd9aa54bb 100644 --- a/physicsnemo/distributed/custom_ops/_tensor_ops.py +++ b/physicsnemo/distributed/custom_ops/_tensor_ops.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import torch from physicsnemo.utils.version_check import check_module_requirements @@ -27,7 +28,9 @@ OutputSharding, RuntimeSchemaInfo, ) -from torch.distributed.tensor._ops.utils import register_prop_rule # noqa: E402 +from torch.distributed.tensor._ops.utils import ( # noqa: E402 + register_prop_rule, +) from torch.distributed.tensor.placement_types import ( # noqa: E402 Partial, Replicate, diff --git a/physicsnemo/distributed/manager.py b/physicsnemo/distributed/manager.py index 325b9687e7..1bd69c1a1e 100644 --- a/physicsnemo/distributed/manager.py +++ b/physicsnemo/distributed/manager.py @@ -520,6 +520,39 @@ def initialize_mesh( return self._global_mesh + @require_version("torch", "2.4") + def get_mesh_group(self, mesh: dist.DeviceMesh) -> dist.ProcessGroup: + """ + Get the process group for a given mesh. + + Creating a group is an expensive operation, so we cache the result manually. + + We hash the mesh and use that as the key. + """ + + key = hash(mesh) + + # Initialize a cache for the groups + if not hasattr(self, "_mesh_groups"): + self._mesh_groups = {} + + if key in self._mesh_groups.keys(): + return self._mesh_groups[key] + else: + + if mesh.ndim != 1: + # We need to get all ranks in this mesh and spawn a group. + # The mesh.mesh object is a GPU tensor and using it will block. + ranks = mesh.mesh.cpu() + ranks = list(ranks.flatten().tolist()) + group = dist.new_group(ranks=ranks, use_local_synchronization=True) + self._mesh_groups[key] = group + return group + + else: + self._mesh_groups[key] = mesh.get_group() + return mesh.get_group() + @staticmethod def setup( rank=0, diff --git a/physicsnemo/distributed/shard_tensor.py b/physicsnemo/distributed/shard_tensor.py index 38d7d30626..fa705b2031 100644 --- a/physicsnemo/distributed/shard_tensor.py +++ b/physicsnemo/distributed/shard_tensor.py @@ -15,7 +15,7 @@ # limitations under the License. from collections.abc import Iterable -from typing import Dict, List, Optional, Sequence, Tuple, Union, cast +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union, cast from warnings import warn import torch @@ -23,7 +23,7 @@ from torch.distributed.device_mesh import DeviceMesh, _mesh_resources from physicsnemo.distributed import DistributedManager -from physicsnemo.distributed.utils import compute_split_shapes, split_tensor_along_dim +from physicsnemo.utils.profiling import annotate, profile from physicsnemo.utils.version_check import check_module_requirements # Prevent importing this module if the minimum version of pytorch is not met. @@ -45,8 +45,11 @@ from physicsnemo.distributed._shard_tensor_spec import ( # noqa: E402 ShardTensorSpec, _infer_shard_tensor_spec_from_local_chunks, + _stride_from_contiguous_shape_C_style, ) +aten = torch.ops.aten + class _ToTorchTensor(torch.autograd.Function): """Autograd function to convert a ShardTensor to a regular PyTorch tensor. @@ -99,12 +102,23 @@ def backward( """ shard_tensor_spec = ctx.shard_tensor_spec mesh = shard_tensor_spec.mesh - - grad_placements = ctx.grad_placements or shard_tensor_spec.placements - + if ctx.grad_placements is not None: + if ctx.grad_placements != shard_tensor_spec.placements: + grad_placements = ctx.grad_placements + grad_sharding_shapes = "infer" + else: + # If the placements are the same as the input placements, + # we reuse the sharding sizes from the input placements. + grad_placements = ctx.grad_placements + grad_sharding_shapes = shard_tensor_spec._sharding_shapes + else: + grad_placements = shard_tensor_spec.placements + grad_sharding_shapes = shard_tensor_spec._sharding_shapes + if grad_sharding_shapes is None: + grad_sharding_shapes = "infer" # Generate a spec based on grad outputs and the expected placements: grad_tensor_spec = _infer_shard_tensor_spec_from_local_chunks( - grad_output, mesh, grad_placements + grad_output, mesh, grad_placements, grad_sharding_shapes ) return ( @@ -131,6 +145,7 @@ def forward( local_input: torch.Tensor, device_mesh: DeviceMesh, placements: Tuple[Placement, ...], + sharding_shapes: Union[str, Dict[int, List[Tuple[int, ...]]]] = "chunk", ) -> "ShardTensor": """Convert a local torch.Tensor to a ShardTensor in forward pass. @@ -139,17 +154,24 @@ def forward( local_input: Local tensor to convert to ShardTensor device_mesh: Device mesh specifying process groups placements: Tuple of placement rules for sharding + sharding_shapes: Controls how shard tensor spec is generated: + - "blocking_infer": Use blocking collective communication to infer shapes + - "chunk": Use torch.chunk shapes to infer shapes from global shape (no communication) + - "infer": Assume shapes are not even, but defer inference until needed. + Note that infer will launch async RPC calls to infer the shapes, but won't + block on them until they are called upon. + - Manual dict mapping mesh dim to list of shard shapes: Use provided shapes. Must pass on each rank! Returns: ShardTensor constructed from the local input tensor """ ctx.previous_placement = placements ctx.previous_mesh = device_mesh - # This function is simpler than the corresponding DTensor implementation on the surface - # because under the hood, we always do checks here. + # This function is simpler than the corresponding DTensor implementation on the surface + # because under the hood, we have some logic here to infer the sharding shapes. shard_tensor_spec = _infer_shard_tensor_spec_from_local_chunks( - local_input, device_mesh, placements + local_input, device_mesh, placements, sharding_shapes ) shard_tensor = ShardTensor( @@ -181,11 +203,18 @@ def backward( RuntimeError: If gradient tensor has different placement than original """ previous_placement = ctx.previous_placement - if grad_output.placements != previous_placement: - raise RuntimeError("Resharding gradients not yet implemented") + # Automatically redistribute to the previous placement as long as it's not a partial. + if not any(p.is_partial() for p in previous_placement): + grad_output = grad_output.redistribute( + grad_output._spec.mesh, previous_placement + ) + else: + raise RuntimeError( + "Resharding gradients with partial placements not implemented" + ) - return grad_output.to_local(), None, None + return grad_output.to_local(), None, None, None class ShardTensor(DTensor): @@ -221,7 +250,11 @@ class ShardTensor(DTensor): _spec: ShardTensorSpec __slots__ = ["_local_tensor", "_spec"] - _function_registry: Dict[torch._ops.OpOverload, callable] = {} + # For torch.ops.aten operators (low-level dispatch) + _dispatch_registry: Dict[torch._ops.OpOverload, Callable] = {} + + # For Python-level functions (torch.mean, tensor.mean, etc.) + _function_registry: Dict[Callable, Callable] = {} # Upon construction of any ShardTensor objects, this will be set to true. # Wrappers are triggered dynamically, so the wrapping will be pass-through @@ -238,14 +271,15 @@ def patches_enabled(cls) -> bool: return cls._enable_shard_patches @classmethod - def register_function_handler(cls, func: torch._ops.OpOverload, handler: callable): - """ - Register a custom handler for a specific function. + def register_dispatch_handler( + cls, op: torch._ops.OpOverload, handler: Callable + ) -> None: + """Register a handler for a specific PyTorch operator in the dispatch system.""" + cls._dispatch_registry[op] = handler - Args: - func: The function to intercept. - handler: The custom handler to call instead of the default dispatch. - """ + @classmethod + def register_function_handler(cls, func: Callable, handler: Callable) -> None: + """Register a handler for a Python-level function or method.""" cls._function_registry[func] = handler @staticmethod @@ -305,11 +339,9 @@ def __repr__(self) -> str: return f"ShardTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})" @classmethod - def from_dtensor( - cls, dtensor: DTensor, force_sharding_inference: bool = False - ) -> "ShardTensor": + def from_dtensor(cls, dtensor: DTensor) -> "ShardTensor": """ - Convert a DTensor to a ShardTensor. + Convert a DTensor to a ShardTensor. We assume the DTensor is properly constructed. Args: dtensor: DTensor to convert @@ -317,82 +349,41 @@ def from_dtensor( Returns: Equivalent ShardTensor """ + # Always ensure sharding is turned on: + cls._enable_shard_patches = True # DTensor is locked to sharding a tensor according to chunk format. # We can use that to infer sharding sizes with no communication. - mesh = dtensor._spec.mesh - placements = dtensor._spec.placements - - if force_sharding_inference: - shard_tensor_spec = _infer_shard_tensor_spec_from_local_chunks( - dtensor._local_tensor, dtensor._spec.mesh, dtensor._spec.placements - ) - return ShardTensor.__new__( - cls, - local_tensor=dtensor._local_tensor, - spec=shard_tensor_spec, - requires_grad=dtensor.requires_grad, - ) - else: - temp_sharding_sizes = {} - for i in range(mesh.ndim): - if isinstance(placements[i], Shard): - # Compute the chunk size for this dimension: - input_dim = dtensor.shape[placements[i].dim] - chunked_shapes = compute_split_shapes(input_dim, mesh.size(i)) - # This needs to be a tuple of torch.Size - - temp_sharding_sizes[i] = chunked_shapes - - # To create the full, final sharding shapes, we update the global shape with - sharding_sizes = {} - # Initialize sharding_sizes with same keys as temp_sharding_sizes - # Each value is a list of torch.Size equal to mesh size for that dimension - for mesh_dim in temp_sharding_sizes.keys(): - placement = placements[mesh_dim] - # We should not have the mesh dim in this dict if it wasn't sharded above: - tensor_dim = placement.dim - - sharding_sizes[mesh_dim] = [ - torch.Size(dtensor.shape) for _ in temp_sharding_sizes[mesh_dim] - ] - # For each shard along this mesh dimension - for i, shard_size in enumerate(temp_sharding_sizes[mesh_dim]): - # Replace size at sharded dim with actual shard size - updated_shard_size = torch.Size( - tuple( - ( - shard_size if j == tensor_dim else s - for j, s in enumerate(sharding_sizes[mesh_dim][i]) - ) - ) - ) - sharding_sizes[mesh_dim][i] = updated_shard_size - - # Cast to tuples: - for mesh_dim in temp_sharding_sizes.keys(): - sharding_sizes[mesh_dim] = tuple(sharding_sizes[mesh_dim]) - - spec = ShardTensorSpec( - mesh=dtensor._spec.mesh, - placements=dtensor._spec.placements, - tensor_meta=dtensor._spec.tensor_meta, - _sharding_sizes=sharding_sizes, # Leave this to none for a lazy init and assume it's not breaking to make this cast. - _local_shape=dtensor._local_tensor.shape, - ) + # Create the spec by inferring the sharding sizes from the DTensor: + spec = _infer_shard_tensor_spec_from_local_chunks( + dtensor._local_tensor, + dtensor._spec.mesh, + dtensor._spec.placements, + sharding_shapes="chunk", + global_shape=dtensor.shape, + ) - cls._enable_shard_patches = True + return ShardTensor.__new__( + cls, + local_tensor=dtensor._local_tensor, + spec=spec, + requires_grad=dtensor.requires_grad, + ) - return ShardTensor.__new__( - cls, - local_tensor=dtensor._local_tensor, - spec=spec, - requires_grad=dtensor.requires_grad, - ) + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs={}): + with annotate(f"__torch_function___{func.__name__}"): + # Check for overrides: + if func in cls._function_registry and cls._enable_shard_patches: + res = cls._function_registry[func](func, types, args, kwargs) + return res + # Fall back to the default behavior: + return super().__torch_function__(func, types, args, kwargs) @classmethod @torch._disable_dynamo + @profile def __torch_dispatch__( cls, func: torch._ops.OpOverload, @@ -400,37 +391,71 @@ def __torch_dispatch__( args: Tuple[object, ...] = (), kwargs: Optional[Dict[str, object]] = None, ) -> Union["ShardTensor", Iterable["ShardTensor"], object]: - # Leverage DTensor Dispatch as much as possible, but, enable - # the ability to operate on this output in the future: - - if func in cls._function_registry: - return cls._function_registry[func](*args, **kwargs) - - dispatch_res = DTensor._op_dispatcher.dispatch(func, args, kwargs or {}) - - # dispatch_res = ShardTensor._op_dispatcher.dispatch(func, args, kwargs or {}) - - # Return a shard tensor instead of a dtensor. - # ShardTensor inherits from DTensor and can lazy-init from for efficiency - if isinstance(dispatch_res, DTensor): - return ShardTensor.from_dtensor(dispatch_res, force_sharding_inference=True) - - if isinstance(dispatch_res, Iterable): - return type(dispatch_res)( - ShardTensor.from_dtensor(d, force_sharding_inference=True) - if isinstance(d, DTensor) - else d - for d in dispatch_res - ) + with annotate(f"__torch_dispatch___{func.__name__}"): + # Leverage DTensor Dispatch as much as possible, but, enable + # the ability to operate on this output in the future: + if func in cls._dispatch_registry: + res = cls._dispatch_registry[func](*args, **kwargs) + return res + + # We assume that if we reach this point, the operator has not been + # intercepted by a wrapper or in the registry. So the DTensor + # default behavior is likely to be correct. + + if func == aten.view.default: + # For view, we need input tensors to be contiguous: + for arg in args: + if isinstance(arg, ShardTensor) or isinstance(arg, DTensor): + if not arg._local_tensor.is_contiguous(): + arg._local_tensor = arg._local_tensor.contiguous() + + dispatch_res = DTensor._op_dispatcher.dispatch(func, args, kwargs or {}) + + # Return a shard tensor instead of a dtensor. + def _convert_dtensor_with_input_check(dtensor, input_args): + """ + This function searches the input for ShardTensors that match output shapes. + It prevents collectives, since we can copy the sharding shapes for irregular shards. + + If no matches are found, it falls back to inference based on DTensor. + + This is only used when we already went back through the DTensor dispatch. + """ + # Check if this matches any input ShardTensor + for arg in input_args: + if ( + isinstance(arg, ShardTensor) + and dtensor._spec.tensor_meta == arg._spec.tensor_meta.shape + and dtensor._spec.placements == arg._spec.placements + ): + return ShardTensor.__new__( + ShardTensor, + local_tensor=dtensor._local_tensor, + spec=arg._spec, + requires_grad=dtensor.requires_grad, + ) + # Fall back to default conversion + return ShardTensor.from_dtensor(dtensor) + + if isinstance(dispatch_res, DTensor): + return _convert_dtensor_with_input_check(dispatch_res, args) + + if isinstance(dispatch_res, Iterable): + return type(dispatch_res)( + _convert_dtensor_with_input_check(d, args) + if isinstance(d, DTensor) + else d + for d in dispatch_res + ) - return dispatch_res + return dispatch_res @staticmethod def from_local( local_tensor: torch.Tensor, device_mesh: Optional[DeviceMesh] = None, placements: Optional[Sequence[Placement]] = None, - infer_shape: Optional[bool] = True, + sharding_shapes: Union[str, Dict[int, List[Tuple[int, ...]]]] = "infer", ) -> "ShardTensor": """ Generate a new ShardTensor from local torch tensors. Uses @@ -445,48 +470,44 @@ def from_local( of the same rank and concatable across the mesh dimensions device_mesh: Target Device Mesh, if not specified will use the current mesh placements: Target placements, must have same number of elements as device_mesh.ndim - infer_shape: If False, assumes even distribution like DTensor. Default True. - + sharding_shapes: Passed to _infer_shard_tensor_spec_from_local_chunks to control + how the sharding sizes are inferred. Returns: A new ShardTensor instance """ - if infer_shape: + # this turns on shard patches globally for this process. + ShardTensor._enable_shard_patches = True - # This implementation follows the pytorch DTensor Implementation Closely. - device_mesh = device_mesh or _mesh_resources.get_current_mesh() - device_type = device_mesh.device_type + # This implementation follows the pytorch DTensor Implementation Closely. + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + device_type = device_mesh.device_type - # convert the local tensor to desired device base on device mesh's device_type - if device_type != local_tensor.device.type and not local_tensor.is_meta: - local_tensor = local_tensor.to(device_type) + # convert the local tensor to desired device base on device mesh's device_type + if device_type != local_tensor.device.type and not local_tensor.is_meta: + local_tensor = local_tensor.to(device_type) - # set default placements to replicated if not specified - if placements is None: - placements = [Replicate() for _ in range(device_mesh.ndim)] - else: - placements = list(placements) - for idx, placement in enumerate(placements): - # normalize shard dim to be positive - if placement.is_shard(): - placement = cast(Shard, placement) - if placement.dim < 0: - placements[idx] = Shard(placement.dim + local_tensor.ndim) - - # `from_local` is differentiable, and the gradient of the dist tensor this function - # created should flow back the gradients to the local_tensor, so we call an autograd - # function to construct the dist tensor instead. - ShardTensor._enable_shard_patches = True - return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func - local_tensor, - device_mesh, - tuple(placements), - ) + # set default placements to replicated if not specified + if placements is None: + placements = [Replicate() for _ in range(device_mesh.ndim)] else: - ShardTensor._enable_shard_patches = True - return ShardTensor.from_dtensor( - DTensor.from_local(local_tensor, device_mesh, placements) - ) + placements = list(placements) + for idx, placement in enumerate(placements): + # normalize shard dim to be positive + if placement.is_shard(): + placement = cast(Shard, placement) + if placement.dim < 0: + placements[idx] = Shard(placement.dim + local_tensor.ndim) + + # `from_local` is differentiable, and the gradient of the dist tensor this function + # created should flow back the gradients to the local_tensor, so we call an autograd + # function to construct the dist tensor instead. + return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func + local_tensor, + device_mesh, + tuple(placements), + sharding_shapes, + ) def offsets(self, mesh_dim: Optional[int] = None) -> List[int]: """ @@ -576,12 +597,35 @@ def full_tensor( ) return _ToTorchTensor.apply(redist_res, grad_placements) + def backward(self, *args, **kwargs): + + # Before calling backward, we need to resolve any partial placements. + new_placements = [] + # grad_placements = [] + needs_redistribute = False + for i, placement in enumerate(self._spec.placements): + if placement.is_partial(): + new_placements.append(Replicate()) + # grad_placements.append(Shard(i)) + needs_redistribute = True + else: + new_placements.append(placement) + # grad_placements.append(placement) + + if needs_redistribute: + self = self.redistribute(placements=new_placements) + + return self.to_local().backward(*args, **kwargs) + def scatter_tensor( tensor: torch.Tensor, global_src: int, mesh: DeviceMesh, placements: Tuple[Placement, ...], + global_shape: Optional[torch.Size] = None, + dtype: Optional[torch.dtype] = None, + requires_grad: bool = False, ) -> "ShardTensor": """ Take a tensor from source rank and distribute it across devices on the mesh according to placements. @@ -611,64 +655,60 @@ def scatter_tensor( is_src = dm.rank == global_src - # For multi-dimensional meshes, create a flattened process group - if mesh.ndim != 1: - global_ranks = mesh.mesh.flatten().tolist() - mesh_group = dist.new_group(ranks=global_ranks, use_local_synchronization=True) - else: - mesh_group = mesh.get_group() + # For multi-dimensional meshes, we use a flattened process group + mesh_group = dm.get_mesh_group(mesh) # Broadcast tensor metadata from source - axis_rank = dist.get_rank(mesh_group) - if dm.rank == global_src: - meta = [TensorMeta(tensor.shape, tensor.stride(), tensor.dtype)] - else: - meta = [None] - - dist.broadcast_object_list(meta, src=global_src, group=mesh_group) - dist.barrier(group=mesh_group) - local_meta = meta[0] + if global_shape is None or dtype is None: + if dm.rank == global_src: + meta = [TensorMeta(tensor.shape, tensor.stride(), tensor.dtype)] + else: + meta = [None] - # Cast the shape to a list to be mutable: - local_shape = list(local_meta.shape) + dist.broadcast_object_list(meta, src=global_src, group=mesh_group) - if is_src: - chunks = [tensor] + local_meta = meta[0] else: - chunks = None - - # Split tensor according to shard placements - for dim, placement in enumerate(placements): - if isinstance(placement, Shard): - tensor_dim = placement.dim - axis_rank = dist.get_rank(group=mesh.get_group(dim)) - axis_size = dist.get_world_size(group=mesh.get_group(dim)) - - sections = compute_split_shapes(local_shape[tensor_dim], axis_size) - - if is_src: - new_chunks = [] - for t in chunks: - new_chunks += split_tensor_along_dim(t, tensor_dim, axis_size) - chunks = new_chunks - local_shape[tensor_dim] = sections[axis_rank] - - # Convert the shape back to a tuple: - local_shape = tuple(local_shape) - - # Allocate local tensor - local_chunk = torch.empty( - local_shape, - dtype=local_meta.dtype, - device=torch.device(f"cuda:{dm.local_rank}"), + stride = _stride_from_contiguous_shape_C_style(global_shape) + local_meta = TensorMeta(global_shape, stride, dtype) + + # This needs to be optimized, but I want to get the whole pipeline optimized first. + # This only gets done when scatter_tensor is called and it should be relatively small + # in full applications. + + # What isn't optimmized? Broadcasting the full tensor when placement is likely + # Shard on at least one mesh dimension. It would be more efficient to iteratively + # scatter along Shard dimensions. BUT, the focus is on performance of full applications + # and this is a once-per-iteration cost. + + # Broadcast the tensor to all ranks + if tensor is None and not is_src: + # Tensor is allowed to be none if not on the root rank + tensor = torch.empty(local_meta.shape, dtype=local_meta.dtype, device=dm.device) + + dist.broadcast(tensor, src=global_src, group=mesh_group) + + # Create a fully-replicated spec: + spec = ShardTensorSpec( + mesh=mesh, + placements=[Replicate() for _ in range(mesh.ndim)], + tensor_meta=local_meta, + _sharding_shapes={}, ) - # Scatter chunks across mesh - dist.scatter(local_chunk, chunks, src=global_src, group=mesh_group) - - # Construct ShardTensor from local tensor - return ShardTensor.from_local( - local_tensor=local_chunk, - device_mesh=mesh, - placements=placements, + # Make a "fully-replicated" tensor on all ranks: + st = ShardTensor.__new__( + ShardTensor, + local_tensor=tensor, + spec=spec, + requires_grad=requires_grad, ) + + # Redistribute the tensor to the desired placements: + st = st.redistribute(mesh, placements, async_op=False) + # This is an unoptimal step but is functional: + if requires_grad: + st = st.detach() + st.requires_grad = True + + return st diff --git a/physicsnemo/distributed/shard_utils/__init__.py b/physicsnemo/distributed/shard_utils/__init__.py index 4c228ce650..cebe47f6c9 100644 --- a/physicsnemo/distributed/shard_utils/__init__.py +++ b/physicsnemo/distributed/shard_utils/__init__.py @@ -14,15 +14,34 @@ # See the License for the specific language governing permissions and # limitations under the License. +import torch + from physicsnemo.utils.version_check import check_module_requirements # Prevent importing this module if the minimum version of pytorch is not met. try: check_module_requirements("physicsnemo.distributed.shard_tensor") + from physicsnemo.distributed.shard_tensor import ShardTensor + def register_shard_wrappers(): + from .attention_patches import sdpa_wrapper from .conv_patches import generic_conv_nd_wrapper + from .index_ops import ( + index_select_wrapper, + select_backward_wrapper, + select_wrapper, + ) from .natten_patches import na2d_wrapper + from .normalization_patches import group_norm_wrapper + from .point_cloud_ops import ball_query_layer_wrapper + from .pooling_patches import generic_avg_pool_nd_wrapper + from .unpooling_patches import interpolate_wrapper + + ShardTensor.register_dispatch_handler(torch.ops.aten.select.int, select_wrapper) + ShardTensor.register_dispatch_handler( + torch.ops.aten.select_backward.default, select_backward_wrapper + ) except ImportError: pass diff --git a/physicsnemo/distributed/shard_utils/attention_patches.py b/physicsnemo/distributed/shard_utils/attention_patches.py new file mode 100644 index 0000000000..833cfbb2f9 --- /dev/null +++ b/physicsnemo/distributed/shard_utils/attention_patches.py @@ -0,0 +1,711 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import wrapt +from torch.autograd.profiler import record_function + +from physicsnemo.utils.version_check import check_module_requirements + +check_module_requirements("physicsnemo.distributed.shard_tensor") + + +from torch.distributed import DeviceMesh # noqa: E402 + +from physicsnemo.distributed import ShardTensor # noqa: E402 +from physicsnemo.distributed.shard_utils.patch_core import ( # noqa: E402 + MissingShardPatch, + UndeterminedShardingError, +) +from physicsnemo.distributed.shard_utils.ring import ( # noqa: E402 + RingPassingConfig, + perform_ring_iteration, +) + +aten = torch.ops.aten + + +def add_log_sumexp( + log_a: Optional[torch.Tensor], log_b: Optional[torch.Tensor] +) -> torch.Tensor: + """ + Add two log_sumexp values together. + + Args: + log_a: First log-space value, can be None + log_b: Second log-space value, can be None + + Returns: + torch.Tensor: Result of log(exp(log_a) + exp(log_b)) computed in a numerically stable way + + Think of this function as taking two values, A and B, + passed in via log form: log(A) and log(B). This function + will return log(A+B) in a numerically stable way. + + """ + if log_a is None or log_b is None: + return log_a if log_a is not None else log_b + + diff = torch.abs(log_a - log_b) + return torch.max(log_a, log_b) + torch.log(torch.exp(-diff) + 1.0) + + +def stable_signed_accumulate( + log_abs_global_O: Optional[torch.Tensor], + sign_global_O: Optional[torch.Tensor], + log_O: torch.Tensor, + sign_O: torch.Tensor, + log_A: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Accumulate two functions together, keeping track of the sign and log_abs. + + Args: + log_abs_global_O: Log of absolute value of accumulated output so far, can be None + sign_global_O: Sign of accumulated output so far, can be None + log_O: Log of absolute value of current output + sign_O: Sign of current output + log_A: Log of normalization factor for current output + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Updated (log_abs, sign) pair for accumulated output + + The block attention algorithm needs to continuously accumulate the output of each block, + however, the normalization is done in log space. This function accomodates that by + accumulating the output in log space using log space normalizations. Note that because + the output of an attention block can be negative, we must use both log(|O|) and sign(O) + for each term. + """ + if log_abs_global_O is None and sign_global_O is None: + return log_O + log_A, sign_O + + log_abs_T = log_O + log_A + sign_T = sign_O + + # Find larger magnitude term + max_log = torch.maximum(log_abs_global_O, log_abs_T) + min_log = torch.minimum(log_abs_global_O, log_abs_T) + + # If signs are the same, use log-sum-exp + same_sign = sign_global_O == sign_T + log_abs_new = torch.where( + same_sign, + max_log + torch.log1p(torch.exp(min_log - max_log)), # log-sum-exp + max_log + torch.log1p(-torch.exp(min_log - max_log)), # log-subtraction + ) + + # Determine new sign + sign_new = torch.where( + same_sign, + sign_global_O, + torch.where(log_abs_global_O >= log_abs_T, sign_global_O, sign_T), + ) + + return log_abs_new, sign_new + + +# Create a persistent communication stream for the ring attention: +comm_stream = torch.cuda.Stream() + + +class RingSDPA(torch.autograd.Function): + """ + Performs scaled dot product attention on sharded Q, K, V. + + The ring allreduce happens concurrently and overlapping with the computation, + for performance improvements. + + For details about the ring attention, see: https://arxiv.org/abs/2310.01889 + Note that the original implementation is a combination of JAX + flash attention + ring attention. + Here, instead, we leverage the underlying and built-in pytorch efficienct attention. + + A key difference with this algorithm is how we track the per-block normalizations. The pytorch + function returns log_sumexp, which we use for a running normalization. But it has to be kept in log + space to prevent underflow/overflow as well as precision issues. See the helper functions + `add_log_sumexp` and `stable_signed_accumulate` for more details. + """ + + # comm_stream = torch.cuda.Stream() + + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor], + mesh: DeviceMesh, + ring_config: RingPassingConfig, + attn_args: dict, + ) -> torch.Tensor: + """ + Forward pass for the ring attention implementation. This implementation + will overlap the communication with the computation. Note that there is an + explicit sync in each iteration to prevent the communication stream from getting + ahead of the computation stream, by waiting on the all_to_all operation to complete. + """ + + ctx.attn_args = attn_args + ctx.mesh = mesh + ctx.ring_config = ring_config + + # Create buffers to store outputs + log_global_output = None + sign_global_output = None + global_log_sumexp = None + + # For the first iteration, use local tensors + current_k, current_v = k, v + + # Pre-allocate a single combined buffer for k,v + # Find the dimension to concatenate on - should be a dim that preserves tensor structure + cat_dim = 0 # Usually batch or sequence dimension + + # Set up communication + local_group = mesh.get_group(ring_config.mesh_dim) + local_rank = mesh.get_local_rank(ring_config.mesh_dim) + local_size = dist.get_world_size(group=local_group) + + id_of_right = (local_rank + 1) % ring_config.mesh_size + id_of_left = (local_rank - 1) % ring_config.mesh_size + + # Create streams that persist for the duration of the ring computation. + compute_stream = torch.cuda.default_stream() + + for i in range(ring_config.mesh_size): + # Launch communication for the next iteration early + with record_function(f"sdpa_send_data_{i}_{dist.get_rank()}"): + if i < ring_config.mesh_size - 1: + + # Use a dedicated stream for communication + with torch.cuda.stream(comm_stream): + + send_tensors = [ + torch.empty((), device=q.device, dtype=q.dtype) + for _ in range(local_size) + ] + recv_tensors = [ + torch.empty((), device=q.device, dtype=q.dtype) + for _ in range(local_size) + ] + + # Combine k and v for communication + send_tensors[id_of_right] = torch.cat( + [current_k, current_v], dim=cat_dim + ).contiguous() + recv_tensors[id_of_left] = torch.empty_like( + send_tensors[id_of_right] + ) + + # Use async_op=True to enable overlapping + a2a_op = dist.all_to_all( + recv_tensors, send_tensors, group=local_group, async_op=True + ) + + # Mark these as used by the communication stream: + current_k.record_stream(comm_stream) + current_v.record_stream(comm_stream) + + # Perform computation on current k,v while communication happens + with record_function(f"sdpa_forward_{i}_{dist.get_rank()}"): + with torch.cuda.stream(compute_stream): + ( + output, + log_sumexp, + philox_seed, + philox_offset, + ) = aten._scaled_dot_product_efficient_attention( + q, + current_k, + current_v, + attn_mask, + compute_log_sumexp=True, + **attn_args, + ) + + # Add an extra dimension to the log_sumexp: + log_sumexp = log_sumexp.unsqueeze(-1) + log_output = torch.log(torch.abs(output)) + sign_output = torch.sign(output) + + log_global_output, sign_global_output = stable_signed_accumulate( + log_global_output, + sign_global_output, + log_output, + sign_output, + log_sumexp, + ) + + global_log_sumexp = add_log_sumexp(global_log_sumexp, log_sumexp) + + if i < ring_config.mesh_size - 1: + # Wait for communication operations to complete before allowing more work + a2a_op.wait() + + # compute_stream.wait_stream(comm_stream) + + # Explicit synchronization to ensure communication is complete + # Also makes sure that the attention computation is complete before changing current_k, current_v + # compute_stream.synchronize() + + current_k, current_v = torch.chunk( + recv_tensors[id_of_left], 2, dim=cat_dim + ) + + # Compute the final output + stable_output = sign_global_output * torch.exp( + log_global_output - global_log_sumexp + ) + + ctx.save_for_backward( + q, + k, + v, + attn_mask, + stable_output, + global_log_sumexp, + philox_seed, + philox_offset, + ) + ctx.grad_input_mask = (True, True, True, attn_mask is not None) + + return stable_output + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], + None, + None, + None, + ]: + """ + Backward pass for the ring SDPA. + + Currently, this is not overlapping communication with the computation. + Note that the backward pass has 2x communication: send k, v but also grad_k, grad_v. + + """ + ( + q, + k, + v, + attn_mask, + output, + log_sumexp, + philox_seed, + philox_offset, + ) = ctx.saved_tensors + attn_args = ctx.attn_args + + grad_q = torch.zeros_like( + q, device=q.device, memory_format=torch.contiguous_format + ) + grad_k = torch.zeros_like( + k, device=k.device, memory_format=torch.contiguous_format + ) + grad_v = torch.zeros_like( + v, device=v.device, memory_format=torch.contiguous_format + ) + grad_attn_mask = None + + # TODO: overlap communication with computation. + # This needs to be done in two stages. First, we can send k, v along the ring before computing + # the gradients. We also need to send grad_k, grad_v along the ring and accumulate them. + + # Since the next iteration's grad_k, grad_v do not depend on the current iteration's gradient + # outputs, we can still overlap. But we need two sync spots instead of one. + # Algorithm therefore looks like this: + # 1. If iteration != N-1, send k, v to the next GPU asycn after combining them into one tensor. + # 2. If iteration != 0, wait for grad_k, grad_v to be received from the previous GPU and split them. + # 2. Compute the gradients on the local block (grad_q, grad_k, grad_v) + # 3. Accumulate the gradients on the local block. + # 5. If iteration != N-1, wait for k, v to be received from the previous GPU (and split them) before the next iteration + # 4. If iteration != 0, send grad_k, grad_v to the next GPU after combining them into one tensor. + + for i in range(ctx.ring_config.mesh_size): + + ( + block_grad_q, + block_grad_k, + block_grad_v, + block_grad_attn_mask, + ) = aten._scaled_dot_product_efficient_attention_backward( + grad_output, + q, + k, + v, + attn_mask, + output, + log_sumexp, + philox_seed, + philox_offset, + grad_input_mask=ctx.grad_input_mask, + **attn_args, + ) + + grad_q += block_grad_q + grad_k += block_grad_k + grad_v += block_grad_v + + # Send k, v, grad_k, grad_v to the next rank: + k = perform_ring_iteration(k, ctx.mesh, ctx.ring_config) + v = perform_ring_iteration(v, ctx.mesh, ctx.ring_config) + grad_k = perform_ring_iteration(grad_k, ctx.mesh, ctx.ring_config) + grad_v = perform_ring_iteration(grad_v, ctx.mesh, ctx.ring_config) + + return grad_q, grad_k, grad_v, grad_attn_mask, None, None, None + + +class RingSDPABlocking(torch.autograd.Function): + """ + Performs scaled dot product attention on sharded Q, K, V. + + The ring allreduce happens in a blocking manner. This isn't more efficient, but + it is useful for understanding the algorithm and debugging. + + For details about the ring attention, see: https://arxiv.org/abs/2310.01889 + Note that the original implementation is a combination of JAX + flash attention + ring attention. + Here, instead, we leverage the underlying and built-in pytorch efficienct attention. + + A key difference with this algorithm is how we track the per-block normalizations. The pytorch + function returns log_sumexp, which we use for a running normalization. But it has to be kept in log + space to prevent underflow/overflow as well as precision issues. See the helper functions + `add_log_sumexp` and `stable_signed_accumulate` for more details. + """ + + # comm_stream = torch.cuda.Stream() + + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor], + mesh: DeviceMesh, + ring_config: RingPassingConfig, + attn_args: dict, + ) -> torch.Tensor: + """ + Forward pass for the ring attention implementation. This implementation + will overlap the communication with the computation. Note that there is an + explicit sync in each iteration to prevent the communication stream from getting + ahead of the computation stream, by waiting on the all_to_all operation to complete. + """ + + ctx.attn_args = attn_args + ctx.mesh = mesh + ctx.ring_config = ring_config + + # Create buffers to store outputs + log_global_output = None + sign_global_output = None + global_log_sumexp = None + + # For the first iteration, use local tensors + current_k, current_v = k, v + + for i in range(ring_config.mesh_size): + + # Perform computation on current k,v while communication happens + ( + output, + log_sumexp, + philox_seed, + philox_offset, + ) = aten._scaled_dot_product_efficient_attention( + q, + current_k, + current_v, + attn_mask, + compute_log_sumexp=True, + **attn_args, + ) + + # Add an extra dimension to the log_sumexp: + log_sumexp = log_sumexp.unsqueeze(-1) + log_output = torch.log(torch.abs(output)) + sign_output = torch.sign(output) + + log_global_output, sign_global_output = stable_signed_accumulate( + log_global_output, + sign_global_output, + log_output, + sign_output, + log_sumexp, + ) + + global_log_sumexp = add_log_sumexp(global_log_sumexp, log_sumexp) + + # send k and v to the next rank: + current_k = perform_ring_iteration(current_k, ctx.mesh, ctx.ring_config) + current_v = perform_ring_iteration(current_v, ctx.mesh, ctx.ring_config) + + # Compute the final output + stable_output = sign_global_output * torch.exp( + log_global_output - global_log_sumexp + ) + + ctx.save_for_backward( + q, + k, + v, + attn_mask, + stable_output, + global_log_sumexp, + philox_seed, + philox_offset, + ) + ctx.grad_input_mask = (True, True, True, attn_mask is not None) + + return stable_output + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], + None, + None, + None, + ]: + """ + Backward pass for the ring SDPA. + + Currently, this is not overlapping communication with the computation. + Note that the backward pass has 2x communication: send k, v but also grad_k, grad_v. + + """ + ( + q, + k, + v, + attn_mask, + output, + log_sumexp, + philox_seed, + philox_offset, + ) = ctx.saved_tensors + attn_args = ctx.attn_args + + grad_q = torch.zeros_like( + q, device=q.device, memory_format=torch.contiguous_format + ) + grad_k = torch.zeros_like( + k, device=k.device, memory_format=torch.contiguous_format + ) + grad_v = torch.zeros_like( + v, device=v.device, memory_format=torch.contiguous_format + ) + grad_attn_mask = None + + # TODO: overlap communication with computation. + # This needs to be done in two stages. First, we can send k, v along the ring before computing + # the gradients. We also need to send grad_k, grad_v along the ring and accumulate them. + + # Since the next iteration's grad_k, grad_v do not depend on the current iteration's gradient + # outputs, we can still overlap. But we need two sync spots instead of one. + # Algorithm therefore looks like this: + # 1. If iteration != N-1, send k, v to the next GPU asycn after combining them into one tensor. + # 2. If iteration != 0, wait for grad_k, grad_v to be received from the previous GPU and split them. + # 2. Compute the gradients on the local block (grad_q, grad_k, grad_v) + # 3. Accumulate the gradients on the local block. + # 5. If iteration != N-1, wait for k, v to be received from the previous GPU (and split them) before the next iteration + # 4. If iteration != 0, send grad_k, grad_v to the next GPU after combining them into one tensor. + + for i in range(ctx.ring_config.mesh_size): + + ( + block_grad_q, + block_grad_k, + block_grad_v, + block_grad_attn_mask, + ) = aten._scaled_dot_product_efficient_attention_backward( + grad_output, + q, + k, + v, + attn_mask, + output, + log_sumexp, + philox_seed, + philox_offset, + grad_input_mask=ctx.grad_input_mask, + **attn_args, + ) + + grad_q += block_grad_q + grad_k += block_grad_k + grad_v += block_grad_v + + # Send k, v, grad_k, grad_v to the next rank: + k = perform_ring_iteration(k, ctx.mesh, ctx.ring_config) + v = perform_ring_iteration(v, ctx.mesh, ctx.ring_config) + grad_k = perform_ring_iteration(grad_k, ctx.mesh, ctx.ring_config) + grad_v = perform_ring_iteration(grad_v, ctx.mesh, ctx.ring_config) + + return grad_q, grad_k, grad_v, grad_attn_mask, None, None, None + + +def ring_sdpa( + q: ShardTensor, + k: ShardTensor, + v: ShardTensor, + attn_mask: Optional[ShardTensor] = None, + **kwargs: dict, +) -> ShardTensor: + """ + High Level, differentiable function to compute neighborhood attention on a sharded tensor. + + Operation works like so: + - Figure out the size of halos needed. + - Apply the halo padding (differentiable) + - Perform the neighborhood attention on the padded tensor. (differentiable) + - "UnHalo" the output tensor (different from, say, convolutions) + - Return the updated tensor as a ShardTensor. + + """ + + mesh = q._spec.mesh + + # We can be confident of this because 1D meshes are enforced + mesh_dim = 0 + + local_group = mesh.get_group(mesh_dim) + local_size = dist.get_world_size(group=local_group) + + # Create a config object to simplify function args for message passing: + ring_config = RingPassingConfig( + mesh_dim=mesh_dim, + mesh_size=local_size, + communication_method="p2p", + ) + + # First, get the tensors locally and perform halos: + lq, lk, lv = q.to_local(), k.to_local(), v.to_local() + + if attn_mask is not None: + latn_mask = attn_mask.to_local() + else: + latn_mask = None + + x = RingSDPA.apply(lq, lk, lv, latn_mask, q._spec.mesh, ring_config, kwargs) + + # Convert back to ShardTensor + x = ShardTensor.from_local( + x, q._spec.mesh, q._spec.placements, q._spec.sharding_shapes() + ) + return x + + +# Make sure the module exists before importing it: + + +@wrapt.patch_function_wrapper( + "torch.nn.functional", + "scaled_dot_product_attention", + enabled=ShardTensor.patches_enabled, +) +def sdpa_wrapper( + wrapped: Any, instance: Any, args: tuple, kwargs: dict +) -> Union[torch.Tensor, ShardTensor]: + """Wrapper for natten.functional.na2d to support sharded tensors. + + Handles both regular torch.Tensor inputs and distributed ShardTensor inputs. + For regular tensors, passes through to the wrapped na2d function. + For ShardTensor inputs, handles adding halos and applying distributed na2d. + + Args: + wrapped: Original na2d function being wrapped + instance: Instance the wrapped function is bound to + args: Positional arguments containing query, key, value tensors + kwargs: Keyword arguments including kernel_size and dilation + + Returns: + Result tensor as either torch.Tensor or ShardTensor depending on input types + + Raises: + UndeterminedShardingError: If input tensor types are mismatched + """ + + q, k, v, attn_mask, kwargs = repackage_sdpa_args(*args, **kwargs) + + if all([type(_t) == torch.Tensor for _t in (q, k, v)]): + return wrapped(*args, **kwargs) + elif all([type(_t) == ShardTensor for _t in (q, k, v)]): + + # Make sure all tensors are on the same mesh + if not (q._spec.mesh == k._spec.mesh == v._spec.mesh): + raise MissingShardPatch("q, k, and v must all be on the same mesh") + + # Make sure the mesh is 1D + if q._spec.mesh.ndim != 1: + raise MissingShardPatch("q must be on a 1D mesh") + + return ring_sdpa(q, k, v, attn_mask, **kwargs) + + else: + raise UndeterminedShardingError( + "q, k, and v must all be the same types (torch.Tensor or ShardTensor)" + ) + + +def repackage_sdpa_args( + query: Union[torch.Tensor, ShardTensor], + key: Union[torch.Tensor, ShardTensor], + value: Union[torch.Tensor, ShardTensor], + attn_mask: Optional[Union[torch.Tensor, ShardTensor]] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float = None, + enable_gqa: bool = False, + *args, + **kwargs, +) -> Tuple[ + Union[torch.Tensor, ShardTensor], + Union[torch.Tensor, ShardTensor], + Union[torch.Tensor, ShardTensor], + Union[torch.Tensor, ShardTensor], + dict, +]: + """ + Repackages scaled dot product attention arguments into standard format. + + """ + + if enable_gqa: + raise NotImplementedError("GQA is not implemented for sharded tensors") + + # Package all non-tensor parameters into a kwargs dictionary + return_kwargs = { + "dropout_p": dropout_p, + "is_causal": is_causal, + "scale": scale, + # "enable_gqa": enable_gqa, + } + + return query, key, value, attn_mask, return_kwargs diff --git a/physicsnemo/distributed/shard_utils/conv_patches.py b/physicsnemo/distributed/shard_utils/conv_patches.py index 6d850a66d2..24fae6215e 100644 --- a/physicsnemo/distributed/shard_utils/conv_patches.py +++ b/physicsnemo/distributed/shard_utils/conv_patches.py @@ -14,12 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist -import wrapt +from physicsnemo.utils.profiling import annotate, profile from physicsnemo.utils.version_check import check_module_requirements check_module_requirements("physicsnemo.distributed.shard_tensor") @@ -36,31 +36,49 @@ UndeterminedShardingError, ) -from .halo import ( # noqa: E402 - apply_grad_halo, - halo_padding_1d, - halo_unpadding_1d, - perform_halo_collective, -) +from .halo import HaloConfig, halo_padding # noqa: E402 from .patch_core import promote_to_iterable # noqa: E402 -__all__ = [ - "conv1d_wrapper", - "conv2d_wrapper", - "conv3d_wrapper", -] +aten = torch.ops.aten + +@profile +def conv_output_shape( + L_in: int, padding: int, stride: int, kernel_size: int, dilation: int +) -> int: + """Calculate the output length of a 1D convolution operation. + + This function computes the resulting length of a 1D tensor after applying + a convolution with the given parameters. -def conv_output_shape(L_in, p, s, k, d): - L_out = (L_in + 2 * p - d * (k - 1) - 1) / s + 1 + Args: + L_in: Input length + padding: Padding size (on each side) + stride: Convolution stride + kernel_size: Size of the convolution kernel + dilation: Dilation factor for the kernel + + Returns: + The length of the output tensor after convolution + """ + L_out = (L_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1 return int(L_out) +@profile def compute_halo_from_kernel_stride_and_dilation( - kernel_size: int, stride: int, dilation: int + kernel_size: int, + stride: int, + dilation: int, + padding: Union[int, str], + transposed: bool, ) -> int: """Compute the halo size needed for a convolution kernel along a single dimension. + At a high level, the halo is equal to half the receptive field of the kernel. + There are some subtleties with even vs odd kernel sizes and the conventions of + where a kernel starts getting applied. + Args: kernel_size: Size of convolution kernel along this dimension stride: Convolution stride along this dimension @@ -70,228 +88,321 @@ def compute_halo_from_kernel_stride_and_dilation( Required halo size on each side of a data chunk Raises: - MissingShardPatch: If kernel configuration is not supported for sharding + MissingShardPatch: If kernel configuration is not supported for sharding, + specifically for even kernels without matching stride """ # Special case: even kernel with matching stride and no dilation needs no halo if kernel_size % 2 == 0: - if kernel_size == stride and dilation == 1: + if kernel_size == stride and dilation == 1 and padding == 0: return 0 else: raise MissingShardPatch( - "Sharded Convolution is not implemented for even kernels without matching stride" + "Sharded Convolution is not implemented for even kernels without matching stride and padding 0. " + "If you need this functionality, please open an issue at https://github.com/NVIDIA/PhysicsNemo/issues" ) - if dilation != 1: - raise MissingShardPatch( - "Sharded Convolution is not implemented for dilation != 1" - ) + if transposed: + # Support currently only for even kernels with padding 0 and stride = kernel_size + if kernel_size % 2 != 0 or padding != 0 or stride != kernel_size: + raise MissingShardPatch( + "Sharded Convolution is not implemented for transposed convolutions with non-matching stride or padding. " + "If you need this functionality, please open an issue at https://github.com/NVIDIA/PhysicsNemo/issues" + ) # The receptive field is how far in the input a pixel in the output can see # It's used to calculate how large the halo computation has to be receptive_field = dilation * (kernel_size - 1) + 1 - # The number of halo pixels is the casting `int(receptive field/2)` - # Why? Assuming a filter in the output image is centered in the input image, - # we have only half of it's filter to the left. - # Even kernels: - if kernel_size % 2 == 0: - halo_size = int(receptive_field / 2 - 1) - else: - halo_size = int(receptive_field / 2) + # For odd kernels, the halo size is half the receptive field (integer division) + # This represents how many pixels we need from neighboring ranks on each side + halo_size = receptive_field // 2 return halo_size -def shard_to_haloed_local_for_convNd( - input: ShardTensor, - kernel_shape: Tuple[int, ...], - stride: Union[int, Tuple[int, ...]], - padding: Union[int, Tuple[int, ...]], - dilation: Union[int, Tuple[int, ...]], - groups: int = 1, - **extra_kwargs, -) -> torch.Tensor: - """Converts a sharded tensor to a local tensor with halo regions for convolution. +@profile +def padding_from_str_and_params( + padding: str, + input_shape: Tuple[int, ...], + kernel_size: int, + stride: int, + dilation: int, +) -> int: + """Convert a string padding specification to a numerical value. + + Args: + padding: String padding specification + conv_kwargs: Dictionary of convolution arguments + dim: Dimension index + """ + + if padding == "same": + total_padding = max( + 0, + ( + (input_shape - 1) * stride + + 1 + + (kernel_size - 1) * dilation + - input_shape + ), + ) + return total_padding // 2 + elif padding == "valid": + return 0 + elif padding == "none": + return 0 + else: + raise ValueError(f"Invalid padding specification: {padding}") + - Takes a sharded tensor and adds appropriate halo regions based on the convolution - parameters, so that when the convolution is applied locally it produces the same - result as if applied globally then sharded. +@profile +def compute_halo_configs_from_conv_args( + input: ShardTensor, + kernel_size: Tuple[int, ...], + conv_kwargs: Dict[str, Any], +) -> List[HaloConfig]: + """Compute halo configurations for a sharded tensor based on convolution arguments. Args: - input: ShardTensor to add halos to - kernel_shape: Shape of convolution kernel - stride: Convolution stride (int or tuple matching kernel dims) - padding: Convolution padding (int or tuple matching kernel dims) - dilation: Convolution dilation (int or tuple matching kernel dims) - groups: Number of convolution groups (default: 1) + input: The sharded tensor that will be used in convolution + kernel_size: Tuple of kernel dimensions for the convolution + conv_kwargs: Dictionary of convolution arguments including stride, + padding, dilation, and groups Returns: - Local torch.Tensor with added halo regions + List of HaloConfig objects for each sharded dimension - Raises: - ValueError: If tensor is sharded along batch/channel dims or invalid dims + Note: + This function updates conv_kwargs in place, setting padding to 0 for sharded dimensions. """ - mesh = input._spec.mesh + placements = input._spec.placements - if mesh.ndim != len(placements): - raise ValueError("Mesh dimensions must match number of placements") + stride = conv_kwargs["stride"] + dilation = conv_kwargs["dilation"] + + # This is to update and set the padding to 0 on the sharded dims: + padding = conv_kwargs["padding"] - # Extract parameters for sharded dimensions only - h_kernel = [] - h_stride = [] - h_padding = [] - h_dilation = [] + if isinstance(padding, str): + # Convert this to numerical values: + padding = [ + padding_from_str_and_params( + padding, input.shape[i], kernel_size[i], stride[i], dilation[i] + ) + for i in range(len(kernel_size)) + ] + else: + # Ensure it's a list: + padding = list(padding) - for p in placements: + # All parameters are assumed to be iterables of the same length + halo_configs = [] + + for mesh_dim, p in enumerate(placements): if not isinstance(p, Shard): continue tensor_dim = p.dim - if tensor_dim in [0, 1]: - raise ValueError("Cannot shard convolution along batch/channel dimensions") - - if tensor_dim >= len(kernel_shape) + 2: - raise ValueError("Invalid tensor dimension for kernel rank") + if tensor_dim in [0, 1]: # Skip batch and channel dimensions + continue - # Convert from NCHW indexing to kernel indexing - kernel_idx = tensor_dim - 2 - h_kernel.append(kernel_shape[kernel_idx]) - h_stride.append(stride[kernel_idx]) - h_padding.append(padding[kernel_idx]) - h_dilation.append(dilation[kernel_idx]) + # Map tensor dimension to kernel dimension (accounting for batch, channel dims) + kernel_dim = tensor_dim - 2 + if kernel_dim >= len(kernel_size): + continue - # Compute required halo size for each sharded dim - halo_size = tuple( - compute_halo_from_kernel_stride_and_dilation(k, s, d) - for k, s, d in zip(h_kernel, h_stride, h_dilation) - ) + # Compute halo size for this dimension + halo_size = compute_halo_from_kernel_stride_and_dilation( + kernel_size[kernel_dim], + stride[kernel_dim], + dilation[kernel_dim], + padding[kernel_dim], + conv_kwargs["transposed"], + ) - # Set edge padding type based on convolution padding - edge_padding_t = "zeros" if any(h_padding) else "none" + if halo_size > 0: - # Add halos via collective communication - local_input = HaloPaddingConvND.apply(input, halo_size, edge_padding_t, padding) + # Create a halo config for this dimension - return local_input + halo_configs.append( + HaloConfig( + mesh_dim=mesh_dim, + tensor_dim=tensor_dim, + halo_size=halo_size, + edge_padding_size=padding[kernel_dim], + communication_method="a2a", + async_op=True, + ) + ) + # Set the padding to 0 on the sharded dims: + padding[kernel_dim] = 0 + # Update the padding before returning: + conv_kwargs["padding"] = tuple(padding) -class HaloPaddingConvND(torch.autograd.Function): - """Autograd wrapper for distributed convolution-centric halo padding. + return halo_configs - Handles halo padding for distributed convolutions using ShardTensor concept - (local tensor + device mesh + shard placements). Forward pass gathers adjacent regions - from neighboring devices. Backward pass distributes gradients outward. - Supports multi-dimensional halo passing with compatible mesh and halo parameters.""" +@profile +def compute_output_shape( + sharding_shape: Tuple[int, ...], + conv_kwargs: Dict[str, Any], + kernel_size: Tuple[int, ...], +) -> Tuple[int, ...]: + """ + For a specified input shape, determine the output shape after a convolution. + Handles both regular and transposed convolutions. + """ + output_shape = [] + tensor_rank = len(sharding_shape[2:]) + for tensor_dim in range(tensor_rank): + if not conv_kwargs["transposed"]: + # Regular convolution + num = ( + sharding_shape[tensor_dim + 2] + + 2 * conv_kwargs["padding"][tensor_dim] + - (kernel_size[tensor_dim] - 1) * conv_kwargs["dilation"][tensor_dim] + - 1 + ) + o = num / conv_kwargs["stride"][tensor_dim] + 1 + else: + # Transposed convolution + output_padding = conv_kwargs.get("output_padding", (0,) * tensor_rank)[ + tensor_dim + ] + o = (sharding_shape[tensor_dim + 2] - 1) * conv_kwargs["stride"][tensor_dim] + o = o - 2 * conv_kwargs["padding"][tensor_dim] + o = o + conv_kwargs["dilation"][tensor_dim] * (kernel_size[tensor_dim] - 1) + o = o + output_padding + 1 - @staticmethod - def forward( - ctx, - stensor: ShardTensor, - halo: tuple[int, ...], - edge_padding_t: str, - edge_padding_s: tuple[int, ...], - ) -> torch.Tensor: - """Forward pass of distributed halo padding. + output_shape.append(int(o)) - Args: - stensor: Input ShardTensor - halo: Halo sizes for each dimension - edge_padding_t: Edge padding type ("zeros" or "none") - edge_padding_s: Edge padding sizes + return tuple(output_shape) - Returns: - Padded local tensor - Raises: - ValueError: If halo size does not match mesh rank - """ - mesh = stensor.device_mesh - if len(halo) != mesh.ndim: - raise ValueError( - f"Halo size ({len(halo)}) must match mesh rank ({mesh.ndim})" - ) +@profile +def partial_conv_nd( + input: ShardTensor, + weight: torch.nn.Parameter, + bias: Optional[torch.nn.Parameter], + conv_kwargs: Dict[str, Any], +) -> ShardTensor: + """Perform a convolution on a sharded tensor with halo exchange. + + This high-level, differentiable function computes a convolution on a sharded tensor + by performing these steps: + 1. Calculate the size of halos needed + 2. Apply halo padding (differentiable) + 3. Perform convolution on the padded tensor with padding=0 on sharded dimensions + 4. Return the result as a ShardTensor - placements = stensor.placements - local_tensor = stensor.to_local() - - # Apply halo padding for each sharded dimension - for mesh_dim in range(mesh.ndim): - if isinstance(placements[mesh_dim], Shard): - tensor_dim = placements[mesh_dim].dim - local_tensor = halo_padding_1d( - local_tensor, - mesh, - mesh_dim, - tensor_dim, - halo[mesh_dim], - edge_padding_t, - edge_padding_s[mesh_dim], - ) + Args: + input: The sharded input tensor + weight: Convolution filter weights + bias: Optional bias parameter + conv_kwargs: Dictionary of convolution parameters (stride, padding, etc.) - ctx.halo = halo - ctx.spec = stensor._spec - ctx.requires_input_grad = stensor.requires_grad + Returns: + Resulting ShardTensor after convolution operation + """ + with annotate("partial_conv_nd"): + kernel_size = weight.shape[2:] - return local_tensor + # This will produce one config per sharded dim + # It also *updates* conv_kwargs in place to set padding to 0 on the sharded dims + halo_configs = compute_halo_configs_from_conv_args( + input, kernel_size, conv_kwargs + ) - @staticmethod - def backward( - ctx, grad_output: torch.Tensor - ) -> tuple[ShardTensor, None, None, None]: - """Backward pass of distributed halo padding. + # We get one halo_config per sharded dim. + sharding_shapes = input._spec.sharding_shapes() + # # First, update the shapes to take into account the halo and edge paddings: + + # Create a mapping from mesh_dim to halo_config for easy lookup + halo_config_map = {config.mesh_dim: config for config in halo_configs} + + real_input_shapes = {} + for mesh_dim, sharing_tuple in sharding_shapes.items(): + # If this mesh_dim doesn't need halos, just copy the original shapes + if mesh_dim not in halo_config_map: + real_input_shapes[mesh_dim] = sharing_tuple + continue + + tensor_dim = halo_config_map[mesh_dim].tensor_dim + real_input_shapes[mesh_dim] = [] + for i, s in enumerate(sharing_tuple): + padding = halo_config_map[mesh_dim].halo_size + + if i == 0 or i == len(sharing_tuple) - 1: + # On the edge of the split, the additional size is halo + edge padding + padding += halo_config_map[mesh_dim].edge_padding_size + else: + # Otherwise, its 2xhalo size added on. + padding += halo_config_map[mesh_dim].halo_size + + updated_shape = list(s) + updated_shape[tensor_dim] += padding + + real_input_shapes[mesh_dim].append(tuple(updated_shape)) + + input_spec = input._spec + local_input = input.to_local() + + with annotate("halo_padding"): + # Apply the halo padding to the input tensor + for halo_config in halo_configs: + local_input = halo_padding(local_input, input._spec.mesh, halo_config) + + with annotate("perform_convolution"): + # Perform the convolution on the padded tensor + local_output = perform_convolution( + local_input, weight, bias, input_spec, conv_kwargs + ) - Args: - grad_output: Gradient tensor from downstream + batch_channel_shape = tuple(local_output.shape[:2]) + # Update the output shapes to take into account the batch anc channel dims: + real_output_shapes = { + dim: tuple( + batch_channel_shape + compute_output_shape(s, conv_kwargs, kernel_size) + for s in real_input_shapes[dim] + ) + for dim in real_input_shapes + } - Returns: - Tuple of (gradient ShardTensor, None, None, None) - """ - spec = ctx.spec - mesh = spec.mesh - placements = spec.placements - halo = ctx.halo - - # Process gradients for each sharded dimension - for mesh_dim in range(mesh.ndim): - if isinstance(placements[mesh_dim], Shard): - tensor_dim = placements[mesh_dim].dim - - # Unpad gradients and get halo slices - grad_input, grad_halos = halo_unpadding_1d( - grad_output, - mesh, - mesh_dim, - tensor_dim, - halo[mesh_dim], - return_slices=True, - ) + # Convert the local output to a ShardTensor + with annotate("partial_conv_nd.from_local"): + output = ShardTensor.from_local( + local_output, + input_spec.mesh, + input_spec.placements, + sharding_shapes=real_output_shapes, + ) - # Exchange and accumulate gradient halos - halo_from_left, halo_from_right = perform_halo_collective( - mesh, mesh_dim, *grad_halos - ) - grad_input = apply_grad_halo( - mesh, - mesh_dim, - tensor_dim, - grad_input, - halo_from_left, - halo_from_right, - ) + return output - # Wrap gradient in ShardTensor - grad_tensor = ShardTensor( - grad_input, - spec, - requires_grad=grad_input.requires_grad, - ) - return grad_tensor, None, None, None +@profile +def perform_convolution( + inputs: torch.Tensor, + weights: torch.nn.Parameter, + bias: Optional[torch.nn.Parameter], + input_spec: "ShardTensorSpec", + conv_kwargs: Dict[str, Any], +) -> torch.Tensor: + """Apply a convolution operation using the PartialConvND autograd function. + Args: + inputs: Input tensor to convolve + weights: Convolution filter weights + bias: Optional bias tensor + input_spec: Specification for output ShardTensor + conv_kwargs: Dictionary of convolution parameters -aten = torch.ops.aten + Returns: + ShardTensor containing the convolution result + """ + return PartialConvND.apply(inputs, weights, bias, input_spec, conv_kwargs) class PartialConvND(torch.autograd.Function): @@ -304,14 +415,15 @@ class PartialConvND(torch.autograd.Function): """ @staticmethod + @profile def forward( ctx, inputs: torch.Tensor, weights: torch.nn.Parameter, bias: Optional[torch.nn.Parameter], - output_spec: "ShardTensorSpec", - conv_kwargs: dict, - ) -> "ShardTensor": + input_spec: "ShardTensorSpec", + conv_kwargs: Dict[str, Any], + ) -> torch.Tensor: """Forward pass of the distributed convolution. Args: @@ -319,57 +431,31 @@ def forward( inputs: Input tensor to convolve weights: Convolution filter weights bias: Optional bias tensor - output_spec: Specification for output ShardTensor + input_spec: Specification for output ShardTensor conv_kwargs: Dictionary of convolution parameters (stride, padding, etc.) Returns: ShardTensor containing the convolution result """ # Save spec for backward pass - ctx.spec = output_spec + ctx.spec = input_spec # Save local tensors to avoid distributed dispatch in backward pass ctx.save_for_backward(inputs, weights, bias) - # Get sharded output dimensions by checking placements - sharded_output_dims = [] - for i, placement in enumerate(output_spec.placements): - if isinstance(placement, Shard): - sharded_output_dims.append(placement.dim) - - # Force padding to 0 _along sharded dims only_ since padding is - # handled by halo exchange - # Check if input is channels first (NCHW) or channels last (NHWC) - if inputs.is_contiguous(memory_format=torch.contiguous_format): - offset = 2 - elif inputs.is_contiguous(memory_format=torch.channels_last): - offset = 1 - else: - raise ValueError("Input tensor must be channels first or channels last") - - padding = list(conv_kwargs["padding"]) - - # Update padding arguments. Set to 0 on sharded dims only: - for i, p in enumerate(padding): - if i + offset in sharded_output_dims: - padding[i] = 0 - - conv_kwargs["padding"] = tuple(padding) + # print type of inputs ctx.conv_kwargs = conv_kwargs # Perform local convolution on this shard local_chunk = aten.convolution.default(inputs, weights, bias, **conv_kwargs) - # Wrap result in ShardTensor with specified distribution - output = ShardTensor.from_local( - local_chunk, output_spec.mesh, output_spec.placements - ) - ctx.requires_input_grad = inputs.requires_grad - return output + # return output + return local_chunk @staticmethod + @profile def backward( - ctx, grad_output: "ShardTensor" + ctx, grad_output: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], None, None]: """Backward pass for distributed convolution. @@ -391,10 +477,14 @@ def backward( bias is not None, # bias gradient if bias exists ) + # Cast, in case the precision is off: + weight = weight.to(dtype=grad_output.dtype) + bias = bias.to(dtype=grad_output.dtype) if bias is not None else None + local_chunk = local_chunk.to(dtype=grad_output.dtype) # Compute local gradients - local_grad_output = grad_output._local_tensor + # local_grad_output = grad_output._local_tensor grad_input, grad_weight, grad_bias = aten.convolution_backward( - local_grad_output, + grad_output, local_chunk, weight, bias, @@ -411,51 +501,31 @@ def backward( return grad_input, grad_weight, grad_bias, None, None -@wrapt.patch_function_wrapper( - "torch.nn.functional", "conv1d", enabled=ShardTensor.patches_enabled -) -def conv1d_wrapper(wrapped, instance, args, kwargs): - - return generic_conv_nd_wrapper(wrapped, instance, args, kwargs) - - -@wrapt.patch_function_wrapper( - "torch.nn.functional", "conv2d", enabled=ShardTensor.patches_enabled -) -def conv2d_wrapper(wrapped, instance, args, kwargs): - - return generic_conv_nd_wrapper(wrapped, instance, args, kwargs) +def generic_conv_nd_wrapper(func: callable, types: tuple, args: tuple, kwargs: dict): + """Wrapper function for N-dimensional convolution operations supporting shardtensors. - -@wrapt.patch_function_wrapper( - "torch.nn.functional", "conv3d", enabled=ShardTensor.patches_enabled -) -def conv3d_wrapper(wrapped, instance, args, kwargs): - - return generic_conv_nd_wrapper(wrapped, instance, args, kwargs) - - -def generic_conv_nd_wrapper(wrapped, instance, args, kwargs): - """Generic wrapper for torch N-dimensional convolution operations. - - Handles both regular torch.Tensor inputs and distributed ShardTensor inputs. - For regular tensors, passes through to the wrapped convolution. - For ShardTensor inputs, handles gathering weights/bias and applying distributed - convolution with halo regions. + This function dispatches convolution operations to appropriate implementations based on input types. + It handles both regular and transposed convolutions, and supports both torch.Tensor and ShardTensor inputs. Args: - wrapped: Original convolution function being wrapped - instance: Instance the wrapped function is bound to - args: Positional arguments for convolution - kwargs: Keyword arguments for convolution + func: The convolution function to be wrapped (conv1d, conv2d, etc.) + types: Tuple of input types (unused) + args: Positional arguments to the convolution function + kwargs: Keyword arguments to the convolution function Returns: - Convolution result as either torch.Tensor or ShardTensor + The result of the convolution operation Raises: - UndeterminedShardingError: If input tensor types are invalid + UndeterminedShardingError: If input, weight, or bias have invalid types """ - input, weight, bias, conv_kwargs = repackage_conv_args(*args, **kwargs) + + if "transpose" in func.__name__: + input, weight, bias, conv_kwargs = repackage_conv_transposed_args( + *args, **kwargs + ) + else: + input, weight, bias, conv_kwargs = repackage_conv_args(*args, **kwargs) # Handle regular torch tensor inputs if ( @@ -463,7 +533,7 @@ def generic_conv_nd_wrapper(wrapped, instance, args, kwargs): and type(weight) == torch.nn.parameter.Parameter and (bias is None or type(bias) == torch.nn.parameter.Parameter) ): - return wrapped(*args, **kwargs) + return func(*args, **kwargs) # Handle distributed ShardTensor inputs elif type(input) == ShardTensor: @@ -477,18 +547,14 @@ def generic_conv_nd_wrapper(wrapped, instance, args, kwargs): # Promote scalar args to match kernel dimensions promotables = ["stride", "padding", "dilation", "output_padding"] + conv_kwargs = { key: promote_to_iterable(p, kernel_shape) if key in promotables else p for key, p in conv_kwargs.items() } - # Add halos and perform distributed convolution - local_input = shard_to_haloed_local_for_convNd( - input, kernel_shape, **conv_kwargs - ) - output_spec = input._spec - x = PartialConvND.apply(local_input, weight, bias, output_spec, conv_kwargs) - return x + # Use the convolution args to compute the sharded halo + return partial_conv_nd(input, weight, bias, conv_kwargs) else: msg = ( @@ -501,6 +567,7 @@ def generic_conv_nd_wrapper(wrapped, instance, args, kwargs): raise UndeterminedShardingError(msg) +@profile def repackage_conv_args( input: Union[torch.Tensor, ShardTensor], weight: Union[torch.Tensor, DTensor], @@ -509,7 +576,6 @@ def repackage_conv_args( padding: Union[int, Tuple[int, ...]] = 0, dilation: Union[int, Tuple[int, ...]] = 1, groups: int = 1, - transposed: bool = False, output_padding: Union[int, Tuple[int, ...]] = 0, *args, **kwargs, @@ -550,16 +616,86 @@ def repackage_conv_args( "stride": stride, "padding": padding, "dilation": dilation, - "transposed": transposed, + "transposed": False, + "groups": groups, + "output_padding": output_padding, + } + + return input, weight, bias, return_kwargs + + +@profile +def repackage_conv_transposed_args( + input: Union[torch.Tensor, ShardTensor], + weight: Union[torch.Tensor, DTensor], + bias: Union[torch.Tensor, DTensor, None] = None, + stride: Union[int, Tuple[int, ...]] = 1, + padding: Union[int, Tuple[int, ...]] = 0, + output_padding: Union[int, Tuple[int, ...]] = 0, + groups: int = 1, + dilation: Union[int, Tuple[int, ...]] = 1, + *args, + **kwargs, +) -> Tuple[ + Union[torch.Tensor, ShardTensor], + Union[torch.Tensor, DTensor], + Union[torch.Tensor, DTensor, None], + dict, +]: + """Repackages convolution arguments into standard format. + + Takes the full set of arguments that could be passed to a convolution operation + and separates them into core tensor inputs (input, weight, bias) and + configuration parameters packaged as a kwargs dict. + + Args: + input: Input tensor to convolve + weight: Convolution kernel weights + bias: Optional bias tensor + stride: Convolution stride length(s) + padding: Input padding size(s) + dilation: Kernel dilation factor(s) + groups: Number of convolution groups + transposed: Whether this is a transposed convolution + output_padding: Additional output padding for transposed convs + *args: Additional positional args (unused) + **kwargs: Additional keyword args (unused) + + Returns: + Tuple containing: + - Input tensor + - Weight tensor + - Bias tensor (or None) + - Dict of convolution configuration parameters + """ + # Package all non-tensor parameters into a kwargs dictionary + return_kwargs = { + "stride": stride, + "padding": padding, + "dilation": dilation, "output_padding": output_padding, "groups": groups, + "transposed": True, } return input, weight, bias, return_kwargs -# This will become the future implementation, or similar. -# Why not today? Because the backwards pass in DTensor has an explicit (and insufficient) -# hard coded implementation for the backwards pass. -# When that switch happens, the order in the arg repackaging will need to be updated. -# ShardTensor.register_function_handler(aten.convolution.default, generic_conv_nd_wrapper) +ShardTensor.register_function_handler( + torch.nn.functional.conv1d, generic_conv_nd_wrapper +) +ShardTensor.register_function_handler( + torch.nn.functional.conv2d, generic_conv_nd_wrapper +) +ShardTensor.register_function_handler( + torch.nn.functional.conv3d, generic_conv_nd_wrapper +) +ShardTensor.register_function_handler( + torch.nn.functional.conv_transpose1d, generic_conv_nd_wrapper +) +ShardTensor.register_function_handler( + torch.nn.functional.conv_transpose2d, generic_conv_nd_wrapper +) +ShardTensor.register_function_handler( + torch.nn.functional.conv_transpose3d, generic_conv_nd_wrapper +) diff --git a/physicsnemo/distributed/shard_utils/halo.py b/physicsnemo/distributed/shard_utils/halo.py index de5445d3de..b4d6d4853d 100644 --- a/physicsnemo/distributed/shard_utils/halo.py +++ b/physicsnemo/distributed/shard_utils/halo.py @@ -14,179 +14,447 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union + +from dataclasses import dataclass +from typing import Literal, Optional, Tuple import torch import torch.distributed as dist +from torch.autograd.profiler import record_function from torch.distributed.device_mesh import DeviceMesh -from physicsnemo.utils.version_check import check_module_requirements +from physicsnemo.utils.profiling import profile -check_module_requirements("physicsnemo.distributed.shard_tensor") +"""Halo exchange utilities for distributed tensor operations. +This module provides functionality for halo padding operations in distributed computing +environments. Halo padding is a technique used in distributed tensor operations where +each process needs access to a small region of data (the "halo") from neighboring +processes to perform local computations correctly. -def halo_unpadding_1d( - local_tensor: torch.Tensor, +The module includes: +- HaloConfig: Configuration class for halo exchange parameters +- Autograd-compatible functions for forward and backward passes +- Primitives for halo exchange operations +- Utility functions for slicing and applying halo regions + +""" + + +@dataclass +class HaloConfig: + """Configuration for halo padding operations. + + This class encapsulates all parameters needed for halo exchange operations, + making it easier to pass consistent configurations between functions. + + Attributes: + mesh_dim (int): Mesh dimension for this padding operation + tensor_dim (int): Tensor dimension to pad/unpad + halo_size (int): Size of halo padding (assumed symmetric) + edge_padding_size (int): Edge padding size (puts 0s on the edge tensors) + communication_method (str): Method for exchanging halos ("p2p" or "a2a") + """ + + mesh_dim: int + tensor_dim: int + halo_size: int + edge_padding_size: int = 0 + async_op: bool = False + + CommMethod = Literal["p2p", "a2a"] + VALID_COMM_METHODS = ["p2p", "a2a"] + communication_method: CommMethod = "a2a" + + def __post_init__(self): + """Validate configuration parameters after initialization. + + Raises: + ValueError: If invalid padding type or communication method is specified + """ + + if self.communication_method not in self.VALID_COMM_METHODS: + raise ValueError( + f"Invalid communication method: {self.communication_method}. " + f"Must be one of {self.VALID_COMM_METHODS}" + ) + + if self.async_op and self.communication_method == "p2p": + raise ValueError( + "Async halo padding is not supported with p2p communication. " + "Must be a2a." + ) + + +@profile +def halo_padding( + tensor: torch.Tensor, mesh: DeviceMesh, - mesh_dim: int, - tensor_dim: int, - halo_t: int, - edge_padding_t: Optional[str] = "zeros", - edge_padding_s: Optional[int] = 0, - return_slices: bool = False, -) -> Union[ - torch.Tensor, - Tuple[torch.Tensor, Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]], -]: - """Removes halo padding from a tensor in 1D. - - This is the backward pass of distributed halo padding. Can be chained to remove - halo padding in multiple dimensions if needed. + halo_config: HaloConfig, +) -> torch.Tensor: + """High-level, differentiable function to apply halo padding with gradient support. Args: - local_tensor: Local tensor chunk to unpad - mesh: Device mesh containing sharding information - mesh_dim: Mesh dimension for this unpadding operation - tensor_dim: Tensor dimension to unpad - halo_t: Size of halo padding to remove (assumed symmetric) - edge_padding_t: Edge padding type (currently unused) - edge_padding_s: Edge padding size (only valid with zeros padding, currently unused) - return_slices: Whether to return removed halo slices + tensor: torch.Tensor to apply halo padding to + mesh: DeviceMesh containing device information that the halo is performed on + halo_config: Configuration object containing all halo parameters Returns: - Unpadded tensor if return_slices=False, otherwise tuple of: - - Unpadded tensor - - Tuple of (front slice, end slice) containing removed halos + Padded tensor with halos added locally to each chunk. This is *not* a ShardTensor - it + is a torch.Tensor that has had local edges replicated from neighboring ranks. """ - # Validate mesh dimension - if mesh_dim is None: - if mesh.ndim != 1: - raise ValueError( - f"Halo padding requires `dim` for mesh size > 1 (got shape {mesh.shape})" - ) - mesh_dim = 0 - # Get process group info - local_group = mesh.get_group(mesh_dim) - local_rank = mesh.get_local_rank(mesh_dim) - local_size = dist.get_world_size(group=local_group) + return HaloPadding.apply(tensor, mesh, halo_config) - # Get shape of dimension being unpadded - dim_shape = local_tensor.shape[tensor_dim] - # Calculate slice boundaries - start = halo_t if local_rank != 0 else 0 - end = dim_shape - halo_t if local_rank != local_size - 1 else dim_shape +@profile +def unhalo_padding( + tensor: torch.Tensor, + mesh: DeviceMesh, + halo_config: HaloConfig, +) -> torch.Tensor: + """High-level, differentiable function to apply unhalo padding with gradient support. - if return_slices: - # Get removed halo slices for non-edge ranks - front_slice = None - if local_rank != 0: - front_indices = torch.arange(0, start).to(local_tensor.device) - front_slice = local_tensor.index_select( - tensor_dim, front_indices - ).contiguous() + This function removes halo regions from a tensor according to the provided configuration. + It is the inverse operation of halo_padding and maintains differentiability for gradients. - end_slice = None - if local_rank != local_size - 1: - end_indices = torch.arange(end, dim_shape).to(local_tensor.device) - end_slice = local_tensor.index_select(tensor_dim, end_indices).contiguous() + Args: + tensor: Padded tensor with halos to be removed + mesh: DeviceMesh containing device information for the operation + halo_config: Configuration object containing all halo parameters - # Remove halo padding - indices = torch.arange(start, end).to(local_tensor.device) - local_tensor = local_tensor.index_select(tensor_dim, indices).contiguous() + Returns: + Tensor with halo regions removed according to the configuration + """ + + return UnHaloPadding.apply(tensor, mesh, halo_config) + + +class HaloPadding(torch.autograd.Function): + """Autograd Function for applying and removing halo padding. + + This class handles the forward and backward passes for halo padding operations, + maintaining proper gradient flow between distributed tensors. + """ + + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + tensor: torch.Tensor, + mesh: DeviceMesh, + config: HaloConfig, + ) -> torch.Tensor: + """Add halo padding to a tensor in the forward pass. + + Args: + ctx: Autograd context for saving tensors/variables for backward + tensor: torch.Tensor to apply halo padding to + mesh: DeviceMesh containing device information that the halo is performed on + config: HaloConfig defining padding parameters + + Returns: + Padded tensor with halos added locally to each chunk + """ + + # Save context for backward pass + ctx.mesh = mesh + ctx.config = config + + padded_tensor = halo_padding_fwd_primitive( + tensor, + mesh, + config, + ) + return padded_tensor + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor + ) -> Tuple[torch.Tensor, None, None]: + """Handle gradients by removing halo padding and applying halo gradients. + + Args: + ctx: Autograd context from forward pass + grad_output: Gradient tensor with halo padding + + Returns: + Tuple of (gradient for input Tensor, None) + """ + mesh = ctx.mesh + config = ctx.config + + grad_input = halo_padding_bwd_primitive( + grad_output, + mesh, + config, + ) + + return grad_input, None, None + + +class UnHaloPadding(torch.autograd.Function): + """Autograd Function for removing halo padding with gradient support. + + This class implements the forward and backward passes for unhalo padding operations. + In the forward pass, it removes halo regions from the input tensor according to the + configuration. In the backward pass, it adds zero padding in the halo regions to + maintain the correct shape for gradient propagation. + + This is the inverse operation of HaloPadding and maintains differentiability. + """ + + @staticmethod + def forward( + ctx, tensor: torch.Tensor, mesh: DeviceMesh, config: HaloConfig + ) -> torch.Tensor: + """ + Forward pass for unhalo padding. + + Conceptually, this is a truncated version of the bwd pass of halo padding. + It is actually collective-free in the forward pass, since we just cut pieces off. + + We still require the mesh to save it for the backward pass. + + Args: + ctx: Autograd context for saving tensors/variables for backward + tensor: torch.Tensor to apply halo padding to + mesh: DeviceMesh containing device information that the halo is performed on + config: HaloConfig defining padding parameters + """ + + # Save context for backward pass + ctx.mesh = mesh + ctx.config = config + + # Chop off the halos + _left, unpadded_tensor, _right = slice_halo_regions( + tensor, + mesh, + config, + ) + + ctx.left_shape = _left.shape + ctx.right_shape = _right.shape + + return unpadded_tensor - if return_slices: - return local_tensor, (front_slice, end_slice) + @staticmethod + def backward( + ctx, + grad_output: torch.Tensor, + ) -> Tuple[torch.Tensor, None, None]: + """ + Backward pass for unhalo padding. - return local_tensor + In the backward pass, we need to add zero tensors where we previously + removed the halo regions in the forward pass. This effectively pads + the gradient with zeros in the halo regions. + Args: + ctx: Autograd context containing saved tensors/variables from forward + grad_output: Gradient of the loss with respect to the output of forward -def halo_padding_1d( + Returns: + Tuple containing: + - Gradient with respect to the input tensor + - None for mesh parameter (not differentiable) + - None for config parameter (not differentiable) + """ + + left_zeros = torch.zeros( + ctx.left_shape, device=grad_output.device, dtype=grad_output.dtype + ) + right_zeros = torch.zeros( + ctx.right_shape, device=grad_output.device, dtype=grad_output.dtype + ) + + grad_input = apply_halo_tensors( + ctx.mesh, + ctx.config, + grad_output, + left_zeros, + right_zeros, + ) + + return grad_input, None, None + + +@profile +def halo_padding_fwd_primitive( local_tensor: torch.Tensor, mesh: DeviceMesh, - mesh_dim: int, - tensor_dim: int, - halo_t: int, - edge_padding_t: Optional[str] = "zeros", - edge_padding_s: Optional[int] = 0, -) -> torch.Tensor: # pragma: no cover - """Adds halo padding to a tensor in 1D. + halo_config: HaloConfig, +) -> torch.Tensor: + """ + Forward primitive for halo padding. - This is the forward pass of distributed halo padding. Can be chained to add - halo padding in multiple dimensions if needed. + Halo padding is meant for operations + that are applying a localized function (like convolution, but need not be conv) + to a spatially sharded tensor. During the forward pass, the inputs from the + neighboring tensors are copied from remote regions and appended to the local image. Args: - local_tensor: Local tensor chunk to pad + local_tensor: The local tensor chunk to pad with halos mesh: Device mesh containing sharding information - mesh_dim: Mesh dimension for this padding operation - tensor_dim: Tensor dimension to pad - halo_t: Size of halo padding to add (assumed symmetric) - edge_padding_t: Edge padding type (zeros, reflect, replicate, circular, none) - edge_padding_s: Edge padding size (only valid with zeros padding) + halo_config: HaloConfig defining padding parameters Returns: - Padded tensor with halos added locally to each chunk - - Note: - Coalescing the padded tensor directly without consuming the halo will produce - invalid results. + Padded tensor with halos from neighboring ranks """ - valid_padding = ["zeros", "reflect", "replicate", "circular", "none"] - if edge_padding_t not in valid_padding: - raise ValueError(f"Invalid edge padding: {edge_padding_t}") - - if edge_padding_s != 0 and edge_padding_t != "zeros": - raise NotImplementedError( - f"Edge padding size != 0 only supported with zeros padding " - f"(got size={edge_padding_s}, type={edge_padding_t})" - ) - - # Validate mesh dimension - if mesh_dim is None: - if mesh.ndim != 1: - raise ValueError( - f"Halo padding requires `dim` for mesh size > 1 (got shape {mesh.shape})" - ) + # It's not optimized, but we pull of the halo from both sides currently. One + # gets discarded on the edge ranks. But, it would have to wait + # for the other ranks to make this selection anyways. # Select halo regions to exchange - left_indices = torch.arange(0, halo_t).to(local_tensor.device) - max_index = local_tensor.shape[tensor_dim] + left_indices = torch.arange(0, halo_config.halo_size, device=local_tensor.device) + max_index = local_tensor.shape[halo_config.tensor_dim] right_indices = max_index - 1 - left_indices right_indices = torch.flip(right_indices, (0,)) - halo_to_left = local_tensor.index_select(tensor_dim, left_indices).contiguous() - halo_to_right = local_tensor.index_select(tensor_dim, right_indices).contiguous() + # Collectives need contiguous data. So we enforce that here. + halo_to_left = local_tensor.index_select( + halo_config.tensor_dim, left_indices + ).contiguous() + halo_to_right = local_tensor.index_select( + halo_config.tensor_dim, right_indices + ).contiguous() # Exchange halos between ranks halo_from_left, halo_from_right = perform_halo_collective( - mesh, mesh_dim, halo_to_left, halo_to_right + mesh, + halo_config.mesh_dim, + halo_to_left, + halo_to_right, + halo_config.communication_method, + halo_config.async_op, ) # Combine local tensor with received halos - padded_output = unpack_halo_tensors( + padded_output = apply_halo_tensors( mesh, - mesh_dim, - tensor_dim, + halo_config, + local_tensor, halo_from_left, halo_from_right, - local_tensor, - edge_padding_s, - edge_padding_t, ) - return torch.cat(padded_output, dim=tensor_dim) + return padded_output + + +@profile +def slice_halo_regions( + local_tensor: torch.Tensor, + mesh: DeviceMesh, + halo_config: HaloConfig, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Splits a tensor into left halo, center, and right halo regions. + + This primitive function divides the input tensor along the specified dimension into + three parts: left halo region, central tensor (without halos), and right halo region. + The slicing boundaries are determined based on the rank in the mesh and the halo configuration. + + "left" and "right" do not necessarily correspond to spatial locations, but instead + think of "left" as the region closer to rank 0 and "right" closer to rank N-1. + + Args: + local_tensor: Input tensor to be sliced + mesh: DeviceMesh containing device information + halo_config: Configuration defining halo parameters and dimensions + + Returns: + Tuple of (left_slice, central_slice, right_slice) tensors + """ + + # Get process group info + local_group = mesh.get_group(halo_config.mesh_dim) + local_rank = mesh.get_local_rank(halo_config.mesh_dim) + local_size = dist.get_world_size(group=local_group) + + # Get shape of dimension being unpadded + dim_shape = local_tensor.shape[halo_config.tensor_dim] + + # Calculate slice boundaries + start = halo_config.halo_size if local_rank != 0 else halo_config.edge_padding_size + end = ( + dim_shape - halo_config.halo_size + if local_rank != local_size - 1 + else dim_shape - halo_config.edge_padding_size + ) + + left_slice, central_slice, right_slice = torch.tensor_split( + local_tensor, [start, end], dim=halo_config.tensor_dim + ) + + return left_slice, central_slice, right_slice + + +@profile +def halo_padding_bwd_primitive( + grad_output: torch.Tensor, + mesh: DeviceMesh, + halo_config: HaloConfig, +) -> torch.Tensor: + """Backward primitive for halo padding. + + Recall the forward pass is a concatenation of neighboring regions. + The backward pass takes the gradients of the padded images, slices + off the pieces that represent the halos, performs a halo collective, + and *adds* the gradients to their original positions in the local grads. + + Args: + grad_output: Gradient tensor from upstream operations + mesh: Device mesh containing sharding information + halo_config: HaloConfig defining padding parameters + + Returns: + Gradient tensor with halo contributions applied + """ + + grad_to_left, local_grad, grad_to_right = slice_halo_regions( + grad_output, + mesh, + halo_config, + ) + + # Exchange halos between ranks + grad_from_left, grad_from_right = perform_halo_collective( + mesh, + halo_config.mesh_dim, + grad_to_left.contiguous(), + grad_to_right.contiguous(), + halo_config.communication_method, + ) + + # Apply halo gradients + final_grad_local = apply_grad_halo( + mesh, + halo_config, + local_grad, + grad_from_left, + grad_from_right, + ) + return final_grad_local + +@profile def perform_halo_collective( mesh: DeviceMesh, mesh_dim: int, halo_to_left: torch.Tensor, halo_to_right: torch.Tensor, - method: str = "a2a", + method: Literal["p2p", "a2a"] = "a2a", + async_op: bool = False, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: """Performs collective communication to exchange halo regions between ranks. + There is an assumption made here that messages are symmetric between paired + processes in terms of message size. So, size(message_to_right) == size(message_from_left). + This assumption is used when preparing buffers for the incoming messages. + + If messages aren't being sent in one direction, it's expected to take + in an empty tensor with the proper device and dtype still. + Args: mesh: Device mesh for communication mesh_dim: Mesh dimension for exchange @@ -197,13 +465,17 @@ def perform_halo_collective( Returns: Tuple of (halo from left, halo from right) tensors """ + + # We get the dtype and device from the first non-None tensor + # Do not use this as the generalized template - we don't assume a + # rank is sending equal amounts of data left and right. Only + # assume the messages are symmetric between template_halo = next( (x for x in [halo_to_left, halo_to_right] if x is not None), None ) - if template_halo is None: - raise ValueError( - "At least one of halo_to_left or halo_to_right must not be None" - ) + + dtype = template_halo.dtype + device = template_halo.device # Get process group info local_group = mesh.get_group(mesh_dim) @@ -272,12 +544,10 @@ def perform_halo_collective( elif method == "a2a": # All-to-all communication all_to_all_send = [ - torch.empty(0, dtype=template_halo.dtype, device=template_halo.device) - for _ in range(local_size) + torch.empty(0, dtype=dtype, device=device) for _ in range(local_size) ] all_to_all_recv = [ - torch.empty(0, dtype=template_halo.dtype, device=template_halo.device) - for _ in range(local_size) + torch.empty(0, dtype=dtype, device=device) for _ in range(local_size) ] # Set up send/recv buffers @@ -286,7 +556,7 @@ def perform_halo_collective( all_to_all_send[local_rank - 1] = halo_to_left # Receive one right (need to initialize an empty buffer of the right size): all_to_all_recv[local_rank - 1] = torch.zeros_like( - template_halo + halo_to_left ).contiguous() if local_rank != local_size - 1: @@ -294,11 +564,18 @@ def perform_halo_collective( all_to_all_send[local_rank + 1] = halo_to_right # Receive one from the right: all_to_all_recv[local_rank + 1] = torch.zeros_like( - template_halo + halo_to_right ).contiguous() # Perform exchange - dist.all_to_all(all_to_all_recv, all_to_all_send, group=local_group) + with record_function("all_to_all_queue_and_wait"): + request = dist.all_to_all( + all_to_all_recv, all_to_all_send, group=local_group, async_op=async_op + ) + + if async_op: + # According to the docs, this will wait until the collectives are enqueued and it's safe to use the recv buffers. + request.wait() # Extract received halos halo_from_left = all_to_all_recv[local_rank - 1] if local_rank != 0 else None @@ -309,34 +586,29 @@ def perform_halo_collective( return halo_from_left, halo_from_right -def unpack_halo_tensors( +@profile +def apply_halo_tensors( mesh: DeviceMesh, - mesh_dim: int, - target_dim: int, + halo_config: HaloConfig, + local_tensor: torch.Tensor, halo_from_left: Optional[torch.Tensor], halo_from_right: Optional[torch.Tensor], - local_tensor: torch.Tensor, - edge_padding_s: Optional[int], - edge_padding_t: str, -) -> List[torch.Tensor]: +) -> torch.Tensor: """Combines local tensor with received halos and edge padding. Args: mesh: Device mesh for process info - mesh_dim: Mesh dimension being padded - target_dim: Tensor dimension being padded + halo_config: HaloConfig defining padding parameters + local_tensor: Local tensor chunk halo_from_left: Halo received from left rank halo_from_right: Halo received from right rank - local_tensor: Local tensor chunk - edge_padding_s: Edge padding size - edge_padding_t: Edge padding type Returns: - List of tensors to concatenate for final padded result + Padded tensor with halos from neighboring ranks """ # Get process group info - local_group = mesh.get_group(mesh_dim) - local_rank = mesh.get_local_rank(mesh_dim) + local_group = mesh.get_group(halo_config.mesh_dim) + local_rank = mesh.get_local_rank(halo_config.mesh_dim) local_size = dist.get_world_size(group=local_group) padded_output = [] @@ -345,68 +617,49 @@ def unpack_halo_tensors( if local_rank != 0: padded_output.append(halo_from_left) else: - if edge_padding_t == "zeros": - if edge_padding_s is None: - padded_output.append(torch.zeros_like(halo_from_right)) - else: - shape = list(halo_from_right.shape) - shape[target_dim] = edge_padding_s - zeros = torch.zeros( - shape, device=halo_from_right.device, dtype=halo_from_right.dtype - ) - padded_output.append(zeros) - elif edge_padding_t == "reflect": - padded_output.append(halo_from_right.flip(target_dim)) - elif edge_padding_t == "replicate": - raise NotImplementedError("Replicate padding not implemented") - elif edge_padding_t == "circular": - padded_output.append(halo_from_right) - elif edge_padding_t == "none": - pass - - # Add local tensor + if halo_config.edge_padding_size > 0: + shape = list(halo_from_right.shape) + shape[halo_config.tensor_dim] = halo_config.edge_padding_size + zeros = torch.zeros( + shape, device=halo_from_right.device, dtype=halo_from_right.dtype + ) + padded_output.append(zeros) + + # Add the original, now central tensor padded_output.append(local_tensor) # Add right padding if local_rank != local_size - 1: padded_output.append(halo_from_right) else: - if edge_padding_t == "zeros": - if edge_padding_s is None: - padded_output.append(torch.zeros_like(halo_from_left)) - else: - shape = list(halo_from_left.shape) - shape[target_dim] = edge_padding_s - zeros = torch.zeros( - shape, device=halo_from_left.device, dtype=halo_from_left.dtype - ) - padded_output.append(zeros) - elif edge_padding_t == "reflect": - padded_output = halo_from_left.flip(target_dim) - elif edge_padding_t == "replicate": - raise NotImplementedError("Replicate padding not implemented") - elif edge_padding_t == "circular": - padded_output.append(halo_from_left) - elif edge_padding_t == "none": - pass + if halo_config.edge_padding_size > 0: + shape = list(halo_from_left.shape) + shape[halo_config.tensor_dim] = halo_config.edge_padding_size + zeros = torch.zeros( + shape, device=halo_from_left.device, dtype=halo_from_left.dtype + ) + padded_output.append(zeros) - return padded_output + return torch.cat(padded_output, dim=halo_config.tensor_dim) +@profile def apply_grad_halo( mesh: DeviceMesh, - mesh_dim: int, - tensor_dim: int, + halo_config: HaloConfig, grad_input: torch.Tensor, halo_from_left: torch.Tensor, halo_from_right: torch.Tensor, ) -> torch.Tensor: """Applies halo gradients to input gradient tensor. + The forward pass of a halo is padding to edges. The backward + pass is to trim add the halos to the edges of the local region + (in the same locations that were sent previously). + Args: mesh: Device mesh for process info - mesh_dim: Mesh dimension for halo - tensor_dim: Tensor dimension for halo + halo_config: HaloConfig defining padding parameters grad_input: Input gradient tensor halo_from_left: Gradient from left halo halo_from_right: Gradient from right halo @@ -415,20 +668,27 @@ def apply_grad_halo( Updated gradient tensor with halo gradients applied """ # Get process group info - local_group = mesh.get_group(mesh_dim) - local_rank = mesh.get_local_rank(mesh_dim) + local_group = mesh.get_group(halo_config.mesh_dim) + local_rank = mesh.get_local_rank(halo_config.mesh_dim) local_size = dist.get_world_size(group=local_group) # Apply right halo gradient if local_rank != local_size - 1: - start_idx = grad_input.shape[tensor_dim] - halo_from_right.shape[tensor_dim] - length = halo_from_right.shape[tensor_dim] - grad_input.narrow(tensor_dim, start_idx, length).add_(halo_from_right) + start_idx = ( + grad_input.shape[halo_config.tensor_dim] + - halo_from_right.shape[halo_config.tensor_dim] + ) + length = halo_from_right.shape[halo_config.tensor_dim] + grad_input.narrow(halo_config.tensor_dim, start_idx, length).add_( + halo_from_right + ) # Apply left halo gradient if local_rank != 0: start_idx = 0 - length = halo_from_left.shape[tensor_dim] - grad_input.narrow(tensor_dim, start_idx, length).add_(halo_from_left) + length = halo_from_left.shape[halo_config.tensor_dim] + grad_input.narrow(halo_config.tensor_dim, start_idx, length).add_( + halo_from_left + ) return grad_input diff --git a/physicsnemo/distributed/shard_utils/index_ops.py b/physicsnemo/distributed/shard_utils/index_ops.py new file mode 100644 index 0000000000..152339a9ff --- /dev/null +++ b/physicsnemo/distributed/shard_utils/index_ops.py @@ -0,0 +1,491 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Tuple, Union + +import torch +import wrapt + +from physicsnemo.utils.version_check import check_module_requirements + +check_module_requirements("physicsnemo.distributed.shard_tensor") + +from torch.distributed.tensor.placement_types import ( # noqa: E402 + Replicate, + Shard, +) + +from physicsnemo.distributed import ShardTensor # noqa: E402 +from physicsnemo.distributed._shard_tensor_spec import ( # noqa: E402 + ShardTensorSpec, + TensorMeta, + _stride_from_contiguous_shape_C_style, +) +from physicsnemo.distributed.shard_utils.patch_core import ( # noqa: E402 + MissingShardPatch, +) + +aten = torch.ops.aten + +__all__ = [ + "index_select_wrapper", +] + + +class ShardedIndexSelect(torch.autograd.Function): + """ + Autograd function implementing a differentiable index_select operation for ShardTensors. + + This class provides both forward and backward pass implementations to enable + gradient computation through the index_select operation when working with + distributed sharded tensors. + """ + + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + tensor: ShardTensor, + dim: int, + index: ShardTensor, + ) -> ShardTensor: + """ + Implementation of a differentiable index select operation on ShardTensors. + + This requires collectives and temporarily utilizing the full shape. + It could be optimized, for large tensors, to use a ring and smarter indexing. + + Parameters + ---------- + ctx : torch.autograd.function.FunctionCtx + Context object to store information for backward pass + tensor : ShardTensor + Input tensor to select from + dim : int + Dimension along which to index + index : ShardTensor + Indices to select + + Returns + ------- + ShardTensor + Output tensor containing the selected elements + + Raises + ------ + MissingShardPatch + If the index sharding strategy is not implemented + """ + # This is the simplest implementation, to enable functionality. + # It could be optimized for very large tensors to ensure performace. + + # We save the local version of the index and the input tensor spec for the backwards pass + + ctx.spec = tensor._spec + ctx.grad_shape = tensor._local_tensor.shape + ctx.dim = dim + + # First - Make sure we have the full input tensor + # Triggers an all_gather(_v) for (uneven) tensors. + local_tensor = tensor.full_tensor() + + # Perform the index select using the local values of the index: + local_index = index.to_local() + ctx.save_for_backward(index) + + # Get everything requested from the local index: + local_values = aten.index_select(local_tensor, dim, local_index) + + # Now, we do gymnastics to make sure the output is correctly sharded. + # Because index is one dimensional, by requirement of the underlying function, + # it's not as annoying as it could be. + index_placement = index._spec.placements[0] + + if index_placement.is_shard(): + # Then, we return a tensor sharded along dim aka Shard(dim). + # Size per rank is easy to compute, no communication needed. + output_size = list(tensor.shape) + output_shard_sizes = {} + for mesh_dim, index_shard_sizes in index._spec.sharding_shapes().items(): + output_shard_sizes[mesh_dim] = [] + for local_chunk_size in index_shard_sizes: + this_shard_size = output_size + this_shard_size[dim] = local_chunk_size[0] + # Make sure it's a tuple: + output_shard_sizes[mesh_dim].append( + torch.Size(tuple(this_shard_size)) + ) + # Make sure it's a tuple: + output_shard_sizes[mesh_dim] = tuple(output_shard_sizes[mesh_dim]) + + ctx.output_shard_sizes = output_shard_sizes + + return_tensor = ShardTensor.from_local( + local_values, + device_mesh=tensor._spec.mesh, + placements=[ + Shard(dim), + ], + sharding_shapes=output_shard_sizes, + ) + return return_tensor + elif index_placement.is_replicate(): + # The output sharding should match the sharding of the original tensor. + output_size = list(tensor.shape) + + # Replace the output size along the indexing dim with the right size: + output_size[dim] = local_values.shape[dim] + # Cast to shard tensor (as replicated, right now): + output = ShardTensor.from_local( + local_values, + device_mesh=tensor._spec.mesh, + placements=[ + Replicate(), + ], + ) + + # Redistribute to the original sharding of the input tensor: + output = output.redistribute(tensor._spec.mesh, tensor._spec.placements) + + return output + + else: + raise MissingShardPatch( + f"Index select is not implemented for {index_placement} sharding." + ) + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, grad_output: ShardTensor + ) -> Tuple[ShardTensor, None, None]: + """ + Backward pass for the index_select operation on ShardTensors. + + The backward pass sends gradients appropriately to the input tensor. + Therefore, its sharding should match the input tensor's sharding. + + Parameters + ---------- + ctx : torch.autograd.function.FunctionCtx + Context object containing saved tensors and attributes from forward pass + grad_output : ShardTensor + Gradient of the loss with respect to the output of forward pass + + Returns + ------- + Tuple[ShardTensor, None, None] + Tuple containing: + - Gradient with respect to input tensor + - None for dim parameter (not differentiable) + - None for index parameter (not differentiable) + """ + (index,) = ctx.saved_tensors + spec = ctx.spec + dim = ctx.dim + + local_index = index.full_tensor() + + grad_inputs = torch.zeros( + spec.tensor_meta.shape, device=grad_output._local_tensor.device + ) + # local_grad_output = grad_output.to_local() + local_grad_output = grad_output.full_tensor() + + grad_inputs = aten.index_add(grad_inputs, dim, local_index, local_grad_output) + + # Now, grad_inputs is replicated on all devices. + # Shard it along the original sharding of the input tensor. + grad_inputs = ShardTensor.from_local( + grad_inputs, + device_mesh=spec.mesh, + placements=[ + Replicate(), + ], + ) + grad_inputs = grad_inputs.redistribute(spec.mesh, spec.placements) + + return grad_inputs, None, None + + +def sharded_index_select( + tensor: ShardTensor, + dim: int, + index: ShardTensor, +) -> ShardTensor: + """ + Performs an index_select operation on ShardTensors with autograd support. + + This is a thin wrapper around the ShardedIndexSelect autograd function + to make the operation differentiable. + + Parameters + ---------- + tensor : ShardTensor + Input tensor to select from + dim : int + Dimension along which to index + index : ShardTensor + Indices to select + + Returns + ------- + ShardTensor + Output tensor containing the selected elements + """ + return ShardedIndexSelect.apply(tensor, dim, index) + + +@wrapt.patch_function_wrapper( + "torch", + "index_select", + enabled=ShardTensor.patches_enabled, +) +def index_select_wrapper( + wrapped: Any, instance: Any, args: tuple, kwargs: dict +) -> Union[ShardTensor, torch.Tensor]: + """ + Wrapper for index_select operation that handles both ShardTensors and regular Tensors. + + This function dispatches to the appropriate implementation based on the input types. + For ShardTensors, it uses sharded_index_select, otherwise falls back to torch's index_select. + + + Returns + ------- + Union[ShardTensor, torch.Tensor] + Output tensor containing the selected elements + + Raises + ------ + TypeError + If the input combination is not supported + """ + + # Extract the tensor and index from the arguments + tensor, dim, index = args + + if isinstance(tensor, ShardTensor) and isinstance(index, ShardTensor): + return sharded_index_select(tensor, dim, index) + elif isinstance(tensor, torch.Tensor) and isinstance(index, torch.Tensor): + return torch.index_select(tensor, dim, index) + else: + raise TypeError( + f"Unsupported input types: tensor {type(tensor)}, index {type(index)}" + ) + + +def sharded_select_helper(tensor: ShardTensor, dim: int, index: int) -> ShardTensor: + """ + This function contains the logic for performing a select operation on a ShardTensor. + """ + + # if the chunking dimension is along a dimension that is sharded, we have to handle that. + # If it's along an unsharded dimension, there is nearly nothing to do. + + input_spec = tensor._spec + + input_placements = input_spec.placements + + shards = [s for s in input_placements if isinstance(s, Shard)] + + # We are reducing tensor rank and returning one sharding per tensor: + original_shape = list(input_spec.shape) + + if dim in [i.dim for i in shards]: + raise MissingShardPatch( + "No implementation for aten.select.int along sharding axis yet." + ) + + else: + + # We are reducing tensor rank: + original_shape.pop(dim) + output_stride = _stride_from_contiguous_shape_C_style(original_shape) + + # Need to create a new global meta: + new_meta = TensorMeta( + torch.Size(tuple(original_shape)), + stride=output_stride, + dtype=input_spec.tensor_meta.dtype, + ) + # The placements get adjusted too + new_placements = [] + for p in input_spec.placements: + if p.is_replicate(): + new_placements.append(p) + elif p.is_shard(): + if p.dim > dim: + new_placements.append(Shard(p.dim - 1)) + else: + new_placements.append(p) + elif p.is_partial(): + raise MissingShardPatch( + "Partial placement not supported yet for select" + ) + + # We can directly compute the sizes from the input spec sharding sizes: + # Since the constraint above prevents selecting along a sharded dimension, + # we can be sure that none of these adjusted shapes will be sharded. + output_shard_sizes = {} + for mesh_dim, index_shard_sizes in input_spec.sharding_shapes().items(): + output_shard_sizes[mesh_dim] = [] + for local_chunk_size in index_shard_sizes: + local_chunk_size_list = list(local_chunk_size) + local_chunk_size_list.pop(dim) + output_shard_sizes[mesh_dim].append( + torch.Size(tuple(local_chunk_size_list)) + ) + output_shard_sizes[mesh_dim] = tuple(output_shard_sizes[mesh_dim]) + + output_spec = ShardTensorSpec( + mesh=input_spec.mesh, + placements=tuple(new_placements), + tensor_meta=new_meta, + _sharding_shapes=output_shard_sizes, + ) + # Finally, actually perform the select: + local_result = aten.select.int(tensor._local_tensor, dim, index) + + return ShardTensor( + local_result, + output_spec, + requires_grad=False, # This will get adjusted after the dispatcher + ) + + +def sharded_select_backward_helper( + grad_output: ShardTensor, input_sizes: torch.Size, dim: int, index: int +) -> ShardTensor: + """ + This function contains the logic for performing a gradient of a select operation on a ShardTensor. + + We shard the gradients analogously to the output gradients. + + """ + + # if the chunking dimension is along a dimension that is sharded, we have to handle that. + # If it's along an unsharded dimension, there is nearly nothing to do. + + input_placements = grad_output._spec.placements + + output_stride = _stride_from_contiguous_shape_C_style(input_sizes) + + # Need to create a new global meta: + new_meta = TensorMeta( + torch.Size(tuple(input_sizes)), + stride=output_stride, + dtype=grad_output._spec.tensor_meta.dtype, + ) + + new_placements = input_placements + # The placements get adjusted too + new_placements = [] + for p in grad_output._spec.placements: + if p.is_replicate(): + new_placements.append(p) + elif p.is_shard(): + if p.dim >= dim: + new_placements.append(Shard(p.dim + 1)) + else: + new_placements.append(p) + elif p.is_partial(): + raise Exception("Partial placement not supported yet for select_backward") + + # Next, calculate the sharding sizes for the output tensor: + output_shard_sizes = {} + for mesh_dim, index_shard_sizes in grad_output._spec.sharding_shapes().items(): + output_shard_sizes[mesh_dim] = [] + for local_chunk_size in index_shard_sizes: + # We need to insert input_sizes[dim] at index: + local_chunk_size_list = list(local_chunk_size) + local_chunk_size_list.insert(dim, input_sizes[dim]) + output_shard_sizes[mesh_dim].append( + torch.Size(tuple(local_chunk_size_list)) + ) + output_shard_sizes[mesh_dim] = tuple(output_shard_sizes[mesh_dim]) + + output_spec = ShardTensorSpec( + mesh=grad_output._spec.mesh, + placements=tuple(new_placements), + tensor_meta=new_meta, + _sharding_shapes=output_shard_sizes, + ) + + # Finally, make sure we use the correct local size: + mesh_rank = grad_output._spec.mesh.get_local_rank() + if len(output_shard_sizes.keys()) > 0: + local_output_size = output_shard_sizes[0][mesh_rank] + else: + # Fall back to the global shape if nothing is sharded: + local_output_size = output_spec.tensor_meta.shape + + # Now, compute the local result: + local_result = aten.select_backward( + grad_output._local_tensor, local_output_size, dim, index + ) + + return ShardTensor( + local_result, + output_spec, + requires_grad=False, # This will get adjusted after the dispatcher + ) + + +def select_wrapper(tensor, dim, index): + + if not isinstance(dim, int): + raise TypeError(f"Dim must be an int, got {type(dim)}") + if not isinstance(index, int): + raise TypeError(f"Index must be an int, got {type(index)}") + + # This is a _dispatch_level_ wrapper, so we're intercepting aten.select.int + + if isinstance(tensor, ShardTensor): + # if the select index is not sharded, just perform locally, repackage, and return. + + # Perform the select locally: + return sharded_select_helper(tensor, dim, index) + + elif isinstance(tensor, torch.Tensor): + return aten.select.int(tensor, dim, index) + else: + raise MissingShardPatch(f"Unsupported tensor type: {type(tensor)}") + + +def select_backward_wrapper(grad_output, input_sizes, dim, index): + """ + Backward function for select operation. + + Args: + grad_output: Gradient from the downstream operation + input_sizes: Original tensor sizes before select operation + dim: Dimension along which select was performed + index: Index that was selected + + Returns: + Gradient with respect to the input tensor + """ + # Create a zero tensor with the original input shape + # Place the gradient at the selected index + if isinstance(grad_output, ShardTensor): + # Handle ShardTensor case + return sharded_select_backward_helper(grad_output, input_sizes, dim, index) + elif isinstance(grad_output, torch.Tensor): + # Regular tensor case + return aten.select_backward.default(grad_output, input_sizes, dim, index) + else: + raise MissingShardPatch( + f"Unsupported tensor types: grad_output {type(grad_output)}" + ) diff --git a/physicsnemo/distributed/shard_utils/natten_patches.py b/physicsnemo/distributed/shard_utils/natten_patches.py index 75c8b8bbdf..7dcf58ffe2 100644 --- a/physicsnemo/distributed/shard_utils/natten_patches.py +++ b/physicsnemo/distributed/shard_utils/natten_patches.py @@ -15,7 +15,7 @@ # limitations under the License. import importlib.util -from typing import Any, Tuple, Union +from typing import Any, Callable, List, Tuple, Union import torch import wrapt @@ -27,16 +27,16 @@ from torch.distributed.tensor.placement_types import Shard # noqa: E402 from physicsnemo.distributed import ShardTensor # noqa: E402 +from physicsnemo.distributed.shard_utils.halo import ( # noqa: E402 + HaloConfig, + halo_padding, + unhalo_padding, +) from physicsnemo.distributed.shard_utils.patch_core import ( # noqa: E402 MissingShardPatch, UndeterminedShardingError, ) -from .halo import ( # noqa: E402 - halo_padding_1d, - halo_unpadding_1d, -) - __all__ = ["na2d_wrapper"] @@ -74,267 +74,117 @@ def compute_halo_from_kernel_and_dilation(kernel_size: int, dilation: int) -> in return halo -def shard_to_haloed_local( - q: ShardTensor, k: ShardTensor, v: ShardTensor, kernel_size: int, dilation: int = 1 -) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], list[int]]: - """Add halo regions to query, key and value tensors for neighborhood attention. - - For neighborhood attention, each tensor needs access to neighboring values within - the kernel window. This function adds halo regions to sharded q/k/v tensors by - gathering values from adjacent ranks. +def compute_halo_configs_from_natten_args( + example_input: ShardTensor, + kernel_size: int, + dilation: int, +) -> List[HaloConfig]: + """Compute halo configurations for a sharded tensor based on convolution arguments. Args: - q: Query tensor, sharded across device mesh - k: Key tensor, must be sharded same as query - v: Value tensor, must be sharded same as query - kernel_size: Size of attention window - dilation: Dilation factor for attention window, must be 1 + example_input: The sharded tensor that will be used in neighborhood attention + kernel_size: Size of attention kernel window + dilation: Dilation factor for attention kernel Returns: - Tuple containing: - - Tuple of (padded_q, padded_k, padded_v) local tensors with halos - - List of halo sizes for each mesh dimension - - Raises: - ValueError: If q/k/v are not sharded on same mesh + List of HaloConfig objects for each sharded dimension """ - # Verify q/k/v use same device mesh - if q._spec.mesh != k._spec.mesh: - raise ValueError("Mismatched mesh not supported in na2d") - if q._spec.mesh != v._spec.mesh: - raise ValueError("Mismatched mesh not supported in na2d") - # Compute required halo size from kernel parameters halo_size = compute_halo_from_kernel_and_dilation(kernel_size, dilation) - # Get device mesh and create halo params - mesh = q._spec.mesh - halo = [halo_size] * mesh.ndim - edge_padding_s = [0] * mesh.ndim - edge_padding_t = "none" - - # TODO: Verify q/k/v have identical sharding - - # Add halos to each tensor - local_padded_q = HaloPaddingND.apply( - q, - halo, - edge_padding_t, - edge_padding_s, - ) - - local_padded_k = HaloPaddingND.apply( - k, - halo, - edge_padding_t, - edge_padding_s, - ) - - local_padded_v = HaloPaddingND.apply( - v, - halo, - edge_padding_t, - edge_padding_s, - ) + placements = example_input._spec.placements - return (local_padded_q, local_padded_k, local_padded_v), halo + halo_configs = [] + for mesh_dim, p in enumerate(placements): + if not isinstance(p, Shard): + continue -class HaloPaddingND(torch.autograd.Function): - """Autograd wrapper for distributed halo padding. + tensor_dim = p.dim + if tensor_dim in [ + 0, + ]: # Skip batch dim + continue - Handles halo padding for distributed tensors using ShardTensor concept - (local tensor + device mesh + shard placements). Forward pass gathers adjacent regions - from neighboring devices. Backward pass distributes gradients outward. + # Compute required halo size from kernel parameters + halo_size = compute_halo_from_kernel_and_dilation(kernel_size, dilation) - Supports multi-dimensional halo passing with compatible mesh and halo parameters. - """ - - @staticmethod - def forward( - ctx, - stensor: ShardTensor, - halo: Tuple[int, ...], - edge_padding_t: str, - edge_padding_s: Tuple[int, ...], - ) -> torch.Tensor: - """Forward pass of distributed halo padding. - - Args: - ctx: Autograd context for saving tensors - stensor: Input ShardTensor - halo: Halo sizes for each dimension - edge_padding_t: Edge padding type ("zeros" or "none") - edge_padding_s: Edge padding sizes - - Returns: - Padded local tensor - - Raises: - ValueError: If halo size doesn't match mesh dimensions - """ - mesh = stensor.device_mesh - if len(halo) != mesh.ndim: - raise ValueError( - f"Halo size ({len(halo)}) must match mesh rank ({mesh.ndim})" - ) - - placements = stensor.placements - local_tensor = stensor.to_local() - - # Apply halo padding for each sharded dimension - for mesh_dim in range(mesh.ndim): - if isinstance(placements[mesh_dim], Shard): - tensor_dim = placements[mesh_dim].dim - local_tensor = halo_padding_1d( - local_tensor, - mesh, - mesh_dim, - tensor_dim, - halo[mesh_dim], - edge_padding_t, - edge_padding_s[0], + if halo_size > 0: + # Create a halo config for this dimension + halo_configs.append( + HaloConfig( + mesh_dim=mesh_dim, + tensor_dim=tensor_dim, + halo_size=halo_size, + edge_padding_size=0, # Always 0 for natten + communication_method="a2a", ) + ) - # Save context for backward pass - ctx.halo = halo - ctx.spec = stensor._spec - ctx.requires_input_grad = stensor.requires_grad - - return local_tensor - - @staticmethod - def backward( - ctx, grad_output: torch.Tensor - ) -> Tuple[ShardTensor, None, None, None]: - """Backward pass of distributed halo padding. - - Args: - ctx: Autograd context containing saved tensors - grad_output: Gradient tensor from downstream + return halo_configs - Returns: - Tuple containing: - - Gradient for input tensor - - None for other inputs (halo, padding_type, padding_size) - """ - spec = ctx.spec - mesh = spec.mesh - placements = spec.placements - halo = ctx.halo - - # Remove halos from gradients in reverse order - for mesh_dim in range(mesh.ndim): - if isinstance(placements[mesh_dim], Shard): - tensor_dim = placements[mesh_dim].dim - grad_output = halo_unpadding_1d( - grad_output, mesh, mesh_dim, tensor_dim, halo[mesh_dim] - ) - - # Wrap gradient in ShardTensor - grad_tensor = ShardTensor( - grad_output, - spec, - requires_grad=grad_output.requires_grad, - ) - return grad_tensor, None, None, None +def partial_na2d( + q: ShardTensor, + k: ShardTensor, + v: ShardTensor, + kernel_size: int, + dilation: int, + base_func: Callable, +) -> ShardTensor: + """ + High Level, differentiable function to compute neighborhood attention on a sharded tensor. + Operation works like so: + - Figure out the size of halos needed. + - Apply the halo padding (differentiable) + - Perform the neighborhood attention on the padded tensor. (differentiable) + - "UnHalo" the output tensor (different from, say, convolutions) + - Return the updated tensor as a ShardTensor. -class UnSliceHaloND(torch.autograd.Function): - """Autograd function to remove halo regions from a tensor after halo computation. + Args: + q: Query tensor as ShardTensor + k: Key tensor as ShardTensor + v: Value tensor as ShardTensor + kernel_size: Size of attention kernel window + dilation: Dilation factor for attention kernel + base_func: The base neighborhood attention function to call with padded tensors - Used to trim off unnecessary halo sections after operations like neighborhood attention - that require halo regions for computation but not in the final output. + Returns: + ShardTensor containing the result of neighborhood attention - Forward pass removes halo regions by unpadding along sharded dimensions. - Backward pass adds halo regions back via padding to match the original shape. + Raises: + MissingShardPatch: If kernel configuration is not supported for sharding + UndeterminedShardingError: If input tensor types are mismatched """ - @staticmethod - def forward( - ctx, - local_tensor: torch.Tensor, - halo: tuple[int, ...], - mesh: torch.distributed.device_mesh.DeviceMesh, - placements: tuple[torch.distributed.tensor.placement_types.Placement, ...], - ) -> "ShardTensor": - """Forward pass to remove halo regions. - - Args: - ctx: Autograd context for saving tensors - local_tensor: Input tensor with halo regions - halo: Tuple of halo sizes for each mesh dimension - mesh: Device mesh for distributed computation - placements: Tuple of placement specs for each mesh dimension - - Returns: - ShardTensor with halo regions removed - - Raises: - ValueError: If halo size does not match mesh rank - """ - # Save context for backward pass - ctx.halo = halo - ctx.mesh = mesh - ctx.placements = placements - - if len(halo) != mesh.ndim: - raise ValueError( - f"Halo size ({len(halo)}) must match mesh rank ({mesh.ndim})" - ) + # First, get the tensors locally and perform halos: + lq, lk, lv = q.to_local(), k.to_local(), v.to_local() - # Remove halos along sharded dimensions - for mesh_dim in range(mesh.ndim): - if isinstance(placements[mesh_dim], Shard): - tensor_dim = placements[mesh_dim].dim - local_tensor = halo_unpadding_1d( - local_tensor, mesh, mesh_dim, tensor_dim, halo[mesh_dim] - ) + # Compute halo configs for these tensors. We can assume + # the halo configs are the same for q/k/v and just do it once: - # Convert to ShardTensor - stensor = ShardTensor.from_local(local_tensor, mesh, placements) - return stensor + halo_configs = compute_halo_configs_from_natten_args(q, kernel_size, dilation) - @staticmethod - def backward( - ctx, grad_output: "ShardTensor" - ) -> tuple[torch.Tensor, None, None, None]: - """Backward pass to add halo regions back. + # Apply the halo padding to the input tensor + for halo_config in halo_configs: + lq = halo_padding(lq, q._spec.mesh, halo_config) + lk = halo_padding(lk, k._spec.mesh, halo_config) + lv = halo_padding(lv, v._spec.mesh, halo_config) - Args: - ctx: Autograd context with saved tensors - grad_output: Gradient tensor from downstream + # Apply native na2d operation + x = base_func(lq, lk, lv, kernel_size, dilation) - Returns: - Tuple containing: - - Gradient tensor with halo regions added back - - None for other inputs (halo, mesh, placements) - """ - mesh = ctx.mesh - halo = ctx.halo - placements = ctx.placements - - # Configure padding parameters - edge_padding_s = [0] * len(halo) - edge_padding_t = "none" - - # Add halos back via padding - local_tensor = grad_output.to_local() - for mesh_dim in range(mesh.ndim): - if isinstance(placements[mesh_dim], Shard): - tensor_dim = placements[mesh_dim].dim - local_tensor = halo_padding_1d( - local_tensor, - mesh, - mesh_dim, - tensor_dim, - halo[mesh_dim], - edge_padding_t, - edge_padding_s[mesh_dim], - ) + # Remove halos and convert back to ShardTensor + # x = UnSliceHaloND.apply(x, halo, q._spec) + for halo_config in halo_configs: + x = unhalo_padding(x, q._spec.mesh, halo_config) - return local_tensor, None, None, None + # Convert back to ShardTensor + x = ShardTensor.from_local( + x, q._spec.mesh, q._spec.placements, q._spec.sharding_shapes() + ) + return x # Make sure the module exists before importing it: @@ -382,15 +232,8 @@ def fetch_qkv( if all([type(_t) == torch.Tensor for _t in (q, k, v)]): return wrapped(*args, **kwargs) elif all([type(_t) == ShardTensor for _t in (q, k, v)]): - # This applies a halo layer and returns local torch tensors: - (lq, lk, lv), halo = shard_to_haloed_local(q, k, v, kernel_size, dilation) - - # Apply native na2d operation - x = wrapped(lq, lk, lv, kernel_size, dilation) - # Remove halos and convert back to ShardTensor - x = UnSliceHaloND.apply(x, halo, q._spec.mesh, q._spec.placements) - return x + return partial_na2d(q, k, v, kernel_size, dilation, base_func=wrapped) else: raise UndeterminedShardingError( diff --git a/physicsnemo/distributed/shard_utils/normalization_patches.py b/physicsnemo/distributed/shard_utils/normalization_patches.py new file mode 100644 index 0000000000..12ba611aa0 --- /dev/null +++ b/physicsnemo/distributed/shard_utils/normalization_patches.py @@ -0,0 +1,303 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import torch +import torch.distributed as dist +import wrapt +from torch.distributed.tensor import DTensor + +from physicsnemo.distributed import ShardTensor, ShardTensorSpec +from physicsnemo.distributed.manager import DistributedManager + +__all__ = [ + "group_norm_wrapper", +] + + +from physicsnemo.distributed.shard_utils.patch_core import ( + UndeterminedShardingError, +) + +aten = torch.ops.aten + + +class PartialGroupNorm(torch.autograd.Function): + """Custom autograd function for applying group normalization to sharded tensors. + + This implementation extends group normalization functionality to work with distributed + ShardTensor inputs by: + 1. Computing local statistics on each shard + 2. Synchronizing statistics across all shards + 3. Applying the global statistics to normalize each local shard + + The implementation ensures that the result is mathematically equivalent to running + group normalization on the full, unsharded tensor, while maintaining the distributed + nature of the computation. + + This class is used by the group_norm_wrapper function to intercept and handle + torch.nn.functional.group_norm calls with ShardTensor inputs. + """ + + @staticmethod + def forward( + ctx, + input: torch.Tensor, + spec: ShardTensorSpec, + num_groups: int, + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + eps: float, + ) -> ShardTensor: + """Applies group normalization over a sharded tensor. + + Args: + ctx: Autograd context + input: Input tensor of shape [N, C, *] + spec: Sharding specification for the input tensor + num_groups: Number of groups to separate the channels into + weight: Optional scale parameter of shape [C] + bias: Optional bias parameter of shape [C] + eps: Small constant added to denominator for numerical stability + + Returns: + Normalized tensor of same shape as input + """ + # Save for backward + ctx.num_groups = num_groups + ctx.eps = eps + ctx.spec = spec + + # The syntax is: + # local_output, mean, rstd = torch.ops.aten.native_group_norm( + # input: Tensor, # [N, C, *spatial] + # weight: Optional[Tensor], + # bias: Optional[Tensor], + # N: int, + # C: int, + # HxW: int, + # group: int, + # eps: float + # ) + + N, C = input.shape[0], input.shape[1] + + HxW = input.numel() // (N * C) + + local_output, mean, rstd = aten.native_group_norm( + input, weight, bias, N, C, HxW, num_groups, eps + ) + + # Sync the mean and rstd across all ranks + # Note that the variance has to be inverted to make it a linear sync: + + global_mean = mean.clone() + global_var = (1.0 / (rstd**2)) - eps + + # If the mesh is 2D, we still want to reduce this over entire tensor. + # The DistributedManager provides a caching mechanism for getting a mesh-wide group: + group = DistributedManager().get_mesh_group(spec.mesh) + + # TODO - unevenly sharded tensors need a *weighted* reduction here!! + count = len(dist.get_process_group_ranks(group)) + + # Could merge these if needed. They are probably small, + # so paying more for latency than bandwidth. + dist.all_reduce(global_mean, op=dist.ReduceOp.SUM, group=group) + dist.all_reduce(global_var, op=dist.ReduceOp.SUM, group=group) + + # Compute final global statistics + global_mean = global_mean / count + global_var = global_var / count + + global_rstd = torch.rsqrt(global_var + eps) + + # Correct the output from global stats: + + original_shape = input.shape + + broadcast_shape = (N, num_groups, -1) + + scale_factor = (global_rstd / rstd).view(broadcast_shape) + + # Correct to the globally normalized output: + local_output = ( + local_output.view(broadcast_shape) + - global_mean.view(broadcast_shape) + + mean.view(broadcast_shape) + ) * scale_factor + + local_output = local_output.view(original_shape) + + # Now, apply the weight and + if weight is not None: + local_output = local_output * weight.view(1, -1, *([1] * (input.dim() - 2))) + if bias is not None: + local_output = local_output + bias.view(1, -1, *([1] * (input.dim() - 2))) + + ctx.save_for_backward(input, weight, bias) + ctx.global_mean = global_mean + ctx.global_invstd = global_rstd + + ctx.grad_mask = ( + input.requires_grad, + weight is not None and weight.requires_grad, + bias is not None and bias.requires_grad, + ) + + return ShardTensor.from_local( + local_output, + spec.mesh, + spec.placements, + sharding_shapes=spec.sharding_shapes(), + ) + + @staticmethod + def backward( + ctx, grad_output: ShardTensor + ) -> Tuple[ + torch.Tensor, None, None, Optional[torch.Tensor], Optional[torch.Tensor], None + ]: + """Backward pass for group normalization. + + Args: + ctx: Autograd context containing saved variables + grad_output: Gradient of the loss with respect to the output + + Returns: + Tuple containing gradients for inputs, None, None, weights, bias, and None + """ + input, weight, _ = ctx.saved_tensors + num_groups = ctx.num_groups + N, C = input.shape[0], input.shape[1] + HxW = input.numel() // (N * C) + + local_grad_output = grad_output._local_tensor.contiguous() + + grad_input, grad_weight, grad_bias = aten.native_group_norm_backward( + local_grad_output, + input=input, + mean=ctx.global_mean, + rstd=ctx.global_invstd, + weight=weight, + # bias, + N=N, + C=C, + HxW=HxW, + group=num_groups, + output_mask=ctx.grad_mask, + ) + + spec = ctx.spec + group = DistributedManager().get_mesh_group(spec.mesh) + + # Only reduce if grad_weight or grad_bias is not None + if grad_weight is not None: + dist.all_reduce(grad_weight, group=group) + if grad_bias is not None: + dist.all_reduce(grad_bias, group=group) + + return grad_input, None, None, grad_weight, grad_bias, None + + +@wrapt.patch_function_wrapper("torch.nn.functional", "group_norm") +def group_norm_wrapper( + wrapped, instance, args, kwargs +) -> Union[torch.Tensor, ShardTensor]: + """Wrapper for torch.nn.functional.group_norm that handles ShardTensor inputs. + + This function intercepts calls to group_norm and either: + 1. Passes regular torch.Tensor inputs to the original function + 2. Handles ShardTensor inputs with the PartialGroupNorm custom implementation + + Args: + wrapped: Original group_norm function + instance: Instance reference (unused) + args: Positional arguments to group_norm + kwargs: Keyword arguments to group_norm + + Returns: + Normalized tensor (either torch.Tensor or ShardTensor) + """ + input, num_groups, weight, bias, eps = repackage_group_norm_args(*args, **kwargs) + + # Handle regular torch tensor inputs + if ( + isinstance(input, torch.Tensor) + and not isinstance(input, ShardTensor) + and ( + isinstance(weight, (torch.nn.parameter.Parameter, torch.Tensor)) + or weight is None + ) + and ( + bias is None + or isinstance(bias, (torch.nn.parameter.Parameter, torch.Tensor)) + ) + ): + output = wrapped(*args, **kwargs) + return output + + # Handle distributed ShardTensor inputs + elif isinstance(input, ShardTensor): + # Gather any distributed weights/bias + if isinstance(weight, (ShardTensor, DTensor)): + weight = weight.full_tensor() + if isinstance(bias, (ShardTensor, DTensor)): + bias = bias.full_tensor() + + output_spec = input._spec + x = PartialGroupNorm.apply( + input.to_local(), output_spec, num_groups, weight, bias, eps + ) + + return x + + else: + msg = ( + "input, weight, bias (if not None) must all be the valid types " + "(torch.Tensor or ShardTensor), but got " + f"{type(input)}, " + f"{type(weight)}, " + f"{type(bias)}, " + ) + raise UndeterminedShardingError(msg) + + +def repackage_group_norm_args( + input: torch.Tensor, + num_groups: int, + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + eps: float = 1e-05, + *args, + **kwargs, +) -> Tuple[torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor], float]: + """Repackage arguments for group_norm function into a standardized format. + + Args: + input: Input tensor of shape [N, C, *] + num_groups: Number of groups to separate the channels into + weight: Optional scale parameter of shape [C] + bias: Optional bias parameter of shape [C] + eps: Small constant added to denominator for numerical stability + *args: Additional positional arguments (unused) + **kwargs: Additional keyword arguments (unused) + + Returns: + Tuple of (input, num_groups, weight, bias, eps) + """ + return input, num_groups, weight, bias, eps diff --git a/physicsnemo/distributed/shard_utils/patch_core.py b/physicsnemo/distributed/shard_utils/patch_core.py index 5bf4bea40f..8661bcb671 100644 --- a/physicsnemo/distributed/shard_utils/patch_core.py +++ b/physicsnemo/distributed/shard_utils/patch_core.py @@ -47,6 +47,10 @@ def promote_to_iterable(input_obj, target_iterable): An iterable of the same type as the target iterable. """ + # Don't do anything to strings: + if isinstance(input_obj, str): + return input_obj + # If input_obj is a string or not iterable, wrap it in the target's type. if isinstance(input_obj, str) or not isinstance(input_obj, Iterable): # Also extend it with copies to the same length: diff --git a/physicsnemo/distributed/shard_utils/point_cloud_ops.py b/physicsnemo/distributed/shard_utils/point_cloud_ops.py new file mode 100644 index 0000000000..9dd95e1a6c --- /dev/null +++ b/physicsnemo/distributed/shard_utils/point_cloud_ops.py @@ -0,0 +1,615 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Tuple, Union + +import torch +import torch.distributed as dist +import warp as wp +import wrapt + +from physicsnemo.models.layers.ball_query import ( + _ball_query_backward_primative_, + _ball_query_forward_primative_, + ball_query_layer, +) +from physicsnemo.utils.version_check import check_module_requirements + +check_module_requirements("physicsnemo.distributed.shard_tensor") + +from torch.distributed.tensor.placement_types import ( # noqa: E402 + Replicate, + Shard, +) + +from physicsnemo.distributed import ShardTensor # noqa: E402 +from physicsnemo.distributed.shard_utils.patch_core import ( # noqa: E402 + MissingShardPatch, + UndeterminedShardingError, +) +from physicsnemo.distributed.shard_utils.ring import ( # noqa: E402 + RingPassingConfig, + perform_ring_iteration, +) + +wp.config.quiet = True + +__all__ = ["ball_query_layer_wrapper"] + + +def ring_ball_query( + points1: ShardTensor, + points2: ShardTensor, + bq_kwargs: Any, +) -> Tuple[ShardTensor, ShardTensor, ShardTensor]: + """ + Performs ball query operation on points distributed across ranks in a ring configuration. + + Args: + points1: First set of points as a ShardTensor + points2: Second set of points as a ShardTensor + lengths1: Lengths of each batch in points1 + lengths2: Lengths of each batch in points2 + wrapped: The original ball query function to call on each rank + *args: Additional positional arguments to pass to the wrapped function + **kwargs: Additional keyword arguments to pass to the wrapped function + + Returns: + Tuple of (mapping, num_neighbors, outputs) as ShardTensors + """ + mesh = points1._spec.mesh + # We can be confident of this because 1D meshes are enforced + mesh_dim = 0 + + local_group = mesh.get_group(mesh_dim) + local_size = dist.get_world_size(group=local_group) + + # Create a config object to simplify function args for message passing: + ring_config = RingPassingConfig( + mesh_dim=mesh_dim, + mesh_size=local_size, + communication_method="p2p", + ring_direction="forward", + ) + + # Now, get the inputs locally: + local_points1 = points1.to_local() + local_points2 = points2.to_local() + + # Get the shard sizes for the point cloud going around the ring. + # We've already checked that the mesh is 1D so call the '0' index. + p2_shard_sizes = points2._spec.sharding_shapes()[0] + + # Call the differentiable version of the ring-ball-query: + mapping_shard, num_neighbors_shard, outputs_shard = RingBallQuery.apply( + local_points1, + local_points2, + mesh, + ring_config, + p2_shard_sizes, + bq_kwargs, + ) + + # TODO + # the output shapes can be computed directly from the input sharding of points1 + # Requires a little work to fish out parameters but that's it. + # For now, using blocking inference to get the output shapes. + + # For the output shapes, we can compute the output sharding if needed. If the placement + # is Replicate, just infer since there aren't shardings. + if type(points1._spec.placements[0]) == Replicate: + map_shard_shapes = "infer" + neighbors_shard_shapes = "infer" + outputs_shard_shapes = "infer" + elif type(points1._spec.placements[0]) == Shard: + + p1_shard_sizes = points1._spec.sharding_shapes()[0] + + # This conversion to shard tensor can be done explicitly computing the output shapes. + + b = mapping_shard.shape[0] + mp = mapping_shard.shape[-1] + d = points1.shape[-1] + mapping_shard_output_sharding = { + 0: tuple(torch.Size([b, s[1], mp]) for s in p1_shard_sizes), + } + num_neighbors_shard_output_sharding = { + 0: tuple(torch.Size([b, s[1]]) for s in p1_shard_sizes), + } + outputs_shard_output_sharding = { + 0: tuple(torch.Size([b, s[1], mp, d]) for s in p1_shard_sizes), + } + + map_shard_shapes = mapping_shard_output_sharding + # map_shard_shapes = "infer" + neighbors_shard_shapes = num_neighbors_shard_output_sharding + outputs_shard_shapes = outputs_shard_output_sharding + + # Convert back to ShardTensor + mapping_shard = ShardTensor.from_local( + mapping_shard, points1._spec.mesh, points1._spec.placements, map_shard_shapes + ) + num_neighbors_shard = ShardTensor.from_local( + num_neighbors_shard, + points1._spec.mesh, + points1._spec.placements, + neighbors_shard_shapes, + ) + outputs_shard = ShardTensor.from_local( + outputs_shard, + points1._spec.mesh, + points1._spec.placements, + outputs_shard_shapes, + ) + return mapping_shard, num_neighbors_shard, outputs_shard + + +def merge_outputs( + current_mapping: Union[torch.Tensor, None], + current_num_neighbors: Union[torch.Tensor, None], + current_outputs: Union[torch.Tensor, None], + incoming_mapping: torch.Tensor, + incoming_num_neighbors: torch.Tensor, + incoming_outputs: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Perform a gather/scatter operation on the mapping and outputs tensors. + This is an _inplace_ operation on the current tensors, assuming they are not None + + Args: + current_mapping: Current mapping tensor or None + current_num_neighbors: Current number of neighbors tensor or None + current_outputs: Current outputs tensor or None + incoming_mapping: Incoming mapping tensor to merge + incoming_num_neighbors: Incoming number of neighbors tensor to merge + incoming_outputs: Incoming outputs tensor to merge + + Returns: + Tuple of merged (mapping, num_neighbors, outputs) tensors + """ + + @wp.kernel + def merge_mapping_and_outputs( + current_m: wp.array3d(dtype=wp.int32), + current_nn: wp.array2d(dtype=wp.int32), + current_o: wp.array4d(dtype=wp.float32), + incoming_m: wp.array3d(dtype=wp.int32), + incoming_nn: wp.array2d(dtype=wp.int32), + incoming_o: wp.array4d(dtype=wp.float32), + max_neighbors: int, + ): + # This is a kernel that is essentially doing a gather/scatter operation. + + # Which points are we looking at? + tid = wp.tid() + + # How many neighbors do we have? + num_neighbors = current_nn[0, tid] + available_space = max_neighbors - num_neighbors + + # How many neighbors do we have in the incoming tensor? + incoming_num_neighbors = incoming_nn[0, tid] + + # Can't add more neighbors than we have space for: + neighbors_to_add = min(incoming_num_neighbors, available_space) + + # Now, copy the incoming neighbors to offset locations in the current tensor: + for i in range(neighbors_to_add): + + # incoming has no offset + # current has offset of num_neighbors + current_m[0, tid, num_neighbors + i] = incoming_m[0, tid, i] + current_o[0, tid, num_neighbors + i, 0] = incoming_o[0, tid, i, 0] + current_o[0, tid, num_neighbors + i, 1] = incoming_o[0, tid, i, 1] + current_o[0, tid, num_neighbors + i, 2] = incoming_o[0, tid, i, 2] + + # Finally, update the number of neighbors: + current_nn[0, tid] = num_neighbors + incoming_num_neighbors + return + + if ( + current_mapping is None + and current_num_neighbors is None + and current_outputs is None + ): + return incoming_mapping, incoming_num_neighbors, incoming_outputs + + _, n_points, max_neighbors = current_mapping.shape + + # This is a gather/scatter operation: + # We need to merge the incoming values into the current arrays. The arrays + # are essentially a ragged tensor that has been padded to a consistent shape. + # What happens here is: + # - Compare the available space in current tensors to the number of incoming values. + # - If there are more values coming in than there is space, they are truncated. + # - Using the available space, determine the section in the incoming tensor to gather. + # - Using the (trucated) size of incoming values, determine the region of the current tensor for scatter. + # - gather / scatter from incoming to current. + # - Update the current num neighbors correctly + + wp.launch( + merge_mapping_and_outputs, + dim=n_points, + inputs=[ + wp.from_torch(current_mapping, return_ctype=True), + wp.from_torch(current_num_neighbors, return_ctype=True), + wp.from_torch(current_outputs, return_ctype=True), + wp.from_torch(incoming_mapping, return_ctype=True), + wp.from_torch(incoming_num_neighbors, return_ctype=True), + wp.from_torch(incoming_outputs, return_ctype=True), + max_neighbors, + ], + ) + + return current_mapping, current_num_neighbors, current_outputs + + +class RingBallQuery(torch.autograd.Function): + """ + Custom autograd function for performing ball query operations in a distributed ring configuration. + + Handles the forward pass of ball queries across multiple ranks, enabling distributed computation + of nearest neighbors for point clouds. + """ + + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + points1: torch.Tensor, + points2: torch.Tensor, + mesh: Any, + ring_config: RingPassingConfig, + shard_sizes: list, + bq_kwargs: Any, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Forward pass for distributed ball query computation. + + Args: + ctx: Context for saving variables for backward pass + points1: First set of points + points2: Second set of points + lengths1: Lengths of each batch in points1 + lengths2: Lengths of each batch in points2 + mesh: Distribution mesh specification + ring_config: Configuration for ring passing + shard_sizes: Sizes of each shard across ranks + wrapped: The original ball query function + *args: Additional positional arguments for the wrapped function + **kwargs: Additional keyword arguments for the wrapped function + + Returns: + Tuple of (mapping, num_neighbors, outputs) tensors + """ + ctx.mesh = mesh + ctx.ring_config = ring_config + + # Create buffers to store outputs + current_mapping = None + current_num_neighbors = None + current_outputs = None + + # For the first iteration, use local tensors + current_p1, current_p2 = (points1, points2) + + mesh_rank = mesh.get_local_rank() + + # Get all the ranks in the mesh: + world_size = ring_config.mesh_size + + # Store results from each rank to merge in the correct order + rank_results = [None] * world_size + # For uneven point clouds, the global stide is important: + strides = [s[1] for s in shard_sizes] + + ctx.k = bq_kwargs["k"] + ctx.radius = bq_kwargs["radius"] + ctx.hash_grid = bq_kwargs["hash_grid"] + + for i in range(world_size): + + source_rank = (mesh_rank - i) % world_size + + # local_mapping, local_num_neighbors, local_outputs = ball_query_layer( + # current_p1, current_p2, current_l1, current_l2, **bq_kwargs + # ) + ( + local_mapping, + local_num_neighbors, + local_outputs, + ) = _ball_query_forward_primative_( + current_p1[0], + current_p2[0], + ctx.k, + ctx.radius, + ctx.hash_grid, + ) + # Store the result with its source rank + rank_results[source_rank] = ( + local_mapping, + local_num_neighbors, + local_outputs, + ) + # strides.append(current_p2.shape[1]) + + # For point clouds, we need to pass the size of the incoming shard. + next_source_rank = (source_rank - 1) % world_size + + # TODO - this operation should be done async and checked for completion at the start of the next loop. + if i != world_size - 1: + # Don't do a ring on the last iteration. + current_p2 = perform_ring_iteration( + current_p2, + ctx.mesh, + ctx.ring_config, + recv_shape=shard_sizes[next_source_rank], + ) + + # Now merge the results in rank order (0, 1, 2, ...) + stride = 0 + for r in range(world_size): + if rank_results[r] is not None: + local_mapping, local_num_neighbors, local_outputs = rank_results[r] + + current_mapping, current_num_neighbors, current_outputs = merge_outputs( + current_mapping, + current_num_neighbors, + current_outputs, + local_mapping + stride, + local_num_neighbors, + local_outputs, + ) + + stride += strides[r] + ctx.save_for_backward( + points1, points2, current_mapping, current_num_neighbors, current_outputs + ) + + return current_mapping, current_num_neighbors, current_outputs + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + mapping_grad: torch.Tensor, + num_neighbors_grad: torch.Tensor, + outputs_grad: torch.Tensor, + ) -> Tuple[None, ...]: + """ + Backward pass for distributed ring ball query computation. + + Args: + ctx: Context containing saved variables from forward pass + grad_output: Gradients from subsequent layers + + Returns: + Gradients for inputs (currently not implemented) + """ + + raise NotImplementedError("Backward pass for ring ball query not implemented.") + + ( + points1, + points2, + current_mapping, + current_num_neighbors, + current_outputs, + ) = ctx.saved_tensors + + # We need to do a ring again in the backward direction. + # The backward pass is computed locally, and then the gradients + # and p2 are moved along the ring together. + # for i in range(world_size): + # Calculate which source rank this data is from + + local_p2_grad = _ball_query_backward_primative_( + points1[0], + points2[0], + current_mapping, + current_num_neighbors, + current_outputs, + mapping_grad, + num_neighbors_grad, + outputs_grad, + ) + local_p1_grad = torch.zeros_like(points1) + + return ( + local_p1_grad, + local_p2_grad.unsqueeze(0), + None, + None, + None, + None, + ) + + +@wrapt.patch_function_wrapper( + "physicsnemo.models.layers.ball_query", + "ball_query_layer", + enabled=ShardTensor.patches_enabled, +) +def ball_query_layer_wrapper( + wrapped: Any, instance: Any, args: tuple, kwargs: dict +) -> Union[ + Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + Tuple[ShardTensor, ShardTensor, ShardTensor], +]: + """ + Wrapper for BallQueryLayer.forward to support sharded tensors. + + Handles 4 situations, based on the sharding of points 1 and points 2: + - Points 2 is sharded: a ring computation is performed. + - Points 1 is sharded: each rank contains a partial output, + which is returned sharded like Points 1. + - Points 1 is replicated: each rank returns the full output, + even though the input points 2 is sharded. + - Points 1 is replicated: No ring is needed. + - Points 1 is sharded: each rank contains a partial output, + which is returned sharded like Points 1. + - Points 1 is replicated: each rank returns the full output, + even though the input points 2 is sharded. + + All input sharding has to be over a 1D mesh. 2D Point cloud sharding + is not supported at this time. + + Regardless of the input sharding, the output will always be sharded like + points 1, and the output points will always have queried every input point + like in the non-sharded case. + + Args: + wrapped: Original forward method + instance: BallQueryLayer instance + args: Positional arguments (points1, points2, lengths1, lengths2) + kwargs: Keyword arguments + + Returns: + Tuple of (mapping, num_neighbors, outputs) as torch.Tensor or ShardTensor + """ + + points1, points2, bq_kwargs = repackage_ball_query_args(*args, **kwargs) + + # If inputs are ShardTensors, handle them appropriately + if all(isinstance(t, ShardTensor) for t in (points1, points2)): + + # Make sure all meshes are the same + if points1._spec.mesh != points2._spec.mesh: + raise MissingShardPatch( + "point_cloud_ops.ball_query_layer_wrapper: All point inputs must be on the same mesh" + ) + + # make sure all meshes are 1D + if points1._spec.mesh.ndim != 1: + raise MissingShardPatch( + "point_cloud_ops.ball_query_layer_wrapper: All point inputs must be on 1D meshes" + ) + + # Do we need a ring? + points2_placement = points2._spec.placements[0] + if isinstance(points2_placement, Shard): + # We need a ring + mapping, num_neighbors, outputs = ring_ball_query( + points1, points2, bq_kwargs + ) + else: + # No ring is needed + # Call the original function with local tensors + + local_p1 = points1.to_local() + local_p2 = points2.to_local() + k = bq_kwargs["k"] + + mapping, num_neighbors, outputs = ball_query_layer( + local_p1, + local_p2, + **bq_kwargs, + ) + + b = points1.shape[0] + + mapping_placement = {} + num_neighbors_placement = {} + outputs_placement = {} + + for k, s in points1._spec.sharding_shapes().items(): + n_points = [int(_s[1]) for _s in s] + mapping_placement[k] = tuple(torch.Size([b, np, k]) for np in n_points) + num_neighbors_placement[k] = tuple( + torch.Size([b, np]) for np in n_points + ) + outputs_placement[k] = tuple( + torch.Size([b, np, k, 3]) for np in n_points + ) + + mapping = ShardTensor.from_local( + mapping, + points1._spec.mesh, + points1._spec.placements, + sharding_shapes=mapping_placement, + ) + num_neighbors = ShardTensor.from_local( + num_neighbors, + points1._spec.mesh, + points1._spec.placements, + sharding_shapes=num_neighbors_placement, + ) + outputs = ShardTensor.from_local( + outputs, + points1._spec.mesh, + points1._spec.placements, + sharding_shapes=outputs_placement, + ) + + return mapping, num_neighbors, outputs + + # If inputs are regular torch tensors, just call the original function + elif all(isinstance(t, torch.Tensor) for t in (points1, points2)): + return ball_query_layer(points1, points2, **bq_kwargs) + + # If inputs are mixed types, raise an error + else: + raise UndeterminedShardingError( + "points1 and points2 must be the same types (torch.Tensor or ShardTensor)" + ) + + +def repackage_ball_query_args( + points1: Union[torch.Tensor, ShardTensor], + points2: Union[torch.Tensor, ShardTensor], + k: int, + radius: float, + hash_grid: wp.HashGrid, + *args: Any, + **kwargs: Any, +) -> Tuple[ + Union[torch.Tensor, ShardTensor], + Union[torch.Tensor, ShardTensor], + Union[torch.Tensor, ShardTensor], + Union[torch.Tensor, ShardTensor], + dict, +]: + """Repackages ball query arguments into a standard format. + + Takes the arguments that could be passed to a ball query operation + and separates them into core tensor inputs (points1, points2, lengths1, lengths2) + and configuration parameters packaged as a kwargs dict. + + Args: + points1: First set of points + points2: Second set of points + lengths1: Lengths of each batch in points1 + lengths2: Lengths of each batch in points2 + *args: Additional positional args + **kwargs: Additional keyword args + + Returns: + Tuple containing: + - points1 tensor + - points2 tensor + - Dict of ball query configuration parameters + """ + # Extract any additional parameters that might be in kwargs + # or use defaults if not provided + return_kwargs = { + "k": k, + "radius": radius, + "hash_grid": hash_grid, + } + + # Add any explicitly passed parameters + if kwargs: + return_kwargs.update(kwargs) + + return points1, points2, return_kwargs diff --git a/physicsnemo/distributed/shard_utils/pooling_patches.py b/physicsnemo/distributed/shard_utils/pooling_patches.py new file mode 100644 index 0000000000..32b62a9380 --- /dev/null +++ b/physicsnemo/distributed/shard_utils/pooling_patches.py @@ -0,0 +1,434 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import wrapt +from torch.distributed.tensor.placement_types import Shard + +from physicsnemo.distributed import ShardTensor +from physicsnemo.distributed.shard_utils.patch_core import ( + MissingShardPatch, + UndeterminedShardingError, +) + +aten = torch.ops.aten + +__all__ = [ + "avg_pool3d_wrapper", + "max_pool3d_wrapper", +] + + +def compute_output_shape(input_shape, pool_kwargs): + """Compute the output shape of a pooling operation. + + Args: + input_shape: Shape of the input tensor + pool_kwargs: Keyword arguments for the pooling operation + + Returns: + tuple: Output shape after pooling operation + """ + # Extract pooling parameters + kernel_size = pool_kwargs.get("kernel_size") + stride = pool_kwargs.get("stride", kernel_size) + padding = pool_kwargs.get("padding", 0) + + # Handle scalar parameters + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * (len(input_shape) - 2) + if isinstance(stride, int): + stride = (stride,) * (len(input_shape) - 2) + if isinstance(padding, int): + padding = (padding,) * (len(input_shape) - 2) + + # Batch and channel dimensions remain unchanged + output_shape = list(input_shape[:2]) + + # Compute spatial dimensions + for i, (size, k, s, p) in enumerate( + zip(input_shape[2:], kernel_size, stride, padding) + ): + output_size = ((size + 2 * p - k) // s) + 1 + output_shape.append(output_size) + + return tuple(output_shape) + + +@wrapt.patch_function_wrapper( + "torch.nn.functional", "avg_pool3d", enabled=ShardTensor.patches_enabled +) +def avg_pool3d_wrapper(wrapped, instance, args, kwargs): + return generic_avg_pool_nd_wrapper(wrapped, instance, args, kwargs) + + +@wrapt.patch_function_wrapper( + "torch.nn.functional", "avg_pool2d", enabled=ShardTensor.patches_enabled +) +def avg_pool2d_wrapper(wrapped, instance, args, kwargs): + return generic_avg_pool_nd_wrapper(wrapped, instance, args, kwargs) + + +@wrapt.patch_function_wrapper( + "torch.nn.functional", "avg_pool1d", enabled=ShardTensor.patches_enabled +) +def avg_pool1d_wrapper(wrapped, instance, args, kwargs): + return generic_avg_pool_nd_wrapper(wrapped, instance, args, kwargs) + + +def repackage_pool_args( + input: Union[torch.Tensor, ShardTensor], + kernel_size: Union[int, Tuple[int, ...]], + stride: Union[int, Tuple[int, ...]] = None, + padding: Union[int, Tuple[int, ...]] = 0, + ceil_mode: bool = False, + count_include_pad: bool = True, + divisor_override: Optional[int] = None, + *args, + **kwargs, +) -> Tuple[Union[torch.Tensor, ShardTensor], Dict[str, Any]]: + """Repackages pooling arguments into standard format. + + Takes the full set of arguments that could be passed to an avg_pool operation + and separates them into the input tensor and configuration parameters + packaged as a kwargs dict. + + Args: + input: Input tensor to pool + kernel_size: Size of the pooling window + stride: Stride of the pooling window, defaults to kernel_size + padding: Padding added to both sides of the input + ceil_mode: When True, will use ceil instead of floor to compute the output shape + count_include_pad: When True, will include the zero-padding in the averaging calculation + divisor_override: If specified, will be used as divisor, otherwise kernel_size is used + *args: Additional positional args (unused) + **kwargs: Additional keyword args (unused) + + Returns: + Tuple containing: + - Input tensor + - Dict of pooling configuration parameters + """ + # Handle stride=None case (defaults to kernel_size) + if stride is None: + stride = kernel_size + + # Package all non-tensor parameters into a kwargs dictionary + return_kwargs = { + "kernel_size": kernel_size, + "stride": stride, + "padding": padding, + "ceil_mode": ceil_mode, + "count_include_pad": count_include_pad, + } + + # Only add divisor_override if it's not None + if divisor_override is not None: + return_kwargs["divisor_override"] = divisor_override + + return input, return_kwargs + + +def generic_avg_pool_nd_wrapper(wrapped, instance, args, kwargs): + """Generic wrapper for torch N-dimensional pooling operations. + + Handles both regular torch.Tensor inputs and distributed ShardTensor inputs. + For regular tensors, passes through to the wrapped pooling function. + For ShardTensor inputs, handles applying distributed pooling. + + Args: + wrapped: Original pooling function being wrapped + instance: Instance the wrapped function is bound to + args: Positional arguments for pooling + kwargs: Keyword arguments for pooling + + Returns: + Pooling result as either torch.Tensor or ShardTensor + + Raises: + UndeterminedShardingError: If input tensor types are invalid + """ + + # Extract the input tensor and package the remaining arguments + input, pool_kwargs = repackage_pool_args(*args, **kwargs) + + # Handle regular torch tensor inputs + if type(input) == torch.Tensor: + return wrapped(*args, **kwargs) + + # Handle distributed ShardTensor inputs + elif type(input) == ShardTensor: + + # For pooling, the main challenge is to predict the output shape + + # Get the local tensor: + local_input = input.to_local() + + local_pooled_output = wrapped(local_input, **pool_kwargs) + + # Reject cases where stride != kernel_size + if pool_kwargs.get("stride") != pool_kwargs.get("kernel_size"): + raise MissingShardPatch( + "Stride must equal kernel_size for pooling operations" + ) + + # Check divisibility by stride only for sharded dimensions + stride = pool_kwargs.get("stride") + if isinstance(stride, int): + # Assuming channels first ... + stride = (stride,) * (len(local_input.shape) - 2) + + for mesh_dim, placement in enumerate(input._spec.placements): + if isinstance(placement, Shard): + # This dimension is sharded on this mesh dimension + shard_dim = placement.dim + # Skip batch and channel dimensions (first two dims) + if shard_dim >= 2: + spatial_dim = shard_dim - 2 # Convert to spatial dimension index + # Get the sizes for this mesh dimension + shard_shapes = input._spec.sharding_shapes()[mesh_dim] + for shard_shape in shard_shapes: + if ( + spatial_dim < len(shard_shape) - 2 + ): # Check if dimension is valid + spatial_size = shard_shape[shard_dim] + stride_for_dim = stride[spatial_dim] + if spatial_size % stride_for_dim != 0: + raise UndeterminedShardingError( + f"Sharded dimension {shard_dim} with local size {spatial_size} " + f"must be divisible by stride {stride_for_dim}" + ) + + # Compute the sharding shapes: + updated_placements = {} + for mesh_dim, shard_shapes in input._spec.sharding_shapes().items(): + updated_shard_shapes = [ + compute_output_shape(shard_shape, pool_kwargs) + for shard_shape in shard_shapes + ] + updated_placements[mesh_dim] = updated_shard_shapes + + output = ShardTensor.from_local( + local_pooled_output, + input._spec.mesh, + input._spec.placements, + sharding_shapes=updated_placements, + ) + return output + # Use the convolution args to compute the sharded halo + + else: + msg = ( + "input must be a valid type " + "(torch.Tensor or ShardTensor), but got " + f"{type(input)}" + ) + raise UndeterminedShardingError(msg) + + +@wrapt.patch_function_wrapper( + "torch.nn.functional", "max_pool3d", enabled=ShardTensor.patches_enabled +) +def max_pool3d_wrapper(wrapped, instance, args, kwargs): + return generic_max_pool_nd_wrapper(wrapped, instance, args, kwargs) + + +@wrapt.patch_function_wrapper( + "torch.nn.functional", "max_pool2d", enabled=ShardTensor.patches_enabled +) +def max_pool2d_wrapper(wrapped, instance, args, kwargs): + return generic_max_pool_nd_wrapper(wrapped, instance, args, kwargs) + + +@wrapt.patch_function_wrapper( + "torch.nn.functional", "max_pool1d", enabled=ShardTensor.patches_enabled +) +def max_pool1d_wrapper(wrapped, instance, args, kwargs): + return generic_max_pool_nd_wrapper(wrapped, instance, args, kwargs) + + +def repackage_max_pool_args( + input: Union[torch.Tensor, ShardTensor], + kernel_size: Union[int, Tuple[int, ...]], + stride: Union[int, Tuple[int, ...]] = None, + padding: Union[int, Tuple[int, ...]] = 0, + dilation: Union[int, Tuple[int, ...]] = 1, + ceil_mode: bool = False, + return_indices: bool = False, + *args, + **kwargs, +) -> Tuple[Union[torch.Tensor, ShardTensor], Dict[str, Any]]: + """Repackages max pooling arguments into standard format. + + Takes the full set of arguments that could be passed to a max_pool operation + and separates them into the input tensor and configuration parameters + packaged as a kwargs dict. + + Args: + input: Input tensor to pool + kernel_size: Size of the pooling window + stride: Stride of the pooling window, defaults to kernel_size + padding: Padding added to both sides of the input + dilation: Controls the spacing between kernel elements + ceil_mode: When True, will use ceil instead of floor to compute the output shape + return_indices: When True, returns indices of max locations along with outputs + *args: Additional positional args (unused) + **kwargs: Additional keyword args (unused) + + Returns: + Tuple containing: + - Input tensor + - Dict of pooling configuration parameters + """ + # Handle stride=None case (defaults to kernel_size) + if stride is None: + stride = kernel_size + + # Package all non-tensor parameters into a kwargs dictionary + return_kwargs = { + "kernel_size": kernel_size, + "stride": stride, + "padding": padding, + "dilation": dilation, + "ceil_mode": ceil_mode, + "return_indices": return_indices, + } + + return input, return_kwargs + + +def generic_max_pool_nd_wrapper(wrapped, instance, args, kwargs): + """Generic wrapper for torch N-dimensional max pooling operations. + + Handles both regular torch.Tensor inputs and distributed ShardTensor inputs. + For regular tensors, passes through to the wrapped pooling function. + For ShardTensor inputs, handles applying distributed pooling. + + Args: + wrapped: Original pooling function being wrapped + instance: Instance the wrapped function is bound to + args: Positional arguments for pooling + kwargs: Keyword arguments for pooling + + Returns: + Pooling result as either torch.Tensor or ShardTensor (and indices if return_indices=True) + + Raises: + UndeterminedShardingError: If input tensor types are invalid + """ + + # Extract the input tensor and package the remaining arguments + input, pool_kwargs = repackage_max_pool_args(*args, **kwargs) + + # Handle regular torch tensor inputs + if type(input) == torch.Tensor: + return wrapped(*args, **kwargs) + + # Handle distributed ShardTensor inputs + elif type(input) == ShardTensor: + # Get the local tensor: + local_input = input.to_local() + + # Call the local pooling operation + local_pooled_output = wrapped(local_input, **pool_kwargs) + + # Handle return_indices case + return_indices = pool_kwargs.get("return_indices", False) + if return_indices: + local_pooled_output, indices = local_pooled_output + + # Everything below here is computing output meta data + + # Reject cases where stride != kernel_size + if pool_kwargs.get("stride") != pool_kwargs.get("kernel_size"): + raise MissingShardPatch( + "Stride must equal kernel_size for pooling operations" + ) + + # Check divisibility by stride only for sharded dimensions + stride = pool_kwargs.get("stride") + if isinstance(stride, int): + # Assuming channels first ... + stride = (stride,) * (len(local_input.shape) - 2) + + for mesh_dim, placement in enumerate(input._spec.placements): + if isinstance(placement, Shard): + # This dimension is sharded on this mesh dimension + shard_dim = placement.dim + # Skip batch and channel dimensions (first two dims) + if shard_dim >= 2: + spatial_dim = shard_dim - 2 # Convert to spatial dimension index + # Get the sizes for this mesh dimension + shard_shapes = input._spec.sharding_shapes()[mesh_dim] + for shard_shape in shard_shapes: + if ( + spatial_dim < len(shard_shape) - 2 + ): # Check if dimension is valid + spatial_size = shard_shape[shard_dim] + stride_for_dim = stride[spatial_dim] + if spatial_size % stride_for_dim != 0: + raise UndeterminedShardingError( + f"Sharded dimension {shard_dim} with local size {spatial_size} " + f"must be divisible by stride {stride_for_dim}" + ) + + # Compute the sharding shapes: + updated_placements = {} + for mesh_dim, shard_shapes in input._spec.sharding_shapes().items(): + updated_shard_shapes = [ + compute_output_shape(shard_shape, pool_kwargs) + for shard_shape in shard_shapes + ] + updated_placements[mesh_dim] = updated_shard_shapes + + output = ShardTensor.from_local( + local_pooled_output, + input._spec.mesh, + input._spec.placements, + sharding_shapes=updated_placements, + ) + + if return_indices: + # Also create a ShardTensor for indices with the same sharding + indices_output = ShardTensor.from_local( + indices, + input._spec.mesh, + input._spec.placements, + sharding_shapes=updated_placements, + ) + return output, indices_output + else: + return output + + else: + msg = ( + "input must be a valid type " + "(torch.Tensor or ShardTensor), but got " + f"{type(input)}" + ) + raise UndeterminedShardingError(msg) + + +# Write a function to extract the default args for avg_pool_nd + + +# This will become the future implementation, or similar. +# Why not today? Because the backwards pass in DTensor has an explicit (and insufficient) +# hard coded implementation for the backwards pass. +# When that switch happens, the order in the arg repackaging will need to be updated. +# ShardTensor.register_function_handler(aten.convolution.default, generic_conv_nd_wrapper) diff --git a/physicsnemo/distributed/shard_utils/ring.py b/physicsnemo/distributed/shard_utils/ring.py new file mode 100644 index 0000000000..6cafdb550d --- /dev/null +++ b/physicsnemo/distributed/shard_utils/ring.py @@ -0,0 +1,160 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Literal, Union + +import torch +import torch.distributed as dist +from torch.distributed.device_mesh import DeviceMesh + + +@dataclass +class RingPassingConfig: + """Configuration for ring-based communication operations. + + This class encapsulates all parameters needed for ring communication patterns, + making it easier to pass consistent configurations between functions. + + Attributes: + mesh_dim (int): Mesh dimension for the ring communication + tensor_dim (int): Tensor dimension involved in the communication + mesh_size (int): Size of the mesh for this communication + communication_method (str): Method for exchanging data ("p2p" or "a2a") + """ + + VALID_COMM_METHODS = ["p2p", "a2a"] + VALID_RING_DIRECTIONS = ["forward", "backward"] + + mesh_dim: int + mesh_size: int + ring_direction: Literal["forward", "backward"] = "forward" + + communication_method: Literal["p2p", "a2a"] = "a2a" + + def __post_init__(self) -> None: + """Validate configuration parameters after initialization. + + Raises: + ValueError: If invalid communication method is specified + """ + + if self.communication_method not in self.VALID_COMM_METHODS: + raise ValueError( + f"Invalid communication method: {self.communication_method}. " + f"Must be one of {self.VALID_COMM_METHODS}" + ) + + if self.ring_direction not in self.VALID_RING_DIRECTIONS: + raise ValueError( + f"Invalid ring direction: {self.ring_direction}. " + f"Must be one of {self.VALID_RING_DIRECTIONS}" + ) + + +def perform_ring_iteration( + tensor: torch.Tensor, + mesh: DeviceMesh, + ring_config: RingPassingConfig, + recv_shape: Union[torch.Size, None] = None, +) -> torch.Tensor: + """ + Performs a ring collective communication where all tensors are the same size. + + Tensors are sent to the next rank in the ring, and wrap around from rank N-1 to rank 0. + This implements a single step of ring communication where each process sends data to + its neighbor and receives data from its other neighbor. + + Args: + tensor (torch.Tensor): The tensor to be sent in this ring communication step + mesh (DeviceMesh): Device mesh that defines the distributed process group + ring_config (RingPassingConfig): Configuration for the ring communication pattern + + Returns: + torch.Tensor: The tensor received from the previous rank in the ring + """ + + dtype = tensor.dtype + device = tensor.device + + # Get process group info + local_group = mesh.get_group(ring_config.mesh_dim) + local_rank = mesh.get_local_rank(ring_config.mesh_dim) + local_size = dist.get_world_size(group=local_group) + + # Point-to-point communication + local_id_for_send = local_rank + 1 if local_rank < local_size - 1 else 0 + local_id_for_recv = local_rank - 1 if local_rank > 0 else local_size - 1 + + id_for_send = dist.get_global_rank(group=local_group, group_rank=local_id_for_send) + id_for_recv = dist.get_global_rank(group=local_group, group_rank=local_id_for_recv) + + if ring_config.ring_direction == "reverse": + # Swap + id_for_send, id_for_recv = id_for_recv, id_for_send + + if recv_shape is None: + tensor_recv = torch.empty_like(tensor) + else: + tensor_recv = torch.empty(recv_shape, dtype=dtype, device=device) + + if ring_config.communication_method == "p2p": + + p2p_op_list = [] + torch.cuda.set_device(tensor.device) + + # Post receive + p2p_op_list.append( + dist.P2POp( + op=dist.irecv, + tensor=tensor_recv, + peer=id_for_recv, + group=local_group, + ) + ) + + # Post sends + p2p_op_list.append( + dist.P2POp( + op=dist.isend, + tensor=tensor.contiguous(), + peer=id_for_send, + group=local_group, + ) + ) + + # Ensure all communication completes + reqs = dist.batch_isend_irecv(p2p_op_list) + for req in reqs: + req.wait() + + elif ring_config.communication_method == "a2a": + # All-to-all communication + all_to_all_send = [ + torch.empty(0, dtype=dtype, device=device) for _ in range(local_size) + ] + all_to_all_recv = [ + torch.empty(0, dtype=dtype, device=device) for _ in range(local_size) + ] + + all_to_all_recv[id_for_recv] = tensor_recv + all_to_all_send[id_for_send] = tensor + + # Perform exchange + dist.all_to_all(all_to_all_recv, all_to_all_send, group=local_group) + + return tensor_recv diff --git a/physicsnemo/distributed/shard_utils/unpooling_patches.py b/physicsnemo/distributed/shard_utils/unpooling_patches.py new file mode 100644 index 0000000000..e64b3cb69a --- /dev/null +++ b/physicsnemo/distributed/shard_utils/unpooling_patches.py @@ -0,0 +1,334 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import torch +import wrapt +from torch.autograd.profiler import record_function +from torch.distributed.tensor.placement_types import Shard + +from physicsnemo.distributed import ShardTensor +from physicsnemo.distributed.shard_utils.halo import ( + HaloConfig, + halo_padding, + unhalo_padding, +) +from physicsnemo.distributed.shard_utils.patch_core import ( + UndeterminedShardingError, +) + +__all__ = [ + "interpolate_wrapper", +] + + +@wrapt.patch_function_wrapper( + "torch.nn.functional", "interpolate", enabled=ShardTensor.patches_enabled +) +def interpolate_wrapper( + wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] +) -> Union[torch.Tensor, ShardTensor]: + return generic_interpolate_wrapper(wrapped, instance, args, kwargs) + + +def repackage_interpolate_args( + input: Union[torch.Tensor, ShardTensor], + size: Optional[Union[int, Tuple[int, ...]]] = None, + scale_factor: Optional[Union[float, Tuple[float, ...]]] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, + antialias: bool = False, + *args: Any, + **kwargs: Any, +) -> Tuple[Union[torch.Tensor, ShardTensor], Dict[str, Any]]: + """Repackages interpolation arguments into standard format. + + Args: + input: Input tensor to interpolate + size: Output spatial size + scale_factor: Multiplier for spatial size + mode: Algorithm used for upsampling + align_corners: Geometrically, whether corners are aligned + recompute_scale_factor: Whether to recompute scale_factor + antialias: Whether to use anti-aliasing + *args: Additional positional args (unused) + **kwargs: Additional keyword args (unused) + + Returns: + Tuple containing: + - Input tensor + - Dict of interpolation configuration parameters + """ + # Package all non-tensor parameters into a kwargs dictionary + return_kwargs = { + "size": size, + "scale_factor": scale_factor, + "mode": mode, + "align_corners": align_corners, + "recompute_scale_factor": recompute_scale_factor, + "antialias": antialias, + } + + return input, return_kwargs + + +def compute_interpolate_output_shape( + input_shape: Tuple[int, ...], interp_kwargs: Dict[str, Any] +) -> Tuple[int, ...]: + """Compute the output shape of an interpolation operation. + + Args: + input_shape: Shape of the input tensor + interp_kwargs: Keyword arguments for the interpolation operation + + Returns: + tuple: Output shape after interpolation operation + """ + size = interp_kwargs.get("size") + scale_factor = interp_kwargs.get("scale_factor") + + # Batch and channel dimensions remain unchanged + output_shape = list(input_shape[:2]) + + if size is not None: + # If size is provided, use it directly + if isinstance(size, int): + # If size is a single integer, use it for the last dimension only + output_shape.extend(list(input_shape[2:-1])) + output_shape.append(size) + else: + # If size is a sequence, it specifies output size for all spatial dimensions + output_shape.extend(list(size)) + elif scale_factor is not None: + # If scale_factor is provided, compute output sizes + if isinstance(scale_factor, (int, float)): + # Single scale factor for all spatial dimensions + spatial_dims = [int(dim * scale_factor) for dim in input_shape[2:]] + else: + # Separate scale factor for each spatial dimension + spatial_dims = [ + int(dim * scale_factor[i]) for i, dim in enumerate(input_shape[2:]) + ] + output_shape.extend(spatial_dims) + else: + # If neither is provided, output shape is the same as input + output_shape.extend(list(input_shape[2:])) + + return tuple(output_shape) + + +def compute_halo_sizes( + input_shape: Tuple[int, ...], + placements: Sequence[Any], + interp_kwargs: Dict[str, Any], +) -> Dict[int, Tuple[int, int]]: + """Compute the necessary halo sizes for different interpolation modes. + + Args: + input_shape: Shape of the input tensor + placements: Placements from the ShardTensor spec + interp_kwargs: Keyword arguments for the interpolation operation + + Returns: + dict: Halo sizes for each sharded dimension + """ + mode = interp_kwargs.get("mode", "nearest") + + # Default halo sizes based on interpolation mode + # For most modes, we need at least 1 element of overlap + default_halo = { + "nearest": 1, + "linear": 2, + "bilinear": 2, + "bicubic": 4, + "trilinear": 2, + "area": 1, + } + + halo_size = default_halo.get(mode, 1) + halo_sizes = {} + + # Determine which dimensions are sharded + for mesh_dim, placement in enumerate(placements): + if isinstance(placement, Shard): + shard_dim = placement.dim + # Only add halos for spatial dimensions (skip batch and channel) + if shard_dim >= 2: + # The halo might need to be asymmetric based on scale_factor + # This is a simplified implementation; might need refinement + halo_sizes[shard_dim] = (halo_size, halo_size) + + return halo_sizes + + +def compute_halo_configs_from_interpolate_args( + input: ShardTensor, + interp_kwargs: Dict[str, Any], +) -> List[HaloConfig]: + """Compute halo configurations for a sharded tensor based on interpolation arguments. + + Args: + input: The sharded tensor that will be used in interpolation + interp_kwargs: Dictionary of interpolation arguments including mode, size, + scale_factor, etc. + + Returns: + List of HaloConfig objects for each sharded dimension + """ + # Get the placements from the input tensor's spec + placements = input._spec.placements + + # Compute halo sizes using the existing function + halo_sizes = compute_halo_sizes(input.shape, placements, interp_kwargs) + + # Create halo configs from the computed sizes + halo_configs = [] + + for tensor_dim, (left_halo, right_halo) in halo_sizes.items(): + # Find which mesh dimension this tensor dimension is sharded on + for mesh_dim, p in enumerate(placements): + if isinstance(p, Shard) and p.dim == tensor_dim: + # Create a halo config for this dimension + halo_configs.append( + HaloConfig( + mesh_dim=mesh_dim, + tensor_dim=tensor_dim, + halo_size=max( + left_halo, right_halo + ), # Using max as a simplification + edge_padding_size=0, # No explicit padding in interpolation + communication_method="a2a", + ) + ) + break + + return halo_configs + + +def partial_interpolate_nd( + input: ShardTensor, + interp_kwargs: Dict[str, Any], +) -> ShardTensor: + """Perform a convolution on a sharded tensor with halo exchange. + + This high-level, differentiable function computes a convolution on a sharded tensor + by performing these steps: + 1. Calculate the size of halos needed + 2. Apply halo padding (differentiable) + 3. Perform convolution on the padded tensor with padding=0 on sharded dimensions + 4. Return the result as a ShardTensor + + Args: + input: The sharded input tensor + weight: Convolution filter weights + bias: Optional bias parameter + conv_kwargs: Dictionary of convolution parameters (stride, padding, etc.) + + Returns: + Resulting ShardTensor after convolution operation + """ + + # This will produce one config per sharded dim + # It also *updates* conv_kwargs in place to set padding to 0 on the sharded dims + halo_configs = compute_halo_configs_from_interpolate_args(input, interp_kwargs) + + local_input = input.to_local() + + # Apply the halo padding to the input tensor + for halo_config in halo_configs: + local_input = halo_padding(local_input, input._spec.mesh, halo_config) + + unhalo_configs = [] + for h in halo_configs: + unhalo_configs.append( # noqa PERF401 + HaloConfig( + mesh_dim=h.mesh_dim, + tensor_dim=h.tensor_dim, + halo_size=int(interp_kwargs["scale_factor"] * h.halo_size), + edge_padding_size=h.edge_padding_size, + communication_method=h.communication_method, + ) + ) + + # Perform the convolution on the padded tensor + output = torch.nn.functional.interpolate(local_input, **interp_kwargs) + + # Remove halos and convert back to ShardTensor + # x = UnSliceHaloND.apply(x, halo, q._spec) + for halo_config in unhalo_configs: + output = unhalo_padding(output, input._spec.mesh, halo_config) + + result_shapes = {} + for mesh_dim, sharding_shape in input._spec.sharding_shapes().items(): + updated_shapes = tuple( + torch.Size(compute_interpolate_output_shape(s, interp_kwargs)) + for s in sharding_shape + ) + result_shapes[mesh_dim] = updated_shapes + + with record_function("upsampling.from_local"): + # Convert back to ShardTensor + output = ShardTensor.from_local( + output, + input._spec.mesh, + input._spec.placements, + sharding_shapes=result_shapes, + ) + + return output + + +def generic_interpolate_wrapper( + wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] +) -> Union[torch.Tensor, ShardTensor]: + """Wrapper for torch.nn.functional.interpolate. + + Handles both regular torch.Tensor inputs and distributed ShardTensor inputs. + For regular tensors, passes through to the wrapped interpolation function. + For ShardTensor inputs, handles distributed interpolation with halo exchanges. + + Args: + wrapped: Original interpolation function being wrapped + instance: Instance the wrapped function is bound to + args: Positional arguments for interpolation + kwargs: Keyword arguments for interpolation + + Returns: + Interpolation result as either torch.Tensor or ShardTensor + + Raises: + UndeterminedShardingError: If input tensor types are invalid + """ + # Extract the input tensor and package the remaining arguments + input, interp_kwargs = repackage_interpolate_args(*args, **kwargs) + + # Handle regular torch tensor inputs + if type(input) == torch.Tensor: + return wrapped(*args, **kwargs) + + # Handle distributed ShardTensor inputs + elif type(input) == ShardTensor: + # Use the convolution args to compute the sharded halo + return partial_interpolate_nd(input, interp_kwargs) + else: + msg = ( + "input must be a valid type " + "(torch.Tensor or ShardTensor), but got " + f"{type(input)}" + ) + raise UndeterminedShardingError(msg) diff --git a/physicsnemo/models/domino/model.py b/physicsnemo/models/domino/model.py index 5eb2b890a1..258dfadc39 100644 --- a/physicsnemo/models/domino/model.py +++ b/physicsnemo/models/domino/model.py @@ -29,17 +29,34 @@ import torch.nn.functional as F from physicsnemo.models.layers.ball_query import BallQueryLayer +from physicsnemo.utils.profiling import profile def fourier_encode(coords, num_freqs): """Function to caluculate fourier features""" # Create a range of frequencies - freqs = torch.exp(torch.linspace(0, math.pi, num_freqs)) + freqs = torch.exp(torch.linspace(0, math.pi, num_freqs, device=coords.device)) # Generate sine and cosine features features = [torch.sin(coords * f) for f in freqs] + [ torch.cos(coords * f) for f in freqs ] - return torch.cat(features, dim=-1) + ret = torch.cat(features, dim=-1) + return ret + + +def fourier_encode_vectorized(coords, freqs): + """Vectorized Fourier feature encoding""" + D = coords.shape[-1] + F = freqs.shape[0] + + # freqs = torch.exp(torch.linspace(0, math.pi, num_freqs, device=coords.device)) # [F] + freqs = freqs[None, None, :, None] # reshape to [*, F, 1] for broadcasting + + coords = coords.unsqueeze(-2) # [*, 1, D] + scaled = (coords * freqs).reshape(*coords.shape[:-2], D * F) # [*, D, F] + features = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=-1) # [*, D, 2F] + + return features.reshape(*coords.shape[:-2], D * 2 * F) # [*, D * 2F] def calculate_pos_encoding(nx, d=8): @@ -96,6 +113,7 @@ def __init__( self.ball_query_layer = BallQueryLayer(neighbors_in_radius, radius) self.grid_resolution = grid_resolution + @profile def forward( self, x: torch.Tensor, p_grid: torch.Tensor, reverse_mapping: bool = True ) -> tuple[torch.Tensor, torch.Tensor]: @@ -125,26 +143,16 @@ def forward( nx, ny, nz = self.grid_resolution p_grid = torch.reshape(p_grid, (batch_size, nx * ny * nz, 3)) - # p1 = nx * ny * nz - # p2 = x.shape[1] if reverse_mapping: - # lengths1 = torch.full((batch_size,), p1, dtype=torch.int32) - # lengths2 = torch.full((batch_size,), p2, dtype=torch.int32) mapping, num_neighbors, outputs = self.ball_query_layer( p_grid, x, - # lengths1, - # lengths2, ) else: - # lengths1 = torch.full((batch_size,), p2, dtype=torch.int32) - # lengths2 = torch.full((batch_size,), p1, dtype=torch.int32) mapping, num_neighbors, outputs = self.ball_query_layer( x, p_grid, - # lengths1, - # lengths2, ) return mapping, outputs @@ -182,6 +190,7 @@ def __init__( self.activation = F.relu + @profile def forward( self, x: torch.Tensor, radius: float = 0.025, neighbors_in_radius: int = 10 ) -> torch.Tensor: @@ -252,6 +261,7 @@ def __init__(self, input_filters: int, model_parameters): self.upsample = nn.Upsample(scale_factor=2, mode="nearest") self.activation = F.relu + @profile def forward(self, x: torch.Tensor) -> torch.Tensor: """ Process geometry information through the 3D CNN network. @@ -367,6 +377,7 @@ def __init__(self, input_features: int, radii, model_parameters=None): self.radii = radii self.hops = geometry_rep.geo_conv.hops + @profile def forward( self, x: torch.Tensor, p_grid: torch.Tensor, sdf: torch.Tensor ) -> torch.Tensor: @@ -442,12 +453,17 @@ def __init__(self, input_features: int, model_parameters=None): self.fc1 = nn.Linear(input_features_calculated, base_layer) self.fc2 = nn.Linear(base_layer, int(base_layer)) self.fc3 = nn.Linear(int(base_layer), int(base_layer)) - self.bn1 = nn.BatchNorm1d(base_layer) - self.bn2 = nn.BatchNorm1d(int(base_layer)) - self.bn3 = nn.BatchNorm1d(int(base_layer)) + # self.bn1 = nn.BatchNorm1d(base_layer) + # self.bn2 = nn.BatchNorm1d(int(base_layer)) + # self.bn3 = nn.BatchNorm1d(int(base_layer)) self.activation = F.relu + self.register_buffer( + "freqs", torch.exp(torch.linspace(0, math.pi, self.num_modes)) + ) + + @profile def forward(self, x: torch.Tensor) -> torch.Tensor: """ Transform point features into a basis function representation. @@ -459,7 +475,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Tensor containing basis function coefficients """ if self.fourier_features: - facets = torch.cat((x, fourier_encode(x, self.num_modes)), axis=-1) + # facets = torch.cat((x, fourier_encode(x, self.num_modes)), axis=-1) + facets = torch.cat((x, fourier_encode_vectorized(x, self.freqs)), axis=-1) else: facets = x facets = self.activation(self.fc1(facets)) @@ -507,6 +524,7 @@ def __init__(self, input_features: int, model_parameters=None): self.activation = F.relu + @profile def forward(self, x: torch.Tensor) -> torch.Tensor: """ Encode physical parameters into a latent representation. @@ -563,10 +581,10 @@ def __init__( self.fc3 = nn.Linear(int(base_layer), int(base_layer)) self.fc4 = nn.Linear(int(base_layer), int(base_layer)) self.fc5 = nn.Linear(int(base_layer), self.output_features) - self.bn1 = nn.BatchNorm1d(base_layer) - self.bn2 = nn.BatchNorm1d(int(base_layer)) - self.bn3 = nn.BatchNorm1d(int(base_layer)) - self.bn4 = nn.BatchNorm1d(int(base_layer)) + # self.bn1 = nn.BatchNorm1d(base_layer) + # self.bn2 = nn.BatchNorm1d(int(base_layer)) + # self.bn3 = nn.BatchNorm1d(int(base_layer)) + # self.bn4 = nn.BatchNorm1d(int(base_layer)) self.activation = F.relu def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -611,6 +629,7 @@ def __init__( self.fc2 = nn.Linear(base_layer, self.output_features) self.activation = F.relu + @profile def forward(self, x): out = self.activation(self.fc1(x)) out = self.fc2(out) @@ -762,7 +781,7 @@ def __init__( raise ValueError( "At least one of `output_features_vol` or `output_features_surf` must be specified" ) - + self.solution_calculation_mode = model_parameters.solution_calculation_mode self.num_variables_vol = output_features_vol self.num_variables_surf = output_features_surf self.grid_resolution = model_parameters.interp_res @@ -983,6 +1002,7 @@ def __init__( ) ) + @profile def position_encoder( self, encoding_node: torch.Tensor, @@ -1010,6 +1030,7 @@ def position_encoder( x = self.fc_p2(x) return x + @profile def geo_encoding_local( self, encoding_g, volume_mesh_centers, p_grid, mode="volume" ): @@ -1131,42 +1152,116 @@ def calculate_solution_with_neighbors( axis=-1, ) - for f in range(num_variables): - for p in range(num_sample_points): - if p == 0: - volume_m_c = surface_mesh_centers - else: - volume_m_c = surface_mesh_neighbors[:, :, p - 1] + 1e-6 - noise = surface_mesh_centers - volume_m_c - dist = torch.sqrt( - noise[:, :, 0:1] ** 2.0 - + noise[:, :, 1:2] ** 2.0 - + noise[:, :, 2:3] ** 2.0 + if ( + self.solution_calculation_mode == "one-loop" + or self.solution_calculation_mode == "compare" + ): + encoding_list = [ + encoding_node.unsqueeze(2).expand(-1, -1, num_sample_points, -1), + encoding_g.unsqueeze(2).expand(-1, -1, num_sample_points, -1), + ] + + for f in range(num_variables): + + one_loop_centers_expanded = surface_mesh_centers.unsqueeze(2) + + one_loop_noise = one_loop_centers_expanded - ( + surface_mesh_neighbors + 1e-6 + ) + one_loop_noise = torch.norm(one_loop_noise, dim=-1, keepdim=True) + + # Doing it this way prevents the intermediate one_loop_basis_f from being stored in memory for the rest of the function. + agg_output = agg_model[f]( + torch.cat( + ( + nn_basis[f]( + torch.cat( + ( + one_loop_centers_expanded, + surface_mesh_neighbors + 1e-6, + ), + axis=2, + ) + ), + *encoding_list, + ), + axis=-1, + ) + ) + + one_loop_output_center, one_loop_output_neighbor = torch.split( + agg_output, [1, num_sample_points - 1], dim=2 + ) + one_loop_output_neighbor = one_loop_output_neighbor * ( + 1.0 / one_loop_noise + ) + + one_loop_output_center = one_loop_output_center.squeeze(2) + one_loop_output_neighbor = one_loop_output_neighbor.sum(2) + one_loop_dist_sum = torch.sum(1.0 / one_loop_noise, dim=2) + + # Stop here + if num_sample_points > 1: + one_loop_output_res = ( + 0.5 * one_loop_output_center + + 0.5 * one_loop_output_neighbor / one_loop_dist_sum ) - basis_f = nn_basis[f](volume_m_c) - output = torch.cat((basis_f, encoding_node, encoding_g), axis=-1) - if self.encode_parameters: - output = torch.cat((output, param_encoding), axis=-1) - if p == 0: - output_center = agg_model[f](output) else: - if p == 1: - output_neighbor = agg_model[f](output) * (1.0 / dist) - dist_sum = 1.0 / dist + one_loop_output_res = one_loop_output_center + if f == 0: + one_loop_output_all = one_loop_output_res + else: + one_loop_output_all = torch.cat( + (one_loop_output_all, one_loop_output_res), axis=-1 + ) + + if self.solution_calculation_mode != "compare": + return one_loop_output_all + + if ( + self.solution_calculation_mode == "two-loop" + or self.solution_calculation_mode == "compare" + ): + for f in range(num_variables): + for p in range(num_sample_points): + if p == 0: + volume_m_c = surface_mesh_centers else: - output_neighbor += agg_model[f](output) * (1.0 / dist) - dist_sum += 1.0 / dist - if num_sample_points > 1: - output_res = 0.5 * output_center + 0.5 * output_neighbor / dist_sum - else: - output_res = output_center - if f == 0: - output_all = output_res - else: - output_all = torch.cat((output_all, output_res), axis=-1) + volume_m_c = surface_mesh_neighbors[:, :, p - 1] + 1e-6 + noise = surface_mesh_centers - volume_m_c + dist = torch.norm(noise, dim=-1, keepdim=True) + + basis_f = nn_basis[f](volume_m_c) + output = torch.cat((basis_f, encoding_node, encoding_g), axis=-1) + if self.encode_parameters: + output = torch.cat((output, param_encoding), axis=-1) + if p == 0: + output_center = agg_model[f](output) + else: + if p == 1: + output_neighbor = agg_model[f](output) * (1.0 / dist) + dist_sum = 1.0 / dist + else: + output_neighbor += agg_model[f](output) * (1.0 / dist) + dist_sum += 1.0 / dist + if num_sample_points > 1: + output_res = 0.5 * output_center + 0.5 * output_neighbor / dist_sum + else: + output_res = output_center + if f == 0: + output_all = output_res + else: + output_all = torch.cat((output_all, output_res), axis=-1) + if self.solution_calculation_mode != "compare": + return output_all + if self.solution_calculation_mode == "compare": + print( + f"NEIGHBORS: 2-loop vs. 1-loop Agreement? {torch.allclose(one_loop_output_all, output_all)}" + ) return output_all + # @to_local_tensors def calculate_solution( self, volume_mesh_centers, @@ -1206,44 +1301,171 @@ def calculate_solution( params = torch.cat((inlet_velocity, air_density), axis=-1) param_encoding = self.parameter_model(params) - for f in range(num_variables): - for p in range(num_sample_points): - if p == 0: - volume_m_c = volume_mesh_centers - else: - noise = torch.rand_like(volume_mesh_centers) - noise = 2 * (noise - 0.5) - noise = noise / noise_intensity - dist = torch.sqrt( - noise[:, :, 0:1] ** 2.0 - + noise[:, :, 1:2] ** 2.0 - + noise[:, :, 2:3] ** 2.0 + if self.solution_calculation_mode == "compare": + full_random_noise = torch.rand( + ( + num_variables, + num_sample_points, + ) + + tuple(volume_mesh_centers.shape), + dtype=volume_mesh_centers.dtype, + device=volume_mesh_centers.device, + ) + + if ( + self.solution_calculation_mode == "one-loop" + or self.solution_calculation_mode == "compare" + ): + + # Stretch these out to num_sample_points + one_loop_encoding_node = encoding_node.unsqueeze(0).expand( + num_sample_points, -1, -1, -1 + ) + one_loop_encoding_g = encoding_g.unsqueeze(0).expand( + num_sample_points, -1, -1, -1 + ) + + if self.encode_parameters: + one_loop_other_terms = ( + one_loop_encoding_node, + one_loop_encoding_g, + param_encoding, + ) + else: + one_loop_other_terms = (one_loop_encoding_node, one_loop_encoding_g) + + for f in range(num_variables): + + if self.solution_calculation_mode == "one-loop": + one_loop_volume_mesh_centers_expanded = ( + volume_mesh_centers.unsqueeze(0).expand( + num_sample_points, -1, -1, -1 + ) + ) + # Bulk_random_noise has shape (num_sample_points, batch_size, num_points, 3) + one_loop_bulk_random_noise = torch.rand_like( + one_loop_volume_mesh_centers_expanded + ) + + elif self.solution_calculation_mode == "compare": + one_loop_bulk_random_noise = full_random_noise[f] + + one_loop_bulk_random_noise = 2 * (one_loop_bulk_random_noise - 0.5) + one_loop_bulk_random_noise = ( + one_loop_bulk_random_noise / noise_intensity + ) + one_loop_bulk_dist = torch.norm( + one_loop_bulk_random_noise, dim=-1, keepdim=True + ) + + _, one_loop_bulk_dist = torch.split( + one_loop_bulk_dist, [1, num_sample_points - 1], dim=0 + ) + + # Set the first sample point to 0.0: + one_loop_bulk_random_noise[0] = torch.zeros_like( + one_loop_bulk_random_noise[0] + ) + + # Add the noise to the expanded volume_mesh_centers: + one_loop_volume_m_c = volume_mesh_centers + one_loop_bulk_random_noise + # If this looks overly complicated - it is. + # But, this makes sure that the memory used to store the output of both nn_basis[f] + # as well as the output of torch.cat can be deallocated immediately. + # Apply the aggregation model and distance scaling: + one_loop_output = agg_model[f]( + torch.cat( + (nn_basis[f](one_loop_volume_m_c), *one_loop_other_terms), + axis=-1, ) - volume_m_c = volume_mesh_centers + noise - basis_f = nn_basis[f](volume_m_c) - output = torch.cat((basis_f, encoding_node, encoding_g), axis=-1) - if self.encode_parameters: - output = torch.cat((output, param_encoding), axis=-1) - if p == 0: - output_center = agg_model[f](output) + ) + + # select off the first, unperturbed term: + one_loop_output_center, one_loop_output_neighbor = torch.split( + one_loop_output, [1, num_sample_points - 1], dim=0 + ) + + # Scale the neighbor terms by the distance: + one_loop_output_neighbor = one_loop_output_neighbor / one_loop_bulk_dist + + one_loop_dist_sum = torch.sum(1.0 / one_loop_bulk_dist, dim=0) + + # Adjust shapes: + one_loop_output_center = one_loop_output_center.squeeze(1) + one_loop_output_neighbor = one_loop_output_neighbor.sum(0) + + # Compare: + if num_sample_points > 1: + one_loop_output_res = ( + 0.5 * one_loop_output_center + + 0.5 * one_loop_output_neighbor / one_loop_dist_sum + ) + else: + one_loop_output_res = one_loop_output_center + if f == 0: + one_loop_output_all = one_loop_output_res else: - if p == 1: - output_neighbor = agg_model[f](output) * (1.0 / dist) - dist_sum = 1.0 / dist + one_loop_output_all = torch.cat( + (one_loop_output_all, one_loop_output_res), axis=-1 + ) + + if self.solution_calculation_mode != "compare": + return one_loop_output_all + + if ( + self.solution_calculation_mode == "two-loop" + or self.solution_calculation_mode == "compare" + ): + + for f in range(num_variables): + for p in range(num_sample_points): + if p == 0: + volume_m_c = volume_mesh_centers else: - output_neighbor += agg_model[f](output) * (1.0 / dist) - dist_sum += 1.0 / dist - if num_sample_points > 1: - output_res = 0.5 * output_center + 0.5 * output_neighbor / dist_sum - else: - output_res = output_center - if f == 0: - output_all = output_res - else: - output_all = torch.cat((output_all, output_res), axis=-1) + if self.solution_calculation_mode == "two-loop": + noise = torch.rand_like(volume_mesh_centers) + elif self.solution_calculation_mode == "compare": + # Reuse the bulk random noise for precise comparison + noise = full_random_noise[f, p] + noise = 2 * (noise - 0.5) + noise = noise / noise_intensity + dist = torch.norm(noise, dim=-1, keepdim=True) + + volume_m_c = volume_mesh_centers + noise + # print(f"volume_m_c shape: {volume_m_c.shape}") + basis_f = nn_basis[f](volume_m_c) + output = torch.cat((basis_f, encoding_node, encoding_g), axis=-1) + if self.encode_parameters: + output = torch.cat((output, param_encoding), axis=-1) + if p == 0: + output_center = agg_model[f](output) + else: + if p == 1: + output_neighbor = agg_model[f](output) * (1.0 / dist) + dist_sum = 1.0 / dist + else: + output_neighbor += agg_model[f](output) * (1.0 / dist) + dist_sum += 1.0 / dist + if num_sample_points > 1: + output_res = 0.5 * output_center + 0.5 * output_neighbor / dist_sum + else: + output_res = output_center + if f == 0: + output_all = output_res + else: + output_all = torch.cat((output_all, output_res), axis=-1) + + if self.solution_calculation_mode == "two-loop": + return output_all + + if self.solution_calculation_mode == "compare": + print( + f"STANDARD: 2-loop vs. 1-loop Agreement? {torch.allclose(one_loop_output_all, output_all)}" + ) return output_all + @profile def forward( self, data_dict, diff --git a/physicsnemo/models/layers/ball_query.py b/physicsnemo/models/layers/ball_query.py index cc07ab861f..059c15c6cb 100644 --- a/physicsnemo/models/layers/ball_query.py +++ b/physicsnemo/models/layers/ball_query.py @@ -14,251 +14,293 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple import torch import warp as wp -class BallQuery(torch.autograd.Function): +@wp.kernel +def ball_query( + points1: wp.array(dtype=wp.vec3), + points2: wp.array(dtype=wp.vec3), + grid: wp.uint64, + k: wp.int32, + radius: wp.float32, + mapping: wp.array3d(dtype=wp.int32), + num_neighbors: wp.array2d(dtype=wp.int32), +): """ - Warp based Ball Query. - - Note: only differentiable with respect to points1 and points2. - """ - - @wp.kernel - def ball_query( - points1: wp.array(dtype=wp.vec3), - points2: wp.array(dtype=wp.vec3), - grid: wp.uint64, - k: wp.int32, - radius: wp.float32, - mapping: wp.array3d(dtype=wp.int32), - num_neighbors: wp.array2d(dtype=wp.int32), - ): - """ - Performs ball query operation to find neighboring points within a specified radius. - - For each point in points1, finds up to k neighboring points from points2 that are - within the specified radius. Uses a hash grid for efficient spatial queries. - - Note that the neighbors found are not strictly guaranteed to be the closest k neighbors, - in the event that more than k neighbors are found within the radius. - - Args: - points1: Array of query points - points2: Array of points to search - grid: Pre-computed hash grid for accelerated spatial queries - k: Maximum number of neighbors to find for each query point - radius: Maximum search radius for finding neighbors - mapping: Output array to store indices of neighboring points. Should be instantiated as zeros(1, len(points1), k) - num_neighbors: Output array to store the number of neighbors found for each query point. Should be instantiated as zeros(1, len(points1)) - """ - tid = wp.tid() - - # Get position from points1 - pos = points1[tid] - - # particle contact - neighbors = wp.hash_grid_query(id=grid, point=pos, max_dist=radius) - - # Keep track of the number of neighbors found - neighbors_found = wp.int32(0) - - # loop through neighbors to compute density - for index in neighbors: - # Check if outside the radius - pos2 = points2[index] - if wp.length(pos - pos2) > radius: - continue - - # Add neighbor to the list - mapping[0, tid, neighbors_found] = index - - # Increment the number of neighbors found - neighbors_found += 1 + Performs ball query operation to find neighboring points within a specified radius. - # Break if we have found enough neighbors - if neighbors_found == k: - num_neighbors[0, tid] = k - break + For each point in points1, finds up to k neighboring points from points2 that are + within the specified radius. Uses a hash grid for efficient spatial queries. - # Set the number of neighbors - num_neighbors[0, tid] = neighbors_found + Note that the neighbors found are not strictly guaranteed to be the closest k neighbors, + in the event that more than k neighbors are found within the radius. - @wp.kernel - def sparse_ball_query( - points2: wp.array(dtype=wp.vec3), - mapping: wp.array3d(dtype=wp.int32), - num_neighbors: wp.array2d(dtype=wp.int32), - outputs: wp.array4d(dtype=wp.float32), - ): - tid = wp.tid() - - # Get number of neighbors - k = num_neighbors[0, tid] + Args: + points1: Array of query points + points2: Array of points to search + grid: Pre-computed hash grid for accelerated spatial queries + k: Maximum number of neighbors to find for each query point + radius: Maximum search radius for finding neighbors + mapping: Output array to store indices of neighboring points. Should be instantiated as zeros(1, len(points1), k) + num_neighbors: Output array to store the number of neighbors found for each query point. Should be instantiated as zeros(1, len(points1)) + """ + tid = wp.tid() + + # Get position from points1 + pos = points1[tid] + + # particle contact + neighbors = wp.hash_grid_query(id=grid, point=pos, max_dist=radius) + + # Keep track of the number of neighbors found + neighbors_found = wp.int32(0) + + # loop through neighbors to compute density + for index in neighbors: + # Check if outside the radius + pos2 = points2[index] + if wp.length(pos - pos2) > radius: + continue + + # Add neighbor to the list + mapping[0, tid, neighbors_found] = index + + # Increment the number of neighbors found + neighbors_found += 1 + + # Break if we have found enough neighbors + if neighbors_found == k: + num_neighbors[0, tid] = k + break + + # Set the number of neighbors + num_neighbors[0, tid] = neighbors_found + + +@wp.kernel +def sparse_ball_query( + points2: wp.array(dtype=wp.vec3), + mapping: wp.array3d(dtype=wp.int32), + num_neighbors: wp.array2d(dtype=wp.int32), + outputs: wp.array4d(dtype=wp.float32), +): + tid = wp.tid() + + # Get number of neighbors + k = num_neighbors[0, tid] + + # Loop through neighbors + for _k in range(k): + # Get point2 index + index = mapping[0, tid, _k] + + # Get position from points2 + pos = points2[index] + + # Set the output + outputs[0, tid, _k, 0] = pos[0] + outputs[0, tid, _k, 1] = pos[1] + outputs[0, tid, _k, 2] = pos[2] + + +def _ball_query_forward_primative_( + points1: torch.Tensor, + points2: torch.Tensor, + k: int, + radius: float, + hash_grid: wp.HashGrid, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + # Create output tensors: + mapping = torch.zeros( + (1, points1.shape[0], k), + dtype=torch.int32, + device=points1.device, + requires_grad=False, + ) + num_neighbors = torch.zeros( + (1, points1.shape[0]), + dtype=torch.int32, + device=points1.device, + requires_grad=False, + ) + outputs = torch.zeros( + (1, points1.shape[0], k, 3), + dtype=torch.float32, + device=points1.device, + requires_grad=(points1.requires_grad or points2.requires_grad), + ) + + # Convert from torch to warp + points1 = wp.from_torch(points1, dtype=wp.vec3, requires_grad=points1.requires_grad) + points2 = wp.from_torch(points2, dtype=wp.vec3, requires_grad=points2.requires_grad) + + wp_mapping = wp.from_torch(mapping, dtype=wp.int32, requires_grad=False) + wp_num_neighbors = wp.from_torch(num_neighbors, dtype=wp.int32, requires_grad=False) + wp_outputs = wp.from_torch( + outputs, + dtype=wp.float32, + requires_grad=(points1.requires_grad or points2.requires_grad), + ) + + # Build the grid + hash_grid.build(points2, radius) + + # Run the kernel to get mapping + wp.launch( + ball_query, + inputs=[ + points1, + points2, + hash_grid.id, + k, + radius, + ], + outputs=[ + wp_mapping, + wp_num_neighbors, + ], + dim=[points1.shape[0]], + ) + + # Run the kernel to get outputs + wp.launch( + sparse_ball_query, + inputs=[ + points2, + wp_mapping, + wp_num_neighbors, + ], + outputs=[ + wp_outputs, + ], + dim=[points1.shape[0]], + ) + + return mapping, num_neighbors, outputs + + +def _ball_query_backward_primative_( + points1, + points2, + mapping, + num_neighbors, + outputs, + grad_mapping, + grad_num_neighbors, + grad_outputs, +) -> Tuple[torch.Tensor, torch.Tensor]: + + p2_grad = torch.zeros_like(points2) + + # Run the kernel in adjoint mode + wp.launch( + sparse_ball_query, + inputs=[ + wp.from_torch(points2, dtype=wp.vec3, requires_grad=points2.requires_grad), + wp.from_torch(mapping, dtype=wp.int32, requires_grad=False), + wp.from_torch(num_neighbors, dtype=wp.int32, requires_grad=False), + ], + outputs=[ + wp.from_torch(outputs, dtype=wp.float32, requires_grad=False), + ], + adj_inputs=[ + wp.from_torch(p2_grad, dtype=wp.vec3, requires_grad=points2.requires_grad), + wp.from_torch( + grad_mapping, dtype=wp.int32, requires_grad=mapping.requires_grad + ), + wp.from_torch( + grad_num_neighbors, + dtype=wp.int32, + requires_grad=num_neighbors.requires_grad, + ), + ], + adj_outputs=[ + wp.from_torch(grad_outputs, dtype=wp.float32), + ], + dim=[points1.shape[0]], + adjoint=True, + ) + + return p2_grad - # Loop through neighbors - for _k in range(k): - # Get point2 index - index = mapping[0, tid, _k] - # Get position from points2 - pos = points2[index] +class BallQuery(torch.autograd.Function): + """ + Warp based Ball Query. - # Set the output - outputs[0, tid, _k, 0] = pos[0] - outputs[0, tid, _k, 1] = pos[1] - outputs[0, tid, _k, 2] = pos[2] + Note: only differentiable with respect to points1 and points2. + """ @staticmethod def forward( ctx, points1: torch.Tensor, points2: torch.Tensor, - # lengths1: torch.Tensor, - # lengths2: torch.Tensor, k: int, radius: float, hash_grid: wp.HashGrid, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Only works for batch size 1 if points1.shape[0] != 1: - raise AssertionError("nly works for batch size 1") + raise AssertionError("Ball Query only works for batch size 1") - # Convert from torch to warp - ctx.points1 = wp.from_torch( - points1[0], dtype=wp.vec3, requires_grad=points1.requires_grad - ) - ctx.points2 = wp.from_torch( - points2[0], dtype=wp.vec3, requires_grad=points2.requires_grad - ) - # ctx.lengths1 = wp.from_torch(lengths1, dtype=wp.int32, requires_grad=False) - # ctx.lengths2 = wp.from_torch(lengths2, dtype=wp.int32, requires_grad=False) ctx.k = k ctx.radius = radius - # Allocate the mapping and outputs - mapping = torch.zeros( - 1, - points1.shape[1], - k, - dtype=torch.int32, - device="cuda", - requires_grad=False, - ) - ctx.mapping = wp.from_torch(mapping, dtype=wp.int32, requires_grad=False) - num_neighbors = torch.zeros( - 1, - points1.shape[1], - dtype=torch.int32, - device="cuda", - requires_grad=False, - ) - ctx.num_neighbors = wp.from_torch( - num_neighbors, dtype=wp.int32, requires_grad=False - ) - outputs = torch.zeros( - 1, - points1.shape[1], - k, - 3, - dtype=torch.float32, - device="cuda", - requires_grad=(points1.requires_grad or points2.requires_grad), - ) - ctx.outputs = wp.from_torch( - outputs, - dtype=wp.float32, - requires_grad=(points1.requires_grad or points2.requires_grad), - ) - # Make grid ctx.hash_grid = hash_grid - # Build the grid - ctx.hash_grid.build(ctx.points2, radius) - - ctx.dim = [ctx.points1.shape[0]] - - # Run the kernel to get mapping - wp.launch( - BallQuery.ball_query, - inputs=[ - ctx.points1, - ctx.points2, - ctx.hash_grid.id, - k, - radius, - ], - outputs=[ - ctx.mapping, - ctx.num_neighbors, - ], - dim=ctx.dim, - ) - - # Run the kernel to get outputs - wp.launch( - BallQuery.sparse_ball_query, - inputs=[ - ctx.points2, - ctx.mapping, - ctx.num_neighbors, - ], - outputs=[ - ctx.outputs, - ], - dim=ctx.dim, + # Apply the primitive. Note the batch index is removed. + mapping, num_neighbors, outputs = _ball_query_forward_primative_( + points1[0], + points2[0], + k, + radius, + hash_grid, ) + ctx.save_for_backward(points1, points2, mapping, num_neighbors, outputs) - return ( - wp.to_torch(ctx.mapping), - wp.to_torch(ctx.num_neighbors), - wp.to_torch(ctx.outputs), - ) + return mapping, num_neighbors, outputs @staticmethod - def backward( - ctx, - grad_mapping: torch.Tensor, - grad_num_neighbors: torch.Tensor, - grad_outputs: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, None, None, None]: - # Map incoming torch grads to our output variable - ctx.outputs.grad = wp.from_torch(grad_outputs, dtype=wp.float32) - - # Run the kernel in adjoint mode - wp.launch( - BallQuery.sparse_ball_query, - inputs=[ - ctx.points2, - ctx.mapping, - ctx.num_neighbors, - ], - outputs=[ - ctx.outputs, - ], - adj_inputs=[ctx.points2.grad, ctx.mapping.grad, ctx.num_neighbors.grad], - adj_outputs=[ - ctx.outputs.grad, - ], - dim=ctx.dim, - adjoint=True, + def backward(ctx, grad_mapping, grad_num_neighbors, grad_outputs): + + points1, points2, mapping, num_neighbors, outputs = ctx.saved_tensors + # Apply the primitive + p2_grad = _ball_query_backward_primative_( + points1[0], + points2[0], + mapping, + num_neighbors, + outputs, + grad_mapping, + grad_num_neighbors, + grad_outputs, ) + p2_grad = p2_grad.unsqueeze(0) # Return the gradients return ( - wp.to_torch(ctx.points1.grad).unsqueeze(0), - wp.to_torch(ctx.points2.grad).unsqueeze(0), - # None, - # None, + torch.zeros_like(points1), + p2_grad, None, None, None, ) +def ball_query_layer(points1, points2, k, radius, hash_grid): + """ + Wrapper for BallQuery.apply to support a functional interface. + """ + return BallQuery.apply(points1, points2, k, radius, hash_grid) + + class BallQueryLayer(torch.nn.Module): """ Torch layer for differentiable and accelerated Ball Query @@ -295,11 +337,9 @@ def forward( - num_neighbors: Tensor containing the number of neighbors found for each query point - outputs: Tensor containing features or coordinates of the neighboring points """ - return BallQuery.apply( + return ball_query_layer( points1, points2, - # lengths1, - # lengths2, self.k, self.radius, self.hash_grid, @@ -323,8 +363,6 @@ def save_point_cloud(points, name): points1 = torch.rand(n, p1, d, device="cuda", requires_grad=True) points2 = torch.rand(n, p2, d, device="cuda", requires_grad=True) - # lengths1 = torch.full((n,), p1, dtype=torch.int32).cuda() - # lengths2 = torch.full((n,), p2, dtype=torch.int32).cuda() k = 256 # maximum number of neighbors radius = 0.1 @@ -336,8 +374,6 @@ def save_point_cloud(points, name): mapping, num_neighbors, outputs = layer( points1, points2, - # lengths1, - # lengths2, ) for i in range(20): @@ -345,14 +381,10 @@ def save_point_cloud(points, name): p2 += 100 points1 = torch.rand(n, p1, d, device="cuda", requires_grad=False) points2 = torch.rand(n, p2, d, device="cuda", requires_grad=False) - # lengths1 = torch.full((n,), p1, dtype=torch.int32).cuda() - # lengths2 = torch.full((n,), p2, dtype=torch.int32).cuda() mapping, num_neighbors, outputs = layer( points1, points2, - # lengths1, - # lengths2, ) # Perform matrix multiplication as comparison for timing diff --git a/test/distributed/test_ball_query_shard_tensor.py b/test/distributed/test_ball_query_shard_tensor.py new file mode 100644 index 0000000000..5afe020261 --- /dev/null +++ b/test/distributed/test_ball_query_shard_tensor.py @@ -0,0 +1,288 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from physicsnemo.distributed import DistributedManager +from physicsnemo.models.domino.model import BQWarp +from physicsnemo.utils.version_check import check_module_requirements + +try: + check_module_requirements("physicsnemo.distributed.shard_tensor") + +except ImportError: + pytest.skip( + "Skipping test because physicsnemo.distributed.shard_tensor is not available", + allow_module_level=True, + ) + + +from distributed_utils_for_testing import modify_environment # noqa: E402 +from test_shard_tensor_initialization import init_dist +from torch.distributed.tensor import distribute_module # noqa: E402 +from torch.distributed.tensor.placement_types import ( # noqa: E402 + Replicate, + Shard, +) + +from physicsnemo.distributed import ( + register_custom_ops, # noqa: E402 + scatter_tensor, +) + +register_custom_ops() + + +def convert_input_dict_to_shard_tensor( + input_dict, point_placements, grid_placements, mesh +): + # Strategy: convert the point clouds to replicated tensors, and + # grid objects to sharded tensors + non_sharded_keys = [ + "surface_min_max", + "volume_min_max", + "stream_velocity", + "air_density", + ] + + sharded_dict = {} + + for key, value in input_dict.items(): + # Skip non-tensor values + if not isinstance(value, torch.Tensor): + continue + + # Skip keys that should not be sharded + if key in non_sharded_keys: + sharded_dict[key] = scatter_tensor( + value, + 0, + mesh, + [ + Replicate(), + ], + global_shape=value.shape, + dtype=value.dtype, + requires_grad=value.requires_grad, + ) + continue + + if "grid" in key: + sharded_dict[key] = scatter_tensor( + value, + 0, + mesh, + grid_placements, + global_shape=value.shape, + dtype=value.dtype, + requires_grad=value.requires_grad, + ) + else: + sharded_dict[key] = scatter_tensor( + value, + 0, + mesh, + point_placements, + global_shape=value.shape, + dtype=value.dtype, + requires_grad=value.requires_grad, + ) + + return sharded_dict + + +def run_ball_query_module(model, data_dict, reverse_mapping): + geo_centers = data_dict["geometry_coordinates"] + + # Bounding box grid + s_grid = data_dict["surf_grid"] + + # Scaling factors + surf_max = data_dict["surface_min_max"][:, 1] + surf_min = data_dict["surface_min_max"][:, 0] + + # Normalize based on BBox around surface (car) + geo_centers_surf = 2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1 + + mapping, outputs = model(geo_centers_surf, s_grid, reverse_mapping) + + return mapping, outputs + + +def run_sharded_ball_query_layer_forward( + rank, num_gpus, shard_points, shard_grid, reverse_mapping +): + with modify_environment( + RANK=f"{rank}", + WORLD_SIZE=f"{num_gpus}", + MASTER_ADDR="localhost", + MASTER_PORT=str(13245), + LOCAL_RANK=f"{rank % torch.cuda.device_count()}", + ): + init_dist(rank, num_gpus) + dm = DistributedManager() + + device = dm.device + + # Create the input dict: + bsize = 1 + npoints = 17 + nx, ny, nz = 12, 6, 4 + # This is pretty aggressive, it'd never actually be this many. + # But it enables checking the ring ball query deterministically. + if reverse_mapping: + num_neigh = npoints + else: + num_neigh = nx * ny * nz + geom_centers = torch.randn(bsize, npoints, 3).to(device) + surf_grid = torch.randn(bsize, nx, ny, nz, 3).to(device) + surf_grid_max_min = torch.randn(bsize, 2, 3).to(device) + input_dict = { + "geometry_coordinates": geom_centers, + "surf_grid": surf_grid, + "surface_min_max": surf_grid_max_min, + } + + # To make this work, we need to broadcast the input_dict to all GPUs + # Easiest to shard it and then pull it together into a full_tensor on each GPU + + global_mesh = dm.initialize_mesh([-1], ["domain"]) + domain_mesh = global_mesh["domain"] + + # Define the sharding placements: + point_placement = (Shard(1),) if shard_points else (Replicate(),) + grid_placement = (Shard(1),) if shard_grid else (Replicate(),) + + # Convert the input dict to sharded tensors: + sharded_input_dict = convert_input_dict_to_shard_tensor( + input_dict, point_placement, grid_placement, domain_mesh + ) + + # Get the single_gpu input_dict again, but now it's identical on all GPUs + input_dict = { + key: value.full_tensor() for key, value in sharded_input_dict.items() + } + + # Create the model: + model = BQWarp( + input_features=3, + grid_resolution=[nx, ny, nz], + radius=1.0, + neighbors_in_radius=num_neigh, + ).to(device) + + single_gpu_mapping, single_gpu_outputs = run_ball_query_module( + model, input_dict, reverse_mapping=reverse_mapping + ) + + # Initialize a mesh: + + # Convert the model to a distributed model: + # Since the model has no parameters, this might not be necessary. + model = distribute_module(model, device_mesh=domain_mesh) + + sharded_mapping, sharded_outputs = run_ball_query_module( + model, sharded_input_dict, reverse_mapping + ) + + # This ball query function is tricky - we may or may not preserve order. + # To ensure the mapping is correct, we take the sorted values + # along the point dimension and compare. + + sorted_single_gpu_mapping, sorted_single_gpu_mapping_indices = torch.sort( + single_gpu_mapping, dim=-1, descending=True + ) + sorted_sharded_mapping, sorted_sharded_mapping_indices = torch.sort( + sharded_mapping.full_tensor(), dim=-1, descending=True + ) + + # print(f"sorted_sharded_mapping: {sorted_sharded_mapping}") + # mapping_diff = sorted_single_gpu_mapping - sorted_sharded_mapping + # batch_loc, point_loc, ind_loc = torch.where(mapping_diff != 0) + # print(f"sorted_single_gpu_mapping: {sorted_single_gpu_mapping[batch_loc, point_loc]}") + # print(f"sorted_sharded_mapping: {sorted_sharded_mapping[batch_loc, point_loc]}") + assert torch.allclose(sorted_single_gpu_mapping, sorted_sharded_mapping) + + # To check the outputs, we apply the sorted indexes into the outputs + # and validate the sorted version. + + # Apply the sort to the output tensors too: + single_gpu_output_sort_indices = sorted_single_gpu_mapping_indices.unsqueeze( + -1 + ).expand(-1, -1, -1, sharded_outputs.shape[-1]) + sorted_single_gpu_outputs = single_gpu_outputs.gather( + 2, index=single_gpu_output_sort_indices + ) + + sharded_output_sort_indices = sorted_sharded_mapping_indices.unsqueeze( + -1 + ).expand(-1, -1, -1, sharded_outputs.shape[-1]) + sorted_sharded_outputs = sharded_outputs.full_tensor().gather( + 2, index=sharded_output_sort_indices + ) + + assert torch.allclose(sorted_single_gpu_outputs, sorted_sharded_outputs) + + if reverse_mapping: + correct_placement = grid_placement + else: + correct_placement = point_placement + + mapping_placement_correct = ( + sharded_mapping._spec.placements == correct_placement + ) + sharded_outputs_placement_correct = ( + sharded_outputs.placements == correct_placement + ) + + assert mapping_placement_correct + assert sharded_outputs_placement_correct + + DistributedManager().cleanup() + + +@pytest.mark.multigpu +@pytest.mark.timeout(120) +@pytest.mark.parametrize("shard_points", [True, False]) +@pytest.mark.parametrize("shard_grid", [True, False]) +@pytest.mark.parametrize("reverse_mapping", [True, False]) +def test_shard_tensor_ball_query(shard_points, shard_grid, reverse_mapping): + """ + This test is meant to ensure ShardTensor can be initialized correctly + from local data. Checks the following: + + """ + num_gpus = torch.cuda.device_count() + num_gpus = 3 + if num_gpus < 2: + pytest.skip("Not enough GPUs available for distributed tests") + + torch.multiprocessing.set_start_method("spawn", force=True) + + torch.multiprocessing.spawn( + run_sharded_ball_query_layer_forward, + args=(num_gpus, shard_points, shard_grid, reverse_mapping), + nprocs=num_gpus, + join=True, + daemon=True, + ) + + +if __name__ == "__main__": + test_shard_tensor_ball_query( + shard_points=True, shard_grid=True, reverse_mapping=True + ) diff --git a/test/distributed/test_shard_tensor_grad_sharding.py b/test/distributed/test_shard_tensor_grad_sharding.py index de4641dd2d..d82b46f4e9 100644 --- a/test/distributed/test_shard_tensor_grad_sharding.py +++ b/test/distributed/test_shard_tensor_grad_sharding.py @@ -52,20 +52,6 @@ def run_shard_tensor_detach(rank, num_gpus, mesh_names, mesh_sizes, uneven, verb ) shard_tensor_detached = shard_tensor.detach() - print(f"Original spec: {shard_tensor._spec} of type {type(shard_tensor._spec)}") - print( - f"Detached spec: {shard_tensor_detached._spec} of type {type(shard_tensor_detached._spec)}" - ) - - print( - f"Original sharding sizes: {shard_tensor._spec.sharding_sizes()}", - flush=True, - ) - print( - f"Detached sharding sizes: {shard_tensor_detached._spec.sharding_sizes()}", - flush=True, - ) - # Detaching should not change the original data nor should it change the spec: assert shard_tensor._spec == shard_tensor_detached._spec @@ -77,6 +63,7 @@ def run_shard_tensor_detach(rank, num_gpus, mesh_names, mesh_sizes, uneven, verb @pytest.mark.multigpu +@pytest.mark.timeout(120) @pytest.mark.parametrize("data_parallel_size", [-1]) @pytest.mark.parametrize("domain_H", [2, 4]) @pytest.mark.parametrize("domain_W", [1, 2]) @@ -169,8 +156,8 @@ def loss(_input): # Now compute the sharded gradients with FULL TENSOR LOSS: sharded_loss = loss(shard_tensor) - sharded_loss.backward() + # Check if shard_tensor requires grad assert shard_tensor.requires_grad, "ShardTensor should require grad" assert shard_tensor.grad is not None @@ -180,6 +167,7 @@ def loss(_input): @pytest.mark.multigpu +@pytest.mark.timeout(120) @pytest.mark.parametrize("data_parallel_size", [-1]) @pytest.mark.parametrize("domain_H", [2, 4]) @pytest.mark.parametrize("domain_W", [1, 2]) @@ -289,6 +277,7 @@ def loss(_input): @pytest.mark.multigpu +@pytest.mark.timeout(120) @pytest.mark.parametrize("data_parallel_size", [-1]) @pytest.mark.parametrize("domain_H", [2, 4]) @pytest.mark.parametrize("domain_W", [1, 2]) @@ -336,10 +325,3 @@ def test_shard_tensor_input_gradient_local_loss( join=True, daemon=True, ) - - -if __name__ == "__main__": - - # test_shard_tensor_detach(-1,2, 1, True) - test_shard_tensor_input_gradient_local_loss(-1, 2, 1, True) - test_shard_tensor_input_gradient_full_loss(-1, 2, 1, True) diff --git a/test/distributed/test_shard_tensor_initialization.py b/test/distributed/test_shard_tensor_initialization.py index aadd8ce5ea..92e67827de 100644 --- a/test/distributed/test_shard_tensor_initialization.py +++ b/test/distributed/test_shard_tensor_initialization.py @@ -89,21 +89,21 @@ def run_shard_tensor_initialization_from_data_rank( ) # Create the raw data on the first rank of the first dimension of the domain mesh: - source = dist.get_global_rank(domain_mesh.get_group(0), 0) - source = int(domain_mesh.mesh.min()) + first_axis_group = domain_mesh.get_group(0) + first_axis_ranks = dist.get_process_group_ranks(first_axis_group) + source = min(first_axis_ranks) if rank == source: raw_data = torch.randn( global_shape, device=torch.device(f"cuda:{dm.local_rank}") ) else: - raw_data = torch.empty(0) + raw_data = None st = scatter_tensor(raw_data, source, domain_mesh, placements) # Check that the local shape matches the expected shape: local_data = st.to_local() - print(f"local shape: {local_data.shape}") # Check the dimensions on the sharded mesh: checked_dims = [] for mesh_dim, placement in enumerate(placements): @@ -158,7 +158,10 @@ def run_shard_tensor_initialization_from_all_dtensor( st = ShardTensor.from_dtensor(dt) - assert torch.allclose(dt.full_tensor(), st.full_tensor()) + dt_full = dt.full_tensor() + st_full = st.full_tensor() + + assert torch.allclose(dt_full, st_full) # on the "source" rank of the mesh, we should have agreement with raw data. # on the "not-source" rank of the mesh, we shouldn't @@ -202,9 +205,7 @@ def run_shard_tensor_initialization_from_local_chunks( local_shape = list(global_shape) first_shard_dim = placements[0].dim replacement_size = int(random.uniform(0.5, 1.5) * local_shape[first_shard_dim]) - local_shape[first_shard_dim] = replacement_size - # replace the dimension with a new one # Create the raw data everywhere, but it will mostly get thrown away @@ -213,10 +214,13 @@ def run_shard_tensor_initialization_from_local_chunks( local_shape, device=torch.device(f"cuda:{dm.local_rank}") ) st = ShardTensor.from_local( - raw_data, device_mesh=domain_mesh, placements=placements, infer_shape=True + raw_data, + device_mesh=domain_mesh, + placements=placements, + sharding_shapes="infer", ) - # Data comes back ok: + # Local data comes back ok: assert torch.allclose(st.to_local(), raw_data) # Gather the shapes along the random placement and make sure they agree: @@ -239,7 +243,6 @@ def run_shard_tensor_initialization_from_local_chunks( index = index.to(raw_data.device) local_slice = st.full_tensor().index_select(placements[0].dim, index) - # Slice out what should be the original tensor agreement_with_original_data = torch.allclose(local_slice, raw_data) @@ -306,7 +309,7 @@ def test_shard_tensor_initialization_from_data_rank( @pytest.mark.parametrize( "domain_W", [ - 1, + 1, # Lock this to 1. This test will randomize the shape of one axis of the local tensor, a 2D mesh breaks that. ], ) def test_shard_tensor_initialization_from_local_chunks( diff --git a/test/distributed/test_shard_tensor_redistribute.py b/test/distributed/test_shard_tensor_redistribute.py index 070809a018..b0ffa8bf78 100644 --- a/test/distributed/test_shard_tensor_redistribute.py +++ b/test/distributed/test_shard_tensor_redistribute.py @@ -84,7 +84,10 @@ def shard_tensor_factory(mesh_names, mesh_sizes, requires_grad=False, uneven=Tru ) st = ShardTensor.from_local( - raw_data, device_mesh=domain_mesh, placements=placements, infer_shape=True + raw_data, + device_mesh=domain_mesh, + placements=placements, + sharding_shapes="infer", ) return st @@ -361,8 +364,3 @@ def test_shard_tensor_redistribute2d( join=True, daemon=True, ) - - -if __name__ == "__main__": - - test_shard_tensor_reduction(-1, 2, 2, torch.sum) diff --git a/test/distributed/test_shard_tensor_reductions.py b/test/distributed/test_shard_tensor_reductions.py new file mode 100644 index 0000000000..c4aaefc8a7 --- /dev/null +++ b/test/distributed/test_shard_tensor_reductions.py @@ -0,0 +1,200 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from physicsnemo.utils.version_check import check_module_requirements + +try: + check_module_requirements("physicsnemo.distributed.shard_tensor") + ST_AVAILABLE = True +except ImportError: + pytest.skip( + "Skipping test because physicsnemo.distributed.shard_tensor is not available", + allow_module_level=True, + ) + + +if ST_AVAILABLE: + from test_shard_tensor_initialization import ( + init_dist, + ) + from torch.distributed.tensor.placement_types import Shard + + from physicsnemo.distributed import register_custom_ops, scatter_tensor + + register_custom_ops() + + +import torch +from distributed_utils_for_testing import modify_environment + +from physicsnemo.distributed import DistributedManager + + +def run_shard_tensor_reduction( + rank, num_gpus, mesh_names, mesh_sizes, op, backward, dim, in_place, verbose +): + + with modify_environment( + RANK=f"{rank}", + WORLD_SIZE=f"{num_gpus}", + MASTER_ADDR="localhost", + MASTER_PORT=str(13245), + LOCAL_RANK=f"{rank % torch.cuda.device_count()}", + ): + init_dist(rank, num_gpus) + + dm = DistributedManager() + + # Create a random-valued tensor of at least rank 3: + full_input = torch.randn(2, 27, 2, requires_grad=backward).to(dm.device) + + # Scatter it: + global_mesh = dm.initialize_mesh(mesh_sizes, mesh_names) # noqa: F841 + placements = (Shard(1),) + shard_tensor = scatter_tensor( + full_input, + 0, + global_mesh, + placements, + global_shape=full_input.shape, + dtype=full_input.dtype, + requires_grad=backward, + ) + + if verbose: + print( + f"Shard tensor global shape: {shard_tensor.shape} and local shape: {shard_tensor._local_tensor.shape}" + ) + + # For this test, we're testing that the reduction of the tensor works correctly + + # This means we're calling things like `shard_tensor.max()` or `shard_tensor.mean()` + # and looking to get the right answers + + # Note that calling `full_tensor` is already checked in the initialize file but that's + # also, technically, a reduction. + + args = () + kwargs = {"dim": dim} + + full_input = shard_tensor.full_tensor().detach().requires_grad_(True) + + if in_place: + if op == "sum": + partial_result = shard_tensor.sum(*args, **kwargs) + full_result = full_input.sum(*args, **kwargs) + elif op == "min": + partial_result = shard_tensor.min(*args, **kwargs) + full_result = full_input.min(*args, **kwargs) + elif op == "max": + partial_result = shard_tensor.max(*args, **kwargs) + full_result = full_input.max(*args, **kwargs) + elif op == "mean": + partial_result = shard_tensor.mean(*args, **kwargs) + full_result = full_input.mean(*args, **kwargs) + else: + raise ValueError(f"Unsupported operation: {op}") + else: + if op == "sum": + partial_result = torch.sum(shard_tensor, *args, **kwargs) + full_result = torch.sum(full_input, *args, **kwargs) + elif op == "min": + partial_result = torch.min(shard_tensor, *args, **kwargs) + full_result = torch.min(full_input, *args, **kwargs) + elif op == "max": + partial_result = torch.max(shard_tensor, *args, **kwargs) + full_result = torch.max(full_input, *args, **kwargs) + elif op == "mean": + partial_result = torch.mean(shard_tensor, *args, **kwargs) + full_result = torch.mean(full_input, *args, **kwargs) + else: + raise ValueError(f"Unsupported operation: {op}") + resolved_partial_result = partial_result.full_tensor() + + if verbose: + print(f"Partial first: {resolved_partial_result}") + print(f"All gather first: {full_result}") + + assert torch.allclose(resolved_partial_result, full_result, atol=1e-6) + + if backward: + if len(full_result.shape) != 0: + full_result.sum().backward() + else: + full_result.backward() + + standard_grads = full_input.grad + + if len(partial_result.shape) != 0: + partial_result.sum().backward() + else: + partial_result.backward() + + sharded_grads = shard_tensor.grad.full_tensor() + + # Ensure gradient values agree: + assert torch.allclose(standard_grads, sharded_grads) + + # Make sure that the sharded gradients have the same placement and sharding sizes as the original tensor + assert shard_tensor.grad._spec.placements == shard_tensor._spec.placements + assert ( + shard_tensor.grad._spec.sharding_shapes() + == shard_tensor._spec.sharding_shapes() + ) + + print("Success!") + DistributedManager().cleanup() + + +@pytest.mark.multigpu +@pytest.mark.parametrize("op", ["sum", "mean"]) +@pytest.mark.parametrize("backward", [True, False]) +@pytest.mark.parametrize("dim", [None, 0, (0, 1)]) +@pytest.mark.parametrize("in_place", [True, False]) +def test_shard_tensor_reduction(op, backward, dim, in_place): + """ + This test ensures that reductions work correctly on ShardTensors. + + Reductions are implemented with a custom autograd function which intercepts + the call path at ShardTensor.__torch_function__. This isn't strictly + necessary for most reductions in the forward pass, but the backward pass + has incorrectly sharded gradients. The custom function ensures the + output of the reduction is correctly sharded. + """ + num_gpus = torch.cuda.device_count() + assert num_gpus >= 2, "Not enough GPUs available for test" + + mesh_names = ["domain"] + mesh_sizes = [-1] + + verbose = True # Change to True for debug + + torch.multiprocessing.set_start_method("spawn", force=True) + + torch.multiprocessing.spawn( + run_shard_tensor_reduction, + args=(num_gpus, mesh_names, mesh_sizes, op, backward, dim, in_place, verbose), + nprocs=num_gpus, + join=True, + daemon=True, + ) + + +if __name__ == "__main__": + test_shard_tensor_reduction(op="sum", backward=True, dim=(0,), in_place=False)