@@ -1079,6 +1079,33 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
1079
1079
use_onnx = kwargs .pop ("use_onnx" , None )
1080
1080
load_connected_pipeline = kwargs .pop ("load_connected_pipeline" , False )
1081
1081
1082
+ if low_cpu_mem_usage and not is_accelerate_available ():
1083
+ low_cpu_mem_usage = False
1084
+ logger .warning (
1085
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
1086
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
1087
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n ```\n pip"
1088
+ " install accelerate\n ```\n ."
1089
+ )
1090
+
1091
+ if device_map is not None and not is_torch_version (">=" , "1.9.0" ):
1092
+ raise NotImplementedError (
1093
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
1094
+ " `device_map=None`."
1095
+ )
1096
+
1097
+ if low_cpu_mem_usage is True and not is_torch_version (">=" , "1.9.0" ):
1098
+ raise NotImplementedError (
1099
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
1100
+ " `low_cpu_mem_usage=False`."
1101
+ )
1102
+
1103
+ if low_cpu_mem_usage is False and device_map is not None :
1104
+ raise ValueError (
1105
+ f"You cannot set `low_cpu_mem_usage` to False while using device_map={ device_map } for loading and"
1106
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
1107
+ )
1108
+
1082
1109
# 1. Download the checkpoints and configs
1083
1110
# use snapshot download here to get it working from from_pretrained
1084
1111
if not os .path .isdir (pretrained_model_name_or_path ):
@@ -1211,33 +1238,6 @@ def load_module(name, value):
1211
1238
f"Keyword arguments { unused_kwargs } are not expected by { pipeline_class .__name__ } and will be ignored."
1212
1239
)
1213
1240
1214
- if low_cpu_mem_usage and not is_accelerate_available ():
1215
- low_cpu_mem_usage = False
1216
- logger .warning (
1217
- "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
1218
- " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
1219
- " `accelerate` for faster and less memory-intense model loading. You can do so with: \n ```\n pip"
1220
- " install accelerate\n ```\n ."
1221
- )
1222
-
1223
- if device_map is not None and not is_torch_version (">=" , "1.9.0" ):
1224
- raise NotImplementedError (
1225
- "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
1226
- " `device_map=None`."
1227
- )
1228
-
1229
- if low_cpu_mem_usage is True and not is_torch_version (">=" , "1.9.0" ):
1230
- raise NotImplementedError (
1231
- "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
1232
- " `low_cpu_mem_usage=False`."
1233
- )
1234
-
1235
- if low_cpu_mem_usage is False and device_map is not None :
1236
- raise ValueError (
1237
- f"You cannot set `low_cpu_mem_usage` to False while using device_map={ device_map } for loading and"
1238
- " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
1239
- )
1240
-
1241
1241
# import it here to avoid circular import
1242
1242
from diffusers import pipelines
1243
1243
0 commit comments