diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 6b1abbf16..829c9b2fc 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -172,8 +172,9 @@ def __init__( self.ln_final = text.ln_final self.text_projection = text.text_projection self.register_buffer('attn_mask', text.attn_mask, persistent=False) - - self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + # Num Nested Dims in MRL -- hard coding to 4 + self.logit_scale = nn.Parameter(torch.ones([4]) * np.log(1 / 0.07)) def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 diff --git a/src/training/train.py b/src/training/train.py index 83f4f6fa7..afab836ca 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -95,7 +95,11 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w if args.accum_freq == 1: with autocast(): image_features, text_features, logit_scale = model(images, texts) - total_loss = loss(image_features, text_features, logit_scale) + # Hard coding 4 MRL dims -- [dim, dim/2, dim/4, dim/8] + rep_size = image_features.shape[1] + total_loss = loss(image_features, text_features, logit_scale[0]) + for mrl_i in range(1, 4): + total_loss += loss(image_features[:, :(rep_size//(2**mrl_i))], text_features[:, :(rep_size//(2**mrl_i))], logit_scale[mrl_i]) backward(total_loss, scaler) else: @@ -127,7 +131,7 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w accum_image_features[:j] + [chunk_image_features] + accum_image_features[j + 1:]) text_features = torch.cat( accum_text_features[:j] + [chunk_text_features] + accum_text_features[j + 1:]) - total_loss = loss(image_features, text_features, logit_scale) + total_loss = loss(image_features, text_features, logit_scale[0]) backward(total_loss, scaler) if scaler is not None: @@ -155,7 +159,9 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w # Note: we clamp to 4.6052 = ln(100), as in the original paper. with torch.no_grad(): - unwrap_model(model).logit_scale.clamp_(0, math.log(100)) + for mrl_i in range(0, 4): + unwrap_model(model).logit_scale[mrl_i].clamp_(0, math.log(100)) + batch_time_m.update(time.time() - end) end = time.time() @@ -168,7 +174,7 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w # NOTE loss is coarsely sampled, just master node and per log update loss_m.update(total_loss.item(), batch_size) - logit_scale_scalar = logit_scale.item() + logit_scale_scalar = logit_scale[0].item() logging.info( f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "