From dac3ac1e467bde41a2f1547190fac6d1128a7238 Mon Sep 17 00:00:00 2001 From: runame Date: Fri, 6 Dec 2024 15:35:14 +0000 Subject: [PATCH] Adjust UI and docs --- distributed_shampoo/README.md | 15 +++++++------ distributed_shampoo/__init__.py | 38 +++++++++++++++++++++------------ 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/distributed_shampoo/README.md b/distributed_shampoo/README.md index 0d35aa2..9b1e48d 100644 --- a/distributed_shampoo/README.md +++ b/distributed_shampoo/README.md @@ -64,7 +64,7 @@ A few notes on hyperparameters: - We allow for decoupled and coupled weight decay. If one sets `use_decoupled_weight_decay=True`, then you are enabling AdamW-style weight decay, while `use_decoupled_weight_decay=False` corresponds to the normal L2-regularization style weight decay. -- When setting `preconditioner_computation_config` as an instance of `EigenvalueCorrectionConfig`, there is typically no need to use learning rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet. Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial. +- When setting `preconditioner_computation_config` as an instance of `EigenvalueCorrectedShampooPreconditionerConfig` (see Example 5), there is typically no need to use learning rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet. Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial. ### Example 1: [SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html) with Momentum @@ -221,7 +221,7 @@ optimizer = DistributedShampoo( ) ``` -### Example 5: eigenvalue-corrected Shampoo (SOAP) +### Example 5: eigenvalue-corrected Shampoo/SOAP If we previously used the optimizer: ```python @@ -241,7 +241,10 @@ optimizer = AdamW( we would instead use: ```python import torch -from distributed_shampoo import DistributedShampoo, EighEigenvalueCorrectionConfig +from distributed_shampoo import ( + DistributedShampoo, + DefaultEigenvalueCorrectedShampooConfig, +) model = instantiate_model() @@ -254,9 +257,9 @@ optimizer = DistributedShampoo( max_preconditioner_dim=8192, precondition_frequency=100, use_decoupled_weight_decay=True, - # This can also be set to `QREigenvalueCorrectionConfig` which is less expensive - # and might therefore allow for a smaller `precondition_frequency`. - preconditioner_computation_config=EighEigenvalueCorrectionConfig(), + # This can also be set to `DefaultSOAPConfig` which uses QR decompositions, hence is + # less expensive and might thereby allow for a smaller `precondition_frequency`. + preconditioner_computation_config=DefaultEigenvalueCorrectedShampooConfig, ) ``` diff --git a/distributed_shampoo/__init__.py b/distributed_shampoo/__init__.py index 2b2c45f..6ba8c2a 100644 --- a/distributed_shampoo/__init__.py +++ b/distributed_shampoo/__init__.py @@ -13,14 +13,20 @@ AdamGraftingConfig, CommunicationDType, DDPShampooConfig, + DefaultEigenvalueCorrectedShampooConfig, + DefaultShampooConfig, + DefaultSOAPConfig, DistributedConfig, + EigenvalueCorrectedShampooPreconditionerConfig, FSDPShampooConfig, FullyShardShampooConfig, GraftingConfig, HSDPShampooConfig, PrecisionConfig, + PreconditionerComputationConfig, RMSpropGraftingConfig, SGDGraftingConfig, + ShampooPreconditionerConfig, ShampooPT2CompileConfig, ) from distributed_shampoo.utils.shampoo_fsdp_utils import compile_fsdp_parameter_metadata @@ -28,12 +34,9 @@ CoupledHigherOrderConfig, CoupledNewtonConfig, DefaultEigenConfig, - DefaultEighEigenvalueCorrectionConfig, EigenConfig, - EigenvalueCorrectionConfig, - EighEigenvalueCorrectionConfig, - PreconditionerComputationConfig, - QREigenvalueCorrectionConfig, + EigenvectorConfig, + MatrixFunctionConfig, RootInvConfig, ) @@ -58,15 +61,22 @@ "PrecisionConfig", # `preconditioner_computation_config` options. "PreconditionerComputationConfig", # Abstract base class. - "RootInvConfig", # Abstract base class (based on `PreconditionerComputationConfig`). - "EigenConfig", - "DefaultEigenConfig", # Default `RootInvConfig`. - "CoupledNewtonConfig", - "CoupledHigherOrderConfig", - "EigenvalueCorrectionConfig", # Abstract base class (based on `PreconditionerComputationConfig`). - "EighEigenvalueCorrectionConfig", - "DefaultEighEigenvalueCorrectionConfig", # Default `EigenvalueCorrectionConfig`. - "QREigenvalueCorrectionConfig", + "ShampooPreconditionerConfig", # Based on `PreconditionerComputationConfig`. + "DefaultShampooConfig", # Default `ShampooPreconditionerConfig` using `EigenConfig`. + "EigenvalueCorrectedShampooPreconditionerConfig", # Based on `PreconditionerComputationConfig`. + "DefaultEigenvalueCorrectedShampooConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `EighConfig`. + "DefaultSOAPConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `QRConfig`. + # matrix functions configs. + "MatrixFunctionConfig", # Abstract base class. + "RootInvConfig", # Abstract base class (based on `MatrixFunctionConfig`). + "EigenConfig", # Based on `RootInvConfig`. + "DefaultEigenConfig", # Default `RootInvConfig` using `EigenConfig`. + "CoupledNewtonConfig", # Based on `RootInvConfig`. + "CoupledHigherOrderConfig", # Based on `RootInvConfig`. + "EigenvectorConfig", # Abstract base class (based on `MatrixFunctionConfig`). + "EighConfig", # Based on `EigenvectorConfig`. + "DefaultEighConfig", # Default `EigenvectorConfig` using `EighConfig`. + "QRConfig", # Based on `EigenvectorConfig`. # Other utilities. "compile_fsdp_parameter_metadata", # For `FSDPShampooConfig` and `HSDPShampooConfig`. "CommunicationDType", # For `DDPShampooConfig` and `HSDPShampooConfig`.