diff --git a/cpp/src/neighbors/detail/vpq_dataset.cuh b/cpp/src/neighbors/detail/vpq_dataset.cuh index 6c9f775e05..308b9da475 100644 --- a/cpp/src/neighbors/detail/vpq_dataset.cuh +++ b/cpp/src/neighbors/detail/vpq_dataset.cuh @@ -302,15 +302,29 @@ __launch_bounds__(BlockSize) RAFT_KERNEL process_and_fill_codes_kernel( auto* out_label_ptr = reinterpret_cast(&out_codes(row_ix, 0)); if (lane_id == 0) { *out_label_ptr = vq_label; } - auto* out_codes_ptr = reinterpret_cast(out_label_ptr + 1); - cuvs::neighbors::ivf_pq::detail::bitfield_view_t code_view{out_codes_ptr}; + // The out_codes rows are aligned to (at least) the size of the label type, + // so we use this type for writing them (global memory). + LabelT staging_codes = 0; + LabelT* out_codes_ptr = out_label_ptr + 1; + constexpr uint32_t BitsPerLabel = sizeof(LabelT) * 8; + uint32_t filled_bits = 0; for (uint32_t j = 0; j < pq_dim; j++) { // find PQ label uint8_t code = compute_code(dataset, vq_centers, pq_centers, row_ix, j, vq_label); - // TODO: this writes in global memory one byte per warp, which is very slow. - // It's better to keep the codes in the shared memory or registers and dump them at once. - if (lane_id == 0) { code_view[j] = code; } + // stage the code and maybe write + if (lane_id == 0) { + staging_codes |= (static_cast(code) << filled_bits); + filled_bits += PqBits; + if (filled_bits >= BitsPerLabel) { + filled_bits -= BitsPerLabel; + // write the codes to global memory + *out_codes_ptr++ = staging_codes; + // stage the leftover (or zero the buffer if no bits are left) + staging_codes = (static_cast(code) >> (PqBits - filled_bits)); + } + } } + if (lane_id == 0 && filled_bits > 0) { *out_codes_ptr = staging_codes; } } template