You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I’m using this repository for VLP on our medical dataset, where the false negative issue is particularly severe. To address this, my first step is to modify the CLIP supervision signal by incorporating multiple positive pairs within a single column or row. Since the current cross-entropy-based CLIP loss implementation does not support this modification, I have implemented an alternative CLIP loss using an identity matrix as the supervision signal. This setup allows further adjustments to correct false negatives later.
Could your experts @rwightman@mitchellnw@rom1504 review my implementation to confirm its correctness? I have tried my best to follow the current implementation. Additionally, do you have any suggestions for mitigating false negatives more effectively?
I truly appreciate your help!
Here is my implementation:
from open_clip.loss import *
class CustomClipLoss(ClipLoss):
def __init__(
self,
local_loss=False,
gather_with_grad=False,
cache_labels=False,
rank=0,
world_size=1,
use_horovod=False,
):
super().__init__(
local_loss=local_loss,
gather_with_grad=gather_with_grad,
cache_labels=cache_labels,
rank=rank,
world_size=world_size,
use_horovod=use_horovod
)
def get_ground_truth(self, device, num_logits) -> torch.Tensor:
"""
Returns a binary matrix for the alternative loss formulation.
If world_size == 1 or local_loss is False:
- Returns an identity matrix of shape [num_logits, num_logits].
If world_size > 1 and local_loss is True:
- The logits are expected to have shape [num_logits, num_logits * world_size],
so we construct a target matrix of that shape. Only the columns corresponding
to the current rank (i.e. columns [rank*num_logits:(rank+1)*num_logits]) form an
identity matrix (indicating the positives); all other positions are zeros.
"""
if self.prev_num_logits != num_logits or device not in self.labels:
if self.world_size > 1 and self.local_loss:
target_shape = (num_logits, num_logits * self.world_size)
labels = torch.zeros(target_shape, device=device)
start = self.rank * num_logits
end = (self.rank + 1) * num_logits
labels[:, start:end] = torch.eye(num_logits, device=device)
else:
labels = torch.eye(num_logits, device=device)
if self.cache_labels:
self.labels[device] = labels
self.prev_num_logits = num_logits
else:
labels = self.labels[device]
return labels
def forward(self, image_features, text_features, logit_scale, output_dict=False):
"""
Alternative forward where we assume one positive per sample.
Instead of using F.cross_entropy with index targets, we use the log_softmax
multiplied elementwise by a one-hot target matrix.
"""
device = image_features.device
logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
# Get one-hot targets (identity matrix) for the batch
sim_targets = self.get_ground_truth(device, logits_per_image.shape[0])
# Compute the loss by summing the log probabilities at the positive positions.
# For a one-hot target, this is equivalent to the standard cross entropy.
loss_i2t = - torch.sum(F.log_softmax(logits_per_image, dim=1) * sim_targets, dim=1).mean()
loss_t2i = - torch.sum(F.log_softmax(logits_per_text, dim=1) * sim_targets, dim=1).mean()
total_loss = (loss_i2t + loss_t2i) / 2
return {"contrastive_loss": total_loss} if output_dict else total_loss
I haven't tested the local_loss and the gather_with_grad features.
This discussion was converted from issue #1027 on February 14, 2025 15:18.
Heading
Bold
Italic
Quote
Code
Link
Numbered list
Unordered list
Task list
Attach files
Mention
Reference
Menu
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi,
Thank you for this excellent work!
I’m using this repository for VLP on our medical dataset, where the false negative issue is particularly severe. To address this, my first step is to modify the CLIP supervision signal by incorporating multiple positive pairs within a single column or row. Since the current cross-entropy-based CLIP loss implementation does not support this modification, I have implemented an alternative CLIP loss using an identity matrix as the supervision signal. This setup allows further adjustments to correct false negatives later.
Could your experts @rwightman @mitchellnw @rom1504 review my implementation to confirm its correctness? I have tried my best to follow the current implementation. Additionally, do you have any suggestions for mitigating false negatives more effectively?
I truly appreciate your help!
Here is my implementation:
I haven't tested the local_loss and the gather_with_grad features.
Beta Was this translation helpful? Give feedback.
All reactions