Skip to content

Commit

Permalink
cpu to support bitwise ops (microsoft#14197)
Browse files Browse the repository at this point in the history
  • Loading branch information
liqunfu authored Jan 24, 2023
1 parent 77b455b commit 7b6d880
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 2 deletions.
4 changes: 4 additions & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ Do not modify directly.*
|||[9, 13]|**T** = tensor(double), tensor(float)|
|||[7, 8]|**T** = tensor(double), tensor(float)|
|BitShift|*in* X:**T**<br> *in* Y:**T**<br> *out* Z:**T**|11+|**T** = tensor(uint32), tensor(uint64), tensor(uint8)|
|BitwiseAnd|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|BitwiseNot|*in* X:**T**<br> *out* Y:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|BitwiseOr|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|BitwiseXor|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|BlackmanWindow|*in* size:**T1**<br> *out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)<br/> **T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Cast|*in* input:**T1**<br> *out* output:**T2**|13+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[6, 12]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down
64 changes: 64 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,38 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceSumSquare);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceSumSquare);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceSumSquare);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, BitwiseAnd);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int16_t, BitwiseAnd);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, BitwiseAnd);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, BitwiseAnd);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint8_t, BitwiseAnd);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint16_t, BitwiseAnd);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint32_t, BitwiseAnd);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint64_t, BitwiseAnd);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, BitwiseNot);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int16_t, BitwiseNot);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, BitwiseNot);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, BitwiseNot);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint8_t, BitwiseNot);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint16_t, BitwiseNot);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint32_t, BitwiseNot);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint64_t, BitwiseNot);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, BitwiseOr);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int16_t, BitwiseOr);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, BitwiseOr);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, BitwiseOr);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint8_t, BitwiseOr);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint16_t, BitwiseOr);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint32_t, BitwiseOr);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint64_t, BitwiseOr);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, BitwiseXor);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int16_t, BitwiseXor);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, BitwiseXor);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, BitwiseXor);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint8_t, BitwiseXor);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint16_t, BitwiseXor);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint32_t, BitwiseXor);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint64_t, BitwiseXor);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Pad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterND);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterElements);
Expand Down Expand Up @@ -2131,6 +2163,38 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
ReduceSumSquare)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double,
ReduceSumSquare)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, BitwiseAnd)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int16_t, BitwiseAnd)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, BitwiseAnd)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, BitwiseAnd)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint8_t, BitwiseAnd)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint16_t, BitwiseAnd)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint32_t, BitwiseAnd)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint64_t, BitwiseAnd)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, BitwiseNot)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int16_t, BitwiseNot)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, BitwiseNot)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, BitwiseNot)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint8_t, BitwiseNot)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint16_t, BitwiseNot)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint32_t, BitwiseNot)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint64_t, BitwiseNot)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, BitwiseOr)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int16_t, BitwiseOr)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, BitwiseOr)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, BitwiseOr)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint8_t, BitwiseOr)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint16_t, BitwiseOr)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint32_t, BitwiseOr)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint64_t, BitwiseOr)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, BitwiseXor)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int16_t, BitwiseXor)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, BitwiseXor)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, BitwiseXor)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint8_t, BitwiseXor)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint16_t, BitwiseXor)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint32_t, BitwiseXor)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint64_t, BitwiseXor)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterND)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterElements)>,
Expand Down
153 changes: 152 additions & 1 deletion onnxruntime/core/providers/cpu/math/element_wise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,42 @@ REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint8_t, BitShift);
REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint32_t, BitShift);
REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint64_t, BitShift);

REG_ELEMENTWISE_TYPED_KERNEL(BitwiseAnd, 18, int8_t, BitwiseAnd);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseAnd, 18, int16_t, BitwiseAnd);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseAnd, 18, int32_t, BitwiseAnd);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseAnd, 18, int64_t, BitwiseAnd);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseAnd, 18, uint8_t, BitwiseAnd);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseAnd, 18, uint16_t, BitwiseAnd);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseAnd, 18, uint32_t, BitwiseAnd);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseAnd, 18, uint64_t, BitwiseAnd);

REG_ELEMENTWISE_TYPED_KERNEL(BitwiseNot, 18, int8_t, BitwiseNot);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseNot, 18, int16_t, BitwiseNot);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseNot, 18, int32_t, BitwiseNot);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseNot, 18, int64_t, BitwiseNot);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseNot, 18, uint8_t, BitwiseNot);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseNot, 18, uint16_t, BitwiseNot);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseNot, 18, uint32_t, BitwiseNot);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseNot, 18, uint64_t, BitwiseNot);

REG_ELEMENTWISE_TYPED_KERNEL(BitwiseOr, 18, int8_t, BitwiseOr);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseOr, 18, int16_t, BitwiseOr);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseOr, 18, int32_t, BitwiseOr);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseOr, 18, int64_t, BitwiseOr);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseOr, 18, uint8_t, BitwiseOr);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseOr, 18, uint16_t, BitwiseOr);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseOr, 18, uint32_t, BitwiseOr);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseOr, 18, uint64_t, BitwiseOr);

REG_ELEMENTWISE_TYPED_KERNEL(BitwiseXor, 18, int8_t, BitwiseXor);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseXor, 18, int16_t, BitwiseXor);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseXor, 18, int32_t, BitwiseXor);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseXor, 18, int64_t, BitwiseXor);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseXor, 18, uint8_t, BitwiseXor);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseXor, 18, uint16_t, BitwiseXor);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseXor, 18, uint32_t, BitwiseXor);
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseXor, 18, uint64_t, BitwiseXor);

REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Erf, 9, 12, float, Erf);
// Supposed to add BFloat16 but we are not supporting now, however, separate registration
REG_ELEMENTWISE_TYPED_KERNEL(Erf, 13, float, Erf);
Expand Down Expand Up @@ -1155,7 +1191,122 @@ Status BitShift<T>::Compute(OpKernelContext* context) const {
}

template <typename T>
class Sin final : public OpKernel {
Status BitwiseAnd<T>::Compute(OpKernelContext* context) const {
ProcessBroadcastSpanFuncs funcs {
[](BroadcastHelper& per_iter_bh) {
const T X = per_iter_bh.ScalarInput0<T>();
auto Y = per_iter_bh.SpanInput1<T>();
auto output = per_iter_bh.OutputSpan<T>();

std::transform(Y.begin(), Y.end(), output.begin(),
[X](T y) {
return std::bit_and<T>()(X, y);
});
},
[](BroadcastHelper& per_iter_bh) {
auto X = per_iter_bh.SpanInput0<T>();
const T Y = per_iter_bh.ScalarInput1<T>();
auto output = per_iter_bh.OutputSpan<T>();

std::transform(X.begin(), X.end(), output.begin(),
[Y](T x) {
return static_cast<T>(std::bit_and<T>()(x, Y));
});
},
[](BroadcastHelper& per_iter_bh) {
auto X = per_iter_bh.SpanInput0<T>();
auto Y = per_iter_bh.SpanInput1<T>();
auto output = per_iter_bh.OutputSpan<T>();

std::transform(X.begin(), X.end(), Y.begin(), output.begin(), std::bit_and<T>());
}};

UntypedBroadcastTwo(*context, funcs, 1.0f);
return Status::OK();
}

template <typename T>
Status BitwiseNot<T>::Compute(OpKernelContext* context) const {
auto& input = *context->Input<Tensor>(0);
auto& output = *context->Output(0, input.Shape());

std::transform(EigenMap<T>(input).array().begin(), EigenMap<T>(input).array().end(), EigenMap<T>(output).array().begin(), std::bit_not<T>());

return Status::OK();
}

template <typename T>
Status BitwiseOr<T>::Compute(OpKernelContext* context) const {
ProcessBroadcastSpanFuncs funcs{
[](BroadcastHelper& per_iter_bh) {
const T X = per_iter_bh.ScalarInput0<T>();
auto Y = per_iter_bh.SpanInput1<T>();
auto output = per_iter_bh.OutputSpan<T>();

std::transform(Y.begin(), Y.end(), output.begin(),
[X](T y) {
return std::bit_or<T>()(X, y);
});
},
[](BroadcastHelper& per_iter_bh) {
auto X = per_iter_bh.SpanInput0<T>();
const T Y = per_iter_bh.ScalarInput1<T>();
auto output = per_iter_bh.OutputSpan<T>();

std::transform(X.begin(), X.end(), output.begin(),
[Y](T x) {
return static_cast<T>(std::bit_or<T>()(x, Y));
});
},
[](BroadcastHelper& per_iter_bh) {
auto X = per_iter_bh.SpanInput0<T>();
auto Y = per_iter_bh.SpanInput1<T>();
auto output = per_iter_bh.OutputSpan<T>();

std::transform(X.begin(), X.end(), Y.begin(), output.begin(), std::bit_or<T>());
}};

UntypedBroadcastTwo(*context, funcs, 1.0f);
return Status::OK();
}

template <typename T>
Status BitwiseXor<T>::Compute(OpKernelContext* context) const {
ProcessBroadcastSpanFuncs funcs{
[](BroadcastHelper& per_iter_bh) {
const T X = per_iter_bh.ScalarInput0<T>();
auto Y = per_iter_bh.SpanInput1<T>();
auto output = per_iter_bh.OutputSpan<T>();

std::transform(Y.begin(), Y.end(), output.begin(),
[X](T y) {
return std::bit_xor<T>()(X, y);
});
},
[](BroadcastHelper& per_iter_bh) {
auto X = per_iter_bh.SpanInput0<T>();
const T Y = per_iter_bh.ScalarInput1<T>();
auto output = per_iter_bh.OutputSpan<T>();

std::transform(X.begin(), X.end(), output.begin(),
[Y](T x) {
return static_cast<T>(std::bit_xor<T>()(x, Y));
});
},
[](BroadcastHelper& per_iter_bh) {
auto X = per_iter_bh.SpanInput0<T>();
auto Y = per_iter_bh.SpanInput1<T>();
auto output = per_iter_bh.OutputSpan<T>();

std::transform(X.begin(), X.end(), Y.begin(), output.begin(), std::bit_xor<T>());
}};

UntypedBroadcastTwo(*context, funcs, 1.0f);
return Status::OK();
}

template <typename T>
class Sin final : public OpKernel {
public:
Sin(const OpKernelInfo& info) : OpKernel(info) {
}
Expand Down
36 changes: 36 additions & 0 deletions onnxruntime/core/providers/cpu/math/element_wise_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,42 @@ class BitShift final : public OpKernel {
bool shift_left_;
};

template <typename T>
class BitwiseAnd final : public OpKernel {
public:
explicit BitwiseAnd(const OpKernelInfo& info) : OpKernel(info) {
}

Status Compute(OpKernelContext* context) const override;
};

template <typename T>
class BitwiseNot final : public OpKernel {
public:
explicit BitwiseNot(const OpKernelInfo& info) : OpKernel(info) {
}

Status Compute(OpKernelContext* context) const override;
};

template <typename T>
class BitwiseOr final : public OpKernel {
public:
explicit BitwiseOr(const OpKernelInfo& info) : OpKernel(info) {
}

Status Compute(OpKernelContext* context) const override;
};

template <typename T>
class BitwiseXor final : public OpKernel {
public:
explicit BitwiseXor(const OpKernelInfo& info) : OpKernel(info) {
}

Status Compute(OpKernelContext* context) const override;
};

// PRelu is activation function, but it's closer to binary elementwise ops in implementation
template <typename T>
class PRelu final : public OpKernel {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@
"^test_div_uint8_cuda",
"^test_add_uint8_cuda",
"^test_roialign_aligned_*",
"^test_bitwise_*",
"^test_clip_default_int8_max_expanded_cpu",
"^test_clip_default_int8_min_expanded_cpu",
"^test_col2im_*",
Expand Down

0 comments on commit 7b6d880

Please sign in to comment.