Skip to content

Commit 0599ec3

Browse files
authored
[STF] Support dynamically added dependencies in host_launch (#8174)
* [STF] Use opaque struct pointers for C API handles Replace `typedef void*` with `typedef struct <name>_t* <name>` for all opaque handle types in the STF C API. This provides type safety at the C level, preventing accidental mixing of different handle types. The naming convention (e.g. stf_ctx_handle_t / stf_ctx_handle) matches the Cython declarations in the stf_c_api branch for consistency. Made-with: Cursor * Add add_deps(task_dep_untyped) to host_launch_scope Allow C/Python bindings to build dependencies incrementally without requiring compile-time type information. Cherry-picked C++ parts from afde3ce. Made-with: Cursor * Add untyped host_launch with host_launch_deps handle Introduce host_launch_deps, an opaque handle that provides indexed access to dependency data and optional user data inside host callbacks. Supports get<T>(index), size(), user_data() with proper lifetime management via a destructor callback. New host_launch_deps class with if-constexpr untyped dispatch path in operator->*, set_user_data() on host_launch_scope, graph-safe resource management. Tests cover stream_ctx and graph_ctx (12 tests). Cherry-picked C++ parts from 9db3b86. Made-with: Cursor * Fix build errors in host_launch untyped path - Add set_user_data forwarding to unified_scope in context.cuh - Use explicit host_launch_deps{} instead of {} for nvcc pair deduction Cherry-picked C++ parts from 6f6bf42. Made-with: Cursor * Fix untyped dispatch for generic lambdas in host_launch Use a private canary type to detect generic lambdas: if Fun also accepts the canary, it is generic and should use the typed path. Cherry-picked from af970ac. Made-with: Cursor * Fix untyped dispatch: use std::conjunction for nvcc compatibility nvcc eagerly instantiates generic-lambda bodies during is_invocable_v checks, causing hard errors when the lambda body uses members that don't exist on host_launch_deps (e.g. data_handle()). Use std::conjunction to short-circuit: is_invocable<Fun, host_launch_deps&> is only instantiated when sizeof...(Deps) == 0. Cherry-picked from 5356c1f. Made-with: Cursor * clang-format * Fix incomplete type error for logical_data_untyped in host_launch_scope host_launch_deps holds a std::vector<logical_data_untyped>, but logical_data_untyped was only forward-declared through transitive includes. GCC 9's new_allocator::deallocate applies alignof to the element type, which requires a complete type. Include logical_data.cuh to provide the full definition. Made-with: Cursor * Add comments to user_data and host_launch_deps APIs Made-with: Cursor * Add host_launch C bindings for stream and graph contexts Expose the untyped host_launch API through the C interface, allowing C and Python callers to schedule host callbacks as task graph nodes with full dependency tracking and optional user data. Made-with: Cursor * Use opaque struct pointers for host_launch handles Apply the same opaque struct pointer pattern to stf_host_launch_handle and stf_host_launch_deps_handle, and update the corresponding casts in the implementation. Made-with: Cursor * Use a less ambiguous variable name for clarity * Minor comment improvement
1 parent 35968c2 commit 0599ec3

File tree

7 files changed

+976
-44
lines changed

7 files changed

+976
-44
lines changed

c/experimental/stf/include/cccl/c/experimental/stf/stf.h

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,30 @@ typedef struct stf_task_handle_t* stf_task_handle;
440440

441441
typedef struct stf_cuda_kernel_handle_t* stf_cuda_kernel_handle;
442442

443+
//!
444+
//! \brief Opaque handle for a host launch scope
445+
//!
446+
//! A host launch scope schedules a user-provided C callback on the host
447+
//! as a proper task graph node, with full dependency tracking.
448+
//! Created with stf_host_launch_create() and destroyed with stf_host_launch_destroy().
449+
450+
typedef struct stf_host_launch_handle_t* stf_host_launch_handle;
451+
452+
//!
453+
//! \brief Opaque handle for host launch dependency data
454+
//!
455+
//! Passed to the host callback at invocation time. Provides indexed
456+
//! access to the data of each dependency and to optional user data.
457+
458+
typedef struct stf_host_launch_deps_handle_t* stf_host_launch_deps_handle;
459+
460+
//!
461+
//! \brief C callback type for host launch
462+
//!
463+
//! \param deps Opaque handle to dependency data and user data
464+
465+
typedef void (*stf_host_callback_fn)(stf_host_launch_deps_handle deps);
466+
443467
//! \}
444468

445469
//! \defgroup Context Context Management
@@ -1277,6 +1301,105 @@ void stf_cuda_kernel_destroy(stf_cuda_kernel_handle k);
12771301

12781302
//! \}
12791303

1304+
//! \defgroup HostLaunch Host Launch
1305+
//! \brief Schedule a host callback as a task graph node with dependency tracking
1306+
//!
1307+
//! \details
1308+
//! Host launch provides a way to run arbitrary host-side functions as part of
1309+
//! the task graph. Unlike generic tasks where the user manually launches work
1310+
//! on a stream, host launch automatically schedules a C callback via
1311+
//! `cudaLaunchHostFunc` (stream context) or `cudaGraphAddHostNode` (graph context).
1312+
//!
1313+
//! This is the untyped counterpart of the C++ `ctx.host_launch(deps...)->*lambda`
1314+
//! construct, designed for use from C and Python bindings.
1315+
//! \{
1316+
1317+
//! \brief Create a host launch scope on a regular context
1318+
//!
1319+
//! \param ctx Context handle
1320+
//! \param[out] h Pointer to receive host launch handle
1321+
//!
1322+
//! \see stf_host_launch_destroy()
1323+
void stf_host_launch_create(stf_ctx_handle ctx, stf_host_launch_handle* h);
1324+
1325+
//! \brief Add a dependency to a host launch scope
1326+
//!
1327+
//! \param h Host launch handle
1328+
//! \param ld Logical data handle
1329+
//! \param m Access mode (STF_READ, STF_WRITE, STF_RW)
1330+
//!
1331+
//! \see stf_task_add_dep()
1332+
void stf_host_launch_add_dep(stf_host_launch_handle h, stf_logical_data_handle ld, stf_access_mode m);
1333+
1334+
//! \brief Set the debug symbol for a host launch scope
1335+
//!
1336+
//! \param h Host launch handle
1337+
//! \param symbol Null-terminated string
1338+
void stf_host_launch_set_symbol(stf_host_launch_handle h, const char* symbol);
1339+
1340+
//! \brief Copy user data into the host launch scope
1341+
//!
1342+
//! The data is copied and later accessible via
1343+
//! stf_host_launch_deps_get_user_data() inside the callback.
1344+
//! An optional destructor is called on the copied buffer when the
1345+
//! dependency handle is destroyed.
1346+
//!
1347+
//! \param h Host launch handle
1348+
//! \param data Pointer to user data
1349+
//! \param size Size of user data in bytes
1350+
//! \param dtor Optional destructor for the copied data (may be NULL)
1351+
void stf_host_launch_set_user_data(stf_host_launch_handle h, const void* data, size_t size, void (*dtor)(void*));
1352+
1353+
//! \brief Submit the host callback and finalize the scope
1354+
//!
1355+
//! After this call, the callback will be invoked on the host when all
1356+
//! read/write dependencies are satisfied. The callback receives an
1357+
//! opaque deps handle for accessing dependency data and user data.
1358+
//!
1359+
//! \param h Host launch handle
1360+
//! \param callback Function pointer invoked on the host
1361+
//!
1362+
//! \see stf_host_launch_create()
1363+
void stf_host_launch_submit(stf_host_launch_handle h, stf_host_callback_fn callback);
1364+
1365+
//! \brief Destroy a host launch handle
1366+
//!
1367+
//! \param h Host launch handle
1368+
//!
1369+
//! \see stf_host_launch_create()
1370+
void stf_host_launch_destroy(stf_host_launch_handle h);
1371+
1372+
//! \brief Get the raw data pointer for a dependency
1373+
//!
1374+
//! Returns the host-side pointer to the data of the dependency at \p index.
1375+
//! The pointer is valid only during the callback execution.
1376+
//!
1377+
//! \param deps Dependency handle
1378+
//! \param index Zero-based dependency index
1379+
//! \return Pointer to the data (as `slice<char>` data handle)
1380+
void* stf_host_launch_deps_get(stf_host_launch_deps_handle deps, size_t index);
1381+
1382+
//! \brief Get the byte size of a dependency
1383+
//!
1384+
//! \param deps Dependency handle
1385+
//! \param index Zero-based dependency index
1386+
//! \return Size in bytes
1387+
size_t stf_host_launch_deps_get_size(stf_host_launch_deps_handle deps, size_t index);
1388+
1389+
//! \brief Get the number of dependencies
1390+
//!
1391+
//! \param deps Dependency handle
1392+
//! \return Number of dependencies
1393+
size_t stf_host_launch_deps_size(stf_host_launch_deps_handle deps);
1394+
1395+
//! \brief Get the user data pointer
1396+
//!
1397+
//! \param deps Dependency handle
1398+
//! \return Pointer to the copied user data, or NULL if none was set
1399+
void* stf_host_launch_deps_get_user_data(stf_host_launch_deps_handle deps);
1400+
1401+
//! \}
1402+
12801403
#ifdef __cplusplus
12811404
}
12821405
#endif

c/experimental/stf/src/stf.cu

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,99 @@ void stf_cuda_kernel_destroy(stf_cuda_kernel_handle t)
411411
delete kernel_ptr;
412412
}
413413

414+
// -----------------------------------------------------------------------------
415+
// Host launch
416+
// -----------------------------------------------------------------------------
417+
418+
using host_launch_type = decltype(::std::declval<context>().host_launch());
419+
420+
void stf_host_launch_create(stf_ctx_handle ctx, stf_host_launch_handle* h)
421+
{
422+
_CCCL_ASSERT(ctx != nullptr, "context handle must not be null");
423+
_CCCL_ASSERT(h != nullptr, "host launch handle output pointer must not be null");
424+
425+
auto* context_ptr = reinterpret_cast<context*>(ctx);
426+
*h = reinterpret_cast<stf_host_launch_handle>(new host_launch_type{context_ptr->host_launch()});
427+
}
428+
429+
void stf_host_launch_add_dep(stf_host_launch_handle h, stf_logical_data_handle ld, stf_access_mode m)
430+
{
431+
_CCCL_ASSERT(h != nullptr, "host launch handle must not be null");
432+
_CCCL_ASSERT(ld != nullptr, "logical data handle must not be null");
433+
434+
auto* scope_ptr = reinterpret_cast<host_launch_type*>(h);
435+
auto* ld_ptr = reinterpret_cast<logical_data_untyped*>(ld);
436+
scope_ptr->add_deps(task_dep_untyped(*ld_ptr, access_mode(m)));
437+
}
438+
439+
void stf_host_launch_set_symbol(stf_host_launch_handle h, const char* symbol)
440+
{
441+
_CCCL_ASSERT(h != nullptr, "host launch handle must not be null");
442+
_CCCL_ASSERT(symbol != nullptr, "symbol must not be null");
443+
444+
auto* scope_ptr = reinterpret_cast<host_launch_type*>(h);
445+
scope_ptr->set_symbol(symbol);
446+
}
447+
448+
void stf_host_launch_set_user_data(stf_host_launch_handle h, const void* data, size_t size, void (*dtor)(void*))
449+
{
450+
_CCCL_ASSERT(h != nullptr, "host launch handle must not be null");
451+
452+
auto* scope_ptr = reinterpret_cast<host_launch_type*>(h);
453+
scope_ptr->set_user_data(data, size, dtor);
454+
}
455+
456+
void stf_host_launch_submit(stf_host_launch_handle h, stf_host_callback_fn callback)
457+
{
458+
_CCCL_ASSERT(h != nullptr, "host launch handle must not be null");
459+
_CCCL_ASSERT(callback != nullptr, "callback must not be null");
460+
461+
auto* scope_ptr = reinterpret_cast<host_launch_type*>(h);
462+
(*scope_ptr)->*[callback](reserved::host_launch_deps& deps) {
463+
callback(reinterpret_cast<stf_host_launch_deps_handle>(&deps));
464+
};
465+
}
466+
467+
void stf_host_launch_destroy(stf_host_launch_handle h)
468+
{
469+
_CCCL_ASSERT(h != nullptr, "host launch handle must not be null");
470+
471+
auto* scope_ptr = reinterpret_cast<host_launch_type*>(h);
472+
delete scope_ptr;
473+
}
474+
475+
void* stf_host_launch_deps_get(stf_host_launch_deps_handle deps, size_t index)
476+
{
477+
_CCCL_ASSERT(deps != nullptr, "deps handle must not be null");
478+
479+
auto* d = reinterpret_cast<reserved::host_launch_deps*>(deps);
480+
return d->get<slice<char>>(index).data_handle();
481+
}
482+
483+
size_t stf_host_launch_deps_get_size(stf_host_launch_deps_handle deps, size_t index)
484+
{
485+
_CCCL_ASSERT(deps != nullptr, "deps handle must not be null");
486+
487+
auto* d = reinterpret_cast<reserved::host_launch_deps*>(deps);
488+
return d->get<slice<char>>(index).extent(0);
489+
}
490+
491+
size_t stf_host_launch_deps_size(stf_host_launch_deps_handle deps)
492+
{
493+
_CCCL_ASSERT(deps != nullptr, "deps handle must not be null");
494+
495+
auto* d = reinterpret_cast<reserved::host_launch_deps*>(deps);
496+
return d->size();
497+
}
498+
499+
void* stf_host_launch_deps_get_user_data(stf_host_launch_deps_handle deps)
500+
{
501+
_CCCL_ASSERT(deps != nullptr, "deps handle must not be null");
502+
503+
auto* d = reinterpret_cast<reserved::host_launch_deps*>(deps);
504+
return d->user_data();
505+
}
506+
414507
// -----------------------------------------------------------------------------
415508
// Composite data place and execution place grid (for Python/cuTile multi-stream)
416509
// -----------------------------------------------------------------------------
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of CUDA Experimental in CUDA C++ Core Libraries,
4+
// under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES.
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
#include <cuda_runtime.h>
12+
13+
#include <c2h/catch2_test_helper.h>
14+
#include <cccl/c/experimental/stf/stf.h>
15+
16+
__global__ void fill_kernel(int cnt, double* data, double value)
17+
{
18+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
19+
int nthreads = gridDim.x * blockDim.x;
20+
21+
for (int i = tid; i < cnt; i += nthreads)
22+
{
23+
data[i] = value + i;
24+
}
25+
}
26+
27+
struct verify_args
28+
{
29+
size_t N;
30+
bool* passed;
31+
};
32+
33+
static void verify_callback(stf_host_launch_deps_handle deps)
34+
{
35+
auto* v = static_cast<verify_args*>(stf_host_launch_deps_get_user_data(deps));
36+
37+
if (stf_host_launch_deps_size(deps) != 1)
38+
{
39+
*v->passed = false;
40+
return;
41+
}
42+
43+
if (stf_host_launch_deps_get_size(deps, 0) != v->N * sizeof(double))
44+
{
45+
*v->passed = false;
46+
return;
47+
}
48+
49+
auto* data = static_cast<double*>(stf_host_launch_deps_get(deps, 0));
50+
for (size_t i = 0; i < v->N; i++)
51+
{
52+
if (fabs(data[i] - (42.0 + i)) > 1e-10)
53+
{
54+
*v->passed = false;
55+
return;
56+
}
57+
}
58+
*v->passed = true;
59+
}
60+
61+
C2H_TEST("host_launch with stream context", "[host_launch]")
62+
{
63+
const size_t N = 1024;
64+
65+
stf_ctx_handle ctx;
66+
stf_ctx_create(&ctx);
67+
68+
double* host_data;
69+
cudaMallocHost(&host_data, N * sizeof(double));
70+
for (size_t i = 0; i < N; i++)
71+
{
72+
host_data[i] = 0.0;
73+
}
74+
75+
stf_logical_data_handle lData;
76+
stf_logical_data(ctx, &lData, host_data, N * sizeof(double));
77+
stf_logical_data_set_symbol(lData, "data");
78+
79+
// Fill data via a kernel task
80+
stf_task_handle t;
81+
stf_task_create(ctx, &t);
82+
stf_task_set_symbol(t, "fill");
83+
stf_task_add_dep(t, lData, STF_WRITE);
84+
stf_task_start(t);
85+
double* dData = (double*) stf_task_get(t, 0);
86+
fill_kernel<<<2, 128, 0, (cudaStream_t) stf_task_get_custream(t)>>>((int) N, dData, 42.0);
87+
stf_task_end(t);
88+
stf_task_destroy(t);
89+
90+
// Use host_launch to verify data on the host
91+
bool passed = false;
92+
verify_args vargs{N, &passed};
93+
94+
stf_host_launch_handle h;
95+
stf_host_launch_create(ctx, &h);
96+
stf_host_launch_set_symbol(h, "verify");
97+
stf_host_launch_add_dep(h, lData, STF_READ);
98+
stf_host_launch_set_user_data(h, &vargs, sizeof(vargs), nullptr);
99+
stf_host_launch_submit(h, verify_callback);
100+
stf_host_launch_destroy(h);
101+
102+
stf_logical_data_destroy(lData);
103+
stf_ctx_finalize(ctx);
104+
105+
REQUIRE(passed);
106+
107+
cudaFreeHost(host_data);
108+
}
109+
110+
C2H_TEST("host_launch with graph context", "[host_launch]")
111+
{
112+
const size_t N = 1024;
113+
114+
stf_ctx_handle ctx;
115+
stf_ctx_create_graph(&ctx);
116+
117+
double* host_data;
118+
cudaMallocHost(&host_data, N * sizeof(double));
119+
for (size_t i = 0; i < N; i++)
120+
{
121+
host_data[i] = 0.0;
122+
}
123+
124+
stf_logical_data_handle lData;
125+
stf_logical_data(ctx, &lData, host_data, N * sizeof(double));
126+
stf_logical_data_set_symbol(lData, "data");
127+
128+
// Fill data via a generic task with stream capture
129+
stf_task_handle t;
130+
stf_task_create(ctx, &t);
131+
stf_task_set_symbol(t, "fill");
132+
stf_task_add_dep(t, lData, STF_WRITE);
133+
stf_task_enable_capture(t);
134+
stf_task_start(t);
135+
double* dData = (double*) stf_task_get(t, 0);
136+
cudaStream_t stream = (cudaStream_t) stf_task_get_custream(t);
137+
fill_kernel<<<2, 128, 0, stream>>>((int) N, dData, 42.0);
138+
stf_task_end(t);
139+
stf_task_destroy(t);
140+
141+
// Use host_launch to verify data on the host
142+
bool passed = false;
143+
verify_args vargs{N, &passed};
144+
145+
stf_host_launch_handle h;
146+
stf_host_launch_create(ctx, &h);
147+
stf_host_launch_set_symbol(h, "verify");
148+
stf_host_launch_add_dep(h, lData, STF_READ);
149+
stf_host_launch_set_user_data(h, &vargs, sizeof(vargs), nullptr);
150+
stf_host_launch_submit(h, verify_callback);
151+
stf_host_launch_destroy(h);
152+
153+
stf_logical_data_destroy(lData);
154+
stf_ctx_finalize(ctx);
155+
156+
REQUIRE(passed);
157+
158+
cudaFreeHost(host_data);
159+
}

0 commit comments

Comments
 (0)