Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Commit

Permalink
Refined transpose kernel configurations for CPU target (#485)
Browse files Browse the repository at this point in the history
* Down-sized some copy-op benchmarks
  • Loading branch information
OuadiElfarouki authored Dec 19, 2023
1 parent 8b220fc commit 39e3747
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 48 deletions.
76 changes: 32 additions & 44 deletions common/include/common/common_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1242,15 +1242,13 @@ static inline std::vector<matcopy_param_t<scalar_t>> get_matcopy_params(
std::vector<matcopy_param_t<scalar_t>> matcopy_default;
constexpr index_t dmin = 64, dmax = 8192;
constexpr scalar_t alpha{2};
constexpr index_t lda_mul = 1;
constexpr index_t ldb_mul = 1;
for (char trans : {'n', 't'}) {
for (index_t m = dmin; m <= dmax; m *= 2) {
for (index_t n = dmin; n <= dmax; n *= 2) {
for (index_t lda_mul = 1; lda_mul < 2; ++lda_mul) {
for (index_t ldb_mul = 1; ldb_mul < 2; ++ldb_mul) {
matcopy_default.push_back(
std::make_tuple(trans, m, n, alpha, lda_mul, ldb_mul));
}
}
matcopy_default.push_back(
std::make_tuple(trans, m, n, alpha, lda_mul, ldb_mul));
}
}
}
Expand Down Expand Up @@ -1287,17 +1285,15 @@ static inline std::vector<omatcopy2_param_t<scalar_t>> get_omatcopy2_params(
std::vector<omatcopy2_param_t<scalar_t>> omatcopy2_default;
constexpr index_t dmin = 1024, dmax = 8192;
constexpr scalar_t alpha{2};
constexpr index_t lda_mul = 1;
constexpr index_t ldb_mul = 1;
for (char trans : {'n', 't'}) {
for (index_t m = dmin; m <= dmax; m *= 2) {
for (index_t n = dmin; n <= dmax; n *= 2) {
for (index_t lda_mul = 1; lda_mul < 2; ++lda_mul) {
for (index_t inc_a = 1; inc_a < 3; ++inc_a) {
for (index_t ldb_mul = 1; ldb_mul < 2; ++ldb_mul) {
for (index_t inc_b = 1; inc_b < 3; ++inc_b) {
omatcopy2_default.push_back(std::make_tuple(
trans, m, n, alpha, lda_mul, ldb_mul, inc_a, inc_b));
}
}
for (index_t inc_a = 1; inc_a < 3; ++inc_a) {
for (index_t inc_b = 1; inc_b < 3; ++inc_b) {
omatcopy2_default.push_back(std::make_tuple(
trans, m, n, alpha, lda_mul, ldb_mul, inc_a, inc_b));
}
}
}
Expand Down Expand Up @@ -1336,21 +1332,20 @@ get_matcopy_batch_params(Args& args) {
if (args.csv_param.empty()) {
warning_no_csv();
std::vector<matcopy_batch_param_t<scalar_t>> matcopy_batch_default;
constexpr index_t dmin = 256, dmax = 8192;
constexpr index_t dmin = 256, dmax = 4096;
constexpr scalar_t alpha{2};
constexpr index_t batch_size{3};
constexpr index_t stride_a_mul{1};
constexpr index_t stride_b_mul{1};
constexpr index_t lda_mul = 1;
constexpr index_t ldb_mul = 1;
constexpr index_t ldc_mul = 1;
for (char trans : {'n', 't'}) {
for (index_t m = dmin; m <= dmax; m *= 2) {
for (index_t n = dmin; n <= dmax; n *= 2) {
for (index_t lda_mul = 1; lda_mul < 2; ++lda_mul) {
for (index_t ldb_mul = 1; ldb_mul < 2; ++ldb_mul) {
matcopy_batch_default.push_back(
std::make_tuple(trans, m, n, alpha, lda_mul, ldb_mul,
stride_a_mul, stride_b_mul, batch_size));
}
}
matcopy_batch_default.push_back(
std::make_tuple(trans, m, n, alpha, lda_mul, ldb_mul,
stride_a_mul, stride_b_mul, batch_size));
}
}
}
Expand Down Expand Up @@ -1386,22 +1381,19 @@ static inline std::vector<omatadd_param_t<scalar_t>> get_omatadd_params(
if (args.csv_param.empty()) {
warning_no_csv();
std::vector<omatadd_param_t<scalar_t>> omatadd_default;
constexpr index_t dmin = 64, dmax = 8192;
constexpr index_t dmin = 64, dmax = 4096;
constexpr scalar_t alpha{2};
constexpr scalar_t beta{2};
constexpr index_t lda_mul = 1;
constexpr index_t ldb_mul = 1;
constexpr index_t ldc_mul = 1;
for (char trans_a : {'n', 't'}) {
for (char trans_b : {'n', 't'}) {
for (index_t m = dmin; m <= dmax; m *= 2) {
for (index_t n = dmin; n <= dmax; n *= 2) {
for (index_t lda_mul = 1; lda_mul < 2; ++lda_mul) {
for (index_t ldb_mul = 1; ldb_mul < 2; ++ldb_mul) {
for (index_t ldc_mul = 1; ldc_mul < 2; ++ldc_mul) {
omatadd_default.push_back(
std::make_tuple(trans_a, trans_b, m, n, alpha, beta,
lda_mul, ldb_mul, ldc_mul));
}
}
}
omatadd_default.push_back(std::make_tuple(trans_a, trans_b, m, n,
alpha, beta, lda_mul,
ldb_mul, ldc_mul));
}
}
}
Expand Down Expand Up @@ -1439,27 +1431,23 @@ get_omatadd_batch_params(Args& args) {
if (args.csv_param.empty()) {
warning_no_csv();
std::vector<omatadd_batch_param_t<scalar_t>> omatadd_batch_default;
constexpr index_t dmin = 256, dmax = 8192;
constexpr index_t dmin = 1024, dmax = 4096;
constexpr scalar_t alpha{2};
constexpr scalar_t beta{2};
constexpr index_t batch_size{3};
constexpr index_t stride_a_mul{1};
constexpr index_t stride_b_mul{1};
constexpr index_t stride_c_mul{1};
constexpr index_t lda_mul = 1;
constexpr index_t ldb_mul = 1;
constexpr index_t ldc_mul = 1;
for (char trans_a : {'n', 't'}) {
for (char trans_b : {'n', 't'}) {
for (char trans_b : {'n'}) {
for (index_t m = dmin; m <= dmax; m *= 2) {
for (index_t n = dmin; n <= dmax; n *= 2) {
for (index_t lda_mul = 1; lda_mul < 2; ++lda_mul) {
for (index_t ldb_mul = 1; ldb_mul < 2; ++ldb_mul) {
for (index_t ldc_mul = 1; ldc_mul < 2; ++ldc_mul) {
omatadd_batch_default.push_back(
std::make_tuple(trans_a, trans_b, m, n, alpha, beta,
lda_mul, ldb_mul, ldc_mul, stride_a_mul,
stride_b_mul, stride_c_mul, batch_size));
}
}
}
omatadd_batch_default.push_back(std::make_tuple(
trans_a, trans_b, m, n, alpha, beta, lda_mul, ldb_mul, ldc_mul,
stride_a_mul, stride_b_mul, stride_c_mul, batch_size));
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/interface/extension/backend/default_cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ typename sb_handle_t::event_t _transpose_outplace(
container_0_t in_, index_t _ld_in, index_t _inc_in, index_t _stride_in,
container_1_t out_, index_t _ld_out, index_t _inc_out, index_t _stride_out,
index_t _batch_size, const typename sb_handle_t::event_t& _dependencies) {
if (_M * _N < (1 << 20)) {
if (_M * _N < (1 << 16)) {
return blas::internal::_transpose_outplace_impl<16, 64, 64, false>(
sb_handle, _M, _N, _alpha, in_, _ld_in, _inc_in, _stride_in, out_,
_ld_out, _inc_out, _stride_out, _batch_size, _dependencies);
} else {
return blas::internal::_transpose_outplace_impl<32, 128, 64, false>(
return blas::internal::_transpose_outplace_impl<32, 32, 64, false>(
sb_handle, _M, _N, _alpha, in_, _ld_in, _inc_in, _stride_in, out_,
_ld_out, _inc_out, _stride_out, _batch_size, _dependencies);
}
Expand All @@ -58,13 +58,13 @@ typename sb_handle_t::event_t _transpose_add(
index_t _b_rows, index_t _b_cols, index_t _stride_b, container_2_t c_,
index_t _ld_c, index_t _stride_c, index_t _batch_size,
const typename sb_handle_t::event_t& _dependencies) {
if (_M * _N < (1 << 20)) {
if (_M * _N < (1 << 16)) {
return blas::internal::_transpose_add_impl<both_trans, 16, 64, 64, false>(
sb_handle, _M, _N, _alpha, a_, _ld_a, _a_rows, _a_cols, _stride_a,
_beta, b_, _ld_b, _b_rows, _b_cols, _stride_b, c_, _ld_c, _stride_c,
_batch_size, _dependencies);
} else {
return blas::internal::_transpose_add_impl<both_trans, 32, 128, 64, false>(
return blas::internal::_transpose_add_impl<both_trans, 32, 32, 64, false>(
sb_handle, _M, _N, _alpha, a_, _ld_a, _a_rows, _a_cols, _stride_a,
_beta, b_, _ld_b, _b_rows, _b_cols, _stride_b, c_, _ld_c, _stride_c,
_batch_size, _dependencies);
Expand Down

0 comments on commit 39e3747

Please sign in to comment.