Skip to content

Commit 414f938

Browse files
iden-kalemajfacebook-github-bot
authored andcommitted
Support mixed and low precision training (#764)
Summary: Pull Request resolved: #764 We add support for mixed and low precision training in Opacus. Mixed precision training is supported with "hooks", "ghost", "functorch" grad_sample_modes. Low-precision training is additionally supported with "ew" Support for low and mixed precision training is GPU dependent. Differential Revision: D72415906
1 parent ef04ad9 commit 414f938

File tree

8 files changed

+975
-10
lines changed

8 files changed

+975
-10
lines changed

opacus/grad_sample/conv.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ def compute_conv_grad_sample(
4040
backprops: Backpropagations
4141
"""
4242
activations = activations[0]
43+
44+
activations = activations.to(backprops.dtype)
45+
4346
n = activations.shape[0]
4447
if n == 0:
4548
# Empty batch

opacus/grad_sample/embedding.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ def compute_embedding_grad_sample(
5151
.reshape(batch_size, -1, layer.embedding_dim)
5252
)
5353
grad_sample = torch.zeros(
54-
batch_size, *layer.weight.shape, device=layer.weight.device
54+
batch_size,
55+
*layer.weight.shape,
56+
device=layer.weight.device,
57+
dtype=backprops.dtype
5558
)
5659
grad_sample.scatter_add_(
5760
1, index, backprops.reshape(batch_size, -1, layer.embedding_dim)
@@ -65,7 +68,13 @@ def compute_embedding_grad_sample(
6568
def compute_embeddingbag_gradsampler(layer, inputs, backprops):
6669
index, offset = inputs
6770
batch_size = offset.shape[0]
68-
gsm = torch.zeros(batch_size, layer.num_embeddings, layer.embedding_dim)
71+
gsm = torch.zeros(
72+
batch_size,
73+
layer.num_embeddings,
74+
layer.embedding_dim,
75+
device=index.device,
76+
dtype=backprops.dtype,
77+
)
6978

7079
for i in range(batch_size):
7180
begin = offset[i]

opacus/grad_sample/embedding_norm_sample.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,15 +131,18 @@ def compute_embedding_norm_sample(
131131
# Sum gradients over new index positions and compute squared gradient norms
132132
num_unique_paired_indices = unique_paired_indices.size(0)
133133
summed_gradients = torch.zeros(
134-
num_unique_paired_indices, grad_values.size(-1), device=device
134+
num_unique_paired_indices,
135+
grad_values.size(-1),
136+
device=device,
137+
dtype=grad_values.dtype,
135138
)
136139
summed_gradients = summed_gradients.index_add(
137140
0, new_index_positions.to(device), grad_values
138141
)
139142
sqr_gradient_sum = torch.sum(summed_gradients**2, dim=1)
140143

141144
# Scatter add the squared sums back to their respective rows
142-
result = torch.zeros(nrows, device=device)
145+
result = torch.zeros(nrows, device=device, dtype=grad_values.dtype)
143146
unique_batch_ids = unique_paired_indices[:, 0].to(device)
144147
result.scatter_add_(0, unique_batch_ids, sqr_gradient_sum)
145148

opacus/grad_sample/functorch.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import copy
16+
from contextlib import nullcontext
1617

1718
import torch
1819
import torch.nn as nn
@@ -82,8 +83,19 @@ def compute_loss_stateless_model(params, activations, backprops):
8283
batched_activations = activations.unsqueeze(1)
8384
batched_backprops = backprops.unsqueeze(1)
8485

85-
output = flayer(params, batched_activations)
86-
loss = (output * batched_backprops).sum()
86+
# mixed precision logic
87+
is_mixed = activations.dtype != params[0].dtype
88+
mixed_lowest_dtype = activations.dtype
89+
device_type = activations.device.type
90+
91+
# use amp context if user is using mixed_precision, else proceed as usual
92+
with (
93+
torch.amp.autocast(device_type=device_type, dtype=mixed_lowest_dtype)
94+
if is_mixed
95+
else nullcontext()
96+
):
97+
output = flayer(params, batched_activations)
98+
loss = (output * batched_backprops).sum()
8799
return loss
88100

89101
ft_compute_grad = grad(compute_loss_stateless_model)
@@ -105,9 +117,10 @@ def ft_compute_per_sample_gradient(layer, activations, backprops):
105117
if not hasattr(layer, "ft_compute_sample_grad"):
106118
prepare_layer(layer)
107119

108-
per_sample_grads = layer.ft_compute_sample_grad(
109-
parameters, activations[0], backprops
110-
)
120+
activations = activations[0]
121+
if activations.dtype != backprops.dtype and activations.is_floating_point():
122+
activations = activations.to(backprops.dtype)
123+
per_sample_grads = layer.ft_compute_sample_grad(parameters, activations, backprops)
111124

112125
ret = {}
113126
for i_p, p in enumerate(parameters):

opacus/grad_sample/linear.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ def compute_linear_grad_sample(
3939
backprops: Backpropagations
4040
"""
4141
activations = activations[0]
42+
43+
activations = activations.to(backprops.dtype)
44+
4245
ret = {}
4346
if layer.weight.requires_grad:
4447
gs = torch.einsum("n...i,n...j->nij", backprops, activations)
@@ -61,6 +64,9 @@ def compute_linear_norm_sample(
6164
backprops: Backpropagations
6265
"""
6366
activations = activations[0]
67+
if activations.dtype != backprops.dtype:
68+
activations = activations.to(backprops.dtype)
69+
6470
ret = {}
6571

6672
if backprops.dim() == 2:

opacus/optimizers/optimizer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def _generate_noise(
134134
`2^53` (easy to break) but with `n=2`, we get `2^159`, which is hard
135135
enough for an attacker to break.
136136
"""
137-
zeros = torch.zeros(reference.shape, device=reference.device)
137+
zeros = torch.zeros(reference.shape, device=reference.device, dtype=reference.dtype)
138138
if std == 0:
139139
return zeros
140140
# TODO: handle device transfers: generator and reference tensor
@@ -164,6 +164,7 @@ def _generate_noise(
164164
size=reference.shape,
165165
device=reference.device,
166166
generator=generator,
167+
dtype=reference.dtype,
167168
)
168169

169170

@@ -450,6 +451,12 @@ def clip_and_accumulate(self):
450451
for p in self.params:
451452
_check_processed_flag(p.grad_sample)
452453
grad_sample = self._get_flat_grad_sample(p)
454+
455+
# gradients should match the dtype of the optimizer parameters
456+
# for mixed precision, optimizer parameters are usually in FP32
457+
# lower precision grads will be cast up to FP32
458+
grad_sample = grad_sample.to(p.dtype)
459+
per_sample_clip_factor = per_sample_clip_factor.to(p.dtype)
453460
grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample)
454461

455462
if p.summed_grad is not None:

0 commit comments

Comments
 (0)