From 7b6d880b28db5e52ddaecb0f09533027daab14f2 Mon Sep 17 00:00:00 2001 From: liqun Fu Date: Mon, 23 Jan 2023 16:42:18 -0800 Subject: [PATCH] cpu to support bitwise ops (#14197) --- docs/OperatorKernels.md | 4 + .../providers/cpu/cpu_execution_provider.cc | 64 ++++++++ .../providers/cpu/math/element_wise_ops.cc | 153 +++++++++++++++++- .../providers/cpu/math/element_wise_ops.h | 36 +++++ .../onnx_backend_test_series_filters.jsonc | 1 - 5 files changed, 256 insertions(+), 2 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 964799c3a0234..27d511c55dc90 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -44,6 +44,10 @@ Do not modify directly.* |||[9, 13]|**T** = tensor(double), tensor(float)| |||[7, 8]|**T** = tensor(double), tensor(float)| |BitShift|*in* X:**T**
*in* Y:**T**
*out* Z:**T**|11+|**T** = tensor(uint32), tensor(uint64), tensor(uint8)| +|BitwiseAnd|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|BitwiseNot|*in* X:**T**
*out* Y:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|BitwiseOr|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|BitwiseXor|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |BlackmanWindow|*in* size:**T1**
*out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Cast|*in* input:**T1**
*out* output:**T2**|13+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[6, 12]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index b7c369e173c49..3bcef3d9ff7a3 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -830,6 +830,38 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceSumSquare); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceSumSquare); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceSumSquare); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, BitwiseAnd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int16_t, BitwiseAnd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, BitwiseAnd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, BitwiseAnd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint8_t, BitwiseAnd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint16_t, BitwiseAnd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint32_t, BitwiseAnd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint64_t, BitwiseAnd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, BitwiseNot); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int16_t, BitwiseNot); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, BitwiseNot); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, BitwiseNot); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint8_t, BitwiseNot); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint16_t, BitwiseNot); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint32_t, BitwiseNot); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint64_t, BitwiseNot); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, BitwiseOr); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int16_t, BitwiseOr); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, BitwiseOr); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, BitwiseOr); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint8_t, BitwiseOr); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint16_t, BitwiseOr); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint32_t, BitwiseOr); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint64_t, BitwiseOr); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, BitwiseXor); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int16_t, BitwiseXor); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, BitwiseXor); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, BitwiseXor); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint8_t, BitwiseXor); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint16_t, BitwiseXor); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint32_t, BitwiseXor); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint64_t, BitwiseXor); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterElements); @@ -2131,6 +2163,38 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { ReduceSumSquare)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 08e15251ed2c1..99eb09bbbd2ec 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -374,6 +374,42 @@ REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint8_t, BitShift); REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint32_t, BitShift); REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint64_t, BitShift); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseAnd, 18, int8_t, BitwiseAnd); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseAnd, 18, int16_t, BitwiseAnd); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseAnd, 18, int32_t, BitwiseAnd); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseAnd, 18, int64_t, BitwiseAnd); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseAnd, 18, uint8_t, BitwiseAnd); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseAnd, 18, uint16_t, BitwiseAnd); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseAnd, 18, uint32_t, BitwiseAnd); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseAnd, 18, uint64_t, BitwiseAnd); + +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseNot, 18, int8_t, BitwiseNot); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseNot, 18, int16_t, BitwiseNot); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseNot, 18, int32_t, BitwiseNot); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseNot, 18, int64_t, BitwiseNot); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseNot, 18, uint8_t, BitwiseNot); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseNot, 18, uint16_t, BitwiseNot); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseNot, 18, uint32_t, BitwiseNot); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseNot, 18, uint64_t, BitwiseNot); + +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseOr, 18, int8_t, BitwiseOr); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseOr, 18, int16_t, BitwiseOr); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseOr, 18, int32_t, BitwiseOr); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseOr, 18, int64_t, BitwiseOr); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseOr, 18, uint8_t, BitwiseOr); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseOr, 18, uint16_t, BitwiseOr); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseOr, 18, uint32_t, BitwiseOr); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseOr, 18, uint64_t, BitwiseOr); + +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseXor, 18, int8_t, BitwiseXor); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseXor, 18, int16_t, BitwiseXor); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseXor, 18, int32_t, BitwiseXor); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseXor, 18, int64_t, BitwiseXor); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseXor, 18, uint8_t, BitwiseXor); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseXor, 18, uint16_t, BitwiseXor); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseXor, 18, uint32_t, BitwiseXor); +REG_ELEMENTWISE_TYPED_KERNEL(BitwiseXor, 18, uint64_t, BitwiseXor); + REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Erf, 9, 12, float, Erf); // Supposed to add BFloat16 but we are not supporting now, however, separate registration REG_ELEMENTWISE_TYPED_KERNEL(Erf, 13, float, Erf); @@ -1155,7 +1191,122 @@ Status BitShift::Compute(OpKernelContext* context) const { } template -class Sin final : public OpKernel { +Status BitwiseAnd::Compute(OpKernelContext* context) const { + ProcessBroadcastSpanFuncs funcs { + [](BroadcastHelper& per_iter_bh) { + const T X = per_iter_bh.ScalarInput0(); + auto Y = per_iter_bh.SpanInput1(); + auto output = per_iter_bh.OutputSpan(); + + std::transform(Y.begin(), Y.end(), output.begin(), + [X](T y) { + return std::bit_and()(X, y); + }); + }, + [](BroadcastHelper& per_iter_bh) { + auto X = per_iter_bh.SpanInput0(); + const T Y = per_iter_bh.ScalarInput1(); + auto output = per_iter_bh.OutputSpan(); + + std::transform(X.begin(), X.end(), output.begin(), + [Y](T x) { + return static_cast(std::bit_and()(x, Y)); + }); + }, + [](BroadcastHelper& per_iter_bh) { + auto X = per_iter_bh.SpanInput0(); + auto Y = per_iter_bh.SpanInput1(); + auto output = per_iter_bh.OutputSpan(); + + std::transform(X.begin(), X.end(), Y.begin(), output.begin(), std::bit_and()); + }}; + + UntypedBroadcastTwo(*context, funcs, 1.0f); + return Status::OK(); +} + +template +Status BitwiseNot::Compute(OpKernelContext* context) const { + auto& input = *context->Input(0); + auto& output = *context->Output(0, input.Shape()); + + std::transform(EigenMap(input).array().begin(), EigenMap(input).array().end(), EigenMap(output).array().begin(), std::bit_not()); + + return Status::OK(); +} + +template +Status BitwiseOr::Compute(OpKernelContext* context) const { + ProcessBroadcastSpanFuncs funcs{ + [](BroadcastHelper& per_iter_bh) { + const T X = per_iter_bh.ScalarInput0(); + auto Y = per_iter_bh.SpanInput1(); + auto output = per_iter_bh.OutputSpan(); + + std::transform(Y.begin(), Y.end(), output.begin(), + [X](T y) { + return std::bit_or()(X, y); + }); + }, + [](BroadcastHelper& per_iter_bh) { + auto X = per_iter_bh.SpanInput0(); + const T Y = per_iter_bh.ScalarInput1(); + auto output = per_iter_bh.OutputSpan(); + + std::transform(X.begin(), X.end(), output.begin(), + [Y](T x) { + return static_cast(std::bit_or()(x, Y)); + }); + }, + [](BroadcastHelper& per_iter_bh) { + auto X = per_iter_bh.SpanInput0(); + auto Y = per_iter_bh.SpanInput1(); + auto output = per_iter_bh.OutputSpan(); + + std::transform(X.begin(), X.end(), Y.begin(), output.begin(), std::bit_or()); + }}; + + UntypedBroadcastTwo(*context, funcs, 1.0f); + return Status::OK(); +} + +template +Status BitwiseXor::Compute(OpKernelContext* context) const { + ProcessBroadcastSpanFuncs funcs{ + [](BroadcastHelper& per_iter_bh) { + const T X = per_iter_bh.ScalarInput0(); + auto Y = per_iter_bh.SpanInput1(); + auto output = per_iter_bh.OutputSpan(); + + std::transform(Y.begin(), Y.end(), output.begin(), + [X](T y) { + return std::bit_xor()(X, y); + }); + }, + [](BroadcastHelper& per_iter_bh) { + auto X = per_iter_bh.SpanInput0(); + const T Y = per_iter_bh.ScalarInput1(); + auto output = per_iter_bh.OutputSpan(); + + std::transform(X.begin(), X.end(), output.begin(), + [Y](T x) { + return static_cast(std::bit_xor()(x, Y)); + }); + }, + [](BroadcastHelper& per_iter_bh) { + auto X = per_iter_bh.SpanInput0(); + auto Y = per_iter_bh.SpanInput1(); + auto output = per_iter_bh.OutputSpan(); + + std::transform(X.begin(), X.end(), Y.begin(), output.begin(), std::bit_xor()); + }}; + + UntypedBroadcastTwo(*context, funcs, 1.0f); + return Status::OK(); +} + +template + class Sin final : public OpKernel { public: Sin(const OpKernelInfo& info) : OpKernel(info) { } diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.h b/onnxruntime/core/providers/cpu/math/element_wise_ops.h index b7deea9625634..6b0254ff53926 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.h +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.h @@ -427,6 +427,42 @@ class BitShift final : public OpKernel { bool shift_left_; }; +template +class BitwiseAnd final : public OpKernel { + public: + explicit BitwiseAnd(const OpKernelInfo& info) : OpKernel(info) { + } + + Status Compute(OpKernelContext* context) const override; +}; + +template +class BitwiseNot final : public OpKernel { + public: + explicit BitwiseNot(const OpKernelInfo& info) : OpKernel(info) { + } + + Status Compute(OpKernelContext* context) const override; +}; + +template +class BitwiseOr final : public OpKernel { + public: + explicit BitwiseOr(const OpKernelInfo& info) : OpKernel(info) { + } + + Status Compute(OpKernelContext* context) const override; +}; + +template +class BitwiseXor final : public OpKernel { + public: + explicit BitwiseXor(const OpKernelInfo& info) : OpKernel(info) { + } + + Status Compute(OpKernelContext* context) const override; +}; + // PRelu is activation function, but it's closer to binary elementwise ops in implementation template class PRelu final : public OpKernel { diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 57a2eda7df33f..6bdfe58cc2f9a 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -116,7 +116,6 @@ "^test_div_uint8_cuda", "^test_add_uint8_cuda", "^test_roialign_aligned_*", - "^test_bitwise_*", "^test_clip_default_int8_max_expanded_cpu", "^test_clip_default_int8_min_expanded_cpu", "^test_col2im_*",