@@ -49,9 +49,8 @@ class TestHistUpdater : public HistUpdater<GradientSumT> {
49
49
auto TestInitNewNode (int nid,
50
50
const common::GHistIndexMatrix& gmat,
51
51
const USMVector<GradientPair, MemoryType::on_device> &gpair,
52
- const DMatrix& fmat,
53
52
const RegTree& tree) {
54
- HistUpdater<GradientSumT>::InitNewNode (nid, gmat, gpair, fmat, tree);
53
+ HistUpdater<GradientSumT>::InitNewNode (nid, gmat, gpair, tree);
55
54
return HistUpdater<GradientSumT>::snode_host_[nid];
56
55
}
57
56
@@ -67,6 +66,13 @@ class TestHistUpdater : public HistUpdater<GradientSumT> {
67
66
RegTree* p_tree) {
68
67
HistUpdater<GradientSumT>::ApplySplit (nodes, gmat, p_tree);
69
68
}
69
+
70
+ auto TestExpandWithLossGuide (const common::GHistIndexMatrix& gmat,
71
+ DMatrix *p_fmat,
72
+ RegTree* p_tree,
73
+ const USMVector<GradientPair, MemoryType::on_device> &gpair) {
74
+ HistUpdater<GradientSumT>::ExpandWithLossGuide (gmat, p_tree, gpair);
75
+ }
70
76
};
71
77
72
78
void GenerateRandomGPairs (::sycl::queue* qu, GradientPair* gpair_ptr, size_t num_rows, bool has_neg_hess) {
@@ -295,7 +301,7 @@ void TestHistUpdaterInitNewNode(const xgboost::tree::TrainParam& param, float sp
295
301
auto & row_idxs = row_set_collection->Data ();
296
302
const size_t * row_idxs_ptr = row_idxs.DataConst ();
297
303
updater.TestBuildHistogramsLossGuide (node, gmat, &tree, gpair);
298
- const auto snode = updater.TestInitNewNode (ExpandEntry::kRootNid , gmat, gpair, *p_fmat, tree);
304
+ const auto snode = updater.TestInitNewNode (ExpandEntry::kRootNid , gmat, gpair, tree);
299
305
300
306
GradStats<GradientSumT> grad_stat;
301
307
{
@@ -354,7 +360,7 @@ void TestHistUpdaterEvaluateSplits(const xgboost::tree::TrainParam& param) {
354
360
auto & row_idxs = row_set_collection->Data ();
355
361
const size_t * row_idxs_ptr = row_idxs.DataConst ();
356
362
const auto * hist = updater.TestBuildHistogramsLossGuide (node, gmat, &tree, gpair);
357
- const auto snode_init = updater.TestInitNewNode (ExpandEntry::kRootNid , gmat, gpair, *p_fmat, tree);
363
+ const auto snode_init = updater.TestInitNewNode (ExpandEntry::kRootNid , gmat, gpair, tree);
358
364
359
365
const auto snode_updated = updater.TestEvaluateSplits ({node}, gmat, tree);
360
366
auto best_loss_chg = snode_updated[0 ].best .loss_chg ;
@@ -479,6 +485,53 @@ void TestHistUpdaterApplySplit(const xgboost::tree::TrainParam& param, float spa
479
485
480
486
}
481
487
488
+ template <typename GradientSumT>
489
+ void TestHistUpdaterExpandWithLossGuide (const xgboost::tree::TrainParam& param) {
490
+ const size_t num_rows = 3 ;
491
+ const size_t num_columns = 1 ;
492
+ const size_t n_bins = 16 ;
493
+
494
+ Context ctx;
495
+ ctx.UpdateAllowUnknown (Args{{" device" , " sycl" }});
496
+
497
+ DeviceManager device_manager;
498
+ auto qu = device_manager.GetQueue (ctx.Device ());
499
+
500
+ std::vector<float > data = {7 , 3 , 15 };
501
+ auto p_fmat = GetDMatrixFromData (data, num_rows, num_columns);
502
+
503
+ DeviceMatrix dmat;
504
+ dmat.Init (qu, p_fmat.get ());
505
+ common::GHistIndexMatrix gmat;
506
+ gmat.Init (qu, &ctx, dmat, n_bins);
507
+
508
+ std::vector<GradientPair> gpair_host = {{1 , 2 }, {3 , 1 }, {1 , 1 }};
509
+ USMVector<GradientPair, MemoryType::on_device> gpair (&qu, gpair_host);
510
+
511
+ RegTree tree;
512
+ FeatureInteractionConstraintHost int_constraints;
513
+ TestHistUpdater<GradientSumT> updater (&ctx, qu, param, int_constraints, p_fmat.get ());
514
+ updater.SetHistSynchronizer (new BatchHistSynchronizer<GradientSumT>());
515
+ updater.SetHistRowsAdder (new BatchHistRowsAdder<GradientSumT>());
516
+ auto * row_set_collection = updater.TestInitData (gmat, gpair, *p_fmat, tree);
517
+
518
+ updater.TestExpandWithLossGuide (gmat, p_fmat.get (), &tree, gpair);
519
+
520
+ const auto & nodes = tree.GetNodes ();
521
+ std::vector<float > ans (data.size ());
522
+ for (size_t data_idx = 0 ; data_idx < data.size (); ++data_idx) {
523
+ size_t node_idx = 0 ;
524
+ while (!nodes[node_idx].IsLeaf ()) {
525
+ node_idx = data[data_idx] < nodes[node_idx].SplitCond () ? nodes[node_idx].LeftChild () : nodes[node_idx].RightChild ();
526
+ }
527
+ ans[data_idx] = nodes[node_idx].LeafValue ();
528
+ }
529
+
530
+ ASSERT_NEAR (ans[0 ], -0.15 , 1e-6 );
531
+ ASSERT_NEAR (ans[1 ], -0.45 , 1e-6 );
532
+ ASSERT_NEAR (ans[2 ], -0.15 , 1e-6 );
533
+ }
534
+
482
535
TEST (SyclHistUpdater, Sampling) {
483
536
xgboost::tree::TrainParam param;
484
537
param.UpdateAllowUnknown (Args{{" subsample" , " 0.7" }});
@@ -546,4 +599,13 @@ TEST(SyclHistUpdater, ApplySplitDence) {
546
599
TestHistUpdaterApplySplit<double >(param, 0.0 , (1u << 16 ) + 1 );
547
600
}
548
601
602
+ TEST (SyclHistUpdater, ExpandWithLossGuide) {
603
+ xgboost::tree::TrainParam param;
604
+ param.UpdateAllowUnknown (Args{{" max_depth" , " 2" },
605
+ {" grow_policy" , " lossguide" }});
606
+
607
+ TestHistUpdaterExpandWithLossGuide<float >(param);
608
+ TestHistUpdaterExpandWithLossGuide<double >(param);
609
+ }
610
+
549
611
} // namespace xgboost::sycl::tree
0 commit comments