|
| 1 | +import warnings |
| 2 | +import weakref |
| 3 | +from collections.abc import Callable |
| 4 | +from functools import singledispatch, wraps |
| 5 | +from hashlib import sha256 |
| 6 | +from pathlib import Path |
| 7 | +from pickle import dumps |
| 8 | +from tempfile import NamedTemporaryFile |
| 9 | +from typing import Any |
| 10 | + |
| 11 | +from numba.core.caching import CacheImpl, _CacheLocator |
| 12 | + |
| 13 | +from pytensor import config |
| 14 | +from pytensor.link.numba.compile import numba_funcify, numba_njit |
| 15 | + |
| 16 | + |
| 17 | +NUMBA_PYTENSOR_CACHE_ENABLED = True |
| 18 | +NUMBA_CACHE_PATH = config.base_compiledir / "numba" |
| 19 | +NUMBA_CACHE_PATH.mkdir(exist_ok=True) |
| 20 | +CACHED_SRC_FUNCTIONS = weakref.WeakKeyDictionary() |
| 21 | + |
| 22 | + |
| 23 | +class NumbaPyTensorCacheLocator(_CacheLocator): |
| 24 | + def __init__(self, py_func, py_file, hash): |
| 25 | + self._py_func = py_func |
| 26 | + self._py_file = py_file |
| 27 | + self._hash = hash |
| 28 | + |
| 29 | + def ensure_cache_path(self): |
| 30 | + pass |
| 31 | + |
| 32 | + def get_cache_path(self): |
| 33 | + """ |
| 34 | + Return the directory the function is cached in. |
| 35 | + """ |
| 36 | + return NUMBA_CACHE_PATH |
| 37 | + |
| 38 | + def get_source_stamp(self): |
| 39 | + """ |
| 40 | + Get a timestamp representing the source code's freshness. |
| 41 | + Can return any picklable Python object. |
| 42 | + """ |
| 43 | + return 0 |
| 44 | + |
| 45 | + def get_disambiguator(self): |
| 46 | + """ |
| 47 | + Get a string disambiguator for this locator's function. |
| 48 | + It should allow disambiguating different but similarly-named functions. |
| 49 | + """ |
| 50 | + return self._hash |
| 51 | + |
| 52 | + @classmethod |
| 53 | + def from_function(cls, py_func, py_file): |
| 54 | + """ |
| 55 | + Create a locator instance for the given function located in the given file. |
| 56 | + """ |
| 57 | + # py_file = Path(py_file).parent |
| 58 | + # if py_file == (config.base_compiledir / "numba"): |
| 59 | + if NUMBA_PYTENSOR_CACHE_ENABLED and py_func in CACHED_SRC_FUNCTIONS: |
| 60 | + # print(f"Applies to {py_file}") |
| 61 | + return cls(py_func, Path(py_file).parent, CACHED_SRC_FUNCTIONS[py_func]) |
| 62 | + |
| 63 | + |
| 64 | +CacheImpl._locator_classes.insert(0, NumbaPyTensorCacheLocator) |
| 65 | + |
| 66 | + |
| 67 | +@singledispatch |
| 68 | +def numba_funcify_default_op_cache_key( |
| 69 | + op, node=None, **kwargs |
| 70 | +) -> Callable | tuple[Callable, Any]: |
| 71 | + """Funcify an Op and implement a default cache key. |
| 72 | +
|
| 73 | + The default cache key is based on the op class and its properties. |
| 74 | + It does not take into account the node inputs or other context. |
| 75 | + Note that numba will use the array dtypes, rank and layout as part of the cache key, |
| 76 | + but not the static shape or constant values. |
| 77 | + If the funcify implementation exploits this information, then this method should not be used. |
| 78 | + Instead dispatch directly on `numba_funcify_and_cache_key` (or just numba_funcify) |
| 79 | + which won't use any cache key. |
| 80 | + """ |
| 81 | + # Default cache key of None which means "don't try to do directly cache this function" |
| 82 | + raise NotImplementedError() |
| 83 | + |
| 84 | + |
| 85 | +def register_funcify_default_op_cache_key(op_type): |
| 86 | + """Register a funcify implementation for both cache and non-cache versions.""" |
| 87 | + |
| 88 | + def decorator(dispatch_func): |
| 89 | + # Register with the cache key dispatcher |
| 90 | + numba_funcify_default_op_cache_key.register(op_type)(dispatch_func) |
| 91 | + |
| 92 | + # Create a wrapper for the non-cache dispatcher |
| 93 | + @wraps(dispatch_func) |
| 94 | + def dispatch_func_wrapper(*args, **kwargs): |
| 95 | + func, _key = dispatch_func(*args, **kwargs) |
| 96 | + # Discard the key for the non-cache version |
| 97 | + return func |
| 98 | + |
| 99 | + # Register the wrapper with the non-cache dispatcher |
| 100 | + numba_funcify.register(op_type)(dispatch_func_wrapper) |
| 101 | + |
| 102 | + return dispatch_func |
| 103 | + |
| 104 | + return decorator |
| 105 | + |
| 106 | + |
| 107 | +@singledispatch |
| 108 | +def numba_funcify_and_cache_key(op, node=None, **kwargs) -> tuple[Callable, str | None]: |
| 109 | + # Default cache key of None which means "don't try to do directly cache this function" |
| 110 | + if hasattr(op, "_props"): |
| 111 | + try: |
| 112 | + func_and_salt = numba_funcify_default_op_cache_key(op, node=node, **kwargs) |
| 113 | + except NotImplementedError: |
| 114 | + pass |
| 115 | + else: |
| 116 | + if isinstance(func_and_salt, tuple): |
| 117 | + func, salt = func_and_salt |
| 118 | + else: |
| 119 | + func, salt = func_and_salt, "0" |
| 120 | + props_dict = op._props_dict() |
| 121 | + if not props_dict: |
| 122 | + # Simple op, just use the type string as key |
| 123 | + key_bytes = str((type(op), salt)).encode() |
| 124 | + else: |
| 125 | + # Simple props, can use string representation of props as key |
| 126 | + simple_types = (str, bool, int, type(None), float) |
| 127 | + container_types = (tuple, frozenset) |
| 128 | + if all( |
| 129 | + isinstance(v, simple_types) |
| 130 | + or ( |
| 131 | + isinstance(v, container_types) |
| 132 | + and all(isinstance(i, simple_types) for i in v) |
| 133 | + ) |
| 134 | + for v in props_dict.values() |
| 135 | + ): |
| 136 | + key_bytes = str( |
| 137 | + (type(op), tuple(props_dict.items()), salt) |
| 138 | + ).encode() |
| 139 | + else: |
| 140 | + # Complex props, use pickle to serialize them |
| 141 | + key_bytes = dumps((str(type(op)), tuple(props_dict.items()), salt)) |
| 142 | + return func, sha256(key_bytes).hexdigest() |
| 143 | + |
| 144 | + # Fallback |
| 145 | + return numba_funcify(op, node=node, **kwargs), None |
| 146 | + |
| 147 | + |
| 148 | +def register_funcify_and_cache_key(op_type): |
| 149 | + """Register a funcify implementation for both cache and non-cache versions.""" |
| 150 | + |
| 151 | + def decorator(dispatch_func): |
| 152 | + # Register with the cache key dispatcher |
| 153 | + numba_funcify_and_cache_key.register(op_type)(dispatch_func) |
| 154 | + |
| 155 | + # Create a wrapper for the non-cache dispatcher |
| 156 | + @wraps(dispatch_func) |
| 157 | + def dispatch_func_wrapper(*args, **kwargs): |
| 158 | + func, _key = dispatch_func(*args, **kwargs) |
| 159 | + # Discard the key for the non-cache version |
| 160 | + return func |
| 161 | + |
| 162 | + # Register the wrapper with the non-cache dispatcher |
| 163 | + numba_funcify.register(op_type)(dispatch_func_wrapper) |
| 164 | + |
| 165 | + return dispatch_func_wrapper |
| 166 | + |
| 167 | + return decorator |
| 168 | + |
| 169 | + |
| 170 | +def numba_njit_and_cache(op, *args, **kwargs): |
| 171 | + jitable_func, key = numba_funcify_and_cache_key(op, *args, **kwargs) |
| 172 | + |
| 173 | + if key is not None: |
| 174 | + # To force numba to use our cache, we must compile the function so that any closure |
| 175 | + # becomes a global variable... |
| 176 | + op_name = op.__class__.__name__ |
| 177 | + cached_func = compile_numba_function_src( |
| 178 | + src=f"def {op_name}(*args): return jitable_func(*args)", |
| 179 | + function_name=op_name, |
| 180 | + global_env=globals() | {"jitable_func": jitable_func}, |
| 181 | + cache_key=key, |
| 182 | + ) |
| 183 | + return numba_njit(cached_func, final_function=True, cache=True), key |
| 184 | + else: |
| 185 | + if config.numba__cache and config.compiler_verbose: |
| 186 | + warnings.warn( |
| 187 | + f"Custom numba cache disabled for {op} of type {type(op)}. " |
| 188 | + f"Even if the function is cached by numba, larger graphs using this function cannot be cached.\n" |
| 189 | + "To enable custom caching, register a numba_funcify_and_cache_key implementation for this Op, with a proper cache key." |
| 190 | + ) |
| 191 | + |
| 192 | + return numba_njit( |
| 193 | + lambda *args: jitable_func(*args), final_function=True, cache=False |
| 194 | + ), None |
| 195 | + |
| 196 | + |
| 197 | +def compile_numba_function_src( |
| 198 | + src: str, |
| 199 | + function_name: str, |
| 200 | + global_env: dict[Any, Any] | None = None, |
| 201 | + local_env: dict[Any, Any] | None = None, |
| 202 | + store_to_disk: bool = False, |
| 203 | + cache_key: str | None = None, |
| 204 | +) -> Callable: |
| 205 | + if store_to_disk: |
| 206 | + with NamedTemporaryFile(delete=False) as f: |
| 207 | + filename = f.name |
| 208 | + f.write(src.encode()) |
| 209 | + else: |
| 210 | + filename = "<string>" |
| 211 | + |
| 212 | + if global_env is None: |
| 213 | + global_env = {} |
| 214 | + |
| 215 | + if local_env is None: |
| 216 | + local_env = {} |
| 217 | + |
| 218 | + mod_code = compile(src, filename, mode="exec") |
| 219 | + exec(mod_code, global_env, local_env) |
| 220 | + |
| 221 | + res = local_env[function_name] |
| 222 | + res.__source__ = src # type: ignore |
| 223 | + |
| 224 | + if cache_key is not None: |
| 225 | + CACHED_SRC_FUNCTIONS[res] = cache_key |
| 226 | + return res |
0 commit comments