66from gstaichi .lang .impl import Program
77from gstaichi .lang .kernel_arguments import ArgMetadata
88
9- from ._template_mapper_hotpath import _extract_arg
9+ from .._test_tools import warnings_helper
10+ from ._template_mapper_hotpath import _extract_arg , _primitive_types
1011
1112ArgsHash : TypeAlias = tuple [int , ...]
1213Key : TypeAlias = tuple [Any , ...]
@@ -38,7 +39,7 @@ def __init__(self, arguments: list[ArgMetadata], template_slot_locations: list[i
3839 self .template_slot_locations : list [int ] = template_slot_locations
3940 self .mapping : dict [Key , int ] = {}
4041 self ._mapping_cache : dict [ArgsHash , tuple [int , Key ]] = {}
41- self ._mapping_cache_tracker : dict [ArgsHash , list [ReferenceType ]] = {}
42+ self ._mapping_cache_tracker : dict [ArgsHash , list [ReferenceType | None ]] = {}
4243 self ._prog_weakref : ReferenceType [Program ] | None = None
4344
4445 def extract (self , raise_on_templated_floats : bool , args : tuple [Any , ...]) -> Key :
@@ -64,8 +65,12 @@ def lookup(self, raise_on_templated_floats: bool, args: tuple[Any, ...]) -> tupl
6465 prog = self ._prog_weakref ()
6566 assert prog is not None
6667
67- mapping_cache_tracker : list [ReferenceType ] | None = None
68- args_hash : ArgsHash = tuple ([id (arg ) for arg in args ])
68+ # Note that it is necessary to handle primitive types separately. First, using their address as cache key must
69+ # be avoided, because even though it is theoretically possible, it is overly restrictive. Second, it does not
70+ # make sense to use these arguments to track the lifetime of the corresponding cache entry and taking weakref
71+ # of primitive types if forbidden anyway.
72+ mapping_cache_tracker : list [ReferenceType | None ] | None = None
73+ args_hash : ArgsHash = tuple ([id (arg ) if type (arg ) not in _primitive_types else arg for arg in args ])
6974 try :
7075 mapping_cache_tracker = self ._mapping_cache_tracker [args_hash ]
7176 except KeyError :
@@ -79,13 +84,17 @@ def lookup(self, raise_on_templated_floats: bool, args: tuple[Any, ...]) -> tupl
7984 except KeyError :
8085 count = self .mapping [key ] = len (self .mapping )
8186
82- mapping_cache_tracker_ : list [ReferenceType ] = []
87+ # Note that it is important to prepend the cache tracker with 'None' to avoid misclassifying no argument with
88+ # expired cache entry caused by deallocated argument.
89+ mapping_cache_tracker_ : list [ReferenceType | None ] = [None ]
8390 clear_callback = lambda ref : mapping_cache_tracker_ .clear ()
8491 try :
85- mapping_cache_tracker_ += [ReferenceType (arg , clear_callback ) for arg in args ]
92+ mapping_cache_tracker_ += [
93+ ReferenceType (arg , clear_callback ) for arg in args if type (arg ) not in _primitive_types
94+ ]
8695 self ._mapping_cache_tracker [args_hash ] = mapping_cache_tracker_
8796 self ._mapping_cache [args_hash ] = (count , key )
88- except TypeError :
89- pass
97+ except TypeError as e :
98+ warnings_helper . warn_once ( f" { e } . Template mapper caching disabled." )
9099
91100 return (count , key )
0 commit comments