Skip to content

Commit

Permalink
Fix type hints (#45)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #45

Reviewed By: anana10c

Differential Revision: D65882392

Pulled By: tsunghsienlee

fbshipit-source-id: f57a74be3c4bdaf1fdda869f30606b1f80d2a76d
  • Loading branch information
runame authored and facebook-github-bot committed Nov 13, 2024
1 parent 13f68b7 commit 7ffbf1b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions distributed_shampoo/utils/shampoo_ddp_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def __init__(

# Determine communication type.
if distributed_config.communication_dtype == CommunicationDType.BF16:
communication_dtype: torch.dtype = torch.bfloat16
communication_dtype = torch.bfloat16
elif distributed_config.communication_dtype == CommunicationDType.FP16:
communication_dtype: torch.dtype = torch.float16
communication_dtype = torch.float16
else:
assert distributed_config.communication_dtype in [
CommunicationDType.FP32,
Expand Down
4 changes: 2 additions & 2 deletions distributed_shampoo/utils/shampoo_hsdp_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ def __init__(

# Determine communication type.
if distributed_config.communication_dtype == CommunicationDType.BF16:
communication_dtype: torch.dtype = torch.bfloat16
communication_dtype = torch.bfloat16
elif distributed_config.communication_dtype == CommunicationDType.FP16:
communication_dtype: torch.dtype = torch.float16
communication_dtype = torch.float16
else:
assert distributed_config.communication_dtype in [
CommunicationDType.FP32,
Expand Down

0 comments on commit 7ffbf1b

Please sign in to comment.