Skip to content

Commit

Permalink
Add support for specifying a PreconditionerComputationType in the exa…
Browse files Browse the repository at this point in the history
…mples
  • Loading branch information
runame committed Nov 4, 2024
1 parent f3451cd commit bd5dc3a
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 1 deletion.
3 changes: 3 additions & 0 deletions distributed_shampoo/examples/ddp_cifar10_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,15 @@
computation_dtype=args.computation_dtype.value,
factor_matrix_dtype=args.factor_matrix_dtype.value,
inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value,
corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value,
factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value,
filtered_grad_dtype=args.filtered_grad_dtype.value,
momentum_dtype=args.momentum_dtype.value,
grafting_state_dtype=args.grafting_state_dtype.value,
),
use_protected_eigh=args.use_protected_eigh,
track_root_inv_residuals=args.track_root_inv_residuals,
preconditioner_computation_type=args.preconditioner_computation_type,
)

# checks for checkpointing
Expand Down
5 changes: 4 additions & 1 deletion distributed_shampoo/examples/default_cifar10_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def train_default_model(

# instantiate data loader. Note that this is a single GPU training example,
# so we do not need to instantiate a sampler.
data_loader, _ = get_data_loader_and_sampler(args.data_path, 1, 1, args.batch_size)
data_loader, _ = get_data_loader_and_sampler(args.data_path, 1, 0, args.batch_size)

# instantiate optimizer (SGD, Adam, DistributedShampoo)
optimizer = instantiate_optimizer(
Expand Down Expand Up @@ -135,12 +135,15 @@ def train_default_model(
computation_dtype=args.computation_dtype.value,
factor_matrix_dtype=args.factor_matrix_dtype.value,
inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value,
corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value,
factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value,
filtered_grad_dtype=args.filtered_grad_dtype.value,
momentum_dtype=args.momentum_dtype.value,
grafting_state_dtype=args.grafting_state_dtype.value,
),
use_protected_eigh=args.use_protected_eigh,
track_root_inv_residuals=args.track_root_inv_residuals,
preconditioner_computation_type=args.preconditioner_computation_type,
)

# train model
Expand Down
3 changes: 3 additions & 0 deletions distributed_shampoo/examples/fsdp_cifar10_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,15 @@
computation_dtype=args.computation_dtype.value,
factor_matrix_dtype=args.factor_matrix_dtype.value,
inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value,
corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value,
factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value,
filtered_grad_dtype=args.filtered_grad_dtype.value,
momentum_dtype=args.momentum_dtype.value,
grafting_state_dtype=args.grafting_state_dtype.value,
),
use_protected_eigh=args.use_protected_eigh,
track_root_inv_residuals=args.track_root_inv_residuals,
preconditioner_computation_type=args.preconditioner_computation_type,
)

# train model
Expand Down
3 changes: 3 additions & 0 deletions distributed_shampoo/examples/fully_shard_cifar10_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,15 @@ def create_model_and_optimizer_and_loss_fn(args, device):
computation_dtype=args.computation_dtype.value,
factor_matrix_dtype=args.factor_matrix_dtype.value,
inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value,
corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value,
factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value,
filtered_grad_dtype=args.filtered_grad_dtype.value,
momentum_dtype=args.momentum_dtype.value,
grafting_state_dtype=args.grafting_state_dtype.value,
),
use_protected_eigh=args.use_protected_eigh,
track_root_inv_residuals=args.track_root_inv_residuals,
preconditioner_computation_type=args.preconditioner_computation_type,
)
return model, optimizer, loss_function

Expand Down
3 changes: 3 additions & 0 deletions distributed_shampoo/examples/hsdp_cifar10_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,15 @@
computation_dtype=args.computation_dtype.value,
factor_matrix_dtype=args.factor_matrix_dtype.value,
inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value,
corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value,
factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value,
filtered_grad_dtype=args.filtered_grad_dtype.value,
momentum_dtype=args.momentum_dtype.value,
grafting_state_dtype=args.grafting_state_dtype.value,
),
use_protected_eigh=args.use_protected_eigh,
track_root_inv_residuals=args.track_root_inv_residuals,
preconditioner_computation_type=args.preconditioner_computation_type,
)

# train model
Expand Down
62 changes: 62 additions & 0 deletions distributed_shampoo/examples/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@
RMSpropGraftingConfig,
SGDGraftingConfig,
)
from matrix_functions_types import (
CoupledHigherOrderConfig,
CoupledNewtonConfig,
EigenConfig,
EighEigenvalueCorrectionConfig,
PreconditionerComputationConfig,
)
from torch import nn
from torchvision import datasets, transforms

Expand Down Expand Up @@ -62,6 +69,13 @@ class GraftingType(enum.Enum):
ADAM = 4


class PreconditionerComputationType(enum.Enum):
EIGEN_ROOT_INV = 0
COUPLED_NEWTON_ROOT_INV = 1
COUPLED_HIGHER_ORDER_ROOT_INV = 2
EIGH_EIGENVALUE_CORRECTION = 3


###### ARGPARSER ######
def enum_type_parse(s: str, enum_type: enum.Enum):
try:
Expand Down Expand Up @@ -195,6 +209,12 @@ def get_args():
action="store_true",
help="Use debug mode for examining root inverse residuals.",
)
parser.add_argument(
"--preconditioner-computation-type",
type=lambda t: enum_type_parse(t, PreconditionerComputationType),
default=PreconditionerComputationType.EIGEN_ROOT_INV,
help="Preconditioner computation method for Shampoo.",
)

# Arguments for grafting.
parser.add_argument(
Expand Down Expand Up @@ -235,6 +255,18 @@ def get_args():
default=DType.FP32,
help="Data type for storing Shampoo inverse factor matrices.",
)
parser.add_argument(
"--corrected-eigenvalues-dtype",
type=lambda t: enum_type_parse(t, DType),
default=DType.FP32,
help="Data type for storing corrected eigenvalues of Shampoo preconditioner.",
)
parser.add_argument(
"--factor-matrix-eigenvectors-dtype",
type=lambda t: enum_type_parse(t, DType),
default=DType.FP32,
help="Data type for storing Shampoo factor matrices eigenvectors.",
)
parser.add_argument(
"--filtered-grad-dtype",
type=lambda t: enum_type_parse(t, DType),
Expand Down Expand Up @@ -410,6 +442,7 @@ def instantiate_optimizer(
precision_config: Optional[PrecisionConfig],
use_protected_eigh: bool,
track_root_inv_residuals: bool,
preconditioner_computation_type: PreconditionerComputationType,
) -> torch.optim.Optimizer:
if optimizer_type == OptimizerType.SGD:
optimizer = torch.optim.SGD(
Expand Down Expand Up @@ -464,6 +497,9 @@ def instantiate_optimizer(
precision_config=precision_config,
use_protected_eigh=use_protected_eigh,
track_root_inv_residuals=track_root_inv_residuals,
preconditioner_computation_config=instantiate_preconditioner_computation_config(
preconditioner_computation_type
),
)
else:
raise ValueError(f"Invalid OptimizerType {optimizer_type}!")
Expand Down Expand Up @@ -501,6 +537,32 @@ def instantiate_grafting_config(
raise ValueError(f"Invalid GraftingType {grafting_type}!")


def instantiate_preconditioner_computation_config(
preconditioner_computation_type: PreconditionerComputationType,
) -> PreconditionerComputationConfig:
if preconditioner_computation_type == PreconditionerComputationType.EIGEN_ROOT_INV:
return EigenConfig()
elif (
preconditioner_computation_type
== PreconditionerComputationType.COUPLED_NEWTON_ROOT_INV
):
return CoupledNewtonConfig()
elif (
preconditioner_computation_type
== PreconditionerComputationType.COUPLED_HIGHER_ORDER_ROOT_INV
):
return CoupledHigherOrderConfig()
elif (
preconditioner_computation_type
== PreconditionerComputationType.EIGH_EIGENVALUE_CORRECTION
):
return EighEigenvalueCorrectionConfig()
else:
raise ValueError(
f"Invalid PreconditionerComputationType {preconditioner_computation_type}!"
)


###### DATA LOADER ######
def get_data_loader_and_sampler(
data_path: str, world_size: int, rank: int, local_batch_size: int
Expand Down

0 comments on commit bd5dc3a

Please sign in to comment.