Skip to content

Commit

Permalink
Adjust UI and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Dec 6, 2024
1 parent 9291d15 commit dac3ac1
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 20 deletions.
15 changes: 9 additions & 6 deletions distributed_shampoo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ A few notes on hyperparameters:

- We allow for decoupled and coupled weight decay. If one sets `use_decoupled_weight_decay=True`, then you are enabling AdamW-style weight decay, while `use_decoupled_weight_decay=False` corresponds to the normal L2-regularization style weight decay.

- When setting `preconditioner_computation_config` as an instance of `EigenvalueCorrectionConfig`, there is typically no need to use learning rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet. Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial.
- When setting `preconditioner_computation_config` as an instance of `EigenvalueCorrectedShampooPreconditionerConfig` (see Example 5), there is typically no need to use learning rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet. Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial.

### Example 1: [SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html) with Momentum

Expand Down Expand Up @@ -221,7 +221,7 @@ optimizer = DistributedShampoo(
)
```

### Example 5: eigenvalue-corrected Shampoo (SOAP)
### Example 5: eigenvalue-corrected Shampoo/SOAP

If we previously used the optimizer:
```python
Expand All @@ -241,7 +241,10 @@ optimizer = AdamW(
we would instead use:
```python
import torch
from distributed_shampoo import DistributedShampoo, EighEigenvalueCorrectionConfig
from distributed_shampoo import (
DistributedShampoo,
DefaultEigenvalueCorrectedShampooConfig,
)

model = instantiate_model()

Expand All @@ -254,9 +257,9 @@ optimizer = DistributedShampoo(
max_preconditioner_dim=8192,
precondition_frequency=100,
use_decoupled_weight_decay=True,
# This can also be set to `QREigenvalueCorrectionConfig` which is less expensive
# and might therefore allow for a smaller `precondition_frequency`.
preconditioner_computation_config=EighEigenvalueCorrectionConfig(),
# This can also be set to `DefaultSOAPConfig` which uses QR decompositions, hence is
# less expensive and might thereby allow for a smaller `precondition_frequency`.
preconditioner_computation_config=DefaultEigenvalueCorrectedShampooConfig,
)
```

Expand Down
38 changes: 24 additions & 14 deletions distributed_shampoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,30 @@
AdamGraftingConfig,
CommunicationDType,
DDPShampooConfig,
DefaultEigenvalueCorrectedShampooConfig,
DefaultShampooConfig,
DefaultSOAPConfig,
DistributedConfig,
EigenvalueCorrectedShampooPreconditionerConfig,
FSDPShampooConfig,
FullyShardShampooConfig,
GraftingConfig,
HSDPShampooConfig,
PrecisionConfig,
PreconditionerComputationConfig,
RMSpropGraftingConfig,
SGDGraftingConfig,
ShampooPreconditionerConfig,
ShampooPT2CompileConfig,
)
from distributed_shampoo.utils.shampoo_fsdp_utils import compile_fsdp_parameter_metadata
from matrix_functions_types import (
CoupledHigherOrderConfig,
CoupledNewtonConfig,
DefaultEigenConfig,
DefaultEighEigenvalueCorrectionConfig,
EigenConfig,
EigenvalueCorrectionConfig,
EighEigenvalueCorrectionConfig,
PreconditionerComputationConfig,
QREigenvalueCorrectionConfig,
EigenvectorConfig,
MatrixFunctionConfig,
RootInvConfig,
)

Expand All @@ -58,15 +61,22 @@
"PrecisionConfig",
# `preconditioner_computation_config` options.
"PreconditionerComputationConfig", # Abstract base class.
"RootInvConfig", # Abstract base class (based on `PreconditionerComputationConfig`).
"EigenConfig",
"DefaultEigenConfig", # Default `RootInvConfig`.
"CoupledNewtonConfig",
"CoupledHigherOrderConfig",
"EigenvalueCorrectionConfig", # Abstract base class (based on `PreconditionerComputationConfig`).
"EighEigenvalueCorrectionConfig",
"DefaultEighEigenvalueCorrectionConfig", # Default `EigenvalueCorrectionConfig`.
"QREigenvalueCorrectionConfig",
"ShampooPreconditionerConfig", # Based on `PreconditionerComputationConfig`.
"DefaultShampooConfig", # Default `ShampooPreconditionerConfig` using `EigenConfig`.
"EigenvalueCorrectedShampooPreconditionerConfig", # Based on `PreconditionerComputationConfig`.
"DefaultEigenvalueCorrectedShampooConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `EighConfig`.
"DefaultSOAPConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `QRConfig`.
# matrix functions configs.
"MatrixFunctionConfig", # Abstract base class.
"RootInvConfig", # Abstract base class (based on `MatrixFunctionConfig`).
"EigenConfig", # Based on `RootInvConfig`.
"DefaultEigenConfig", # Default `RootInvConfig` using `EigenConfig`.
"CoupledNewtonConfig", # Based on `RootInvConfig`.
"CoupledHigherOrderConfig", # Based on `RootInvConfig`.
"EigenvectorConfig", # Abstract base class (based on `MatrixFunctionConfig`).
"EighConfig", # Based on `EigenvectorConfig`.
"DefaultEighConfig", # Default `EigenvectorConfig` using `EighConfig`.
"QRConfig", # Based on `EigenvectorConfig`.
# Other utilities.
"compile_fsdp_parameter_metadata", # For `FSDPShampooConfig` and `HSDPShampooConfig`.
"CommunicationDType", # For `DDPShampooConfig` and `HSDPShampooConfig`.
Expand Down

0 comments on commit dac3ac1

Please sign in to comment.