From 273e0a1585bdfa6adb1fb454aed3963b7c0dc518 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 10 Dec 2024 18:07:15 +0000 Subject: [PATCH] Fix defaults with default_factory --- distributed_shampoo/shampoo_types.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/distributed_shampoo/shampoo_types.py b/distributed_shampoo/shampoo_types.py index 42e57ab..bf5ccdd 100644 --- a/distributed_shampoo/shampoo_types.py +++ b/distributed_shampoo/shampoo_types.py @@ -8,7 +8,7 @@ """ import enum -from dataclasses import dataclass +from dataclasses import dataclass, field import torch @@ -101,7 +101,9 @@ class ShampooPreconditionerConfig(PreconditionerConfig): """ - amortized_computation_config: RootInvConfig = DefaultEigenConfig + amortized_computation_config: RootInvConfig = field( + default_factory=lambda: DefaultEigenConfig + ) DefaultShampooConfig = ShampooPreconditionerConfig() @@ -116,7 +118,9 @@ class EigenvalueCorrectedShampooPreconditionerConfig(PreconditionerConfig): """ - amortized_computation_config: EigenvectorConfig = DefaultEighConfig + amortized_computation_config: EigenvectorConfig = field( + default_factory=lambda: DefaultEighConfig + ) DefaultEigenvalueCorrectedShampooConfig = (