99import numpy as np
1010import iris
1111from iris ._mpi_helpers import mpi_allgather
12- # from examples.common.utils import read_realtime
13-
14- @triton .jit
15- def read_realtime ():
16- tmp = tl .inline_asm_elementwise (
17- asm = "mov.u64 $0, %globaltimer;" ,
18- constraints = ("=l" ),
19- args = [],
20- dtype = tl .int64 ,
21- is_pure = False ,
22- pack = 1 ,
23- )
24- return tmp
12+ from examples .common .utils import read_realtime
13+
14+
2515
2616@triton .jit ()
2717def gather_latencies (
28- local_latency ,
29- global_latency ,
30- curr_rank ,
31- num_ranks ,
32- BLOCK_SIZE : tl .constexpr ,
33- heap_bases : tl .tensor
18+ local_latency , global_latency , curr_rank , num_ranks , BLOCK_SIZE : tl .constexpr , heap_bases : tl .tensor
3419):
3520 pid = tl .program_id (0 )
3621 block_start = pid * BLOCK_SIZE
3722 offsets = block_start + tl .arange (0 , BLOCK_SIZE )
3823
3924 latency_mask = offsets < num_ranks
40- iris .put (local_latency + offsets , global_latency + curr_rank * num_ranks + offsets , curr_rank , 0 , heap_bases , mask = latency_mask )
25+ iris .put (
26+ local_latency + offsets ,
27+ global_latency + curr_rank * num_ranks + offsets ,
28+ curr_rank ,
29+ 0 ,
30+ heap_bases ,
31+ mask = latency_mask ,
32+ )
33+
4134
4235@triton .jit ()
4336def ping_pong (
@@ -66,7 +59,7 @@ def ping_pong(
6659 start = read_realtime ()
6760 tl .atomic_xchg (mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets , start , time_stmp_mask )
6861 first_rank = tl .minimum (curr_rank , peer_rank ) if (i % 2 ) == 0 else tl .maximum (curr_rank , peer_rank )
69- token_first_done = i + 1
62+ token_first_done = i + 1
7063 token_second_done = i + 2
7164 if curr_rank == first_rank :
7265 iris .put (data + offsets , data + offsets , curr_rank , peer_rank , heap_bases , mask = data_mask )
@@ -82,8 +75,9 @@ def ping_pong(
8275 stop = read_realtime ()
8376 tl .atomic_xchg (mm_end_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets , stop , time_stmp_mask )
8477
78+
8579if __name__ == "__main__" :
86- dtype = torch .int32
80+ dtype = torch .int32
8781 heap_size = 1 << 32
8882 shmem = iris .iris (heap_size )
8983 num_ranks = shmem .get_num_ranks ()
@@ -96,42 +90,48 @@ def ping_pong(
9690 iter = 200
9791 skip = 1
9892 mm_begin_timestamp = torch .zeros ((num_ranks , BLOCK_SIZE ), dtype = torch .int64 , device = "cuda" )
99- mm_end_timestamp = torch .zeros ((num_ranks , BLOCK_SIZE ), dtype = torch .int64 , device = "cuda" )
93+ mm_end_timestamp = torch .zeros ((num_ranks , BLOCK_SIZE ), dtype = torch .int64 , device = "cuda" )
10094
101- local_latency = torch .zeros ((num_ranks ), dtype = torch .float32 , device = "cuda" )
95+ local_latency = torch .zeros ((num_ranks ), dtype = torch .float32 , device = "cuda" )
10296
10397 source_buffer = shmem .ones (BUFFER_LEN , dtype = dtype )
10498 result_buffer = shmem .zeros_like (source_buffer )
105- flag = shmem .ones (1 , dtype = dtype )
99+ flag = shmem .ones (1 , dtype = dtype )
106100
107101 grid = lambda meta : (1 ,)
108102 for source_rank in range (num_ranks ):
109103 for destination_rank in range (num_ranks ):
110104 if source_rank != destination_rank and cur_rank in [source_rank , destination_rank ]:
111105 peer_for_me = destination_rank if cur_rank == source_rank else source_rank
112- ping_pong [grid ](source_buffer ,
113- BUFFER_LEN ,
114- skip , iter ,
115- flag ,
116- cur_rank , peer_for_me ,
117- BLOCK_SIZE ,
118- heap_bases ,
119- mm_begin_timestamp ,
120- mm_end_timestamp )
106+ ping_pong [grid ](
107+ source_buffer ,
108+ BUFFER_LEN ,
109+ skip ,
110+ iter ,
111+ flag ,
112+ cur_rank ,
113+ peer_for_me ,
114+ BLOCK_SIZE ,
115+ heap_bases ,
116+ mm_begin_timestamp ,
117+ mm_end_timestamp ,
118+ )
121119 shmem .barrier ()
122-
120+
123121 for destination_rank in range (num_ranks ):
124- local_latency [destination_rank ] = (mm_end_timestamp .cpu ()[destination_rank ] - mm_begin_timestamp .cpu ()[destination_rank ]) / iter
125-
122+ local_latency [destination_rank ] = (
123+ mm_end_timestamp .cpu ()[destination_rank ] - mm_begin_timestamp .cpu ()[destination_rank ]
124+ ) / iter
125+
126126 latency_matrix = mpi_allgather (local_latency .cpu ())
127127
128128 if cur_rank == 0 :
129- with open (f "latency.txt" , "w" ) as f :
129+ with open ("latency.txt" , "w" ) as f :
130130 f .write (" ," + ", " .join (f"R{ j } " for j in range (num_ranks )) + "\n " )
131131 for i in range (num_ranks ):
132132 row_entries = []
133133 for j in range (num_ranks ):
134134 val = float (latency_matrix [i , j ])
135135 row_entries .append (f"{ val :0.6f} " )
136136 line = f"R{ i } ," + ", " .join (row_entries ) + "\n "
137- f .write (line )
137+ f .write (line )
0 commit comments