Skip to content

Commit dd7038d

Browse files
github-actions[bot]astroC86
authored andcommitted
Apply Ruff auto-fixes
1 parent ad03093 commit dd7038d

File tree

1 file changed

+39
-39
lines changed

1 file changed

+39
-39
lines changed

tests/examples/test_load_latency.py

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -9,35 +9,28 @@
99
import numpy as np
1010
import iris
1111
from iris._mpi_helpers import mpi_allgather
12-
# from examples.common.utils import read_realtime
13-
14-
@triton.jit
15-
def read_realtime():
16-
tmp = tl.inline_asm_elementwise(
17-
asm="mov.u64 $0, %globaltimer;",
18-
constraints=("=l"),
19-
args=[],
20-
dtype=tl.int64,
21-
is_pure=False,
22-
pack=1,
23-
)
24-
return tmp
12+
from examples.common.utils import read_realtime
13+
14+
2515

2616
@triton.jit()
2717
def gather_latencies(
28-
local_latency,
29-
global_latency,
30-
curr_rank,
31-
num_ranks ,
32-
BLOCK_SIZE: tl.constexpr,
33-
heap_bases: tl.tensor
18+
local_latency, global_latency, curr_rank, num_ranks, BLOCK_SIZE: tl.constexpr, heap_bases: tl.tensor
3419
):
3520
pid = tl.program_id(0)
3621
block_start = pid * BLOCK_SIZE
3722
offsets = block_start + tl.arange(0, BLOCK_SIZE)
3823

3924
latency_mask = offsets < num_ranks
40-
iris.put(local_latency + offsets, global_latency + curr_rank * num_ranks + offsets, curr_rank, 0, heap_bases, mask=latency_mask)
25+
iris.put(
26+
local_latency + offsets,
27+
global_latency + curr_rank * num_ranks + offsets,
28+
curr_rank,
29+
0,
30+
heap_bases,
31+
mask=latency_mask,
32+
)
33+
4134

4235
@triton.jit()
4336
def ping_pong(
@@ -66,7 +59,7 @@ def ping_pong(
6659
start = read_realtime()
6760
tl.atomic_xchg(mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, start, time_stmp_mask)
6861
first_rank = tl.minimum(curr_rank, peer_rank) if (i % 2) == 0 else tl.maximum(curr_rank, peer_rank)
69-
token_first_done = i + 1
62+
token_first_done = i + 1
7063
token_second_done = i + 2
7164
if curr_rank == first_rank:
7265
iris.put(data + offsets, data + offsets, curr_rank, peer_rank, heap_bases, mask=data_mask)
@@ -82,8 +75,9 @@ def ping_pong(
8275
stop = read_realtime()
8376
tl.atomic_xchg(mm_end_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, stop, time_stmp_mask)
8477

78+
8579
if __name__ == "__main__":
86-
dtype = torch.int32
80+
dtype = torch.int32
8781
heap_size = 1 << 32
8882
shmem = iris.iris(heap_size)
8983
num_ranks = shmem.get_num_ranks()
@@ -96,42 +90,48 @@ def ping_pong(
9690
iter = 200
9791
skip = 1
9892
mm_begin_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda")
99-
mm_end_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda")
93+
mm_end_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda")
10094

101-
local_latency = torch.zeros((num_ranks), dtype=torch.float32, device="cuda")
95+
local_latency = torch.zeros((num_ranks), dtype=torch.float32, device="cuda")
10296

10397
source_buffer = shmem.ones(BUFFER_LEN, dtype=dtype)
10498
result_buffer = shmem.zeros_like(source_buffer)
105-
flag = shmem.ones(1, dtype=dtype)
99+
flag = shmem.ones(1, dtype=dtype)
106100

107101
grid = lambda meta: (1,)
108102
for source_rank in range(num_ranks):
109103
for destination_rank in range(num_ranks):
110104
if source_rank != destination_rank and cur_rank in [source_rank, destination_rank]:
111105
peer_for_me = destination_rank if cur_rank == source_rank else source_rank
112-
ping_pong[grid](source_buffer,
113-
BUFFER_LEN,
114-
skip, iter,
115-
flag,
116-
cur_rank, peer_for_me,
117-
BLOCK_SIZE,
118-
heap_bases,
119-
mm_begin_timestamp,
120-
mm_end_timestamp)
106+
ping_pong[grid](
107+
source_buffer,
108+
BUFFER_LEN,
109+
skip,
110+
iter,
111+
flag,
112+
cur_rank,
113+
peer_for_me,
114+
BLOCK_SIZE,
115+
heap_bases,
116+
mm_begin_timestamp,
117+
mm_end_timestamp,
118+
)
121119
shmem.barrier()
122-
120+
123121
for destination_rank in range(num_ranks):
124-
local_latency[destination_rank] = (mm_end_timestamp.cpu()[destination_rank] - mm_begin_timestamp.cpu()[destination_rank]) / iter
125-
122+
local_latency[destination_rank] = (
123+
mm_end_timestamp.cpu()[destination_rank] - mm_begin_timestamp.cpu()[destination_rank]
124+
) / iter
125+
126126
latency_matrix = mpi_allgather(local_latency.cpu())
127127

128128
if cur_rank == 0:
129-
with open(f"latency.txt", "w") as f:
129+
with open("latency.txt", "w") as f:
130130
f.write(" ," + ", ".join(f"R{j}" for j in range(num_ranks)) + "\n")
131131
for i in range(num_ranks):
132132
row_entries = []
133133
for j in range(num_ranks):
134134
val = float(latency_matrix[i, j])
135135
row_entries.append(f"{val:0.6f}")
136136
line = f"R{i}," + ", ".join(row_entries) + "\n"
137-
f.write(line)
137+
f.write(line)

0 commit comments

Comments
 (0)