Skip to content

Commit

Permalink
Fix partial unfreezing for encoder decode HF model (#243)
Browse files Browse the repository at this point in the history
* Fix freeze for t5

* improve by using arch dict
  • Loading branch information
rom1504 authored Nov 23, 2022
1 parent 82cc506 commit c4b6dc9
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 2 deletions.
6 changes: 6 additions & 0 deletions src/open_clip/hf_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings"
},
"pooler": "mean_pooler",
},
Expand All @@ -19,6 +21,8 @@
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings"
},
"pooler": "mean_pooler",
},
Expand All @@ -33,6 +37,8 @@
"width": "d_model",
"heads": "num_heads",
"layers": "num_layers",
"layer_attr": "block",
"token_embeddings_attr": "embed_tokens"
},
"pooler": "mean_pooler",
},
Expand Down
7 changes: 5 additions & 2 deletions src/open_clip/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,11 @@ def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
return

n_layers = len(self.transformer.encoder.layer) - unlocked_layers - 1 # -1 for embeddings
modules = [self.transformer.embeddings, self.transformer.encoder.layer[:n_layers]]
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
n_layers = len(layer_list) - unlocked_layers - 1 # -1 for embeddings
embeddings = getattr(self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
modules = [embeddings, layer_list[:n_layers]]
for module in modules:
for n, p in module.named_parameters():
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
Expand Down
19 changes: 19 additions & 0 deletions tests/test_training_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,22 @@ def test_training():
'--model', 'RN50'
])

@pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals")
def test_training_mt5():
main([
'--save-frequency', '1',
'--zeroshot-frequency', '1',
'--dataset-type', "synthetic",
'--train-num-samples', '16',
'--warmup', '1',
'--batch-size', '4',
'--lr', '1e-3',
'--wd', '0.1',
'--epochs', '1',
'--workers', '2',
'--model', 'mt5-base-ViT-B-32',
'--lock-text',
'--lock-text-unlocked-layers', '2'
])


0 comments on commit c4b6dc9

Please sign in to comment.