@@ -73,6 +73,13 @@ class TestHistUpdater : public HistUpdater<GradientSumT> {
73
73
const USMVector<GradientPair, MemoryType::on_device> &gpair) {
74
74
HistUpdater<GradientSumT>::ExpandWithLossGuide (gmat, p_tree, gpair);
75
75
}
76
+
77
+ auto TestExpandWithDepthWise (const common::GHistIndexMatrix& gmat,
78
+ DMatrix *p_fmat,
79
+ RegTree* p_tree,
80
+ const USMVector<GradientPair, MemoryType::on_device> &gpair) {
81
+ HistUpdater<GradientSumT>::ExpandWithDepthWise (gmat, p_tree, gpair);
82
+ }
76
83
};
77
84
78
85
void GenerateRandomGPairs (::sycl::queue* qu, GradientPair* gpair_ptr, size_t num_rows, bool has_neg_hess) {
@@ -532,6 +539,53 @@ void TestHistUpdaterExpandWithLossGuide(const xgboost::tree::TrainParam& param)
532
539
ASSERT_NEAR (ans[2 ], -0.15 , 1e-6 );
533
540
}
534
541
542
+ template <typename GradientSumT>
543
+ void TestHistUpdaterExpandWithDepthWise (const xgboost::tree::TrainParam& param) {
544
+ const size_t num_rows = 3 ;
545
+ const size_t num_columns = 1 ;
546
+ const size_t n_bins = 16 ;
547
+
548
+ Context ctx;
549
+ ctx.UpdateAllowUnknown (Args{{" device" , " sycl" }});
550
+
551
+ DeviceManager device_manager;
552
+ auto qu = device_manager.GetQueue (ctx.Device ());
553
+
554
+ std::vector<float > data = {7 , 3 , 15 };
555
+ auto p_fmat = GetDMatrixFromData (data, num_rows, num_columns);
556
+
557
+ DeviceMatrix dmat;
558
+ dmat.Init (qu, p_fmat.get ());
559
+ common::GHistIndexMatrix gmat;
560
+ gmat.Init (qu, &ctx, dmat, n_bins);
561
+
562
+ std::vector<GradientPair> gpair_host = {{1 , 2 }, {3 , 1 }, {1 , 1 }};
563
+ USMVector<GradientPair, MemoryType::on_device> gpair (&qu, gpair_host);
564
+
565
+ RegTree tree;
566
+ FeatureInteractionConstraintHost int_constraints;
567
+ TestHistUpdater<GradientSumT> updater (&ctx, qu, param, int_constraints, p_fmat.get ());
568
+ updater.SetHistSynchronizer (new BatchHistSynchronizer<GradientSumT>());
569
+ updater.SetHistRowsAdder (new BatchHistRowsAdder<GradientSumT>());
570
+ auto * row_set_collection = updater.TestInitData (gmat, gpair, *p_fmat, tree);
571
+
572
+ updater.TestExpandWithDepthWise (gmat, p_fmat.get (), &tree, gpair);
573
+
574
+ const auto & nodes = tree.GetNodes ();
575
+ std::vector<float > ans (data.size ());
576
+ for (size_t data_idx = 0 ; data_idx < data.size (); ++data_idx) {
577
+ size_t node_idx = 0 ;
578
+ while (!nodes[node_idx].IsLeaf ()) {
579
+ node_idx = data[data_idx] < nodes[node_idx].SplitCond () ? nodes[node_idx].LeftChild () : nodes[node_idx].RightChild ();
580
+ }
581
+ ans[data_idx] = nodes[node_idx].LeafValue ();
582
+ }
583
+
584
+ ASSERT_NEAR (ans[0 ], -0.15 , 1e-6 );
585
+ ASSERT_NEAR (ans[1 ], -0.45 , 1e-6 );
586
+ ASSERT_NEAR (ans[2 ], -0.15 , 1e-6 );
587
+ }
588
+
535
589
TEST (SyclHistUpdater, Sampling) {
536
590
xgboost::tree::TrainParam param;
537
591
param.UpdateAllowUnknown (Args{{" subsample" , " 0.7" }});
@@ -608,4 +662,12 @@ TEST(SyclHistUpdater, ExpandWithLossGuide) {
608
662
TestHistUpdaterExpandWithLossGuide<double >(param);
609
663
}
610
664
665
+ TEST (SyclHistUpdater, ExpandWithDepthWise) {
666
+ xgboost::tree::TrainParam param;
667
+ param.UpdateAllowUnknown (Args{{" max_depth" , " 2" }});
668
+
669
+ TestHistUpdaterExpandWithDepthWise<float >(param);
670
+ TestHistUpdaterExpandWithDepthWise<double >(param);
671
+ }
672
+
611
673
} // namespace xgboost::sycl::tree
0 commit comments