diff --git a/bitsandbytes/triton/triton_utils.py b/bitsandbytes/triton/triton_utils.py index 6bbdbf1c1..b706ff1ba 100644 --- a/bitsandbytes/triton/triton_utils.py +++ b/bitsandbytes/triton/triton_utils.py @@ -1,5 +1,14 @@ -import importlib +import functools +@functools.lru_cache(None) def is_triton_available(): - return importlib.util.find_spec("triton") is not None + try: + # torch>=2.2.0 + from torch.utils._triton import has_triton, has_triton_package + + return has_triton_package() and has_triton() + except ImportError: + from torch._inductor.utils import has_triton + + return has_triton()