Skip to content

Commit

Permalink
Move ClipLoss to clip package, enable label caching for train
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Mar 30, 2022
1 parent d8084a6 commit 2e8f3a4
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 100 deletions.
1 change: 1 addition & 0 deletions src/clip/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .factory import create_model_and_transforms
from .loss import ClipLoss
from .model import CLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_fp16
from .openai_clip import load_openai
from .pretrained import list_pretrained, list_pretrained_tag_models, list_pretrained_model_tags,\
Expand Down
114 changes: 114 additions & 0 deletions src/clip/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import torch
import torch.distributed.nn
from torch import distributed as dist, nn as nn
from torch.nn import functional as F

try:
import horovod.torch as hvd
except ImportError:
hvd = None


def gather_features(
image_features,
text_features,
local_loss=False,
gather_with_grad=False,
rank=0,
world_size=1,
use_horovod=False
):
if use_horovod:
assert hvd is not None, 'Please install horovod'
if gather_with_grad:
all_image_features = hvd.allgather(image_features)
all_text_features = hvd.allgather(text_features)
else:
with torch.no_grad():
all_image_features = hvd.allgather(image_features)
all_text_features = hvd.allgather(text_features)
if not local_loss:
# ensure grads for local rank when all_* features don't have a gradient
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
gathered_image_features[rank] = image_features
gathered_text_features[rank] = text_features
all_image_features = torch.cat(gathered_image_features, dim=0)
all_text_features = torch.cat(gathered_text_features, dim=0)
else:
# We gather tensors from all gpus
if gather_with_grad:
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
else:
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
dist.all_gather(gathered_image_features, image_features)
dist.all_gather(gathered_text_features, text_features)
if not local_loss:
# ensure grads for local rank when all_* features don't have a gradient
gathered_image_features[rank] = image_features
gathered_text_features[rank] = text_features
all_image_features = torch.cat(gathered_image_features, dim=0)
all_text_features = torch.cat(gathered_text_features, dim=0)

return all_image_features, all_text_features


class ClipLoss(nn.Module):

def __init__(
self,
local_loss=False,
gather_with_grad=False,
cache_labels=False,
rank=0,
world_size=1,
use_horovod=False,
):
super().__init__()
self.local_loss = local_loss
self.gather_with_grad = gather_with_grad
self.cache_labels = cache_labels
self.rank = rank
self.world_size = world_size
self.use_horovod = use_horovod

# cache state
self.prev_num_logits = 0
self.labels = {}

def forward(self, image_features, text_features, logit_scale):
device = image_features.device
if self.world_size > 1:
all_image_features, all_text_features = gather_features(
image_features, text_features,
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)

if self.local_loss:
logits_per_image = logit_scale * image_features @ all_text_features.T
logits_per_text = logit_scale * text_features @ all_image_features.T
else:
logits_per_image = logit_scale * all_image_features @ all_text_features.T
logits_per_text = logits_per_image.T
else:
logits_per_image = logit_scale * image_features @ text_features.T
logits_per_text = logit_scale * text_features @ image_features.T

# calculated ground-truth and cache if enabled
num_logits = logits_per_image.shape[0]
if self.prev_num_logits != num_logits or device not in self.labels:
labels = torch.arange(num_logits, device=device, dtype=torch.long)
if self.world_size > 1 and self.local_loss:
labels = labels + num_logits * self.rank
if self.cache_labels:
self.labels[device] = labels
self.prev_num_logits = num_logits
else:
labels = self.labels[device]

total_loss = (
F.cross_entropy(logits_per_image, labels) +
F.cross_entropy(logits_per_text, labels)
) / 2
return total_loss
108 changes: 8 additions & 100 deletions src/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,112 +7,14 @@

import numpy as np
import torch
import torch.distributed as dist
import torch.distributed.nn
import torch.nn as nn
import torch.nn.functional as F
import wandb

try:
import horovod.torch as hvd
except ImportError:
hvd = None

from clip.loss import ClipLoss
from .distributed import is_master
from .zero_shot import zero_shot_eval


def gather_features(
image_features, text_features,
local_loss=False, gather_with_grad=False, rank=0, world_size=1, horovod=False):
if horovod:
assert hvd is not None, 'Please install horovod'
if gather_with_grad:
all_image_features = hvd.allgather(image_features)
all_text_features = hvd.allgather(text_features)
else:
with torch.no_grad():
all_image_features = hvd.allgather(image_features)
all_text_features = hvd.allgather(text_features)
if not local_loss:
# ensure grads for local rank when all_* features don't have a gradient
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
gathered_image_features[rank] = image_features
gathered_text_features[rank] = text_features
all_image_features = torch.cat(gathered_image_features, dim=0)
all_text_features = torch.cat(gathered_text_features, dim=0)
else:
# We gather tensors from all gpus
if gather_with_grad:
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
else:
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
dist.all_gather(gathered_image_features, image_features)
dist.all_gather(gathered_text_features, text_features)
if not local_loss:
# ensure grads for local rank when all_* features don't have a gradient
gathered_image_features[rank] = image_features
gathered_text_features[rank] = text_features
all_image_features = torch.cat(gathered_image_features, dim=0)
all_text_features = torch.cat(gathered_text_features, dim=0)
return all_image_features, all_text_features


class ClipLoss(nn.Module):

def __init__(
self,
local_loss=False,
gather_with_grad=False,
rank=0,
world_size=1,
horovod=False,
):
super().__init__()
self.local_loss = local_loss
self.gather_with_grad = gather_with_grad
self.rank = rank
self.world_size = world_size
self.horovod = horovod
self.prev_num_logits = 0
self.labels = {}

def forward(self, image_features, text_features, logit_scale):
device = image_features.device
if self.world_size > 1:
all_image_features, all_text_features = gather_features(
image_features, text_features,
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.horovod)

if self.local_loss:
logits_per_image = logit_scale * image_features @ all_text_features.T
logits_per_text = logit_scale * text_features @ all_image_features.T
else:
logits_per_image = logit_scale * all_image_features @ all_text_features.T
logits_per_text = logits_per_image.T
else:
logits_per_image = logit_scale * image_features @ text_features.T
logits_per_text = logit_scale * text_features @ image_features.T

num_logits = logits_per_image.shape[0]
if self.prev_num_logits != num_logits or device not in self.labels:
labels = torch.arange(num_logits, device=device, dtype=torch.long)
if self.world_size > 1 and self.local_loss:
labels = labels + num_logits * self.rank
self.labels[device] = labels
else:
labels = self.labels[device]

total_loss = (
F.cross_entropy(logits_per_image, labels) +
F.cross_entropy(logits_per_text, labels)
) / 2
return total_loss


class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
Expand Down Expand Up @@ -143,7 +45,13 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w
autocast = torch.cuda.amp.autocast if args.precision == 'amp' else suppress

model.train()
loss = ClipLoss(args.local_loss, args.gather_with_grad, args.rank, args.world_size, args.horovod)
loss = ClipLoss(
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod)

dataloader, sampler = data['train'].dataloader, data['train'].sampler
if args.distributed and sampler is not None:
Expand Down

0 comments on commit 2e8f3a4

Please sign in to comment.