Skip to content

Commit

Permalink
XNNPACK: Support 1D input for Conv and ConvTranspose (microsoft#20349)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Support 1D input to XNNPACK Conv and ConvTranspose by using faking
height of 1 to convert to 2D input.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Enable speech model with 1D input to use XNNPACK. There is no CPU EP
quantized ConvTranspose, so this fills that gap.
  • Loading branch information
skottmckay authored Apr 23, 2024
1 parent 3270a00 commit c47a6ce
Show file tree
Hide file tree
Showing 11 changed files with 504 additions and 264 deletions.
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/cpu/nn/conv_attributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ struct ConvAttributes {

~ConvAttributes() = default;

Status ComputeKernelShape(const TensorShape& weight_shape, TensorShapeVector& kernel_shape, bool weight_channels_last = false) const {
Status ComputeKernelShape(const TensorShape& weight_shape, TensorShapeVector& kernel_shape,
bool weight_channels_last = false) const {
if (kernel_shape_specified) {
kernel_shape = kernel_shape_;
if (kernel_shape.size() + 2 != weight_shape.NumDimensions()) {
Expand Down
20 changes: 16 additions & 4 deletions onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,15 @@ struct ConvTransposeAttributes : public ConvAttributes {
TensorShapeVector strides;
};

// Viewing dim 1 of the X input as 'input channels' (C) and dim 1 of the Y output as 'output channels' (M),
// if is_nhwc is true, the input channels (dim 0) or output channels (dim 1) of the W input (the filter with
// shape {C, M/group, ...}) could be transposed to be last. transposed_input_channels indicates whether dim 0 or
// dim 1 was moved.
//
// e.g. XNNPACK moves the input channels dim to the end. CUDA moves the output channels dim to the end.
Status PrepareForCompute(OpKernelContext* context, bool has_bias, Prepare& p,
bool dynamic_padding = false, const TensorShape* filter_shape = nullptr,
bool is_nhwc = false) const {
bool is_nhwc = false, bool transposed_input_channels = true) const {
const Tensor* X = context->Input<Tensor>(0);
const Tensor* F = (filter_shape != nullptr) ? nullptr : context->Input<Tensor>(1);
const TensorShape& F_Shape = (filter_shape != nullptr) ? *filter_shape : F->Shape();
Expand All @@ -57,7 +63,12 @@ struct ConvTransposeAttributes : public ConvAttributes {
TensorShape input_shape = X->Shape().Slice(is_nhwc ? 1 : 2, is_nhwc ? rank - 1 : rank);
const int64_t num_input_channels = is_nhwc ? X->Shape()[rank - 1] : X->Shape()[1];
const int64_t N = X->Shape()[0];
const int64_t num_output_channels_multiplier = is_nhwc ? F_Shape[3] : F_Shape[1];

// W is {C, M/group, ....}. adjust for NHWC and transposed_input_channels
// If we transposed the input channels, {C, M/group, ...} becomes {M/group, ..., C}
// If we transposed the output channels, {C, M/group, ...} becomes {C, ..., M/group}
const auto M_div_group_dim = is_nhwc ? (transposed_input_channels ? 0 : F_Shape.NumDimensions() - 1) : 1;
const int64_t num_output_channels_multiplier = F_Shape[M_div_group_dim];
const int64_t num_output_channels = num_output_channels_multiplier * group;

// input validations
Expand All @@ -72,9 +83,10 @@ struct ConvTransposeAttributes : public ConvAttributes {
" W: ", F_Shape.ToString().c_str());
}

if (F_Shape[0] != num_input_channels) {
const auto F_channels_dim = is_nhwc && transposed_input_channels ? F_Shape.NumDimensions() - 1 : 0;
if (F_Shape[F_channels_dim] != num_input_channels) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "filter number not equal to input channel number.",
" filter_number: ", F_Shape[0],
" filter_number: ", F_Shape[F_channels_dim],
" num_input_channels: ", num_input_channels);
}

Expand Down
15 changes: 13 additions & 2 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2394,7 +2394,9 @@ static bool RNNNeedFallbackToCPU(const onnxruntime::Node& node,
return false;
}

static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node, const logging::Logger& logger) {
static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node, const logging::Logger& logger,
[[maybe_unused]] const GraphViewer& graph_viewer,
[[maybe_unused]] const bool prefer_nhwc) {
const auto& node_attributes = node.GetAttributes();
// Check attributes
for (auto& attr : node_attributes) {
Expand Down Expand Up @@ -2442,6 +2444,15 @@ static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node, const
}
}

#ifdef ENABLE_CUDA_NHWC_OPS
if (prefer_nhwc) {
// NHWC implementation doesn't handle transpose of W if it's not an initializer
if (!graph_viewer.IsConstantInitializer(node.InputDefs()[1]->Name(), true)) {
return true;
}
}
#endif

return false;
}

Expand Down Expand Up @@ -2510,7 +2521,7 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
not_supported = RNNNeedFallbackToCPU(node, activations_supported, node.OpType());
force_inside = !not_supported;
} else if ("ConvTranspose" == node.OpType()) {
not_supported = ConvTransposeNeedFallbackToCPU(node, logger);
not_supported = ConvTransposeNeedFallbackToCPU(node, logger, graph, IsNHWCPreferred());
force_inside = !not_supported;
} else if ("Cast" == node.OpType()) {
not_supported = CastNeedFallbackToCPU(node);
Expand Down
138 changes: 107 additions & 31 deletions onnxruntime/core/providers/cuda/nn/conv_transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,28 +45,47 @@ Status ConvTranspose<T, NHWC>::ComputeInternal(OpKernelContext* context) const {

template <typename T, bool NHWC>
Status ConvTranspose<T, NHWC>::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed,
[[maybe_unused]] PrePackedWeights* prepacked_weights) {
PrePackedWeights* prepacked_weights) {
is_packed = false;
// only layout of weight input is adjusted via PrePack
if (NHWC) { // InputTensors::IN_W
if constexpr (NHWC) { // InputTensors::IN_W
if (input_idx == 1) {
// Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group}
auto orig_shape = tensor.Shape();
const auto rank = orig_shape.NumDimensions();

InlinedVector<size_t> perm;
TensorShapeVector new_dims;

// Input is { N, C, ...}. Output is { N, M, ...}. 'input channels' is C. 'output channels' is M.
// Transpose the output channels related dimension (M/group) to be last. Leave the input channels as-is.
if (rank == 3) {
// Transpose from {C, M/group, k1} to {C, k1, M/group}
perm = {0, 2, 1};
new_dims = TensorShapeVector{orig_shape[0], orig_shape[2], orig_shape[1]};
} else if (rank == 4) {
// Transpose from {C, M/group, kH, kW} to {C, kH, kW, M/group}
perm = {0, 2, 3, 1};
new_dims = TensorShapeVector{orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[1]};
} else if (rank == 5) {
// Transpose from {C, M/group, k1, k2, k3} to {C, k1, k2, k3, M/group}
perm = {0, 2, 3, 4, 1};
new_dims = TensorShapeVector{orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[4], orig_shape[1]};
}

InlinedVector<size_t> perm{0, 2, 3, 1};
gsl::span<size_t> permutation(perm.data(), 4);
TensorShapeVector new_dims{orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[1]};
gsl::span<size_t> permutation(perm.data(), rank);
W_ = Tensor::Create(tensor.DataType(), TensorShape(new_dims), std::move(alloc));

auto status = cuda::Transpose::DoTranspose(GetDeviceProp(), DefaultCudaStream(), DefaultCublasHandle(),
permutation, tensor, *W_);
ORT_RETURN_IF_ERROR(cuda::Transpose::DoTranspose(GetDeviceProp(), DefaultCudaStream(), DefaultCublasHandle(),
permutation, tensor, *W_));

if (!status.IsOK()) {
return status;
}
CUDA_CALL_THROW(cudaStreamSynchronize(DefaultCudaStream()));
is_packed = true;
}
} else {
ORT_UNUSED_PARAMETER(tensor);
ORT_UNUSED_PARAMETER(input_idx);
ORT_UNUSED_PARAMETER(alloc);
ORT_UNUSED_PARAMETER(prepacked_weights);
}

return Status::OK();
Expand All @@ -87,12 +106,10 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input X must be 3-, 4- or 5-dimensional.",
" X: ", X->Shape().ToString().c_str());
}
const Tensor* W;
if (!W_) {
W = context->Input<Tensor>(1);
} else {
W = W_.get();
}

// use pre-packed W if available
const Tensor* W = W_ ? W_.get() : context->Input<Tensor>(1);

const TensorShape& w_shape = W->Shape();
TensorShapeVector w_dims = w_shape.AsShapeVector();
auto w_data = reinterpret_cast<const CudaT*>(W->Data<T>());
Expand All @@ -101,9 +118,38 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
bool has_bias = dynamic_padding ? num_inputs == 4 : num_inputs == 3;

CudaT* y_data = nullptr;

const auto* cuda_ep = static_cast<const CUDAExecutionProvider*>(Info().GetExecutionProvider());

// convert 1D to 2D
if (x_dimensions == 3) {
x_dims.insert(x_dims.begin() + 2, 1);
w_dims.insert(w_dims.begin() + 2, 1);
// we can either add a fake H or W dimension with a value of 1. to be consistent with the Conv behavior we use
// GetCudnnConv1dPadToNc1d to determine which is added.
// see Conv<T, NHWC>::UpdateState in /onnxruntime/core/providers/cuda/nn/conv.cc for more details.
if (cuda_ep->GetCudnnConv1dPadToNc1d()) {
// add fake H dimension
const auto insert_at = NHWC ? 1 : 2;

// NCHW: N, C, d1 -> N, C, 1, d1
// NHWC: N, d1, C -> N, 1, d1, C
x_dims.insert(x_dims.begin() + insert_at, 1);

// 'M' is channels dim in CUDA implementation
// NCHW: C, M/g, k1 -> C, M/g, 1, k1
// NHWC: C, k1, M/g -> C, 1, k1, M/g
w_dims.insert(w_dims.begin() + insert_at, 1);
} else {
// add fake W dimension
const auto insert_at = NHWC ? 2 : 3;

// NCHW: N, C, d1 -> N, C, d1, 1
// NHWC: N, d1, C -> N, d1, 1, C
x_dims.insert(x_dims.begin() + insert_at, 1);

// NCHW: C, M/g, k1 -> C, M/g, k1, 1
// NHWC: C, k1, M/g -> C, k1, 1, M/g
w_dims.insert(w_dims.begin() + insert_at, 1);
}
}

{
Expand All @@ -113,30 +159,50 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
bool input_dims_changed = (s_.last_x_dims.AsShapeVector() != x_dims);
bool w_dims_changed = (s_.last_w_dims.AsShapeVector() != w_dims);
if (input_dims_changed || w_dims_changed) {
if (input_dims_changed) s_.last_x_dims = gsl::make_span(x_dims);
if (input_dims_changed) {
s_.last_x_dims = gsl::make_span(x_dims);
}

if (w_dims_changed) {
s_.last_w_dims = gsl::make_span(w_dims);
s_.cached_benchmark_results.clear();
}

ConvTransposeAttributes::Prepare p;
// PrePack moves the M/group dimension of W to the end, with 'M' being interpreted as 'output channels'
const bool transposed_input_channels = false;
ORT_RETURN_IF_ERROR(
conv_transpose_attrs_.PrepareForCompute(context, has_bias, p, dynamic_padding, &w_shape, NHWC));
conv_transpose_attrs_.PrepareForCompute(context, has_bias, p, dynamic_padding, &w_shape, NHWC, transposed_input_channels));

auto y_dims = p.Y->Shape().AsShapeVector();
if (x_dimensions == 3) {
y_dims.insert(y_dims.begin() + 2, 1);
p.kernel_shape.insert(p.kernel_shape.begin(), 1);
p.pads.insert(p.pads.begin(), 0);
p.pads.insert(p.pads.begin() + 2, 0);
p.strides.insert(p.strides.begin(), 1);
p.dilations.insert(p.dilations.begin(), 1);
if (cuda_ep->GetCudnnConv1dPadToNc1d()) {
// add fake H dimension of 1
// NCHW: N, M, d1 -> N, M, 1, d1 or
// NHWC: N, d1, M -> N, 1, d1, M
y_dims.insert(y_dims.begin() + (NHWC ? 1 : 2), 1);
p.kernel_shape.insert(p.kernel_shape.begin(), 1);
p.pads.insert(p.pads.begin(), 0);
p.pads.insert(p.pads.begin() + 2, 0);
p.strides.insert(p.strides.begin(), 1);
p.dilations.insert(p.dilations.begin(), 1);
} else {
// add fake W dimension of 1
// NCHW: N, M, d1 -> N, M, d1, 1 or
// NHWC: N, d1, M -> N, d1, 1, M
y_dims.insert(y_dims.begin() + (NHWC ? 2 : 3), 1);
p.kernel_shape.push_back(1);
p.pads.insert(p.pads.begin() + 1, 0);
p.pads.push_back(0);
p.strides.push_back(1);
p.dilations.push_back(1);
}
}

s_.y_dims = gsl::make_span(y_dims);

if (w_dims_changed) {
if (NHWC) {
if constexpr (NHWC) {
ORT_RETURN_IF_ERROR(s_.w_desc.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType<CudaT>(),
static_cast<int>(w_dims[0]), static_cast<int>(w_dims[3]),
static_cast<int>(w_dims[1]), static_cast<int>(w_dims[2])));
Expand All @@ -152,7 +218,8 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
if (p.Y->Shape().Size() == 0) {
return Status::OK();
}
if (NHWC) {

if constexpr (NHWC) {
ORT_RETURN_IF_ERROR(s_.x_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType<CudaT>(),
static_cast<int>(x_dims[0]), static_cast<int>(x_dims[3]),
static_cast<int>(x_dims[1]), static_cast<int>(x_dims[2])));
Expand All @@ -176,7 +243,9 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
TensorShapeVector b_dims(2 + p.kernel_shape.size());
b_dims[0] = 1; // N
b_dims[NHWC ? 3 : 1] = b_shape[0]; // C
for (size_t i = 0; i < p.kernel_shape.size(); i++) b_dims[(NHWC ? 1 : 2) + i] = 1;
for (size_t i = 0; i < p.kernel_shape.size(); i++) {
b_dims[(NHWC ? 1 : 2) + i] = 1;
}

ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType<CudaT>(), NHWC));
}
Expand Down Expand Up @@ -215,8 +284,15 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
if (!y_data) {
auto y_dims = s_.y_dims.AsShapeVector();
if (x_dimensions == 3) {
y_dims.erase(y_dims.begin() + 2);
if (cuda_ep->GetCudnnConv1dPadToNc1d()) {
// erase the fake H dimension
y_dims.erase(y_dims.begin() + (NHWC ? 1 : 2));
} else {
// erase the fake W dimension
y_dims.erase(y_dims.begin() + (NHWC ? 2 : 3));
}
}

Tensor* Y = context->Output(0, TensorShape(y_dims));
y_data = reinterpret_cast<CudaT*>(Y->MutableData<T>());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -839,13 +839,15 @@ namespace OperatorHelper

if (outputShape.size() > 2)
{
ML_CHECK_VALID_ARGUMENT(outputShape[outputShape.size() - 3] == gsl::narrow_cast<int>(m_outputShapes[0].GetShape()[C]), "Output channel must be equivalent to filter channel.");
}
ML_CHECK_VALID_ARGUMENT(outputShape[C] == gsl::narrow_cast<int>(m_outputShapes[0].GetShape()[C]),
"Output channel must be equivalent to filter channel.");
}

for (size_t i = 0; i < m_kernel.spatialDimensionCount; ++i)
{
size_t outputIndex = outputShape.size() - m_kernel.spatialDimensionCount + i;
ML_CHECK_VALID_ARGUMENT(outputShape[outputIndex] >= gsl::narrow_cast<int>(inputDimensions[H + i]), "Output dimension cannot be smaller than input dimension.");
ML_CHECK_VALID_ARGUMENT(outputShape[outputIndex] >= gsl::narrow_cast<int>(inputDimensions[H + i]),
"Output dimension cannot be smaller than input dimension.");
m_outputShapes[0].GetShape()[H + i] = outputShape[outputIndex];
}

Expand Down
40 changes: 30 additions & 10 deletions onnxruntime/core/providers/xnnpack/nn/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#include "conv.h"

#include <cassert>

#include "core/common/gsl.h"
#include "core/common/inlined_containers_fwd.h"
#include "core/framework/tensorprotoutils.h"
Expand All @@ -24,16 +26,30 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
(conv_type_ != OpComputeType::op_compute_type_fp32 && input_idx == 3)) { // InputTensors::IN_W
// Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group}
auto orig_shape = tensor.Shape();
const auto rank = orig_shape.NumDimensions();

if (rank == 4) {
InlinedVector<size_t> perm{0, 2, 3, 1};
TensorShapeVector new_dims{orig_shape[0],
orig_shape[2],
orig_shape[3],
orig_shape[1]};

packed_w_ = Tensor(tensor.DataType(), TensorShape(new_dims), std::move(alloc));

InlinedVector<size_t> perm{0, 2, 3, 1};
TensorShapeVector new_dims{orig_shape[0],
orig_shape[2],
orig_shape[3],
orig_shape[1]};
SingleAxisTranspose(perm, tensor, packed_w_, /*from*/ 1, /*to*/ 3);
} else {
assert(rank == 3); // ConvBase::IsOnnxNodeSupported validates this

packed_w_ = Tensor(tensor.DataType(), TensorShape(new_dims), std::move(alloc));
InlinedVector<size_t> perm{0, 2, 1};
TensorShapeVector new_dims{orig_shape[0],
orig_shape[2],
orig_shape[1]};

SingleAxisTranspose(perm, tensor, packed_w_, /*from*/ 1, /*to*/ 3);
packed_w_ = Tensor(tensor.DataType(), TensorShape(new_dims), std::move(alloc));

SingleAxisTranspose(perm, tensor, packed_w_, /*from*/ 1, /*to*/ 2);
}

is_packed = true;

Expand All @@ -47,9 +63,13 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
Status Conv::Compute(OpKernelContext* context) const {
const Tensor& X = *context->Input<Tensor>(0); // this is in NHWC format
const auto& X_shape = X.Shape();
const int64_t N = X_shape[0]; // input is NHWC
const int64_t H = X_shape[1];
const int64_t W = X_shape[2];
const auto rank = X_shape.NumDimensions();
const auto is_1D = rank == 3;
const int64_t N = X_shape[0]; // input is NHWC or NWC

// we support 1D or 2D. if 1D we convert to 2D by setting H to 1
const int64_t H = is_1D ? 1 : X_shape[1];
const int64_t W = X_shape[rank - 2];

// We don't need to call ValidateInputShape as we checked validity in ConvChecker.
// We also can't use ValidateInputShape as-is as the weight tensor was pre-packed and the layout was changed there.
Expand Down
Loading

0 comments on commit c47a6ce

Please sign in to comment.