diff --git a/distributed_shampoo/utils/shampoo_ddp_distributor.py b/distributed_shampoo/utils/shampoo_ddp_distributor.py index ad9408a..ca6bea3 100644 --- a/distributed_shampoo/utils/shampoo_ddp_distributor.py +++ b/distributed_shampoo/utils/shampoo_ddp_distributor.py @@ -78,9 +78,9 @@ def __init__( # Determine communication type. if distributed_config.communication_dtype == CommunicationDType.BF16: - communication_dtype: torch.dtype = torch.bfloat16 + communication_dtype = torch.bfloat16 elif distributed_config.communication_dtype == CommunicationDType.FP16: - communication_dtype: torch.dtype = torch.float16 + communication_dtype = torch.float16 else: assert distributed_config.communication_dtype in [ CommunicationDType.FP32, diff --git a/distributed_shampoo/utils/shampoo_hsdp_distributor.py b/distributed_shampoo/utils/shampoo_hsdp_distributor.py index c2b48e7..433dcfc 100644 --- a/distributed_shampoo/utils/shampoo_hsdp_distributor.py +++ b/distributed_shampoo/utils/shampoo_hsdp_distributor.py @@ -152,9 +152,9 @@ def __init__( # Determine communication type. if distributed_config.communication_dtype == CommunicationDType.BF16: - communication_dtype: torch.dtype = torch.bfloat16 + communication_dtype = torch.bfloat16 elif distributed_config.communication_dtype == CommunicationDType.FP16: - communication_dtype: torch.dtype = torch.float16 + communication_dtype = torch.float16 else: assert distributed_config.communication_dtype in [ CommunicationDType.FP32,