Skip to content

Commit 966ab5e

Browse files
committed
Merge branch 'hp/func-base-refactorization' into hp/func-base-refactorization-factorize-cache
2 parents aa1dcbf + 2985894 commit 966ab5e

8 files changed

Lines changed: 116 additions & 27 deletions

File tree

gstaichi/runtime/cuda/jit_cuda.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ JITSessionCUDA::JITSessionCUDA(GsTaichiLLVMContext *tlctx,
7878
: JITSession(tlctx, config),
7979
data_layout(data_layout),
8080
program_impl_(program_impl),
81-
config(config) {
81+
config_(config) {
8282
PtxCache::Config ptx_cache_config;
8383
ptx_cache_config.offline_cache_path = config.offline_cache_file_path;
8484
ptx_cache_ = std::make_unique<PtxCache>(ptx_cache_config, config);
@@ -91,7 +91,7 @@ JITModule *JITSessionCUDA::add_module(std::unique_ptr<llvm::Module> M,
9191
int max_reg) {
9292
const char *dump_ir_env = std::getenv(DUMP_IR_ENV.data());
9393
if (dump_ir_env != nullptr && std::string(dump_ir_env) == "1") {
94-
std::filesystem::path ir_dump_dir = config.debug_dump_path;
94+
std::filesystem::path ir_dump_dir = config_.debug_dump_path;
9595
std::filesystem::create_directories(ir_dump_dir);
9696
std::string dumpName = moduleToDumpName(M.get());
9797
std::filesystem::path filename =

gstaichi/runtime/cuda/jit_cuda.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class JITSessionCUDA : public JITSession {
8585
std::unique_ptr<PtxCache> ptx_cache_;
8686
ProgramImpl *program_impl_;
8787
std::unique_ptr<Finalizer> finalizer_;
88-
const CompileConfig &config;
88+
const CompileConfig &config_;
8989
};
9090

9191
#endif

python/gstaichi/lang/_func_base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,11 @@ def get_tree_and_ctx(
260260
return tree, ctx
261261

262262
def process_args(self, is_pyfunc: bool, is_func: bool, args: tuple[Any, ...], kwargs) -> tuple[Any, ...]:
263+
"""
264+
- expand dataclass args
265+
- fuse args and kwargs into a single list of args
266+
"""
263267
if is_func and not is_pyfunc:
264-
# if typing.TYPE_CHECKING:
265-
# assert isinstance(self, Func)
266268
current_kernel = self.current_kernel
267269
if typing.TYPE_CHECKING:
268270
assert current_kernel is not None
@@ -315,7 +317,6 @@ def process_args(self, is_pyfunc: bool, is_func: bool, args: tuple[Any, ...], kw
315317
raise GsTaichiSyntaxError(f"Unexpected argument '{key}'.")
316318
elif num_missing_args:
317319
for i in range(num_args, num_arg_metas):
318-
arg = fused_args[i]
319320
if fused_args[i] is _ARG_EMPTY:
320321
arg_meta = self.arg_metas_expanded[i]
321322
raise GsTaichiSyntaxError(f"Missing argument '{arg_meta.name}'.")

python/gstaichi/lang/_template_mapper.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from gstaichi.lang.impl import Program
77
from gstaichi.lang.kernel_arguments import ArgMetadata
88

9-
from ._template_mapper_hotpath import _extract_arg
9+
from .._test_tools import warnings_helper
10+
from ._template_mapper_hotpath import _extract_arg, _primitive_types
1011

1112
ArgsHash: TypeAlias = tuple[int, ...]
1213
Key: TypeAlias = tuple[Any, ...]
@@ -38,7 +39,7 @@ def __init__(self, arguments: list[ArgMetadata], template_slot_locations: list[i
3839
self.template_slot_locations: list[int] = template_slot_locations
3940
self.mapping: dict[Key, int] = {}
4041
self._mapping_cache: dict[ArgsHash, tuple[int, Key]] = {}
41-
self._mapping_cache_tracker: dict[ArgsHash, list[ReferenceType]] = {}
42+
self._mapping_cache_tracker: dict[ArgsHash, list[ReferenceType | None]] = {}
4243
self._prog_weakref: ReferenceType[Program] | None = None
4344

4445
def extract(self, raise_on_templated_floats: bool, args: tuple[Any, ...]) -> Key:
@@ -64,8 +65,12 @@ def lookup(self, raise_on_templated_floats: bool, args: tuple[Any, ...]) -> tupl
6465
prog = self._prog_weakref()
6566
assert prog is not None
6667

67-
mapping_cache_tracker: list[ReferenceType] | None = None
68-
args_hash: ArgsHash = tuple([id(arg) for arg in args])
68+
# Note that it is necessary to handle primitive types separately. First, using their address as cache key must
69+
# be avoided, because even though it is theoretically possible, it is overly restrictive. Second, it does not
70+
# make sense to use these arguments to track the lifetime of the corresponding cache entry and taking weakref
71+
# of primitive types if forbidden anyway.
72+
mapping_cache_tracker: list[ReferenceType | None] | None = None
73+
args_hash: ArgsHash = tuple([id(arg) if type(arg) not in _primitive_types else arg for arg in args])
6974
try:
7075
mapping_cache_tracker = self._mapping_cache_tracker[args_hash]
7176
except KeyError:
@@ -79,13 +84,17 @@ def lookup(self, raise_on_templated_floats: bool, args: tuple[Any, ...]) -> tupl
7984
except KeyError:
8085
count = self.mapping[key] = len(self.mapping)
8186

82-
mapping_cache_tracker_: list[ReferenceType] = []
87+
# Note that it is important to prepend the cache tracker with 'None' to avoid misclassifying no argument with
88+
# expired cache entry caused by deallocated argument.
89+
mapping_cache_tracker_: list[ReferenceType | None] = [None]
8390
clear_callback = lambda ref: mapping_cache_tracker_.clear()
8491
try:
85-
mapping_cache_tracker_ += [ReferenceType(arg, clear_callback) for arg in args]
92+
mapping_cache_tracker_ += [
93+
ReferenceType(arg, clear_callback) for arg in args if type(arg) not in _primitive_types
94+
]
8695
self._mapping_cache_tracker[args_hash] = mapping_cache_tracker_
8796
self._mapping_cache[args_hash] = (count, key)
88-
except TypeError:
89-
pass
97+
except TypeError as e:
98+
warnings_helper.warn_once(f"{e}. Template mapper caching disabled.")
9099

91100
return (count, key)

python/gstaichi/lang/_template_mapper_hotpath.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,10 @@ def _extract_arg(raise_on_templated_floats: bool, arg: Any, annotation: Annotati
8989
# TODO(zhanlue): replacing "tuple(args)" with "hash of argument values"
9090
# This can resolve the following issues:
9191
# 1. Invalid weak-ref will leave a dead(dangling) entry in both caches: "self.mapping" and "self.compiled_functions"
92-
# 2. Different argument instances with same type and same value, will get templatized into seperate kernels.
92+
# 2. Different argument instances with same type and same value, will get templatized into separate kernels.
9393
return weakref.ref(arg)
9494

95-
# [Primitive arguments] Return the value
95+
# Return value directly for other types, i.e. primitive types and all ti.Field-derived classes
9696
if raise_on_templated_floats and arg_type is float:
9797
raise ValueError("Floats not allowed as templated types.")
9898
return arg

python/gstaichi/lang/impl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -519,10 +519,10 @@ def materialize(self):
519519
self._check_gradient_field_not_placed("dual")
520520
self._check_matrix_field_member_shape()
521521
self._calc_matrix_field_dynamic_index_stride()
522-
self.global_vars = []
523-
self.grad_vars = []
524-
self.dual_vars = []
525-
self.matrix_fields = []
522+
self.global_vars.clear()
523+
self.grad_vars.clear()
524+
self.dual_vars.clear()
525+
self.matrix_fields.clear()
526526

527527
def _register_signal_handlers(self):
528528
if self._signal_handler_registry is None:

python/gstaichi/lang/kernel.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def __init__(self) -> None:
9595
# the launch context being stored in cache.
9696
# See 'launch_kernel' for details regarding the intended use of caching.
9797
self._launch_ctx_cache: dict[ArgsHash, KernelLaunchContext] = {}
98-
self._launch_ctx_cache_tracker: dict[ArgsHash, list[ReferenceType]] = {}
98+
self._launch_ctx_cache_tracker: dict[ArgsHash, list[ReferenceType | None]] = {}
9999
self._prog_weakref: ReferenceType[Program] | None = None
100100

101101
def _destroy_callback(self, ref: ReferenceType):
@@ -113,11 +113,12 @@ def cache(
113113
self._launch_ctx_cache[args_hash] = cached_launch_ctx
114114

115115
# Note that the clearing callback will only be called once despite being registered for each tracked
116-
# objects, because all the weakrefs get deallocated right away, and their respective callback
117-
# vanishes with them, without even getting a chance to get called. This means that registring the
118-
# clearing callback systematically does not incur any cumulative runtime penalty yet ensures full
119-
# memory safety.
120-
launch_ctx_cache_tracker_: list[ReferenceType] = []
116+
# objects, because all the weakrefs get deallocated right away, and their respective callback vanishes
117+
# with them, without even getting a chance to get called. This means that registring the clearing
118+
# callback systematically does not incur any cumulative runtime penalty yet ensures full memory safety.
119+
# Note that it is important to prepend the cache tracker with 'None' to avoid misclassifying no argument
120+
# with expired cache entry caused by deallocated argument.
121+
launch_ctx_cache_tracker_: list[ReferenceType | None] = []
121122
clear_callback = lambda ref: launch_ctx_cache_tracker_.clear()
122123
if launch_ctx_args := launch_ctx_buffer.get(_TI_ARRAY):
123124
_, arrs = zip(*launch_ctx_args)
@@ -400,7 +401,7 @@ def launch_kernel(self, t_kernel: KernelCxx, compiled_kernel_data: CompiledKerne
400401
assert len(args) == len(self.arg_metas), f"{len(self.arg_metas)} arguments needed but {len(args)} provided"
401402

402403
callbacks: list[Callable[[], None]] = []
403-
args_hash: ArgsHash = tuple(map(id, args))
404+
args_hash: ArgsHash = (id(t_kernel), *[id(arg) for arg in args if type(arg) is not template])
404405
launch_ctx = t_kernel.make_launch_context()
405406
prog, _populated_launch_ctx = self.launch_context_buffer_cache.populate_launch_ctx_from_cache(
406407
launch_ctx, args_hash

tests/python/test_cache.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import gstaichi as ti
2+
from gstaichi.lang.misc import get_host_arch_list
3+
4+
from tests import test_utils
5+
6+
7+
@test_utils.test(arch=get_host_arch_list())
8+
def test_cache_primitive_args():
9+
@ti.data_oriented
10+
class StructStaticConfig:
11+
flag_1: bool = True
12+
13+
@ti.kernel
14+
def fun(static_args: ti.template(), flag_2: ti.template(), value: ti.types.ndarray()):
15+
if ti.static(static_args.flag_1):
16+
if ti.static(flag_2):
17+
value[None] = value[None] + 1
18+
else:
19+
assert "Invalid 'flag_2' branch"
20+
else:
21+
assert "Invalid 'static_args.flag_1' branch"
22+
23+
assert len(fun._primal.mapper._mapping_cache) == 0
24+
assert len(fun._primal.mapper._mapping_cache_tracker) == 0
25+
assert len(fun._primal._launch_ctx_cache) == 0
26+
assert len(fun._primal._launch_ctx_cache_tracker) == 0
27+
28+
static_args = StructStaticConfig()
29+
flag_2 = True
30+
value = ti.ndarray(ti.i32, shape=())
31+
value[None] = 1
32+
33+
fun(static_args, flag_2, value)
34+
assert value[None] == 2
35+
assert len(fun._primal.mapper._mapping_cache) == 1
36+
assert len(fun._primal.mapper._mapping_cache_tracker) == 1
37+
assert len(fun._primal._launch_ctx_cache) == 1
38+
assert len(fun._primal._launch_ctx_cache_tracker) == 1
39+
40+
fun(static_args, flag_2, value)
41+
assert value[None] == 3
42+
assert len(fun._primal.mapper._mapping_cache) == 1
43+
assert len(fun._primal.mapper._mapping_cache_tracker) == 1
44+
assert len(fun._primal._launch_ctx_cache) == 1
45+
assert len(fun._primal._launch_ctx_cache_tracker) == 1
46+
47+
48+
@test_utils.test(arch=get_host_arch_list())
49+
def test_cache_fields_only():
50+
@ti.kernel
51+
def fun(flag: ti.template(), value: ti.template()):
52+
if ti.static(flag):
53+
value[None] = value[None] + 1
54+
else:
55+
assert "Invalid 'static_args.flag_1' branch"
56+
57+
assert len(fun._primal.mapper._mapping_cache) == 0
58+
assert len(fun._primal.mapper._mapping_cache_tracker) == 0
59+
assert len(fun._primal._launch_ctx_cache) == 0
60+
assert len(fun._primal._launch_ctx_cache_tracker) == 0
61+
62+
flag = True
63+
value = ti.field(ti.i32, shape=())
64+
value[None] = 1
65+
66+
fun(flag, value)
67+
assert value[None] == 2
68+
assert len(fun._primal.mapper._mapping_cache) == 1
69+
assert len(fun._primal.mapper._mapping_cache_tracker) == 1
70+
assert len(fun._primal._launch_ctx_cache) == 1
71+
assert len(fun._primal._launch_ctx_cache_tracker) == 1
72+
73+
fun(flag, value)
74+
assert value[None] == 3
75+
assert len(fun._primal.mapper._mapping_cache) == 1
76+
assert len(fun._primal.mapper._mapping_cache_tracker) == 1
77+
assert len(fun._primal._launch_ctx_cache) == 1
78+
assert len(fun._primal._launch_ctx_cache_tracker) == 1

0 commit comments

Comments
 (0)