Skip to content

cpu: aarch64: Support for Bf16 LUT Eltwise Operations#4827

Open
nikhil-arm wants to merge 7 commits intouxlfoundation:mainfrom
nikhil-arm:gelu_lut
Open

cpu: aarch64: Support for Bf16 LUT Eltwise Operations#4827
nikhil-arm wants to merge 7 commits intouxlfoundation:mainfrom
nikhil-arm:gelu_lut

Conversation

@nikhil-arm
Copy link
Contributor

Description

Please include a summary of the change. Please also include relevant motivation and context. See contribution guidelines for more details. If the change fixes an issue not documented in the project's Github issue tracker, please document all steps necessary to reproduce it.

Fixes # (github issue)

Checklist

General

  • Do all unit and benchdnn tests (make test and make test_benchdnn_*) pass locally for each commit?
  • Have you formatted the code using clang-format?

Performance improvements

  • Have you submitted performance data that demonstrates performance improvements?

New features

  • Have you published an RFC for the new feature?
  • Was the RFC approved?
  • Have you added relevant tests?

Bug fixes

  • Have you included information on how to reproduce the issue (either in a github issue or in this PR)?
  • Have you added relevant regression tests?

RFC PR

  • Does RFC document follow the template?
  • Have you added a link to the rendered document?

nikhil-arm and others added 4 commits March 12, 2026 23:42
Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com>
Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com>
Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com>
@nikhil-arm nikhil-arm requested review from a team as code owners March 13, 2026 08:02
@github-actions github-actions bot added platform:cpu-aarch64 Codeowner: @oneapi-src/onednn-cpu-aarch64 component:common labels Mar 13, 2026
@nikhil-arm
Copy link
Contributor Author

Performance Improvements on 128 bit machine

BenchDNN commands

  • OMP_NUM_THREADS=1 ./build/tests/benchdnn/benchdnn --eltwise --mode=P --dt=bf16
    --alg=gelu_erf 6432x5120
  • OMP_NUM_THREADS=1 ./build/tests/benchdnn/benchdnn --eltwise --mode=P --dt=bf16
    --alg=swish --alpha=1 6432x5120
  • OMP_NUM_THREADS=1 ./build/tests/benchdnn/benchdnn --eltwise --mode=P --dt=bf16
    --alg=gelu_tanh 6432x5120
  • OMP_NUM_THREADS=1 ./build/tests/benchdnn/benchdnn --eltwise --mode=P --dt=bf16
    --alg=exp 6432x5120
  • OMP_NUM_THREADS=1 ./build/tests/benchdnn/benchdnn --eltwise --mode=P --dt=bf16
    --alg=log 6432x5120
  • OMP_NUM_THREADS=1 ./build/tests/benchdnn/benchdnn --eltwise --mode=P --dt=bf16
    --alg=sqrt 6432x5120
threads op lut avg (ms) base avg (ms) speedup
1 gelu_erf 12.7663 105.556 8.268331
1 swish 12.9223 108.851 8.423500
1 gelu_tanh 12.8577 136.556 10.620562
1 exp 12.7922 59.2863 4.634566
1 log 12.8866 80.5249 6.248731
1 sqrt 12.8670 20.4943 1.592780
8 gelu_erf 1.62322 13.2287 8.149665
8 swish 1.62900 13.6354 8.370411
8 gelu_tanh 1.63109 17.0969 10.481886
8 exp 1.62035 7.43281 4.587163
8 log 1.62276 10.0804 6.211886
8 sqrt 1.62573 2.56726 1.579143
64 gelu_erf 0.208485 1.66726 7.997026
64 swish 0.208498 1.71613 8.230918
64 gelu_tanh 0.208785 2.15408 10.317216
64 exp 0.207944 0.940454 4.522631
64 log 0.208260 1.26797 6.088399
64 sqrt 0.207303 0.328519 1.584729

@nikhil-arm nikhil-arm force-pushed the gelu_lut branch 2 times, most recently from 1707aa2 to 2e02dd6 Compare March 13, 2026 08:22
This reverts commit 67d089324e046072bc76b5244f1572e9ec1f0393.
Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com>
Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com>
Copy link
Contributor

@jondea jondea left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lovely speedups and such an elegant idea, thank you!

VDISPATCH_ELTWISE(data_type == ::dnnl::impl::data_type::bf16,
VERBOSE_UNSUPPORTED_DT);

const auto *spec = get_bf16_fwd_lut_spec_(desc()->alg_kind);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need all this spec machinery? Can we replace it all with

if (!swish){
    alpha = 0;
}
beta = 0;

const float alpha = spec->ignore_alpha_beta ? 0.f : desc()->alpha;
const float beta = spec->ignore_alpha_beta ? 0.f : desc()->beta;
bf16_lut_.resize(1u << 16);
for (uint32_t raw = 0; raw < (1u << 16); ++raw) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raw is a bit vague, consider raw -> x_u16

CPU_INSTANCE_X64(jit_uni_eltwise_int_fwd_t<avx512_core>)
CPU_INSTANCE_X64(jit_uni_eltwise_int_fwd_t<avx2>)
CPU_INSTANCE_X64(jit_uni_eltwise_int_fwd_t<sse41>)
CPU_INSTANCE_AARCH64(ref_eltwise_lut_fwd_t<bf16>)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I'd call it a ref, when we see ref in a verbose output, it usually implies it is unoptimized. I'd just call it eltwise_lut_fwd_t, I don't think we even need the template, it's not really doing anything. If we come to extend this to other data types, I think we can reconsider how.

const auto *src = reinterpret_cast<const data_t *>(src_u8);
auto *dst = reinterpret_cast<data_t *>(dst_u8);

static_assert(sizeof(data_t) == sizeof(bfloat16_t),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is needed, validation should be kept to the pd_t::init() if possible


const auto offset_bytes = types::elements_to_bytes(
pd()->src_md()->data_type, data_d.offset0());
src_u8 += offset_bytes;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to cast back and forward to u8. We can just use bf16 and do

src += data_d.offset0();

etc

VDISPATCH_ELTWISE(everyone_is(data_type, src_md()->data_type,
dst_md()->data_type),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_ELTWISE(platform::has_data_type_support(data_type),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to enforce this. I think this check is for bfdot, but your code doesn't need to use any bf16 instructions. The calculation of the LUT should automatically fallback, and the LUT access doesn't use bf16.

VDISPATCH_ELTWISE(
set_default_formats_common(), VERBOSE_UNSUPPORTED_TAG);
VDISPATCH_ELTWISE(src_d.is_dense(true), VERBOSE_NONTRIVIAL_STRIDE);
VDISPATCH_ELTWISE(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant due to line 54

VDISPATCH_ELTWISE(
attr()->has_default_values(), VERBOSE_UNSUPPORTED_ATTR);

VDISPATCH_ELTWISE(data_type == ::dnnl::impl::data_type::bf16,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant due to line 54

return status::success;
}

std::vector<bfloat16_t> bf16_lut_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the LUT should be stored in the primitive rather than primitive::pd_t, similar to how the JIT primitives do codegen.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

component:common platform:cpu-aarch64 Codeowner: @oneapi-src/onednn-cpu-aarch64

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants