Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make clip always output dict in open_clip traininig #396

Merged
merged 12 commits into from
Jan 30, 2023
Merged
7 changes: 7 additions & 0 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def create_model(
pretrained_image: bool = False,
pretrained_hf: bool = True,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
):
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
if has_hf_hub_prefix:
Expand Down Expand Up @@ -215,6 +216,10 @@ def create_model(
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD

# to always output dict even if it is clip
if output_dict and hasattr(model, "output_dict"):
model.output_dict = True

if jit:
model = torch.jit.script(model)

Expand Down Expand Up @@ -259,6 +264,7 @@ def create_model_and_transforms(
image_std: Optional[Tuple[float, ...]] = None,
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
):
model = create_model(
model_name,
Expand All @@ -273,6 +279,7 @@ def create_model_and_transforms(
pretrained_image=pretrained_image,
pretrained_hf=pretrained_hf,
cache_dir=cache_dir,
output_dict=output_dict,
)

image_mean = image_mean or getattr(model.visual, 'image_mean', None)
Expand Down
1 change: 1 addition & 0 deletions src/training/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def main(args):
image_mean=args.image_mean,
image_std=args.image_std,
aug_cfg=args.aug_cfg,
output_dict=True,
)
random_seed(args.seed, args.rank)

Expand Down
19 changes: 0 additions & 19 deletions src/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ def update(self, val, n=1):
self.count += n
self.avg = self.sum / self.count

def is_clip(model):
return type(model) in [CLIP, CustomTextCLIP]

def postprocess_clip_output(model_out):
return {
"image_features": model_out[0],
Expand Down Expand Up @@ -98,10 +95,6 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args
if args.accum_freq == 1:
with autocast():
model_out = model(images, texts)
# for clip if it does not output_dict
module = model.module if type(model) == DistributedDataParallel else model
if is_clip(module) and not module.output_dict:
model_out = postprocess_clip_output(model_out)
logit_scale = model_out["logit_scale"]
losses = loss(**model_out, output_dict=True)

Expand All @@ -114,10 +107,6 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args
with torch.no_grad():
with autocast():
model_out = model(images, texts)
# for clip if it does not output_dict
module = model.module if type(model) == DistributedDataParallel else model
if is_clip(module) and not module.output_dict:
model_out = postprocess_clip_output(model_out)
model_out.pop("logit_scale")
for key, val in model_out.items():
if key in accum_features:
Expand All @@ -142,10 +131,6 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args
texts = accum_texts[j]
with autocast():
model_out = model(images, texts, output_dict=True)
# for clip if it does not output_dict
module = model.module if type(model) == DistributedDataParallel else model
if is_clip(module) and not model.output_dict:
model_out = postprocess_clip_output(model_out)
logit_scale = model_out.pop("logit_scale")
for key, val in accum_features:
accumulated = accum_features[key]
Expand Down Expand Up @@ -267,10 +252,6 @@ def evaluate(model, data, epoch, args, tb_writer=None):

with autocast():
model_out = model(images, texts, output_dict=True)
# for clip if it does not output_dict
module = model.module if type(model) == DistributedDataParallel else model
if is_clip(module) and not module.output_dict:
model_out = postprocess_clip_output(model_out)
image_features = model_out["image_features"]
text_features = model_out["text_features"]
logit_scale = model_out["logit_scale"]
Expand Down
24 changes: 24 additions & 0 deletions tests/test_training_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@

os.environ["CUDA_VISIBLE_DEVICES"] = ""

if hasattr(torch._C, '_jit_set_profiling_executor'):
# legacy executor is too slow to compile large models for unit tests
# no need for the fusion performance here
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(False)

@pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals")
def test_training():
main([
Expand Down Expand Up @@ -77,3 +83,21 @@ def test_training_unfreezing_vit():
'--lock-image',
'--lock-image-unlocked-groups', '5'
])


@pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals")
def test_training_clip_with_jit():
main([
'--save-frequency', '1',
'--zeroshot-frequency', '1',
'--dataset-type', "synthetic",
'--train-num-samples', '16',
'--warmup', '1',
'--batch-size', '4',
'--lr', '1e-3',
'--wd', '0.1',
'--epochs', '1',
'--workers', '2',
'--model', 'ViT-B-32',
'--torchscript'
])