diff --git a/setup.py b/setup.py index 914140a..584f6c1 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'x-clip', packages = find_packages(exclude=[]), include_package_data = True, - version = '0.12.4', + version = '0.12.5', license='MIT', description = 'X-CLIP', author = 'Phil Wang', diff --git a/x_clip/x_clip.py b/x_clip/x_clip.py index d94f7c9..08b932d 100644 --- a/x_clip/x_clip.py +++ b/x_clip/x_clip.py @@ -67,53 +67,55 @@ def matrix_diag(t): # distributed helpers -def all_gather_variable_batch(t): +def all_gather_variable_batch(t, dim = 0): device, rank, world_size = t.device, distributed.get_rank(), distributed.get_world_size() - batch_size = torch.tensor([t.shape[0]], device = device, dtype = torch.long) - batch_sizes = [torch.empty(1, device = device, dtype = torch.long) for i in range(world_size)] + batch_size = torch.tensor(t.shape[dim], device = device, dtype = torch.long) + batch_sizes = [torch.empty_like(batch_size, device = device, dtype = torch.long) for i in range(world_size)] distributed.all_gather(batch_sizes, batch_size) - batch_sizes = torch.cat(batch_sizes) + batch_sizes = torch.stack(batch_sizes) max_batch_size = batch_sizes.amax().item() - padded_t = pad_dim_to(t, max_batch_size) + padded_t = pad_dim_to(t, max_batch_size, dim = dim) gathered_tensors = [torch.empty(padded_t.shape, device = device, dtype = padded_t.dtype) for i in range(world_size)] distributed.all_gather(gathered_tensors, padded_t) - gathered_tensor = torch.cat(gathered_tensors, dim = 0) + gathered_tensor = torch.cat(gathered_tensors, dim = dim) seq = torch.arange(max_batch_size, device = device) mask = rearrange(seq, 'j -> 1 j') < rearrange(batch_sizes, 'i -> i 1') mask = rearrange(mask, 'i j -> (i j)') + seq = torch.arange(mask.shape[-1], device = device) + indices = seq[mask] - gathered_tensor = gathered_tensor[mask] + gathered_tensor = gathered_tensor.index_select(dim, indices) batch_sizes = batch_sizes.tolist() return gathered_tensor, batch_sizes class MaybeAllGather(Function): @staticmethod - def forward(ctx, x): + def forward(ctx, x, dim): is_distributed = distributed.is_initialized() and distributed.get_world_size() > 1 ctx.is_distributed = is_distributed + ctx.dim = dim if not is_distributed: return x - x = rearrange(x, 'm b ... -> b m ...') - x, batch_sizes = all_gather_variable_batch(x) + x, batch_sizes = all_gather_variable_batch(x, dim = dim) ctx.batch_sizes = batch_sizes - return rearrange(x, 'b m ... -> m b ...') + return x @staticmethod def backward(ctx, grads): if not ctx.is_distributed: - return grads + return grads, None batch_sizes, rank = ctx.batch_sizes, distributed.get_rank() - grads_by_rank = grads.split(batch_sizes, dim = 0) - return grads_by_rank[rank] + grads_by_rank = grads.split(batch_sizes, dim = ctx.dim) + return grads_by_rank[rank], None maybe_all_gather = MaybeAllGather.apply @@ -801,12 +803,12 @@ def forward( # maybe distributed all gather - text_latents = maybe_all_gather(text_latents) - image_latents = maybe_all_gather(image_latents) + text_latents = maybe_all_gather(text_latents, 1) + image_latents = maybe_all_gather(image_latents, 1) if self.extra_latent_projection: - text_latents_extra = maybe_all_gather(text_latents_extra) - image_latents_extra = maybe_all_gather(image_latents_extra) + text_latents_extra = maybe_all_gather(text_latents_extra, 1) + image_latents_extra = maybe_all_gather(image_latents_extra, 1) # contrastive loss