Skip to content

Commit

Permalink
maybe more efficient all gather
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 5, 2023
1 parent a846b2f commit 95b722f
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 64 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'x-clip',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.12.7',
version = '0.12.9',
license='MIT',
description = 'X-CLIP',
author = 'Phil Wang',
Expand All @@ -19,7 +19,7 @@
],
install_requires=[
'beartype',
'einops>=0.3',
'einops>=0.6',
'ftfy',
'regex',
'torch>=1.6',
Expand Down
51 changes: 51 additions & 0 deletions x_clip/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
from torch.autograd import Function
import torch.distributed as distributed

from einops import rearrange

# distributed helpers

def all_gather_variable_dim(t, dim = 0, sizes = None):
device, rank, world_size = t.device, distributed.get_rank(), distributed.get_world_size()

if not exists(sizes):
size = torch.tensor(t.shape[dim], device = device, dtype = torch.long)
sizes = [torch.empty_like(size, device = device, dtype = torch.long) for i in range(world_size)]
distributed.all_gather(sizes, size)
sizes = torch.stack(sizes)

max_size = sizes.amax().item()
padded_t = pad_dim_to(t, max_size, dim = dim)

gathered_tensors = [torch.empty(padded_t.shape, device = device, dtype = padded_t.dtype) for i in range(world_size)]
distributed.all_gather(gathered_tensors, padded_t)

gathered_tensor = torch.cat(gathered_tensors, dim = dim)
seq = torch.arange(max_size, device = device)

mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
mask = rearrange(mask, 'i j -> (i j)')
seq = torch.arange(mask.shape[-1], device = device)
indices = seq[mask]

gathered_tensor = gathered_tensor.index_select(dim, indices)

return gathered_tensor, sizes

class AllGather(Function):
@staticmethod
def forward(ctx, x, dim, sizes):
assert distributed.is_initialized() and distributed.get_world_size() > 1
x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
ctx.batch_sizes = batch_sizes.tolist()
ctx.dim = dim
return x, batch_sizes

@staticmethod
def backward(ctx, grads, _):
batch_sizes, rank = ctx.batch_sizes, distributed.get_rank()
grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
return grads_by_rank[rank], None, None

all_gather = AllGather.apply
75 changes: 13 additions & 62 deletions x_clip/x_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,16 @@

import torch
import torch.nn.functional as F
import torch.distributed as distributed
from torch import nn, einsum
from torch.utils.checkpoint import checkpoint

from torch.autograd import Function
import torch.distributed as distributed

from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange, Reduce

from x_clip.mlm import MLM
from x_clip.visual_ssl import SimSiam, SimCLR
from x_clip.distributed import all_gather

# helper functions

Expand Down Expand Up @@ -65,60 +64,6 @@ def matrix_diag(t):
diag_el = t.masked_select(diag_mask)
return rearrange(diag_el, '(b d) -> b d', d = num_diag_el)

# distributed helpers

def all_gather_variable_dim(t, dim = 0):
device, rank, world_size = t.device, distributed.get_rank(), distributed.get_world_size()

size = torch.tensor(t.shape[dim], device = device, dtype = torch.long)
sizes = [torch.empty_like(size, device = device, dtype = torch.long) for i in range(world_size)]
distributed.all_gather(sizes, size)

sizes = torch.stack(sizes)
max_size = sizes.amax().item()
padded_t = pad_dim_to(t, max_size, dim = dim)

gathered_tensors = [torch.empty(padded_t.shape, device = device, dtype = padded_t.dtype) for i in range(world_size)]
distributed.all_gather(gathered_tensors, padded_t)

gathered_tensor = torch.cat(gathered_tensors, dim = dim)
seq = torch.arange(max_size, device = device)

mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
mask = rearrange(mask, 'i j -> (i j)')
seq = torch.arange(mask.shape[-1], device = device)
indices = seq[mask]

gathered_tensor = gathered_tensor.index_select(dim, indices)
sizes = sizes.tolist()

return gathered_tensor, sizes

class MaybeAllGather(Function):
@staticmethod
def forward(ctx, x, dim):
is_distributed = distributed.is_initialized() and distributed.get_world_size() > 1
ctx.is_distributed = is_distributed
ctx.dim = dim

if not is_distributed:
return x

x, batch_sizes = all_gather_variable_dim(x, dim = dim)
ctx.batch_sizes = batch_sizes
return x

@staticmethod
def backward(ctx, grads):
if not ctx.is_distributed:
return grads, None

batch_sizes, rank = ctx.batch_sizes, distributed.get_rank()
grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
return grads_by_rank[rank], None

maybe_all_gather = MaybeAllGather.apply

# checkpointing helper function

def make_checkpointable(fn):
Expand Down Expand Up @@ -641,6 +586,9 @@ def __init__(

self.multiview_loss_weight = multiview_loss_weight

# is distributed or not
self.requires_all_gather = distributed.is_initialized() and distributed.get_world_size() > 1

def forward(
self,
text,
Expand Down Expand Up @@ -803,12 +751,15 @@ def forward(

# maybe distributed all gather

text_latents = maybe_all_gather(text_latents, 1)
image_latents = maybe_all_gather(image_latents, 1)
if self.requires_all_gather:
latents = torch.stack((text_latents, image_latents))
latents, sizes = all_gather(latents, 2, None)
text_latents, image_latents = latents

if self.extra_latent_projection:
text_latents_extra = maybe_all_gather(text_latents_extra, 1)
image_latents_extra = maybe_all_gather(image_latents_extra, 1)
if self.extra_latent_projection:
latents_extra = torch.stack((text_latents_extra, image_latents_extra))
latents_extra, _ = all_gather(latents_extra, 2, sizes)
text_latents_extra, image_latents_extra = latents_extra

# contrastive loss

Expand Down

0 comments on commit 95b722f

Please sign in to comment.