Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ability to load pre-computed vectors from Pickle file #134

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ python -m training.main \
Note: `imagenet-val` is the path to the *validation* set of ImageNet for zero-shot evaluation, not the training set!
You can remove this argument if you do not want to perform zero-shot evaluation on ImageNet throughout training. Note that the `val` folder should contain subfolders. If it doest not, please use [this script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh).

### Pre-computed Embeddedings

Pre-computed Embeddedings are in both the model and the dataset. You must set the model to precomputed. Adjust the dimension accordingly. The data input is a pickled pandas dataframe containing both the embeddings and the caption. The CSV parameters are used for the column names.

### Multi-GPU and Beyond

This code has been battle tested up to 1024 A100s and offers a variety of solutions
Expand Down
12 changes: 11 additions & 1 deletion src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,13 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
return x


class Precomputed(nn.Module):
def __init__(self,image_size=1):
super().__init__()
self.image_size=image_size
def forward(self, x: torch.Tensor):
return x

class VisualTransformer(nn.Module):
def __init__(
self, image_size: int, patch_size: int, width: int, layers: int, heads: int, mlp_ratio: float,
Expand Down Expand Up @@ -315,6 +322,7 @@ class CLIPVisionCfg:
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
precomputed: bool = False


@dataclass
Expand Down Expand Up @@ -366,6 +374,8 @@ def __init__(
image_size=vision_cfg.image_size,
width=vision_cfg.width
)
elif vision_cfg.precomputed:
self.visual = Precomputed()
else:
vision_heads = vision_cfg.width // vision_cfg.head_width
self.visual = VisualTransformer(
Expand All @@ -378,7 +388,7 @@ def __init__(
output_dim=embed_dim,
act_layer=act_layer,
)

self.transformer = Transformer(
width=text_cfg.width,
layers=text_cfg.layers,
Expand Down
13 changes: 13 additions & 0 deletions src/open_clip/model_configs/precomputed.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"embed_dim": 512,
"vision_cfg": {
"precomputed": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
46 changes: 46 additions & 0 deletions src/training/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,24 @@ def __getitem__(self, idx):
texts = tokenize([str(self.captions[idx])])[0]
return images, texts

class PrecomputedDataset(Dataset):
def __init__(self, input_filename, transforms, emb_key, caption_key):
logging.debug(f'Loading pickle data from {input_filename}.')
df = pd.read_pickle(input_filename)

self.embeddings = df[emb_key].tolist()
self.captions = df[caption_key].tolist()
self.transforms = transforms
logging.debug('Done loading data.')

def __len__(self):
return len(self.captions)

def __getitem__(self, idx):
embeddings = self.embeddings[idx].astype("float32")
texts = tokenize([str(self.captions[idx])])[0]
return embeddings, texts


class SharedEpoch:
def __init__(self, epoch: int = 0):
Expand Down Expand Up @@ -418,6 +436,32 @@ def get_csv_dataset(args, preprocess_fn, is_train, epoch=0):

return DataInfo(dataloader, sampler)

def get_precomputed_dataset(args, preprocess_fn, is_train, epoch=0):
input_filename = args.train_data if is_train else args.val_data
assert input_filename
dataset = PrecomputedDataset(
input_filename,
preprocess_fn,
emb_key=args.csv_img_key,
caption_key=args.csv_caption_key)
num_samples = len(dataset)
sampler = DistributedSampler(dataset) if args.distributed and is_train else None
shuffle = is_train and sampler is None

dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=shuffle,
num_workers=args.workers,
pin_memory=True,
sampler=sampler,
drop_last=is_train,
)
dataloader.num_samples = num_samples
dataloader.num_batches = len(dataloader)

return DataInfo(dataloader, sampler)


def get_dataset_fn(data_path, dataset_type):
if dataset_type == "webdataset":
Expand All @@ -430,6 +474,8 @@ def get_dataset_fn(data_path, dataset_type):
return get_csv_dataset
elif ext in ['tar']:
return get_wds_dataset
elif ext in ['pkl']:
return get_precomputed_dataset
else:
raise ValueError(
f"Tried to figure out dataset type, but failed for extention {ext}.")
Expand Down