99import numpy as np
1010import iris
1111from iris ._mpi_helpers import mpi_allgather
12- # from examples.common.utils import read_realtime
13-
14- @triton .jit
15- def read_realtime ():
16- tmp = tl .inline_asm_elementwise (
17- asm = "mov.u64 $0, %globaltimer;" ,
18- constraints = ("=l" ),
19- args = [],
20- dtype = tl .int64 ,
21- is_pure = False ,
22- pack = 1 ,
23- )
24- return tmp
12+ from examples .common .utils import read_realtime
2513
26- @triton .jit ()
27- def gather_latencies (
28- local_latency ,
29- global_latency ,
30- curr_rank ,
31- num_ranks ,
32- BLOCK_SIZE : tl .constexpr ,
33- heap_bases : tl .tensor
34- ):
35- pid = tl .program_id (0 )
36- block_start = pid * BLOCK_SIZE
37- offsets = block_start + tl .arange (0 , BLOCK_SIZE )
38-
39- latency_mask = offsets < num_ranks
40- iris .put (local_latency + offsets , global_latency + curr_rank * num_ranks + offsets , curr_rank , 0 , heap_bases , mask = latency_mask )
4114
4215@triton .jit ()
4316def ping_pong (
@@ -66,7 +39,7 @@ def ping_pong(
6639 start = read_realtime ()
6740 tl .atomic_xchg (mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets , start , time_stmp_mask )
6841 first_rank = tl .minimum (curr_rank , peer_rank ) if (i % 2 ) == 0 else tl .maximum (curr_rank , peer_rank )
69- token_first_done = i + 1
42+ token_first_done = i + 1
7043 token_second_done = i + 2
7144 if curr_rank == first_rank :
7245 iris .put (data + offsets , data + offsets , curr_rank , peer_rank , heap_bases , mask = data_mask )
@@ -82,8 +55,9 @@ def ping_pong(
8255 stop = read_realtime ()
8356 tl .atomic_xchg (mm_end_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets , stop , time_stmp_mask )
8457
58+
8559if __name__ == "__main__" :
86- dtype = torch .int32
60+ dtype = torch .int32
8761 heap_size = 1 << 32
8862 shmem = iris .iris (heap_size )
8963 num_ranks = shmem .get_num_ranks ()
@@ -96,42 +70,48 @@ def ping_pong(
9670 iter = 200
9771 skip = 1
9872 mm_begin_timestamp = torch .zeros ((num_ranks , BLOCK_SIZE ), dtype = torch .int64 , device = "cuda" )
99- mm_end_timestamp = torch .zeros ((num_ranks , BLOCK_SIZE ), dtype = torch .int64 , device = "cuda" )
73+ mm_end_timestamp = torch .zeros ((num_ranks , BLOCK_SIZE ), dtype = torch .int64 , device = "cuda" )
10074
101- local_latency = torch .zeros ((num_ranks ), dtype = torch .float32 , device = "cuda" )
75+ local_latency = torch .zeros ((num_ranks ), dtype = torch .float32 , device = "cuda" )
10276
10377 source_buffer = shmem .ones (BUFFER_LEN , dtype = dtype )
10478 result_buffer = shmem .zeros_like (source_buffer )
105- flag = shmem .ones (1 , dtype = dtype )
79+ flag = shmem .ones (1 , dtype = dtype )
10680
10781 grid = lambda meta : (1 ,)
10882 for source_rank in range (num_ranks ):
10983 for destination_rank in range (num_ranks ):
11084 if source_rank != destination_rank and cur_rank in [source_rank , destination_rank ]:
11185 peer_for_me = destination_rank if cur_rank == source_rank else source_rank
112- ping_pong [grid ](source_buffer ,
113- BUFFER_LEN ,
114- skip , iter ,
115- flag ,
116- cur_rank , peer_for_me ,
117- BLOCK_SIZE ,
118- heap_bases ,
119- mm_begin_timestamp ,
120- mm_end_timestamp )
86+ ping_pong [grid ](
87+ source_buffer ,
88+ BUFFER_LEN ,
89+ skip ,
90+ iter ,
91+ flag ,
92+ cur_rank ,
93+ peer_for_me ,
94+ BLOCK_SIZE ,
95+ heap_bases ,
96+ mm_begin_timestamp ,
97+ mm_end_timestamp ,
98+ )
12199 shmem .barrier ()
122-
100+
123101 for destination_rank in range (num_ranks ):
124- local_latency [destination_rank ] = (mm_end_timestamp .cpu ()[destination_rank ] - mm_begin_timestamp .cpu ()[destination_rank ]) / iter
125-
102+ local_latency [destination_rank ] = (
103+ mm_end_timestamp .cpu ()[destination_rank ] - mm_begin_timestamp .cpu ()[destination_rank ]
104+ ) / iter
105+
126106 latency_matrix = mpi_allgather (local_latency .cpu ())
127107
128108 if cur_rank == 0 :
129- with open (f "latency.txt" , "w" ) as f :
109+ with open ("latency.txt" , "w" ) as f :
130110 f .write (" ," + ", " .join (f"R{ j } " for j in range (num_ranks )) + "\n " )
131111 for i in range (num_ranks ):
132112 row_entries = []
133113 for j in range (num_ranks ):
134114 val = float (latency_matrix [i , j ])
135115 row_entries .append (f"{ val :0.6f} " )
136116 line = f"R{ i } ," + ", " .join (row_entries ) + "\n "
137- f .write (line )
117+ f .write (line )
0 commit comments