From 9078293a67501d9489b22b8ee5b6c10b061a6b54 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 Jun 2022 12:10:07 -0700 Subject: [PATCH] Fix #77, guard torch.distributed imports so it won't break on Windows --- src/open_clip/loss.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 191096644..de31426df 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -1,8 +1,14 @@ import torch -import torch.distributed.nn -from torch import distributed as dist, nn as nn +import torch.nn as nn from torch.nn import functional as F +try: + import torch.distributed.nn + from torch import distributed as dist + has_distributed = True +except ImportError: + has_distributed = False + try: import horovod.torch as hvd except ImportError: @@ -18,6 +24,7 @@ def gather_features( world_size=1, use_horovod=False ): + assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.' if use_horovod: assert hvd is not None, 'Please install horovod' if gather_with_grad: