Skip to content

Commit

Permalink
hf_model.py: small refactoring (#242)
Browse files Browse the repository at this point in the history
* hf_model.py: encoder-decoder fix partial locking

* refactor ifs, less ifs

* try roms method

* fix

* back to this

* update

* revert
  • Loading branch information
iejMac authored Nov 23, 2022
1 parent c4b6dc9 commit bb6e834
Showing 1 changed file with 6 additions and 12 deletions.
18 changes: 6 additions & 12 deletions src/open_clip/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,13 @@ def __init__(
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
if config is None:
self.config = AutoConfig.from_pretrained(model_name_or_path)
if pretrained:
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
self.transformer = AutoModel.from_pretrained(model_name_or_path)
self.transformer = self.transformer.encoder
else:
self.transformer = AutoModel.from_pretrained(model_name_or_path, add_pooling_layer=uses_transformer_pooler)
create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (AutoModel.from_config, self.config)
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
self.transformer = create_func(model_args)
self.transformer = self.transformer.encoder
else:
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
self.transformer = AutoModel.from_config(self.config)
self.transformer = self.transformer.encoder
else:
self.transformer = AutoModel.from_config(self.config, add_pooling_layer=uses_transformer_pooler)
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
else:
self.config = config
self.transformer = AutoModel.from_config(config)
Expand Down

0 comments on commit bb6e834

Please sign in to comment.