From b4c7540fbca27cf28c4b40319b280a400751cad6 Mon Sep 17 00:00:00 2001 From: "Chowdhury, Hisham" Date: Mon, 20 Jan 2025 16:22:09 -0800 Subject: [PATCH 1/3] Fix for running via DirectML Fix DirectML empty image generation issue with Flux1. add CPU fallback for unsupported path. Verified the model works on AMD GPUs --- comfy/clip_model.py | 6 +++++- comfy/ldm/flux/math.py | 2 +- comfy/model_management.py | 7 +++++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 23ddea9c0299..d176b3c7a4ab 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -104,7 +104,11 @@ def forward(self, input_tokens, attention_mask=None, intermediate_output=None, f mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) - causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) + if comfy.model_management.is_directml_enabled(): + causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).triu_(1) + else: + causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) + if mask is not None: mask += causal_mask else: diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index b5960ffd3073..36b67931c401 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -22,7 +22,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 - if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu(): + if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled(): device = torch.device("cpu") else: device = pos.device diff --git a/comfy/model_management.py b/comfy/model_management.py index f6dfc18b02b6..95fd39f6a11c 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -993,6 +993,13 @@ def is_device_mps(device): def is_device_cuda(device): return is_device_type(device, 'cuda') +def is_directml_enabled(): + global directml_enabled + if directml_enabled: + return True + + return False + def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): global directml_enabled From 75ec851324f5e171fb93b36371e0a4bbbc6fa202 Mon Sep 17 00:00:00 2001 From: "Chowdhury, Hisham" Date: Thu, 6 Feb 2025 12:23:33 -0800 Subject: [PATCH 2/3] fix formating --- comfy/clip_model.py | 2 +- comfy/model_management.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/clip_model.py b/comfy/clip_model.py index d176b3c7a4ab..466d5887ca6f 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -108,7 +108,7 @@ def forward(self, input_tokens, attention_mask=None, intermediate_output=None, f causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).triu_(1) else: causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) - + if mask is not None: mask += causal_mask else: diff --git a/comfy/model_management.py b/comfy/model_management.py index 95fd39f6a11c..d571e7bff662 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -997,7 +997,7 @@ def is_directml_enabled(): global directml_enabled if directml_enabled: return True - + return False def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): From f594ea41f5f42405a87601117c07c31e68ca9a1c Mon Sep 17 00:00:00 2001 From: "Chowdhury, Hisham" Date: Tue, 11 Feb 2025 11:41:17 -0800 Subject: [PATCH 3/3] update casual mask calculation --- comfy/clip_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/clip_model.py b/comfy/clip_model.py index e1d525caa7c8..0163c6fe7d5d 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -105,7 +105,7 @@ def forward(self, input_tokens, attention_mask=None, intermediate_output=None, f mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max) if comfy.model_management.is_directml_enabled(): - causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).triu_(1) + causal_mask = torch.full((x.shape[1], x.shape[1]), -torch.finfo(x.dtype).max, dtype=x.dtype, device=x.device).triu_(1) else: causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)