Skip to content

Commit

Permalink
zero-shot distributed eval: all reduce metrics all together
Browse files Browse the repository at this point in the history
  • Loading branch information
mehdidc committed Sep 24, 2022
1 parent ff36243 commit 84b2dc7
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions src/training/zero_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,9 @@ def run(model, classifier, dataloader, args):
top5 += acc5
n += images.size(0)
if args.distributed_evaluation:
n = torch.Tensor([n]).to(args.device)
torch.distributed.all_reduce(n)
n = n.item()
top = torch.Tensor([top1, top5]).to(args.device)
torch.distributed.all_reduce(top)
top1, top5 = top.tolist()
vals = torch.Tensor([n, top1, top5]).to(args.device)
torch.distributed.all_reduce(vals)
n, top1, top5 = vals.tolist()
top1 = (top1 / n)
top5 = (top5 / n)
return top1, top5
Expand Down

0 comments on commit 84b2dc7

Please sign in to comment.