-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TextTextCLIP #323
base: main
Are you sure you want to change the base?
Add TextTextCLIP #323
Conversation
src/training/params.py
Outdated
@@ -326,13 +326,54 @@ def parse_args(args): | |||
action='store_true', | |||
help="Freeze BatchNorm running stats in image tower for any locked layers.", | |||
) | |||
parser.add_argument( | |||
"--lock-doc", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about a "lock-tower-1" and "lock-tower-2" param instead ?
query/doc is not the only thing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. I will rename these variables immediately. Do you think we should also rename variables like image_features
in train.py
and main.py'? Will renaming variables cause confusion for the original text-image model? Currently
train.py` is fully compatible with TextTextCLIP. No modification other than renaming variables is needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have tested TextTextCLIP locally and it seems to work smoothly. But the script fails some tests here. I will make adjustments.
@@ -326,13 +338,54 @@ def parse_args(args): | |||
action='store_true', | |||
help="Freeze BatchNorm running stats in image tower for any locked layers.", | |||
) | |||
parser.add_argument( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wonder if this can be reconciled with the normal locking params
src/training/data.py
Outdated
class TextPairDataset(Dataset): | ||
def __init__(self, input_filename, text_a_key, text_b_key, tokenizer=None): | ||
logging.debug(f'Loading parquet data from {input_filename}.') | ||
df = pd.read_parquet(input_filename) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is unlikely to scale
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. I will update this. Do you have any suggestions? Do you think we should use the pyarrow
package to load .parquet
files?
src/open_clip/model.py
Outdated
@@ -248,6 +248,47 @@ def forward(self, image, text): | |||
return image_features, text_features, self.logit_scale.exp() | |||
|
|||
|
|||
|
|||
class TextTextCLIP(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is completely duplicated from above class, I wonder if we could reconcile it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could use this one as the general model? It's not specific about modality. We can refer to image_features
as features_a
and text_features
as features_b
.
src/open_clip/factory.py
Outdated
mean=image_mean, | ||
std=image_std | ||
) | ||
if not text_to_text: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about checking if model.visual exists instead ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we merge TextTextCLIP
to CustomCLIP
to get a more general model, then model.visual
might not exist at all?
src/open_clip/factory.py
Outdated
mean=image_mean, | ||
std=image_std | ||
) | ||
if not text_to_text: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about checking if model.visual exists instead ?
src/open_clip/factory.py
Outdated
@@ -179,9 +193,10 @@ def create_model( | |||
if precision in ("fp16", "bf16"): | |||
convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16) | |||
|
|||
if not text_to_text: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about checking if model.visual exists instead ?
Can you rebase on master ? |
Done merging but still working on some minor inconsistencies that need to be fixed. |
@rom1504 Here is the latest code for TextTextCLIP. It is tested on the stability cluster and the evaluation code is also included. I have added an example script at |
This would really benefits from reviews @rwightman @mitchellnw @iejMac |
@rom1504 k, will try and look at it soon |
return logits_per_feature_a, logits_per_feature_b | ||
|
||
|
||
def forward(self, image_features=None, text_features=None, logit_scale=None, text_a_features=None, text_b_features=None, output_dict=False): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This section feels a little awkward, maybe it should just always take features_a and features_b, not optional, like the prior code?
Same with logit_scale, why is that optional now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, yeah, wondering for clarity if it'd make more sense to have a specific text-text (CLLP?) loss with appropriate naming, possibly share the gather w/ generic names but multiple sets of args w/ different naming schemes and erasing the task specific names (that aid a bit with comprehension) seems less desirable than a bit of duplication...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your comments! This is an artifact of trying to reconcile the original CLIP and the text model. But I agree with you that this is a bit awkward. I could add a loss wrapper for the text model. What do you think?
By the way, do you have any ideas as to the naming of the text model? TextTextCLIP sounds a bit redundant to me. How about CTTP? CLLP? LATTE (contrastive LAnguage-To-TExt pretraining)? CTP (contrastive text pretraining)? Do you have a better suggestions?
overall things look pretty good, I'm trying to get over a mental block re the loss naming, I realize why the feature_a/b changes were made to the loss but I feel it harms ease of understanding re the most common use case, especially for newcomers to the core loss fn... hmm hnm |
This pull request adds TextTextCLIP (CLIP-like text-to-text contrastive retrieval model) to the main branch.
It is still a work in progress.
Tasks
model.py
factory.py
to load modeldata.py
to load text data