diff --git a/distributed_shampoo/shampoo_types.py b/distributed_shampoo/shampoo_types.py index 42e57ab..bf5ccdd 100644 --- a/distributed_shampoo/shampoo_types.py +++ b/distributed_shampoo/shampoo_types.py @@ -8,7 +8,7 @@ """ import enum -from dataclasses import dataclass +from dataclasses import dataclass, field import torch @@ -101,7 +101,9 @@ class ShampooPreconditionerConfig(PreconditionerConfig): """ - amortized_computation_config: RootInvConfig = DefaultEigenConfig + amortized_computation_config: RootInvConfig = field( + default_factory=lambda: DefaultEigenConfig + ) DefaultShampooConfig = ShampooPreconditionerConfig() @@ -116,7 +118,9 @@ class EigenvalueCorrectedShampooPreconditionerConfig(PreconditionerConfig): """ - amortized_computation_config: EigenvectorConfig = DefaultEighConfig + amortized_computation_config: EigenvectorConfig = field( + default_factory=lambda: DefaultEighConfig + ) DefaultEigenvalueCorrectedShampooConfig = (