Skip to content

Commit de64a1b

Browse files
authored
[STF] Add graph_ctx constructors for explicit graph with user stream (#8094)
The impl(cudaGraph_t) constructor was missing base class initialization with the async_resources_handle. Fix this and add a new constructor taking (cudaGraph_t, cudaStream_t, async_resources_handle) that launches the graph on a user-provided stream with non-blocking finalize. Also add explicit_graph_async test exercising the new constructor.
1 parent 5faed19 commit de64a1b

File tree

3 files changed

+81
-4
lines changed

3 files changed

+81
-4
lines changed

cudax/include/cuda/experimental/__stf/graph/graph_ctx.cuh

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,13 +187,27 @@ class graph_ctx : public backend_ctx<graph_ctx>
187187
}
188188

189189
// Note that graph contexts with an explicit graph passed by the user cannot use stages
190-
impl(cudaGraph_t g)
191-
: _graph(wrap_cuda_graph(g))
190+
impl(cudaGraph_t g, async_resources_handle _async_resources = async_resources_handle(nullptr))
191+
: backend_ctx<graph_ctx>::impl(mv(_async_resources))
192+
, _graph(wrap_cuda_graph(g))
192193
, explicit_graph(true)
193194
{
194195
reserved::backend_ctx_setup_allocators<impl, uncached_graph_allocator>(*this);
195196
}
196197

198+
// Constructor with explicit graph, user stream, and async resources
199+
impl(cudaGraph_t g,
200+
cudaStream_t user_stream,
201+
async_resources_handle _async_resources = async_resources_handle(nullptr))
202+
: backend_ctx<graph_ctx>::impl(mv(_async_resources))
203+
, submitted_stream(user_stream)
204+
, _graph(wrap_cuda_graph(g))
205+
, explicit_graph(true)
206+
, blocking_finalize(false)
207+
{
208+
reserved::backend_ctx_setup_allocators<impl, uncached_graph_allocator>(*this);
209+
}
210+
197211
~impl() override {}
198212

199213
::std::string to_string() const override
@@ -285,8 +299,13 @@ public:
285299
}
286300

287301
/// @brief Constructor taking a user-provided graph. User code is not supposed to destroy the graph later.
288-
graph_ctx(cudaGraph_t g)
289-
: backend_ctx<graph_ctx>(::std::make_shared<impl>(g))
302+
graph_ctx(cudaGraph_t g, async_resources_handle handle = async_resources_handle(nullptr))
303+
: backend_ctx<graph_ctx>(::std::make_shared<impl>(g, mv(handle)))
304+
{}
305+
306+
/// @brief Constructor with explicit graph, support stream, and async resources
307+
graph_ctx(cudaGraph_t g, cudaStream_t user_stream, async_resources_handle handle = async_resources_handle(nullptr))
308+
: backend_ctx<graph_ctx>(::std::make_shared<impl>(g, user_stream, mv(handle)))
290309
{}
291310
///@}
292311

cudax/test/stf/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ set(
3333
gnu/include_only.cpp
3434
graph/concurrency_test.cu
3535
graph/explicit_graph.cu
36+
graph/explicit_graph_async.cu
3637
graph/explicit_graph_while.cu
3738
graph/explicit_graph_while-kernels.cu
3839
graph/get_cache_stats.cu
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of CUDASTF in CUDA C++ Core Libraries,
4+
// under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
// SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES.
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
//! @file
12+
//! @brief Add tasks to a user-provided graph and launch on a user-provided stream.
13+
//! Exercises graph_ctx(cudaGraph_t, cudaStream_t): finalize() submits the
14+
//! graph on the given stream and does not block; the caller synchronizes.
15+
16+
#include <cuda/experimental/__stf/graph/graph_ctx.cuh>
17+
18+
using namespace cuda::experimental::stf;
19+
20+
__global__ void dummy() {}
21+
22+
int main()
23+
{
24+
cudaGraph_t graph;
25+
cudaStream_t stream;
26+
27+
cuda_safe_call(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
28+
cuda_safe_call(cudaGraphCreate(&graph, 0));
29+
30+
graph_ctx ctx(graph, stream);
31+
32+
auto lX = ctx.token();
33+
auto lY = ctx.token();
34+
auto lZ = ctx.token();
35+
36+
ctx.task(lX.write())->*[](cudaStream_t s) {
37+
dummy<<<1, 1, 0, s>>>();
38+
};
39+
40+
ctx.task(lX.read(), lY.write())->*[](cudaStream_t s) {
41+
dummy<<<1, 1, 0, s>>>();
42+
};
43+
44+
ctx.task(lX.read(), lZ.write())->*[](cudaStream_t s) {
45+
dummy<<<1, 1, 0, s>>>();
46+
};
47+
48+
ctx.task(lY.rw(), lZ.rw())->*[](cudaStream_t s) {
49+
dummy<<<1, 1, 0, s>>>();
50+
};
51+
52+
// Non-blocking: submits the graph on the user-provided stream
53+
ctx.finalize();
54+
55+
cuda_safe_call(cudaStreamSynchronize(stream));
56+
cuda_safe_call(cudaStreamDestroy(stream));
57+
}

0 commit comments

Comments
 (0)