diff --git a/examples/BuddyConvolution/conv2d-nhwc-fhwc-opt.mlir b/examples/BuddyConvolution/conv2d-nhwc-fhwc-opt.mlir index 76d5e4d932..7c30df096e 100644 --- a/examples/BuddyConvolution/conv2d-nhwc-fhwc-opt.mlir +++ b/examples/BuddyConvolution/conv2d-nhwc-fhwc-opt.mlir @@ -40,6 +40,7 @@ module { %h_o = memref.dim %arg2, %c1 : memref %w_o = memref.dim %arg2, %c2 : memref + %t_start = call @rtclock() : () -> f64 // Output is NHoWoF affine.for %idx_n = %c0 to %n { affine.for %idx_f = %c0 to %f { @@ -67,7 +68,14 @@ module { } } } + %t_end = call @rtclock() : () -> f64 + %time = arith.subf %t_end, %t_start : f64 + + %printed_output = memref.cast %arg2 : memref to memref<*xf32> + call @printMemrefF32(%printed_output) : (memref<*xf32>) -> () + // Print timings. + vector.print %time : f64 return } @@ -111,10 +119,6 @@ module { %v1 = call @alloc_f32(%c6, %c5, %c5, %c1, %f3) : (index, index, index, index, f32) -> memref %v2 = call @alloc_f32(%c1, %c24, %c24, %c6, %f0) : (index, index, index, index, f32) -> memref - %t_start = call @rtclock() : () -> f64 - call @conv_2d_nhwc_fhwc(%v0, %v1, %v2) : (memref, memref, memref) -> () - %t_end = call @rtclock() : () -> f64 - // All the elements of the MemRef are the same, // only check the first line to verify the correctness. // CHECK: Unranked Memref @@ -122,16 +126,11 @@ module { // CHECK: [ // CHECK: [ // CHECK: [150{{(, 150)*}}], - %print_v2 = memref.cast %v2 : memref to memref<*xf32> - call @printMemrefF32(%print_v2) : (memref<*xf32>) -> () - - %time = arith.subf %t_end, %t_start : f64 - vector.print %time : f64 - + call @conv_2d_nhwc_fhwc(%v0, %v1, %v2) : (memref, memref, memref) -> () + memref.dealloc %v0 : memref memref.dealloc %v1 : memref memref.dealloc %v2 : memref - return } } diff --git a/examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir b/examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir new file mode 100644 index 0000000000..340f5109fc --- /dev/null +++ b/examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir @@ -0,0 +1,161 @@ +// RUN: buddy-opt %s \ +// RUN: -convert-vector-to-scf \ +// RUN: -lower-affine \ +// RUN: -arith-bufferize \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm \ +// RUN: -convert-arith-to-llvm \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -O3 -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +// Using `8` as the vector size. +module { + func.func private @printMemrefF32(memref<*xf32>) + func.func private @rtclock() -> f64 + + func.func @conv_2d_nhwc_fhwc(%arg0: memref, %arg1: memref, %arg2: memref) { + %f0 = arith.constant 0. : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %vl_step = arith.constant 8 : index + %vec0 = vector.splat %f0 : vector<8xf32> + %n = memref.dim %arg0, %c0 : memref + %c = memref.dim %arg0, %c3 : memref + %f = memref.dim %arg1, %c0 : memref + %h_k = memref.dim %arg1, %c1 : memref + %w_k = memref.dim %arg1, %c2 : memref + %h_o = memref.dim %arg2, %c1 : memref + %w_o = memref.dim %arg2, %c2 : memref + + // Calculate the upper bound for vectorized processing + // - Subtract `vl_step` is to avoid overflow at the vectorization tail. + // - Add 1 to ensure the final loop runs when the workload length + // is divisible by the vector size. + %upbound_tmp = arith.subi %c, %vl_step : index + %upbound = arith.addi %upbound_tmp, %c1 : index + + %t_start = call @rtclock() : () -> f64 + // Output is NHoWoF + affine.for %idx_n = %c0 to %n { + affine.for %idx_h_o = %c0 to %h_o { + affine.for %idx_w_o = %c0 to %w_o { + affine.for %idx_f = %c0 to %f { + %tmp_result = memref.load %arg2[%idx_n, %idx_h_o, %idx_w_o, %idx_f] : memref + %iter_idx, %iter_value = scf.for %idx_c = %c0 to %upbound step %vl_step + iter_args(%iter_init = %c0, %iter_value0 = %tmp_result) -> (index, f32) { + %tmp8 = affine.for %idx_h_k = %c0 to %h_k iter_args(%tmp9 = %iter_value0) -> (f32) { + %tmp6 = affine.for %idx_w_k = %c0 to %w_k iter_args(%tmp7 = %tmp9) -> (f32) { + %in_iter_h = arith.addi %idx_h_k, %idx_h_o : index + %in_iter_w = arith.addi %idx_w_k, %idx_w_o : index + %input_vec = vector.load %arg0[%idx_n, %in_iter_h, %in_iter_w, %idx_c] : memref, vector<8xf32> + %kernel_vec = vector.load %arg1[%idx_f, %idx_h_k, %idx_w_k, %idx_c] : memref, vector<8xf32> + %tmp_vec0 = arith.mulf %kernel_vec, %input_vec : vector<8xf32> + %tmp_val = vector.reduction , %tmp_vec0 : vector<8xf32> into f32 + %tmp4 = arith.addf %tmp7, %tmp_val : f32 + affine.yield %tmp4 : f32 + } + affine.yield %tmp6 : f32 + } + %tmp11 = arith.addi %idx_c, %vl_step : index + scf.yield %tmp11, %tmp8 : index, f32 + } + // Compute the tail size and Process the remaining elements + // using masked vector operations. + %tail_size = arith.subi %c, %iter_idx : index + %3 = arith.cmpi sgt, %tail_size, %c0 : index + scf.if %3 { + %mask = vector.create_mask %tail_size : vector<8xi1> + %tmp8 = affine.for %idx_h_k = %c0 to %h_k iter_args(%tmp9 = %iter_value) -> (f32) { + %tmp6 = affine.for %idx_w_k = %c0 to %w_k iter_args(%tmp7 = %tmp9) -> (f32) { + %in_iter_h = arith.addi %idx_h_k, %idx_h_o : index + %in_iter_w = arith.addi %idx_w_k, %idx_w_o : index + %input_vec = vector.maskedload %arg0[%idx_n, %in_iter_h, %in_iter_w, %iter_idx], %mask, %vec0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> + %kernel_vec = vector.maskedload %arg1[%idx_f, %idx_h_k, %idx_w_k, %iter_idx], %mask, %vec0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> + %tmp_vec0 = arith.mulf %kernel_vec, %input_vec : vector<8xf32> + %tmp_val = vector.reduction , %tmp_vec0 : vector<8xf32> into f32 + %tmp4 = arith.addf %tmp7, %tmp_val : f32 + affine.yield %tmp4 : f32 + } + affine.yield %tmp6 : f32 + } + memref.store %tmp8, %arg2[%idx_n, %idx_h_o, %idx_w_o, %idx_f] : memref + } else { + memref.store %iter_value, %arg2[%idx_n, %idx_h_o, %idx_w_o, %idx_f] : memref + } + } + } + } + } + %t_end = call @rtclock() : () -> f64 + %time = arith.subf %t_end, %t_start : f64 + + %printed_output = memref.cast %arg2 : memref to memref<*xf32> + call @printMemrefF32(%printed_output) : (memref<*xf32>) -> () + + // Print timings. + vector.print %time : f64 + return + } + + func.func @alloc_f32(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: f32) -> memref { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = memref.alloc(%arg0, %arg1, %arg2, %arg3) : memref + scf.for %idx0 = %c0 to %arg0 step %c1 { + scf.for %idx1 = %c0 to %arg1 step %c1 { + scf.for %idx2 = %c0 to %arg2 step %c1 { + scf.for %idx3 = %c0 to %arg3 step %c1 { + memref.store %arg4, %0[%idx0, %idx1, %idx2, %idx3] : memref + } + } + } + } + return %0 : memref + } + + func.func @main() { + %f0 = arith.constant 0.000000e+00 : f32 + %f2 = arith.constant 2.000000e+00 : f32 + %f3 = arith.constant 3.000000e+00 : f32 + + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c8 = arith.constant 8 : index + %c12 = arith.constant 12 : index + %c16 = arith.constant 16 : index + %c24 = arith.constant 24 : index + %c28 = arith.constant 28 : index + + // %v0 = call @alloc_f32(%c1, %c12, %c12, %c6, %f2) : (index, index, index, index, f32) -> memref + // %v1 = call @alloc_f32(%c16, %c5, %c5, %c6, %f3) : (index, index, index, index, f32) -> memref + // %v2 = call @alloc_f32(%c1, %c8, %c8, %c16, %f0) : (index, index, index, index, f32) -> memref + + %v0 = call @alloc_f32(%c1, %c28, %c28, %c1, %f2) : (index, index, index, index, f32) -> memref + %v1 = call @alloc_f32(%c6, %c5, %c5, %c1, %f3) : (index, index, index, index, f32) -> memref + %v2 = call @alloc_f32(%c1, %c24, %c24, %c6, %f0) : (index, index, index, index, f32) -> memref + + // All the elements of the MemRef are the same, + // only check the first line to verify the correctness. + // CHECK: Unranked Memref + // CHECK: [ + // CHECK: [ + // CHECK: [ + // CHECK: [150{{(, 150)*}}], + call @conv_2d_nhwc_fhwc(%v0, %v1, %v2) : (memref, memref, memref) -> () + + memref.dealloc %v0 : memref + memref.dealloc %v1 : memref + memref.dealloc %v2 : memref + return + } +} diff --git a/examples/BuddyConvolution/conv2d-nhwc-fhwc.mlir b/examples/BuddyConvolution/conv2d-nhwc-fhwc.mlir index 90759355e9..15d7284266 100644 --- a/examples/BuddyConvolution/conv2d-nhwc-fhwc.mlir +++ b/examples/BuddyConvolution/conv2d-nhwc-fhwc.mlir @@ -1,4 +1,5 @@ // RUN: buddy-opt %s \ +// RUN: -conv2d-nhwc-fhwc-vectorization \ // RUN: -convert-linalg-to-loops \ // RUN: -lower-affine \ // RUN: -arith-bufferize \ @@ -18,8 +19,20 @@ module { func.func private @rtclock() -> f64 func.func @conv_2d_nhwc_fhwc(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.conv_2d_nhwc_fhwc ins (%arg0, %arg1: memref, memref) - outs (%arg2: memref) + %t_start = call @rtclock() : () -> f64 + + linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%arg0, %arg1: memref, memref) + outs (%arg2: memref) + + %t_end = call @rtclock() : () -> f64 + %time = arith.subf %t_end, %t_start : f64 + + %printed_output = memref.cast %arg2 : memref to memref<*xf32> + call @printMemrefF32(%printed_output) : (memref<*xf32>) -> () + + // Print timings. + vector.print %time : f64 return } @@ -63,10 +76,6 @@ module { %v1 = call @alloc_f32(%c6, %c5, %c5, %c1, %f3) : (index, index, index, index, f32) -> memref %v2 = call @alloc_f32(%c1, %c24, %c24, %c6, %f0) : (index, index, index, index, f32) -> memref - %t_start = call @rtclock() : () -> f64 - call @conv_2d_nhwc_fhwc(%v0, %v1, %v2) : (memref, memref, memref) -> () - %t_end = call @rtclock() : () -> f64 - // All the elements of the MemRef are the same, // only check the first line to verify the correctness. // CHECK: Unranked Memref @@ -74,12 +83,8 @@ module { // CHECK: [ // CHECK: [ // CHECK: [150{{(, 150)*}}], - %print_v2 = memref.cast %v2 : memref to memref<*xf32> - call @printMemrefF32(%print_v2) : (memref<*xf32>) -> () - - %time = arith.subf %t_end, %t_start : f64 - vector.print %time : f64 - + call @conv_2d_nhwc_fhwc(%v0, %v1, %v2) : (memref, memref, memref) -> () + memref.dealloc %v0 : memref memref.dealloc %v1 : memref memref.dealloc %v2 : memref diff --git a/examples/BuddyConvolution/makefile b/examples/BuddyConvolution/makefile index 1962643766..20c1de2c12 100644 --- a/examples/BuddyConvolution/makefile +++ b/examples/BuddyConvolution/makefile @@ -125,3 +125,72 @@ conv2d-nhwc-fhwc-opt-aot: -L${MLIR_LIB} -lmlir_runner_utils -lmlir_c_runner_utils \ -o a.out @LD_LIBRARY_PATH=${MLIR_LIB} ./a.out + +conv2d-nhwc-fhwc-vectorization-lower: + @${BUDDY_OPT} ./conv2d-nhwc-fhwc.mlir \ + -conv2d-nhwc-fhwc-vectorization \ + -o log.mlir + +conv2d-nhwc-fhwc-vectorization-run: + @${BUDDY_OPT} ./conv2d-nhwc-fhwc.mlir \ + -conv2d-nhwc-fhwc-vectorization="vec-size=2" \ + -convert-linalg-to-loops \ + -lower-affine \ + -arith-bufferize \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +conv2d-nhwc-fhwc-vectorization-aot: + @${BUDDY_OPT} ./conv2d-nhwc-fhwc.mlir \ + -conv2d-nhwc-fhwc-vectorization \ + -convert-linalg-to-loops \ + -lower-affine \ + -arith-bufferize \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} -mlir-to-llvmir -o log.ll + ${CLANG} log.ll ${OPT_FLAG} \ + -L${MLIR_LIB} -lmlir_runner_utils -lmlir_c_runner_utils \ + -o a.out + @LD_LIBRARY_PATH=${MLIR_LIB} ./a.out + +conv2d-nhwc-fhwc-vec-run: + @${BUDDY_OPT} ./conv2d-nhwc-fhwc-vec.mlir \ + -convert-vector-to-scf \ + -lower-affine \ + -arith-bufferize \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +conv2d-nhwc-fhwc-vec-aot: + @${BUDDY_OPT} ./conv2d-nhwc-fhwc-vec.mlir \ + -convert-vector-to-scf \ + -lower-affine \ + -arith-bufferize \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} -mlir-to-llvmir -o log.ll + ${CLANG} log.ll -O3 \ + -L${MLIR_LIB} -lmlir_runner_utils -lmlir_c_runner_utils \ + -o a.out + @LD_LIBRARY_PATH=${MLIR_LIB} ./a.out diff --git a/midend/lib/Conversion/ConvVectorization/CMakeLists.txt b/midend/lib/Conversion/ConvVectorization/CMakeLists.txt index d4cc3ec987..6306049a38 100644 --- a/midend/lib/Conversion/ConvVectorization/CMakeLists.txt +++ b/midend/lib/Conversion/ConvVectorization/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_library(CBConvVectorization GEMMPointwiseConv2DNhwcHwcf.cpp PoolingVectorization.cpp PoolingNhwcMaxVectorization.cpp + Conv2dNhwcFhwcVectorization.cpp LINK_LIBS PUBLIC BuddyUtils diff --git a/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp b/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp new file mode 100644 index 0000000000..a4d491e31f --- /dev/null +++ b/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp @@ -0,0 +1,398 @@ +//===--------Conv2dNhwcFhwcVectorization.cpp-------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Pooling Nhwc Max Vectorization. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/ValueRange.h" +#include "llvm/ADT/ArrayRef.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "Utils/Utils.h" + +using namespace mlir; +using namespace vector; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { + +class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { +public: + explicit Conv2dNhwcFhwcVectorizationPattern(MLIRContext *context, + int64_t vecsizeParam) + : ConversionPattern(linalg::Conv2DNhwcFhwcOp::getOperationName(), 1, + context) { + vecsize = vecsizeParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto ctx = op->getContext(); + + // Get input, kernel and output. + Value input = op->getOperand(0); + Value kernel = op->getOperand(1); + Value output = op->getOperand(2); + // Get Strides. + SmallVector strides = {1, 1}; + if (op->hasAttr("strides")) { + strides.clear(); + for (auto value : op->getAttrOfType("strides") + .getValues()) { + strides.push_back(value); + } + } + bool stride1 = strides[0] != 1; + bool stride2 = strides[1] != 1; + Value strHeight = rewriter.create(loc, strides[0]); + Value strWidth = rewriter.create(loc, strides[1]); + + // Get Dilations. + SmallVector dilations = {1, 1}; + if (op->hasAttr("dilations")) { + dilations.clear(); + for (auto value : + op->getAttrOfType("dilations") + .getValues()) { + dilations.push_back(value); + } + } + bool dilated1 = dilations[0] != 1; + bool dilated2 = dilations[1] != 1; + Value dilHeight = + rewriter.create(loc, dilations[0]); + Value dilWidth = rewriter.create(loc, dilations[1]); + + // Get i1 as the element type for mask vector. + IntegerType i1 = IntegerType::get(ctx, 1); + VectorType vectorMaskTy = mlir::VectorType::get({vecsize}, i1); + // Get ElementType of input and create pass through vector. + Type elementTy = input.getType().cast().getElementType(); + VectorType vectorTy = mlir::VectorType::get({vecsize}, elementTy); + + // Get Constants. + const Value c0 = rewriter.create(loc, 0); + const Value c1 = rewriter.create(loc, 1); + const Value c2 = rewriter.create(loc, 2); + const Value c3 = rewriter.create(loc, 3); + const Value vlStep = rewriter.create(loc, vecsize); + const Value zero = + buddy::insertZeroConstantOp(ctx, rewriter, loc, elementTy); + + // Create pass through vector. + Value passThroughVec = rewriter.create(loc, vectorTy, zero); + + // Get Dimensions of Input. + Value batch = rewriter.create(loc, input, c0); + Value channels = rewriter.create(loc, input, c3); + + // Get Dimensions of Kernel. + Value f_o = rewriter.create(loc, kernel, c0); + Value height_k = rewriter.create(loc, kernel, c1); + Value width_k = rewriter.create(loc, kernel, c2); + + // Get Dimensions of Outputs. + Value height_o = rewriter.create(loc, output, c1); + Value width_o = rewriter.create(loc, output, c2); + + // Calculate the upper bound for vectorized processing + // - Subtract `vlStep` is to avoid overflow at the vectorization tail. + // - Add 1 to ensure the final loop runs when the workload length + // is divisible by the vector size. + Value upperBoundTmp = rewriter.create(loc, channels, vlStep); + Value upperBound = rewriter.create(loc, upperBoundTmp, c1); + + SmallVector lowerBounds(4, c0); + SmallVector uperBounds{batch, height_o, width_o, f_o}; + SmallVector steps(4, /*Value=*/1); + affine::buildAffineLoopNest( + rewriter, loc, lowerBounds, uperBounds, steps, + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + // Create strides variables. + Value tmpIvs1 = ivs[1]; + if (stride1) { + tmpIvs1 = builder.create(loc, ivs[1], strHeight); + } + Value tmpIvs2 = ivs[2]; + if (stride2) { + tmpIvs2 = builder.create(loc, ivs[2], strWidth); + } + Value tmp_result = builder.create( + loc, elementTy, output, + ValueRange{ivs[0], ivs[1], ivs[2], ivs[3]}); + // Create vecsize mining loop. + auto iter_val = builder.create( + loc, c0, upperBound, /*Step=*/vlStep, ValueRange{c0, tmp_result}, + [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, + ValueRange itrArgs) { + auto tmp0 = nestedBuilder.create( + nestedLoc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{height_k}, builder.getDimIdentityMap(), + /*Step=*/1, ValueRange{itrArgs[1]}, + [&](OpBuilder &builder, Location loc, Value iv0, + ValueRange itrArgs0) { + // Create dilated[0] variables. + Value tmpIvs3 = iv0; + if (dilated1) { + tmpIvs3 = + builder.create(loc, iv0, dilHeight); + } + Value inputHeight = + builder.create(loc, tmpIvs1, tmpIvs3); + auto tmp1 = builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{width_k}, builder.getDimIdentityMap(), + /*Step=*/1, ValueRange{itrArgs0[0]}, + [&](OpBuilder &builder, Location loc, Value iv1, + ValueRange itrArgs1) { + // Create dilated[1] variables. + Value tmpIvs4 = iv1; + if (dilated2) { + tmpIvs4 = builder.create(loc, iv1, + dilWidth); + } + Value inputWidth = builder.create( + loc, tmpIvs2, tmpIvs4); + Value inputVector = builder.create( + loc, vectorTy, input, + ValueRange{ivs[0], inputHeight, inputWidth, + iv}); + Value kernelVector = builder.create( + loc, vectorTy, kernel, + ValueRange{ivs[3], iv0, iv1, iv}); + // FMA + Value resultVal; + if (auto ty = + llvm::dyn_cast(elementTy)) { + Value tmpVec = builder.create( + loc, inputVector, kernelVector); + Value tmpVal = + builder.create( + loc, vector::CombiningKind::ADD, tmpVec, + ::mlir::arith::FastMathFlags::none); + resultVal = builder.create( + loc, tmpVal, itrArgs1[0]); + } else { + Value tmpVec = builder.create( + loc, inputVector, kernelVector); + Value tmpVal = + builder.create( + loc, vector::CombiningKind::ADD, tmpVec, + ::mlir::arith::FastMathFlags::none); + resultVal = builder.create( + loc, tmpVal, itrArgs1[0]); + } + builder.create(loc, + resultVal); + }); + nestedBuilder.create( + nestedLoc, tmp1.getResult(0)); + }); + Value idx = + builder.create(loc, iv, vlStep); + builder.create( + loc, ValueRange{idx, tmp0.getResult(0)}); + }); + // Compute the tail size and Process the remaining elements + // using masked vector operations. + + Value idx = iter_val.getResult(0); + Value tailSize = builder.create(loc, channels, idx); + Value tailCond = rewriter.create( + loc, arith::CmpIPredicate::sgt, tailSize, c0); + // If the current column does not reach the tail. + builder.create( + loc, tailCond, + [&](OpBuilder &builder, Location loc) { + Value tailMask = + builder.create(loc, vectorMaskTy, tailSize); + auto tmp0 = builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{height_k}, builder.getDimIdentityMap(), + /*Step=*/1, ValueRange{iter_val.getResult(1)}, + [&](OpBuilder &builder, Location loc, Value iv0, + ValueRange itrArgs0) { + // Create dilated[0] variables. + Value tmpIvs3 = iv0; + if (dilated1) { + tmpIvs3 = + builder.create(loc, iv0, dilHeight); + } + Value inputHeight = + builder.create(loc, tmpIvs1, tmpIvs3); + auto tmp1 = builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{width_k}, builder.getDimIdentityMap(), + /*Step=*/1, ValueRange{itrArgs0[0]}, + [&](OpBuilder &builder, Location loc, Value iv1, + ValueRange itrArgs1) { + // Create dilated[1] variables. + Value tmpIvs4 = iv1; + if (dilated2) { + tmpIvs4 = builder.create(loc, iv1, + dilWidth); + } + Value inputWidth = builder.create( + loc, tmpIvs2, tmpIvs4); + Value inputVec = builder.create( + loc, vectorTy, input, + ValueRange{ivs[0], inputHeight, inputWidth, + idx}, + tailMask, passThroughVec); + Value kernelVec = builder.create( + loc, vectorTy, kernel, + ValueRange{ivs[3], iv0, iv1, idx}, tailMask, + passThroughVec); + // FMA + Value resultVal; + if (auto ty = + llvm::dyn_cast(elementTy)) { + Value tmpVec = builder.create( + loc, inputVec, kernelVec); + Value tmpVal = + builder.create( + loc, vector::CombiningKind::ADD, tmpVec, + ::mlir::arith::FastMathFlags::none); + resultVal = builder.create( + loc, tmpVal, itrArgs1[0]); + } else { + Value tmpVec = builder.create( + loc, inputVec, kernelVec); + Value tmpVal = + builder.create( + loc, vector::CombiningKind::ADD, tmpVec, + ::mlir::arith::FastMathFlags::none); + resultVal = builder.create( + loc, tmpVal, itrArgs1[0]); + } + builder.create(loc, + resultVal); + }); + builder.create(loc, + tmp1.getResult(0)); + }); + builder.create( + loc, tmp0.getResult(0), output, + ValueRange{ivs[0], ivs[1], ivs[2], ivs[3]}); + builder.create(loc); + }, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, iter_val.getResult(1), output, + ValueRange{ivs[0], ivs[1], ivs[2], ivs[3]}); + builder.create(loc); + }); + }); + // Remove the origin convolution operation. + rewriter.eraseOp(op); + return success(); + } + +private: + int64_t vecsize; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Conv2dNhwcFhwcVectorizationPass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering linalg pooling max operations to mixture of +/// Arith + Vector operations. +namespace { +class Conv2dNhwcFhwcVectorizationPass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(Conv2dNhwcFhwcVectorizationPass) + StringRef getArgument() const final { + return "conv2d-nhwc-fhwc-vectorization"; + } + StringRef getDescription() const final { + return "Conv2d_Nhwc_Fhwc Vectorization."; + } + Conv2dNhwcFhwcVectorizationPass() = default; + Conv2dNhwcFhwcVectorizationPass(const Conv2dNhwcFhwcVectorizationPass &) {} + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + Option vecsize{*this, "vec-size", + llvm::cl::desc("Specify vector type size."), + llvm::cl::init(8)}; +}; +} // end anonymous namespace. + +void Conv2dNhwcFhwcVectorizationPass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target.addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, vecsize); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +namespace mlir { +namespace buddy { +void registerConv2dNhwcFhwcVectorizationPass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir b/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir new file mode 100644 index 0000000000..37b955453b --- /dev/null +++ b/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir @@ -0,0 +1,33 @@ +// RUN: buddy-opt -conv2d-nhwc-fhwc-vectorization %s | FileCheck %s + +// CHECK: #map = affine_map<(d0) -> (d0)> +// CHECK: module { +// CHECK: affine.for %arg3 = #map(%c0) to #map(%dim) { +// CHECK-NEXT: affine.for %arg4 = #map(%c0) to #map(%dim_8) { +// CHECK-NEXT: affine.for %arg5 = #map(%c0) to #map(%dim_9) { +// CHECK-NEXT: affine.for %arg6 = #map(%c0) to #map(%dim_5) { +// CHECK-NEXT: %3 = memref.load %arg2[%arg3, %arg4, %arg5, %arg6] : memref +// CHECK-NEXT: %4:2 = scf.for %arg7 = %c0 to %2 step %c8 iter_args(%arg8 = %c0, %arg9 = %3) -> (index, f32) { +// CHECK-NEXT: %7 = affine.for %arg10 = #map(%c0) to #map(%dim_6) iter_args(%arg11 = %arg9) -> (f32) { +// CHECK-NEXT: %9 = arith.addi %arg4, %arg10 : index +// CHECK-NEXT: %10 = affine.for %arg12 = #map(%c0) to #map(%dim_7) iter_args(%arg13 = %arg11) -> (f32) { +// CHECK-NEXT: %11 = arith.addi %arg5, %arg12 : index +// CHECK-NEXT: %12 = vector.load %arg0[%arg3, %9, %11, %arg7] : memref, vector<8xf32> +// CHECK-NEXT: %13 = vector.load %arg1[%arg6, %arg10, %arg12, %arg7] : memref, vector<8xf32> +// CHECK-NEXT: %14 = arith.mulf %12, %13 : vector<8xf32> +// CHECK-NEXT: %15 = vector.reduction , %14 : vector<8xf32> into f32 +// CHECK-NEXT: %16 = arith.addf %15, %arg13 : f32 +// CHECK-NEXT: affine.yield %16 : f32 +// CHECK-NEXT: } +// CHECK-NEXT: affine.yield %10 : f32 +// CHECK-NEXT: } +// CHECK-NEXT: %8 = arith.addi %arg7, %c8 : index +// CHECK-NEXT: scf.yield %8, %7 : index, f32 +// CHECK-NEXT: } + +func.func @conv_2d_nhwc_fhwc(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%arg0, %arg1: memref, memref) + outs (%arg2: memref) + return +} diff --git a/tools/buddy-opt/buddy-opt.cpp b/tools/buddy-opt/buddy-opt.cpp index 2c7de877c6..3ac33506ae 100644 --- a/tools/buddy-opt/buddy-opt.cpp +++ b/tools/buddy-opt/buddy-opt.cpp @@ -56,6 +56,7 @@ void registerPointwiseConvToGemmPass(); void registerPointwiseConvToGemmForNhwcFhwcPass(); void registerPoolingVectorizationPass(); void registerPoolingNhwcMaxVectorizationPass(); +void registerConv2dNhwcFhwcVectorizationPass(); void registerLowerBudPass(); void registerLowerDIPPass(); void registerBatchMatMulOptimizePass(); @@ -97,6 +98,8 @@ int main(int argc, char **argv) { mlir::buddy::registerPoolingVectorizationPass(); // Register Vectorization of Pooling Nhwc Max. mlir::buddy::registerPoolingNhwcMaxVectorizationPass(); + // Register Vectorization of Conv2D Nhwc Fhwc. + mlir::buddy::registerConv2dNhwcFhwcVectorizationPass(); mlir::buddy::registerLowerBudPass(); mlir::buddy::registerLowerDIPPass(); mlir::buddy::registerLowerDAPPass(); @@ -119,6 +122,7 @@ int main(int argc, char **argv) { mlir::buddy::registerConvOptimizePass(); mlir::buddy::registerConvNhwcFhwcOptimizePass(); mlir::buddy::registerConvNhwcFhwcTileOptimizePass(); + mlir::buddy::registerDepthwiseConv2DNhwcHwcOptimizePass(); mlir::buddy::registerDeviceSchedulePass(); mlir::buddy::registerLowerSchePass();