Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable use of pseudoinverse in computation of preconditioners #11

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions distributed_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
GraftingType,
LargeDimMethod,
ShampooPreconditioner,
RootInvMethod
)
from torch.nn import Parameter

Expand Down Expand Up @@ -62,7 +63,7 @@ class DistributedShampoo(torch.optim.Optimizer):
Hao-Jun Michael Shi (Meta Platforms, Inc.)
Tsung-Hsien Lee
Shintaro Iwasaki (Meta Platforms, Inc.)
Jose Gallego-Posada (MILA / Meta Platforms, Inc.)
Jose Gallego-Posada (Mila / Meta Platforms, Inc.)

with contributions and support from:

Expand Down Expand Up @@ -181,6 +182,8 @@ class DistributedShampoo(torch.optim.Optimizer):
precondition_frequency (int): frequency for computing root inverse preconditioner (Default: 1)
start_preconditioning_step (int): iteration to start computing inverse preconditioner. If -1, uses
the same value as precondition_frequency. (Default: -1)
root_inv_method (RootInvMethod): Strategy for computing (inverse) preconditioner roots. (Default:
RootInvMethod.PSEUDO_EIGEN)
exponent_override (int, List[int]): inverse root to use in Shampoo. If a list [l1, l2, ..., lp], then we will
use -1 / l1 for 1-D tensor (vectors), -1 / l2 for 2-D tensors (matrices), and so on. If the order of the
tensor exceeds the order of the tensor, reverts to the default value. If 0 is used, uses the default inverse
Expand Down Expand Up @@ -221,6 +224,7 @@ def __init__(
max_preconditioner_dim: int = 1024,
precondition_frequency: int = 1,
start_preconditioning_step: int = -1,
root_inv_method: RootInvMethod = RootInvMethod.PSEUDO_EIGEN,
exponent_override: Union[int, List[int]] = 0,
exponent_multiplier: float = 1.0,
use_nesterov: bool = False,
Expand Down Expand Up @@ -338,6 +342,7 @@ def __init__(
# Initialize algorithm-related fields.
self._max_preconditioner_dim = max_preconditioner_dim
self._precondition_frequency = precondition_frequency
self._root_inv_method = root_inv_method
self._exponent_override = exponent_override
self._exponent_multiplier = exponent_multiplier
self._num_trainers_per_group = num_trainers_per_group
Expand Down Expand Up @@ -366,7 +371,9 @@ def __init__(
logger.warning(
f"start_preconditioning_step set to -1. Setting start_preconditioning_step equal to precondition frequency {precondition_frequency} by default."
)
elif start_preconditioning_step < precondition_frequency:
elif 0 < start_preconditioning_step < precondition_frequency:
# If start_preconditioning_step==0, we allow an initial preconditioned update
# with the first estimate of the preconditioner.
raise ValueError(
f"Invalid start_preconditioning_step value: {start_preconditioning_step}. Must be >= {precondition_frequency = }."
)
Expand Down Expand Up @@ -443,6 +450,7 @@ def _initialize_preconditioners_and_steps(
use_protected_eigh=self._use_protected_eigh,
use_dtensor=self._use_dtensor,
communication_dtype=self._communication_dtype,
root_inv_method=self._root_inv_method,
)
preconditioner_count += len(
state[PRECONDITIONERS].get_split_dist_buffers()
Expand All @@ -469,6 +477,7 @@ def _initialize_preconditioners_and_steps(
dist_buffer=dist_buffer,
use_dtensor=self._use_dtensor,
communication_dtype=self._communication_dtype,
root_inv_method=self._root_inv_method,
)
if torch.any(dims > self._max_preconditioner_dim)
else ShampooPreconditioner(
Expand All @@ -491,6 +500,7 @@ def _initialize_preconditioners_and_steps(
use_protected_eigh=self._use_protected_eigh,
use_dtensor=self._use_dtensor,
communication_dtype=self._communication_dtype,
root_inv_method=self._root_inv_method,
)
)

Expand Down Expand Up @@ -524,6 +534,7 @@ def _initialize_preconditioners_and_steps(
use_protected_eigh=self._use_protected_eigh,
use_dtensor=self._use_dtensor,
communication_dtype=self._communication_dtype,
root_inv_method=self._root_inv_method,
)

else:
Expand Down Expand Up @@ -814,10 +825,8 @@ def step(self, closure=None):

# Computes root inverse of all preconditioners every self._precondition_frequency
# after the self._start_preconditioning_step iteration.
if (
iteration % self._precondition_frequency == 0
and iteration >= self._start_preconditioning_step
):
is_precond_step = (iteration == 1) or (iteration % self._precondition_frequency == 0)
if (is_precond_step and iteration >= self._start_preconditioning_step):
self._compute_root_inverse()

if self._debug_mode:
Expand Down
46 changes: 31 additions & 15 deletions distributed_shampoo/matrix_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class NewtonConvergenceFlag(enum.Enum):
class RootInvMethod(enum.Enum):
EIGEN = 0
NEWTON = 1
PSEUDO_EIGEN = 2


def check_diagonal(A: Tensor) -> Tensor:
Expand All @@ -46,7 +47,7 @@ def matrix_inverse_root(
root: int,
epsilon: float = 0.0,
exponent_multiplier: float = 1.0,
root_inv_method: RootInvMethod = RootInvMethod.EIGEN,
root_inv_method: RootInvMethod = RootInvMethod.PSEUDO_EIGEN,
max_iterations: int = 1000,
tolerance: float = 1e-6,
is_diagonal: Union[Tensor, bool] = False,
Expand Down Expand Up @@ -91,15 +92,17 @@ def matrix_inverse_root(
inverse=True,
exponent_multiplier=exponent_multiplier,
return_full_matrix=True,
use_pseudo_inverse=(root_inv_method == RootInvMethod.PSEUDO_EIGEN),
)
elif root_inv_method == RootInvMethod.EIGEN:
elif root_inv_method in [RootInvMethod.EIGEN, RootInvMethod.PSEUDO_EIGEN]:
X, _, _ = _matrix_root_eigen(
A=A,
root=root,
epsilon=epsilon,
inverse=True,
exponent_multiplier=exponent_multiplier,
retry_double_precision=retry_double_precision,
use_pseudo_inverse=(root_inv_method == RootInvMethod.PSEUDO_EIGEN),
)
elif root_inv_method == RootInvMethod.NEWTON:
if exponent_multiplier != 1.0:
Expand Down Expand Up @@ -135,6 +138,7 @@ def matrix_root_diagonal(
inverse: bool = True,
exponent_multiplier: float = 1.0,
return_full_matrix: bool = False,
use_pseudo_inverse: bool = True,
) -> Tensor:
"""Computes matrix inverse root for a diagonal matrix by taking inverse square root of diagonal entries.

Expand All @@ -143,7 +147,9 @@ def matrix_root_diagonal(
root (int): Root of interest. Any natural number.
epsilon (float): Adds epsilon * I to matrix before taking matrix root. (Default: 0.0)
inverse (bool): Returns inverse root matrix. (Default: True)
return_full_matrix (bool): Returns full matrix by taking torch.diag of diagonal entries. (bool: False)
return_full_matrix (bool): Returns full matrix by taking torch.diag of diagonal entries. (Default: False)
use_pseudo_inverse (bool): Computes the matrix pseudo-inverse root by discarding negative eigenvalues/diagonal
elements. To use this option, `inverse` must be True. (Default: False)

Returns:
X (Tensor): Inverse root of diagonal entries.
Expand All @@ -157,7 +163,6 @@ def matrix_root_diagonal(
elif order > 2:
raise ValueError("Matrix is not 2-dimensional!")

# check if root is positive integer
if root <= 0:
raise ValueError(f"Root {root} should be positive!")

Expand All @@ -166,7 +171,11 @@ def matrix_root_diagonal(
if inverse:
alpha = -alpha

X = (A + epsilon).pow(alpha)
if use_pseudo_inverse:
X = torch.where(A <= 0.0, torch.zeros_like(A), A.pow(alpha))
else:
X = (A + epsilon).pow(alpha)

return torch.diag(X) if return_full_matrix else X


Expand All @@ -178,6 +187,7 @@ def _matrix_root_eigen(
exponent_multiplier: float = 1.0,
make_positive_semidefinite: bool = True,
retry_double_precision: bool = True,
use_pseudo_inverse: bool = True,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Compute matrix (inverse) root using eigendecomposition of symmetric positive (semi-)definite matrix.

Expand All @@ -194,6 +204,8 @@ def _matrix_root_eigen(
make_positive_semidefinite (bool): Perturbs matrix eigenvalues to ensure it is numerically positive semi-definite. (Default: True)
retry_double_precision (bool): Flag for re-trying eigendecomposition with higher precision if lower precision fails due
to CuSOLVER failure. (Default: True)
use_pseudo_inverse (bool): Computes the matrix pseudo-inverse root by discarding negative eigenvalues elements.
To use this option, `inverse` must be True. (Default: True)

Returns:
X (Tensor): (Inverse) root of matrix. Same dimensions as A.
Expand All @@ -202,7 +214,6 @@ def _matrix_root_eigen(

"""

# check if root is positive integer
if root <= 0:
raise ValueError(f"Root {root} should be positive!")

Expand All @@ -224,17 +235,22 @@ def _matrix_root_eigen(
else:
raise exception

lambda_min = torch.min(L)

# make eigenvalues >= 0 (if necessary)
if make_positive_semidefinite:
L += -torch.minimum(lambda_min, torch.as_tensor(0.0))

# add epsilon
L += epsilon
if use_pseudo_inverse:
# Filter the eigenvalues based on the numerical rank of the matrix
# The procedure below mimics the steps described in the documentation of
# https://pytorch.org/docs/stable/generated/torch.linalg.matrix_rank.html
rtol = L.numel() * torch.finfo(A.dtype).eps
spectrum_cutoff = rtol * L.max().relu()
power_L = torch.where(L <= spectrum_cutoff, torch.zeros_like(L), L.pow(alpha))
else:
lambda_min = torch.min(L)
# make eigenvalues >= 0 (if necessary)
if make_positive_semidefinite:
L += -torch.minimum(lambda_min, torch.as_tensor(0.0))
power_L = (L + epsilon).pow(alpha)

# compute inverse preconditioner
X = Q * L.pow(alpha).unsqueeze(0) @ Q.T
X = Q * power_L.unsqueeze(0) @ Q.T

return X, L, Q

Expand Down
54 changes: 41 additions & 13 deletions distributed_shampoo/shampoo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
check_diagonal,
compute_matrix_root_inverse_residuals,
matrix_inverse_root,
RootInvMethod
)
from distributed_shampoo.optimizer_modules import OptimizerModule
from distributed_shampoo.shampoo_dist_utils import (
Expand Down Expand Up @@ -288,7 +289,8 @@ class AdagradPreconditioner(DistributedPreconditioner):
dist_buffer (Optional[Tensor]): Buffer for distributed computation. (Default: None)
use_dtensor (bool): Flag for using DTensor. Requires PyTorch 2 nightly. Otherwise, uses Tensor. (Default: True)
communication_dtype (CommunicationDType): Datatype for communication between ranks. (Default: DEFAULT)

root_inv_method (RootInvMethod): Strategy for computing (inverse) preconditioner roots. (Default:
RootInvMethod.PSEUDO_EIGEN)
"""

def __init__(
Expand All @@ -303,9 +305,10 @@ def __init__(
dist_buffer: Optional[Tensor] = None,
use_dtensor: bool = True,
communication_dtype: CommunicationDType = CommunicationDType.DEFAULT,
root_inv_method: RootInvMethod = RootInvMethod.PSEUDO_EIGEN,
):
super(AdagradPreconditioner, self).__init__(
param, group, group_source_rank, dist_buffer, communication_dtype
param, group, group_source_rank, dist_buffer, communication_dtype,
)
self._beta2 = beta2
self._epsilon = epsilon
Expand All @@ -321,6 +324,11 @@ def __init__(
self._bias_correction2 = 1.0
self._parameter_count += self._preconditioner.numel()

if root_inv_method not in [RootInvMethod.EIGEN, RootInvMethod.PSEUDO_EIGEN]:
raise ValueError(f"Adagrad preconditioner only supports eigendecomposition variants for `root_inv_method` \
but received {root_inv_method}.")
self._root_inv_method = root_inv_method

if self._idx is not None:
self._preconditioner_idx = str(self._idx) + "." + str(0)
logger.info(
Expand All @@ -345,23 +353,30 @@ def precondition(self, grad: Tensor, iteration: int) -> Tensor:
if not self._on_source_rank:
return grad

denom = (
(use_local_tensor(self._preconditioner) / self._bias_correction2)
.sqrt()
.add_(self._epsilon)
)
denom = (use_local_tensor(self._preconditioner) / self._bias_correction2).sqrt()
if self._root_inv_method == RootInvMethod.PSEUDO_EIGEN:
# Put ones in the locations of zero entries so that the later division has no effect
torch.where(denom == 0, torch.ones_like(denom), denom, out=denom)
else:
# If not using pseudoinverse, use epsilon to improve numerical stability
denom.add_(self._epsilon)

grad.div_(denom)

return grad

def compute_norm(self, grad: Tensor, iteration: int):
if not self._on_source_rank:
return torch.as_tensor(1.0) # return cheap tensor

denom = (
(use_local_tensor(self._preconditioner) / self._bias_correction2)
.sqrt()
.add_(self._epsilon)
)
denom = (use_local_tensor(self._preconditioner) / self._bias_correction2).sqrt()
if self._root_inv_method == RootInvMethod.PSEUDO_EIGEN:
# Put ones in the locations of zero entries so that the later division has no effect
torch.where(denom == 0, torch.ones_like(denom), denom, out=denom)
else:
# If not using pseudoinverse, use epsilon to improve numerical stability
denom.add_(self._epsilon)

adagrad_nrm = torch.linalg.norm(grad / denom)
return adagrad_nrm

Expand Down Expand Up @@ -418,6 +433,8 @@ class ShampooPreconditioner(DistributedPreconditioner):
3. Otherwise, re-uses previous inverse factor matrix when both root inverse computations fail.
use_dtensor (bool): Flag for using DTensor. Requires PyTorch 2 nightly. Otherwise, uses Tensor. (Default: True)
communication_dtype (CommunicationDType): Datatype for communication between ranks. (Default: DEFAULT)
root_inv_method (RootInvMethod): Strategy for computing (inverse) preconditioner roots. (Default:
RootInvMethod.PSEUDO_EIGEN)

"""

Expand All @@ -442,6 +459,7 @@ def __init__(
use_protected_eigh: bool = True,
use_dtensor: bool = True,
communication_dtype: CommunicationDType = CommunicationDType.DEFAULT,
root_inv_method: RootInvMethod = RootInvMethod.PSEUDO_EIGEN,
):

super(ShampooPreconditioner, self).__init__(
Expand All @@ -463,6 +481,7 @@ def __init__(
self._start_preconditioning_step = start_preconditioning_step
self._use_protected_eigh = use_protected_eigh
self._communication_dtype = communication_dtype
self._root_inv_method = root_inv_method

# Compute root.
self._root = self._get_root_from_exponent_override(
Expand Down Expand Up @@ -785,6 +804,7 @@ def compute_root_inverse(self) -> None:
exponent_multiplier=self._exponent_multiplier,
is_diagonal=preconditioner.is_diagonal,
retry_double_precision=self._use_protected_eigh,
root_inv_method=self._root_inv_method,
).to(dtype=self._dtype)

# check if we encounter NaN or inf values in computed inverse matrix.
Expand Down Expand Up @@ -888,6 +908,8 @@ class BlockShampooPreconditioner(DistributedPreconditioner):
3. Otherwise, re-uses previous inverse factor matrix when both root inverse computations fail.
use_dtensor (bool): Flag for using DTensor. Requires PyTorch 2 nightly. Otherwise, uses Tensor. (Default: True)
communication_dtype (CommunicationDType): Datatype for communication between ranks. (Default: DEFAULT)
root_inv_method (RootInvMethod): Strategy for computing (inverse) preconditioner roots. (Default:
RootInvMethod.PSEUDO_EIGEN)

"""

Expand All @@ -914,6 +936,7 @@ def __init__(
use_protected_eigh: bool = True,
use_dtensor: bool = True,
communication_dtype: CommunicationDType = CommunicationDType.DEFAULT,
root_inv_method: RootInvMethod = RootInvMethod.PSEUDO_EIGEN,
):
super(BlockShampooPreconditioner, self).__init__(
param,
Expand Down Expand Up @@ -975,6 +998,7 @@ def __init__(
use_protected_eigh=use_protected_eigh,
use_dtensor=use_dtensor,
communication_dtype=communication_dtype,
root_inv_method=root_inv_method,
)
self._split_preconditioners.append(preconditioner)
self._parameter_count += preconditioner.parameter_count
Expand Down Expand Up @@ -1182,7 +1206,8 @@ class AdagradGrafting(Grafting):
dist_buffer (Optional[Tensor]): Buffer for distributed computation. (Default: None)
use_dtensor (bool): Flag for using DTensor. Requires PyTorch 2 nightly. Otherwise, uses Tensor. (Default: True)
communication_dtype (CommunicationDType): Datatype for communication between ranks. (Default: DEFAULT)

root_inv_method (RootInvMethod): Strategy for computing (inverse) preconditioner roots. (Default:
RootInvMethod.PSEUDO_EIGEN)
"""

def __init__(
Expand All @@ -1197,6 +1222,7 @@ def __init__(
dist_buffer: Optional[Tensor] = None,
use_dtensor: bool = True,
communication_dtype: CommunicationDType = CommunicationDType.DEFAULT,
root_inv_method: RootInvMethod = RootInvMethod.PSEUDO_EIGEN,
):
super(AdagradGrafting, self).__init__(param)
self._preconditioner = AdagradPreconditioner(
Expand All @@ -1207,7 +1233,9 @@ def __init__(
group=group,
group_source_rank=group_source_rank,
dist_buffer=dist_buffer,
use_dtensor=use_dtensor,
communication_dtype=communication_dtype,
root_inv_method=root_inv_method,
)
self.normalize_gradient = normalize_gradient
self._parameter_count += self._preconditioner.parameter_count
Expand Down
Loading