Skip to content

Commit ba88551

Browse files
authored
Fix sampling for fp32 devices (#58)
* replace bernoulli by uniform * dispatch different approaches * optimize dispatching * fix missprint * linting --------- Co-authored-by: Dmitry Razdoburdin <>
1 parent 888ff62 commit ba88551

File tree

2 files changed

+44
-16
lines changed

2 files changed

+44
-16
lines changed

plugin/sycl/tree/hist_updater.cc

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -447,23 +447,49 @@ void HistUpdater<GradientSumT>::InitSampling(
447447
::sycl::buffer<uint64_t, 1> flag_buf(&num_samples, 1);
448448
uint64_t seed = seed_;
449449
seed_ += num_rows;
450-
event = qu_.submit([&](::sycl::handler& cgh) {
451-
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh);
452-
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)),
453-
[=](::sycl::item<1> pid) {
454-
uint64_t i = pid.get_id(0);
455-
456-
// Create minstd_rand engine
457-
oneapi::dpl::minstd_rand engine(seed, i);
458-
oneapi::dpl::bernoulli_distribution coin_flip(subsample);
459-
460-
auto rnd = coin_flip(engine);
461-
if (gpair_ptr[i].GetHess() >= 0.0f && rnd) {
462-
AtomicRef<uint64_t> num_samples_ref(flag_buf_acc[0]);
463-
row_idx[num_samples_ref++] = i;
464-
}
450+
451+
/*
452+
* oneDLP bernoulli_distribution implicitly uses double.
453+
* In this case the device doesn't have fp64 support,
454+
* we generate bernoulli distributed random values from uniform distribution
455+
*/
456+
if (has_fp64_support_) {
457+
// Use oneDPL bernoulli_distribution for better perf
458+
event = qu_.submit([&](::sycl::handler& cgh) {
459+
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh);
460+
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)),
461+
[=](::sycl::item<1> pid) {
462+
uint64_t i = pid.get_id(0);
463+
// Create minstd_rand engine
464+
oneapi::dpl::minstd_rand engine(seed, i);
465+
oneapi::dpl::bernoulli_distribution coin_flip(subsample);
466+
auto bernoulli_rnd = coin_flip(engine);
467+
468+
if (gpair_ptr[i].GetHess() >= 0.0f && bernoulli_rnd) {
469+
AtomicRef<uint64_t> num_samples_ref(flag_buf_acc[0]);
470+
row_idx[num_samples_ref++] = i;
471+
}
472+
});
465473
});
466-
});
474+
} else {
475+
// Use oneDPL uniform for better perf, as far as bernoulli_distribution uses fp64
476+
event = qu_.submit([&](::sycl::handler& cgh) {
477+
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh);
478+
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)),
479+
[=](::sycl::item<1> pid) {
480+
uint64_t i = pid.get_id(0);
481+
oneapi::dpl::minstd_rand engine(seed, i);
482+
oneapi::dpl::uniform_real_distribution<float> distr;
483+
const float rnd = distr(engine);
484+
const bool bernoulli_rnd = rnd < subsample ? 1 : 0;
485+
486+
if (gpair_ptr[i].GetHess() >= 0.0f && bernoulli_rnd) {
487+
AtomicRef<uint64_t> num_samples_ref(flag_buf_acc[0]);
488+
row_idx[num_samples_ref++] = i;
489+
}
490+
});
491+
});
492+
}
467493
/* After calling a destructor for flag_buf, content will be copyed to num_samples */
468494
}
469495

plugin/sycl/tree/hist_updater.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class HistUpdater {
6666
if (param.max_depth > 0) {
6767
snode_device_.Resize(&qu, 1u << (param.max_depth + 1));
6868
}
69+
has_fp64_support_ = qu_.get_device().has(::sycl::aspect::fp64);
6970
const auto sub_group_sizes =
7071
qu_.get_device().get_info<::sycl::info::device::sub_group_sizes>();
7172
sub_group_size_ = sub_group_sizes.back();
@@ -211,6 +212,7 @@ class HistUpdater {
211212

212213
// --data fields--
213214
const Context* ctx_;
215+
bool has_fp64_support_;
214216
size_t sub_group_size_;
215217
const xgboost::tree::TrainParam& param_;
216218
std::shared_ptr<xgboost::common::ColumnSampler> column_sampler_;

0 commit comments

Comments
 (0)