Skip to content

Commit

Permalink
add an interesting regularization from another paper
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 21, 2023
1 parent 95b722f commit 867ab60
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 3 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,15 @@ loss.backward()
volume = {abs/2208.07220}
}
```

```bibtex
@misc{shi2023enhance,
title = {Enhance audio generation controllability through representation similarity regularization},
author = {Yangyang Shi and Gael Le Lan and Varun Nagaraja and Zhaoheng Ni and Xinhao Mei and Ernie Chang and Forrest Iandola and Yang Liu and Vikas Chandra},
year = {2023},
eprint = {2309.08773},
archivePrefix = {arXiv},
primaryClass = {cs.SD}
}
```

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'x-clip',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.12.9',
version = '0.14.1',
license='MIT',
description = 'X-CLIP',
author = 'Phil Wang',
Expand Down
31 changes: 29 additions & 2 deletions x_clip/x_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ def __init__(
image_ssl_loss_weight = 0.05,
multiview_loss_weight = 0.1,
checkpoint_during_training = False,
sim_reg_loss_weight = 0.,
**kwargs
):
super().__init__()
Expand Down Expand Up @@ -589,6 +590,10 @@ def __init__(
# is distributed or not
self.requires_all_gather = distributed.is_initialized() and distributed.get_world_size() > 1

# use the similarity regularization proposed in https://arxiv.org/abs/2309.08773
self.sim_reg_loss_weight = sim_reg_loss_weight
self.has_sim_reg_loss = sim_reg_loss_weight > 0.

def forward(
self,
text,
Expand All @@ -602,7 +607,7 @@ def forward(
aug_text = None, # augmented text (for multiview)
aug_image = None # augmented image (for multiview)
):
b, device = text.shape[0], text.device
batch, device = text.shape[0], text.device

# derive text mask

Expand Down Expand Up @@ -756,11 +761,28 @@ def forward(
latents, sizes = all_gather(latents, 2, None)
text_latents, image_latents = latents

batch = sizes.sum().item()

if self.extra_latent_projection:
latents_extra = torch.stack((text_latents_extra, image_latents_extra))
latents_extra, _ = all_gather(latents_extra, 2, sizes)
text_latents_extra, image_latents_extra = latents_extra

# maybe similarity regularize

sim_reg_loss = 0.

if self.has_sim_reg_loss:
diag_mask = torch.eye(batch, device = device, dtype = torch.bool)
off_diag_mask = rearrange(~diag_mask, '... -> 1 ...')

text_sim, image_sim, text_extra_sim, image_extra_sim = map(lambda t: einsum('m i ... d, m j ... d -> m ... i j', t, t)[off_diag_mask], (text_latents, image_latents, text_latents_extra, image_latents_extra))

sim_reg_loss = (
F.mse_loss(text_sim, image_sim) +
F.mse_loss(text_extra_sim, image_extra_sim)
) / 2

# contrastive loss

"""
Expand Down Expand Up @@ -810,7 +832,7 @@ def forward(
# denominator

if self.decoupled_contrastive_learning:
pos_mask = torch.eye(b, device = device, dtype = torch.bool)
pos_mask = torch.eye(batch, device = device, dtype = torch.bool)
text_to_image_exp, image_to_text_exp = map(lambda t: t.masked_fill(pos_mask, 0.), (text_to_image_exp, image_to_text_exp))

text_to_image_denom, image_to_text_denom = map(lambda t: t.sum(dim = -1), (text_to_image_exp, image_to_text_exp))
Expand Down Expand Up @@ -845,4 +867,9 @@ def forward(
if is_multiview:
loss = loss + multiview_cl_loss.mean() * multiview_loss_weight

# add similarity regularization loss with weight if needed

if self.has_sim_reg_loss:
loss = loss + sim_reg_loss * self.sim_reg_loss_weight

return loss

0 comments on commit 867ab60

Please sign in to comment.