Skip to content

Commit 11fcba8

Browse files
iden-kalemajfacebook-github-bot
authored andcommitted
Support mixed and low precision training (#764)
Summary: Pull Request resolved: #764 We add support for mixed and low precision training in Opacus. Mixed precision training is supported iwth "hooks", "ghost" grad_sample_modes. Low-precision trainig is additionally supported with "functorch". Why there is no functorch support for mixed precision trainig: The backward pass with functorch performs both a forward and backward pass to compute per-sample gradients. The forrwad pass happens outside of the torch.amp context, so it cannot handle mixed precision. Support for low and mixed precision training is GPU dependent. Differential Revision: D72415906
1 parent f3752c3 commit 11fcba8

File tree

9 files changed

+1318
-13
lines changed

9 files changed

+1318
-13
lines changed

opacus/grad_sample/conv.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ def compute_conv_grad_sample(
4040
backprops: Backpropagations
4141
"""
4242
activations = activations[0]
43+
44+
if activations.dtype != backprops.dtype:
45+
activations = activations.to(backprops.dtype)
46+
4347
n = activations.shape[0]
4448
if n == 0:
4549
# Empty batch

opacus/grad_sample/embedding.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ def compute_embedding_grad_sample(
5151
.reshape(batch_size, -1, layer.embedding_dim)
5252
)
5353
grad_sample = torch.zeros(
54-
batch_size, *layer.weight.shape, device=layer.weight.device
54+
batch_size,
55+
*layer.weight.shape,
56+
device=layer.weight.device,
57+
dtype=backprops.dtype
5558
)
5659
grad_sample.scatter_add_(
5760
1, index, backprops.reshape(batch_size, -1, layer.embedding_dim)
@@ -65,7 +68,9 @@ def compute_embedding_grad_sample(
6568
def compute_embeddingbag_gradsampler(layer, inputs, backprops):
6669
index, offset = inputs
6770
batch_size = offset.shape[0]
68-
gsm = torch.zeros(batch_size, layer.num_embeddings, layer.embedding_dim)
71+
gsm = torch.zeros(
72+
batch_size, layer.num_embeddings, layer.embedding_dim, device=index.device
73+
)
6974

7075
for i in range(batch_size):
7176
begin = offset[i]

opacus/grad_sample/embedding_norm_sample.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,15 +131,18 @@ def compute_embedding_norm_sample(
131131
# Sum gradients over new index positions and compute squared gradient norms
132132
num_unique_paired_indices = unique_paired_indices.size(0)
133133
summed_gradients = torch.zeros(
134-
num_unique_paired_indices, grad_values.size(-1), device=device
134+
num_unique_paired_indices,
135+
grad_values.size(-1),
136+
device=device,
137+
dtype=grad_values.dtype,
135138
)
136139
summed_gradients = summed_gradients.index_add(
137140
0, new_index_positions.to(device), grad_values
138141
)
139142
sqr_gradient_sum = torch.sum(summed_gradients**2, dim=1)
140143

141144
# Scatter add the squared sums back to their respective rows
142-
result = torch.zeros(nrows, device=device)
145+
result = torch.zeros(nrows, device=device, dtype=grad_values.dtype)
143146
unique_batch_ids = unique_paired_indices[:, 0].to(device)
144147
result.scatter_add_(0, unique_batch_ids, sqr_gradient_sum)
145148

opacus/grad_sample/functorch.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ def compute_loss_stateless_model(params, activations, backprops):
8181
# If batch_first is False, the batch dimension is the second dimension
8282
batched_activations = activations.unsqueeze(1)
8383
batched_backprops = backprops.unsqueeze(1)
84-
8584
output = flayer(params, batched_activations)
8685
loss = (output * batched_backprops).sum()
8786
return loss
@@ -105,9 +104,11 @@ def ft_compute_per_sample_gradient(layer, activations, backprops):
105104
if not hasattr(layer, "ft_compute_sample_grad"):
106105
prepare_layer(layer)
107106

108-
per_sample_grads = layer.ft_compute_sample_grad(
109-
parameters, activations[0], backprops
110-
)
107+
activations = activations[0]
108+
print(activations.dtype)
109+
if activations.dtype != backprops.dtype and activations.is_floating_point():
110+
activations = activations.to(backprops.dtype)
111+
per_sample_grads = layer.ft_compute_sample_grad(parameters, activations, backprops)
111112

112113
ret = {}
113114
for i_p, p in enumerate(parameters):

opacus/grad_sample/linear.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ def compute_linear_grad_sample(
3939
backprops: Backpropagations
4040
"""
4141
activations = activations[0]
42+
if activations.dtype != backprops.dtype:
43+
activations = activations.to(backprops.dtype)
44+
4245
ret = {}
4346
if layer.weight.requires_grad:
4447
gs = torch.einsum("n...i,n...j->nij", backprops, activations)
@@ -61,8 +64,10 @@ def compute_linear_norm_sample(
6164
backprops: Backpropagations
6265
"""
6366
activations = activations[0]
64-
ret = {}
67+
if activations.dtype != backprops.dtype:
68+
activations = activations.to(backprops.dtype)
6569

70+
ret = {}
6671
if backprops.dim() == 2:
6772
if layer.weight.requires_grad:
6873
g = torch.einsum("n...i,n...i->n", backprops, backprops)

opacus/optimizers/optimizer.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def _generate_noise(
134134
`2^53` (easy to break) but with `n=2`, we get `2^159`, which is hard
135135
enough for an attacker to break.
136136
"""
137-
zeros = torch.zeros(reference.shape, device=reference.device)
137+
zeros = torch.zeros(reference.shape, device=reference.device, dtype=reference.dtype)
138138
if std == 0:
139139
return zeros
140140
# TODO: handle device transfers: generator and reference tensor
@@ -164,6 +164,7 @@ def _generate_noise(
164164
size=reference.shape,
165165
device=reference.device,
166166
generator=generator,
167+
dtype=reference.dtype,
167168
)
168169

169170

@@ -432,7 +433,6 @@ def clip_and_accumulate(self):
432433
Performs gradient clipping.
433434
Stores clipped and aggregated gradients into `p.summed_grad```
434435
"""
435-
436436
if len(self.grad_samples[0]) == 0:
437437
# Empty batch
438438
per_sample_clip_factor = torch.zeros(
@@ -450,6 +450,14 @@ def clip_and_accumulate(self):
450450
for p in self.params:
451451
_check_processed_flag(p.grad_sample)
452452
grad_sample = self._get_flat_grad_sample(p)
453+
454+
# gradients should match the dtype of the optimizer parameters
455+
# for mixed precision, optimizer parameters are usually in FP32
456+
# lower precision grads will be cast up to FP32
457+
if grad_sample.dtype != p.dtype:
458+
grad_sample = grad_sample.to(p.dtype)
459+
if per_sample_clip_factor.dtype != p.dtype:
460+
per_sample_clip_factor = per_sample_clip_factor.to(p.dtype)
453461
grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample)
454462

455463
if p.summed_grad is not None:
@@ -463,7 +471,6 @@ def add_noise(self):
463471
"""
464472
Adds noise to clipped gradients. Stores clipped and noised result in ``p.grad``
465473
"""
466-
467474
for p in self.params:
468475
_check_processed_flag(p.summed_grad)
469476

@@ -474,7 +481,6 @@ def add_noise(self):
474481
secure_mode=self.secure_mode,
475482
)
476483
p.grad = (p.summed_grad + noise).view_as(p)
477-
478484
_mark_as_processed(p.summed_grad)
479485

480486
def scale_grad(self):

0 commit comments

Comments
 (0)