Skip to content

Commit c6b9c11

Browse files
authored
Add oneAPI device selector for xpu and some other changes. (#6112)
* Add oneAPI device selector and some other minor changes. * Fix device selector variable name. * Flip minor version check sign. * Undo changes to README.md.
1 parent e44d0ac commit c6b9c11

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

comfy/cli_args.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ def __call__(self, parser, namespace, values, option_string=None):
8484

8585
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
8686

87-
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
87+
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
88+
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
8889

8990
class LatentPreviewMethod(enum.Enum):
9091
NoPreviews = "none"

comfy/model_management.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class CPUState(Enum):
7575
try:
7676
import intel_extension_for_pytorch as ipex
7777
_ = torch.xpu.device_count()
78-
xpu_available = torch.xpu.is_available()
78+
xpu_available = xpu_available or torch.xpu.is_available()
7979
except:
8080
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
8181

@@ -219,12 +219,14 @@ def is_nvidia():
219219
if args.cpu_vae:
220220
VAE_DTYPES = [torch.float32]
221221

222-
223222
if ENABLE_PYTORCH_ATTENTION:
224223
torch.backends.cuda.enable_math_sdp(True)
225224
torch.backends.cuda.enable_flash_sdp(True)
226225
torch.backends.cuda.enable_mem_efficient_sdp(True)
227226

227+
if int(torch_version[0]) == 2 and int(torch_version[2]) >= 5:
228+
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
229+
228230
if args.lowvram:
229231
set_vram_to = VRAMState.LOW_VRAM
230232
lowvram_available = True

main.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ def execute_script(script_path):
114114
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
115115
logging.info("Set cuda device to: {}".format(args.cuda_device))
116116

117+
if args.oneapi_device_selector is not None:
118+
os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector
119+
logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector))
120+
117121
if args.deterministic:
118122
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
119123
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"

0 commit comments

Comments
 (0)