Skip to content

Commit 5fabfef

Browse files
committed
inital latency test
1 parent 7564882 commit 5fabfef

File tree

1 file changed

+98
-0
lines changed

1 file changed

+98
-0
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#!/usr/bin/env python3
2+
# SPDX-License-Identifier: MIT
3+
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
4+
5+
import pytest
6+
import torch
7+
import triton
8+
import triton.language as tl
9+
import numpy as np
10+
import iris
11+
from examples.common.utils import read_realtime
12+
13+
14+
@triton.jit()
15+
def ping_pong(
16+
data,
17+
result,
18+
len,
19+
iter,
20+
skip,
21+
flag: tl.tensor,
22+
curr_rank,
23+
BLOCK_SIZE: tl.constexpr,
24+
heap_bases: tl.tensor,
25+
mm_begin_timestamp_ptr: tl.tensor = None,
26+
mm_end_timestamp_ptr: tl.tensor = None,
27+
):
28+
peer = (curr_rank + 1) % 2
29+
pid = tl.program_id(0)
30+
block_start = pid * BLOCK_SIZE
31+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
32+
33+
data_mask = offsets < len
34+
flag_mask = offsets < 1
35+
time_stmp_mask = offsets < 1
36+
37+
for i in range(iter + skip):
38+
if (i == skip):
39+
start = read_realtime();
40+
tl.atomic_xchg(mm_begin_timestamp_ptr + offsets, start, time_stmp_mask)
41+
if curr_rank == (i + 1) % 2:
42+
while tl.load(flag, cache_modifier=".cv", volatile=True) != i + 1:
43+
pass
44+
iris.put(data + offsets, result + offsets, curr_rank, peer, heap_bases, mask=data_mask)
45+
tl.store(flag + offsets, i + 1, mask=flag_mask)
46+
iris.put(flag + offsets, flag + offsets, curr_rank, peer, heap_bases, flag_mask)
47+
else:
48+
iris.put(data + offsets, result + offsets, curr_rank, peer, heap_bases, mask=data_mask)
49+
tl.store(flag + offsets, i + 1, mask=flag_mask)
50+
iris.put(flag + offsets, flag + offsets, curr_rank, peer, heap_bases, flag_mask)
51+
while tl.load(flag, cache_modifier=".cv", volatile=True) != i + 1:
52+
pass
53+
stop = read_realtime();
54+
tl.atomic_xchg(mm_end_timestamp_ptr + offsets, stop, time_stmp_mask)
55+
56+
57+
@pytest.mark.parametrize(
58+
"dtype",
59+
[
60+
torch.int32,
61+
# torch.float16,
62+
# torch.bfloat16,
63+
# torch.float32,
64+
],
65+
)
66+
@pytest.mark.parametrize(
67+
"heap_size",
68+
[
69+
(1 << 33),
70+
],
71+
)
72+
def test_load_bench(dtype, heap_size):
73+
shmem = iris.iris(heap_size)
74+
num_ranks = shmem.get_num_ranks()
75+
heap_bases = shmem.get_heap_bases()
76+
cur_rank = shmem.get_rank()
77+
assert num_ranks == 2
78+
79+
BLOCK_SIZE = 1
80+
BUFFER_LEN = 64*1024
81+
82+
iter = 200
83+
skip = 20
84+
mm_begin_timestamp = torch.zeros(BLOCK_SIZE, dtype=torch.int64, device="cuda")
85+
mm_end_timestamp = torch.zeros(BLOCK_SIZE, dtype=torch.int64, device="cuda")
86+
87+
source_buffer = shmem.ones(BUFFER_LEN, dtype=dtype)
88+
result_buffer = shmem.zeros_like(source_buffer)
89+
flag = shmem.ones(1, dtype=dtype)
90+
91+
grid = lambda meta: (1,)
92+
ping_pong[grid](source_buffer, result_buffer, BUFFER_LEN, skip, iter, flag, cur_rank, BLOCK_SIZE, heap_bases,mm_begin_timestamp, mm_end_timestamp)
93+
shmem.barrier()
94+
begin_val = mm_begin_timestamp.cpu().item()
95+
end_val = mm_end_timestamp.cpu().item()
96+
with open(f'timestamps_{cur_rank}.txt', 'w') as f:
97+
f.write(f"mm_begin_timestamp: {begin_val}\n")
98+
f.write(f"mm_end_timestamp: {end_val}\n")

0 commit comments

Comments
 (0)