Skip to content

Commit

Permalink
automatically take care of distributed setting
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 4, 2023
1 parent 8c2b865 commit aac5d53
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'x-clip',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.12.2',
version = '0.12.4',
license='MIT',
description = 'X-CLIP',
author = 'Phil Wang',
Expand Down
70 changes: 70 additions & 0 deletions x_clip/x_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
import torch.nn.functional as F
from torch import nn, einsum
from torch.utils.checkpoint import checkpoint

from torch.autograd import Function
import torch.distributed as distributed

from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange, Reduce

Expand Down Expand Up @@ -40,6 +44,11 @@ def masked_mean(t, mask, dim = 1, eps = 1e-6):
denom = mask.sum(dim = dim).clamp(min = eps)
return numer / denom

def pad_dim_to(t, length, dim = 0):
pad_length = length - t.shape[dim]
zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length))

def log(t, eps = 1e-20):
return torch.log(t + eps)

Expand All @@ -56,6 +65,58 @@ def matrix_diag(t):
diag_el = t.masked_select(diag_mask)
return rearrange(diag_el, '(b d) -> b d', d = num_diag_el)

# distributed helpers

def all_gather_variable_batch(t):
device, rank, world_size = t.device, distributed.get_rank(), distributed.get_world_size()

batch_size = torch.tensor([t.shape[0]], device = device, dtype = torch.long)
batch_sizes = [torch.empty(1, device = device, dtype = torch.long) for i in range(world_size)]
distributed.all_gather(batch_sizes, batch_size)

batch_sizes = torch.cat(batch_sizes)
max_batch_size = batch_sizes.amax().item()
padded_t = pad_dim_to(t, max_batch_size)

gathered_tensors = [torch.empty(padded_t.shape, device = device, dtype = padded_t.dtype) for i in range(world_size)]
distributed.all_gather(gathered_tensors, padded_t)

gathered_tensor = torch.cat(gathered_tensors, dim = 0)
seq = torch.arange(max_batch_size, device = device)

mask = rearrange(seq, 'j -> 1 j') < rearrange(batch_sizes, 'i -> i 1')
mask = rearrange(mask, 'i j -> (i j)')

gathered_tensor = gathered_tensor[mask]
batch_sizes = batch_sizes.tolist()

return gathered_tensor, batch_sizes

class MaybeAllGather(Function):
@staticmethod
def forward(ctx, x):
is_distributed = distributed.is_initialized() and distributed.get_world_size() > 1
ctx.is_distributed = is_distributed

if not is_distributed:
return x

x = rearrange(x, 'm b ... -> b m ...')
x, batch_sizes = all_gather_variable_batch(x)
ctx.batch_sizes = batch_sizes
return rearrange(x, 'b m ... -> m b ...')

@staticmethod
def backward(ctx, grads):
if not ctx.is_distributed:
return grads

batch_sizes, rank = ctx.batch_sizes, distributed.get_rank()
grads_by_rank = grads.split(batch_sizes, dim = 0)
return grads_by_rank[rank]

maybe_all_gather = MaybeAllGather.apply

# checkpointing helper function

def make_checkpointable(fn):
Expand Down Expand Up @@ -738,6 +799,15 @@ def forward(
text_latents_extra = rearrange(text_latents_extra, '(m b) ... -> m b ...', m = num_batch_texts)
image_latents_extra = rearrange(image_latents_extra, '(m b) ... -> m b ...', m = num_batch_images)

# maybe distributed all gather

text_latents = maybe_all_gather(text_latents)
image_latents = maybe_all_gather(image_latents)

if self.extra_latent_projection:
text_latents_extra = maybe_all_gather(text_latents_extra)
image_latents_extra = maybe_all_gather(image_latents_extra)

# contrastive loss

"""
Expand Down

0 comments on commit aac5d53

Please sign in to comment.