Skip to content
This repository was archived by the owner on Jan 13, 2025. It is now read-only.

Refactored blas1 dot & sdsdot operators #471

Merged
Merged
11 changes: 7 additions & 4 deletions benchmark/portblas/blas1/dot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@ void run(benchmark::State& state, blas::SB_Handle* sb_handle_ptr, index_t size,
scalar_t vr_temp = 0;
{
auto vr_temp_gpu = blas::helper::allocate<mem_alloc, scalar_t>(1, q);
auto copyToD =
blas::helper::copy_to_device<scalar_t>(q, &vr_temp, vr_temp_gpu, 1);
auto dot_event = _dot(sb_handle, size, inx, static_cast<index_t>(1), iny,
static_cast<index_t>(1), vr_temp_gpu);
static_cast<index_t>(1), vr_temp_gpu, {copyToD});
sb_handle.wait(dot_event);
auto copy_output = blas::helper::copy_to_host(q, vr_temp_gpu, &vr_temp, 1);
sb_handle.wait(copy_output);
Expand Down Expand Up @@ -128,8 +130,8 @@ void register_benchmark(blas::SB_Handle* sb_handle_ptr, bool* success,
};

benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
size, mem_type).c_str(),
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(size, mem_type)
.c_str(),
BM_lambda, sb_handle_ptr, size, success)
->UseRealTime();
}
Expand All @@ -141,7 +143,8 @@ void register_benchmark(blas_benchmark::Args& args,
auto dot_params = blas_benchmark::utils::get_blas1_params(args);

register_benchmark<scalar_t, blas::helper::AllocType::buffer>(
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, dot_params);
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER,
dot_params);
#ifdef SB_ENABLE_USM
register_benchmark<scalar_t, blas::helper::AllocType::usm>(
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM, dot_params);
Expand Down
14 changes: 9 additions & 5 deletions benchmark/portblas/blas1/sdsdot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,11 @@ void run(benchmark::State& state, blas::SB_Handle* sb_handle_ptr, index_t size,
scalar_t vr_temp = 0;
{
auto vr_temp_gpu = blas::helper::allocate<mem_alloc, scalar_t>(1, q);
auto copyToD =
blas::helper::copy_to_device<scalar_t>(q, &vr_temp, vr_temp_gpu, 1);
auto sdsdot_event =
_sdsdot(sb_handle, size, sb, inx, static_cast<index_t>(1), iny,
static_cast<index_t>(1), vr_temp_gpu);
static_cast<index_t>(1), vr_temp_gpu, {copyToD});
sb_handle.wait(sdsdot_event);
auto event = blas::helper::copy_to_host(q, vr_temp_gpu, &vr_temp, 1);
sb_handle.wait(event);
Expand Down Expand Up @@ -126,8 +128,8 @@ void register_benchmark(blas::SB_Handle* sb_handle_ptr, bool* success,
run<scalar_t, mem_alloc>(st, sb_handle_ptr, size, success);
};
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
size, mem_type).c_str(),
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(size, mem_type)
.c_str(),
BM_lambda, sb_handle_ptr, size, success)
->UseRealTime();
}
Expand All @@ -139,10 +141,12 @@ void register_benchmark(blas_benchmark::Args& args,
auto sdsdot_params = blas_benchmark::utils::get_blas1_params(args);

register_benchmark<scalar_t, blas::helper::AllocType::buffer>(
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, sdsdot_params);
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER,
sdsdot_params);
#ifdef SB_ENABLE_USM
register_benchmark<scalar_t, blas::helper::AllocType::usm>(
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM, sdsdot_params);
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM,
sdsdot_params);
#endif
}

Expand Down
51 changes: 33 additions & 18 deletions include/interface/blas1_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,19 @@ typename sb_handle_t::event_t _nrm2_impl(
container_1_t _rs, const index_t number_WG,
const typename sb_handle_t::event_t &_dependencies);

/*!
* \brief Prototype for the internal implementation of the Dot operator. See
* documentation in the blas1_interface.hpp file for details.
*/
template <int localSize, int localMemSize, typename sb_handle_t,
typename container_0_t, typename container_1_t,
typename container_2_t, typename index_t, typename increment_t>
typename sb_handle_t::event_t _dot_impl(
sb_handle_t &sb_handle, index_t _N, container_0_t _vx, increment_t _incx,
container_1_t _vy, increment_t _incy, container_2_t _rs,
const index_t _number_wg,
const typename sb_handle_t::event_t &_dependencies);

/**
* @brief _rot constructor given plane rotation
* @param sb_handle SB_Handle
Expand Down Expand Up @@ -306,12 +319,12 @@ typename sb_handle_t::event_t _rotm(
* @tparam container_3_t Buffer Iterator or USM pointer
* @tparam container_4_t Buffer Iterator or USM pointer
* @param sb_handle SB_Handle
* @param _d1[in,out] On entry, memory object holding the scaling factor for the
* x-coordinate. On exit, the re-scaled _d1.
* @param _d2[in,out] On entry, memory object holding the scaling factor for the
* y-coordinate. On exit, the re-scaled _d2.
* @param _x1[in,out] On entry, memory object holding the x-coordinate. On exit,
* the re-scaled _x1
* @param _d1[in,out] On entry, memory object holding the scaling factor for
* the x-coordinate. On exit, the re-scaled _d1.
* @param _d2[in,out] On entry, memory object holding the scaling factor for
* the y-coordinate. On exit, the re-scaled _d2.
* @param _x1[in,out] On entry, memory object holding the x-coordinate. On
* exit, the re-scaled _x1
* @param _y1[in] Memory object holding the y-coordinate of the point.
* @param _param[out] Buffer with the following layout: [flag, h11, h21, h12,
* h22].
Expand Down Expand Up @@ -359,8 +372,10 @@ typename sb_handle_t::event_t _rotg(
* @tparam sb_handle_t SB_Handle type
* @tparam scalar_t Scalar type
* @param sb_handle SB_Handle
* @param a[in, out] On entry, x-coordinate of the point. On exit, the scalar z.
* @param b[in, out] On entry, y-coordinate of the point. On exit, the scalar r.
* @param a[in, out] On entry, x-coordinate of the point. On exit, the scalar
* z.
* @param b[in, out] On entry, y-coordinate of the point. On exit, the scalar
* r.
* @param c[out] scalar representing the output c.
* @param s[out] scalar representing the output s.
* @param _dependencies Vector of events
Expand All @@ -377,7 +392,6 @@ void _rotg(sb_handle_t &sb_handle, scalar_t &a, scalar_t &b, scalar_t &c,
* @tparam sb_handle_t SB_Handle type
* @tparam container_0_t Buffer Iterator or USM pointer
* @tparam container_1_t Buffer Iterator or USM pointer
* @tparam container_2_t Buffer Iterator or USM pointer
* @tparam index_t Index type
* @tparam increment_t Increment type
* @param sb_handle SB_Handle
Expand All @@ -404,7 +418,6 @@ typename ValueType<container_0_t>::type _dot(
* @tparam sb_handle_t SB_Handle type
* @tparam container_0_t Buffer Iterator or USM pointer
* @tparam container_1_t Buffer Iterator or USM pointer
* @tparam container_2_t Buffer Iterator or USM pointer
* @tparam index_t Index type
* @tparam increment_t Increment type
* @param sb_handle SB_Handle
Expand Down Expand Up @@ -754,12 +767,12 @@ typename sb_handle_t::event_t _rotm(
* @tparam container_3_t Buffer Iterator or USM pointer
* @tparam container_4_t Buffer Iterator or USM pointer
* @param sb_handle SB_Handle
* @param _d1[in,out] On entry, memory object holding the scaling factor for the
* x-coordinate. On exit, the re-scaled _d1.
* @param _d2[in,out] On entry, memory object holding the scaling factor for the
* y-coordinate. On exit, the re-scaled _d2.
* @param _x1[in,out] On entry, memory object holding the x-coordinate. On exit,
* the re-scaled _x1
* @param _d1[in,out] On entry, memory object holding the scaling factor for
* the x-coordinate. On exit, the re-scaled _d1.
* @param _d2[in,out] On entry, memory object holding the scaling factor for
* the y-coordinate. On exit, the re-scaled _d2.
* @param _x1[in,out] On entry, memory object holding the x-coordinate. On
* exit, the re-scaled _x1
* @param _y1[in] Memory object holding the y-coordinate of the point.
* @param _param[out] Buffer with the following layout: [flag, h11, h21, h12,
* h22].
Expand Down Expand Up @@ -811,8 +824,10 @@ typename sb_handle_t::event_t _rotg(
* @tparam sb_handle_t SB_Handle type
* @tparam scalar_t Scalar type
* @param sb_handle SB_Handle
* @param a[in, out] On entry, x-coordinate of the point. On exit, the scalar z.
* @param b[in, out] On entry, y-coordinate of the point. On exit, the scalar r.
* @param a[in, out] On entry, x-coordinate of the point. On exit, the scalar
* z.
* @param b[in, out] On entry, y-coordinate of the point. On exit, the scalar
* r.
* @param c[out] scalar representing the output c.
* @param s[out] scalar representing the output s.
* @param _dependencies Vector of events
Expand Down
18 changes: 18 additions & 0 deletions include/operations/blas1_trees.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,24 @@ struct BinaryOp {
void adjust_access_displacement();
};

/*! BinaryOpConst.
* @brief Implements a const Binary Operation (x OP z) with x and z vectors.
*/
template <typename operator_t, typename lhs_t, typename rhs_t>
struct BinaryOpConst {
using index_t = typename rhs_t::index_t;
using value_t = typename ResolveReturnType<operator_t, rhs_t>::type::value_t;
lhs_t lhs_;
rhs_t rhs_;
BinaryOpConst(lhs_t &_l, rhs_t &_r);
index_t get_size() const;
bool valid_thread(cl::sycl::nd_item<1> ndItem) const;
value_t eval(index_t i) const;
value_t eval(cl::sycl::nd_item<1> ndItem) const;
void bind(cl::sycl::handler &h);
void adjust_access_displacement();
};

/*! TupleOp.
* @brief Implements a Tuple Operation (map (\x -> [i, x]) vector).
*/
Expand Down
23 changes: 23 additions & 0 deletions src/interface/blas1/backend/amd_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,29 @@ typename sb_handle_t::event_t _nrm2(
}
} // namespace backend
} // namespace nrm2

namespace dot {
namespace backend {
template <typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename index_t, typename increment_t>
typename sb_handle_t::event_t _dot(
sb_handle_t& sb_handle, index_t _N, container_0_t _vx, increment_t _incx,
container_1_t _vy, increment_t _incy, container_2_t _rs,
const typename sb_handle_t::event_t& _dependencies) {
if (_N < (1 << 18)) {
constexpr index_t localSize = 1024;
const index_t number_WG = (_N + localSize - 1) / localSize;
return blas::internal::_dot_impl<static_cast<int>(localSize), 32>(
sb_handle, _N, _vx, _incx, _vy, _incy, _rs, number_WG, _dependencies);
} else {
constexpr int localSize = 512;
constexpr index_t number_WG = 512;
return blas::internal::_dot_impl<localSize, 32>(
sb_handle, _N, _vx, _incx, _vy, _incy, _rs, number_WG, _dependencies);
}
}
} // namespace backend
} // namespace dot
} // namespace blas

#endif
16 changes: 16 additions & 0 deletions src/interface/blas1/backend/default_cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,22 @@ typename sb_handle_t::event_t _nrm2(
}
} // namespace backend
} // namespace nrm2

namespace dot {
namespace backend {
template <typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename index_t, typename increment_t>
typename sb_handle_t::event_t _dot(
sb_handle_t& sb_handle, index_t _N, container_0_t _vx, increment_t _incx,
container_1_t _vy, increment_t _incy, container_2_t _rs,
const typename sb_handle_t::event_t& _dependencies) {
constexpr int localSize = 8;
constexpr index_t number_WG = 16;
return blas::internal::_dot_impl<localSize, 0>(
sb_handle, _N, _vx, _incx, _vy, _incy, _rs, number_WG, _dependencies);
}
} // namespace backend
} // namespace dot
} // namespace blas

#endif
17 changes: 17 additions & 0 deletions src/interface/blas1/backend/intel_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,23 @@ typename sb_handle_t::event_t _nrm2(
} // namespace backend
} // namespace nrm2

namespace dot {
namespace backend {
template <typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename index_t, typename increment_t>
typename sb_handle_t::event_t _dot(
sb_handle_t& sb_handle, index_t _N, container_0_t _vx, increment_t _incx,
container_1_t _vy, increment_t _incy, container_2_t _rs,
const typename sb_handle_t::event_t& _dependencies) {
constexpr index_t localSize = 128;
const index_t number_WG =
std::min((_N + localSize - 1) / localSize, static_cast<index_t>(512));
return blas::internal::_dot_impl<static_cast<int>(localSize), 32>(
sb_handle, _N, _vx, _incx, _vy, _incy, _rs, number_WG, _dependencies);
}
} // namespace backend
} // namespace dot

} // namespace blas

#endif
26 changes: 26 additions & 0 deletions src/interface/blas1/backend/nvidia_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,32 @@ typename sb_handle_t::event_t _nrm2(
} // namespace backend
} // namespace nrm2

namespace dot {
namespace backend {
template <typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename index_t, typename increment_t>
typename sb_handle_t::event_t _dot(
sb_handle_t& sb_handle, index_t _N, container_0_t _vx, increment_t _incx,
container_1_t _vy, increment_t _incy, container_2_t _rs,
const typename sb_handle_t::event_t& _dependencies) {
if (_N < (1 << 23)) {
constexpr index_t localSize = 512;
const index_t number_WG = (_N < (1 << 18))
? (_N + localSize - 1) / localSize
: static_cast<index_t>(256);

return blas::internal::_dot_impl<static_cast<int>(localSize), 32>(
sb_handle, _N, _vx, _incx, _vy, _incy, _rs, number_WG, _dependencies);
} else {
constexpr int localSize = 512;
constexpr index_t number_WG = 1024;
return blas::internal::_dot_impl<localSize, 32>(
sb_handle, _N, _vx, _incx, _vy, _incy, _rs, number_WG, _dependencies);
}
}
} // namespace backend
} // namespace dot

} // namespace blas

#endif
6 changes: 5 additions & 1 deletion src/interface/blas1/dot.cpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ namespace internal {
* @tparam sb_handle_t SB_Handle type
* @tparam container_0_t Buffer Iterator or USM Pointer
* @tparam container_1_t Buffer Iterator or USM Pointer
* @tparam container_2_t Buffer Iterator or USM Pointer
* @tparam index_t Index type
* @tparam increment_t Increment type
* @param sb_handle SB_Handle
Expand All @@ -62,6 +61,11 @@ template typename SB_Handle::event_t _dot(
SB_Handle& sb_handle, ${INDEX_TYPE} _N, ${DATA_TYPE} * _vx,
${INCREMENT_TYPE} _incx, ${DATA_TYPE} * _vy, ${INCREMENT_TYPE} _incy,
${DATA_TYPE} * _rs, const typename SB_Handle::event_t& dependencies);

template typename SB_Handle::event_t _dot(
SB_Handle& sb_handle, ${INDEX_TYPE} _N, const ${DATA_TYPE} * _vx,
${INCREMENT_TYPE} _incx, const ${DATA_TYPE} * _vy, ${INCREMENT_TYPE} _incy,
${DATA_TYPE} * _rs, const typename SB_Handle::event_t& dependencies);
#endif

} // namespace internal
Expand Down
6 changes: 5 additions & 1 deletion src/interface/blas1/dot_return.cpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ namespace internal {
* @tparam sb_handle_t SB_Handle type
* @tparam container_0_t Buffer Iterator or USM Pointer
* @tparam container_1_t Buffer Iterator or USM Pointer
* @tparam container_2_t Buffer Iterator or USM Pointer
* @tparam index_t Index type
* @tparam increment_t Increment type
* @param sb_handle SB_Handle
Expand All @@ -61,6 +60,11 @@ template typename ValueType<${DATA_TYPE}>::type _dot(
SB_Handle& sb_handle, ${INDEX_TYPE} _N, ${DATA_TYPE} * _vx,
${INCREMENT_TYPE} _incx, ${DATA_TYPE} * _vy, ${INCREMENT_TYPE} _incy,
const typename SB_Handle::event_t& dependencies);

template typename ValueType<${DATA_TYPE}>::type _dot(
SB_Handle& sb_handle, ${INDEX_TYPE} _N, const ${DATA_TYPE} * _vx,
${INCREMENT_TYPE} _incx, const ${DATA_TYPE} * _vy, ${INCREMENT_TYPE} _incy,
const typename SB_Handle::event_t& dependencies);
#endif

} // namespace internal
Expand Down
5 changes: 5 additions & 0 deletions src/interface/blas1/sdsdot.cpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ template typename SB_Handle::event_t _sdsdot(
SB_Handle& sb_handle, ${INDEX_TYPE} _N, float sb, ${DATA_TYPE} * _vx,
${INCREMENT_TYPE} _incx, ${DATA_TYPE} * _vy, ${INCREMENT_TYPE} _incy,
${DATA_TYPE} * _rs, const typename SB_Handle::event_t& dependencies);

template typename SB_Handle::event_t _sdsdot(
SB_Handle& sb_handle, ${INDEX_TYPE} _N, float sb, const ${DATA_TYPE} * _vx,
${INCREMENT_TYPE} _incx, const ${DATA_TYPE} * _vy, ${INCREMENT_TYPE} _incy,
${DATA_TYPE} * _rs, const typename SB_Handle::event_t& dependencies);
#endif

} // namespace internal
Expand Down
5 changes: 5 additions & 0 deletions src/interface/blas1/sdsdot_return.cpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ template typename ValueType<${DATA_TYPE}>::type _sdsdot(
SB_Handle& sb_handle, ${INDEX_TYPE} _N, float sb, ${DATA_TYPE} * _vx,
${INCREMENT_TYPE} _incx, ${DATA_TYPE} * _vy, ${INCREMENT_TYPE} _incy,
const typename SB_Handle::event_t& dependencies);

template typename ValueType<${DATA_TYPE}>::type _sdsdot(
SB_Handle& sb_handle, ${INDEX_TYPE} _N, float sb, const ${DATA_TYPE} * _vx,
${INCREMENT_TYPE} _incx, const ${DATA_TYPE} * _vy, ${INCREMENT_TYPE} _incy,
const typename SB_Handle::event_t& dependencies);
#endif

} // namespace internal
Expand Down
Loading