Skip to content

Add mixed precision tests to ci_gpu #768

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions .github/workflows/ci_gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,28 @@ jobs:
# run: |
# python3 -m unittest opacus.tests.multigpu_gradcheck.GradientComputationTest.test_gradient_correct

unittest_mixed_precision:
runs-on: 4-core-ubuntu-gpu-t4
steps:
- name: Checkout
uses: actions/checkout@v2

- name: Display Python version
run: python3 -c "import sys; print(sys.version)"

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.x'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
./scripts/install_via_pip.sh -c

- name: Run mixed precision unit tests
run: |
python3 -m unittest opacus.tests.mixed_precision_test

integrationtest_py39_torch_release_cuda:
runs-on: 4-core-ubuntu-gpu-t4
Expand Down
3 changes: 3 additions & 0 deletions opacus/grad_sample/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def compute_conv_grad_sample(
backprops: Backpropagations
"""
activations = activations[0]

activations = activations.to(backprops.dtype)

n = activations.shape[0]
if n == 0:
# Empty batch
Expand Down
3 changes: 3 additions & 0 deletions opacus/grad_sample/dp_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def compute_rnn_linear_grad_sample(
backprops: Backpropagations
"""
activations = activations[0]

activations = activations.to(backprops.dtype)

ret = {}
if layer.weight.requires_grad:
gs = torch.einsum("n...i,n...j->nij", backprops, activations)
Expand Down
13 changes: 11 additions & 2 deletions opacus/grad_sample/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ def compute_embedding_grad_sample(
.reshape(batch_size, -1, layer.embedding_dim)
)
grad_sample = torch.zeros(
batch_size, *layer.weight.shape, device=layer.weight.device
batch_size,
*layer.weight.shape,
device=layer.weight.device,
dtype=backprops.dtype
)
grad_sample.scatter_add_(
1, index, backprops.reshape(batch_size, -1, layer.embedding_dim)
Expand All @@ -65,7 +68,13 @@ def compute_embedding_grad_sample(
def compute_embeddingbag_gradsampler(layer, inputs, backprops):
index, offset = inputs
batch_size = offset.shape[0]
gsm = torch.zeros(batch_size, layer.num_embeddings, layer.embedding_dim)
gsm = torch.zeros(
batch_size,
layer.num_embeddings,
layer.embedding_dim,
device=index.device,
dtype=backprops.dtype,
)

for i in range(batch_size):
begin = offset[i]
Expand Down
7 changes: 5 additions & 2 deletions opacus/grad_sample/embedding_norm_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,18 @@ def compute_embedding_norm_sample(
# Sum gradients over new index positions and compute squared gradient norms
num_unique_paired_indices = unique_paired_indices.size(0)
summed_gradients = torch.zeros(
num_unique_paired_indices, grad_values.size(-1), device=device
num_unique_paired_indices,
grad_values.size(-1),
device=device,
dtype=grad_values.dtype,
)
summed_gradients = summed_gradients.index_add(
0, new_index_positions.to(device), grad_values
)
sqr_gradient_sum = torch.sum(summed_gradients**2, dim=1)

# Scatter add the squared sums back to their respective rows
result = torch.zeros(nrows, device=device)
result = torch.zeros(nrows, device=device, dtype=grad_values.dtype)
unique_batch_ids = unique_paired_indices[:, 0].to(device)
result.scatter_add_(0, unique_batch_ids, sqr_gradient_sum)

Expand Down
23 changes: 18 additions & 5 deletions opacus/grad_sample/functorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import copy
from contextlib import nullcontext

import torch
import torch.nn as nn
Expand Down Expand Up @@ -82,8 +83,19 @@ def compute_loss_stateless_model(params, activations, backprops):
batched_activations = activations.unsqueeze(1)
batched_backprops = backprops.unsqueeze(1)

output = flayer(params, batched_activations)
loss = (output * batched_backprops).sum()
# mixed precision logic
is_mixed = activations.dtype != params[0].dtype
mixed_lowest_dtype = activations.dtype
device_type = activations.device.type

# use amp context if user is using mixed_precision, else proceed as usual
with (
torch.amp.autocast(device_type=device_type, dtype=mixed_lowest_dtype)
if is_mixed
else nullcontext()
):
output = flayer(params, batched_activations)
loss = (output * batched_backprops).sum()
return loss

ft_compute_grad = grad(compute_loss_stateless_model)
Expand All @@ -105,9 +117,10 @@ def ft_compute_per_sample_gradient(layer, activations, backprops):
if not hasattr(layer, "ft_compute_sample_grad"):
prepare_layer(layer)

per_sample_grads = layer.ft_compute_sample_grad(
parameters, activations[0], backprops
)
activations = activations[0]
if activations.dtype != backprops.dtype and activations.is_floating_point():
activations = activations.to(backprops.dtype)
per_sample_grads = layer.ft_compute_sample_grad(parameters, activations, backprops)

ret = {}
for i_p, p in enumerate(parameters):
Expand Down
5 changes: 5 additions & 0 deletions opacus/grad_sample/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def compute_linear_grad_sample(
backprops: Backpropagations
"""
activations = activations[0]

activations = activations.to(backprops.dtype)

ret = {}
if layer.weight.requires_grad:
gs = torch.einsum("n...i,n...j->nij", backprops, activations)
Expand All @@ -61,6 +64,8 @@ def compute_linear_norm_sample(
backprops: Backpropagations
"""
activations = activations[0]
activations = activations.to(backprops.dtype)

ret = {}

if backprops.dim() == 2:
Expand Down
9 changes: 8 additions & 1 deletion opacus/optimizers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _generate_noise(
`2^53` (easy to break) but with `n=2`, we get `2^159`, which is hard
enough for an attacker to break.
"""
zeros = torch.zeros(reference.shape, device=reference.device)
zeros = torch.zeros(reference.shape, device=reference.device, dtype=reference.dtype)
if std == 0:
return zeros
# TODO: handle device transfers: generator and reference tensor
Expand Down Expand Up @@ -165,6 +165,7 @@ def _generate_noise(
size=reference.shape,
device=reference.device,
generator=generator,
dtype=reference.dtype,
)


Expand Down Expand Up @@ -451,6 +452,12 @@ def clip_and_accumulate(self):
for p in self.params:
_check_processed_flag(p.grad_sample)
grad_sample = self._get_flat_grad_sample(p)

# gradients should match the dtype of the optimizer parameters
# for mixed precision, optimizer parameters are usually in FP32
# lower precision grads will be cast up to FP32
grad_sample = grad_sample.to(p.dtype)
per_sample_clip_factor = per_sample_clip_factor.to(p.dtype)
grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample)

if p.summed_grad is not None:
Expand Down
Loading
Loading