We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hi lucidrains,
Try this and it will NaN within 100 steps (latest Github code). The loss looks fine before NaN.
import torch torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True import random import numpy as np seed = 42 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) num_text_tokens = 10000 batch_sz = 12 text_seq_len = 256 visual_image_size = 256 # mock data data_sz = 1000 all_text = torch.randint(0, num_text_tokens, (data_sz, text_seq_len)).cuda() all_images = torch.randn(data_sz, 3, visual_image_size, visual_image_size).cuda() text = torch.zeros((batch_sz, text_seq_len), dtype=torch.long).cuda() images = torch.zeros((batch_sz, 3, visual_image_size, visual_image_size)).cuda() ########################################################################################## import wandb import datetime wandb.init(project="Test", name=datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), save_code=False) from x_clip import CLIP clip = CLIP( dim_text = 512, dim_image = 512, dim_latent = 512, num_text_tokens = num_text_tokens, text_enc_depth = 6, text_seq_len = text_seq_len, text_heads = 8, visual_enc_depth = 6, visual_image_size = visual_image_size, visual_patch_size = 32, visual_heads = 8, use_all_token_embeds = False, # whether to use fine-grained contrastive learning (FILIP) decoupled_contrastive_learning = True, # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL) extra_latent_projection = True, # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB) use_visual_ssl = True, # whether to do self supervised learning on iages visual_ssl_type = 'simclr', # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP use_mlm = False, # use masked language learning (MLM) on text (DeCLIP) text_ssl_loss_weight = 0.05, # weight for text MLM loss image_ssl_loss_weight = 0.05 # weight for image self-supervised learning loss ).cuda() optimizer = torch.optim.Adam(clip.parameters(), lr=1e-4, betas=(0.9, 0.99)) for step in range(999999): for i in range(batch_sz): data_id = random.randrange(0, data_sz - 1) text[i] = all_text[data_id] images[i] = all_images[data_id] loss = clip( text, images, freeze_image_encoder = False, # whether to freeze image encoder if using a pretrained image net, proposed by LiT paper return_loss = True # needs to be set to True to return contrastive loss ) clip.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(clip.parameters(), 1.0) optimizer.step() now_loss = loss.item() wandb.log({"loss": now_loss}, step = step) print(step, now_loss) if 'nan' in str(now_loss): break
The text was updated successfully, but these errors were encountered:
@BlinkDL Hey Peng Bo! So I quickly checked the script and indeed it NaNs, but not if the visual_ssl is turned off
NaNs
visual_ssl
I suspect it has something to do with augmenting the randomly created images in the visual SSL, but not completely sure
Sorry, something went wrong.
No branches or pull requests
Hi lucidrains,
Try this and it will NaN within 100 steps (latest Github code). The loss looks fine before NaN.
The text was updated successfully, but these errors were encountered: