Skip to content

Commit

Permalink
Make clip always output dict in open_clip traininig (#396)
Browse files Browse the repository at this point in the history
* always output_dict

* remove is_clip

* fix comment

* comment

* test train with jit

* add faster jit hack

* rewrite test

* get tests right

* can lock with jit and renaming

* rename to output_dict

* annotate
  • Loading branch information
gpucce authored Jan 30, 2023
1 parent 434db0e commit 009f06d
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 19 deletions.
7 changes: 7 additions & 0 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def create_model(
pretrained_image: bool = False,
pretrained_hf: bool = True,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
):
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
if has_hf_hub_prefix:
Expand Down Expand Up @@ -215,6 +216,10 @@ def create_model(
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD

# to always output dict even if it is clip
if output_dict and hasattr(model, "output_dict"):
model.output_dict = True

if jit:
model = torch.jit.script(model)

Expand Down Expand Up @@ -259,6 +264,7 @@ def create_model_and_transforms(
image_std: Optional[Tuple[float, ...]] = None,
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
):
model = create_model(
model_name,
Expand All @@ -273,6 +279,7 @@ def create_model_and_transforms(
pretrained_image=pretrained_image,
pretrained_hf=pretrained_hf,
cache_dir=cache_dir,
output_dict=output_dict,
)

image_mean = image_mean or getattr(model.visual, 'image_mean', None)
Expand Down
1 change: 1 addition & 0 deletions src/training/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def main(args):
image_mean=args.image_mean,
image_std=args.image_std,
aug_cfg=args.aug_cfg,
output_dict=True,
)
random_seed(args.seed, args.rank)

Expand Down
19 changes: 0 additions & 19 deletions src/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ def update(self, val, n=1):
self.count += n
self.avg = self.sum / self.count

def is_clip(model):
return type(model) in [CLIP, CustomTextCLIP]

def postprocess_clip_output(model_out):
return {
"image_features": model_out[0],
Expand Down Expand Up @@ -98,10 +95,6 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args
if args.accum_freq == 1:
with autocast():
model_out = model(images, texts)
# for clip if it does not output_dict
module = model.module if type(model) == DistributedDataParallel else model
if is_clip(module) and not module.output_dict:
model_out = postprocess_clip_output(model_out)
logit_scale = model_out["logit_scale"]
losses = loss(**model_out, output_dict=True)

Expand All @@ -114,10 +107,6 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args
with torch.no_grad():
with autocast():
model_out = model(images, texts)
# for clip if it does not output_dict
module = model.module if type(model) == DistributedDataParallel else model
if is_clip(module) and not module.output_dict:
model_out = postprocess_clip_output(model_out)
model_out.pop("logit_scale")
for key, val in model_out.items():
if key in accum_features:
Expand All @@ -142,10 +131,6 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args
texts = accum_texts[j]
with autocast():
model_out = model(images, texts, output_dict=True)
# for clip if it does not output_dict
module = model.module if type(model) == DistributedDataParallel else model
if is_clip(module) and not model.output_dict:
model_out = postprocess_clip_output(model_out)
logit_scale = model_out.pop("logit_scale")
for key, val in accum_features:
accumulated = accum_features[key]
Expand Down Expand Up @@ -267,10 +252,6 @@ def evaluate(model, data, epoch, args, tb_writer=None):

with autocast():
model_out = model(images, texts, output_dict=True)
# for clip if it does not output_dict
module = model.module if type(model) == DistributedDataParallel else model
if is_clip(module) and not module.output_dict:
model_out = postprocess_clip_output(model_out)
image_features = model_out["image_features"]
text_features = model_out["text_features"]
logit_scale = model_out["logit_scale"]
Expand Down
24 changes: 24 additions & 0 deletions tests/test_training_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@

os.environ["CUDA_VISIBLE_DEVICES"] = ""

if hasattr(torch._C, '_jit_set_profiling_executor'):
# legacy executor is too slow to compile large models for unit tests
# no need for the fusion performance here
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(False)

@pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals")
def test_training():
main([
Expand Down Expand Up @@ -77,3 +83,21 @@ def test_training_unfreezing_vit():
'--lock-image',
'--lock-image-unlocked-groups', '5'
])


@pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals")
def test_training_clip_with_jit():
main([
'--save-frequency', '1',
'--zeroshot-frequency', '1',
'--dataset-type', "synthetic",
'--train-num-samples', '16',
'--warmup', '1',
'--batch-size', '4',
'--lr', '1e-3',
'--wd', '0.1',
'--epochs', '1',
'--workers', '2',
'--model', 'ViT-B-32',
'--torchscript'
])

0 comments on commit 009f06d

Please sign in to comment.