From 867ab60c4cc8c735b32eba1a384f31be28850349 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 21 Sep 2023 11:09:37 -0700 Subject: [PATCH] add an interesting regularization from another paper --- README.md | 12 ++++++++++++ setup.py | 2 +- x_clip/x_clip.py | 31 +++++++++++++++++++++++++++++-- 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index b1ea8ca..1fa91f1 100644 --- a/README.md +++ b/README.md @@ -417,3 +417,15 @@ loss.backward() volume = {abs/2208.07220} } ``` + +```bibtex +@misc{shi2023enhance, + title = {Enhance audio generation controllability through representation similarity regularization}, + author = {Yangyang Shi and Gael Le Lan and Varun Nagaraja and Zhaoheng Ni and Xinhao Mei and Ernie Chang and Forrest Iandola and Yang Liu and Vikas Chandra}, + year = {2023}, + eprint = {2309.08773}, + archivePrefix = {arXiv}, + primaryClass = {cs.SD} +} +``` + diff --git a/setup.py b/setup.py index 8fe03d1..5834eed 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'x-clip', packages = find_packages(exclude=[]), include_package_data = True, - version = '0.12.9', + version = '0.14.1', license='MIT', description = 'X-CLIP', author = 'Phil Wang', diff --git a/x_clip/x_clip.py b/x_clip/x_clip.py index e151098..7bd6510 100644 --- a/x_clip/x_clip.py +++ b/x_clip/x_clip.py @@ -451,6 +451,7 @@ def __init__( image_ssl_loss_weight = 0.05, multiview_loss_weight = 0.1, checkpoint_during_training = False, + sim_reg_loss_weight = 0., **kwargs ): super().__init__() @@ -589,6 +590,10 @@ def __init__( # is distributed or not self.requires_all_gather = distributed.is_initialized() and distributed.get_world_size() > 1 + # use the similarity regularization proposed in https://arxiv.org/abs/2309.08773 + self.sim_reg_loss_weight = sim_reg_loss_weight + self.has_sim_reg_loss = sim_reg_loss_weight > 0. + def forward( self, text, @@ -602,7 +607,7 @@ def forward( aug_text = None, # augmented text (for multiview) aug_image = None # augmented image (for multiview) ): - b, device = text.shape[0], text.device + batch, device = text.shape[0], text.device # derive text mask @@ -756,11 +761,28 @@ def forward( latents, sizes = all_gather(latents, 2, None) text_latents, image_latents = latents + batch = sizes.sum().item() + if self.extra_latent_projection: latents_extra = torch.stack((text_latents_extra, image_latents_extra)) latents_extra, _ = all_gather(latents_extra, 2, sizes) text_latents_extra, image_latents_extra = latents_extra + # maybe similarity regularize + + sim_reg_loss = 0. + + if self.has_sim_reg_loss: + diag_mask = torch.eye(batch, device = device, dtype = torch.bool) + off_diag_mask = rearrange(~diag_mask, '... -> 1 ...') + + text_sim, image_sim, text_extra_sim, image_extra_sim = map(lambda t: einsum('m i ... d, m j ... d -> m ... i j', t, t)[off_diag_mask], (text_latents, image_latents, text_latents_extra, image_latents_extra)) + + sim_reg_loss = ( + F.mse_loss(text_sim, image_sim) + + F.mse_loss(text_extra_sim, image_extra_sim) + ) / 2 + # contrastive loss """ @@ -810,7 +832,7 @@ def forward( # denominator if self.decoupled_contrastive_learning: - pos_mask = torch.eye(b, device = device, dtype = torch.bool) + pos_mask = torch.eye(batch, device = device, dtype = torch.bool) text_to_image_exp, image_to_text_exp = map(lambda t: t.masked_fill(pos_mask, 0.), (text_to_image_exp, image_to_text_exp)) text_to_image_denom, image_to_text_denom = map(lambda t: t.sum(dim = -1), (text_to_image_exp, image_to_text_exp)) @@ -845,4 +867,9 @@ def forward( if is_multiview: loss = loss + multiview_cl_loss.mean() * multiview_loss_weight + # add similarity regularization loss with weight if needed + + if self.has_sim_reg_loss: + loss = loss + sim_reg_loss * self.sim_reg_loss_weight + return loss