Skip to content

Commit f2b44aa

Browse files
iden-kalemajfacebook-github-bot
authored andcommitted
Support mixed and low precision training (#764)
Summary: We add support for mixed and low precision training in Opacus. Mixed precision training is supported iwth "hooks", "ghost" grad_sample_modes. Low-precision trainig is additionally supported with "functorch", "ew" Why there is no functorch support for mixed precision trainig: The backward pass with functorch performs both a forward and backward pass to compute per-sample gradients. The forrwad pass happens outside of the torch.amp context, so it cannot handle mixed precision. Support for low and mixed precision training is GPU dependent. Differential Revision: D72415906
1 parent f3752c3 commit f2b44aa

File tree

8 files changed

+987
-8
lines changed

8 files changed

+987
-8
lines changed

opacus/grad_sample/conv.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ def compute_conv_grad_sample(
4040
backprops: Backpropagations
4141
"""
4242
activations = activations[0]
43+
44+
if activations.dtype != backprops.dtype:
45+
activations = activations.to(backprops.dtype)
46+
4347
n = activations.shape[0]
4448
if n == 0:
4549
# Empty batch

opacus/grad_sample/embedding.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ def compute_embedding_grad_sample(
5151
.reshape(batch_size, -1, layer.embedding_dim)
5252
)
5353
grad_sample = torch.zeros(
54-
batch_size, *layer.weight.shape, device=layer.weight.device
54+
batch_size,
55+
*layer.weight.shape,
56+
device=layer.weight.device,
57+
dtype=backprops.dtype
5558
)
5659
grad_sample.scatter_add_(
5760
1, index, backprops.reshape(batch_size, -1, layer.embedding_dim)
@@ -65,7 +68,13 @@ def compute_embedding_grad_sample(
6568
def compute_embeddingbag_gradsampler(layer, inputs, backprops):
6669
index, offset = inputs
6770
batch_size = offset.shape[0]
68-
gsm = torch.zeros(batch_size, layer.num_embeddings, layer.embedding_dim)
71+
gsm = torch.zeros(
72+
batch_size,
73+
layer.num_embeddings,
74+
layer.embedding_dim,
75+
device=index.device,
76+
dtype=backprops.dtype,
77+
)
6978

7079
for i in range(batch_size):
7180
begin = offset[i]

opacus/grad_sample/embedding_norm_sample.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,15 +131,18 @@ def compute_embedding_norm_sample(
131131
# Sum gradients over new index positions and compute squared gradient norms
132132
num_unique_paired_indices = unique_paired_indices.size(0)
133133
summed_gradients = torch.zeros(
134-
num_unique_paired_indices, grad_values.size(-1), device=device
134+
num_unique_paired_indices,
135+
grad_values.size(-1),
136+
device=device,
137+
dtype=grad_values.dtype,
135138
)
136139
summed_gradients = summed_gradients.index_add(
137140
0, new_index_positions.to(device), grad_values
138141
)
139142
sqr_gradient_sum = torch.sum(summed_gradients**2, dim=1)
140143

141144
# Scatter add the squared sums back to their respective rows
142-
result = torch.zeros(nrows, device=device)
145+
result = torch.zeros(nrows, device=device, dtype=grad_values.dtype)
143146
unique_batch_ids = unique_paired_indices[:, 0].to(device)
144147
result.scatter_add_(0, unique_batch_ids, sqr_gradient_sum)
145148

opacus/grad_sample/functorch.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,10 @@ def ft_compute_per_sample_gradient(layer, activations, backprops):
105105
if not hasattr(layer, "ft_compute_sample_grad"):
106106
prepare_layer(layer)
107107

108-
per_sample_grads = layer.ft_compute_sample_grad(
109-
parameters, activations[0], backprops
110-
)
108+
activations = activations[0]
109+
if activations.dtype != backprops.dtype and activations.is_floating_point():
110+
activations = activations.to(backprops.dtype)
111+
per_sample_grads = layer.ft_compute_sample_grad(parameters, activations, backprops)
111112

112113
ret = {}
113114
for i_p, p in enumerate(parameters):

opacus/grad_sample/linear.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ def compute_linear_grad_sample(
3939
backprops: Backpropagations
4040
"""
4141
activations = activations[0]
42+
if activations.dtype != backprops.dtype:
43+
activations = activations.to(backprops.dtype)
44+
4245
ret = {}
4346
if layer.weight.requires_grad:
4447
gs = torch.einsum("n...i,n...j->nij", backprops, activations)
@@ -61,6 +64,9 @@ def compute_linear_norm_sample(
6164
backprops: Backpropagations
6265
"""
6366
activations = activations[0]
67+
if activations.dtype != backprops.dtype:
68+
activations = activations.to(backprops.dtype)
69+
6470
ret = {}
6571

6672
if backprops.dim() == 2:

opacus/optimizers/optimizer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def _generate_noise(
134134
`2^53` (easy to break) but with `n=2`, we get `2^159`, which is hard
135135
enough for an attacker to break.
136136
"""
137-
zeros = torch.zeros(reference.shape, device=reference.device)
137+
zeros = torch.zeros(reference.shape, device=reference.device, dtype=reference.dtype)
138138
if std == 0:
139139
return zeros
140140
# TODO: handle device transfers: generator and reference tensor
@@ -164,6 +164,7 @@ def _generate_noise(
164164
size=reference.shape,
165165
device=reference.device,
166166
generator=generator,
167+
dtype=reference.dtype,
167168
)
168169

169170

@@ -450,6 +451,14 @@ def clip_and_accumulate(self):
450451
for p in self.params:
451452
_check_processed_flag(p.grad_sample)
452453
grad_sample = self._get_flat_grad_sample(p)
454+
455+
# gradients should match the dtype of the optimizer parameters
456+
# for mixed precision, optimizer parameters are usually in FP32
457+
# lower precision grads will be cast up to FP32
458+
if grad_sample.dtype != p.dtype:
459+
grad_sample = grad_sample.to(p.dtype)
460+
if per_sample_clip_factor.dtype != p.dtype:
461+
per_sample_clip_factor = per_sample_clip_factor.to(p.dtype)
453462
grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample)
454463

455464
if p.summed_grad is not None:

0 commit comments

Comments
 (0)