Skip to content

Commit

Permalink
generalize the distributed all gather to any dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 4, 2023
1 parent aac5d53 commit 67aa6e6
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 25 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'x-clip',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.12.4',
version = '0.12.6',
license='MIT',
description = 'X-CLIP',
author = 'Phil Wang',
Expand Down
50 changes: 26 additions & 24 deletions x_clip/x_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,53 +67,55 @@ def matrix_diag(t):

# distributed helpers

def all_gather_variable_batch(t):
def all_gather_variable_batch(t, dim = 0):
device, rank, world_size = t.device, distributed.get_rank(), distributed.get_world_size()

batch_size = torch.tensor([t.shape[0]], device = device, dtype = torch.long)
batch_sizes = [torch.empty(1, device = device, dtype = torch.long) for i in range(world_size)]
distributed.all_gather(batch_sizes, batch_size)
size = torch.tensor(t.shape[dim], device = device, dtype = torch.long)
sizes = [torch.empty_like(size, device = device, dtype = torch.long) for i in range(world_size)]
distributed.all_gather(sizes, size)

batch_sizes = torch.cat(batch_sizes)
max_batch_size = batch_sizes.amax().item()
padded_t = pad_dim_to(t, max_batch_size)
sizes = torch.stack(sizes)
max_size = sizes.amax().item()
padded_t = pad_dim_to(t, max_size, dim = dim)

gathered_tensors = [torch.empty(padded_t.shape, device = device, dtype = padded_t.dtype) for i in range(world_size)]
distributed.all_gather(gathered_tensors, padded_t)

gathered_tensor = torch.cat(gathered_tensors, dim = 0)
seq = torch.arange(max_batch_size, device = device)
gathered_tensor = torch.cat(gathered_tensors, dim = dim)
seq = torch.arange(max_size, device = device)

mask = rearrange(seq, 'j -> 1 j') < rearrange(batch_sizes, 'i -> i 1')
mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
mask = rearrange(mask, 'i j -> (i j)')
seq = torch.arange(mask.shape[-1], device = device)
indices = seq[mask]

gathered_tensor = gathered_tensor[mask]
batch_sizes = batch_sizes.tolist()
gathered_tensor = gathered_tensor.index_select(dim, indices)
sizes = sizes.tolist()

return gathered_tensor, batch_sizes
return gathered_tensor, sizes

class MaybeAllGather(Function):
@staticmethod
def forward(ctx, x):
def forward(ctx, x, dim):
is_distributed = distributed.is_initialized() and distributed.get_world_size() > 1
ctx.is_distributed = is_distributed
ctx.dim = dim

if not is_distributed:
return x

x = rearrange(x, 'm b ... -> b m ...')
x, batch_sizes = all_gather_variable_batch(x)
x, batch_sizes = all_gather_variable_dim(x, dim = dim)
ctx.batch_sizes = batch_sizes
return rearrange(x, 'b m ... -> m b ...')
return x

@staticmethod
def backward(ctx, grads):
if not ctx.is_distributed:
return grads
return grads, None

batch_sizes, rank = ctx.batch_sizes, distributed.get_rank()
grads_by_rank = grads.split(batch_sizes, dim = 0)
return grads_by_rank[rank]
grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
return grads_by_rank[rank], None

maybe_all_gather = MaybeAllGather.apply

Expand Down Expand Up @@ -801,12 +803,12 @@ def forward(

# maybe distributed all gather

text_latents = maybe_all_gather(text_latents)
image_latents = maybe_all_gather(image_latents)
text_latents = maybe_all_gather(text_latents, 1)
image_latents = maybe_all_gather(image_latents, 1)

if self.extra_latent_projection:
text_latents_extra = maybe_all_gather(text_latents_extra)
image_latents_extra = maybe_all_gather(image_latents_extra)
text_latents_extra = maybe_all_gather(text_latents_extra, 1)
image_latents_extra = maybe_all_gather(image_latents_extra, 1)

# contrastive loss

Expand Down

0 comments on commit 67aa6e6

Please sign in to comment.