@@ -104,6 +104,22 @@ def __call__(
104
104
):
105
105
residual = hidden_states
106
106
107
+ # separate ip_hidden_states from encoder_hidden_states
108
+ if encoder_hidden_states is not None :
109
+ if isinstance (encoder_hidden_states , tuple ):
110
+ encoder_hidden_states , ip_hidden_states = encoder_hidden_states
111
+ else :
112
+ deprecation_message = (
113
+ "You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
114
+ " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
115
+ )
116
+ deprecate ("encoder_hidden_states not a tuple" , "1.0.0" , deprecation_message , standard_warn = False )
117
+ end_pos = encoder_hidden_states .shape [1 ] - self .num_tokens [0 ]
118
+ encoder_hidden_states , ip_hidden_states = (
119
+ encoder_hidden_states [:, :end_pos , :],
120
+ [encoder_hidden_states [:, end_pos :, :]],
121
+ )
122
+
107
123
if attn .spatial_norm is not None :
108
124
hidden_states = attn .spatial_norm (hidden_states , temb )
109
125
@@ -125,15 +141,8 @@ def __call__(
125
141
126
142
if encoder_hidden_states is None :
127
143
encoder_hidden_states = hidden_states
128
- else :
129
- # get encoder_hidden_states, ip_hidden_states
130
- end_pos = encoder_hidden_states .shape [1 ] - self .num_tokens
131
- encoder_hidden_states , ip_hidden_states = (
132
- encoder_hidden_states [:, :end_pos , :],
133
- encoder_hidden_states [:, end_pos :, :],
134
- )
135
- if attn .norm_cross :
136
- encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
144
+ elif attn .norm_cross :
145
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
137
146
138
147
key = attn .to_k (encoder_hidden_states ) + self .lora_scale * self .to_k_lora (encoder_hidden_states )
139
148
value = attn .to_v (encoder_hidden_states ) + self .lora_scale * self .to_v_lora (encoder_hidden_states )
@@ -233,6 +242,22 @@ def __call__(
233
242
):
234
243
residual = hidden_states
235
244
245
+ # separate ip_hidden_states from encoder_hidden_states
246
+ if encoder_hidden_states is not None :
247
+ if isinstance (encoder_hidden_states , tuple ):
248
+ encoder_hidden_states , ip_hidden_states = encoder_hidden_states
249
+ else :
250
+ deprecation_message = (
251
+ "You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
252
+ " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
253
+ )
254
+ deprecate ("encoder_hidden_states not a tuple" , "1.0.0" , deprecation_message , standard_warn = False )
255
+ end_pos = encoder_hidden_states .shape [1 ] - self .num_tokens [0 ]
256
+ encoder_hidden_states , ip_hidden_states = (
257
+ encoder_hidden_states [:, :end_pos , :],
258
+ [encoder_hidden_states [:, end_pos :, :]],
259
+ )
260
+
236
261
if attn .spatial_norm is not None :
237
262
hidden_states = attn .spatial_norm (hidden_states , temb )
238
263
@@ -259,15 +284,8 @@ def __call__(
259
284
260
285
if encoder_hidden_states is None :
261
286
encoder_hidden_states = hidden_states
262
- else :
263
- # get encoder_hidden_states, ip_hidden_states
264
- end_pos = encoder_hidden_states .shape [1 ] - self .num_tokens
265
- encoder_hidden_states , ip_hidden_states = (
266
- encoder_hidden_states [:, :end_pos , :],
267
- encoder_hidden_states [:, end_pos :, :],
268
- )
269
- if attn .norm_cross :
270
- encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
287
+ elif attn .norm_cross :
288
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
271
289
272
290
key = attn .to_k (encoder_hidden_states ) + self .lora_scale * self .to_k_lora (encoder_hidden_states )
273
291
value = attn .to_v (encoder_hidden_states ) + self .lora_scale * self .to_v_lora (encoder_hidden_states )
@@ -951,30 +969,6 @@ def encode_prompt(
951
969
952
970
return prompt_embeds , negative_prompt_embeds
953
971
954
- def encode_image (self , image , device , num_images_per_prompt , output_hidden_states = None ):
955
- dtype = next (self .image_encoder .parameters ()).dtype
956
-
957
- if not isinstance (image , torch .Tensor ):
958
- image = self .feature_extractor (image , return_tensors = "pt" ).pixel_values
959
-
960
- image = image .to (device = device , dtype = dtype )
961
- if output_hidden_states :
962
- image_enc_hidden_states = self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
963
- image_enc_hidden_states = image_enc_hidden_states .repeat_interleave (num_images_per_prompt , dim = 0 )
964
- uncond_image_enc_hidden_states = self .image_encoder (
965
- torch .zeros_like (image ), output_hidden_states = True
966
- ).hidden_states [- 2 ]
967
- uncond_image_enc_hidden_states = uncond_image_enc_hidden_states .repeat_interleave (
968
- num_images_per_prompt , dim = 0
969
- )
970
- return image_enc_hidden_states , uncond_image_enc_hidden_states
971
- else :
972
- image_embeds = self .image_encoder (image ).image_embeds
973
- image_embeds = image_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
974
- uncond_image_embeds = torch .zeros_like (image_embeds )
975
-
976
- return image_embeds , uncond_image_embeds
977
-
978
972
def run_safety_checker (self , image , device , dtype ):
979
973
if self .safety_checker is None :
980
974
has_nsfw_concept = None
@@ -1302,7 +1296,6 @@ def __call__(
1302
1296
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
1303
1297
image_embeds (`torch.FloatTensor`, *optional*):
1304
1298
Pre-generated image embeddings.
1305
- ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1306
1299
output_type (`str`, *optional*, defaults to `"pil"`):
1307
1300
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
1308
1301
return_dict (`bool`, *optional*, defaults to `True`):
@@ -1411,7 +1404,7 @@ def __call__(
1411
1404
prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ])
1412
1405
1413
1406
if image_embeds is not None :
1414
- image_embeds = image_embeds . repeat_interleave ( num_images_per_prompt , dim = 0 ).to (
1407
+ image_embeds = torch . stack ([ image_embeds ] * num_images_per_prompt , dim = 0 ).to (
1415
1408
device = device , dtype = prompt_embeds .dtype
1416
1409
)
1417
1410
negative_image_embeds = torch .zeros_like (image_embeds )
0 commit comments