diff --git a/distributed_shampoo/shampoo_types.py b/distributed_shampoo/shampoo_types.py index 6d34f65..5ae7306 100644 --- a/distributed_shampoo/shampoo_types.py +++ b/distributed_shampoo/shampoo_types.py @@ -8,7 +8,7 @@ """ import enum -from dataclasses import dataclass +from dataclasses import dataclass, field import torch @@ -109,7 +109,9 @@ class ShampooPreconditionerConfig(PreconditionerConfig): """ - amortized_computation_config: RootInvConfig = DefaultEigenConfig + amortized_computation_config: RootInvConfig = field( + default_factory=lambda: DefaultEigenConfig + ) DefaultShampooConfig = ShampooPreconditionerConfig() @@ -124,7 +126,9 @@ class EigenvalueCorrectedShampooPreconditionerConfig(PreconditionerConfig): """ - amortized_computation_config: EigenvectorConfig = DefaultEighConfig + amortized_computation_config: EigenvectorConfig = field( + default_factory=lambda: DefaultEighConfig + ) DefaultEigenvalueCorrectedShampooConfig = (