diff --git a/.github/workflows/ci_gpu.yml b/.github/workflows/ci_gpu.yml index 7be73e94..144dca1b 100644 --- a/.github/workflows/ci_gpu.yml +++ b/.github/workflows/ci_gpu.yml @@ -33,6 +33,28 @@ jobs: # run: | # python3 -m unittest opacus.tests.multigpu_gradcheck.GradientComputationTest.test_gradient_correct + unittest_mixed_precision: + runs-on: 4-core-ubuntu-gpu-t4 + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Display Python version + run: python3 -c "import sys; print(sys.version)" + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.x' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + ./scripts/install_via_pip.sh -c + + - name: Run mixed precision unit tests + run: | + python3 -m unittest opacus.tests.mixed_precision_test integrationtest_py39_torch_release_cuda: runs-on: 4-core-ubuntu-gpu-t4 diff --git a/opacus/grad_sample/conv.py b/opacus/grad_sample/conv.py index b29c47ad..a2536549 100644 --- a/opacus/grad_sample/conv.py +++ b/opacus/grad_sample/conv.py @@ -40,6 +40,9 @@ def compute_conv_grad_sample( backprops: Backpropagations """ activations = activations[0] + + activations = activations.to(backprops.dtype) + n = activations.shape[0] if n == 0: # Empty batch diff --git a/opacus/grad_sample/dp_rnn.py b/opacus/grad_sample/dp_rnn.py index 3fe05876..f9690d55 100644 --- a/opacus/grad_sample/dp_rnn.py +++ b/opacus/grad_sample/dp_rnn.py @@ -39,6 +39,9 @@ def compute_rnn_linear_grad_sample( backprops: Backpropagations """ activations = activations[0] + + activations = activations.to(backprops.dtype) + ret = {} if layer.weight.requires_grad: gs = torch.einsum("n...i,n...j->nij", backprops, activations) diff --git a/opacus/grad_sample/embedding.py b/opacus/grad_sample/embedding.py index 9a2c2637..f6018e28 100644 --- a/opacus/grad_sample/embedding.py +++ b/opacus/grad_sample/embedding.py @@ -51,7 +51,10 @@ def compute_embedding_grad_sample( .reshape(batch_size, -1, layer.embedding_dim) ) grad_sample = torch.zeros( - batch_size, *layer.weight.shape, device=layer.weight.device + batch_size, + *layer.weight.shape, + device=layer.weight.device, + dtype=backprops.dtype ) grad_sample.scatter_add_( 1, index, backprops.reshape(batch_size, -1, layer.embedding_dim) @@ -65,7 +68,13 @@ def compute_embedding_grad_sample( def compute_embeddingbag_gradsampler(layer, inputs, backprops): index, offset = inputs batch_size = offset.shape[0] - gsm = torch.zeros(batch_size, layer.num_embeddings, layer.embedding_dim) + gsm = torch.zeros( + batch_size, + layer.num_embeddings, + layer.embedding_dim, + device=index.device, + dtype=backprops.dtype, + ) for i in range(batch_size): begin = offset[i] diff --git a/opacus/grad_sample/embedding_norm_sample.py b/opacus/grad_sample/embedding_norm_sample.py index 9e2ccf94..d49b2e6c 100644 --- a/opacus/grad_sample/embedding_norm_sample.py +++ b/opacus/grad_sample/embedding_norm_sample.py @@ -131,7 +131,10 @@ def compute_embedding_norm_sample( # Sum gradients over new index positions and compute squared gradient norms num_unique_paired_indices = unique_paired_indices.size(0) summed_gradients = torch.zeros( - num_unique_paired_indices, grad_values.size(-1), device=device + num_unique_paired_indices, + grad_values.size(-1), + device=device, + dtype=grad_values.dtype, ) summed_gradients = summed_gradients.index_add( 0, new_index_positions.to(device), grad_values @@ -139,7 +142,7 @@ def compute_embedding_norm_sample( sqr_gradient_sum = torch.sum(summed_gradients**2, dim=1) # Scatter add the squared sums back to their respective rows - result = torch.zeros(nrows, device=device) + result = torch.zeros(nrows, device=device, dtype=grad_values.dtype) unique_batch_ids = unique_paired_indices[:, 0].to(device) result.scatter_add_(0, unique_batch_ids, sqr_gradient_sum) diff --git a/opacus/grad_sample/functorch.py b/opacus/grad_sample/functorch.py index b74f87b1..9f221a47 100644 --- a/opacus/grad_sample/functorch.py +++ b/opacus/grad_sample/functorch.py @@ -13,6 +13,7 @@ # limitations under the License. import copy +from contextlib import nullcontext import torch import torch.nn as nn @@ -82,8 +83,19 @@ def compute_loss_stateless_model(params, activations, backprops): batched_activations = activations.unsqueeze(1) batched_backprops = backprops.unsqueeze(1) - output = flayer(params, batched_activations) - loss = (output * batched_backprops).sum() + # mixed precision logic + is_mixed = activations.dtype != params[0].dtype + mixed_lowest_dtype = activations.dtype + device_type = activations.device.type + + # use amp context if user is using mixed_precision, else proceed as usual + with ( + torch.amp.autocast(device_type=device_type, dtype=mixed_lowest_dtype) + if is_mixed + else nullcontext() + ): + output = flayer(params, batched_activations) + loss = (output * batched_backprops).sum() return loss ft_compute_grad = grad(compute_loss_stateless_model) @@ -105,9 +117,10 @@ def ft_compute_per_sample_gradient(layer, activations, backprops): if not hasattr(layer, "ft_compute_sample_grad"): prepare_layer(layer) - per_sample_grads = layer.ft_compute_sample_grad( - parameters, activations[0], backprops - ) + activations = activations[0] + if activations.dtype != backprops.dtype and activations.is_floating_point(): + activations = activations.to(backprops.dtype) + per_sample_grads = layer.ft_compute_sample_grad(parameters, activations, backprops) ret = {} for i_p, p in enumerate(parameters): diff --git a/opacus/grad_sample/linear.py b/opacus/grad_sample/linear.py index 6ab860ca..7105c5b0 100644 --- a/opacus/grad_sample/linear.py +++ b/opacus/grad_sample/linear.py @@ -39,6 +39,9 @@ def compute_linear_grad_sample( backprops: Backpropagations """ activations = activations[0] + + activations = activations.to(backprops.dtype) + ret = {} if layer.weight.requires_grad: gs = torch.einsum("n...i,n...j->nij", backprops, activations) @@ -61,6 +64,8 @@ def compute_linear_norm_sample( backprops: Backpropagations """ activations = activations[0] + activations = activations.to(backprops.dtype) + ret = {} if backprops.dim() == 2: diff --git a/opacus/optimizers/optimizer.py b/opacus/optimizers/optimizer.py index ac47e397..c82100b4 100644 --- a/opacus/optimizers/optimizer.py +++ b/opacus/optimizers/optimizer.py @@ -135,7 +135,7 @@ def _generate_noise( `2^53` (easy to break) but with `n=2`, we get `2^159`, which is hard enough for an attacker to break. """ - zeros = torch.zeros(reference.shape, device=reference.device) + zeros = torch.zeros(reference.shape, device=reference.device, dtype=reference.dtype) if std == 0: return zeros # TODO: handle device transfers: generator and reference tensor @@ -165,6 +165,7 @@ def _generate_noise( size=reference.shape, device=reference.device, generator=generator, + dtype=reference.dtype, ) @@ -451,6 +452,12 @@ def clip_and_accumulate(self): for p in self.params: _check_processed_flag(p.grad_sample) grad_sample = self._get_flat_grad_sample(p) + + # gradients should match the dtype of the optimizer parameters + # for mixed precision, optimizer parameters are usually in FP32 + # lower precision grads will be cast up to FP32 + grad_sample = grad_sample.to(p.dtype) + per_sample_clip_factor = per_sample_clip_factor.to(p.dtype) grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample) if p.summed_grad is not None: diff --git a/opacus/tests/mixed_precision_test.py b/opacus/tests/mixed_precision_test.py new file mode 100644 index 00000000..a702a171 --- /dev/null +++ b/opacus/tests/mixed_precision_test.py @@ -0,0 +1,391 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Dict, Optional, Type + +import torch +import torch.nn as nn +from opacus import PrivacyEngine +from torch.utils.data import DataLoader + +from .mixed_precision_utils import ( + AttentionModel, + ComplexModel, + Conv1DModel, + Conv2DModel, + Conv3DModel, + EmbeddingBagModel, + EmbeddingModel, + RNNModel, + SimpleLinearModel, + SimpleLinearReluModel, + create_random_data, +) + + +class MixedPrecisionTest(unittest.TestCase): + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available, skipping test") + def setUp(self): + + # Since this test only runs when CUDA is available, always use CUDA + self.device = torch.device("cuda") + self.input_dim = 4 + self.hidden_dim = 16 + self.output_dim = 4 + self.seq_len = 4 + self.batch_size = 2 + self.num_batches = 2 + + # Check if bfloat16 is supported + self.bf16_supported = hasattr(torch, "bfloat16") + + def _get_training_components( + self, + model: nn.Module, + dataloader: DataLoader, + grad_sample_mode: str, + dtype: torch.dtype, + ): + """ + Return training components (model, optimizer, criterion, dataloader) wrapped by PrivacyEngine. + """ + model = model.to(self.device) + model = model.to(dtype) + model = model.train() + + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + criterion = nn.CrossEntropyLoss() + + privacy_engine = PrivacyEngine() + + # Make the model private with the specified precision + if grad_sample_mode in ["hooks", "functorch", "ew"]: + model, optimizer, dataloader = privacy_engine.make_private( + module=model, + optimizer=optimizer, + data_loader=dataloader, + noise_multiplier=1, + max_grad_norm=1, + grad_sample_mode=grad_sample_mode, + poisson_sampling=False, + ) + elif grad_sample_mode == "ghost": + model, optimizer, criterion, dataloader = privacy_engine.make_private( + module=model, + optimizer=optimizer, + data_loader=dataloader, + criterion=criterion, + max_grad_norm=1, + noise_multiplier=1, + grad_sample_mode="ghost", + poisson_sampling=False, + ) + + return model, optimizer, criterion, dataloader + + def _train_mixed_precision( + self, + model: nn.Module, + dataloader: DataLoader, + dtype: torch.dtype, + grad_sample_mode: str = "hooks", + ): + """ + Integration test for training a model with mixed precison (FP32+FP16 or FP32+BF16). + It checks dtypes of various training artifacts. + The expected behavior is that: + - model parameters are in full precision FP32 + - model outputs are in low precision (BF16 or FP16) + - gradients are in high precision (FP32) + + Args: + model (nn.Module): The neural network model to be trained. + dataloader (DataLoader): DataLoader providing the training data. + dtype (torch.dtype): The lower data type for mixed precision training (torch.float16 or torch.bfloat16). + grad_sample_mode (str): The mode for per-sample gradient computation, options include "hooks", "functorch", "ew", and "ghost". + """ + + model, optimizer, criterion, dataloader = self._get_training_components( + model, dataloader, grad_sample_mode, dtype=torch.float32 + ) + # model weights should be in high precision (fp32) + for p in model.parameters(): + self.assertTrue(p.dtype == torch.float32) + + for batch in dataloader: + x, y = batch + optimizer.zero_grad() + with torch.amp.autocast("cuda", dtype=dtype): + outputs = model(x) + self.assertTrue(outputs.dtype == dtype) + + loss = criterion(outputs, y) + optimizer.zero_grad() + loss.backward() + + for p in model.parameters(): + # the gradients should have been cast up to high precision (fp32) + if p.grad is not None: + self.assertTrue(p.grad.dtype == torch.float32) + # grad_sample and norm_sample could be either in FP32 or low precision depending on the parameter + # we do not explicitly cast them up to FP32, we only ensure that final gradients are cast up + if p.grad_sample is not None: + self.assertTrue(p.grad_sample.dtype in [torch.float32, dtype]) + if grad_sample_mode == "ghost" and p._norm_sample is not None: + self.assertTrue(p._norm_sample.dtype in [torch.float32, dtype]) + + optimizer.step() + + def _train_low_precision( + self, + model: nn.Module, + dataloader: DataLoader, + dtype: torch.dtype, + grad_sample_mode: str = "hooks", + ): + """ + Runs an integration test for low precision training (BF16 or FP16). + Tests that model weights, outputs, and gradients are in the low precision dtype. + + Args: + model (nn.Module): The neural network model to be trained. + dataloader (DataLoader): DataLoader providing the training data. + dtype (torch.dtype): The data type for low precision training (torch.float16 or torch.bfloat16). + grad_sample_mode (str): The mode for per-sample gradient computation, options include "hooks", "functorch", "ew", and "ghost". + """ + + model, optimizer, criterion, dataloader = self._get_training_components( + model, dataloader, grad_sample_mode, dtype=dtype + ) + + for p in model.parameters(): + self.assertTrue(p.dtype == dtype) + + for batch in dataloader: + optimizer.zero_grad() + x, y = batch + if x.is_floating_point(): # for embedding layers, keep input as int + x = x.to(dtype) + outputs = model(x) + self.assertTrue(outputs.dtype == dtype) + + loss = criterion(outputs, y) + loss.backward() + + # all gradients and gradient-related attributes should be in low precision + for p in model.parameters(): + if p.grad is not None: + self.assertTrue(p.grad.dtype == dtype) + if p.grad_sample is not None: + self.assertTrue(p.grad_sample.dtype == dtype) + if grad_sample_mode == "ghost" and p._norm_sample is not None: + self.assertTrue(p._norm_sample.dtype == dtype) + + optimizer.step() + + def _test_precision_training( + self, + model_class: Type[nn.Module], + model_kwargs: Optional[Dict] = None, + ): + """ + Integration tests for training a model with different precision settings: mixed and low precision. + It tests several layer types and architectures with all grad sample modes. + In particular, all layers implemented in Opacus are tested. + The test checks that model weights, outputs, and gradients are in the expected dtypes. + """ + if model_kwargs is None: + model_kwargs = {} + + # Create random data + dataloader, _ = create_random_data( + model_class, + batch_size=self.batch_size, + input_dim=self.input_dim, + output_dim=self.output_dim, + num_batches=self.num_batches, + seq_len=self.seq_len, + device=self.device, + ) + + low_precision_type = [torch.float16] + if self.bf16_supported: + low_precision_type.append(torch.bfloat16) + + # Test with low precision + for grad_sample_mode in ["hooks", "ghost", "functorch", "ew"]: + for dtype in low_precision_type: + # skip test for models with layers not supported by ew + if grad_sample_mode == "ew" and model_class in [ + SimpleLinearModel, + EmbeddingModel, + EmbeddingBagModel, + ComplexModel, + ]: + continue + # functorch does not support EmbeddingBagModel + if grad_sample_mode == "functorch" and model_class == EmbeddingBagModel: + continue + print( + f"Testing {model_class.__name__} model with low {dtype} precision and grad sample mode {grad_sample_mode}" + ) + self._train_low_precision( + model=model_class(**model_kwargs), # Create a fresh model + dataloader=dataloader, + dtype=dtype, + grad_sample_mode=grad_sample_mode, + ) + + # Test mixed FP32 + BF16/FP16 + for grad_sample_mode in ["hooks", "ghost", "functorch"]: + for dtype in low_precision_type: + if grad_sample_mode == "functorch" and model_class == EmbeddingBagModel: + continue + print( + f"Testing {model_class.__name__} with mixed FP32 + {dtype} precision and grad sample mode {grad_sample_mode}" + ) + self._train_mixed_precision( + model=model_class(**model_kwargs), # Create a fresh model + dataloader=dataloader, + dtype=dtype, + grad_sample_mode=grad_sample_mode, + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available, skipping test") + def test_conv2d_model_precision(self): + """Test mixed and low precision training with 2D convolutional layer""" + self._test_precision_training( + model_class=Conv2DModel, + model_kwargs={ + "input_channels": 3, + "hidden_dim": self.hidden_dim, + "output_dim": self.output_dim, + }, + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available, skipping test") + def test_conv3d_model_precision(self): + """Test mixed and low precision training with 3D convolutional layer""" + self._test_precision_training( + model_class=Conv3DModel, + model_kwargs={ + "input_channels": 3, + "hidden_dim": self.hidden_dim, + "output_dim": self.output_dim, + }, + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available, skipping test") + def test_rnn_model_precision(self): + """Test mixed and low precision training with RNN layers""" + self._test_precision_training( + model_class=RNNModel, + model_kwargs={ + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "output_dim": self.output_dim, + }, + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available, skipping test") + def test_attention_model_precision(self): + """Test mixed and low precision training with attention layers""" + self._test_precision_training( + model_class=AttentionModel, + model_kwargs={ + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "output_dim": self.output_dim, + "num_heads": 4, + }, + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available, skipping test") + def test_complex_model_precision(self): + """Test mixed and low precision training with a complex model combining multiple layer types""" + self._test_precision_training( + model_class=ComplexModel, + model_kwargs={ + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "output_dim": self.output_dim, + "seq_len": self.seq_len, + }, + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available, skipping test") + def test_conv1d_model_precision(self): + """Test mixed precision training with 1D convolutional layer""" + self._test_precision_training( + model_class=Conv1DModel, + model_kwargs={ + "input_channels": 3, + "hidden_dim": self.hidden_dim, + "output_dim": self.output_dim, + }, + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available, skipping test") + def test_linear_model_precision(self): + """Test mixed and low precision training with a simple linear model""" + self._test_precision_training( + model_class=SimpleLinearModel, + model_kwargs={ + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "output_dim": self.output_dim, + }, + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available, skipping test") + def test_embedding_model_precision(self): + """Test mixed and low precision training with embedding layer""" + self._test_precision_training( + model_class=EmbeddingModel, + model_kwargs={ + "vocab_size": 100, + "embedding_dim": self.hidden_dim, + "output_dim": self.output_dim, + }, + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available, skipping test") + def test_embedding_bag_model_precision(self): + """Test mixed and low precision training with embedding bag layer""" + self._test_precision_training( + model_class=EmbeddingBagModel, + model_kwargs={ + "vocab_size": 100, + "embedding_dim": self.hidden_dim, + "output_dim": self.output_dim, + }, + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available, skipping test") + def test_linear_relu_model_precision(self): + """Test mixed and low precision training with a simple linear-relu model""" + self._test_precision_training( + model_class=SimpleLinearReluModel, + model_kwargs={ + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "output_dim": self.output_dim, + }, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/opacus/tests/mixed_precision_utils.py b/opacus/tests/mixed_precision_utils.py new file mode 100644 index 00000000..b43412df --- /dev/null +++ b/opacus/tests/mixed_precision_utils.py @@ -0,0 +1,293 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Type + +import torch +import torch.nn as nn +import torch.nn.functional as F +from opacus.layers import DPGRU, DPLSTM, DPRNN, DPMultiheadAttention +from torch.utils.data import DataLoader, TensorDataset + + +def create_random_data( + model_class: Type[nn.Module], + batch_size: int = 2, + input_dim: int = 4, + output_dim: int = 4, + num_batches: int = 2, + seq_len: int = 4, + device: torch.device = torch.device("cuda"), +) -> Tuple[DataLoader, torch.Tensor, torch.Tensor]: + """Create random data for different model types""" + + # Common y tensor creation for all model types + y = torch.randint( + 0, + output_dim, + (batch_size * num_batches,), + device=device, + ) + + # Dictionary mapping model classes to their input tensor creation + model_data_map = { + SimpleLinearModel: lambda: torch.randn( + batch_size * num_batches, input_dim, device=device + ), + SimpleLinearReluModel: lambda: torch.randn( + batch_size * num_batches, input_dim, device=device + ), + Conv1DModel: lambda: torch.randn( + batch_size * num_batches, 3, 16, device=device + ), + Conv2DModel: lambda: torch.randn( + batch_size * num_batches, 3, 16, 16, device=device + ), + Conv3DModel: lambda: torch.randn( + batch_size * num_batches, 3, 8, 8, 8, device=device + ), + RNNModel: lambda: torch.randn( + batch_size * num_batches, seq_len, input_dim, device=device + ), + AttentionModel: lambda: torch.randn( + batch_size * num_batches, seq_len, input_dim, device=device + ), + ComplexModel: lambda: torch.randn( + batch_size * num_batches, seq_len, input_dim, device=device + ), + EmbeddingModel: lambda: torch.randint( + 0, 100, (batch_size * num_batches, 10), device=device + ), + EmbeddingBagModel: lambda: torch.randint( + 0, 100, (batch_size * num_batches, 10), device=device + ), + } + + # Get the appropriate input tensor creation function + if model_class not in model_data_map: + raise ValueError(f"Unknown model class: {model_class.__name__}") + + x = model_data_map[model_class]() + dataset = TensorDataset(x, y) + dataloader = DataLoader(dataset, batch_size=batch_size) + + return dataloader, dataset + + +class SimpleLinearReluModel(nn.Module): + """Simple model with just two linear layers and a ReLU activation""" + + def __init__(self, input_dim=16, hidden_dim=32, output_dim=10): + super().__init__() + self.fc1 = nn.Linear(input_dim, hidden_dim) + self.ln1 = nn.LayerNorm(hidden_dim) + self.fc2 = nn.Linear(hidden_dim, output_dim) + + def forward(self, x): + x = self.fc1(x) + x = self.ln1(x) + x = F.relu(x) + x = self.fc2(x) + return x + + +class SimpleLinearModel(nn.Module): + """Simple model with linear layers and normalization layers""" + + def __init__(self, input_dim=16, hidden_dim=32, output_dim=10): + super().__init__() + self.fc1 = nn.Linear(input_dim, hidden_dim) + self.ln1 = nn.LayerNorm(hidden_dim) + self.fc2 = nn.Linear(hidden_dim, hidden_dim) + self.gn = nn.GroupNorm(num_groups=4, num_channels=hidden_dim) + self.fc3 = nn.Linear(hidden_dim, hidden_dim) + self.rms_norm = nn.RMSNorm(hidden_dim) + self.fc4 = nn.Linear(hidden_dim, output_dim) + + def forward(self, x): + x = self.fc1(x) + x = self.ln1(x) + x = F.relu(x) + x = self.fc2(x) + x_reshaped = x.unsqueeze(-1) + x = self.gn(x_reshaped) + x = x.squeeze(-1) + x = F.relu(x) + x = self.fc3(x) + x = self.rms_norm(x) + x = F.relu(x) + x = self.fc4(x) + return x + + +class Conv1DModel(nn.Module): + """Model with 1D convolutional layer and instance normalization""" + + def __init__(self, input_channels=3, hidden_dim=32, output_dim=10): + super().__init__() + self.conv = nn.Conv1d(input_channels, hidden_dim, kernel_size=3, padding=1) + self.in1d = nn.InstanceNorm1d(hidden_dim) + self.fc = nn.Linear(hidden_dim, output_dim) + + def forward(self, x): + x = self.conv(x) + x = self.in1d(x) + x = F.relu(x) + x = F.avg_pool1d(x, x.size(-1)) + x = x.squeeze(-1) + x = self.fc(x) + return x + + +class Conv2DModel(nn.Module): + """Model with 2D convolutional layer and instance normalization""" + + def __init__(self, input_channels=3, hidden_dim=32, output_dim=10): + super().__init__() + self.conv = nn.Conv2d(input_channels, hidden_dim, kernel_size=3, padding=1) + self.in2d = nn.InstanceNorm2d(hidden_dim) + self.fc = nn.Linear(hidden_dim, output_dim) + + def forward(self, x): + x = self.conv(x) + x = self.in2d(x) + x = F.relu(x) + x = F.avg_pool2d(x, (x.size(-2), x.size(-1))) + x = x.squeeze(-1).squeeze(-1) + x = self.fc(x) + return x + + +class Conv3DModel(nn.Module): + """Model with 3D convolutional layer and instance normalization""" + + def __init__(self, input_channels=3, hidden_dim=32, output_dim=10): + super().__init__() + self.conv = nn.Conv3d(input_channels, hidden_dim, kernel_size=3, padding=1) + self.in3d = nn.InstanceNorm3d(hidden_dim) + self.fc = nn.Linear(hidden_dim, output_dim) + + def forward(self, x): + x = self.conv(x) + x = self.in3d(x) + x = F.relu(x) + x = F.avg_pool3d(x, (x.size(-3), x.size(-2), x.size(-1))) + x = x.squeeze(-1).squeeze(-1).squeeze(-1) + x = self.fc(x) + return x + + +class RNNModel(nn.Module): + """Model with RNN layers (LSTM, GRU, and RNN) and LayerNorm between them""" + + def __init__(self, input_dim=16, hidden_dim=32, output_dim=10): + super().__init__() + self.lstm = DPLSTM(input_dim, hidden_dim, batch_first=True) + self.ln1 = nn.LayerNorm(hidden_dim) + self.gru = DPGRU(hidden_dim, hidden_dim, batch_first=True) + self.ln2 = nn.LayerNorm(hidden_dim) + self.rnn = DPRNN(hidden_dim, hidden_dim, batch_first=True, nonlinearity="tanh") + self.fc = nn.Linear(hidden_dim, output_dim) + + def forward(self, x): + x, _ = self.lstm(x) + x = self.ln1(x) + x, _ = self.gru(x) + x = self.ln2(x) + x, _ = self.rnn(x) + x = x[:, -1] + x = self.fc(x) + return x + + +class AttentionModel(nn.Module): + """Model with multihead attention layers""" + + def __init__(self, input_dim=16, hidden_dim=32, output_dim=10, num_heads=4): + super().__init__() + self.embedding = nn.Linear(input_dim, hidden_dim) + self.attention = DPMultiheadAttention(hidden_dim, num_heads) + self.fc = nn.Linear(hidden_dim, output_dim) + + def forward(self, x): + x = self.embedding(x) + attn_output, _ = self.attention(x, x, x) + x = torch.mean(attn_output, dim=1) + + x = self.fc(x) + return x + + +class EmbeddingModel(nn.Module): + def __init__(self, vocab_size=100, embedding_dim=16, output_dim=10): + super().__init__() + self.embedding = nn.Embedding(vocab_size, embedding_dim) + self.layer_norm = nn.LayerNorm(embedding_dim) + self.fc = nn.Linear(embedding_dim, output_dim) + + def forward(self, x): + x = self.embedding(x) + x = self.layer_norm(x) + x = torch.mean(x, dim=1) + x = F.relu(x) + x = self.fc(x) + return x + + +class EmbeddingBagModel(nn.Module): + def __init__(self, vocab_size=1000, embedding_dim=16, output_dim=10): + super().__init__() + self.embedding_bag = nn.EmbeddingBag(vocab_size, embedding_dim, mode="mean") + self.layer_norm = nn.LayerNorm(embedding_dim) + self.fc = nn.Linear(embedding_dim, output_dim) + + def forward(self, x): + batch_size = x.size(0) + seq_len = x.size(1) + x_flat = x.reshape(-1) + offsets = torch.arange(0, batch_size * seq_len, seq_len, device=x.device) + x = self.embedding_bag(x_flat, offsets) + + x = self.layer_norm(x) + x = F.relu(x) + x = self.fc(x) + return x + + +class ComplexModel(nn.Module): + """Model combining multiple layer types""" + + def __init__(self, input_dim=16, hidden_dim=32, output_dim=10, seq_len=10): + super().__init__() + self.seq_len = seq_len + self.fc_in = nn.Linear(input_dim, hidden_dim) + self.conv1d = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1) + self.lstm = DPLSTM(hidden_dim, hidden_dim, batch_first=True) + self.attention = DPMultiheadAttention(hidden_dim, 4) + self.fc_out = nn.Linear(hidden_dim, output_dim) + + def forward(self, x): + x = self.fc_in(x) + x_conv = x.transpose(1, 2) # [batch_size, hidden_dim, seq_len] + x_conv = self.conv1d(x_conv) # [batch_size, hidden_dim, seq_len] + x_conv = x_conv.transpose(1, 2) # [batch_size, seq_len, hidden_dim] + x = x + x_conv + x_rnn, _ = self.lstm(x) + x = x + x_rnn + x_attn, _ = self.attention(x, x, x) + x = x + x_attn + x = torch.mean(x, dim=1) + x = self.fc_out(x) + return x diff --git a/opacus/tests/multigpu_precision_test.py b/opacus/tests/multigpu_precision_test.py new file mode 100644 index 00000000..8bc6ec95 --- /dev/null +++ b/opacus/tests/multigpu_precision_test.py @@ -0,0 +1,287 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from opacus import PrivacyEngine +from opacus.distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP + +from .mixed_precision_utils import ( + EmbeddingModel, + SimpleLinearReluModel, + create_random_data, +) +from .multigpu_gradcheck_test import cleanup, setup + + +def _get_training_components( + model_class: nn.Module, + model_kwargs: dict[str, int], + device: torch.device, + grad_sample_mode: str, +): + """ + Creates a model, optimizer, criterion, and dataloader for training. + The model is wrapped in DPDDP. + The optimizer, model, dataloader, and criterion are wrapped by the privacy engine. + """ + input_dim = model_kwargs.get("input_dim", 4) + output_dim = model_kwargs.get("output_dim", 4) + seq_len = model_kwargs.get("seq_len", 4) + batch_size = 2 + num_batches = 2 + + dataloader, _ = create_random_data( + model_class, + batch_size=batch_size, + input_dim=input_dim, + output_dim=output_dim, + num_batches=num_batches, + seq_len=seq_len, + device=device, + ) + + model = model_class(**model_kwargs).to(device) + model = model.train() + + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + criterion = nn.CrossEntropyLoss() + + privacy_engine = PrivacyEngine() + + model = DPDDP(model) + + if grad_sample_mode in ["hooks", "functorch", "ew"]: + model, optimizer, dataloader = privacy_engine.make_private( + module=model, + optimizer=optimizer, + data_loader=dataloader, + noise_multiplier=1, + max_grad_norm=1, + grad_sample_mode=grad_sample_mode, + poisson_sampling=False, + ) + elif grad_sample_mode == "ghost": + model, optimizer, criterion, dataloader = privacy_engine.make_private( + module=model, + optimizer=optimizer, + data_loader=dataloader, + criterion=criterion, + max_grad_norm=1, + noise_multiplier=1, + grad_sample_mode="ghost", + poisson_sampling=False, + ) + + return model, optimizer, criterion, dataloader + + +def run_mixed_precision_test( + rank: int, + world_size: int, + model_class: nn.Module, + model_kwargs: dict[str, int], + dtype: torch.dtype, + grad_sample_mode: str, +): + """ + Runs an integration test for distributed training with DPDDP and mixed precision training. + It check dtypes of various training artifacts. + The expected behavior is that: + - model parameters are in full precision FP32 + - model outputs are in low precision (BF16 or FP16) + - gradients are in high precision (FP32) + + Args: + rank (int): The rank of the current process. + world_size (int): The number of processes participating in the job. + model_class (nn.Module): The neural network model to be trained. + model_kwargs (dict): The keyword arguments for the model. + dtype (torch.dtype): The data type for low precision training (torch.float16 or torch.bfloat16). + """ + + setup(rank, world_size) + device = torch.device(f"cuda:{rank}") + + model, optimizer, criterion, dataloader = _get_training_components( + model_class, model_kwargs, device, grad_sample_mode + ) + + # Model weights should be in high precision (fp32) + model = model.to(torch.float32) + for p in model.parameters(): + assert p.dtype == torch.float32 + + for batch in dataloader: + x, y = batch + optimizer.zero_grad() + + with torch.amp.autocast("cuda", dtype=dtype): + outputs = model(x) + assert outputs.dtype == dtype + + loss = criterion(outputs, y) + optimizer.zero_grad() + loss.backward() + + # The gradients should have been cast up to high precision (fp32) + for p in model.parameters(): + if p.grad is not None: + assert p.grad.dtype == torch.float32 + if p.grad_sample is not None: + assert p.grad_sample.dtype in [torch.float32, dtype] + if grad_sample_mode == "ghost" and p._norm_sample is not None: + assert p._norm_sample.dtype in [torch.float32, dtype] + + optimizer.step() + + cleanup() + + +def run_low_precision_test( + rank: int, + world_size: int, + model_class: nn.Module, + model_kwargs: dict[str, int], + dtype: torch.dtype, + grad_sample_mode: str, +): + """ + Runs an integration test for distributed training with DPDDP and low precision training. + Tests that model weights, outputs, and gradients are in the low precision dtype. + + Args: + rank (int): The rank of the current process. + world_size (int): The number of processes participating in the job. + model_class (nn.Module): The neural network model to be trained. + model_kwargs (dict): The keyword arguments for the model. + dtype (torch.dtype): The data type for low precision training (torch.float16 or torch.bfloat16). + grad_sample_mode (str): The mode for per-sample gradient computation, options include "hooks", "functorch + """ + setup(rank, world_size) + device = torch.device(f"cuda:{rank}") + + model, optimizer, criterion, dataloader = _get_training_components( + model_class, model_kwargs, device, grad_sample_mode + ) + + # Model weights should be in low precision + model = model.to(dtype) + for p in model.parameters(): + assert p.dtype == dtype + + for batch in dataloader: + x, y = batch + optimizer.zero_grad() + + if x.is_floating_point(): # For embedding layers, keep input as int + x = x.to(dtype) + outputs = model(x) + assert outputs.dtype == dtype + + loss = criterion(outputs, y) + optimizer.zero_grad() + loss.backward() + + for p in model.parameters(): + if p.grad is not None: + assert p.grad.dtype == dtype + if p.grad_sample is not None: + assert p.grad_sample.dtype == dtype + if grad_sample_mode == "ghost" and p._norm_sample is not None: + assert p._norm_sample.dtype == dtype + + optimizer.step() + + cleanup() + + +class MultiGPUPrecisionTest(unittest.TestCase): + @unittest.skipIf(torch.cuda.device_count() < 2, "Need at least 2 GPUs") + def setUp(self): + self.input_dim = 4 + self.hidden_dim = 16 + self.output_dim = 4 + self.seq_len = 4 + self.bf16_supported = hasattr(torch, "bfloat16") + + @unittest.skipIf(torch.cuda.device_count() < 2, "Need at least 2 GPUs") + def test_precision_training( + self, + ): + """ + Runs an integration test for distributed training with DPDDP and mixed and low precision training. + Tests that model weights, outputs, and gradients are in the expected dtypes. + Tests all available grad sample modes. + """ + + model_kwargs_map = { + SimpleLinearReluModel: { + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "output_dim": self.output_dim, + }, + EmbeddingModel: { + "vocab_size": 100, + "embedding_dim": self.hidden_dim, + "output_dim": self.output_dim, + }, + } + + # test models sequentially since running tests in parallel fails for DDP + for model_class in [SimpleLinearReluModel, EmbeddingModel]: + + model_kwargs = model_kwargs_map[model_class] + dtype = torch.bfloat16 if self.bf16_supported else torch.float16 + world_size = 2 + # test low precision training + for grad_sample_mode in ["ew", "functorch", "hooks", "ghost"]: + # "ew" is not supported for EmbeddingModel + if grad_sample_mode == "ew" and model_class == EmbeddingModel: + continue + mp.spawn( + run_low_precision_test, + args=( + world_size, + model_class, + model_kwargs, + dtype, + grad_sample_mode, + ), + nprocs=world_size, + join=True, + ) + + # test mixed precision training + for grad_sample_mode in ["functorch", "hooks", "ghost"]: + mp.spawn( + run_mixed_precision_test, + args=( + world_size, + model_class, + model_kwargs, + dtype, + grad_sample_mode, + ), + nprocs=world_size, + join=True, + ) + + +if __name__ == "__main__": + unittest.main()