Skip to content

Commit

Permalink
Merge branch 'configs-refactor' into fail-counter
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Dec 10, 2024
2 parents 709ab1c + 273e0a1 commit d53ffc6
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions distributed_shampoo/shampoo_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

import enum
from dataclasses import dataclass
from dataclasses import dataclass, field

import torch

Expand Down Expand Up @@ -109,7 +109,9 @@ class ShampooPreconditionerConfig(PreconditionerConfig):
"""

amortized_computation_config: RootInvConfig = DefaultEigenConfig
amortized_computation_config: RootInvConfig = field(
default_factory=lambda: DefaultEigenConfig
)


DefaultShampooConfig = ShampooPreconditionerConfig()
Expand All @@ -124,7 +126,9 @@ class EigenvalueCorrectedShampooPreconditionerConfig(PreconditionerConfig):
"""

amortized_computation_config: EigenvectorConfig = DefaultEighConfig
amortized_computation_config: EigenvectorConfig = field(
default_factory=lambda: DefaultEighConfig
)


DefaultEigenvalueCorrectedShampooConfig = (
Expand Down

0 comments on commit d53ffc6

Please sign in to comment.