From bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b Mon Sep 17 00:00:00 2001 From: Maciej Kilian Date: Tue, 22 Nov 2022 19:30:52 -0800 Subject: [PATCH] hf_model.py: small refactoring (#242) * hf_model.py: encoder-decoder fix partial locking * refactor ifs, less ifs * try roms method * fix * back to this * update * revert --- src/open_clip/hf_model.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/open_clip/hf_model.py b/src/open_clip/hf_model.py index 040926241..83b3deb82 100644 --- a/src/open_clip/hf_model.py +++ b/src/open_clip/hf_model.py @@ -85,19 +85,13 @@ def __init__( raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models") if config is None: self.config = AutoConfig.from_pretrained(model_name_or_path) - if pretrained: - # TODO: do all model configs have this attribute? PretrainedConfig does so yes?? - if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder: - self.transformer = AutoModel.from_pretrained(model_name_or_path) - self.transformer = self.transformer.encoder - else: - self.transformer = AutoModel.from_pretrained(model_name_or_path, add_pooling_layer=uses_transformer_pooler) + create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (AutoModel.from_config, self.config) + # TODO: do all model configs have this attribute? PretrainedConfig does so yes?? + if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder: + self.transformer = create_func(model_args) + self.transformer = self.transformer.encoder else: - if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder: - self.transformer = AutoModel.from_config(self.config) - self.transformer = self.transformer.encoder - else: - self.transformer = AutoModel.from_config(self.config, add_pooling_layer=uses_transformer_pooler) + self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler) else: self.config = config self.transformer = AutoModel.from_config(config)