Skip to content

Commit

Permalink
Fix typing error complains due to upstream torch.optimizer decorator …
Browse files Browse the repository at this point in the history
…typing changes

Summary:
Due to pytorch/pytorch#144161, the current disabled `state_dict()` and `load_state_dict()` typing starts complains due to the incompatible between `typing.NoReturn` and `None` and `StateDict`.

Although `typing.NoReturn` is a better typing for the disabled checkpoint solution, unfortunately this is not feasible due to `torch.optimizer` typing requirement.

Reviewed By: aorenste

Differential Revision: D67951563

fbshipit-source-id: aeafb8f71ab5a9e077760d8b884af9584b017d7e
  • Loading branch information
tsunghsienlee authored and facebook-github-bot committed Jan 9, 2025
1 parent c51e4e6 commit 166dcbf
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions distributed_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from collections.abc import Callable, Iterator, Sequence
from copy import deepcopy
from functools import partial
from typing import Any, NoReturn
from typing import Any

import torch

Expand Down Expand Up @@ -1098,12 +1098,12 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: # t

return loss

def state_dict(self) -> NoReturn:
def state_dict(self) -> StateDict:
raise NotImplementedError(
"Distributed Shampoo does not support the standard state_dict() method for checkpointing!"
)

def load_state_dict(self, state_dict: StateDict) -> NoReturn:
def load_state_dict(self, state_dict: StateDict) -> None:
raise NotImplementedError(
"Distributed Shampoo does not support the standard load_state_dict() method for checkpointing!"
)
Expand Down

0 comments on commit 166dcbf

Please sign in to comment.