Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions cpp/src/neighbors/detail/vpq_dataset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -302,15 +302,29 @@ __launch_bounds__(BlockSize) RAFT_KERNEL process_and_fill_codes_kernel(
auto* out_label_ptr = reinterpret_cast<LabelT*>(&out_codes(row_ix, 0));
if (lane_id == 0) { *out_label_ptr = vq_label; }

auto* out_codes_ptr = reinterpret_cast<uint8_t*>(out_label_ptr + 1);
cuvs::neighbors::ivf_pq::detail::bitfield_view_t<PqBits> code_view{out_codes_ptr};
// The out_codes rows are aligned to (at least) the size of the label type,
// so we use this type for writing them (global memory).
LabelT staging_codes = 0;
LabelT* out_codes_ptr = out_label_ptr + 1;
constexpr uint32_t BitsPerLabel = sizeof(LabelT) * 8;
uint32_t filled_bits = 0;
for (uint32_t j = 0; j < pq_dim; j++) {
// find PQ label
uint8_t code = compute_code<kSubWarpSize>(dataset, vq_centers, pq_centers, row_ix, j, vq_label);
// TODO: this writes in global memory one byte per warp, which is very slow.
// It's better to keep the codes in the shared memory or registers and dump them at once.
if (lane_id == 0) { code_view[j] = code; }
// stage the code and maybe write
if (lane_id == 0) {
staging_codes |= (static_cast<LabelT>(code) << filled_bits);
filled_bits += PqBits;
if (filled_bits >= BitsPerLabel) {
filled_bits -= BitsPerLabel;
// write the codes to global memory
*out_codes_ptr++ = staging_codes;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving the lane condition if (lane_id == 0) to only this line can improve warp parallelism

// stage the leftover (or zero the buffer if no bits are left)
staging_codes = (static_cast<LabelT>(code) >> (PqBits - filled_bits));
}
}
}
if (lane_id == 0 && filled_bits > 0) { *out_codes_ptr = staging_codes; }
}

template <typename MathT, typename IdxT, typename DatasetT>
Expand Down
Loading