From aac5d53ebd1db188b0e5c696308b2b97029a02d6 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 4 Sep 2023 10:14:40 -0700 Subject: [PATCH] automatically take care of distributed setting --- setup.py | 2 +- x_clip/x_clip.py | 70 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 668d5ee..914140a 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'x-clip', packages = find_packages(exclude=[]), include_package_data = True, - version = '0.12.2', + version = '0.12.4', license='MIT', description = 'X-CLIP', author = 'Phil Wang', diff --git a/x_clip/x_clip.py b/x_clip/x_clip.py index ac2da70..d94f7c9 100644 --- a/x_clip/x_clip.py +++ b/x_clip/x_clip.py @@ -7,6 +7,10 @@ import torch.nn.functional as F from torch import nn, einsum from torch.utils.checkpoint import checkpoint + +from torch.autograd import Function +import torch.distributed as distributed + from einops import rearrange, repeat, reduce from einops.layers.torch import Rearrange, Reduce @@ -40,6 +44,11 @@ def masked_mean(t, mask, dim = 1, eps = 1e-6): denom = mask.sum(dim = dim).clamp(min = eps) return numer / denom +def pad_dim_to(t, length, dim = 0): + pad_length = length - t.shape[dim] + zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1) + return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length)) + def log(t, eps = 1e-20): return torch.log(t + eps) @@ -56,6 +65,58 @@ def matrix_diag(t): diag_el = t.masked_select(diag_mask) return rearrange(diag_el, '(b d) -> b d', d = num_diag_el) +# distributed helpers + +def all_gather_variable_batch(t): + device, rank, world_size = t.device, distributed.get_rank(), distributed.get_world_size() + + batch_size = torch.tensor([t.shape[0]], device = device, dtype = torch.long) + batch_sizes = [torch.empty(1, device = device, dtype = torch.long) for i in range(world_size)] + distributed.all_gather(batch_sizes, batch_size) + + batch_sizes = torch.cat(batch_sizes) + max_batch_size = batch_sizes.amax().item() + padded_t = pad_dim_to(t, max_batch_size) + + gathered_tensors = [torch.empty(padded_t.shape, device = device, dtype = padded_t.dtype) for i in range(world_size)] + distributed.all_gather(gathered_tensors, padded_t) + + gathered_tensor = torch.cat(gathered_tensors, dim = 0) + seq = torch.arange(max_batch_size, device = device) + + mask = rearrange(seq, 'j -> 1 j') < rearrange(batch_sizes, 'i -> i 1') + mask = rearrange(mask, 'i j -> (i j)') + + gathered_tensor = gathered_tensor[mask] + batch_sizes = batch_sizes.tolist() + + return gathered_tensor, batch_sizes + +class MaybeAllGather(Function): + @staticmethod + def forward(ctx, x): + is_distributed = distributed.is_initialized() and distributed.get_world_size() > 1 + ctx.is_distributed = is_distributed + + if not is_distributed: + return x + + x = rearrange(x, 'm b ... -> b m ...') + x, batch_sizes = all_gather_variable_batch(x) + ctx.batch_sizes = batch_sizes + return rearrange(x, 'b m ... -> m b ...') + + @staticmethod + def backward(ctx, grads): + if not ctx.is_distributed: + return grads + + batch_sizes, rank = ctx.batch_sizes, distributed.get_rank() + grads_by_rank = grads.split(batch_sizes, dim = 0) + return grads_by_rank[rank] + +maybe_all_gather = MaybeAllGather.apply + # checkpointing helper function def make_checkpointable(fn): @@ -738,6 +799,15 @@ def forward( text_latents_extra = rearrange(text_latents_extra, '(m b) ... -> m b ...', m = num_batch_texts) image_latents_extra = rearrange(image_latents_extra, '(m b) ... -> m b ...', m = num_batch_images) + # maybe distributed all gather + + text_latents = maybe_all_gather(text_latents) + image_latents = maybe_all_gather(image_latents) + + if self.extra_latent_projection: + text_latents_extra = maybe_all_gather(text_latents_extra) + image_latents_extra = maybe_all_gather(image_latents_extra) + # contrastive loss """