Skip to content

Commit 888ff62

Browse files
authored
Fixes for v.2.1.0 (#55)
* change build * modify build * fix distributed * fix ci * return sycl python tests * set g++ as xgboost compiler for python tests * fix distributed * add simple dask test * linting * fix distibuted * fix * linting * fix prediction bug * batch processing for Transform * lint and fix * linting --------- Co-authored-by: Dmitry Razdoburdin <>
1 parent 0088a2f commit 888ff62

38 files changed

+914
-306
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ jobs:
9595
run: |
9696
mkdir build
9797
cd build
98-
cmake .. -DGOOGLE_TEST=ON -DUSE_DMLC_GTEST=ON -DPLUGIN_SYCL=ON -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX
98+
cmake .. -DGOOGLE_TEST=ON -DUSE_DMLC_GTEST=ON -DPLUGIN_SYCL=ON -DCMAKE_CXX_COMPILER=g++ -DCMAKE_C_COMPILER=gcc -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX
9999
make -j$(nproc)
100100
- name: Run gtest binary for SYCL
101101
run: |

.github/workflows/python_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ jobs:
294294
run: |
295295
mkdir build
296296
cd build
297-
cmake .. -DPLUGIN_SYCL=ON -DCMAKE_PREFIX_PATH=$CONDA_PREFIX
297+
cmake .. -DPLUGIN_SYCL=ON -DCMAKE_PREFIX_PATH=$CONDA_PREFIX -DCMAKE_CXX_COMPILER=g++ -DCMAKE_C_COMPILER=gcc
298298
make -j$(nproc)
299299
- name: Install Python package
300300
run: |

CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
22

33
if(PLUGIN_SYCL)
4-
set(CMAKE_CXX_COMPILER "g++")
5-
set(CMAKE_C_COMPILER "gcc")
64
string(REPLACE " -isystem ${CONDA_PREFIX}/include" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
75
endif()
86

include/xgboost/linalg.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -664,13 +664,13 @@ auto MakeVec(T *ptr, size_t s, DeviceOrd device = DeviceOrd::CPU()) {
664664

665665
template <typename T>
666666
auto MakeVec(HostDeviceVector<T> *data) {
667-
return MakeVec(data->Device().IsCPU() ? data->HostPointer() : data->DevicePointer(), data->Size(),
668-
data->Device());
667+
return MakeVec(data->Device().IsCUDA() ? data->DevicePointer() : data->HostPointer(),
668+
data->Size(), data->Device());
669669
}
670670

671671
template <typename T>
672672
auto MakeVec(HostDeviceVector<T> const *data) {
673-
return MakeVec(data->Device().IsCPU() ? data->ConstHostPointer() : data->ConstDevicePointer(),
673+
return MakeVec(data->Device().IsCUDA() ? data->ConstDevicePointer() : data->ConstHostPointer(),
674674
data->Size(), data->Device());
675675
}
676676

plugin/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ if(PLUGIN_SYCL)
1010
target_compile_definitions(plugin_sycl PUBLIC -DXGBOOST_USE_SYCL=1)
1111
target_link_libraries(plugin_sycl PUBLIC -fsycl)
1212
set_target_properties(plugin_sycl PROPERTIES
13-
COMPILE_FLAGS -fsycl
13+
COMPILE_FLAGS "-fsycl -fno-sycl-id-queries-fit-in-int"
1414
CXX_STANDARD 17
1515
CXX_STANDARD_REQUIRED ON
1616
POSITION_INDEPENDENT_CODE ON)
1717
if(USE_OPENMP)
1818
find_package(OpenMP REQUIRED)
1919
set_target_properties(plugin_sycl PROPERTIES
20-
COMPILE_FLAGS "-fsycl -qopenmp")
20+
COMPILE_FLAGS "-fsycl -qopenmp -fno-sycl-id-queries-fit-in-int")
2121
endif()
2222
# Get compilation and link flags of plugin_sycl and propagate to objxgboost
2323
target_link_libraries(objxgboost PUBLIC plugin_sycl)

plugin/sycl/common/hist_util.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace common {
2121
template<typename GradientSumT>
2222
void InitHist(::sycl::queue qu, GHistRow<GradientSumT, MemoryType::on_device>* hist,
2323
size_t size, ::sycl::event* event) {
24-
*event = qu.fill(hist->Begin(),
24+
*event = qu.fill(hist->Data(),
2525
xgboost::detail::GradientPairInternal<GradientSumT>(), size, *event);
2626
}
2727
template void InitHist(::sycl::queue qu,

plugin/sycl/common/linalg_op.h

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
/**
2+
* Copyright 2021-2024, XGBoost Contributors
3+
* \file linalg_op.h
4+
*/
5+
#ifndef PLUGIN_SYCL_COMMON_LINALG_OP_H_
6+
#define PLUGIN_SYCL_COMMON_LINALG_OP_H_
7+
8+
#include <vector>
9+
#include <utility>
10+
11+
#include "../data.h"
12+
13+
#include <CL/sycl.hpp>
14+
15+
namespace xgboost {
16+
namespace sycl {
17+
namespace linalg {
18+
19+
struct WorkGroupsParams {
20+
size_t n_workgroups;
21+
size_t workgroup_size;
22+
};
23+
24+
template <typename Fn>
25+
::sycl::event GroupWiseKernel(::sycl::queue* qu, int* flag_ptr,
26+
const std::vector<::sycl::event>& events,
27+
const WorkGroupsParams& wg, Fn &&fn) {
28+
::sycl::buffer<int, 1> flag_buf(flag_ptr, 1);
29+
auto event = qu->submit([&](::sycl::handler& cgh) {
30+
cgh.depends_on(events);
31+
auto flag = flag_buf.get_access<::sycl::access::mode::write>(cgh);
32+
cgh.parallel_for_work_group<>(::sycl::range<1>(wg.n_workgroups),
33+
::sycl::range<1>(wg.workgroup_size),
34+
[=](::sycl::group<1> group) {
35+
group.parallel_for_work_item([&](::sycl::h_item<1> item) {
36+
const size_t idx = item.get_global_id()[0];
37+
fn(idx, flag);
38+
});
39+
});
40+
});
41+
return event;
42+
}
43+
44+
struct Argument {
45+
template <typename T>
46+
operator T&&() const;
47+
};
48+
49+
template <typename Fn, typename Is, typename = void>
50+
struct ArgumentsPassedImpl
51+
: std::false_type {};
52+
53+
template <typename Fn, size_t ...Is>
54+
struct ArgumentsPassedImpl<Fn, std::index_sequence<Is...>,
55+
decltype(std::declval<Fn>()(((void)Is, Argument{})...), void())>
56+
: std::true_type {};
57+
58+
template <typename Fn, size_t N>
59+
struct ArgumentsPassed : ArgumentsPassedImpl<Fn, std::make_index_sequence<N>> {};
60+
61+
template <typename OutputDType, typename InputDType,
62+
size_t BatchSize, size_t MaxNumInputs>
63+
class BatchProcessingHelper {
64+
public:
65+
static constexpr size_t kBatchSize = BatchSize;
66+
using InputType = HostDeviceVector<InputDType>;
67+
using OutputType = HostDeviceVector<OutputDType>;
68+
69+
using ConstInputIteratorT =
70+
typename USMVector<InputDType, MemoryType::on_device>::ConstIterator;
71+
using InputIteratorT = typename USMVector<InputDType, MemoryType::on_device>::Iterator;
72+
using OutputIteratorT = typename USMVector<OutputDType, MemoryType::on_device>::Iterator;
73+
74+
private:
75+
template <size_t NumInput = 0>
76+
void Host2Buffers(InputDType* in_buffer_ptr, const InputType& input) {
77+
/*
78+
* Some inputs may have less than 1 sample per output symbol.
79+
*/
80+
const size_t sub_sample_rate = ndata_ * sample_rates_[NumInput+1] / input.Size();
81+
const size_t n_samples = batch_size_ * sample_rates_[NumInput+1] / sub_sample_rate;
82+
83+
const InputDType* in_host_ptr = input.HostPointer() +
84+
batch_begin_ * sample_rates_[NumInput+1] / sub_sample_rate;
85+
86+
events_[NumInput] =
87+
qu_->memcpy(in_buffer_ptr, in_host_ptr, n_samples * sizeof(InputDType),
88+
events_[MaxNumInputs - 2]);
89+
}
90+
91+
template <size_t NumInput = 0, class... InputTypes>
92+
void Host2Buffers(InputDType* in_buffer_ptr, const InputType& input,
93+
const InputTypes&... other_inputs) {
94+
// Make copy for the first input in the list
95+
Host2Buffers<NumInput>(in_buffer_ptr, input);
96+
// Recurent call for next inputs
97+
InputDType* next_input = in_buffer_.Data() + in_buff_offsets_[NumInput + 1];
98+
Host2Buffers<NumInput+1>(next_input, other_inputs...);
99+
}
100+
101+
void Buffers2Host(OutputType* output) {
102+
const size_t n_samples = batch_size_ * sample_rates_[0];
103+
OutputDType* out_host_ptr = output->HostPointer() + batch_begin_* sample_rates_[0];
104+
events_[MaxNumInputs - 1] =
105+
qu_->memcpy(out_host_ptr, out_buffer_.DataConst(), n_samples * sizeof(OutputDType),
106+
events_[MaxNumInputs - 2]);
107+
}
108+
109+
void Buffers2Host(InputType* output) {
110+
const size_t n_samples = batch_size_ * sample_rates_[1];
111+
InputDType* out_host_ptr = output->HostPointer() + batch_begin_* sample_rates_[1];
112+
events_[MaxNumInputs - 1] =
113+
qu_->memcpy(out_host_ptr, in_buffer_.DataConst(), n_samples * sizeof(InputDType),
114+
events_[MaxNumInputs - 2]);
115+
}
116+
117+
template <size_t NumInputs = 1, typename Fn, class... InputTypes>
118+
void Call(Fn &&fn, ConstInputIteratorT input, const InputTypes... other_inputs) {
119+
static_assert(NumInputs <= MaxNumInputs,
120+
"To many arguments in the passed function");
121+
/* Passed lambda may have less inputs than MaxNumInputs,
122+
* need to pass only requared number of arguments
123+
*/
124+
// 1 for events, 1 for batch_size, 1 for output
125+
if constexpr (ArgumentsPassed<Fn, NumInputs + 1 + 1 + 1>::value) {
126+
events_[MaxNumInputs - 2] = fn(events_, batch_size_,
127+
out_buffer_.Begin(), input, other_inputs...);
128+
} else {
129+
ConstInputIteratorT next_input = in_buffer_.Cbegin() +
130+
in_buff_offsets_[MaxNumInputs - 1 - NumInputs];
131+
Call<NumInputs+1>(std::forward<Fn>(fn), next_input, input, other_inputs...);
132+
}
133+
}
134+
135+
template <size_t NumInputs = 1, typename Fn, class... InputTypes>
136+
void Call(Fn &&fn, InputIteratorT io, ConstInputIteratorT input,
137+
const InputTypes... other_inputs) {
138+
static_assert(NumInputs <= MaxNumInputs,
139+
"To many arguments in the passed function");
140+
if constexpr (ArgumentsPassed<Fn, NumInputs + 1 + 1>::value) {
141+
events_[MaxNumInputs - 2] = fn(events_, batch_size_,
142+
io, input, other_inputs...);
143+
} else {
144+
const ConstInputIteratorT next_input = in_buffer_.Cbegin() +
145+
in_buff_offsets_[MaxNumInputs - NumInputs];
146+
Call<NumInputs+1>(std::forward<Fn>(fn), io, next_input, input, other_inputs...);
147+
}
148+
}
149+
150+
template <size_t NumInputs = 1, typename Fn>
151+
void Call(Fn &&fn, InputIteratorT io) {
152+
static_assert(NumInputs <= MaxNumInputs,
153+
"To many arguments in the passed function");
154+
if constexpr (ArgumentsPassed<Fn, NumInputs + 1 + 1>::value) {
155+
events_[MaxNumInputs - 2] = fn(events_, batch_size_, io);
156+
} else {
157+
const ConstInputIteratorT next_input = in_buffer_.Cbegin() +
158+
in_buff_offsets_[MaxNumInputs - 1];
159+
Call<NumInputs+1>(std::forward<Fn>(fn), io, next_input);
160+
}
161+
}
162+
163+
public:
164+
BatchProcessingHelper() = default;
165+
166+
// The first element of sample_rate always corresonds to output sample rate
167+
void InitBuffers(::sycl::queue* qu, const std::vector<int>& sample_rate) {
168+
assert(sample_rate.size() == MaxNumInputs + 1);
169+
sample_rates_ = sample_rate;
170+
qu_ = qu;
171+
events_.resize(MaxNumInputs + 2);
172+
out_buffer_.Resize(qu, kBatchSize * sample_rate.front());
173+
174+
in_buff_offsets_[0] = 0;
175+
for (size_t i = 1; i < MaxNumInputs; ++i) {
176+
in_buff_offsets_[i] = in_buff_offsets_[i - 1] + kBatchSize * sample_rate[i];
177+
}
178+
const size_t in_buff_size = in_buff_offsets_.back() + kBatchSize * sample_rate.back();
179+
in_buffer_.Resize(qu, in_buff_size);
180+
}
181+
182+
/*
183+
* Batch-wise proces on sycl device
184+
* output = fn(inputs)
185+
*/
186+
template <typename Fn, class... InputTypes>
187+
void Calculate(Fn &&fn, OutputType* output, const InputTypes&... inputs) {
188+
ndata_ = output->Size() / sample_rates_.front();
189+
const size_t nBatch = ndata_ / kBatchSize + (ndata_ % kBatchSize > 0);
190+
for (size_t batch = 0; batch < nBatch; ++batch) {
191+
batch_begin_ = batch * kBatchSize;
192+
batch_end_ = (batch == nBatch - 1) ? ndata_ : batch_begin_ + kBatchSize;
193+
batch_size_ = batch_end_ - batch_begin_;
194+
195+
// Iteratively copy all inputs to device buffers
196+
Host2Buffers(in_buffer_.Data(), inputs...);
197+
// Pack buffers and call function
198+
// We shift input pointer to keep the same order of inputs after packing
199+
Call(std::forward<Fn>(fn), in_buffer_.Cbegin() + in_buff_offsets_.back());
200+
// Copy results to host
201+
Buffers2Host(output);
202+
}
203+
}
204+
205+
/*
206+
* Batch-wise proces on sycl device
207+
* input = fn(input, other_inputs)
208+
*/
209+
template <typename Fn, class... InputTypes>
210+
void Calculate(Fn &&fn, InputType* input, const InputTypes&... other_inputs) {
211+
ndata_ = input->Size() / sample_rates_[1];
212+
const size_t nBatch = ndata_ / kBatchSize + (ndata_ % kBatchSize > 0);
213+
for (size_t batch = 0; batch < nBatch; ++batch) {
214+
batch_begin_ = batch * kBatchSize;
215+
batch_end_ = (batch == nBatch - 1) ? ndata_ : batch_begin_ + kBatchSize;
216+
batch_size_ = batch_end_ - batch_begin_;
217+
218+
// Iteratively copy all inputs to device buffers.
219+
// inputs are pased by const reference
220+
Host2Buffers(in_buffer_.Data(), *(input), other_inputs...);
221+
// Pack buffers and call function
222+
// We shift input pointer to keep the same order of inputs after packing
223+
Call(std::forward<Fn>(fn), in_buffer_.Begin());
224+
// Copy results to host
225+
Buffers2Host(input);
226+
}
227+
}
228+
229+
private:
230+
std::array<int, MaxNumInputs> in_buff_offsets_;
231+
std::vector<int> sample_rates_;
232+
size_t ndata_;
233+
size_t batch_begin_;
234+
size_t batch_end_;
235+
// is not equal to kBatchSize for the last batch
236+
size_t batch_size_;
237+
::sycl::queue* qu_;
238+
std::vector<::sycl::event> events_;
239+
USMVector<InputDType, MemoryType::on_device> in_buffer_;
240+
USMVector<OutputDType, MemoryType::on_device> out_buffer_;
241+
};
242+
243+
} // namespace linalg
244+
} // namespace sycl
245+
} // namespace xgboost
246+
#endif // PLUGIN_SYCL_COMMON_LINALG_OP_H_

plugin/sycl/common/row_set.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ class RowSetCollection {
7171
inline void Init() {
7272
CHECK_EQ(elem_of_each_node_.size(), 0U);
7373

74-
const size_t* begin = row_indices_.Begin();
75-
const size_t* end = row_indices_.End();
74+
const size_t* begin = row_indices_.Data();
75+
const size_t* end = begin + row_indices_.Size();
7676
elem_of_each_node_.emplace_back(Elem(begin, end, 0));
7777
}
7878

@@ -86,7 +86,7 @@ class RowSetCollection {
8686
size_t n_right) {
8787
const Elem e = elem_of_each_node_[node_id];
8888
CHECK(e.begin != nullptr);
89-
size_t* all_begin = row_indices_.Begin();
89+
size_t* all_begin = row_indices_.Data();
9090
size_t* begin = all_begin + (e.begin - all_begin);
9191

9292

0 commit comments

Comments
 (0)