From a0dfbd16bbaa967a71f51773037ea6bf32b10497 Mon Sep 17 00:00:00 2001 From: Taylor Brown Date: Wed, 27 Jul 2022 15:44:11 -0500 Subject: [PATCH 1/2] Added ability to load pre-computed vectors from Pickle file The interface matches the csv, but with a pickled data frame. --- src/open_clip/model.py | 12 ++++- src/open_clip/model_configs/precomputed.json | 13 ++++++ src/training/data.py | 46 ++++++++++++++++++++ 3 files changed, 70 insertions(+), 1 deletion(-) create mode 100644 src/open_clip/model_configs/precomputed.json diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 9d427b9ca..72e2e9201 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -251,6 +251,13 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): return x +class Precomputed(nn.Module): + def __init__(self,image_size=1): + super().__init__() + self.image_size=image_size + def forward(self, x: torch.Tensor): + return x + class VisualTransformer(nn.Module): def __init__( self, image_size: int, patch_size: int, width: int, layers: int, heads: int, mlp_ratio: float, @@ -315,6 +322,7 @@ class CLIPVisionCfg: timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') + precomputed: bool = False @dataclass @@ -366,6 +374,8 @@ def __init__( image_size=vision_cfg.image_size, width=vision_cfg.width ) + elif vision_cfg.precomputed: + self.visual = Precomputed() else: vision_heads = vision_cfg.width // vision_cfg.head_width self.visual = VisualTransformer( @@ -378,7 +388,7 @@ def __init__( output_dim=embed_dim, act_layer=act_layer, ) - + self.transformer = Transformer( width=text_cfg.width, layers=text_cfg.layers, diff --git a/src/open_clip/model_configs/precomputed.json b/src/open_clip/model_configs/precomputed.json new file mode 100644 index 000000000..439eed167 --- /dev/null +++ b/src/open_clip/model_configs/precomputed.json @@ -0,0 +1,13 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "precomputed": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/training/data.py b/src/training/data.py index 23ff21400..478d12c01 100644 --- a/src/training/data.py +++ b/src/training/data.py @@ -47,6 +47,24 @@ def __getitem__(self, idx): texts = tokenize([str(self.captions[idx])])[0] return images, texts +class PrecomputedDataset(Dataset): + def __init__(self, input_filename, transforms, emb_key, caption_key): + logging.debug(f'Loading pickle data from {input_filename}.') + df = pd.read_pickle(input_filename) + + self.embeddings = df[emb_key].tolist() + self.captions = df[caption_key].tolist() + self.transforms = transforms + logging.debug('Done loading data.') + + def __len__(self): + return len(self.captions) + + def __getitem__(self, idx): + embeddings = self.embeddings[idx].astype("float32") + texts = tokenize([str(self.captions[idx])])[0] + return embeddings, texts + class SharedEpoch: def __init__(self, epoch: int = 0): @@ -418,6 +436,32 @@ def get_csv_dataset(args, preprocess_fn, is_train, epoch=0): return DataInfo(dataloader, sampler) +def get_precomputed_dataset(args, preprocess_fn, is_train, epoch=0): + input_filename = args.train_data if is_train else args.val_data + assert input_filename + dataset = PrecomputedDataset( + input_filename, + preprocess_fn, + emb_key=args.csv_img_key, + caption_key=args.csv_caption_key) + num_samples = len(dataset) + sampler = DistributedSampler(dataset) if args.distributed and is_train else None + shuffle = is_train and sampler is None + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=shuffle, + num_workers=args.workers, + pin_memory=True, + sampler=sampler, + drop_last=is_train, + ) + dataloader.num_samples = num_samples + dataloader.num_batches = len(dataloader) + + return DataInfo(dataloader, sampler) + def get_dataset_fn(data_path, dataset_type): if dataset_type == "webdataset": @@ -430,6 +474,8 @@ def get_dataset_fn(data_path, dataset_type): return get_csv_dataset elif ext in ['tar']: return get_wds_dataset + elif ext in ['pkl']: + return get_precomputed_dataset else: raise ValueError( f"Tried to figure out dataset type, but failed for extention {ext}.") From baa8013915ad5214dc463b71484adb38d408af29 Mon Sep 17 00:00:00 2001 From: Taylor Brown Date: Wed, 27 Jul 2022 18:37:23 -0500 Subject: [PATCH 2/2] added description of precomputed embeddings --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 010b2236d..3fbf7a584 100644 --- a/README.md +++ b/README.md @@ -147,6 +147,10 @@ python -m training.main \ Note: `imagenet-val` is the path to the *validation* set of ImageNet for zero-shot evaluation, not the training set! You can remove this argument if you do not want to perform zero-shot evaluation on ImageNet throughout training. Note that the `val` folder should contain subfolders. If it doest not, please use [this script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh). +### Pre-computed Embeddedings + +Pre-computed Embeddedings are in both the model and the dataset. You must set the model to precomputed. Adjust the dimension accordingly. The data input is a pickled pandas dataframe containing both the embeddings and the caption. The CSV parameters are used for the column names. + ### Multi-GPU and Beyond This code has been battle tested up to 1024 A100s and offers a variety of solutions