diff --git a/comfy_aimdo/control.py b/comfy_aimdo/control.py index 237f19f..0b9d412 100644 --- a/comfy_aimdo/control.py +++ b/comfy_aimdo/control.py @@ -10,7 +10,6 @@ def detect_vendor(): - version = "" try: torch_spec = importlib.util.find_spec("torch") for folder in torch_spec.submodule_search_locations: @@ -19,15 +18,13 @@ def detect_vendor(): spec = importlib.util.spec_from_file_location("torch_version_import", ver_file) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - version = module.__version__ + if module.cuda != None: + return "cuda" + if module.rocm != None: + return "rocm" except Exception as e: logging.warning("Failed to detect Torch version") pass - - if '+cu' in version: - return "cuda" - if '+rocm' in version: - return "rocm" return None