Skip to content

Commit

Permalink
Update tolerances for 8bit optimizer tests
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Sep 20, 2024
1 parent c2e7749 commit f8d206f
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions tests/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
errors = []
relerrors = []

for i in range(50):
for i in range(100):
g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
p1.grad = g.clone().float()
p2.grad = g.clone()
Expand All @@ -353,8 +353,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
torch_optimizer.step()

# since Lion can have pretty noisy updates where things lie at the boundary
# and AdEMAMix can diverge as well, allow up to 0.05% errors.
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=int(p1.numel() * 5e-4))
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0)

dequant_states = []
for name1, name2, qmap, max_val in str2statenames[optim_name]:
Expand Down Expand Up @@ -398,11 +397,11 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
err = torch.abs(p1 - p2)
relerr = err / (torch.abs(p1) + 1e-9)
if g.dtype == torch.bfloat16:
assert err.mean() < 0.00015
assert relerr.mean() < 0.0020 # 0.0016
assert err.mean() <= 0.00017
assert relerr.mean() <= 0.0016
else:
assert err.mean() < 0.00016 # 0.00012
assert relerr.mean() < 0.0016 # 0.0012
assert err.mean() < 0.00006
assert relerr.mean() < 0.0006

errors.append(err.mean().item())
relerrors.append(relerr.mean().item())
Expand Down Expand Up @@ -460,9 +459,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):

num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0
assert num_not_close.sum().item() < 20
# since Lion can have pretty noisy updates where things lie at the boundary
# and AdEMAMix can also be noisy, allow up to 0.05%.
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=int(p1.numel() * 5e-04))

# Lion can have pretty noisy updates where things lie at the boundary
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0)

# the parameters diverge quickly. Here we keep them close
# together so we can test against the Adam error
Expand Down

0 comments on commit f8d206f

Please sign in to comment.