diff --git a/setup.py b/setup.py index 24b7fe6..ee9de63 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'x-clip', packages = find_packages(exclude=[]), include_package_data = True, - version = '0.12.7', + version = '0.12.8', license='MIT', description = 'X-CLIP', author = 'Phil Wang', @@ -19,7 +19,7 @@ ], install_requires=[ 'beartype', - 'einops>=0.3', + 'einops>=0.6', 'ftfy', 'regex', 'torch>=1.6', diff --git a/x_clip/distributed.py b/x_clip/distributed.py new file mode 100644 index 0000000..5ca9965 --- /dev/null +++ b/x_clip/distributed.py @@ -0,0 +1,48 @@ +from torch.autograd import Function +import torch.distributed as distributed + +# distributed helpers + +def all_gather_variable_dim(t, dim = 0, sizes = None): + device, rank, world_size = t.device, distributed.get_rank(), distributed.get_world_size() + + if not exists(sizes): + size = torch.tensor(t.shape[dim], device = device, dtype = torch.long) + sizes = [torch.empty_like(size, device = device, dtype = torch.long) for i in range(world_size)] + distributed.all_gather(sizes, size) + sizes = torch.stack(sizes) + + max_size = sizes.amax().item() + padded_t = pad_dim_to(t, max_size, dim = dim) + + gathered_tensors = [torch.empty(padded_t.shape, device = device, dtype = padded_t.dtype) for i in range(world_size)] + distributed.all_gather(gathered_tensors, padded_t) + + gathered_tensor = torch.cat(gathered_tensors, dim = dim) + seq = torch.arange(max_size, device = device) + + mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1') + mask = rearrange(mask, 'i j -> (i j)') + seq = torch.arange(mask.shape[-1], device = device) + indices = seq[mask] + + gathered_tensor = gathered_tensor.index_select(dim, indices) + + return gathered_tensor, sizes + +class AllGather(Function): + @staticmethod + def forward(ctx, x, dim, sizes): + assert distributed.is_initialized() and distributed.get_world_size() > 1 + x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes) + ctx.batch_sizes = batch_sizes.tolist() + ctx.dim = dim + return x, batch_sizes + + @staticmethod + def backward(ctx, grads, _): + batch_sizes, rank = ctx.batch_sizes, distributed.get_rank() + grads_by_rank = grads.split(batch_sizes, dim = ctx.dim) + return grads_by_rank[rank], None, None + +all_gather = AllGather.apply diff --git a/x_clip/x_clip.py b/x_clip/x_clip.py index a9d19b8..e151098 100644 --- a/x_clip/x_clip.py +++ b/x_clip/x_clip.py @@ -5,17 +5,16 @@ import torch import torch.nn.functional as F +import torch.distributed as distributed from torch import nn, einsum from torch.utils.checkpoint import checkpoint -from torch.autograd import Function -import torch.distributed as distributed - from einops import rearrange, repeat, reduce from einops.layers.torch import Rearrange, Reduce from x_clip.mlm import MLM from x_clip.visual_ssl import SimSiam, SimCLR +from x_clip.distributed import all_gather # helper functions @@ -65,60 +64,6 @@ def matrix_diag(t): diag_el = t.masked_select(diag_mask) return rearrange(diag_el, '(b d) -> b d', d = num_diag_el) -# distributed helpers - -def all_gather_variable_dim(t, dim = 0): - device, rank, world_size = t.device, distributed.get_rank(), distributed.get_world_size() - - size = torch.tensor(t.shape[dim], device = device, dtype = torch.long) - sizes = [torch.empty_like(size, device = device, dtype = torch.long) for i in range(world_size)] - distributed.all_gather(sizes, size) - - sizes = torch.stack(sizes) - max_size = sizes.amax().item() - padded_t = pad_dim_to(t, max_size, dim = dim) - - gathered_tensors = [torch.empty(padded_t.shape, device = device, dtype = padded_t.dtype) for i in range(world_size)] - distributed.all_gather(gathered_tensors, padded_t) - - gathered_tensor = torch.cat(gathered_tensors, dim = dim) - seq = torch.arange(max_size, device = device) - - mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1') - mask = rearrange(mask, 'i j -> (i j)') - seq = torch.arange(mask.shape[-1], device = device) - indices = seq[mask] - - gathered_tensor = gathered_tensor.index_select(dim, indices) - sizes = sizes.tolist() - - return gathered_tensor, sizes - -class MaybeAllGather(Function): - @staticmethod - def forward(ctx, x, dim): - is_distributed = distributed.is_initialized() and distributed.get_world_size() > 1 - ctx.is_distributed = is_distributed - ctx.dim = dim - - if not is_distributed: - return x - - x, batch_sizes = all_gather_variable_dim(x, dim = dim) - ctx.batch_sizes = batch_sizes - return x - - @staticmethod - def backward(ctx, grads): - if not ctx.is_distributed: - return grads, None - - batch_sizes, rank = ctx.batch_sizes, distributed.get_rank() - grads_by_rank = grads.split(batch_sizes, dim = ctx.dim) - return grads_by_rank[rank], None - -maybe_all_gather = MaybeAllGather.apply - # checkpointing helper function def make_checkpointable(fn): @@ -641,6 +586,9 @@ def __init__( self.multiview_loss_weight = multiview_loss_weight + # is distributed or not + self.requires_all_gather = distributed.is_initialized() and distributed.get_world_size() > 1 + def forward( self, text, @@ -803,12 +751,15 @@ def forward( # maybe distributed all gather - text_latents = maybe_all_gather(text_latents, 1) - image_latents = maybe_all_gather(image_latents, 1) + if self.requires_all_gather: + latents = torch.stack((text_latents, image_latents)) + latents, sizes = all_gather(latents, 2, None) + text_latents, image_latents = latents - if self.extra_latent_projection: - text_latents_extra = maybe_all_gather(text_latents_extra, 1) - image_latents_extra = maybe_all_gather(image_latents_extra, 1) + if self.extra_latent_projection: + latents_extra = torch.stack((text_latents_extra, image_latents_extra)) + latents_extra, _ = all_gather(latents_extra, 2, sizes) + text_latents_extra, image_latents_extra = latents_extra # contrastive loss