@@ -447,23 +447,49 @@ void HistUpdater<GradientSumT>::InitSampling(
447
447
::sycl::buffer<uint64_t , 1 > flag_buf (&num_samples, 1 );
448
448
uint64_t seed = seed_;
449
449
seed_ += num_rows;
450
- event = qu_.submit ([&](::sycl::handler& cgh) {
451
- auto flag_buf_acc = flag_buf.get_access <::sycl::access ::mode::read_write>(cgh);
452
- cgh.parallel_for <>(::sycl::range<1 >(::sycl::range<1 >(num_rows)),
453
- [=](::sycl::item<1 > pid) {
454
- uint64_t i = pid.get_id (0 );
455
-
456
- // Create minstd_rand engine
457
- oneapi::dpl::minstd_rand engine (seed, i);
458
- oneapi::dpl::bernoulli_distribution coin_flip (subsample);
459
-
460
- auto rnd = coin_flip (engine);
461
- if (gpair_ptr[i].GetHess () >= 0 .0f && rnd) {
462
- AtomicRef<uint64_t > num_samples_ref (flag_buf_acc[0 ]);
463
- row_idx[num_samples_ref++] = i;
464
- }
450
+
451
+ /*
452
+ * oneDLP bernoulli_distribution implicitly uses double.
453
+ * In this case the device doesn't have fp64 support,
454
+ * we generate bernoulli distributed random values from uniform distribution
455
+ */
456
+ if (has_fp64_support_) {
457
+ // Use oneDPL bernoulli_distribution for better perf
458
+ event = qu_.submit ([&](::sycl::handler& cgh) {
459
+ auto flag_buf_acc = flag_buf.get_access <::sycl::access ::mode::read_write>(cgh);
460
+ cgh.parallel_for <>(::sycl::range<1 >(::sycl::range<1 >(num_rows)),
461
+ [=](::sycl::item<1 > pid) {
462
+ uint64_t i = pid.get_id (0 );
463
+ // Create minstd_rand engine
464
+ oneapi::dpl::minstd_rand engine (seed, i);
465
+ oneapi::dpl::bernoulli_distribution coin_flip (subsample);
466
+ auto bernoulli_rnd = coin_flip (engine);
467
+
468
+ if (gpair_ptr[i].GetHess () >= 0 .0f && bernoulli_rnd) {
469
+ AtomicRef<uint64_t > num_samples_ref (flag_buf_acc[0 ]);
470
+ row_idx[num_samples_ref++] = i;
471
+ }
472
+ });
465
473
});
466
- });
474
+ } else {
475
+ // Use oneDPL uniform for better perf, as far as bernoulli_distribution uses fp64
476
+ event = qu_.submit ([&](::sycl::handler& cgh) {
477
+ auto flag_buf_acc = flag_buf.get_access <::sycl::access ::mode::read_write>(cgh);
478
+ cgh.parallel_for <>(::sycl::range<1 >(::sycl::range<1 >(num_rows)),
479
+ [=](::sycl::item<1 > pid) {
480
+ uint64_t i = pid.get_id (0 );
481
+ oneapi::dpl::minstd_rand engine (seed, i);
482
+ oneapi::dpl::uniform_real_distribution<float > distr;
483
+ const float rnd = distr (engine);
484
+ const bool bernoulli_rnd = rnd < subsample ? 1 : 0 ;
485
+
486
+ if (gpair_ptr[i].GetHess () >= 0 .0f && bernoulli_rnd) {
487
+ AtomicRef<uint64_t > num_samples_ref (flag_buf_acc[0 ]);
488
+ row_idx[num_samples_ref++] = i;
489
+ }
490
+ });
491
+ });
492
+ }
467
493
/* After calling a destructor for flag_buf, content will be copyed to num_samples */
468
494
}
469
495
0 commit comments