From c47a6ce70b80d5ca83e851d6ddfeab12af3e0941 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Tue, 23 Apr 2024 11:50:31 +1000 Subject: [PATCH] XNNPACK: Support 1D input for Conv and ConvTranspose (#20349) ### Description Support 1D input to XNNPACK Conv and ConvTranspose by using faking height of 1 to convert to 2D input. ### Motivation and Context Enable speech model with 1D input to use XNNPACK. There is no CPU EP quantized ConvTranspose, so this fills that gap. --- .../core/providers/cpu/nn/conv_attributes.h | 3 +- .../cpu/nn/conv_transpose_attributes.h | 20 +- .../providers/cuda/cuda_execution_provider.cc | 15 +- .../core/providers/cuda/nn/conv_transpose.cc | 138 +++++++--- .../OperatorAuthorHelper/OperatorHelper.cpp | 8 +- onnxruntime/core/providers/xnnpack/nn/conv.cc | 40 ++- .../core/providers/xnnpack/nn/conv_base.cc | 80 ++++-- .../providers/xnnpack/nn/conv_transpose.cc | 107 +++++--- .../cpu/nn/conv_transpose_op_test.cc | 250 +++++++++--------- .../cuda/nhwc/conv_transpose_test.cc | 48 +++- .../providers/cuda/nhwc/nhwc_cuda_helper.h | 59 +++-- 11 files changed, 504 insertions(+), 264 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/conv_attributes.h b/onnxruntime/core/providers/cpu/nn/conv_attributes.h index 9e57f4632a9c3..170f313c8fe80 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/conv_attributes.h @@ -73,7 +73,8 @@ struct ConvAttributes { ~ConvAttributes() = default; - Status ComputeKernelShape(const TensorShape& weight_shape, TensorShapeVector& kernel_shape, bool weight_channels_last = false) const { + Status ComputeKernelShape(const TensorShape& weight_shape, TensorShapeVector& kernel_shape, + bool weight_channels_last = false) const { if (kernel_shape_specified) { kernel_shape = kernel_shape_; if (kernel_shape.size() + 2 != weight_shape.NumDimensions()) { diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h b/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h index 4b3b934834ac8..7bd37f28283a0 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h @@ -44,9 +44,15 @@ struct ConvTransposeAttributes : public ConvAttributes { TensorShapeVector strides; }; + // Viewing dim 1 of the X input as 'input channels' (C) and dim 1 of the Y output as 'output channels' (M), + // if is_nhwc is true, the input channels (dim 0) or output channels (dim 1) of the W input (the filter with + // shape {C, M/group, ...}) could be transposed to be last. transposed_input_channels indicates whether dim 0 or + // dim 1 was moved. + // + // e.g. XNNPACK moves the input channels dim to the end. CUDA moves the output channels dim to the end. Status PrepareForCompute(OpKernelContext* context, bool has_bias, Prepare& p, bool dynamic_padding = false, const TensorShape* filter_shape = nullptr, - bool is_nhwc = false) const { + bool is_nhwc = false, bool transposed_input_channels = true) const { const Tensor* X = context->Input(0); const Tensor* F = (filter_shape != nullptr) ? nullptr : context->Input(1); const TensorShape& F_Shape = (filter_shape != nullptr) ? *filter_shape : F->Shape(); @@ -57,7 +63,12 @@ struct ConvTransposeAttributes : public ConvAttributes { TensorShape input_shape = X->Shape().Slice(is_nhwc ? 1 : 2, is_nhwc ? rank - 1 : rank); const int64_t num_input_channels = is_nhwc ? X->Shape()[rank - 1] : X->Shape()[1]; const int64_t N = X->Shape()[0]; - const int64_t num_output_channels_multiplier = is_nhwc ? F_Shape[3] : F_Shape[1]; + + // W is {C, M/group, ....}. adjust for NHWC and transposed_input_channels + // If we transposed the input channels, {C, M/group, ...} becomes {M/group, ..., C} + // If we transposed the output channels, {C, M/group, ...} becomes {C, ..., M/group} + const auto M_div_group_dim = is_nhwc ? (transposed_input_channels ? 0 : F_Shape.NumDimensions() - 1) : 1; + const int64_t num_output_channels_multiplier = F_Shape[M_div_group_dim]; const int64_t num_output_channels = num_output_channels_multiplier * group; // input validations @@ -72,9 +83,10 @@ struct ConvTransposeAttributes : public ConvAttributes { " W: ", F_Shape.ToString().c_str()); } - if (F_Shape[0] != num_input_channels) { + const auto F_channels_dim = is_nhwc && transposed_input_channels ? F_Shape.NumDimensions() - 1 : 0; + if (F_Shape[F_channels_dim] != num_input_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "filter number not equal to input channel number.", - " filter_number: ", F_Shape[0], + " filter_number: ", F_Shape[F_channels_dim], " num_input_channels: ", num_input_channels); } diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 05d9f3b5a1e8f..3b1698773b85b 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -2394,7 +2394,9 @@ static bool RNNNeedFallbackToCPU(const onnxruntime::Node& node, return false; } -static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node, const logging::Logger& logger) { +static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node, const logging::Logger& logger, + [[maybe_unused]] const GraphViewer& graph_viewer, + [[maybe_unused]] const bool prefer_nhwc) { const auto& node_attributes = node.GetAttributes(); // Check attributes for (auto& attr : node_attributes) { @@ -2442,6 +2444,15 @@ static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node, const } } +#ifdef ENABLE_CUDA_NHWC_OPS + if (prefer_nhwc) { + // NHWC implementation doesn't handle transpose of W if it's not an initializer + if (!graph_viewer.IsConstantInitializer(node.InputDefs()[1]->Name(), true)) { + return true; + } + } +#endif + return false; } @@ -2510,7 +2521,7 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, not_supported = RNNNeedFallbackToCPU(node, activations_supported, node.OpType()); force_inside = !not_supported; } else if ("ConvTranspose" == node.OpType()) { - not_supported = ConvTransposeNeedFallbackToCPU(node, logger); + not_supported = ConvTransposeNeedFallbackToCPU(node, logger, graph, IsNHWCPreferred()); force_inside = !not_supported; } else if ("Cast" == node.OpType()) { not_supported = CastNeedFallbackToCPU(node); diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index 939b9959af818..bac99d6a81ed2 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -45,28 +45,47 @@ Status ConvTranspose::ComputeInternal(OpKernelContext* context) const { template Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, - [[maybe_unused]] PrePackedWeights* prepacked_weights) { + PrePackedWeights* prepacked_weights) { is_packed = false; // only layout of weight input is adjusted via PrePack - if (NHWC) { // InputTensors::IN_W + if constexpr (NHWC) { // InputTensors::IN_W if (input_idx == 1) { - // Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} auto orig_shape = tensor.Shape(); + const auto rank = orig_shape.NumDimensions(); + + InlinedVector perm; + TensorShapeVector new_dims; + + // Input is { N, C, ...}. Output is { N, M, ...}. 'input channels' is C. 'output channels' is M. + // Transpose the output channels related dimension (M/group) to be last. Leave the input channels as-is. + if (rank == 3) { + // Transpose from {C, M/group, k1} to {C, k1, M/group} + perm = {0, 2, 1}; + new_dims = TensorShapeVector{orig_shape[0], orig_shape[2], orig_shape[1]}; + } else if (rank == 4) { + // Transpose from {C, M/group, kH, kW} to {C, kH, kW, M/group} + perm = {0, 2, 3, 1}; + new_dims = TensorShapeVector{orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[1]}; + } else if (rank == 5) { + // Transpose from {C, M/group, k1, k2, k3} to {C, k1, k2, k3, M/group} + perm = {0, 2, 3, 4, 1}; + new_dims = TensorShapeVector{orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[4], orig_shape[1]}; + } - InlinedVector perm{0, 2, 3, 1}; - gsl::span permutation(perm.data(), 4); - TensorShapeVector new_dims{orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[1]}; + gsl::span permutation(perm.data(), rank); W_ = Tensor::Create(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); - auto status = cuda::Transpose::DoTranspose(GetDeviceProp(), DefaultCudaStream(), DefaultCublasHandle(), - permutation, tensor, *W_); + ORT_RETURN_IF_ERROR(cuda::Transpose::DoTranspose(GetDeviceProp(), DefaultCudaStream(), DefaultCublasHandle(), + permutation, tensor, *W_)); - if (!status.IsOK()) { - return status; - } CUDA_CALL_THROW(cudaStreamSynchronize(DefaultCudaStream())); is_packed = true; } + } else { + ORT_UNUSED_PARAMETER(tensor); + ORT_UNUSED_PARAMETER(input_idx); + ORT_UNUSED_PARAMETER(alloc); + ORT_UNUSED_PARAMETER(prepacked_weights); } return Status::OK(); @@ -87,12 +106,10 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input X must be 3-, 4- or 5-dimensional.", " X: ", X->Shape().ToString().c_str()); } - const Tensor* W; - if (!W_) { - W = context->Input(1); - } else { - W = W_.get(); - } + + // use pre-packed W if available + const Tensor* W = W_ ? W_.get() : context->Input(1); + const TensorShape& w_shape = W->Shape(); TensorShapeVector w_dims = w_shape.AsShapeVector(); auto w_data = reinterpret_cast(W->Data()); @@ -101,9 +118,38 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy bool has_bias = dynamic_padding ? num_inputs == 4 : num_inputs == 3; CudaT* y_data = nullptr; + + const auto* cuda_ep = static_cast(Info().GetExecutionProvider()); + + // convert 1D to 2D if (x_dimensions == 3) { - x_dims.insert(x_dims.begin() + 2, 1); - w_dims.insert(w_dims.begin() + 2, 1); + // we can either add a fake H or W dimension with a value of 1. to be consistent with the Conv behavior we use + // GetCudnnConv1dPadToNc1d to determine which is added. + // see Conv::UpdateState in /onnxruntime/core/providers/cuda/nn/conv.cc for more details. + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + // add fake H dimension + const auto insert_at = NHWC ? 1 : 2; + + // NCHW: N, C, d1 -> N, C, 1, d1 + // NHWC: N, d1, C -> N, 1, d1, C + x_dims.insert(x_dims.begin() + insert_at, 1); + + // 'M' is channels dim in CUDA implementation + // NCHW: C, M/g, k1 -> C, M/g, 1, k1 + // NHWC: C, k1, M/g -> C, 1, k1, M/g + w_dims.insert(w_dims.begin() + insert_at, 1); + } else { + // add fake W dimension + const auto insert_at = NHWC ? 2 : 3; + + // NCHW: N, C, d1 -> N, C, d1, 1 + // NHWC: N, d1, C -> N, d1, 1, C + x_dims.insert(x_dims.begin() + insert_at, 1); + + // NCHW: C, M/g, k1 -> C, M/g, k1, 1 + // NHWC: C, k1, M/g -> C, k1, 1, M/g + w_dims.insert(w_dims.begin() + insert_at, 1); + } } { @@ -113,7 +159,9 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy bool input_dims_changed = (s_.last_x_dims.AsShapeVector() != x_dims); bool w_dims_changed = (s_.last_w_dims.AsShapeVector() != w_dims); if (input_dims_changed || w_dims_changed) { - if (input_dims_changed) s_.last_x_dims = gsl::make_span(x_dims); + if (input_dims_changed) { + s_.last_x_dims = gsl::make_span(x_dims); + } if (w_dims_changed) { s_.last_w_dims = gsl::make_span(w_dims); @@ -121,22 +169,40 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy } ConvTransposeAttributes::Prepare p; + // PrePack moves the M/group dimension of W to the end, with 'M' being interpreted as 'output channels' + const bool transposed_input_channels = false; ORT_RETURN_IF_ERROR( - conv_transpose_attrs_.PrepareForCompute(context, has_bias, p, dynamic_padding, &w_shape, NHWC)); + conv_transpose_attrs_.PrepareForCompute(context, has_bias, p, dynamic_padding, &w_shape, NHWC, transposed_input_channels)); auto y_dims = p.Y->Shape().AsShapeVector(); if (x_dimensions == 3) { - y_dims.insert(y_dims.begin() + 2, 1); - p.kernel_shape.insert(p.kernel_shape.begin(), 1); - p.pads.insert(p.pads.begin(), 0); - p.pads.insert(p.pads.begin() + 2, 0); - p.strides.insert(p.strides.begin(), 1); - p.dilations.insert(p.dilations.begin(), 1); + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + // add fake H dimension of 1 + // NCHW: N, M, d1 -> N, M, 1, d1 or + // NHWC: N, d1, M -> N, 1, d1, M + y_dims.insert(y_dims.begin() + (NHWC ? 1 : 2), 1); + p.kernel_shape.insert(p.kernel_shape.begin(), 1); + p.pads.insert(p.pads.begin(), 0); + p.pads.insert(p.pads.begin() + 2, 0); + p.strides.insert(p.strides.begin(), 1); + p.dilations.insert(p.dilations.begin(), 1); + } else { + // add fake W dimension of 1 + // NCHW: N, M, d1 -> N, M, d1, 1 or + // NHWC: N, d1, M -> N, d1, 1, M + y_dims.insert(y_dims.begin() + (NHWC ? 2 : 3), 1); + p.kernel_shape.push_back(1); + p.pads.insert(p.pads.begin() + 1, 0); + p.pads.push_back(0); + p.strides.push_back(1); + p.dilations.push_back(1); + } } + s_.y_dims = gsl::make_span(y_dims); if (w_dims_changed) { - if (NHWC) { + if constexpr (NHWC) { ORT_RETURN_IF_ERROR(s_.w_desc.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), static_cast(w_dims[0]), static_cast(w_dims[3]), static_cast(w_dims[1]), static_cast(w_dims[2]))); @@ -152,7 +218,8 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy if (p.Y->Shape().Size() == 0) { return Status::OK(); } - if (NHWC) { + + if constexpr (NHWC) { ORT_RETURN_IF_ERROR(s_.x_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), static_cast(x_dims[0]), static_cast(x_dims[3]), static_cast(x_dims[1]), static_cast(x_dims[2]))); @@ -176,7 +243,9 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy TensorShapeVector b_dims(2 + p.kernel_shape.size()); b_dims[0] = 1; // N b_dims[NHWC ? 3 : 1] = b_shape[0]; // C - for (size_t i = 0; i < p.kernel_shape.size(); i++) b_dims[(NHWC ? 1 : 2) + i] = 1; + for (size_t i = 0; i < p.kernel_shape.size(); i++) { + b_dims[(NHWC ? 1 : 2) + i] = 1; + } ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType(), NHWC)); } @@ -215,8 +284,15 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy if (!y_data) { auto y_dims = s_.y_dims.AsShapeVector(); if (x_dimensions == 3) { - y_dims.erase(y_dims.begin() + 2); + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + // erase the fake H dimension + y_dims.erase(y_dims.begin() + (NHWC ? 1 : 2)); + } else { + // erase the fake W dimension + y_dims.erase(y_dims.begin() + (NHWC ? 2 : 3)); + } } + Tensor* Y = context->Output(0, TensorShape(y_dims)); y_data = reinterpret_cast(Y->MutableData()); diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index e308f76d4b827..44c63089564d3 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -839,13 +839,15 @@ namespace OperatorHelper if (outputShape.size() > 2) { - ML_CHECK_VALID_ARGUMENT(outputShape[outputShape.size() - 3] == gsl::narrow_cast(m_outputShapes[0].GetShape()[C]), "Output channel must be equivalent to filter channel."); - } + ML_CHECK_VALID_ARGUMENT(outputShape[C] == gsl::narrow_cast(m_outputShapes[0].GetShape()[C]), + "Output channel must be equivalent to filter channel."); + } for (size_t i = 0; i < m_kernel.spatialDimensionCount; ++i) { size_t outputIndex = outputShape.size() - m_kernel.spatialDimensionCount + i; - ML_CHECK_VALID_ARGUMENT(outputShape[outputIndex] >= gsl::narrow_cast(inputDimensions[H + i]), "Output dimension cannot be smaller than input dimension."); + ML_CHECK_VALID_ARGUMENT(outputShape[outputIndex] >= gsl::narrow_cast(inputDimensions[H + i]), + "Output dimension cannot be smaller than input dimension."); m_outputShapes[0].GetShape()[H + i] = outputShape[outputIndex]; } diff --git a/onnxruntime/core/providers/xnnpack/nn/conv.cc b/onnxruntime/core/providers/xnnpack/nn/conv.cc index 0cdb9c840aa2d..0366d9f893f7e 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv.cc @@ -3,6 +3,8 @@ #include "conv.h" +#include + #include "core/common/gsl.h" #include "core/common/inlined_containers_fwd.h" #include "core/framework/tensorprotoutils.h" @@ -24,16 +26,30 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, (conv_type_ != OpComputeType::op_compute_type_fp32 && input_idx == 3)) { // InputTensors::IN_W // Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} auto orig_shape = tensor.Shape(); + const auto rank = orig_shape.NumDimensions(); + + if (rank == 4) { + InlinedVector perm{0, 2, 3, 1}; + TensorShapeVector new_dims{orig_shape[0], + orig_shape[2], + orig_shape[3], + orig_shape[1]}; + + packed_w_ = Tensor(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); - InlinedVector perm{0, 2, 3, 1}; - TensorShapeVector new_dims{orig_shape[0], - orig_shape[2], - orig_shape[3], - orig_shape[1]}; + SingleAxisTranspose(perm, tensor, packed_w_, /*from*/ 1, /*to*/ 3); + } else { + assert(rank == 3); // ConvBase::IsOnnxNodeSupported validates this - packed_w_ = Tensor(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); + InlinedVector perm{0, 2, 1}; + TensorShapeVector new_dims{orig_shape[0], + orig_shape[2], + orig_shape[1]}; - SingleAxisTranspose(perm, tensor, packed_w_, /*from*/ 1, /*to*/ 3); + packed_w_ = Tensor(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); + + SingleAxisTranspose(perm, tensor, packed_w_, /*from*/ 1, /*to*/ 2); + } is_packed = true; @@ -47,9 +63,13 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, Status Conv::Compute(OpKernelContext* context) const { const Tensor& X = *context->Input(0); // this is in NHWC format const auto& X_shape = X.Shape(); - const int64_t N = X_shape[0]; // input is NHWC - const int64_t H = X_shape[1]; - const int64_t W = X_shape[2]; + const auto rank = X_shape.NumDimensions(); + const auto is_1D = rank == 3; + const int64_t N = X_shape[0]; // input is NHWC or NWC + + // we support 1D or 2D. if 1D we convert to 2D by setting H to 1 + const int64_t H = is_1D ? 1 : X_shape[1]; + const int64_t W = X_shape[rank - 2]; // We don't need to call ValidateInputShape as we checked validity in ConvChecker. // We also can't use ValidateInputShape as-is as the weight tensor was pre-packed and the layout was changed there. diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc index d21014569234e..2aafc9be7ffd0 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc @@ -17,7 +17,8 @@ namespace onnxruntime { namespace xnnpack { namespace { -Status CreateXnnpackKernel(const ConvAttributes* conv_attrs_ptr, + +Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, int64_t C, int64_t M, const TensorShapeVector& kernel_shape, const std::optional>& clip_min_max, @@ -30,19 +31,21 @@ Status CreateXnnpackKernel(const ConvAttributes* conv_attrs_ptr, bool is_transpose = false) { struct xnn_operator* p = nullptr; - const uint32_t kernel_height = gsl::narrow(kernel_shape[0]); - const uint32_t kernel_width = gsl::narrow(kernel_shape[1]); + // if this is 1D input, we fake all the height related dims being 1 to make it 2D. so {W} -> {1, W} + const auto is_1D = kernel_shape.size() == 1; + + const uint32_t kernel_height = is_1D ? 1 : gsl::narrow(kernel_shape[0]); + const uint32_t kernel_width = gsl::narrow(kernel_shape[is_1D ? 0 : 1]); - const auto& conv_attrs = *conv_attrs_ptr; - const uint32_t input_padding_top = gsl::narrow(conv_attrs.pads[0]); - const uint32_t input_padding_left = gsl::narrow(conv_attrs.pads[1]); - const uint32_t input_padding_bottom = gsl::narrow(conv_attrs.pads[2]); - const uint32_t input_padding_right = gsl::narrow(conv_attrs.pads[3]); + const uint32_t input_padding_top = is_1D ? 0 : gsl::narrow(conv_attrs.pads[0]); + const uint32_t input_padding_left = gsl::narrow(conv_attrs.pads[is_1D ? 0 : 1]); + const uint32_t input_padding_bottom = is_1D ? 0 : gsl::narrow(conv_attrs.pads[2]); + const uint32_t input_padding_right = gsl::narrow(conv_attrs.pads[is_1D ? 1 : 3]); - const uint32_t subsampling_height = gsl::narrow(conv_attrs.strides[0]); - const uint32_t subsampling_width = gsl::narrow(conv_attrs.strides[1]); - const uint32_t dilation_height = gsl::narrow(conv_attrs.dilations[0]); - const uint32_t dilation_width = gsl::narrow(conv_attrs.dilations[1]); + const uint32_t subsampling_height = is_1D ? 1 : gsl::narrow(conv_attrs.strides[0]); + const uint32_t subsampling_width = gsl::narrow(conv_attrs.strides[is_1D ? 0 : 1]); + const uint32_t dilation_height = is_1D ? 1 : gsl::narrow(conv_attrs.dilations[0]); + const uint32_t dilation_width = gsl::narrow(conv_attrs.dilations[is_1D ? 0 : 1]); uint32_t flags = 0; if (conv_attrs.auto_pad == AutoPadType::SAME_UPPER) { @@ -299,7 +302,7 @@ bool ConvBase::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& const onnxruntime::Node& node = node_unit.GetNode(); // use do {} while(false) so it's easier to set a breakpoint on the return do { - // Internal NHWC domain starts at opset 11 + // We have only implemented support for opset 11 Conv and above if (node_unit.SinceVersion() < 11) { break; } @@ -309,11 +312,18 @@ bool ConvBase::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& const auto& x_arg = inputs[0].node_arg; const auto& weight_arg = inputs[1].node_arg; - // we only support 2D (4 dims with batch and channel) const auto* x_shape = x_arg.Shape(); - if (!x_shape || x_shape->dim_size() != 4) { + if (!x_shape) { break; } + + // xnnpack only supports 2D. we support 2D (4 dims with batch and channel) or 1D (3 dims). + // if 1D we can fake the data being 4D by pretending the height dims are 1 + const auto rank = x_shape->dim_size(); + if (rank != 4 && rank != 3) { + break; + } + // we only support float and u8 currently const auto* x_type = x_arg.TypeAsProto(); if (x_type == nullptr || @@ -325,13 +335,13 @@ bool ConvBase::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& // require C, H, W to be known so we can construct the xnnpack kernel prior to Compute if (!x_shape->dim(1).has_dim_value() || !x_shape->dim(2).has_dim_value() || - !x_shape->dim(3).has_dim_value()) { + (rank == 4 && !x_shape->dim(3).has_dim_value())) { break; } - // weight must be constant and also rank 4 + // weight must be constant const auto* weight = graph.GetConstantInitializer(weight_arg.Name(), true); - if (weight == nullptr || weight->dims_size() != 4) { + if (weight == nullptr) { break; } @@ -396,7 +406,9 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose) const auto& input_defs = node.InputDefs(); const NodeArg& X = *input_defs[0]; auto X_shape = utils::GetTensorShapeFromTensorShapeProto(*X.Shape()); - C_ = X_shape[3]; // input is NHWC. op support checker made sure C dim was known + const auto rank = X_shape.NumDimensions(); + + C_ = X_shape[rank - 1]; // input is NHWC or NWC. op support checker made sure C dim was known // as the weight input is a constant initializer we can calculate all the sizes here instead of in Compute const Tensor* Weight = nullptr; @@ -416,10 +428,9 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose) ORT_ENFORCE(info.TryGetConstantInput(weight_index, &Weight), "Weight input was not constant initializer. XNNPACK EP should not have asked for the node. Node name:", node.Name()); - M_ = Weight->Shape()[0]; - // this happens before PrePack, so the weight input is still in the ONNX spec format - ORT_THROW_IF_ERROR(convbase_attrs_ref_.ComputeKernelShape(Weight->Shape(), kernel_shape_)); + const auto& weight_shape = Weight->Shape(); + ORT_THROW_IF_ERROR(convbase_attrs_ref_.ComputeKernelShape(weight_shape, kernel_shape_)); if (convbase_attrs_ref_.pads.empty()) { convbase_attrs_ref_.pads.resize(kernel_shape_.size() * 2, 0); @@ -440,16 +451,19 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose) "Invalid Node with non-constant Bias input. XNNPACK EP should not have asked for the node. Node name:", node.Name()); } else { - has_bias = input_defs.size() == (8 + 1) && input_defs[8]->Exists(); + has_bias = input_defs.size() == 9 && input_defs[8]->Exists(); ORT_ENFORCE(has_bias == false || info.TryGetConstantInput(8, &B_), "Invalid Node with non-constant Bias input. XNNPACK EP should not have asked for the node. Node name:", node.Name()); } - const TensorShape input_shape{X_shape[1], X_shape[2]}; + + // HW from NHWC or W from NWC + const TensorShape input_shape = rank == 4 ? TensorShape{X_shape[1], X_shape[2]} + : TensorShape{X_shape[1]}; if (is_transpose) { // Group_num group_size - M_ = Weight->Shape()[1] * convbase_attrs_ref_.group; + M_ = weight_shape[1] * convbase_attrs_ref_.group; if (conv_transpose_attrs_.output_padding.empty()) { conv_transpose_attrs_.output_padding.resize(kernel_shape_.size(), 0); } @@ -458,23 +472,31 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose) input_shape, M_, kernel_shape_, conv_transpose_attrs_.strides, conv_transpose_attrs_.dilations, conv_transpose_attrs_.output_padding, 1, &conv_transpose_attrs_.pads, &output_shape_); + output_shape_[1] = output_shape_[2]; - output_shape_[2] = output_shape_[3]; - output_shape_[3] = M_; + if (rank == 4) { + output_shape_[2] = output_shape_[3]; + } + + output_shape_[rank - 1] = M_; + } else { + M_ = weight_shape[0]; + ConvAttributes::ConvPadVector pads(conv_attrs_.pads); - output_shape_.push_back(1); + output_shape_.push_back(1); // N ORT_THROW_IF_ERROR(conv_attrs_.InferPadsAndOutputShape(input_shape, kernel_shape_, conv_attrs_.strides, conv_attrs_.dilations, pads, output_shape_)); output_shape_.push_back(M_); } + // have to delay creating the xnnpack kernel until after the weights are pre-packed. } Status ConvBase::CreateKernel() { - auto ret = CreateXnnpackKernel(&convbase_attrs_ref_, C_, M_, kernel_shape_, clip_min_max_, packed_w_, + auto ret = CreateXnnpackKernel(convbase_attrs_ref_, C_, M_, kernel_shape_, clip_min_max_, packed_w_, B_, op0_, GetCodeCache(), GetWeightsCache(), quant_param_, conv_type_, is_transpose_); diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc index 8698c0739509d..c136385f12476 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc @@ -20,41 +20,75 @@ Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr // only layout of weight input is adjusted via PrePack if ((conv_type_ == OpComputeType::op_compute_type_fp32 && input_idx == 1) || (conv_type_ != OpComputeType::op_compute_type_fp32 && input_idx == 3)) { // InputTensors::IN_W + auto orig_shape = tensor.Shape(); + const auto rank = orig_shape.NumDimensions(); + if (conv_transpose_attrs_.group > 1) { // Xnnpack [G, Oc, H, W Ic/G] // (ref: https://github.com/google/XNNPACK/blob/ecd8311c8fd3d9ab47edbc3df5f2b5de7dabe75f/test/deconvolution-operator-tester.h#L678) - TensorShape orig_shape = - {conv_transpose_attrs_.group, - tensor.Shape()[0] / conv_transpose_attrs_.group, - tensor.Shape()[1], - tensor.Shape()[2], - tensor.Shape()[3]}; - - InlinedVector perm{0, 2, 3, 4, 1}; - TensorShapeVector new_dims{ - orig_shape[0], - orig_shape[2], - orig_shape[3], - orig_shape[4], - orig_shape[1]}; - - packed_w_ = Tensor(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); - // g I/g O H W --> g O H W I/g - SingleAxisTranspose(perm, tensor, packed_w_, /*from*/ 1, /*to*/ 4, &orig_shape); + if (rank == 4) { + // split C (dim 0) into {group, C/group} + TensorShape w_reshaped = {conv_transpose_attrs_.group, + orig_shape[0] / conv_transpose_attrs_.group, + orig_shape[1], + orig_shape[2], + orig_shape[3]}; + + InlinedVector perm{0, 2, 3, 4, 1}; + TensorShapeVector new_dims{w_reshaped[0], + w_reshaped[2], + w_reshaped[3], + w_reshaped[4], + w_reshaped[1]}; + + packed_w_ = Tensor(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); + // g I/g O H W --> g O H W I/g + SingleAxisTranspose(perm, tensor, packed_w_, /*from*/ 1, /*to*/ 4, &w_reshaped); + } else { + assert(rank == 3); + + TensorShape w_reshaped = {conv_transpose_attrs_.group, + orig_shape[0] / conv_transpose_attrs_.group, + orig_shape[1], + orig_shape[2]}; + + InlinedVector perm{0, 2, 3, 1}; + TensorShapeVector new_dims{w_reshaped[0], + w_reshaped[2], + w_reshaped[3], + w_reshaped[1]}; + + packed_w_ = Tensor(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); + // g I/g O W --> g O W I/g + SingleAxisTranspose(perm, tensor, packed_w_, /*from*/ 1, /*to*/ 3, &w_reshaped); + } } else { - // Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} - - auto orig_shape = tensor.Shape(); - InlinedVector perm{1, 2, 3, 0}; - TensorShapeVector new_dims{orig_shape[1], - orig_shape[2], - orig_shape[3], - orig_shape[0]}; - - packed_w_ = Tensor(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); - // I O H W --> O H W I - SingleAxisTranspose(perm, tensor, packed_w_, /*from*/ 0, /*to*/ 3); + if (rank == 4) { + // Transpose from {C, M/group, kH, kW} to {M/group, kH, kW, C} + InlinedVector perm{1, 2, 3, 0}; + TensorShapeVector new_dims{orig_shape[1], + orig_shape[2], + orig_shape[3], + orig_shape[0]}; + + packed_w_ = Tensor(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); + // I O H W --> O H W I + SingleAxisTranspose(perm, tensor, packed_w_, /*from*/ 0, /*to*/ 3); + } else { + // Transpose from {C, M/group, kW} to {M/group, kW, C} + assert(rank == 3); + + InlinedVector perm{1, 2, 0}; + TensorShapeVector new_dims{orig_shape[1], + orig_shape[2], + orig_shape[0]}; + + packed_w_ = Tensor(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); + // I O W --> O W I + SingleAxisTranspose(perm, tensor, packed_w_, /*from*/ 0, /*to*/ 2); + } } + is_packed = true; // we can create the kernel now @@ -68,9 +102,11 @@ Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr Status ConvTranspose::Compute(OpKernelContext* context) const { const Tensor& X = *context->Input(0); // this is in NHWC format const auto& X_shape = X.Shape(); - const int64_t N = X_shape[0]; // input is NHWC - const int64_t H = X_shape[1]; - const int64_t W = X_shape[2]; + const int64_t N = X_shape[0]; // input is NHWC or NWC + const auto rank = X_shape.NumDimensions(); + const bool is_1D = rank == 3; + const int64_t H = is_1D ? 1 : X_shape[1]; + const int64_t W = X_shape[rank - 2]; TensorShapeVector Y_dims(output_shape_); Y_dims[0] = N; @@ -83,8 +119,9 @@ Status ConvTranspose::Compute(OpKernelContext* context) const { } pthreadpool_t threadpool = GetThreadPool(); - auto output_pad_0 = gsl::narrow_cast(conv_transpose_attrs_.output_padding[0]); - auto output_pad_1 = gsl::narrow_cast(conv_transpose_attrs_.output_padding[1]); + auto output_pad_0 = is_1D ? 0 : gsl::narrow_cast(conv_transpose_attrs_.output_padding[0]); + auto output_pad_1 = gsl::narrow_cast(conv_transpose_attrs_.output_padding[is_1D ? 0 : 1]); + xnn_status status = xnn_status_invalid_state; auto reshape_fn = xnn_reshape_deconvolution2d_nhwc_f32; diff --git a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc index ec93dc249eeb2..81191e9b48c3c 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc @@ -98,17 +98,23 @@ TEST(ConvTransposeTest, ConvTranspose_1D) { 1, // group "NOTSET" // auto_pad }; - vector X = {0.0f, 1.0f, 2.0f}; - vector X_shape = {1, 1, 3}; - vector W = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; - vector W_shape = {1, 2, 3}; + + vector X_shape = {1, 2, 3}; + vector X = {0.1f, 1.0f, 2.0f, + 3.0f, 4.0f, 5.0f}; + vector W_shape = {2, 2, 3}; + vector W = {1.0f, 2.0f, 3.0f, + 4.0f, 5.0f, 6.0f, + 6.0f, 5.0f, 4.0f, + 3.0f, 2.0f, 1.0f}; vector Y_shape = {1, 2, 5}; - auto expected_vals = {0.0f, 1.0f, 3.0f, 3.0f, 2.0f, 0.0f, 1.0f, 3.0f, 3.0f, 2.0f}; + auto expected_vals = {18.1f, 40.2f, 66.3f, 48.f, 26.f, + 9.4f, 22.5f, 39.6f, 30.f, 17.f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } -TEST(ConvTransposeTest, ConvTranspose_2D) { +TEST(ConvTransposeTest, ConvTranspose_2D_outputpadding_strides2) { ConvTransposeOpAttributes attrs = { vector{3, 3}, // kernel_shape vector{1, 1}, // output_padding @@ -119,14 +125,17 @@ TEST(ConvTransposeTest, ConvTranspose_2D) { 1, // group "NOTSET" // auto_pad }; + + vector X_shape = {1, 1, 3, 3}; vector X = {0.16857791f, -0.15161794f, 0.08540368f, 0.1820628f, -0.21746576f, 0.08245695f, 0.1431433f, -0.43156421f, 0.30591947f}; - vector X_shape = {1, 1, 3, 3}; + + vector W_shape = {1, 1, 3, 3}; vector W = {-0.06230065f, 0.37932432f, -0.25388849f, 0.33878803f, 0.43709868f, -0.22477469f, 0.04118127f, -0.44696793f, 0.06373066f}; - vector W_shape = {1, 1, 3, 3}; + vector Y_shape = {1, 1, 6, 6}; auto expected_vals = {0.07368518f, -0.08925839f, -0.06627201f, 0.06301362f, 0.03732984f, -0.01919658f, -0.00628807f, -0.02817563f, -0.01472169f, 0.04392925f, -0.00689478f, -0.01549204f, @@ -137,6 +146,45 @@ TEST(ConvTransposeTest, ConvTranspose_2D) { TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } +// 2D input with C > 1 +TEST(ConvTransposeTest, ConvTranspose_2D_C2) { + ConvTransposeOpAttributes attrs = { + vector{2, 2}, // kernel_shape + {}, // output_padding + {}, // output_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + vector{1, 1}, // dilations + 1, // group + "NOTSET" // auto_pad + }; + + vector X_shape = {1, 2, 3, 3}; + vector X = {0.43f, 0.42871707f, 0.29552766f, + 0.17258859f, 0.68087016f, 0.7090254f, + 0.60937387f, 0.58646585f, 0.84525721f, + + 0.47011843f, 0.95854213f, 0.3972888f, + 0.0585452f, 0.1206734f, 0.76727852f, + 0.46040912f, 0.83495316f, 0.02409773f}; + + vector W_shape = {2, 1, 2, 2}; + vector W = {0.25616416f, 0.10246604f, + 0.08771133f, 0.30770606f, + + 0.84369617f, 0.3010619f, + 0.44524362f, 0.6056068f}; + + vector Y_shape = {1, 1, 4, 4}; + auto expected_vals = { + 0.50678771f, 1.10413539f, 0.74340409f, 0.14989006f, + 0.34063845f, 1.19294512f, 1.85030293f, 0.63518577f, + 0.58575004f, 1.25774109f, 1.23472511f, 0.77670550f, + 0.25844323f, 0.88953220f, 0.77098041f, 0.27468451f}; + + TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); +} + TEST(ConvTransposeTest, ConvTranspose_2D_Bias_1) { ConvTransposeOpAttributes attrs = { vector{3, 3}, // kernel_shape @@ -261,7 +309,48 @@ TEST(ConvTransposeTest, ConvTranspose_2D_OutputShape_1) { {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider}); } -TEST(ConvTransposeTest, ConvTranspose_2D_OutputShape_1_group_2_for_tranpose_path) { +TEST(ConvTransposeTest, ConvTranspose_1D_OutputShape_1_group_2_for_transpose_path) { + ConvTransposeOpAttributes attrs = { + vector{3}, // kernel_shape + {}, // output_padding + vector{1, 6, 4}, // output_shape + vector{0, 0}, // pads + vector{1}, // strides + vector{1}, // dilations + 2, // group + "NOTSET" // auto_pad + }; + int image_size = 4; + int input_channels = 3 * 2; + int output_channels = 3; + std::vector X; + for (int i = 0; i < input_channels * image_size; i++) { + X.push_back(1.0f); + } + + std::vector W; + int kernel_size = output_channels * input_channels * 3; + for (int i = 0; i < kernel_size; i++) { + W.push_back(1.0f); + } + + vector X_shape = {1, 6, 4}; + vector W_shape = {6, 3, 3}; + vector Y_shape = {1, 6, 4}; + + auto expected_vals = {6.0f, 9.0f, 9.0f, 6.0f, + 6.0f, 9.0f, 9.0f, 6.0f, + 6.0f, 9.0f, 9.0f, 6.0f, + 6.0f, 9.0f, 9.0f, 6.0f, + 6.0f, 9.0f, 9.0f, 6.0f, + 6.0f, 9.0f, 9.0f, 6.0f}; + + TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, + OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider}); +} + +TEST(ConvTransposeTest, ConvTranspose_2D_OutputShape_1_group_2_for_transpose_path) { ConvTransposeOpAttributes attrs = { vector{3, 3}, // kernel_shape {}, // output_padding @@ -287,104 +376,31 @@ TEST(ConvTransposeTest, ConvTranspose_2D_OutputShape_1_group_2_for_tranpose_path vector W_shape = {6, 3, 3, 3}; vector Y_shape = {1, 6, 4, 4}; - auto expected_vals = { - 12.0f, - 18.0f, - 18.0f, - 12.0f, - 18.0f, - 27.0f, - 27.0f, - 18.0f, - 18.0f, - 27.0f, - 27.0f, - 18.0f, - 12.0f, - 18.0f, - 18.0f, - 12.0f, - 12.0f, - 18.0f, - 18.0f, - 12.0f, - 18.0f, - 27.0f, - 27.0f, - 18.0f, - 18.0f, - 27.0f, - 27.0f, - 18.0f, - 12.0f, - 18.0f, - 18.0f, - 12.0f, - 12.0f, - 18.0f, - 18.0f, - 12.0f, - 18.0f, - 27.0f, - 27.0f, - 18.0f, - 18.0f, - 27.0f, - 27.0f, - 18.0f, - 12.0f, - 18.0f, - 18.0f, - 12.0f, // duplicate below - 12.0f, - 18.0f, - 18.0f, - 12.0f, - 18.0f, - 27.0f, - 27.0f, - 18.0f, - 18.0f, - 27.0f, - 27.0f, - 18.0f, - 12.0f, - 18.0f, - 18.0f, - 12.0f, - 12.0f, - 18.0f, - 18.0f, - 12.0f, - 18.0f, - 27.0f, - 27.0f, - 18.0f, - 18.0f, - 27.0f, - 27.0f, - 18.0f, - 12.0f, - 18.0f, - 18.0f, - 12.0f, - 12.0f, - 18.0f, - 18.0f, - 12.0f, - 18.0f, - 27.0f, - 27.0f, - 18.0f, - 18.0f, - 27.0f, - 27.0f, - 18.0f, - 12.0f, - 18.0f, - 18.0f, - 12.0f, - }; + auto expected_vals = {12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, // duplicate below + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f}; + TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider}); @@ -441,11 +457,6 @@ TEST(ConvTransposeTest, ConvTranspose_2D_OutputShapeWithBatchSize) { } TEST(ConvTransposeTest, ConvTranspose_InvalidKernelShape) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: provider_test_utils.cc(866): error: Value of: expect_result == ExpectResult::kExpectSuccess"; - } - ConvTransposeOpAttributes attrs = { vector{1, 1, 1, 5}, // invalid kernel_shape, should be [1, 5] {}, // output_padding @@ -468,7 +479,11 @@ TEST(ConvTransposeTest, ConvTranspose_InvalidKernelShape) { 11.0f, 32.0f, 65.0f, 91.0f, 109.0f, 118.0f, 127.0f, 136.0f, 145.0f, 154.0f, 143.0f, 111.0f, 57.0f, 20.0f}; TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectFailure, - "kernel_shape num_dims is not compatible with W num_dims. kernel_shape: {1,1,1,5} W: {1,1,1,5}"); + // error message will end in "W: {1,1,1,5}" or "W: {1,1,5,1} depending on whether NCHW or NHWC, + // so drop the part that differs from the expected string + "kernel_shape num_dims is not compatible with W num_dims. kernel_shape: {1,1,1,5} W: {1,1,", + {kTensorrtExecutionProvider, kQnnExecutionProvider, + kDmlExecutionProvider}); // TODO: Unskip when fixed #41968513 } TEST(ConvTransposeTest, ConvTranspose_onnx) { @@ -1085,11 +1100,6 @@ TEST(ConvTransposeTest, ConvTranspose_1D_AutoPad_SameLower) { } TEST(ConvTransposeTest, ConvTranspose_AutoPad_with_non_default_strides) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect."; - } - ConvTransposeOpAttributes attrs = { vector{3, 3}, // kernel_shape {}, // output_padding @@ -1132,7 +1142,9 @@ TEST(ConvTransposeTest, ConvTranspose_AutoPad_with_non_default_strides) { TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider}); // Accuracy Mismatch on OpenVINO-EP + {kTensorrtExecutionProvider, kQnnExecutionProvider, + kOpenVINOExecutionProvider, // Accuracy Mismatch on OpenVINO-EP + kDmlExecutionProvider}); // TODO: Unskip when fixed #41968513 } #ifndef ENABLE_TRAINING diff --git a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc index 786b2cb4cedc4..9837dca398ff8 100644 --- a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc @@ -20,12 +20,18 @@ struct ConvTransposeOp { std::vector dilations = {1, 1}; std::unique_ptr get_test() { - RandomValueGenerator random{}; + RandomValueGenerator random{T(123.f)}; // use seed so output is deterministic to aid in debugging failures auto test = std::make_unique("ConvTranspose", 14); std::vector input_data = random.Uniform(input_dims, 0.0f, 1.0f); - std::vector weight_dims{input_dims[1], channels / group, kernel_shape[0], kernel_shape[1]}; + // 1D or 2D input is supported + const bool is_1D = input_dims.size() == 3; + std::vector weight_dims{input_dims[1], channels / group, kernel_shape[0]}; + if (!is_1D) { + weight_dims.push_back(kernel_shape[1]); + } + std::vector weight_data = random.Uniform(weight_dims, -0.4f, 0.4f); test->AddInput("X", input_dims, input_data); @@ -46,12 +52,17 @@ struct ConvTransposeOp { output_padding = {0, 0, 0, 0}; } - std::vector output_dims = { - input_dims[0], channels, - (kernel_shape[1] - 1) * dilations[1] + (input_dims[2] - 1) * strides[1] - (padding[1] + padding[0]) + 1 + - output_padding[2], - (kernel_shape[0] - 1) * dilations[0] + (input_dims[3] - 1) * strides[0] - (padding[3] + padding[2]) + 1 + - output_padding[3]}; + // the test input is NCHW so calculate output based on that. conversion to/from NHWC is internal to execution. + std::vector output_dims = {input_dims[0], channels}; + + for (size_t i = 0, end = is_1D ? 1 : 2; i < end; ++i) { + // formula from https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConvTranspose + const size_t start_pad = i * 2; + output_dims.push_back( + strides[i] * (input_dims[i + 2] - 1) + output_padding[i] + + ((kernel_shape[i] - 1) * dilations[i] + 1) - padding[start_pad] - padding[start_pad + 1]); + } + std::vector output_data = FillZeros(output_dims); test->AddOutput("Y", output_dims, output_data); @@ -83,6 +94,27 @@ TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcBias) { } } +TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcBias1D) { + auto op = ConvTransposeOp{}; + op.input_dims = {1, 8, 80}; + op.kernel_shape = {5}; + op.channels = 16; + op.bias = true; + op.padding = {0, 0}; + op.strides = {1}; + op.dilations = {1}; + op.output_padding = {}; + + // test with adding fake W and H dimensions + if (HasCudaEnvironment(800)) { + MAKE_PROVIDERS_EPS_EXT(1e-2, true) // add fake H dimension of 1 to convert to 2D + MAKE_PROVIDERS_EPS_EXT(1e-2, false) // add fake W dimension of 1 to convert to 2D + } else { + MAKE_PROVIDERS_EPS_TYPE_EXT(TypeParam, true) + MAKE_PROVIDERS_EPS_TYPE_EXT(TypeParam, false) + } +} + TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcPad) { auto op = ConvTransposeOp{}; op.input_dims = {1, 16, 8, 8}; diff --git a/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h b/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h index 82b6a286409cd..1aea58c8d7a10 100644 --- a/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h +++ b/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h @@ -6,37 +6,52 @@ #include #include -#include "core/providers/cuda/cuda_provider_options.h" +#include "gtest/gtest.h" + #include "core/providers/common.h" +#include "core/providers/cuda/cuda_provider_options.h" -#include "test/providers/compare_provider_test_utils.h" #include "test/common/cuda_op_test_utils.h" +#include "test/providers/compare_provider_test_utils.h" +#include "test/util/include/default_providers.h" -#include "gtest/gtest.h" +// extended cuda provider args. compare NHWC implementation vs CUDA NCHW and CPU EP. +#define MAKE_PROVIDERS_EPS_EXT(eps, pad_to_nc1d) \ + { \ + std::vector> execution_providers; \ + OrtCUDAProviderOptionsV2 nhwc{}; \ + nhwc.prefer_nhwc = true; \ + nhwc.cudnn_conv1d_pad_to_nc1d = pad_to_nc1d; \ + execution_providers.push_back(CudaExecutionProviderWithOptions(&nhwc)); \ + \ + double error_tolerance = eps; \ + OrtCUDAProviderOptionsV2 nchw{}; \ + nchw.prefer_nhwc = false; \ + nchw.cudnn_conv1d_pad_to_nc1d = pad_to_nc1d; \ + auto nchw_ep = CudaExecutionProviderWithOptions(&nchw); \ + auto test = op.get_test(); \ + test->CompareEPs(std::move(nchw_ep), execution_providers, error_tolerance); \ + auto cpu_ep = DefaultCpuExecutionProvider(); \ + test->CompareEPs(std::move(cpu_ep), execution_providers, error_tolerance); \ + } -#define MAKE_PROVIDERS_EPS(eps) \ - std::vector> execution_providers; \ - OrtCUDAProviderOptionsV2 nhwc{}; \ - nhwc.prefer_nhwc = true; \ - execution_providers.push_back(CudaExecutionProviderWithOptions(&nhwc)); \ - \ - double error_tolerance = eps; \ - OrtCUDAProviderOptionsV2 nchw{}; \ - nchw.prefer_nhwc = false; \ - auto source_ep = CudaExecutionProviderWithOptions(&nchw); \ - auto test = op.get_test(); \ - test->CompareEPs(std::move(source_ep), execution_providers, error_tolerance); +#define MAKE_PROVIDERS_EPS(eps) \ + MAKE_PROVIDERS_EPS_EXT(eps, false) #define MAKE_PROVIDERS() MAKE_PROVIDERS_EPS(1e-3) -#define MAKE_PROVIDERS_EPS_TYPE(T) \ - if (std::is_same::value) { \ - MAKE_PROVIDERS_EPS(2e-2) \ - } else if (std::is_same::value) { \ - MAKE_PROVIDERS_EPS(2e-4) \ - } else { \ - MAKE_PROVIDERS_EPS(2e-3) \ +#define MAKE_PROVIDERS_EPS_TYPE_EXT(T, pad_to_nc1d) \ + if (std::is_same::value) { \ + MAKE_PROVIDERS_EPS_EXT(2e-2, pad_to_nc1d) \ + } else if (std::is_same::value) { \ + MAKE_PROVIDERS_EPS_EXT(2e-4, pad_to_nc1d) \ + } else { \ + MAKE_PROVIDERS_EPS_EXT(2e-3, pad_to_nc1d) \ } + +#define MAKE_PROVIDERS_EPS_TYPE(T) \ + MAKE_PROVIDERS_EPS_TYPE_EXT(T, false) + namespace onnxruntime { namespace test {