diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index 1089f64..4f340b1 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -12,7 +12,7 @@ from collections.abc import Callable, Iterator, Sequence from copy import deepcopy from functools import partial -from typing import Any, NoReturn +from typing import Any import torch @@ -1098,12 +1098,12 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: # t return loss - def state_dict(self) -> NoReturn: + def state_dict(self) -> StateDict: raise NotImplementedError( "Distributed Shampoo does not support the standard state_dict() method for checkpointing!" ) - def load_state_dict(self, state_dict: StateDict) -> NoReturn: + def load_state_dict(self, state_dict: StateDict) -> None: raise NotImplementedError( "Distributed Shampoo does not support the standard load_state_dict() method for checkpointing!" )