66// ----------------------------------------------------------------------------
77
88#include " open3d/core/Dispatch.h"
9- #include " open3d/core/Indexer.h"
109#include " open3d/core/SYCLContext.h"
1110#include " open3d/core/Tensor.h"
1211#include " open3d/utility/Logging.h"
@@ -15,45 +14,147 @@ namespace open3d {
1514namespace core {
1615namespace kernel {
1716
17+ namespace {
18+
19+ template <typename scalar_t >
20+ // Launches contiguous index_add over dim0:
21+ // dst[index[i], ...] += src[i, ...]
22+ // Contract:
23+ // - `index_ptr`, `src_ptr`, and `dst_ptr` point to contiguous buffers.
24+ // - `index_length` is the length of `index_ptr` and the leading dimension of
25+ // `src_ptr`.
26+ // - `broadcasting_elems` is the flattened product of non-reduction dimensions.
27+ // - `dst_ptr` has enough rows to address all index values.
28+ void LaunchIndexAddContiguousSYCLKernel (sycl::queue& queue,
29+ const int64_t * index_ptr,
30+ const scalar_t * src_ptr,
31+ scalar_t * dst_ptr,
32+ int64_t index_length,
33+ int64_t broadcasting_elems) {
34+ if (index_length <= 0 || broadcasting_elems <= 0 ) {
35+ return ;
36+ }
37+
38+ auto ceil_div = [](int64_t a, int64_t b) -> int64_t {
39+ return (a + b - 1 ) / b;
40+ };
41+ auto round_up = [](int64_t x, int64_t m) -> int64_t {
42+ return ((x + m - 1 ) / m) * m;
43+ };
44+
45+ // 2D launch configuration:
46+ // - X dimension tiles columns (broadcasting_elems).
47+ // - Y dimension tiles reduction rows (index_length).
48+ //
49+ // Each work-group processes TILE_ROWS rows and WG_X columns. Within a row
50+ // tile, consecutive runs of identical destination indices are reduced into
51+ // one atomic add per (column, run), reducing atomic pressure while
52+ // preserving index_add semantics.
53+ constexpr int WG_X = 256 ;
54+ constexpr int TILE_ROWS = 8 ;
55+ const int64_t num_row_tiles = ceil_div (index_length, int64_t (TILE_ROWS));
56+ const int64_t global_x = round_up (broadcasting_elems, int64_t (WG_X));
57+ sycl::nd_range<2 > launch (sycl::range<2 >(num_row_tiles, global_x),
58+ sycl::range<2 >(1 , WG_X));
59+
60+ queue.submit ([&](sycl::handler& cgh) {
61+ sycl::local_accessor<int64_t , 1 > l_idx (sycl::range<1 >(TILE_ROWS),
62+ cgh);
63+
64+ cgh.parallel_for (
65+ launch,
66+ [=](sycl::nd_item<2 > it) [[sycl::reqd_sub_group_size (
67+ 16 )]] {
68+ const int lid_x = int (it.get_local_id (1 ));
69+ const int64_t group_y = it.get_group (0 );
70+ const int64_t col = it.get_global_id (1 );
71+ if (col >= broadcasting_elems) {
72+ return ;
73+ }
74+
75+ const int64_t row_base = group_y * int64_t (TILE_ROWS);
76+
77+ if (lid_x < TILE_ROWS) {
78+ const int64_t r = row_base + lid_x;
79+ l_idx[lid_x] = (r < index_length) ? index_ptr[r]
80+ : int64_t (-1 );
81+ }
82+ it.barrier (sycl::access::fence_space::local_space);
83+
84+ int run_start = 0 ;
85+ while (run_start < TILE_ROWS) {
86+ const int64_t dst_row = l_idx[run_start];
87+ if (dst_row < 0 ) {
88+ break ;
89+ }
90+
91+ int run_end = run_start + 1 ;
92+ while (run_end < TILE_ROWS &&
93+ l_idx[run_end] == dst_row) {
94+ ++run_end;
95+ }
96+
97+ scalar_t sum = scalar_t (0 );
98+ for (int rr = run_start; rr < run_end; ++rr) {
99+ const int64_t src_row = row_base + int64_t (rr);
100+ if (src_row < index_length) {
101+ const int64_t workload_idx =
102+ src_row * broadcasting_elems + col;
103+ sum += src_ptr[workload_idx];
104+ }
105+ }
106+
107+ const int64_t dst_idx =
108+ dst_row * broadcasting_elems + col;
109+ sycl::atomic_ref<scalar_t ,
110+ sycl::memory_order::relaxed,
111+ sycl::memory_scope::device>
112+ aref (dst_ptr[dst_idx]);
113+ aref += sum;
114+
115+ run_start = run_end;
116+ }
117+ });
118+ }).wait_and_throw ();
119+ }
120+
121+ } // namespace
122+
18123void IndexAddSYCL_ (int64_t dim,
19124 const Tensor& index,
20125 const Tensor& src,
21126 Tensor& dst) {
22127 // index: [N,], src: [N, D], dst: [M, D]
23- // In Indexer, output shape defines the actual primary strides.
24- // However, in IndexAdd_, input dominates the iterations.
25- // So put dst (output) at indexer's input, and src (input) at output.
26- Indexer indexer ({dst}, src, DtypePolicy::NONE);
128+ // This kernel assumes contiguous layout for fast linear indexing.
129+ // Non-contiguous tensors are materialized as contiguous before launch.
130+ const Tensor index_contiguous = index.Contiguous ();
131+ const Tensor src_contiguous = src.Contiguous ();
132+ Tensor dst_contiguous = dst.Contiguous ();
27133
28- // Index is simply a 1D contiguous tensor, with a different stride
29- // behavior to src. So use raw pointer for simplicity.
30- auto index_ptr = index.GetDataPtr <int64_t >();
134+ // Index is simply a 1D contiguous tensor.
135+ auto index_ptr = index_contiguous.GetDataPtr <int64_t >();
31136
32137 int64_t broadcasting_elems = 1 ;
33- for (int64_t d = 1 ; d < src .NumDims (); ++d) {
34- broadcasting_elems *= src .GetShape (d);
138+ for (int64_t d = 1 ; d < src_contiguous .NumDims (); ++d) {
139+ broadcasting_elems *= src_contiguous .GetShape (d);
35140 }
141+
142+ const int64_t index_length = index_contiguous.GetLength ();
143+
36144 sycl::queue queue =
37145 sy::SYCLContext::GetInstance ().GetDefaultQueue (src.GetDevice ());
38146
39- // TODO: Replace with SYCL reduction API
40147 DISPATCH_FLOAT_DTYPE_TO_TEMPLATE (src.GetDtype (), [&]() {
41- queue.parallel_for (index.GetLength (), [=](int64_t workload_idx) {
42- int64_t reduction_idx = workload_idx / broadcasting_elems;
43- int64_t broadcasting_idx = workload_idx % broadcasting_elems;
44-
45- const int64_t idx = index_ptr[reduction_idx];
46- int64_t dst_idx = idx * broadcasting_elems + broadcasting_idx;
47-
48- // Note input and output is switched here to adapt to the
49- // indexer
50- scalar_t * src_ptr = indexer.GetOutputPtr <scalar_t >(0 , idx);
51- scalar_t * dst_ptr = indexer.GetInputPtr <scalar_t >(0 , dst_idx);
52- sycl::atomic_ref<scalar_t , sycl::memory_order::acq_rel,
53- sycl::memory_scope::device>(*dst_ptr) +=
54- *src_ptr;
55- }).wait_and_throw ();
148+ LaunchIndexAddContiguousSYCLKernel<scalar_t >(
149+ queue, index_ptr, src_contiguous.GetDataPtr <scalar_t >(),
150+ dst_contiguous.GetDataPtr <scalar_t >(), index_length,
151+ broadcasting_elems);
56152 });
153+
154+ // If dst is non-contiguous, write back from the contiguous temporary.
155+ if (!dst.IsContiguous ()) {
156+ dst.CopyFrom (dst_contiguous);
157+ }
57158}
58159
59160} // namespace kernel
0 commit comments