diff --git a/src/open_clip/hf_model.py b/src/open_clip/hf_model.py index 040926241..83b3deb82 100644 --- a/src/open_clip/hf_model.py +++ b/src/open_clip/hf_model.py @@ -85,19 +85,13 @@ def __init__( raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models") if config is None: self.config = AutoConfig.from_pretrained(model_name_or_path) - if pretrained: - # TODO: do all model configs have this attribute? PretrainedConfig does so yes?? - if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder: - self.transformer = AutoModel.from_pretrained(model_name_or_path) - self.transformer = self.transformer.encoder - else: - self.transformer = AutoModel.from_pretrained(model_name_or_path, add_pooling_layer=uses_transformer_pooler) + create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (AutoModel.from_config, self.config) + # TODO: do all model configs have this attribute? PretrainedConfig does so yes?? + if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder: + self.transformer = create_func(model_args) + self.transformer = self.transformer.encoder else: - if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder: - self.transformer = AutoModel.from_config(self.config) - self.transformer = self.transformer.encoder - else: - self.transformer = AutoModel.from_config(self.config, add_pooling_layer=uses_transformer_pooler) + self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler) else: self.config = config self.transformer = AutoModel.from_config(config)