Skip to content

Commit f0af76d

Browse files
authored
Replace OS sleep with GPU nanosleep kernel in event timing test (#1285)
* Replace timing-based event test with deterministic elapsed-time check The previous test attempted to measure a real sleep delay between two event records, which introduced flakiness (especially on Windows/WDDM) and tested OS/driver timing behavior rather than the __sub__ implementation itself. This change replaces the test with a minimal, deterministic version that: * records two back-to-back events on the same stream * synchronizes on the second event to ensure both timestamps are valid * asserts that cuEventElapsedTime returns a finite, non-negative float This exercises the success path of Event.__sub__ without depending on actual GPU/OS timing characteristics, or requiring artificial GPU work. * cuda_core/tests/helpers/__init__.py: also use CUDA_HOME * Revert "cuda_core/tests/helpers/__init__.py: also use CUDA_HOME" This reverts commit 605f1ef. * Use nanosleep kernel in test_event_elapsed_time_basic for deterministic timing Replace the back-to-back event record test with a version that uses a __nanosleep kernel between events. This ensures a guaranteed positive elapsed time (delta_ms > 10) without depending on OS/driver timing characteristics or requiring artificial GPU work beyond the minimal nanosleep delay. The kernel sleeps for 20ms (double the assertion threshold of 10ms), providing a large safety margin above the ~0.5 microsecond resolution of cudaEventElapsedTime, making this test deterministic and non-flaky across platforms including Windows/WDDM. * Fix nanosleep kernel to use clock64() loop for guaranteed duration Replace single __nanosleep() call with clock64()-based loop to ensure the kernel actually waits for the full 20ms duration. A single __nanosleep() call doesn't guarantee the full sleep duration, which caused measured times to be orders of magnitude less than expected (~0.2ms instead of ~20ms). The new implementation: - Uses clock64() to measure actual elapsed time - Loops until 20ms worth of clock cycles have elapsed - Uses __nanosleep(1000000) inside the loop to yield and avoid 100% CPU spin This ensures delta_ms > 10 assertion is reliable and the test passes deterministically. * clock64() return type is documented as `long long int`: https://docs.nvidia.com/cuda/cuda-c-programming-guide/#time-function * Use device.arch instead of joining device.compute_capability * cusor-generated cuda_core/tests/helpers/nanosleep_kernel.py * Change NanosleepKernel API to sleep_duration_ms * Rename back to test_timing_success * Streamline a comment * Polish comments. Make the code more similar to the existing code. * Simplify nanosleep_kernel implementation.
1 parent f52c71a commit f0af76d

File tree

2 files changed

+63
-18
lines changed

2 files changed

+63
-18
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from cuda.core.experimental import (
5+
LaunchConfig,
6+
Program,
7+
ProgramOptions,
8+
launch,
9+
)
10+
11+
12+
class NanosleepKernel:
13+
"""
14+
Manages a kernel that sleeps for a specified duration using clock64().
15+
"""
16+
17+
def __init__(self, device, sleep_duration_ms: int = 20):
18+
"""
19+
Initialize the nanosleep kernel.
20+
21+
Args:
22+
device: CUDA device to compile the kernel for
23+
sleep_duration_ms: Duration to sleep in milliseconds (default: 20)
24+
"""
25+
code = f"""
26+
extern "C"
27+
__global__ void nanosleep_kernel() {{
28+
// The maximum sleep duration is approximately 1 millisecond.
29+
unsigned int one_ms = 1000000U;
30+
for (unsigned int i = 0; i < {sleep_duration_ms}; ++i) {{
31+
__nanosleep(one_ms);
32+
}}
33+
}}
34+
"""
35+
program_options = ProgramOptions(std="c++17", arch=f"sm_{device.arch}")
36+
prog = Program(code, code_type="c++", options=program_options)
37+
mod = prog.compile("cubin")
38+
self.kernel = mod.get_kernel("nanosleep_kernel")
39+
40+
def launch(self, stream):
41+
"""Launch the nanosleep kernel on the given stream."""
42+
config = LaunchConfig(grid=1, block=1)
43+
launch(stream, config, self.kernel)

cuda_core/tests/test_event.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
import os
5-
import time
4+
5+
import math
66

77
import cuda.core.experimental
88
import pytest
@@ -12,8 +12,7 @@
1212
EventOptions,
1313
)
1414
from helpers.latch import LatchKernel
15-
16-
from cuda_python_test_helpers import IS_WSL
15+
from helpers.nanosleep_kernel import NanosleepKernel
1716

1817

1918
def test_event_init_disabled():
@@ -23,25 +22,28 @@ def test_event_init_disabled():
2322

2423
def test_timing_success(init_cuda):
2524
options = EventOptions(enable_timing=True)
26-
stream = Device().create_stream()
27-
delay_seconds = 0.5
25+
device = Device()
26+
stream = device.create_stream()
27+
28+
# Create a nanosleep kernel that sleeps for 20 ms to ensure a measurable delay.
29+
# This guarantees elapsed_time_ms > 10 without depending on OS/driver timing characteristics.
30+
nanosleep = NanosleepKernel(device, sleep_duration_ms=20)
31+
2832
e1 = stream.record(options=options)
29-
time.sleep(delay_seconds)
33+
nanosleep.launch(stream) # Insert a guaranteed delay
3034
e2 = stream.record(options=options)
3135
e2.sync()
3236
elapsed_time_ms = e2 - e1
3337
assert isinstance(elapsed_time_ms, float)
34-
# Using a generous tolerance, to avoid flaky tests:
35-
# We only want to exercise the __sub__ method, this test is not meant
36-
# to stress-test the CUDA driver or time.sleep().
37-
delay_ms = delay_seconds * 1000
38-
if os.name == "nt" or IS_WSL: # noqa: SIM108
39-
# For Python <=3.10, the Windows timer resolution is typically limited to 15.6 ms by default.
40-
generous_tolerance = 100
41-
else:
42-
# Most modern Linux kernels have a default timer resolution of 1 ms.
43-
generous_tolerance = 20
44-
assert delay_ms - generous_tolerance <= elapsed_time_ms < delay_ms + generous_tolerance
38+
# Sanity check: cuEventElapsedTime should always return a finite float for two completed
39+
# events. This guards against unexpected driver/HW anomalies (e.g. NaN or inf) or general
40+
# undefined behavior, without asserting anything about the magnitude of the measured time.
41+
assert math.isfinite(elapsed_time_ms)
42+
# With the nanosleep kernel between events, the kernel sleeps for 20 ms using clock64(),
43+
# so elapsed_time_ms should definitely be larger than 10 ms. This provides a large safety
44+
# margin above the ~0.5 microsecond resolution of cudaEventElapsedTime(), which should
45+
# make this test deterministic and non-flaky.
46+
assert elapsed_time_ms > 10
4547

4648

4749
def test_is_sync_busy_waited(init_cuda):

0 commit comments

Comments
 (0)