@@ -11,7 +11,7 @@ class StructStaticConfig:
1111 flag_1 : bool = True
1212
1313 @ti .kernel
14- def fun (static_args : ti .template () , flag_2 : ti .template () , value : ti .types .ndarray () ):
14+ def fun (static_args : ti .Template , flag_2 : ti .Template , value : ti .types .NDArray ):
1515 if ti .static (static_args .flag_1 ):
1616 if ti .static (flag_2 ):
1717 value [None ] = value [None ] + 1
@@ -22,8 +22,8 @@ def fun(static_args: ti.template(), flag_2: ti.template(), value: ti.types.ndarr
2222
2323 assert len (fun ._primal .mapper ._mapping_cache ) == 0
2424 assert len (fun ._primal .mapper ._mapping_cache_tracker ) == 0
25- assert len (fun ._primal ._launch_ctx_cache ) == 0
26- assert len (fun ._primal ._launch_ctx_cache_tracker ) == 0
25+ assert len (fun ._primal .launch_context_buffer_cache . _launch_ctx_cache ) == 0
26+ assert len (fun ._primal .launch_context_buffer_cache . _launch_ctx_cache_tracker ) == 0
2727
2828 static_args = StructStaticConfig ()
2929 flag_2 = True
@@ -34,30 +34,30 @@ def fun(static_args: ti.template(), flag_2: ti.template(), value: ti.types.ndarr
3434 assert value [None ] == 2
3535 assert len (fun ._primal .mapper ._mapping_cache ) == 1
3636 assert len (fun ._primal .mapper ._mapping_cache_tracker ) == 1
37- assert len (fun ._primal ._launch_ctx_cache ) == 1
38- assert len (fun ._primal ._launch_ctx_cache_tracker ) == 1
37+ assert len (fun ._primal .launch_context_buffer_cache . _launch_ctx_cache ) == 1
38+ assert len (fun ._primal .launch_context_buffer_cache . _launch_ctx_cache_tracker ) == 1
3939
4040 fun (static_args , flag_2 , value )
4141 assert value [None ] == 3
4242 assert len (fun ._primal .mapper ._mapping_cache ) == 1
4343 assert len (fun ._primal .mapper ._mapping_cache_tracker ) == 1
44- assert len (fun ._primal ._launch_ctx_cache ) == 1
45- assert len (fun ._primal ._launch_ctx_cache_tracker ) == 1
44+ assert len (fun ._primal .launch_context_buffer_cache . _launch_ctx_cache ) == 1
45+ assert len (fun ._primal .launch_context_buffer_cache . _launch_ctx_cache_tracker ) == 1
4646
4747
4848@test_utils .test (arch = get_host_arch_list ())
4949def test_cache_fields_only ():
5050 @ti .kernel
51- def fun (flag : ti .template () , value : ti .template () ):
51+ def fun (flag : ti .Template , value : ti .Template ):
5252 if ti .static (flag ):
5353 value [None ] = value [None ] + 1
5454 else :
5555 assert "Invalid 'static_args.flag_1' branch"
5656
5757 assert len (fun ._primal .mapper ._mapping_cache ) == 0
5858 assert len (fun ._primal .mapper ._mapping_cache_tracker ) == 0
59- assert len (fun ._primal ._launch_ctx_cache ) == 0
60- assert len (fun ._primal ._launch_ctx_cache_tracker ) == 0
59+ assert len (fun ._primal .launch_context_buffer_cache . _launch_ctx_cache ) == 0
60+ assert len (fun ._primal .launch_context_buffer_cache . _launch_ctx_cache_tracker ) == 0
6161
6262 flag = True
6363 value = ti .field (ti .i32 , shape = ())
@@ -67,12 +67,12 @@ def fun(flag: ti.template(), value: ti.template()):
6767 assert value [None ] == 2
6868 assert len (fun ._primal .mapper ._mapping_cache ) == 1
6969 assert len (fun ._primal .mapper ._mapping_cache_tracker ) == 1
70- assert len (fun ._primal ._launch_ctx_cache ) == 1
71- assert len (fun ._primal ._launch_ctx_cache_tracker ) == 1
70+ assert len (fun ._primal .launch_context_buffer_cache . _launch_ctx_cache ) == 1
71+ assert len (fun ._primal .launch_context_buffer_cache . _launch_ctx_cache_tracker ) == 1
7272
7373 fun (flag , value )
7474 assert value [None ] == 3
7575 assert len (fun ._primal .mapper ._mapping_cache ) == 1
7676 assert len (fun ._primal .mapper ._mapping_cache_tracker ) == 1
77- assert len (fun ._primal ._launch_ctx_cache ) == 1
78- assert len (fun ._primal ._launch_ctx_cache_tracker ) == 1
77+ assert len (fun ._primal .launch_context_buffer_cache . _launch_ctx_cache ) == 1
78+ assert len (fun ._primal .launch_context_buffer_cache . _launch_ctx_cache_tracker ) == 1
0 commit comments