@@ -63,7 +63,7 @@ def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelTy
6363 try :
6464 config = self ._load_diffusers_config (model_path , config_name = "config.json" )
6565 if class_name := config .get ("_class_name" ):
66- result = self ._hf_definition_to_type (module = "diffusers" , class_name = class_name )
66+ result = self ._hf_definition_to_type (module = "diffusers" , class_name = class_name , model_name = model_path . name )
6767 elif class_name := config .get ("architectures" ):
6868 result = self ._hf_definition_to_type (module = "transformers" , class_name = class_name [0 ])
6969 else :
@@ -74,19 +74,19 @@ def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelTy
7474 return result
7575
7676 # TO DO: Add exception handling
77- def _hf_definition_to_type (self , module : str , class_name : str ) -> ModelMixin : # fix with correct type
77+ def _hf_definition_to_type (self , module : str , class_name : str , model_name : Optional [ str ] = None ) -> ModelMixin : # fix with correct type
7878 if module in [
7979 "diffusers" ,
8080 "transformers" ,
8181 "invokeai.backend.quantization.fast_quantized_transformers_model" ,
8282 "invokeai.backend.quantization.fast_quantized_diffusion_model" ,
8383 "transformer_bria" ,
8484 ]:
85- if module == "transformer_bria" :
86- module = "invokeai.backend.bria.transformer_bria"
87- elif class_name == "BriaTransformer2DModel" :
85+ if model_name == "BRIA-3.2-ControlNet-Union" :
8886 class_name = "BriaControlNetModel"
8987 module = "invokeai.backend.bria.controlnet_bria"
88+ elif module == "transformer_bria" or class_name == "BriaTransformer2DModel" :
89+ module = "invokeai.backend.bria.transformer_bria"
9090 res_type = sys .modules [module ]
9191 else :
9292 res_type = sys .modules ["diffusers" ].pipelines
0 commit comments