Skip to content

Commit

Permalink
Fix #77, guard torch.distributed imports so it won't break on Windows
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Jun 24, 2022
1 parent e390448 commit 9078293
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/open_clip/loss.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import torch
import torch.distributed.nn
from torch import distributed as dist, nn as nn
import torch.nn as nn
from torch.nn import functional as F

try:
import torch.distributed.nn
from torch import distributed as dist
has_distributed = True
except ImportError:
has_distributed = False

try:
import horovod.torch as hvd
except ImportError:
Expand All @@ -18,6 +24,7 @@ def gather_features(
world_size=1,
use_horovod=False
):
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
if use_horovod:
assert hvd is not None, 'Please install horovod'
if gather_with_grad:
Expand Down

0 comments on commit 9078293

Please sign in to comment.