Skip to content

Commit 4f95dbd

Browse files
committed
Pass stream_ref directly and replace cudaStream_t with cuda::stream
1 parent 1568074 commit 4f95dbd

File tree

2 files changed

+23
-36
lines changed

2 files changed

+23
-36
lines changed

cub/test/catch2_test_device_segmented_sort_pairs_env.cu

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ struct stream_registry_factory_t;
1111

1212
#include <thrust/device_vector.h>
1313

14+
#include <cuda/devices>
15+
#include <cuda/stream>
16+
1417
#include "catch2_test_env_launch_helper.h"
1518

1619
DECLARE_LAUNCH_WRAPPER(cub::DeviceSegmentedSort::StableSortPairs, stable_sort_pairs);
@@ -204,8 +207,7 @@ TEST_CASE("DeviceSegmentedSort::SortPairs nonstable uses custom stream", "[segme
204207
auto values_out = c2h::device_vector<int>(7);
205208
auto offsets = c2h::device_vector<int>{0, 3, 7};
206209

207-
cudaStream_t custom_stream;
208-
REQUIRE(cudaSuccess == cudaStreamCreate(&custom_stream));
210+
cuda::stream stream{cuda::devices[0]};
209211

210212
size_t expected_bytes_allocated{};
211213
REQUIRE(
@@ -222,7 +224,7 @@ TEST_CASE("DeviceSegmentedSort::SortPairs nonstable uses custom stream", "[segme
222224
thrust::raw_pointer_cast(offsets.data()),
223225
thrust::raw_pointer_cast(offsets.data()) + 1));
224226

225-
auto stream_prop = stdexec::prop{cuda::get_stream_t{}, cuda::stream_ref{custom_stream}};
227+
auto stream_prop = stdexec::prop{cuda::get_stream_t{}, cuda::stream_ref{stream}};
226228
auto env = stdexec::env{stream_prop, expected_allocation_size(expected_bytes_allocated)};
227229

228230
sort_pairs(
@@ -236,13 +238,12 @@ TEST_CASE("DeviceSegmentedSort::SortPairs nonstable uses custom stream", "[segme
236238
thrust::raw_pointer_cast(offsets.data()) + 1,
237239
env);
238240

239-
REQUIRE(cudaSuccess == cudaStreamSynchronize(custom_stream));
241+
stream.sync();
240242

241243
c2h::device_vector<int> expected_keys{6, 7, 8, 0, 3, 5, 9};
242244
c2h::device_vector<int> expected_values{1, 2, 0, 5, 4, 3, 6};
243245
REQUIRE(keys_out == expected_keys);
244246
REQUIRE(values_out == expected_values);
245-
REQUIRE(cudaSuccess == cudaStreamDestroy(custom_stream));
246247
}
247248

248249
TEST_CASE("DeviceSegmentedSort::SortPairsDescending nonstable uses custom stream", "[segmented_sort][pairs][device]")
@@ -253,8 +254,7 @@ TEST_CASE("DeviceSegmentedSort::SortPairsDescending nonstable uses custom stream
253254
auto values_out = c2h::device_vector<int>(7);
254255
auto offsets = c2h::device_vector<int>{0, 3, 7};
255256

256-
cudaStream_t custom_stream;
257-
REQUIRE(cudaSuccess == cudaStreamCreate(&custom_stream));
257+
cuda::stream stream{cuda::devices[0]};
258258

259259
size_t expected_bytes_allocated{};
260260
REQUIRE(
@@ -271,7 +271,7 @@ TEST_CASE("DeviceSegmentedSort::SortPairsDescending nonstable uses custom stream
271271
thrust::raw_pointer_cast(offsets.data()),
272272
thrust::raw_pointer_cast(offsets.data()) + 1));
273273

274-
auto stream_prop = stdexec::prop{cuda::get_stream_t{}, cuda::stream_ref{custom_stream}};
274+
auto stream_prop = stdexec::prop{cuda::get_stream_t{}, cuda::stream_ref{stream}};
275275
auto env = stdexec::env{stream_prop, expected_allocation_size(expected_bytes_allocated)};
276276

277277
sort_pairs_descending(
@@ -285,13 +285,12 @@ TEST_CASE("DeviceSegmentedSort::SortPairsDescending nonstable uses custom stream
285285
thrust::raw_pointer_cast(offsets.data()) + 1,
286286
env);
287287

288-
REQUIRE(cudaSuccess == cudaStreamSynchronize(custom_stream));
288+
stream.sync();
289289

290290
c2h::device_vector<int> expected_keys{8, 7, 6, 9, 5, 3, 0};
291291
c2h::device_vector<int> expected_values{0, 2, 1, 6, 3, 4, 5};
292292
REQUIRE(keys_out == expected_keys);
293293
REQUIRE(values_out == expected_values);
294-
REQUIRE(cudaSuccess == cudaStreamDestroy(custom_stream));
295294
}
296295

297296
TEST_CASE("DeviceSegmentedSort::SortPairs nonstable DoubleBuffer uses custom stream", "[segmented_sort][pairs][device]")
@@ -306,8 +305,7 @@ TEST_CASE("DeviceSegmentedSort::SortPairs nonstable DoubleBuffer uses custom str
306305
cub::DoubleBuffer<int> d_values(
307306
thrust::raw_pointer_cast(values_buf0.data()), thrust::raw_pointer_cast(values_buf1.data()));
308307

309-
cudaStream_t custom_stream;
310-
REQUIRE(cudaSuccess == cudaStreamCreate(&custom_stream));
308+
cuda::stream stream{cuda::devices[0]};
311309

312310
size_t expected_bytes_allocated{};
313311
REQUIRE(
@@ -322,7 +320,7 @@ TEST_CASE("DeviceSegmentedSort::SortPairs nonstable DoubleBuffer uses custom str
322320
thrust::raw_pointer_cast(offsets.data()),
323321
thrust::raw_pointer_cast(offsets.data()) + 1));
324322

325-
auto stream_prop = stdexec::prop{cuda::get_stream_t{}, cuda::stream_ref{custom_stream}};
323+
auto stream_prop = stdexec::prop{cuda::get_stream_t{}, cuda::stream_ref{stream}};
326324
auto env = stdexec::env{stream_prop, expected_allocation_size(expected_bytes_allocated)};
327325

328326
sort_pairs(d_keys,
@@ -333,15 +331,14 @@ TEST_CASE("DeviceSegmentedSort::SortPairs nonstable DoubleBuffer uses custom str
333331
thrust::raw_pointer_cast(offsets.data()) + 1,
334332
env);
335333

336-
REQUIRE(cudaSuccess == cudaStreamSynchronize(custom_stream));
334+
stream.sync();
337335

338336
c2h::device_vector<int> expected_keys{6, 7, 8, 0, 3, 5, 9};
339337
c2h::device_vector<int> expected_values{1, 2, 0, 5, 4, 3, 6};
340338
c2h::device_vector<int> result_keys(d_keys.Current(), d_keys.Current() + 7);
341339
c2h::device_vector<int> result_values(d_values.Current(), d_values.Current() + 7);
342340
REQUIRE(result_keys == expected_keys);
343341
REQUIRE(result_values == expected_values);
344-
REQUIRE(cudaSuccess == cudaStreamDestroy(custom_stream));
345342
}
346343

347344
TEST_CASE("DeviceSegmentedSort::SortPairsDescending nonstable DoubleBuffer uses custom stream",
@@ -357,8 +354,7 @@ TEST_CASE("DeviceSegmentedSort::SortPairsDescending nonstable DoubleBuffer uses
357354
cub::DoubleBuffer<int> d_values(
358355
thrust::raw_pointer_cast(values_buf0.data()), thrust::raw_pointer_cast(values_buf1.data()));
359356

360-
cudaStream_t custom_stream;
361-
REQUIRE(cudaSuccess == cudaStreamCreate(&custom_stream));
357+
cuda::stream stream{cuda::devices[0]};
362358

363359
size_t expected_bytes_allocated{};
364360
REQUIRE(
@@ -373,7 +369,7 @@ TEST_CASE("DeviceSegmentedSort::SortPairsDescending nonstable DoubleBuffer uses
373369
thrust::raw_pointer_cast(offsets.data()),
374370
thrust::raw_pointer_cast(offsets.data()) + 1));
375371

376-
auto stream_prop = stdexec::prop{cuda::get_stream_t{}, cuda::stream_ref{custom_stream}};
372+
auto stream_prop = stdexec::prop{cuda::get_stream_t{}, cuda::stream_ref{stream}};
377373
auto env = stdexec::env{stream_prop, expected_allocation_size(expected_bytes_allocated)};
378374

379375
sort_pairs_descending(
@@ -385,15 +381,14 @@ TEST_CASE("DeviceSegmentedSort::SortPairsDescending nonstable DoubleBuffer uses
385381
thrust::raw_pointer_cast(offsets.data()) + 1,
386382
env);
387383

388-
REQUIRE(cudaSuccess == cudaStreamSynchronize(custom_stream));
384+
stream.sync();
389385

390386
c2h::device_vector<int> expected_keys{8, 7, 6, 9, 5, 3, 0};
391387
c2h::device_vector<int> expected_values{0, 2, 1, 6, 3, 4, 5};
392388
c2h::device_vector<int> result_keys(d_keys.Current(), d_keys.Current() + 7);
393389
c2h::device_vector<int> result_values(d_values.Current(), d_values.Current() + 7);
394390
REQUIRE(result_keys == expected_keys);
395391
REQUIRE(result_values == expected_values);
396-
REQUIRE(cudaSuccess == cudaStreamDestroy(custom_stream));
397392
}
398393

399394
C2H_TEST("DeviceSegmentedSort::StableSortPairs uses environment", "[segmented_sort][pairs][device]")

cub/test/catch2_test_device_segmented_sort_pairs_env_api.cu

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ C2H_TEST("cub::DeviceSegmentedSort::StableSortPairs env-based API", "[segmented_
2626

2727
cuda::stream stream{cuda::devices[0]};
2828
cuda::stream_ref stream_ref{stream};
29-
auto env = cuda::std::execution::env{stream_ref};
3029

3130
auto error = cub::DeviceSegmentedSort::StableSortPairs(
3231
thrust::raw_pointer_cast(keys_in.data()),
@@ -37,7 +36,7 @@ C2H_TEST("cub::DeviceSegmentedSort::StableSortPairs env-based API", "[segmented_
3736
2,
3837
thrust::raw_pointer_cast(offsets_begin.data()),
3938
thrust::raw_pointer_cast(offsets_end.data()),
40-
env);
39+
stream_ref);
4140
if (error != cudaSuccess)
4241
{
4342
std::cerr << "cub::DeviceSegmentedSort::StableSortPairs failed with status: " << error << std::endl;
@@ -65,7 +64,6 @@ C2H_TEST("cub::DeviceSegmentedSort::StableSortPairsDescending env-based API", "[
6564

6665
cuda::stream stream{cuda::devices[0]};
6766
cuda::stream_ref stream_ref{stream};
68-
auto env = cuda::std::execution::env{stream_ref};
6967

7068
auto error = cub::DeviceSegmentedSort::StableSortPairsDescending(
7169
thrust::raw_pointer_cast(keys_in.data()),
@@ -76,7 +74,7 @@ C2H_TEST("cub::DeviceSegmentedSort::StableSortPairsDescending env-based API", "[
7674
2,
7775
thrust::raw_pointer_cast(offsets_begin.data()),
7876
thrust::raw_pointer_cast(offsets_end.data()),
79-
env);
77+
stream_ref);
8078
if (error != cudaSuccess)
8179
{
8280
std::cerr << "cub::DeviceSegmentedSort::StableSortPairsDescending failed with status: " << error << std::endl;
@@ -108,7 +106,6 @@ C2H_TEST("cub::DeviceSegmentedSort::StableSortPairs DoubleBuffer env-based API",
108106

109107
cuda::stream stream{cuda::devices[0]};
110108
cuda::stream_ref stream_ref{stream};
111-
auto env = cuda::std::execution::env{stream_ref};
112109

113110
auto error = cub::DeviceSegmentedSort::StableSortPairs(
114111
d_keys,
@@ -117,7 +114,7 @@ C2H_TEST("cub::DeviceSegmentedSort::StableSortPairs DoubleBuffer env-based API",
117114
2,
118115
thrust::raw_pointer_cast(offsets_begin.data()),
119116
thrust::raw_pointer_cast(offsets_end.data()),
120-
env);
117+
stream_ref);
121118
if (error != cudaSuccess)
122119
{
123120
std::cerr << "cub::DeviceSegmentedSort::StableSortPairs (DoubleBuffer) failed with status: " << error << std::endl;
@@ -152,7 +149,6 @@ C2H_TEST("cub::DeviceSegmentedSort::StableSortPairsDescending DoubleBuffer env-b
152149

153150
cuda::stream stream{cuda::devices[0]};
154151
cuda::stream_ref stream_ref{stream};
155-
auto env = cuda::std::execution::env{stream_ref};
156152

157153
auto error = cub::DeviceSegmentedSort::StableSortPairsDescending(
158154
d_keys,
@@ -161,7 +157,7 @@ C2H_TEST("cub::DeviceSegmentedSort::StableSortPairsDescending DoubleBuffer env-b
161157
2,
162158
thrust::raw_pointer_cast(offsets_begin.data()),
163159
thrust::raw_pointer_cast(offsets_end.data()),
164-
env);
160+
stream_ref);
165161
if (error != cudaSuccess)
166162
{
167163
std::cerr << "cub::DeviceSegmentedSort::StableSortPairsDescending (DoubleBuffer) failed with status: " << error
@@ -192,7 +188,6 @@ C2H_TEST("cub::DeviceSegmentedSort::SortPairs nonstable env-based API", "[segmen
192188

193189
cuda::stream stream{cuda::devices[0]};
194190
cuda::stream_ref stream_ref{stream};
195-
auto env = cuda::std::execution::env{stream_ref};
196191

197192
auto error = cub::DeviceSegmentedSort::SortPairs(
198193
thrust::raw_pointer_cast(keys_in.data()),
@@ -203,7 +198,7 @@ C2H_TEST("cub::DeviceSegmentedSort::SortPairs nonstable env-based API", "[segmen
203198
2,
204199
thrust::raw_pointer_cast(offsets_begin.data()),
205200
thrust::raw_pointer_cast(offsets_end.data()),
206-
env);
201+
stream_ref);
207202
if (error != cudaSuccess)
208203
{
209204
std::cerr << "cub::DeviceSegmentedSort::SortPairs failed with status: " << error << std::endl;
@@ -231,7 +226,6 @@ C2H_TEST("cub::DeviceSegmentedSort::SortPairsDescending nonstable env-based API"
231226

232227
cuda::stream stream{cuda::devices[0]};
233228
cuda::stream_ref stream_ref{stream};
234-
auto env = cuda::std::execution::env{stream_ref};
235229

236230
auto error = cub::DeviceSegmentedSort::SortPairsDescending(
237231
thrust::raw_pointer_cast(keys_in.data()),
@@ -242,7 +236,7 @@ C2H_TEST("cub::DeviceSegmentedSort::SortPairsDescending nonstable env-based API"
242236
2,
243237
thrust::raw_pointer_cast(offsets_begin.data()),
244238
thrust::raw_pointer_cast(offsets_end.data()),
245-
env);
239+
stream_ref);
246240
if (error != cudaSuccess)
247241
{
248242
std::cerr << "cub::DeviceSegmentedSort::SortPairsDescending failed with status: " << error << std::endl;
@@ -274,7 +268,6 @@ C2H_TEST("cub::DeviceSegmentedSort::SortPairs nonstable DoubleBuffer env-based A
274268

275269
cuda::stream stream{cuda::devices[0]};
276270
cuda::stream_ref stream_ref{stream};
277-
auto env = cuda::std::execution::env{stream_ref};
278271

279272
auto error = cub::DeviceSegmentedSort::SortPairs(
280273
d_keys,
@@ -283,7 +276,7 @@ C2H_TEST("cub::DeviceSegmentedSort::SortPairs nonstable DoubleBuffer env-based A
283276
2,
284277
thrust::raw_pointer_cast(offsets_begin.data()),
285278
thrust::raw_pointer_cast(offsets_end.data()),
286-
env);
279+
stream_ref);
287280
if (error != cudaSuccess)
288281
{
289282
std::cerr << "cub::DeviceSegmentedSort::SortPairs (DoubleBuffer) failed with status: " << error << std::endl;
@@ -318,7 +311,6 @@ C2H_TEST("cub::DeviceSegmentedSort::SortPairsDescending nonstable DoubleBuffer e
318311

319312
cuda::stream stream{cuda::devices[0]};
320313
cuda::stream_ref stream_ref{stream};
321-
auto env = cuda::std::execution::env{stream_ref};
322314

323315
auto error = cub::DeviceSegmentedSort::SortPairsDescending(
324316
d_keys,
@@ -327,7 +319,7 @@ C2H_TEST("cub::DeviceSegmentedSort::SortPairsDescending nonstable DoubleBuffer e
327319
2,
328320
thrust::raw_pointer_cast(offsets_begin.data()),
329321
thrust::raw_pointer_cast(offsets_end.data()),
330-
env);
322+
stream_ref);
331323
if (error != cudaSuccess)
332324
{
333325
std::cerr

0 commit comments

Comments
 (0)