@@ -93,9 +93,10 @@ def __call__(
9393 query_pool = device .create_query_pool (type = spy .QueryType .timestamp , count = iterations * 2 )
9494 for i in range (iterations ):
9595 command_encoder = device .create_command_encoder ()
96- command_encoder .write_timestamp (query_pool , i * 2 )
97- function (** kwargs , _append_to = command_encoder )
98- command_encoder .write_timestamp (query_pool , i * 2 + 1 )
96+ function .write_timestamps ((query_pool , i * 2 , i * 2 + 1 ))(
97+ ** kwargs ,
98+ _append_to = command_encoder ,
99+ )
99100 device .submit_command_buffer (command_encoder .finish ())
100101 device .wait ()
101102 queries = np .array (query_pool .get_results (0 , iterations * 2 ))
@@ -134,9 +135,14 @@ def __call__(
134135 query_pool = device .create_query_pool (type = spy .QueryType .timestamp , count = iterations * 2 )
135136 for i in range (iterations ):
136137 command_encoder = device .create_command_encoder ()
137- command_encoder .write_timestamp (query_pool , i * 2 )
138- kernel .dispatch (thread_count , command_encoder = command_encoder , ** kwargs )
139- command_encoder .write_timestamp (query_pool , i * 2 + 1 )
138+ kernel .dispatch (
139+ thread_count ,
140+ command_encoder = command_encoder ,
141+ query_pool = query_pool ,
142+ query_index_before = i * 2 ,
143+ query_index_after = i * 2 + 1 ,
144+ ** kwargs ,
145+ )
140146 device .submit_command_buffer (command_encoder .finish ())
141147 device .wait ()
142148 queries = np .array (query_pool .get_results (0 , iterations * 2 ))
0 commit comments