Skip to content

Commit

Permalink
Update test for adaptive eigenbasis computation frequency criterion
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Feb 23, 2025
1 parent 6546d54 commit 8411ed6
Showing 1 changed file with 22 additions and 14 deletions.
36 changes: 22 additions & 14 deletions distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,39 +1101,47 @@ def test_adaptive_amortized_computation_frequency_criterion_below_tolerance(
self,
) -> None:
"""Test adaptive amortized computation frequency criterion."""
test_factor_matrix = torch.tensor(
test_factor_matrix_triu = torch.tensor(
[
[0.1728, 0.7989, -1.3391, -0.2319, -0.6411],
[1.5930, 0.1929, -0.3534, 0.9371, -1.7551],
[1.4669, -0.2361, 0.5761, 0.6107, 0.6555],
[0.2715, -1.1916, 0.8504, -0.2460, 1.0288],
[-0.1769, 1.0129, 0.1652, -0.7164, -1.2969],
],
)
test_factor_matrix_eigenvectors = torch.linalg.eigh(
).triu()
test_factor_matrix = test_factor_matrix_triu + test_factor_matrix_triu.T
test_factor_matrix_eigenvectors: Tensor = torch.linalg.eigh(
test_factor_matrix
).eigenvectors
criterion = torch.linalg.matrix_norm(
(
test_factor_matrix_eigenvectors.T
@ test_factor_matrix
@ test_factor_matrix_eigenvectors
).fill_diagonal_(0.0) # Off-diagonal elements only.
) / (1 + torch.linalg.matrix_norm(test_factor_matrix))

# Compute the norm of the approximate eigenvectors and their off-diagonal.
approximate_eigenvalues = (
test_factor_matrix_eigenvectors.T
@ test_factor_matrix
@ test_factor_matrix_eigenvectors
)
norm = torch.linalg.matrix_norm(approximate_eigenvalues)
off_diagonal_norm = torch.linalg.matrix_norm(
approximate_eigenvalues.fill_diagonal_(0.0) # Off-diagonal elements only.
)
# off_diagonal_norm / norm = 1.8125589917872276e-07.

# Above tolerance.
test_tolerance1 = 1e-1
test_tolerance1 = 1e-7
test_criterion1 = EigenvalueCorrectedShampooPreconditionerList._adaptive_amortized_computation_frequency_criterion_below_or_equal_tolerance(
test_factor_matrix, test_factor_matrix_eigenvectors, test_tolerance1
)
assert (criterion <= test_tolerance1) == test_criterion1 # criterion=False.
assert not test_criterion1
assert not (off_diagonal_norm <= test_tolerance1 * norm)

# Below tolerance.
test_tolerance2 = 1
test_tolerance2 = 2e-7
test_criterion2 = EigenvalueCorrectedShampooPreconditionerList._adaptive_amortized_computation_frequency_criterion_below_or_equal_tolerance(
test_factor_matrix, test_factor_matrix_eigenvectors, test_tolerance2
)
assert (criterion <= test_tolerance2) == test_criterion2 # criterion=True.
assert test_criterion2
assert off_diagonal_norm <= test_tolerance2 * norm

def test_adaptive_amortized_computation_frequency(self):
# Setup the preconditioner list with the adaptive amortized computation frequency.
Expand Down

0 comments on commit 8411ed6

Please sign in to comment.