Skip to content

Commit

Permalink
support attn_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
lldacing committed Jan 17, 2025
1 parent fa1e370 commit 6079758
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 37 deletions.
55 changes: 26 additions & 29 deletions PulidFluxHook.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,25 @@ def set_model_dit_patch_replace(model, patch_kwargs, key):
else:
to["patches_replace"]["dit"][key].add(pulid_patch, **patch_kwargs)

def pulid_patch(img, pulid_model=None, ca_idx=None, weight=1.0, embedding=None, mask=None):
def pulid_patch(img, pulid_model=None, ca_idx=None, weight=1.0, embedding=None, mask=None, transformer_options={}):
pulid_img = weight * pulid_model.pulid_ca[ca_idx].to(img.device)(embedding, img)
if mask is not None:
pulid_temp_attrs = transformer_options.get(PatchKeys.pulid_patch_key_attrs, {})
latent_image_shape = pulid_temp_attrs.get("latent_image_shape", None)
if latent_image_shape is not None:
bs, c, h, w = latent_image_shape
mask = comfy.sampler_helpers.prepare_mask(mask, (bs, c, h, w), img.device)
patch_size = transformer_options[PatchKeys.running_net_model].patch_size
mask = comfy.ldm.common_dit.pad_to_patch_size(mask, (patch_size, patch_size))
mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
# (b, seq_len, _) =>(b, seq_len, seq_len)
mask = mask[..., 0].unsqueeze(-1).repeat(1, 1, mask.shape[1])
del patch_size, latent_image_shape

pulid_img = pulid_img * mask

del mask

return pulid_img

class DitDoubleBlockReplace:
Expand All @@ -51,33 +65,24 @@ def add(self, callback, **kwargs):
def __call__(self, input_args, extra_options):
transformer_options = extra_options["transformer_options"]
pulid_temp_attrs = transformer_options.get(PatchKeys.pulid_patch_key_attrs, {})
sigma = pulid_temp_attrs["timesteps"][0]
sigma = pulid_temp_attrs["timesteps"].detach().cpu()[0]
out = extra_options["original_block"](input_args)
img = out['img']
temp_img = img
for i, callback in enumerate(self.callback):
if sigma <= self.kwargs[i]["sigma_start"] and sigma >= self.kwargs[i]["sigma_end"]:
mask = self.kwargs[i]['mask']
if mask is not None:
latent_image_shape = pulid_temp_attrs.get("latent_image_shape", None)
if latent_image_shape is not None:
bs, c, h, w = latent_image_shape
mask = comfy.sampler_helpers.prepare_mask(mask, (bs, c, h, w), img.device)
flux_model = transformer_options[PatchKeys.running_net_model]
patch_size = flux_model.patch_size
mask = comfy.ldm.common_dit.pad_to_patch_size(mask, (patch_size, patch_size))
mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
mask = flux_model.img_in(mask)

img = img + callback(temp_img,
pulid_model=self.kwargs[i]['pulid_model'],
ca_idx=self.kwargs[i]['ca_idx'],
weight=self.kwargs[i]['weight'],
embedding=self.kwargs[i]['embedding'],
mask = mask,
mask = self.kwargs[i]['mask'],
transformer_options=transformer_options
)

out['img'] = img

del temp_img, pulid_temp_attrs, sigma, transformer_options, img

return out


Expand Down Expand Up @@ -114,29 +119,21 @@ def __call__(self, input_args, extra_options):
temp_img = real_img
for i, callback in enumerate(self.callback):
if sigma <= self.kwargs[i]["sigma_start"] and sigma >= self.kwargs[i]["sigma_end"]:
mask = self.kwargs[i]['mask']
if mask is not None:
latent_image_shape = pulid_temp_attrs.get("latent_image_shape", None)
if latent_image_shape is not None:
bs, c, h, w = latent_image_shape
mask = comfy.sampler_helpers.prepare_mask(mask, (bs, c, h, w), img.device)
flux_model = transformer_options[PatchKeys.running_net_model]
patch_size = flux_model.patch_size
mask = comfy.ldm.common_dit.pad_to_patch_size(mask, (patch_size, patch_size))
mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
mask = flux_model.img_in(mask)

real_img = real_img + callback(temp_img,
pulid_model=self.kwargs[i]['pulid_model'],
ca_idx=self.kwargs[i]['ca_idx'],
weight=self.kwargs[i]['weight'],
embedding=self.kwargs[i]['embedding'],
mask=mask,
mask=self.kwargs[i]['mask'],
transformer_options = transformer_options,
)

img = torch.cat((txt, real_img), 1)

out['img'] = img

del temp_img, pulid_temp_attrs, sigma, transformer_options, real_img, img

return out

def pulid_forward_orig(
Expand Down
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Must uninstall or disable `ComfyUI-PuLID-Flux` and other PuLID-Flux nodes before

## Preview (Image with WorkFlow)
![save api extended](examples/PuLID_with_speedup.png)
![save api extended](examples/PuLID_with_attn_mask.png)

## Install

Expand All @@ -34,11 +35,11 @@ Please see [ComfyUI-PuLID-Flux](https://github.com/balazik/ComfyUI-PuLID-Flux)
- See [ComfyUI-PuLID-Flux](https://github.com/balazik/ComfyUI-PuLID-Flux)
- ApplyPulidFlux
- Solved the model pollution problem of the original plugin ComfyUI-PuLID-Flux
- `attn_mask` may not work correctly (I have no idea how to apply it, I have tried multiple methods and the results have been not satisfactory)
- `attn_mask` ~~may not work correctly (I have no idea how to apply it, I have tried multiple methods and the results have been not satisfactory)~~ works now.
- If you want use with [TeaCache](https://github.com/ali-vilab/TeaCache), must put it before node [`FluxForwardOverrider` and `ApplyTeaCachePatch`](https://github.com/lldacing/ComfyUI_Patches_ll).
- If you want use with [Comfy-WaveSpeed](https://github.com/chengzeyi/Comfy-WaveSpeed), must put it before node `ApplyFBCacheOnModel`.
- FixPulidFluxPatch (Deprecated)
- If you want use with [TeaCache](https://github.com/ali-vilab/TeaCache), must link it after node `ApplyPulidFlux`, and link node [`FluxForwardOverrider` and `ApplyTeaCachePatch`](https://github.com/lldacing/ComfyUI_Patches_ll) after it.
- If you want use with [TeaCache](https://github.com/ali-vilab/TeaCache), must ~~link it after node `ApplyPulidFlux`, and~~ link node [`FluxForwardOverrider` and `ApplyTeaCachePatch`](https://github.com/lldacing/ComfyUI_Patches_ll) after it.

## Thanks

Expand Down
5 changes: 3 additions & 2 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

## 预览 (图片含工作流)
![save api extended](examples/PuLID_with_speedup.png)
![save api extended](examples/PuLID_with_attn_mask.png)

## 安装

Expand All @@ -31,11 +32,11 @@
-[ComfyUI-PuLID-Flux](https://github.com/balazik/ComfyUI-PuLID-Flux)
- ApplyPulidFlux
- 解决了原插件中模型污染的问题
- `attn_mask`可能不能正确工作, 因为我不知道如何实现它, 尝试了多种方式效果都未能达到预期
- `attn_mask`~~可能不能正确工作, 因为我不知道如何实现它, 尝试了多种方式效果都未能达到预期~~,可以正常工作了。
- 使用 [TeaCache](https://github.com/ali-vilab/TeaCache)加速, 必须加在[`FluxForwardOverrider` and `ApplyTeaCachePatch`](https://github.com/lldacing/ComfyUI_Patches_ll)之前.
- 使用 [Comfy-WaveSpeed](https://github.com/chengzeyi/Comfy-WaveSpeed)加速, 必须加在[`ApplyFBCacheOnModel`](https://github.com/lldacing/ComfyUI_Patches_ll)之前.
- FixPulidFluxPatch (已弃用)
- 如果想使用 [TeaCache](https://github.com/ali-vilab/TeaCache)加速, 必须加在 `ApplyPulidFlux` 节点之后, 并在后面连接节点 [`FluxForwardOverrider` and `ApplyTeaCachePatch`](https://github.com/lldacing/ComfyUI_Patches_ll).
- 如果想使用 [TeaCache](https://github.com/ali-vilab/TeaCache)加速, 必须~~加在 `ApplyPulidFlux` 节点之后, ~~在后面连接节点 [`FluxForwardOverrider` and `ApplyTeaCachePatch`](https://github.com/lldacing/ComfyUI_Patches_ll).

## 感谢

Expand Down
Binary file added examples/PuLID_with_attn_mask.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 1 addition & 2 deletions pulidflux.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ def apply_pulid_flux(self, model, pulid_flux, eva_clip, face_analysis, image, we
eva_clip.to(device, dtype=dtype)
pulid_flux.to(device, dtype=dtype)

# TODO: Add masking support!
if attn_mask is not None:
if attn_mask.dim() > 3:
attn_mask = attn_mask.squeeze(-1)
Expand Down Expand Up @@ -301,7 +300,6 @@ def apply_pulid_flux(self, model, pulid_flux, eva_clip, face_analysis, image, we
"embedding": cond,
"sigma_start": sigma_start,
"sigma_end": sigma_end,
# don't know how to apply mask
"mask": attn_mask
}

Expand All @@ -324,6 +322,7 @@ def apply_pulid_flux(self, model, pulid_flux, eva_clip, face_analysis, image, we
# Just add it once when connecting in series
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.APPLY_MODEL, wrappers_name, pulid_apply_model_wrappers)

del eva_clip, face_analysis, pulid_flux
return (model,)


Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
[project]
name = "comfyui_pulid_flux_ll"
description = "The implementation for PuLID-Flux, support use with TeaCache and WaveSpeed, no model pollution."
version = "1.0.3"
version = "1.0.4"
license = {file = "LICENSE"}
dependencies = ['facexlib', 'insightface', 'onnxruntime', 'onnxruntime-gpu', 'ftfy', 'timm']
dependencies = ['facexlib', 'insightface', 'onnxruntime', 'onnxruntime-gpu; sys_platform != "darwin" and platform_machine == "x86_64"', 'ftfy', 'timm']

[project.urls]
Repository = "https://github.com/lldacing/ComfyUI_PuLID_Flux_ll"
Expand Down

0 comments on commit 6079758

Please sign in to comment.