Skip to content

Commit

Permalink
Open-sourced update on 12/19/2024 (facebookresearch#63)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#63

1. Add `HybridShardDistributor` (i.e., per-parameter HSDP or HSDP2) implemented by wz337 and hjmshi into `DistributedShampoo`.
2. Disable quantization functionality for now.

Differential Revision: D67398314
  • Loading branch information
tsunghsienlee authored and facebook-github-bot committed Dec 18, 2024
1 parent b5dd2f2 commit 45d0d95
Show file tree
Hide file tree
Showing 25 changed files with 1,605 additions and 1,198 deletions.
2 changes: 0 additions & 2 deletions distributed_shampoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
FullyShardShampooConfig,
GraftingConfig,
HSDPShampooConfig,
PrecisionConfig,
PreconditionerConfig,
RMSpropGraftingConfig,
SGDGraftingConfig,
Expand Down Expand Up @@ -58,7 +57,6 @@
"FullyShardShampooConfig",
"HSDPShampooConfig",
# `precision_config`.
"PrecisionConfig",
# `preconditioner_config` options.
"PreconditionerConfig", # Abstract base class.
"ShampooPreconditionerConfig", # Based on `PreconditionerConfig`.
Expand Down
367 changes: 146 additions & 221 deletions distributed_shampoo/distributed_shampoo.py

Large diffs are not rendered by default.

13 changes: 2 additions & 11 deletions distributed_shampoo/examples/ddp_cifar10_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch.distributed as dist
import torch.distributed.checkpoint as dist_checkpoint

from distributed_shampoo import DDPShampooConfig, DistributedShampoo, PrecisionConfig
from distributed_shampoo import DDPShampooConfig, DistributedShampoo
from distributed_shampoo.examples.trainer_utils import (
get_data_loader_and_sampler,
get_model_and_loss_fn,
Expand Down Expand Up @@ -123,16 +123,7 @@
num_trainers_per_group=args.num_trainers_per_group,
communicate_params=args.communicate_params,
),
precision_config=PrecisionConfig(
computation_dtype=args.computation_dtype.value,
factor_matrix_dtype=args.factor_matrix_dtype.value,
inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value,
corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value,
factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value,
filtered_grad_dtype=args.filtered_grad_dtype.value,
momentum_dtype=args.momentum_dtype.value,
grafting_state_dtype=args.grafting_state_dtype.value,
),
preconditioner_dtype=args.preconditioner_dtype,
use_protected_eigh=args.use_protected_eigh,
track_root_inv_residuals=args.track_root_inv_residuals,
preconditioner_computation_type=args.preconditioner_computation_type,
Expand Down
12 changes: 1 addition & 11 deletions distributed_shampoo/examples/default_cifar10_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import os

import torch
from distributed_shampoo import PrecisionConfig

from distributed_shampoo.examples.trainer_utils import (
get_data_loader_and_sampler,
Expand Down Expand Up @@ -134,16 +133,7 @@ def train_default_model(
use_merge_dims=args.use_merge_dims,
use_pytorch_compile=args.use_pytorch_compile,
distributed_config=None,
precision_config=PrecisionConfig(
computation_dtype=args.computation_dtype.value,
factor_matrix_dtype=args.factor_matrix_dtype.value,
inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value,
corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value,
factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value,
filtered_grad_dtype=args.filtered_grad_dtype.value,
momentum_dtype=args.momentum_dtype.value,
grafting_state_dtype=args.grafting_state_dtype.value,
),
preconditioner_dtype=args.preconditioner_dtype,
use_protected_eigh=args.use_protected_eigh,
track_root_inv_residuals=args.track_root_inv_residuals,
preconditioner_computation_type=args.preconditioner_computation_type,
Expand Down
17 changes: 2 additions & 15 deletions distributed_shampoo/examples/fsdp_cifar10_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@

import torch.distributed as dist

from distributed_shampoo import (
compile_fsdp_parameter_metadata,
FSDPShampooConfig,
PrecisionConfig,
)
from distributed_shampoo import compile_fsdp_parameter_metadata, FSDPShampooConfig
from distributed_shampoo.examples.trainer_utils import (
get_data_loader_and_sampler,
get_model_and_loss_fn,
Expand Down Expand Up @@ -119,16 +115,7 @@
distributed_config=FSDPShampooConfig(
param_to_metadata=compile_fsdp_parameter_metadata(model),
),
precision_config=PrecisionConfig(
computation_dtype=args.computation_dtype.value,
factor_matrix_dtype=args.factor_matrix_dtype.value,
inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value,
corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value,
factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value,
filtered_grad_dtype=args.filtered_grad_dtype.value,
momentum_dtype=args.momentum_dtype.value,
grafting_state_dtype=args.grafting_state_dtype.value,
),
preconditioner_dtype=args.preconditioner_dtype,
use_protected_eigh=args.use_protected_eigh,
track_root_inv_residuals=args.track_root_inv_residuals,
preconditioner_computation_type=args.preconditioner_computation_type,
Expand Down
17 changes: 2 additions & 15 deletions distributed_shampoo/examples/fully_shard_cifar10_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@
import torch.distributed as dist
import torch.distributed.checkpoint as dist_checkpoint

from distributed_shampoo import (
DistributedShampoo,
FullyShardShampooConfig,
PrecisionConfig,
)
from distributed_shampoo import DistributedShampoo, FullyShardShampooConfig
from distributed_shampoo.examples.trainer_utils import (
get_data_loader_and_sampler,
get_model_and_loss_fn,
Expand Down Expand Up @@ -139,16 +135,7 @@ def create_model_and_optimizer_and_loss_fn(args, device):
use_merge_dims=args.use_merge_dims,
use_pytorch_compile=args.use_pytorch_compile,
distributed_config=FullyShardShampooConfig(),
precision_config=PrecisionConfig(
computation_dtype=args.computation_dtype.value,
factor_matrix_dtype=args.factor_matrix_dtype.value,
inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value,
corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value,
factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value,
filtered_grad_dtype=args.filtered_grad_dtype.value,
momentum_dtype=args.momentum_dtype.value,
grafting_state_dtype=args.grafting_state_dtype.value,
),
preconditioner_dtype=args.preconditioner_dtype,
use_protected_eigh=args.use_protected_eigh,
track_root_inv_residuals=args.track_root_inv_residuals,
preconditioner_computation_type=args.preconditioner_computation_type,
Expand Down
17 changes: 2 additions & 15 deletions distributed_shampoo/examples/hsdp_cifar10_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@

import torch.distributed as dist

from distributed_shampoo import (
compile_fsdp_parameter_metadata,
HSDPShampooConfig,
PrecisionConfig,
)
from distributed_shampoo import compile_fsdp_parameter_metadata, HSDPShampooConfig

from distributed_shampoo.examples.trainer_utils import (
get_data_loader_and_sampler,
Expand Down Expand Up @@ -134,16 +130,7 @@
device_mesh=device_mesh,
num_trainers_per_group=args.num_trainers_per_group,
),
precision_config=PrecisionConfig(
computation_dtype=args.computation_dtype.value,
factor_matrix_dtype=args.factor_matrix_dtype.value,
inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value,
corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value,
factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value,
filtered_grad_dtype=args.filtered_grad_dtype.value,
momentum_dtype=args.momentum_dtype.value,
grafting_state_dtype=args.grafting_state_dtype.value,
),
preconditioner_dtype=args.preconditioner_dtype,
use_protected_eigh=args.use_protected_eigh,
track_root_inv_residuals=args.track_root_inv_residuals,
preconditioner_computation_type=args.preconditioner_computation_type,
Expand Down
51 changes: 4 additions & 47 deletions distributed_shampoo/examples/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
DistributedConfig,
DistributedShampoo,
GraftingConfig,
PrecisionConfig,
PreconditionerConfig,
RMSpropGraftingConfig,
SGDGraftingConfig,
Expand Down Expand Up @@ -239,52 +238,10 @@ def get_args():

# Arguments for mixed-precision.
parser.add_argument(
"--computation-dtype",
"--preconditioner-dtype",
type=lambda t: enum_type_parse(t, DType),
default=DType.FP32,
help="Data type for all computation in Shampoo.",
)
parser.add_argument(
"--factor-matrix-dtype",
type=lambda t: enum_type_parse(t, DType),
default=DType.FP32,
help="Data type for storing Shampoo factor matrices.",
)
parser.add_argument(
"--inv-factor-matrix-dtype",
type=lambda t: enum_type_parse(t, DType),
default=DType.FP32,
help="Data type for storing Shampoo inverse factor matrices.",
)
parser.add_argument(
"--corrected-eigenvalues-dtype",
type=lambda t: enum_type_parse(t, DType),
default=DType.FP32,
help="Data type for storing corrected eigenvalues of Shampoo preconditioner.",
)
parser.add_argument(
"--factor-matrix-eigenvectors-dtype",
type=lambda t: enum_type_parse(t, DType),
default=DType.FP32,
help="Data type for storing Shampoo factor matrices eigenvectors.",
)
parser.add_argument(
"--filtered-grad-dtype",
type=lambda t: enum_type_parse(t, DType),
default=DType.FP32,
help="Data type for storing filtered gradients.",
)
parser.add_argument(
"--momentum-dtype",
type=lambda t: enum_type_parse(t, DType),
default=DType.FP32,
help="Data type for storing momentum states.",
)
parser.add_argument(
"--grafting-state-dtype",
type=lambda t: enum_type_parse(t, DType),
default=DType.FP32,
help="Data type for storing grafting preconditioners.",
help="Preconditioner dtype for Shampoo.",
)

# Arguments for DDP Shampoo.
Expand Down Expand Up @@ -440,7 +397,7 @@ def instantiate_optimizer(
use_merge_dims: bool,
use_pytorch_compile: bool,
distributed_config: DistributedConfig | None,
precision_config: PrecisionConfig | None,
preconditioner_dtype: DType,
use_protected_eigh: bool,
track_root_inv_residuals: bool,
preconditioner_computation_type: PreconditionerComputationType,
Expand Down Expand Up @@ -495,7 +452,7 @@ def instantiate_optimizer(
use_merge_dims=use_merge_dims,
use_pytorch_compile=use_pytorch_compile,
distributed_config=distributed_config,
precision_config=precision_config,
preconditioner_dtype=preconditioner_dtype.value,
use_protected_eigh=use_protected_eigh,
track_root_inv_residuals=track_root_inv_residuals,
preconditioner_config=instantiate_preconditioner_config(
Expand Down
3 changes: 2 additions & 1 deletion distributed_shampoo/gpu_tests/shampoo_pt2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def _shampoo_optim_factory(
parameters,
lr=0.01,
betas=betas,
beta3=betas[0] * betas[0],
# TODO: comment out beta3 to unblock quantization changes; need to fix PT2 FMA changes for this test
# beta3=betas[0] * betas[0],
epsilon=1e-10,
momentum=0.9,
dampening=0.9,
Expand Down
72 changes: 28 additions & 44 deletions distributed_shampoo/shampoo_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
LR = "lr"
MAX_PRECONDITIONER_DIM = "max_preconditioner_dim"
PARAMS = "params" # While this is stored in groups by default, we do not checkpoint this quantity.
PRECISION_CONFIG = "precision_config"
PRECONDITION_FREQUENCY = "precondition_frequency"
PRECONDITIONER_DTYPE = "preconditioner_dtype"
PRECONDITIONER_CONFIG = "preconditioner_config"
Expand Down Expand Up @@ -89,7 +88,7 @@ class PreconditionerConfig(AbstractDataclass):
"""

amortized_computation_config: MatrixFunctionConfig
amortized_computation_config: MatrixFunctionConfig # type: ignore


@dataclass(kw_only=True)
Expand Down Expand Up @@ -154,42 +153,6 @@ class FSDPParameterMetadata:
sharding_strategy: ShardingStrategy


@dataclass
class PrecisionConfig:
"""Configuration for precision of each optimizer state.
TODO: allow more specific computation dtypes that only apply to some computations
Args:
computation_dtype (torch.dtype): Data type that all computation is performed in, except factor matrices (see factor_matrix_computation_dtype). (Default: torch.float32)
factor_matrix_dtype (torch.dtype): Data type for storing Shampoo factor matrices. (Default: torch.float32)
inv_factor_matrix_dtype (torch.dtype): Data type for storing Shampoo inverse factor matrices. (Default: torch.float32)
factor_matrix_computation_dtype (torch.dtype): Data type for accumulating factor matrices and computing their inverses. (Default: torch.float32)
corrected_eigenvalues_dtype (torch.dtype): Data type for storing the corrected eigenvalues of Shampoo preconditioner (EMA). (Default: torch.float32)
factor_matrix_eigenvectors_dtype (torch.dtype): Data type for storing the eigenvectors of Shampoo factor matrices. (Default: torch.float32)
filtered_grad_dtype (torch.dtype): Data type for storing filtered gradients (EMA). (Default: torch.float32)
momentum_dtype (torch.dtype): Data type for storing momentum states. (Default: torch.float32)
grafting_state_dtype (torch.dtype): Data type for storing grafting preconditioners, if applicable. (Default: torch.float32)
Current applicable grafting configs:
- AdaGradGraftingConfig
- RMSpropGraftingConfig
- AdamGraftingConfig
NOT applicable configs:
- SGDGraftingConfig
- None (i.e. no grafting)
"""

computation_dtype: torch.dtype = torch.float32
factor_matrix_dtype: torch.dtype = torch.float32
inv_factor_matrix_dtype: torch.dtype = torch.float32
corrected_eigenvalues_dtype: torch.dtype = torch.float32
factor_matrix_eigenvectors_dtype: torch.dtype = torch.float32
factor_matrix_computation_dtype: torch.dtype = torch.float32
filtered_grad_dtype: torch.dtype = torch.float32
momentum_dtype: torch.dtype = torch.float32
grafting_state_dtype: torch.dtype = torch.float32


@dataclass(init=False)
class DistributedConfig(AbstractDataclass):
"""Abstract dataclass for distributed configs in Shampoo."""
Expand Down Expand Up @@ -229,6 +192,29 @@ class FSDPShampooConfig(DistributedConfig):
param_to_metadata: dict[Parameter, FSDPParameterMetadata]


@dataclass
class HSDPShampooConfig(FSDPShampooConfig, DDPShampooConfig):
"""Configuration for HSDP Shampoo.
Enables distributed computation and optimizer states (like ZeRO-1) via DTensor for Shampoo across ranks with shared
parameters between different HSDP process groups.
Args:
device_mesh (torch.distributed.device_mesh.DeviceMesh): A 2D device mesh that specifies the layout of the numbers of
shard and replicate dimensions.
param_to_metadata (dict[Parameter, FSDPParameterMetadata]): Dictionary mapping parameter to its metadata from HSDP.
communication_dtype (CommunicationDType): Data type for communication between ranks. (Default: DEFAULT)
num_trainers_per_group (int): Number of GPUs per distributed process group for distributed computation/memory.
If num_trainers_per_group = -1 is used, then defaults to using the number of workers in each replicated HSDP
group. (Default: -1)
communicate_params (bool): Flag for all-gathering updated params across multiple workers.
If False, all-gathers parameter updates across multiple workers. (Default: False)
"""

device_mesh: DeviceMesh


@dataclass(kw_only=True)
class FullyShardShampooConfig(DistributedConfig):
"""Configuration for FullyShard (per-parameter FSDP) Shampoo.
Expand All @@ -238,16 +224,14 @@ class FullyShardShampooConfig(DistributedConfig):


@dataclass
class HSDPShampooConfig(FSDPShampooConfig, DDPShampooConfig):
"""Configuration for HSDP Shampoo.
class HybridShardShampooConfig(FullyShardShampooConfig, DDPShampooConfig):
"""Configuration for HybridShard (per-parameter FSDP) Shampoo.
Enables distributed computation and optimizer states (like ZeRO-1) via DTensor for Shampoo across ranks with shared
parameters between different HSDP process groups.
parameters between different Hybrid Shard process groups.
Args:
device_mesh (torch.distributed.device_mesh.DeviceMesh): A 2D device mesh that specifies the layout of the numbers of
shard and replicate dimensions.
param_to_metadata (dict[Parameter, FSDPParameterMetadata]): Dictionary mapping parameter to its metadata from HSDP.
device_mesh (torch.distributed.device_mesh.DeviceMesh): Device mesh for Hybrid Shard.
communication_dtype (CommunicationDType): Data type for communication between ranks. (Default: DEFAULT)
num_trainers_per_group (int): Number of GPUs per distributed process group for distributed computation/memory.
If num_trainers_per_group = -1 is used, then defaults to using the number of workers in each replicated HSDP
Expand Down
Loading

0 comments on commit 45d0d95

Please sign in to comment.