Skip to content

Commit

Permalink
add forward trick to CustomCLIP
Browse files Browse the repository at this point in the history
  • Loading branch information
mehdidc committed Jan 31, 2023
1 parent dada50a commit ad44bf6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
9 changes: 9 additions & 0 deletions src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def encode_text(self, text, normalize: bool = False):
features = self.text(text)
return F.normalize(features, dim=-1) if normalize else features

"""
def forward(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
Expand All @@ -289,6 +290,14 @@ def forward(self, image, text):
"logit_scale": self.logit_scale.exp()
}
return image_features, text_features, self.logit_scale.exp()
"""
def forward(self, image, text, clamp_logit_scale_to=None):
image_features = self.encode_image(image, normalize=True) if image is not None else None
text_features = self.encode_text(text, normalize=True) if text is not None else None
if clamp_logit_scale_to is not None:
with torch.no_grad():
self.logit_scale.data.clamp_(0, clamp_logit_scale_to)
return image_features, text_features, self.logit_scale.exp()


def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
Expand Down
5 changes: 4 additions & 1 deletion src/training/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ def main(args):
wrap,
)
print(f"Before FSTP parameter num: {sum(p.numel() for p in model.parameters())}")
print(f"Before FSTP VISUAL parameter num: {sum(p.numel() for p in model.visual.parameters())}")
#print(f"Before FSTP TEXT parameter num: {sum(p.numel() for p in model.transformer.parameters())}")

print(f"Before FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB")
mp = MixedPrecision(
#param_dtype=torch.bfloat16,
Expand All @@ -292,7 +295,7 @@ def main(args):
ResidualAttentionBlock,
},
),
device_id=None if args.fsdp_init_on_cpu else device,
device_id=device,
)

# avoid "RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory."
Expand Down

0 comments on commit ad44bf6

Please sign in to comment.