Skip to content

Commit 1b6e6f9

Browse files
committed
fix test_cache.py
1 parent bd6e18d commit 1b6e6f9

1 file changed

Lines changed: 14 additions & 14 deletions

File tree

tests/python/test_cache.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class StructStaticConfig:
1111
flag_1: bool = True
1212

1313
@ti.kernel
14-
def fun(static_args: ti.template(), flag_2: ti.template(), value: ti.types.ndarray()):
14+
def fun(static_args: ti.Template, flag_2: ti.Template, value: ti.types.NDArray):
1515
if ti.static(static_args.flag_1):
1616
if ti.static(flag_2):
1717
value[None] = value[None] + 1
@@ -22,8 +22,8 @@ def fun(static_args: ti.template(), flag_2: ti.template(), value: ti.types.ndarr
2222

2323
assert len(fun._primal.mapper._mapping_cache) == 0
2424
assert len(fun._primal.mapper._mapping_cache_tracker) == 0
25-
assert len(fun._primal._launch_ctx_cache) == 0
26-
assert len(fun._primal._launch_ctx_cache_tracker) == 0
25+
assert len(fun._primal.launch_context_buffer_cache._launch_ctx_cache) == 0
26+
assert len(fun._primal.launch_context_buffer_cache._launch_ctx_cache_tracker) == 0
2727

2828
static_args = StructStaticConfig()
2929
flag_2 = True
@@ -34,30 +34,30 @@ def fun(static_args: ti.template(), flag_2: ti.template(), value: ti.types.ndarr
3434
assert value[None] == 2
3535
assert len(fun._primal.mapper._mapping_cache) == 1
3636
assert len(fun._primal.mapper._mapping_cache_tracker) == 1
37-
assert len(fun._primal._launch_ctx_cache) == 1
38-
assert len(fun._primal._launch_ctx_cache_tracker) == 1
37+
assert len(fun._primal.launch_context_buffer_cache._launch_ctx_cache) == 1
38+
assert len(fun._primal.launch_context_buffer_cache._launch_ctx_cache_tracker) == 1
3939

4040
fun(static_args, flag_2, value)
4141
assert value[None] == 3
4242
assert len(fun._primal.mapper._mapping_cache) == 1
4343
assert len(fun._primal.mapper._mapping_cache_tracker) == 1
44-
assert len(fun._primal._launch_ctx_cache) == 1
45-
assert len(fun._primal._launch_ctx_cache_tracker) == 1
44+
assert len(fun._primal.launch_context_buffer_cache._launch_ctx_cache) == 1
45+
assert len(fun._primal.launch_context_buffer_cache._launch_ctx_cache_tracker) == 1
4646

4747

4848
@test_utils.test(arch=get_host_arch_list())
4949
def test_cache_fields_only():
5050
@ti.kernel
51-
def fun(flag: ti.template(), value: ti.template()):
51+
def fun(flag: ti.Template, value: ti.Template):
5252
if ti.static(flag):
5353
value[None] = value[None] + 1
5454
else:
5555
assert "Invalid 'static_args.flag_1' branch"
5656

5757
assert len(fun._primal.mapper._mapping_cache) == 0
5858
assert len(fun._primal.mapper._mapping_cache_tracker) == 0
59-
assert len(fun._primal._launch_ctx_cache) == 0
60-
assert len(fun._primal._launch_ctx_cache_tracker) == 0
59+
assert len(fun._primal.launch_context_buffer_cache._launch_ctx_cache) == 0
60+
assert len(fun._primal.launch_context_buffer_cache._launch_ctx_cache_tracker) == 0
6161

6262
flag = True
6363
value = ti.field(ti.i32, shape=())
@@ -67,12 +67,12 @@ def fun(flag: ti.template(), value: ti.template()):
6767
assert value[None] == 2
6868
assert len(fun._primal.mapper._mapping_cache) == 1
6969
assert len(fun._primal.mapper._mapping_cache_tracker) == 1
70-
assert len(fun._primal._launch_ctx_cache) == 1
71-
assert len(fun._primal._launch_ctx_cache_tracker) == 1
70+
assert len(fun._primal.launch_context_buffer_cache._launch_ctx_cache) == 1
71+
assert len(fun._primal.launch_context_buffer_cache._launch_ctx_cache_tracker) == 1
7272

7373
fun(flag, value)
7474
assert value[None] == 3
7575
assert len(fun._primal.mapper._mapping_cache) == 1
7676
assert len(fun._primal.mapper._mapping_cache_tracker) == 1
77-
assert len(fun._primal._launch_ctx_cache) == 1
78-
assert len(fun._primal._launch_ctx_cache_tracker) == 1
77+
assert len(fun._primal.launch_context_buffer_cache._launch_ctx_cache) == 1
78+
assert len(fun._primal.launch_context_buffer_cache._launch_ctx_cache_tracker) == 1

0 commit comments

Comments
 (0)