From 4bf0dbd0b726c352888a1ff7baaed511a331b110 Mon Sep 17 00:00:00 2001 From: runame Date: Wed, 13 Nov 2024 11:16:08 +0000 Subject: [PATCH] Fix type hints --- distributed_shampoo/utils/shampoo_ddp_distributor.py | 4 ++-- distributed_shampoo/utils/shampoo_hsdp_distributor.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/distributed_shampoo/utils/shampoo_ddp_distributor.py b/distributed_shampoo/utils/shampoo_ddp_distributor.py index ad9408a..ca6bea3 100644 --- a/distributed_shampoo/utils/shampoo_ddp_distributor.py +++ b/distributed_shampoo/utils/shampoo_ddp_distributor.py @@ -78,9 +78,9 @@ def __init__( # Determine communication type. if distributed_config.communication_dtype == CommunicationDType.BF16: - communication_dtype: torch.dtype = torch.bfloat16 + communication_dtype = torch.bfloat16 elif distributed_config.communication_dtype == CommunicationDType.FP16: - communication_dtype: torch.dtype = torch.float16 + communication_dtype = torch.float16 else: assert distributed_config.communication_dtype in [ CommunicationDType.FP32, diff --git a/distributed_shampoo/utils/shampoo_hsdp_distributor.py b/distributed_shampoo/utils/shampoo_hsdp_distributor.py index c2b48e7..433dcfc 100644 --- a/distributed_shampoo/utils/shampoo_hsdp_distributor.py +++ b/distributed_shampoo/utils/shampoo_hsdp_distributor.py @@ -152,9 +152,9 @@ def __init__( # Determine communication type. if distributed_config.communication_dtype == CommunicationDType.BF16: - communication_dtype: torch.dtype = torch.bfloat16 + communication_dtype = torch.bfloat16 elif distributed_config.communication_dtype == CommunicationDType.FP16: - communication_dtype: torch.dtype = torch.float16 + communication_dtype = torch.float16 else: assert distributed_config.communication_dtype in [ CommunicationDType.FP32,