Skip to content

Commit

Permalink
Remove superfluous wrapper function
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Dec 16, 2024
1 parent 9dde421 commit 36e9131
Showing 1 changed file with 3 additions and 23 deletions.
26 changes: 3 additions & 23 deletions matrix_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,10 +639,10 @@ def matrix_eigenvectors(
)

if type(eigenvector_computation_config) is EighEigenvectorConfig:
return _compute_eigenvectors_eigh(
return matrix_eigenvalue_decomposition(
A,
retry_double_precision=eigenvector_computation_config.retry_double_precision,
)
)[1]
elif type(eigenvector_computation_config) is QRConfig:
assert (
eigenvectors_estimate is not None
Expand All @@ -659,26 +659,6 @@ def matrix_eigenvectors(
)


def _compute_eigenvectors_eigh(
A: Tensor, retry_double_precision: bool = True
) -> Tensor:
"""Compute the eigenvectors of a symmetric matrix using torch.linalg.eigh.
Args:
A (Tensor): The symmetric input matrix.
retry_double_precision (bool): Whether to retry the computation in double precision if it fails in the current precision.
(Default: True)
Returns:
Tensor: The eigenvectors of the input matrix A.
"""
return matrix_eigenvalue_decomposition(
A,
retry_double_precision=retry_double_precision,
)[1]


def _compute_orthogonal_iterations(
A: Tensor,
eigenvectors_estimate: Tensor,
Expand All @@ -705,7 +685,7 @@ def _compute_orthogonal_iterations(
"""
if not eigenvectors_estimate.any():
return _compute_eigenvectors_eigh(A)
return matrix_eigenvalue_decomposition(A)[1]

# Perform orthogonal/simultaneous iterations (QR algorithm).
Q = eigenvectors_estimate
Expand Down

0 comments on commit 36e9131

Please sign in to comment.