Skip to content

Commit

Permalink
Fix defaults with default_factory
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Dec 10, 2024
1 parent d3a10af commit 273e0a1
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions distributed_shampoo/shampoo_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

import enum
from dataclasses import dataclass
from dataclasses import dataclass, field

import torch

Expand Down Expand Up @@ -101,7 +101,9 @@ class ShampooPreconditionerConfig(PreconditionerConfig):
"""

amortized_computation_config: RootInvConfig = DefaultEigenConfig
amortized_computation_config: RootInvConfig = field(
default_factory=lambda: DefaultEigenConfig
)


DefaultShampooConfig = ShampooPreconditionerConfig()
Expand All @@ -116,7 +118,9 @@ class EigenvalueCorrectedShampooPreconditionerConfig(PreconditionerConfig):
"""

amortized_computation_config: EigenvectorConfig = DefaultEighConfig
amortized_computation_config: EigenvectorConfig = field(
default_factory=lambda: DefaultEighConfig
)


DefaultEigenvalueCorrectedShampooConfig = (
Expand Down

0 comments on commit 273e0a1

Please sign in to comment.