|
22 | 22 | from comfy.cli_args import args, PerformanceFeature |
23 | 23 | import torch |
24 | 24 | import sys |
| 25 | +import importlib |
25 | 26 | import platform |
26 | 27 | import weakref |
27 | 28 | import gc |
@@ -336,12 +337,13 @@ def amd_min_version(device=None, min_rdna_version=0): |
336 | 337 | logging.info("AMD arch: {}".format(arch)) |
337 | 338 | logging.info("ROCm version: {}".format(rocm_version)) |
338 | 339 | if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: |
339 | | - if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much |
340 | | - if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 |
341 | | - ENABLE_PYTORCH_ATTENTION = True |
342 | | -# if torch_version_numeric >= (2, 8): |
343 | | -# if any((a in arch) for a in ["gfx1201"]): |
344 | | -# ENABLE_PYTORCH_ATTENTION = True |
| 340 | + if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not. |
| 341 | + if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much |
| 342 | + if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 |
| 343 | + ENABLE_PYTORCH_ATTENTION = True |
| 344 | +# if torch_version_numeric >= (2, 8): |
| 345 | +# if any((a in arch) for a in ["gfx1201"]): |
| 346 | +# ENABLE_PYTORCH_ATTENTION = True |
345 | 347 | if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): |
346 | 348 | if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches |
347 | 349 | SUPPORT_FP8_OPS = True |
|
0 commit comments