diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index ec1a938..990a5f9 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -31,6 +31,7 @@ GraftingType, LargeDimMethod, ShampooPreconditioner, + RootInvMethod ) from torch.nn import Parameter @@ -62,7 +63,7 @@ class DistributedShampoo(torch.optim.Optimizer): Hao-Jun Michael Shi (Meta Platforms, Inc.) Tsung-Hsien Lee Shintaro Iwasaki (Meta Platforms, Inc.) - Jose Gallego-Posada (MILA / Meta Platforms, Inc.) + Jose Gallego-Posada (Mila / Meta Platforms, Inc.) with contributions and support from: @@ -181,6 +182,8 @@ class DistributedShampoo(torch.optim.Optimizer): precondition_frequency (int): frequency for computing root inverse preconditioner (Default: 1) start_preconditioning_step (int): iteration to start computing inverse preconditioner. If -1, uses the same value as precondition_frequency. (Default: -1) + root_inv_method (RootInvMethod): Strategy for computing (inverse) preconditioner roots. (Default: + RootInvMethod.PSEUDO_EIGEN) exponent_override (int, List[int]): inverse root to use in Shampoo. If a list [l1, l2, ..., lp], then we will use -1 / l1 for 1-D tensor (vectors), -1 / l2 for 2-D tensors (matrices), and so on. If the order of the tensor exceeds the order of the tensor, reverts to the default value. If 0 is used, uses the default inverse @@ -221,6 +224,7 @@ def __init__( max_preconditioner_dim: int = 1024, precondition_frequency: int = 1, start_preconditioning_step: int = -1, + root_inv_method: RootInvMethod = RootInvMethod.PSEUDO_EIGEN, exponent_override: Union[int, List[int]] = 0, exponent_multiplier: float = 1.0, use_nesterov: bool = False, @@ -338,6 +342,7 @@ def __init__( # Initialize algorithm-related fields. self._max_preconditioner_dim = max_preconditioner_dim self._precondition_frequency = precondition_frequency + self._root_inv_method = root_inv_method self._exponent_override = exponent_override self._exponent_multiplier = exponent_multiplier self._num_trainers_per_group = num_trainers_per_group @@ -366,7 +371,9 @@ def __init__( logger.warning( f"start_preconditioning_step set to -1. Setting start_preconditioning_step equal to precondition frequency {precondition_frequency} by default." ) - elif start_preconditioning_step < precondition_frequency: + elif 0 < start_preconditioning_step < precondition_frequency: + # If start_preconditioning_step==0, we allow an initial preconditioned update + # with the first estimate of the preconditioner. raise ValueError( f"Invalid start_preconditioning_step value: {start_preconditioning_step}. Must be >= {precondition_frequency = }." ) @@ -443,6 +450,7 @@ def _initialize_preconditioners_and_steps( use_protected_eigh=self._use_protected_eigh, use_dtensor=self._use_dtensor, communication_dtype=self._communication_dtype, + root_inv_method=self._root_inv_method, ) preconditioner_count += len( state[PRECONDITIONERS].get_split_dist_buffers() @@ -469,6 +477,7 @@ def _initialize_preconditioners_and_steps( dist_buffer=dist_buffer, use_dtensor=self._use_dtensor, communication_dtype=self._communication_dtype, + root_inv_method=self._root_inv_method, ) if torch.any(dims > self._max_preconditioner_dim) else ShampooPreconditioner( @@ -491,6 +500,7 @@ def _initialize_preconditioners_and_steps( use_protected_eigh=self._use_protected_eigh, use_dtensor=self._use_dtensor, communication_dtype=self._communication_dtype, + root_inv_method=self._root_inv_method, ) ) @@ -524,6 +534,7 @@ def _initialize_preconditioners_and_steps( use_protected_eigh=self._use_protected_eigh, use_dtensor=self._use_dtensor, communication_dtype=self._communication_dtype, + root_inv_method=self._root_inv_method, ) else: @@ -814,10 +825,8 @@ def step(self, closure=None): # Computes root inverse of all preconditioners every self._precondition_frequency # after the self._start_preconditioning_step iteration. - if ( - iteration % self._precondition_frequency == 0 - and iteration >= self._start_preconditioning_step - ): + is_precond_step = (iteration == 1) or (iteration % self._precondition_frequency == 0) + if (is_precond_step and iteration >= self._start_preconditioning_step): self._compute_root_inverse() if self._debug_mode: diff --git a/distributed_shampoo/matrix_functions.py b/distributed_shampoo/matrix_functions.py index f3aaf0c..beaaacb 100644 --- a/distributed_shampoo/matrix_functions.py +++ b/distributed_shampoo/matrix_functions.py @@ -25,6 +25,7 @@ class NewtonConvergenceFlag(enum.Enum): class RootInvMethod(enum.Enum): EIGEN = 0 NEWTON = 1 + PSEUDO_EIGEN = 2 def check_diagonal(A: Tensor) -> Tensor: @@ -46,7 +47,7 @@ def matrix_inverse_root( root: int, epsilon: float = 0.0, exponent_multiplier: float = 1.0, - root_inv_method: RootInvMethod = RootInvMethod.EIGEN, + root_inv_method: RootInvMethod = RootInvMethod.PSEUDO_EIGEN, max_iterations: int = 1000, tolerance: float = 1e-6, is_diagonal: Union[Tensor, bool] = False, @@ -91,8 +92,9 @@ def matrix_inverse_root( inverse=True, exponent_multiplier=exponent_multiplier, return_full_matrix=True, + use_pseudo_inverse=(root_inv_method == RootInvMethod.PSEUDO_EIGEN), ) - elif root_inv_method == RootInvMethod.EIGEN: + elif root_inv_method in [RootInvMethod.EIGEN, RootInvMethod.PSEUDO_EIGEN]: X, _, _ = _matrix_root_eigen( A=A, root=root, @@ -100,6 +102,7 @@ def matrix_inverse_root( inverse=True, exponent_multiplier=exponent_multiplier, retry_double_precision=retry_double_precision, + use_pseudo_inverse=(root_inv_method == RootInvMethod.PSEUDO_EIGEN), ) elif root_inv_method == RootInvMethod.NEWTON: if exponent_multiplier != 1.0: @@ -135,6 +138,7 @@ def matrix_root_diagonal( inverse: bool = True, exponent_multiplier: float = 1.0, return_full_matrix: bool = False, + use_pseudo_inverse: bool = True, ) -> Tensor: """Computes matrix inverse root for a diagonal matrix by taking inverse square root of diagonal entries. @@ -143,7 +147,9 @@ def matrix_root_diagonal( root (int): Root of interest. Any natural number. epsilon (float): Adds epsilon * I to matrix before taking matrix root. (Default: 0.0) inverse (bool): Returns inverse root matrix. (Default: True) - return_full_matrix (bool): Returns full matrix by taking torch.diag of diagonal entries. (bool: False) + return_full_matrix (bool): Returns full matrix by taking torch.diag of diagonal entries. (Default: False) + use_pseudo_inverse (bool): Computes the matrix pseudo-inverse root by discarding negative eigenvalues/diagonal + elements. To use this option, `inverse` must be True. (Default: False) Returns: X (Tensor): Inverse root of diagonal entries. @@ -157,7 +163,6 @@ def matrix_root_diagonal( elif order > 2: raise ValueError("Matrix is not 2-dimensional!") - # check if root is positive integer if root <= 0: raise ValueError(f"Root {root} should be positive!") @@ -166,7 +171,11 @@ def matrix_root_diagonal( if inverse: alpha = -alpha - X = (A + epsilon).pow(alpha) + if use_pseudo_inverse: + X = torch.where(A <= 0.0, torch.zeros_like(A), A.pow(alpha)) + else: + X = (A + epsilon).pow(alpha) + return torch.diag(X) if return_full_matrix else X @@ -178,6 +187,7 @@ def _matrix_root_eigen( exponent_multiplier: float = 1.0, make_positive_semidefinite: bool = True, retry_double_precision: bool = True, + use_pseudo_inverse: bool = True, ) -> Tuple[Tensor, Tensor, Tensor]: """Compute matrix (inverse) root using eigendecomposition of symmetric positive (semi-)definite matrix. @@ -194,6 +204,8 @@ def _matrix_root_eigen( make_positive_semidefinite (bool): Perturbs matrix eigenvalues to ensure it is numerically positive semi-definite. (Default: True) retry_double_precision (bool): Flag for re-trying eigendecomposition with higher precision if lower precision fails due to CuSOLVER failure. (Default: True) + use_pseudo_inverse (bool): Computes the matrix pseudo-inverse root by discarding negative eigenvalues elements. + To use this option, `inverse` must be True. (Default: True) Returns: X (Tensor): (Inverse) root of matrix. Same dimensions as A. @@ -202,7 +214,6 @@ def _matrix_root_eigen( """ - # check if root is positive integer if root <= 0: raise ValueError(f"Root {root} should be positive!") @@ -224,17 +235,22 @@ def _matrix_root_eigen( else: raise exception - lambda_min = torch.min(L) - - # make eigenvalues >= 0 (if necessary) - if make_positive_semidefinite: - L += -torch.minimum(lambda_min, torch.as_tensor(0.0)) - - # add epsilon - L += epsilon + if use_pseudo_inverse: + # Filter the eigenvalues based on the numerical rank of the matrix + # The procedure below mimics the steps described in the documentation of + # https://pytorch.org/docs/stable/generated/torch.linalg.matrix_rank.html + rtol = L.numel() * torch.finfo(A.dtype).eps + spectrum_cutoff = rtol * L.max().relu() + power_L = torch.where(L <= spectrum_cutoff, torch.zeros_like(L), L.pow(alpha)) + else: + lambda_min = torch.min(L) + # make eigenvalues >= 0 (if necessary) + if make_positive_semidefinite: + L += -torch.minimum(lambda_min, torch.as_tensor(0.0)) + power_L = (L + epsilon).pow(alpha) # compute inverse preconditioner - X = Q * L.pow(alpha).unsqueeze(0) @ Q.T + X = Q * power_L.unsqueeze(0) @ Q.T return X, L, Q diff --git a/distributed_shampoo/shampoo_utils.py b/distributed_shampoo/shampoo_utils.py index 1cd1d89..cffd98c 100644 --- a/distributed_shampoo/shampoo_utils.py +++ b/distributed_shampoo/shampoo_utils.py @@ -20,6 +20,7 @@ check_diagonal, compute_matrix_root_inverse_residuals, matrix_inverse_root, + RootInvMethod ) from distributed_shampoo.optimizer_modules import OptimizerModule from distributed_shampoo.shampoo_dist_utils import ( @@ -288,7 +289,8 @@ class AdagradPreconditioner(DistributedPreconditioner): dist_buffer (Optional[Tensor]): Buffer for distributed computation. (Default: None) use_dtensor (bool): Flag for using DTensor. Requires PyTorch 2 nightly. Otherwise, uses Tensor. (Default: True) communication_dtype (CommunicationDType): Datatype for communication between ranks. (Default: DEFAULT) - + root_inv_method (RootInvMethod): Strategy for computing (inverse) preconditioner roots. (Default: + RootInvMethod.PSEUDO_EIGEN) """ def __init__( @@ -303,9 +305,10 @@ def __init__( dist_buffer: Optional[Tensor] = None, use_dtensor: bool = True, communication_dtype: CommunicationDType = CommunicationDType.DEFAULT, + root_inv_method: RootInvMethod = RootInvMethod.PSEUDO_EIGEN, ): super(AdagradPreconditioner, self).__init__( - param, group, group_source_rank, dist_buffer, communication_dtype + param, group, group_source_rank, dist_buffer, communication_dtype, ) self._beta2 = beta2 self._epsilon = epsilon @@ -321,6 +324,11 @@ def __init__( self._bias_correction2 = 1.0 self._parameter_count += self._preconditioner.numel() + if root_inv_method not in [RootInvMethod.EIGEN, RootInvMethod.PSEUDO_EIGEN]: + raise ValueError(f"Adagrad preconditioner only supports eigendecomposition variants for `root_inv_method` \ + but received {root_inv_method}.") + self._root_inv_method = root_inv_method + if self._idx is not None: self._preconditioner_idx = str(self._idx) + "." + str(0) logger.info( @@ -345,23 +353,30 @@ def precondition(self, grad: Tensor, iteration: int) -> Tensor: if not self._on_source_rank: return grad - denom = ( - (use_local_tensor(self._preconditioner) / self._bias_correction2) - .sqrt() - .add_(self._epsilon) - ) + denom = (use_local_tensor(self._preconditioner) / self._bias_correction2).sqrt() + if self._root_inv_method == RootInvMethod.PSEUDO_EIGEN: + # Put ones in the locations of zero entries so that the later division has no effect + torch.where(denom == 0, torch.ones_like(denom), denom, out=denom) + else: + # If not using pseudoinverse, use epsilon to improve numerical stability + denom.add_(self._epsilon) + grad.div_(denom) + return grad def compute_norm(self, grad: Tensor, iteration: int): if not self._on_source_rank: return torch.as_tensor(1.0) # return cheap tensor - denom = ( - (use_local_tensor(self._preconditioner) / self._bias_correction2) - .sqrt() - .add_(self._epsilon) - ) + denom = (use_local_tensor(self._preconditioner) / self._bias_correction2).sqrt() + if self._root_inv_method == RootInvMethod.PSEUDO_EIGEN: + # Put ones in the locations of zero entries so that the later division has no effect + torch.where(denom == 0, torch.ones_like(denom), denom, out=denom) + else: + # If not using pseudoinverse, use epsilon to improve numerical stability + denom.add_(self._epsilon) + adagrad_nrm = torch.linalg.norm(grad / denom) return adagrad_nrm @@ -418,6 +433,8 @@ class ShampooPreconditioner(DistributedPreconditioner): 3. Otherwise, re-uses previous inverse factor matrix when both root inverse computations fail. use_dtensor (bool): Flag for using DTensor. Requires PyTorch 2 nightly. Otherwise, uses Tensor. (Default: True) communication_dtype (CommunicationDType): Datatype for communication between ranks. (Default: DEFAULT) + root_inv_method (RootInvMethod): Strategy for computing (inverse) preconditioner roots. (Default: + RootInvMethod.PSEUDO_EIGEN) """ @@ -442,6 +459,7 @@ def __init__( use_protected_eigh: bool = True, use_dtensor: bool = True, communication_dtype: CommunicationDType = CommunicationDType.DEFAULT, + root_inv_method: RootInvMethod = RootInvMethod.PSEUDO_EIGEN, ): super(ShampooPreconditioner, self).__init__( @@ -463,6 +481,7 @@ def __init__( self._start_preconditioning_step = start_preconditioning_step self._use_protected_eigh = use_protected_eigh self._communication_dtype = communication_dtype + self._root_inv_method = root_inv_method # Compute root. self._root = self._get_root_from_exponent_override( @@ -785,6 +804,7 @@ def compute_root_inverse(self) -> None: exponent_multiplier=self._exponent_multiplier, is_diagonal=preconditioner.is_diagonal, retry_double_precision=self._use_protected_eigh, + root_inv_method=self._root_inv_method, ).to(dtype=self._dtype) # check if we encounter NaN or inf values in computed inverse matrix. @@ -888,6 +908,8 @@ class BlockShampooPreconditioner(DistributedPreconditioner): 3. Otherwise, re-uses previous inverse factor matrix when both root inverse computations fail. use_dtensor (bool): Flag for using DTensor. Requires PyTorch 2 nightly. Otherwise, uses Tensor. (Default: True) communication_dtype (CommunicationDType): Datatype for communication between ranks. (Default: DEFAULT) + root_inv_method (RootInvMethod): Strategy for computing (inverse) preconditioner roots. (Default: + RootInvMethod.PSEUDO_EIGEN) """ @@ -914,6 +936,7 @@ def __init__( use_protected_eigh: bool = True, use_dtensor: bool = True, communication_dtype: CommunicationDType = CommunicationDType.DEFAULT, + root_inv_method: RootInvMethod = RootInvMethod.PSEUDO_EIGEN, ): super(BlockShampooPreconditioner, self).__init__( param, @@ -975,6 +998,7 @@ def __init__( use_protected_eigh=use_protected_eigh, use_dtensor=use_dtensor, communication_dtype=communication_dtype, + root_inv_method=root_inv_method, ) self._split_preconditioners.append(preconditioner) self._parameter_count += preconditioner.parameter_count @@ -1182,7 +1206,8 @@ class AdagradGrafting(Grafting): dist_buffer (Optional[Tensor]): Buffer for distributed computation. (Default: None) use_dtensor (bool): Flag for using DTensor. Requires PyTorch 2 nightly. Otherwise, uses Tensor. (Default: True) communication_dtype (CommunicationDType): Datatype for communication between ranks. (Default: DEFAULT) - + root_inv_method (RootInvMethod): Strategy for computing (inverse) preconditioner roots. (Default: + RootInvMethod.PSEUDO_EIGEN) """ def __init__( @@ -1197,6 +1222,7 @@ def __init__( dist_buffer: Optional[Tensor] = None, use_dtensor: bool = True, communication_dtype: CommunicationDType = CommunicationDType.DEFAULT, + root_inv_method: RootInvMethod = RootInvMethod.PSEUDO_EIGEN, ): super(AdagradGrafting, self).__init__(param) self._preconditioner = AdagradPreconditioner( @@ -1207,7 +1233,9 @@ def __init__( group=group, group_source_rank=group_source_rank, dist_buffer=dist_buffer, + use_dtensor=use_dtensor, communication_dtype=communication_dtype, + root_inv_method=root_inv_method, ) self.normalize_gradient = normalize_gradient self._parameter_count += self._preconditioner.parameter_count diff --git a/distributed_shampoo/tests/matrix_functions_test.py b/distributed_shampoo/tests/matrix_functions_test.py index 49c3d58..3f2b047 100644 --- a/distributed_shampoo/tests/matrix_functions_test.py +++ b/distributed_shampoo/tests/matrix_functions_test.py @@ -335,6 +335,19 @@ def A(n): eig_sols, ) + def test_matrix_root_pseudo_eigen(self): + A = torch.tensor([[2.0, 0.0], [0.0, 0.0]]) + X, L, Q = _matrix_root_eigen( + A=A, + root=1, + epsilon=0.0, + make_positive_semidefinite=False, + use_pseudo_inverse=True, + inverse=True, + ) + target_pseudo_inv = torch.tensor([[0.5, 0.0], [0.0, 0.0]]) + torch.testing.assert_close(X, target_pseudo_inv) + def test_matrix_root_eigen_nonpositive_root(self): A = torch.tensor([[-1.0, 0.0], [0.0, 2.0]]) root = -1