Skip to content

Commit 328d758

Browse files
Bump tolerances for per_sample_grads tutorial
1 parent caa0094 commit 328d758

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

intermediate_source/per_sample_grads.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,10 @@ def compute_loss(params, buffers, sample, target):
169169
# results of hand processing each one individually:
170170

171171
for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads.values()):
172-
assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)
172+
print((per_sample_grad - ft_per_sample_grad).max(), (per_sample_grad - ft_per_sample_grad).mean())
173+
# assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=1.5e-2, rtol=1e-5)
174+
175+
assert False
173176

174177
######################################################################
175178
# A quick note: there are limitations around what types of functions can be

0 commit comments

Comments
 (0)