Skip to content

Commit 13001ee

Browse files
authored
Bugfix in IPAdapterFaceID (#6835)
1 parent 65329ae commit 13001ee

File tree

1 file changed

+37
-44
lines changed

1 file changed

+37
-44
lines changed

examples/community/ip_adapter_face_id.py

Lines changed: 37 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,22 @@ def __call__(
104104
):
105105
residual = hidden_states
106106

107+
# separate ip_hidden_states from encoder_hidden_states
108+
if encoder_hidden_states is not None:
109+
if isinstance(encoder_hidden_states, tuple):
110+
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
111+
else:
112+
deprecation_message = (
113+
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
114+
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
115+
)
116+
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
117+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
118+
encoder_hidden_states, ip_hidden_states = (
119+
encoder_hidden_states[:, :end_pos, :],
120+
[encoder_hidden_states[:, end_pos:, :]],
121+
)
122+
107123
if attn.spatial_norm is not None:
108124
hidden_states = attn.spatial_norm(hidden_states, temb)
109125

@@ -125,15 +141,8 @@ def __call__(
125141

126142
if encoder_hidden_states is None:
127143
encoder_hidden_states = hidden_states
128-
else:
129-
# get encoder_hidden_states, ip_hidden_states
130-
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
131-
encoder_hidden_states, ip_hidden_states = (
132-
encoder_hidden_states[:, :end_pos, :],
133-
encoder_hidden_states[:, end_pos:, :],
134-
)
135-
if attn.norm_cross:
136-
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
144+
elif attn.norm_cross:
145+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
137146

138147
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
139148
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
@@ -233,6 +242,22 @@ def __call__(
233242
):
234243
residual = hidden_states
235244

245+
# separate ip_hidden_states from encoder_hidden_states
246+
if encoder_hidden_states is not None:
247+
if isinstance(encoder_hidden_states, tuple):
248+
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
249+
else:
250+
deprecation_message = (
251+
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
252+
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
253+
)
254+
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
255+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
256+
encoder_hidden_states, ip_hidden_states = (
257+
encoder_hidden_states[:, :end_pos, :],
258+
[encoder_hidden_states[:, end_pos:, :]],
259+
)
260+
236261
if attn.spatial_norm is not None:
237262
hidden_states = attn.spatial_norm(hidden_states, temb)
238263

@@ -259,15 +284,8 @@ def __call__(
259284

260285
if encoder_hidden_states is None:
261286
encoder_hidden_states = hidden_states
262-
else:
263-
# get encoder_hidden_states, ip_hidden_states
264-
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
265-
encoder_hidden_states, ip_hidden_states = (
266-
encoder_hidden_states[:, :end_pos, :],
267-
encoder_hidden_states[:, end_pos:, :],
268-
)
269-
if attn.norm_cross:
270-
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
287+
elif attn.norm_cross:
288+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
271289

272290
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
273291
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
@@ -951,30 +969,6 @@ def encode_prompt(
951969

952970
return prompt_embeds, negative_prompt_embeds
953971

954-
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
955-
dtype = next(self.image_encoder.parameters()).dtype
956-
957-
if not isinstance(image, torch.Tensor):
958-
image = self.feature_extractor(image, return_tensors="pt").pixel_values
959-
960-
image = image.to(device=device, dtype=dtype)
961-
if output_hidden_states:
962-
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
963-
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
964-
uncond_image_enc_hidden_states = self.image_encoder(
965-
torch.zeros_like(image), output_hidden_states=True
966-
).hidden_states[-2]
967-
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
968-
num_images_per_prompt, dim=0
969-
)
970-
return image_enc_hidden_states, uncond_image_enc_hidden_states
971-
else:
972-
image_embeds = self.image_encoder(image).image_embeds
973-
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
974-
uncond_image_embeds = torch.zeros_like(image_embeds)
975-
976-
return image_embeds, uncond_image_embeds
977-
978972
def run_safety_checker(self, image, device, dtype):
979973
if self.safety_checker is None:
980974
has_nsfw_concept = None
@@ -1302,7 +1296,6 @@ def __call__(
13021296
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
13031297
image_embeds (`torch.FloatTensor`, *optional*):
13041298
Pre-generated image embeddings.
1305-
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
13061299
output_type (`str`, *optional*, defaults to `"pil"`):
13071300
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
13081301
return_dict (`bool`, *optional*, defaults to `True`):
@@ -1411,7 +1404,7 @@ def __call__(
14111404
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
14121405

14131406
if image_embeds is not None:
1414-
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0).to(
1407+
image_embeds = torch.stack([image_embeds] * num_images_per_prompt, dim=0).to(
14151408
device=device, dtype=prompt_embeds.dtype
14161409
)
14171410
negative_image_embeds = torch.zeros_like(image_embeds)

0 commit comments

Comments
 (0)