From bb5ac70bec4b036a76d811383a29f6873a9ff493 Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Fri, 8 Nov 2024 04:25:05 +0000 Subject: [PATCH 01/11] [midend/lib/Conversion/ConvVectorization] add conv2dnhwcfhwc vectorization pass and add relevant examples and tests. --- .../conv2d-nhwc-fhwc-seq.mlir | 150 +++++++++ examples/BuddyConvolution/makefile | 31 ++ .../ConvVectorization/CMakeLists.txt | 1 + .../Conv2dNhwcFhwcVectorization.cpp | 312 ++++++++++++++++++ .../conv2d-nhwc-fhwc-max-vectorization.mlir | 52 +++ tools/buddy-opt/buddy-opt.cpp | 4 + 6 files changed, 550 insertions(+) create mode 100644 examples/BuddyConvolution/conv2d-nhwc-fhwc-seq.mlir create mode 100644 midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp create mode 100644 tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir diff --git a/examples/BuddyConvolution/conv2d-nhwc-fhwc-seq.mlir b/examples/BuddyConvolution/conv2d-nhwc-fhwc-seq.mlir new file mode 100644 index 0000000000..2947f8c173 --- /dev/null +++ b/examples/BuddyConvolution/conv2d-nhwc-fhwc-seq.mlir @@ -0,0 +1,150 @@ +// RUN: buddy-opt %s \ +// RUN: -convert-vector-to-scf \ +// RUN: -lower-affine \ +// RUN: -arith-bufferize \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm \ +// RUN: -convert-arith-to-llvm \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -O3 -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +// Using `8` as the vector size. +module { + func.func private @printMemrefF32(memref<*xf32>) + func.func private @rtclock() -> f64 + + func.func @conv_2d_nhwc_fhwc(%arg0: memref, %arg1: memref, %arg2: memref) { + %f0 = arith.constant 0. : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %vl_step = arith.constant 8 : index + %vec0 = vector.splat %f0 : vector<8xf32> + %n = memref.dim %arg0, %c0 : memref + %c = memref.dim %arg0, %c3 : memref + %f = memref.dim %arg1, %c0 : memref + %h_k = memref.dim %arg1, %c1 : memref + %w_k = memref.dim %arg1, %c2 : memref + %h_o = memref.dim %arg2, %c1 : memref + %w_o = memref.dim %arg2, %c2 : memref + + // Calculate the upper bound for vectorized processing + // - Subtract `vl_step` is to avoid overflow at the vectorization tail. + // - Add 1 to ensure the final loop runs when the workload length + // is divisible by the vector size. + %upbound_tmp = arith.subi %c, %vl_step : index + %upbound = arith.addi %upbound_tmp, %c1 : index + + // Output is NHoWoF + affine.for %idx_n = %c0 to %n { + affine.for %idx_h_o = %c0 to %h_o { + affine.for %idx_w_o = %c0 to %w_o { + affine.for %idx_f = %c0 to %f { + %iter_idx, %iter_value = scf.for %idx_c = %c0 to %upbound step %vl_step iter_args(%iter_init = %c0, %iter_value0 = %f0) -> (index, f32) { + %tmp8 = affine.for %idx_h_k = %c0 to %h_k iter_args(%tmp9 = %iter_value0) -> (f32) { + %tmp6 = affine.for %idx_w_k = %c0 to %w_k iter_args(%tmp7 = %tmp9) -> (f32) { + %in_iter_h = arith.addi %idx_h_k, %idx_h_o : index + %in_iter_w = arith.addi %idx_w_k, %idx_w_o : index + %input_vec = vector.load %arg0[%idx_n, %in_iter_h, %in_iter_w, %idx_c] : memref, vector<8xf32> + %kernel_vec = vector.load %arg1[%idx_f, %idx_h_k, %idx_w_k, %idx_c] : memref, vector<8xf32> + %tmp_vec0 = arith.mulf %kernel_vec, %input_vec : vector<8xf32> + %tmp_val = vector.reduction , %tmp_vec0 : vector<8xf32> into f32 + %tmp4 = arith.addf %tmp7, %tmp_val : f32 + affine.yield %tmp4 : f32 + } + affine.yield %tmp6 : f32 + } + %tmp11 = arith.addi %iter_init, %vl_step : index + scf.yield %tmp11, %tmp8 : index, f32 + } + // Compute the tail size and Process the remaining elements + // using masked vector operations. + %tail_size = arith.subi %c, %iter_idx : index + %mask = vector.create_mask %tail_size : vector<8xi1> + %tmp8 = affine.for %idx_h_k = %c0 to %h_k iter_args(%tmp9 = %iter_value) -> (f32) { + %tmp6 = affine.for %idx_w_k = %c0 to %w_k iter_args(%tmp7 = %tmp9) -> (f32) { + %in_iter_h = arith.addi %idx_h_k, %idx_h_o : index + %in_iter_w = arith.addi %idx_w_k, %idx_w_o : index + %input_vec = vector.maskedload %arg0[%idx_n, %in_iter_h, %in_iter_w, %iter_idx], %mask, %vec0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> + %kernel_vec = vector.maskedload %arg1[%idx_f, %idx_h_k, %idx_w_k, %iter_idx], %mask, %vec0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> + %tmp_vec0 = arith.mulf %kernel_vec, %input_vec : vector<8xf32> + %tmp_val = vector.reduction , %tmp_vec0 : vector<8xf32> into f32 + %tmp4 = arith.addf %tmp7, %tmp_val : f32 + affine.yield %tmp4 : f32 + } + affine.yield %tmp6 : f32 + } + memref.store %tmp8, %arg2[%idx_n, %idx_h_o, %idx_w_o, %idx_f] : memref + } + } + } + } + return + } + + func.func @alloc_f32(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: f32) -> memref { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = memref.alloc(%arg0, %arg1, %arg2, %arg3) : memref + scf.for %idx0 = %c0 to %arg0 step %c1 { + scf.for %idx1 = %c0 to %arg1 step %c1 { + scf.for %idx2 = %c0 to %arg2 step %c1 { + scf.for %idx3 = %c0 to %arg3 step %c1 { + memref.store %arg4, %0[%idx0, %idx1, %idx2, %idx3] : memref + } + } + } + } + return %0 : memref + } + + func.func @main() { + %f0 = arith.constant 0.000000e+00 : f32 + %f2 = arith.constant 2.000000e+00 : f32 + %f3 = arith.constant 3.000000e+00 : f32 + + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c8 = arith.constant 8 : index + %c12 = arith.constant 12 : index + %c16 = arith.constant 15 : index + %c24 = arith.constant 24 : index + %c28 = arith.constant 28 : index + + %v0 = call @alloc_f32(%c1, %c28, %c28, %c1, %f2) : (index, index, index, index, f32) -> memref + %v1 = call @alloc_f32(%c6, %c5, %c5, %c1, %f3) : (index, index, index, index, f32) -> memref + %v2 = call @alloc_f32(%c1, %c24, %c24, %c6, %f0) : (index, index, index, index, f32) -> memref + + %t_start = call @rtclock() : () -> f64 + call @conv_2d_nhwc_fhwc(%v0, %v1, %v2) : (memref, memref, memref) -> () + %t_end = call @rtclock() : () -> f64 + + // All the elements of the MemRef are the same, + // only check the first line to verify the correctness. + // CHECK: Unranked Memref + // CHECK: [ + // CHECK: [ + // CHECK: [ + // CHECK: [150{{(, 150)*}}], + %print_v2 = memref.cast %v2 : memref to memref<*xf32> + call @printMemrefF32(%print_v2) : (memref<*xf32>) -> () + + %time = arith.subf %t_end, %t_start : f64 + vector.print %time : f64 + + memref.dealloc %v0 : memref + memref.dealloc %v1 : memref + memref.dealloc %v2 : memref + + return + } +} diff --git a/examples/BuddyConvolution/makefile b/examples/BuddyConvolution/makefile index 1962643766..6594de4a22 100644 --- a/examples/BuddyConvolution/makefile +++ b/examples/BuddyConvolution/makefile @@ -125,3 +125,34 @@ conv2d-nhwc-fhwc-opt-aot: -L${MLIR_LIB} -lmlir_runner_utils -lmlir_c_runner_utils \ -o a.out @LD_LIBRARY_PATH=${MLIR_LIB} ./a.out + +conv2d-nhwc-fhwc-seq-run: + @${BUDDY_OPT} ./conv2d-nhwc-fhwc-seq.mlir \ + -convert-vector-to-scf \ + -lower-affine \ + -arith-bufferize \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +conv2d-nhwc-fhwc-seq-aot: + @${BUDDY_OPT} ./conv2d-nhwc-fhwc-seq.mlir \ + -convert-vector-to-scf \ + -lower-affine \ + -arith-bufferize \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} -mlir-to-llvmir -o log.ll + ${CLANG} log.ll -O3 \ + -L${MLIR_LIB} -lmlir_runner_utils -lmlir_c_runner_utils \ + -o a.out + @LD_LIBRARY_PATH=${MLIR_LIB} ./a.out diff --git a/midend/lib/Conversion/ConvVectorization/CMakeLists.txt b/midend/lib/Conversion/ConvVectorization/CMakeLists.txt index fce89520b6..4b9e527228 100644 --- a/midend/lib/Conversion/ConvVectorization/CMakeLists.txt +++ b/midend/lib/Conversion/ConvVectorization/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(CBConvVectorization CBConvVectorization.cpp GEMMPointwiseConv2DNhwcHwcf.cpp PoolingVectorization.cpp + Conv2dNhwcFhwcVectorization.cpp LINK_LIBS PUBLIC BuddyUtils diff --git a/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp b/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp new file mode 100644 index 0000000000..3e06833849 --- /dev/null +++ b/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp @@ -0,0 +1,312 @@ +//===--------Conv2dNhwcFhwcVectorization.cpp-------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Pooling Nhwc Max Vectorization. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/ValueRange.h" +#include "llvm/ADT/ArrayRef.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "Utils/Utils.h" + +using namespace mlir; +using namespace vector; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { + +class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { +public: + explicit Conv2dNhwcFhwcVectorizationPattern(MLIRContext *context, + int64_t vecsizeParam) + : ConversionPattern(linalg::Conv2DNhwcFhwcOp::getOperationName(), 1, + context) { + vecsize = vecsizeParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto ctx = op->getContext(); + + // Get input, kernel and output. + Value input = op->getOperand(0); + Value kernel = op->getOperand(1); + Value output = op->getOperand(2); + + // Get i1 as the element type for mask vector. + IntegerType i1 = IntegerType::get(ctx, 1); + VectorType vectorMaskTy = mlir::VectorType::get({vecsize}, i1); + // Get ElementType of input and create pass through vector. + Type elementTy = input.getType().cast().getElementType(); + VectorType vectorTy = mlir::VectorType::get({vecsize}, elementTy); + + // Get Constants. + const Value c0 = rewriter.create(loc, 0); + const Value c1 = rewriter.create(loc, 1); + const Value c2 = rewriter.create(loc, 2); + const Value c3 = rewriter.create(loc, 3); + const Value vl_step = rewriter.create(loc, vecsize); + const Value zero = + buddy::insertZeroConstantOp(ctx, rewriter, loc, elementTy); + + // Create pass through vector. + Value passThroughVec = rewriter.create(loc, vectorTy, zero); + + // Get Dimensions of Input. + Value batch = rewriter.create(loc, input, c0); + Value channels = rewriter.create(loc, input, c3); + + // Get Dimensions of Kernel. + Value f_o = rewriter.create(loc, kernel, c0); + Value height_k = rewriter.create(loc, kernel, c1); + Value width_k = rewriter.create(loc, kernel, c2); + + // Get Dimensions of Outputs. + Value height_o = rewriter.create(loc, output, c1); + Value width_o = rewriter.create(loc, output, c2); + + // Calculate the upper bound for vectorized processing + // - Subtract `vl_step` is to avoid overflow at the vectorization tail. + // - Add 1 to ensure the final loop runs when the workload length + // is divisible by the vector size. + Value upperBound_tmp = + rewriter.create(loc, channels, vl_step); + Value upperBound = rewriter.create(loc, upperBound_tmp, c1); + + SmallVector lowerBounds(4, c0); + SmallVector uperBounds{batch, height_o, width_o, f_o}; + SmallVector steps(4, /*Value=*/1); + affine::buildAffineLoopNest( + rewriter, loc, lowerBounds, uperBounds, steps, + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + // Create vecsize mining loop. + auto iter_val = builder.create( + loc, c0, upperBound, /*Step=*/vl_step, ValueRange{c0, zero}, + [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, + ValueRange itrArgs) { + auto tmp0 = nestedBuilder.create( + nestedLoc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{height_k}, builder.getDimIdentityMap(), + /*Step=*/1, ValueRange{itrArgs[1]}, + [&](OpBuilder &builder, Location loc, Value iv0, + ValueRange itrArgs0) { + auto tmp1 = builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{width_k}, builder.getDimIdentityMap(), + /*Step=*/1, ValueRange{itrArgs0[0]}, + [&](OpBuilder &builder, Location loc, Value iv1, + ValueRange itrArgs1) { + Value inputHeight = + builder.create(loc, ivs[1], iv0); + Value inputWidth = + builder.create(loc, ivs[2], iv1); + Value inputVector = builder.create( + loc, vectorTy, input, + ValueRange{ivs[0], inputHeight, inputWidth, + iv}); + Value kernelVector = builder.create( + loc, vectorTy, kernel, + ValueRange{ivs[3], iv0, iv1, iv}); + // FMA + Value resultVal; + if (auto ty = + llvm::dyn_cast(elementTy)) { + Value tmpVec = builder.create( + loc, inputVector, kernelVector); + Value tmpVal = + builder.create( + loc, vector::CombiningKind::ADD, tmpVec, + ::mlir::arith::FastMathFlags::none); + resultVal = builder.create( + loc, tmpVal, itrArgs1[0]); + } else { + Value tmpVec = builder.create( + loc, inputVector, kernelVector); + Value tmpVal = + builder.create( + loc, vector::CombiningKind::ADD, tmpVec, + ::mlir::arith::FastMathFlags::none); + resultVal = builder.create( + loc, tmpVal, itrArgs1[0]); + } + builder.create(loc, + resultVal); + }); + nestedBuilder.create( + nestedLoc, tmp1.getResult(0)); + }); + Value idx = builder.create(loc, iv, vl_step); + builder.create( + loc, ValueRange{idx, tmp0.getResult(0)}); + }); + // Compute the tail size and Process the remaining elements + // using masked vector operations. + Value idx = iter_val.getResult(0); + Value tailSize = builder.create(loc, channels, idx); + Value tailMask = + builder.create(loc, vectorMaskTy, tailSize); + auto tmp0 = builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{height_k}, builder.getDimIdentityMap(), + /*Step=*/1, ValueRange{iter_val.getResult(1)}, + [&](OpBuilder &builder, Location loc, Value iv0, + ValueRange itrArgs0) { + auto tmp1 = builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{width_k}, builder.getDimIdentityMap(), + /*Step=*/1, ValueRange{itrArgs0[0]}, + [&](OpBuilder &builder, Location loc, Value iv1, + ValueRange itrArgs1) { + Value inputHeight = + builder.create(loc, ivs[1], iv0); + Value inputWidth = + builder.create(loc, ivs[2], iv1); + Value inputVec = builder.create( + loc, vectorTy, input, + ValueRange{ivs[0], inputHeight, inputWidth, idx}, + tailMask, passThroughVec); + Value kernelVec = builder.create( + loc, vectorTy, kernel, + ValueRange{ivs[3], iv0, iv1, idx}, tailMask, + passThroughVec); + // FMA + Value resultVal; + if (auto ty = llvm::dyn_cast(elementTy)) { + Value tmpVec = builder.create( + loc, inputVec, kernelVec); + Value tmpVal = builder.create( + loc, vector::CombiningKind::ADD, tmpVec, + ::mlir::arith::FastMathFlags::none); + resultVal = builder.create(loc, tmpVal, + itrArgs1[0]); + } else { + Value tmpVec = builder.create( + loc, inputVec, kernelVec); + Value tmpVal = builder.create( + loc, vector::CombiningKind::ADD, tmpVec, + ::mlir::arith::FastMathFlags::none); + resultVal = builder.create(loc, tmpVal, + itrArgs1[0]); + } + builder.create(loc, resultVal); + }); + builder.create(loc, tmp1.getResult(0)); + }); + builder.create( + loc, tmp0.getResult(0), output, + ValueRange{ivs[0], ivs[1], ivs[2], ivs[3]}); + }); + // Remove the origin convolution operation. + rewriter.eraseOp(op); + return success(); + } + +private: + int64_t vecsize; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Conv2dNhwcFhwcVectorizationPass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering linalg pooling max operations to mixture of +/// Arith + Vector operations. +namespace { +class Conv2dNhwcFhwcVectorizationPass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(Conv2dNhwcFhwcVectorizationPass) + StringRef getArgument() const final { + return "conv2d-nhwc-fhwc-vectorization"; + } + StringRef getDescription() const final { + return "Conv2d_Nhwc_Fhwc Vectorization."; + } + Conv2dNhwcFhwcVectorizationPass() = default; + Conv2dNhwcFhwcVectorizationPass(const Conv2dNhwcFhwcVectorizationPass &) {} + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + Option vecsize{*this, "vec-size", + llvm::cl::desc("Specify vector type size."), + llvm::cl::init(8)}; +}; +} // end anonymous namespace. + +void Conv2dNhwcFhwcVectorizationPass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target.addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, vecsize); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +namespace mlir { +namespace buddy { +void registerConv2dNhwcFhwcVectorizationPass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir b/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir new file mode 100644 index 0000000000..d5e42cdece --- /dev/null +++ b/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir @@ -0,0 +1,52 @@ +// RUN: buddy-opt -conv2d-nhwc-fhwc-vectorization %s | FileCheck %s + +// CHECK: #map = affine_map<(d0) -> (d0)> +// CHECK: module { +// CHECK: affine.for %arg3 = #map(%c0) to #map(%dim) { +// CHECK-NEXT: affine.for %arg4 = #map(%c0) to #map(%dim_4) { +// CHECK-NEXT: affine.for %arg5 = #map(%c0) to #map(%dim_5) { +// CHECK-NEXT: affine.for %arg6 = #map(%c0) to #map(%dim_1) { +// CHECK-NEXT: %3:2 = scf.for %arg7 = %c0 to %2 step %c8 iter_args(%arg8 = %c0, %arg9 = %cst) -> (index, f32) { +// CHECK-NEXT: %7 = affine.for %arg10 = #map(%c0) to #map(%dim_2) iter_args(%arg11 = %arg9) -> (f32) { +// CHECK-NEXT: %9 = affine.for %arg12 = #map(%c0) to #map(%dim_3) iter_args(%arg13 = %arg11) -> (f32) { +// CHECK-NEXT: %10 = arith.addi %arg4, %arg10 : index +// CHECK-NEXT: %11 = arith.addi %arg5, %arg12 : index +// CHECK-NEXT: %12 = vector.load %arg0[%arg3, %10, %11, %arg7] : memref, vector<8xf32> +// CHECK-NEXT: %13 = vector.load %arg1[%arg6, %arg10, %arg12, %arg7] : memref, vector<8xf32> +// CHECK-NEXT: %14 = arith.mulf %12, %13 : vector<8xf32> +// CHECK-NEXT: %15 = vector.reduction , %14 : vector<8xf32> into f32 +// CHECK-NEXT: %16 = arith.addf %15, %arg13 : f32 +// CHECK-NEXT: affine.yield %16 : f32 +// CHECK-NEXT: } +// CHECK-NEXT: affine.yield %9 : f32 +// CHECK-NEXT: } +// CHECK-NEXT: %8 = arith.addi %arg7, %c8 : index +// CHECK-NEXT: scf.yield %8, %7 : index, f32 +// CHECK-NEXT: } +// CHECK-NEXT: %4 = arith.subi %dim_0, %3#0 : index +// CHECK-NEXT: %5 = vector.create_mask %4 : vector<8xi1> +// CHECK-NEXT: %6 = affine.for %arg7 = #map(%c0) to #map(%dim_2) iter_args(%arg8 = %3#1) -> (f32) { +// CHECK-NEXT: %7 = affine.for %arg9 = #map(%c0) to #map(%dim_3) iter_args(%arg10 = %arg8) -> (f32) { +// CHECK-NEXT: %8 = arith.addi %arg4, %arg7 : index +// CHECK-NEXT: %9 = arith.addi %arg5, %arg9 : index +// CHECK-NEXT: %10 = vector.maskedload %arg0[%arg3, %8, %9, %3#0], %5, %0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> +// CHECK-NEXT: %11 = vector.maskedload %arg1[%arg6, %arg7, %arg9, %3#0], %5, %0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> +// CHECK-NEXT: %12 = arith.mulf %10, %11 : vector<8xf32> +// CHECK-NEXT: %13 = vector.reduction , %12 : vector<8xf32> into f32 +// CHECK-NEXT: %14 = arith.addf %13, %arg10 : f32 +// CHECK-NEXT: affine.yield %14 : f32 +// CHECK-NEXT: } +// CHECK-NEXT: affine.yield %7 : f32 +// CHECK-NEXT: } +// CHECK-NEXT: memref.store %6, %arg2[%arg3, %arg4, %arg5, %arg6] : memref +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +func.func @conv_2d_nhwc_fhwc(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%arg0, %arg1: memref, memref) + outs (%arg2: memref) + return +} diff --git a/tools/buddy-opt/buddy-opt.cpp b/tools/buddy-opt/buddy-opt.cpp index 08e172f8bc..4d98c6574d 100644 --- a/tools/buddy-opt/buddy-opt.cpp +++ b/tools/buddy-opt/buddy-opt.cpp @@ -55,6 +55,7 @@ void registerConvVectorizationPass(); void registerPointwiseConvToGemmPass(); void registerPointwiseConvToGemmForNhwcFhwcPass(); void registerPoolingVectorizationPass(); +void registerConv2dNhwcFhwcVectorizationPass(); void registerLowerBudPass(); void registerLowerDIPPass(); void registerBatchMatMulOptimizePass(); @@ -94,6 +95,8 @@ int main(int argc, char **argv) { mlir::buddy::registerConvVectorizationPass(); // Register Vectorization of Pooling. mlir::buddy::registerPoolingVectorizationPass(); + // Register Vectorization of Conv2D Nhwc Fhwc. + mlir::buddy::registerConv2dNhwcFhwcVectorizationPass(); mlir::buddy::registerLowerBudPass(); mlir::buddy::registerLowerDIPPass(); mlir::buddy::registerLowerDAPPass(); @@ -116,6 +119,7 @@ int main(int argc, char **argv) { mlir::buddy::registerConvOptimizePass(); mlir::buddy::registerConvNhwcFhwcOptimizePass(); mlir::buddy::registerConvNhwcFhwcTileOptimizePass(); + mlir::buddy::registerDepthwiseConv2DNhwcHwcOptimizePass(); mlir::buddy::registerDeviceSchedulePass(); mlir::buddy::registerLowerSchePass(); From 8ccc2c597a0bfd1747d66aefc89329400b5bf12d Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Thu, 21 Nov 2024 04:25:40 +0000 Subject: [PATCH 02/11] [examples/BuddyConvolution] update convolution-related tests --- ...hwc-seq.mlir => conv2d-nhwc-fhwc-vec.mlir} | 0 .../BuddyConvolution/conv2d-nhwc-fhwc.mlir | 1 + examples/BuddyConvolution/makefile | 37 ++++++++++++++++++- 3 files changed, 36 insertions(+), 2 deletions(-) rename examples/BuddyConvolution/{conv2d-nhwc-fhwc-seq.mlir => conv2d-nhwc-fhwc-vec.mlir} (100%) diff --git a/examples/BuddyConvolution/conv2d-nhwc-fhwc-seq.mlir b/examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir similarity index 100% rename from examples/BuddyConvolution/conv2d-nhwc-fhwc-seq.mlir rename to examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir diff --git a/examples/BuddyConvolution/conv2d-nhwc-fhwc.mlir b/examples/BuddyConvolution/conv2d-nhwc-fhwc.mlir index 90759355e9..8928d027a3 100644 --- a/examples/BuddyConvolution/conv2d-nhwc-fhwc.mlir +++ b/examples/BuddyConvolution/conv2d-nhwc-fhwc.mlir @@ -1,4 +1,5 @@ // RUN: buddy-opt %s \ +// RUN: -conv2d-nhwc-fhwc-vectorization \ // RUN: -convert-linalg-to-loops \ // RUN: -lower-affine \ // RUN: -arith-bufferize \ diff --git a/examples/BuddyConvolution/makefile b/examples/BuddyConvolution/makefile index 6594de4a22..4e47d3368b 100644 --- a/examples/BuddyConvolution/makefile +++ b/examples/BuddyConvolution/makefile @@ -126,7 +126,40 @@ conv2d-nhwc-fhwc-opt-aot: -o a.out @LD_LIBRARY_PATH=${MLIR_LIB} ./a.out -conv2d-nhwc-fhwc-seq-run: +conv2d-nhwc-fhwc-vectorization-run: + @${BUDDY_OPT} ./conv2d-nhwc-fhwc.mlir \ + -conv2d-nhwc-fhwc-vectorization \ + -convert-linalg-to-loops \ + -lower-affine \ + -arith-bufferize \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +conv2d-nhwc-fhwc-vectorization-aot: + @${BUDDY_OPT} ./conv2d-nhwc-fhwc.mlir \ + -conv2d-nhwc-fhwc-vectorization \ + -convert-linalg-to-loops \ + -lower-affine \ + -arith-bufferize \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} -mlir-to-llvmir -o log.ll + ${CLANG} log.ll ${OPT_FLAG} \ + -L${MLIR_LIB} -lmlir_runner_utils -lmlir_c_runner_utils \ + -o a.out + @LD_LIBRARY_PATH=${MLIR_LIB} ./a.out + +conv2d-nhwc-fhwc-vec-run: @${BUDDY_OPT} ./conv2d-nhwc-fhwc-seq.mlir \ -convert-vector-to-scf \ -lower-affine \ @@ -140,7 +173,7 @@ conv2d-nhwc-fhwc-seq-run: ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} -conv2d-nhwc-fhwc-seq-aot: +conv2d-nhwc-fhwc-vec-aot: @${BUDDY_OPT} ./conv2d-nhwc-fhwc-seq.mlir \ -convert-vector-to-scf \ -lower-affine \ From b5010c7ebddd53189cf5a4bfe6b26bbac39365e7 Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Fri, 22 Nov 2024 11:38:26 +0000 Subject: [PATCH 03/11] [midend/lib/Conversion/ConvVectorization] Fix cond2dnhwcfhwc vectorization pass and examples. --- .../conv2d-nhwc-fhwc-vec.mlir | 42 +++--- examples/BuddyConvolution/makefile | 4 +- .../Conv2dNhwcFhwcVectorization.cpp | 131 ++++++++++-------- .../conv2d-nhwc-fhwc-max-vectorization.mlir | 58 ++++---- 4 files changed, 132 insertions(+), 103 deletions(-) diff --git a/examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir b/examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir index 2947f8c173..5031d814eb 100644 --- a/examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir +++ b/examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir @@ -46,7 +46,8 @@ module { affine.for %idx_h_o = %c0 to %h_o { affine.for %idx_w_o = %c0 to %w_o { affine.for %idx_f = %c0 to %f { - %iter_idx, %iter_value = scf.for %idx_c = %c0 to %upbound step %vl_step iter_args(%iter_init = %c0, %iter_value0 = %f0) -> (index, f32) { + %iter_idx, %iter_value = scf.for %idx_c = %c0 to %upbound step %vl_step + iter_args(%iter_init = %c0, %iter_value0 = %f0) -> (index, f32) { %tmp8 = affine.for %idx_h_k = %c0 to %h_k iter_args(%tmp9 = %iter_value0) -> (f32) { %tmp6 = affine.for %idx_w_k = %c0 to %w_k iter_args(%tmp7 = %tmp9) -> (f32) { %in_iter_h = arith.addi %idx_h_k, %idx_h_o : index @@ -65,22 +66,26 @@ module { } // Compute the tail size and Process the remaining elements // using masked vector operations. - %tail_size = arith.subi %c, %iter_idx : index - %mask = vector.create_mask %tail_size : vector<8xi1> - %tmp8 = affine.for %idx_h_k = %c0 to %h_k iter_args(%tmp9 = %iter_value) -> (f32) { - %tmp6 = affine.for %idx_w_k = %c0 to %w_k iter_args(%tmp7 = %tmp9) -> (f32) { - %in_iter_h = arith.addi %idx_h_k, %idx_h_o : index - %in_iter_w = arith.addi %idx_w_k, %idx_w_o : index - %input_vec = vector.maskedload %arg0[%idx_n, %in_iter_h, %in_iter_w, %iter_idx], %mask, %vec0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> - %kernel_vec = vector.maskedload %arg1[%idx_f, %idx_h_k, %idx_w_k, %iter_idx], %mask, %vec0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> - %tmp_vec0 = arith.mulf %kernel_vec, %input_vec : vector<8xf32> - %tmp_val = vector.reduction , %tmp_vec0 : vector<8xf32> into f32 - %tmp4 = arith.addf %tmp7, %tmp_val : f32 - affine.yield %tmp4 : f32 + %result = scf.for %idx_c = %iter_idx to %c step %vl_step + iter_args(%tmp10 = %iter_value) -> (f32){ + %tail_size = arith.subi %c, %iter_idx : index + %mask = vector.create_mask %tail_size : vector<8xi1> + %tmp8 = affine.for %idx_h_k = %c0 to %h_k iter_args(%tmp9 = %tmp10) -> (f32) { + %tmp6 = affine.for %idx_w_k = %c0 to %w_k iter_args(%tmp7 = %tmp9) -> (f32) { + %in_iter_h = arith.addi %idx_h_k, %idx_h_o : index + %in_iter_w = arith.addi %idx_w_k, %idx_w_o : index + %input_vec = vector.maskedload %arg0[%idx_n, %in_iter_h, %in_iter_w, %iter_idx], %mask, %vec0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> + %kernel_vec = vector.maskedload %arg1[%idx_f, %idx_h_k, %idx_w_k, %iter_idx], %mask, %vec0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> + %tmp_vec0 = arith.mulf %kernel_vec, %input_vec : vector<8xf32> + %tmp_val = vector.reduction , %tmp_vec0 : vector<8xf32> into f32 + %tmp4 = arith.addf %tmp7, %tmp_val : f32 + affine.yield %tmp4 : f32 + } + affine.yield %tmp6 : f32 } - affine.yield %tmp6 : f32 + scf.yield %tmp8 : f32 } - memref.store %tmp8, %arg2[%idx_n, %idx_h_o, %idx_w_o, %idx_f] : memref + memref.store %result, %arg2[%idx_n, %idx_h_o, %idx_w_o, %idx_f] : memref } } } @@ -116,9 +121,13 @@ module { %c6 = arith.constant 6 : index %c8 = arith.constant 8 : index %c12 = arith.constant 12 : index - %c16 = arith.constant 15 : index + %c16 = arith.constant 16 : index %c24 = arith.constant 24 : index %c28 = arith.constant 28 : index + + // %v0 = call @alloc_f32(%c1, %c12, %c12, %c6, %f2) : (index, index, index, index, f32) -> memref + // %v1 = call @alloc_f32(%c16, %c5, %c5, %c6, %f3) : (index, index, index, index, f32) -> memref + // %v2 = call @alloc_f32(%c1, %c8, %c8, %c16, %f0) : (index, index, index, index, f32) -> memref %v0 = call @alloc_f32(%c1, %c28, %c28, %c1, %f2) : (index, index, index, index, f32) -> memref %v1 = call @alloc_f32(%c6, %c5, %c5, %c1, %f3) : (index, index, index, index, f32) -> memref @@ -144,7 +153,6 @@ module { memref.dealloc %v0 : memref memref.dealloc %v1 : memref memref.dealloc %v2 : memref - return } } diff --git a/examples/BuddyConvolution/makefile b/examples/BuddyConvolution/makefile index 4e47d3368b..42cf9d6313 100644 --- a/examples/BuddyConvolution/makefile +++ b/examples/BuddyConvolution/makefile @@ -160,7 +160,7 @@ conv2d-nhwc-fhwc-vectorization-aot: @LD_LIBRARY_PATH=${MLIR_LIB} ./a.out conv2d-nhwc-fhwc-vec-run: - @${BUDDY_OPT} ./conv2d-nhwc-fhwc-seq.mlir \ + @${BUDDY_OPT} ./conv2d-nhwc-fhwc-vec.mlir \ -convert-vector-to-scf \ -lower-affine \ -arith-bufferize \ @@ -174,7 +174,7 @@ conv2d-nhwc-fhwc-vec-run: -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} conv2d-nhwc-fhwc-vec-aot: - @${BUDDY_OPT} ./conv2d-nhwc-fhwc-seq.mlir \ + @${BUDDY_OPT} ./conv2d-nhwc-fhwc-vec.mlir \ -convert-vector-to-scf \ -lower-affine \ -arith-bufferize \ diff --git a/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp b/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp index 3e06833849..8bb94695cd 100644 --- a/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp +++ b/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp @@ -142,58 +142,68 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { [&](OpBuilder &builder, Location loc, Value iv1, ValueRange itrArgs1) { Value inputHeight = - builder.create(loc, ivs[1], iv0); + builder.create(loc, ivs[1], + iv0); Value inputWidth = - builder.create(loc, ivs[2], iv1); - Value inputVector = builder.create( - loc, vectorTy, input, - ValueRange{ivs[0], inputHeight, inputWidth, - iv}); - Value kernelVector = builder.create( - loc, vectorTy, kernel, - ValueRange{ivs[3], iv0, iv1, iv}); + builder.create(loc, ivs[2], + iv1); + Value inputVector = + builder.create( + loc, vectorTy, input, + ValueRange{ivs[0], inputHeight, inputWidth, + iv}); + Value kernelVector = + builder.create( + loc, vectorTy, kernel, + ValueRange{ivs[3], iv0, iv1, iv}); // FMA Value resultVal; if (auto ty = llvm::dyn_cast(elementTy)) { - Value tmpVec = builder.create( - loc, inputVector, kernelVector); + Value tmpVec = + builder.create( + loc, inputVector, kernelVector); Value tmpVal = builder.create( - loc, vector::CombiningKind::ADD, tmpVec, - ::mlir::arith::FastMathFlags::none); - resultVal = builder.create( - loc, tmpVal, itrArgs1[0]); + loc, vector::CombiningKind::ADD, tmpVec, ::mlir::arith::FastMathFlags::none); + resultVal = + builder.create( + loc, tmpVal, itrArgs1[0]); } else { - Value tmpVec = builder.create( - loc, inputVector, kernelVector); + Value tmpVec = + builder.create( + loc, inputVector, kernelVector); Value tmpVal = builder.create( - loc, vector::CombiningKind::ADD, tmpVec, - ::mlir::arith::FastMathFlags::none); - resultVal = builder.create( - loc, tmpVal, itrArgs1[0]); + loc, vector::CombiningKind::ADD, tmpVec, ::mlir::arith::FastMathFlags::none); + resultVal = + builder.create( + loc, tmpVal, itrArgs1[0]); } - builder.create(loc, - resultVal); + builder.create( + loc, resultVal); }); nestedBuilder.create( nestedLoc, tmp1.getResult(0)); }); Value idx = builder.create(loc, iv, vl_step); - builder.create( - loc, ValueRange{idx, tmp0.getResult(0)}); + builder.create(loc, ValueRange{idx, tmp0.getResult(0)}); }); // Compute the tail size and Process the remaining elements // using masked vector operations. - Value idx = iter_val.getResult(0); - Value tailSize = builder.create(loc, channels, idx); + auto result = builder.create( + loc, c0, channels, /*Step=*/vl_step, ValueRange{iter_val.getResult(1)}, + [&](OpBuilder &builder, Location loc, Value iv, + ValueRange itrArgs) { + Value idx = iter_val.getResult(0); + Value tailSize = + builder.create(loc, channels, idx); Value tailMask = builder.create(loc, vectorMaskTy, tailSize); auto tmp0 = builder.create( loc, ValueRange{c0}, builder.getDimIdentityMap(), ValueRange{height_k}, builder.getDimIdentityMap(), - /*Step=*/1, ValueRange{iter_val.getResult(1)}, + /*Step=*/1, ValueRange{itrArgs[0]}, [&](OpBuilder &builder, Location loc, Value iv0, ValueRange itrArgs0) { auto tmp1 = builder.create( @@ -203,43 +213,52 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { [&](OpBuilder &builder, Location loc, Value iv1, ValueRange itrArgs1) { Value inputHeight = - builder.create(loc, ivs[1], iv0); + builder.create(loc, ivs[1], + iv0); Value inputWidth = - builder.create(loc, ivs[2], iv1); + builder.create(loc, ivs[2], + iv1); Value inputVec = builder.create( - loc, vectorTy, input, - ValueRange{ivs[0], inputHeight, inputWidth, idx}, - tailMask, passThroughVec); + loc, vectorTy, input, + ValueRange{ivs[0], inputHeight, inputWidth, idx}, + tailMask, passThroughVec); Value kernelVec = builder.create( - loc, vectorTy, kernel, - ValueRange{ivs[3], iv0, iv1, idx}, tailMask, - passThroughVec); + loc, vectorTy, kernel, + ValueRange{ivs[3], iv0, iv1, idx}, + tailMask, passThroughVec); // FMA Value resultVal; - if (auto ty = llvm::dyn_cast(elementTy)) { - Value tmpVec = builder.create( - loc, inputVec, kernelVec); - Value tmpVal = builder.create( - loc, vector::CombiningKind::ADD, tmpVec, - ::mlir::arith::FastMathFlags::none); - resultVal = builder.create(loc, tmpVal, - itrArgs1[0]); + if (auto ty = + llvm::dyn_cast(elementTy)) { + Value tmpVec = + builder.create( + loc, inputVec, kernelVec); + Value tmpVal = + builder.create( + loc, vector::CombiningKind::ADD, tmpVec, ::mlir::arith::FastMathFlags::none); + resultVal = + builder.create( + loc, tmpVal, itrArgs1[0]); } else { - Value tmpVec = builder.create( - loc, inputVec, kernelVec); - Value tmpVal = builder.create( - loc, vector::CombiningKind::ADD, tmpVec, - ::mlir::arith::FastMathFlags::none); - resultVal = builder.create(loc, tmpVal, - itrArgs1[0]); + Value tmpVec = + builder.create( + loc, inputVec, kernelVec); + Value tmpVal = + builder.create( + loc, vector::CombiningKind::ADD, tmpVec, ::mlir::arith::FastMathFlags::none); + resultVal = + builder.create( + loc, tmpVal, itrArgs1[0]); } - builder.create(loc, resultVal); + builder.create( + loc, resultVal); }); - builder.create(loc, tmp1.getResult(0)); + builder.create( + loc, tmp1.getResult(0)); }); - builder.create( - loc, tmp0.getResult(0), output, - ValueRange{ivs[0], ivs[1], ivs[2], ivs[3]}); + builder.create(loc, tmp0.getResult(0)); + }); + builder.create(loc, result.getResult(0), output, ValueRange{ivs[0], ivs[1], ivs[2], ivs[3]}); }); // Remove the origin convolution operation. rewriter.eraseOp(op); diff --git a/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir b/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir index d5e42cdece..7f35332611 100644 --- a/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir +++ b/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir @@ -7,43 +7,45 @@ // CHECK-NEXT: affine.for %arg5 = #map(%c0) to #map(%dim_5) { // CHECK-NEXT: affine.for %arg6 = #map(%c0) to #map(%dim_1) { // CHECK-NEXT: %3:2 = scf.for %arg7 = %c0 to %2 step %c8 iter_args(%arg8 = %c0, %arg9 = %cst) -> (index, f32) { -// CHECK-NEXT: %7 = affine.for %arg10 = #map(%c0) to #map(%dim_2) iter_args(%arg11 = %arg9) -> (f32) { -// CHECK-NEXT: %9 = affine.for %arg12 = #map(%c0) to #map(%dim_3) iter_args(%arg13 = %arg11) -> (f32) { -// CHECK-NEXT: %10 = arith.addi %arg4, %arg10 : index -// CHECK-NEXT: %11 = arith.addi %arg5, %arg12 : index -// CHECK-NEXT: %12 = vector.load %arg0[%arg3, %10, %11, %arg7] : memref, vector<8xf32> -// CHECK-NEXT: %13 = vector.load %arg1[%arg6, %arg10, %arg12, %arg7] : memref, vector<8xf32> -// CHECK-NEXT: %14 = arith.mulf %12, %13 : vector<8xf32> -// CHECK-NEXT: %15 = vector.reduction , %14 : vector<8xf32> into f32 -// CHECK-NEXT: %16 = arith.addf %15, %arg13 : f32 -// CHECK-NEXT: affine.yield %16 : f32 +// CHECK-NEXT: %5 = affine.for %arg10 = #map(%c0) to #map(%dim_2) iter_args(%arg11 = %arg9) -> (f32) { +// CHECK-NEXT: %7 = affine.for %arg12 = #map(%c0) to #map(%dim_3) iter_args(%arg13 = %arg11) -> (f32) { +// CHECK-NEXT: %8 = arith.addi %arg4, %arg10 : index +// CHECK-NEXT: %9 = arith.addi %arg5, %arg12 : index +// CHECK-NEXT: %10 = vector.load %arg0[%arg3, %8, %9, %arg7] : memref, vector<8xf32> +// CHECK-NEXT: %11 = vector.load %arg1[%arg6, %arg10, %arg12, %arg7] : memref, vector<8xf32> +// CHECK-NEXT: %12 = arith.mulf %10, %11 : vector<8xf32> +// CHECK-NEXT: %13 = vector.reduction , %12 : vector<8xf32> into f32 +// CHECK-NEXT: %14 = arith.addf %13, %arg13 : f32 +// CHECK-NEXT: affine.yield %14 : f32 // CHECK-NEXT: } -// CHECK-NEXT: affine.yield %9 : f32 +// CHECK-NEXT: affine.yield %7 : f32 // CHECK-NEXT: } -// CHECK-NEXT: %8 = arith.addi %arg7, %c8 : index -// CHECK-NEXT: scf.yield %8, %7 : index, f32 +// CHECK-NEXT: %6 = arith.addi %arg7, %c8 : index +// CHECK-NEXT: scf.yield %6, %5 : index, f32 // CHECK-NEXT: } -// CHECK-NEXT: %4 = arith.subi %dim_0, %3#0 : index -// CHECK-NEXT: %5 = vector.create_mask %4 : vector<8xi1> -// CHECK-NEXT: %6 = affine.for %arg7 = #map(%c0) to #map(%dim_2) iter_args(%arg8 = %3#1) -> (f32) { -// CHECK-NEXT: %7 = affine.for %arg9 = #map(%c0) to #map(%dim_3) iter_args(%arg10 = %arg8) -> (f32) { -// CHECK-NEXT: %8 = arith.addi %arg4, %arg7 : index -// CHECK-NEXT: %9 = arith.addi %arg5, %arg9 : index -// CHECK-NEXT: %10 = vector.maskedload %arg0[%arg3, %8, %9, %3#0], %5, %0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> -// CHECK-NEXT: %11 = vector.maskedload %arg1[%arg6, %arg7, %arg9, %3#0], %5, %0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> -// CHECK-NEXT: %12 = arith.mulf %10, %11 : vector<8xf32> -// CHECK-NEXT: %13 = vector.reduction , %12 : vector<8xf32> into f32 -// CHECK-NEXT: %14 = arith.addf %13, %arg10 : f32 -// CHECK-NEXT: affine.yield %14 : f32 +// CHECK-NEXT: %4 = scf.for %arg7 = %c0 to %dim_0 step %c8 iter_args(%arg8 = %3#1) -> (f32) { +// CHECK-NEXT: %5 = arith.subi %dim_0, %3#0 : index +// CHECK-NEXT: %6 = vector.create_mask %5 : vector<8xi1> +// CHECK-NEXT: %7 = affine.for %arg9 = #map(%c0) to #map(%dim_2) iter_args(%arg10 = %arg8) -> (f32) { +// CHECK-NEXT: %8 = affine.for %arg11 = #map(%c0) to #map(%dim_3) iter_args(%arg12 = %arg10) -> (f32) { +// CHECK-NEXT: %9 = arith.addi %arg4, %arg9 : index +// CHECK-NEXT: %10 = arith.addi %arg5, %arg11 : index +// CHECK-NEXT: %11 = vector.maskedload %arg0[%arg3, %9, %10, %3#0], %6, %0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> +// CHECK-NEXT: %12 = vector.maskedload %arg1[%arg6, %arg9, %arg11, %3#0], %6, %0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> +// CHECK-NEXT: %13 = arith.mulf %11, %12 : vector<8xf32> +// CHECK-NEXT: %14 = vector.reduction , %13 : vector<8xf32> into f32 +// CHECK-NEXT: %15 = arith.addf %14, %arg12 : f32 +// CHECK-NEXT: affine.yield %15 : f32 +// CHECK-NEXT: } +// CHECK-NEXT: affine.yield %8 : f32 // CHECK-NEXT: } -// CHECK-NEXT: affine.yield %7 : f32 +// CHECK-NEXT: scf.yield %7 : f32 // CHECK-NEXT: } -// CHECK-NEXT: memref.store %6, %arg2[%arg3, %arg4, %arg5, %arg6] : memref +// CHECK-NEXT: memref.store %4, %arg2[%arg3, %arg4, %arg5, %arg6] : memref // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } - func.func @conv_2d_nhwc_fhwc(%arg0: memref, %arg1: memref, %arg2: memref) { linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins (%arg0, %arg1: memref, memref) From e3217eb1287989edf0277d5d77f4d3cc07a0cd5a Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Sun, 24 Nov 2024 05:26:37 +0000 Subject: [PATCH 04/11] [midend/lib/Conversion/ConvVectorization] Fix cond2dnhwcfhwc vectorization pass and examples. --- .../conv2d-nhwc-fhwc-vec.mlir | 3 +- .../Conv2dNhwcFhwcVectorization.cpp | 187 +++++++++--------- .../conv2d-nhwc-fhwc-max-vectorization.mlir | 60 +++--- 3 files changed, 126 insertions(+), 124 deletions(-) diff --git a/examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir b/examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir index 5031d814eb..4f31ac9ec0 100644 --- a/examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir +++ b/examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir @@ -46,8 +46,9 @@ module { affine.for %idx_h_o = %c0 to %h_o { affine.for %idx_w_o = %c0 to %w_o { affine.for %idx_f = %c0 to %f { + %tmp_result = memref.load %arg2[%idx_n, %idx_h_o, %idx_w_o, %idx_f] : memref %iter_idx, %iter_value = scf.for %idx_c = %c0 to %upbound step %vl_step - iter_args(%iter_init = %c0, %iter_value0 = %f0) -> (index, f32) { + iter_args(%iter_init = %c0, %iter_value0 = %tmp_result) -> (index, f32) { %tmp8 = affine.for %idx_h_k = %c0 to %h_k iter_args(%tmp9 = %iter_value0) -> (f32) { %tmp6 = affine.for %idx_w_k = %c0 to %w_k iter_args(%tmp7 = %tmp9) -> (f32) { %in_iter_h = arith.addi %idx_h_k, %idx_h_o : index diff --git a/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp b/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp index 8bb94695cd..9c62dd29ba 100644 --- a/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp +++ b/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp @@ -124,9 +124,12 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { affine::buildAffineLoopNest( rewriter, loc, lowerBounds, uperBounds, steps, [&](OpBuilder &builder, Location loc, ValueRange ivs) { + Value tmp_result = builder.create( + loc, elementTy, output, + ValueRange{ivs[0], ivs[1], ivs[2], ivs[3]}); // Create vecsize mining loop. auto iter_val = builder.create( - loc, c0, upperBound, /*Step=*/vl_step, ValueRange{c0, zero}, + loc, c0, upperBound, /*Step=*/vl_step, ValueRange{c0, tmp_result}, [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, ValueRange itrArgs) { auto tmp0 = nestedBuilder.create( @@ -142,123 +145,119 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { [&](OpBuilder &builder, Location loc, Value iv1, ValueRange itrArgs1) { Value inputHeight = - builder.create(loc, ivs[1], - iv0); + builder.create(loc, ivs[1], iv0); Value inputWidth = - builder.create(loc, ivs[2], - iv1); - Value inputVector = - builder.create( - loc, vectorTy, input, - ValueRange{ivs[0], inputHeight, inputWidth, - iv}); - Value kernelVector = - builder.create( - loc, vectorTy, kernel, - ValueRange{ivs[3], iv0, iv1, iv}); + builder.create(loc, ivs[2], iv1); + Value inputVector = builder.create( + loc, vectorTy, input, + ValueRange{ivs[0], inputHeight, inputWidth, + iv}); + Value kernelVector = builder.create( + loc, vectorTy, kernel, + ValueRange{ivs[3], iv0, iv1, iv}); // FMA Value resultVal; if (auto ty = llvm::dyn_cast(elementTy)) { - Value tmpVec = - builder.create( - loc, inputVector, kernelVector); + Value tmpVec = builder.create( + loc, inputVector, kernelVector); Value tmpVal = builder.create( - loc, vector::CombiningKind::ADD, tmpVec, ::mlir::arith::FastMathFlags::none); - resultVal = - builder.create( - loc, tmpVal, itrArgs1[0]); + loc, vector::CombiningKind::ADD, tmpVec, + ::mlir::arith::FastMathFlags::none); + resultVal = builder.create( + loc, tmpVal, itrArgs1[0]); } else { - Value tmpVec = - builder.create( - loc, inputVector, kernelVector); + Value tmpVec = builder.create( + loc, inputVector, kernelVector); Value tmpVal = builder.create( - loc, vector::CombiningKind::ADD, tmpVec, ::mlir::arith::FastMathFlags::none); - resultVal = - builder.create( - loc, tmpVal, itrArgs1[0]); + loc, vector::CombiningKind::ADD, tmpVec, + ::mlir::arith::FastMathFlags::none); + resultVal = builder.create( + loc, tmpVal, itrArgs1[0]); } - builder.create( - loc, resultVal); + builder.create(loc, + resultVal); }); nestedBuilder.create( nestedLoc, tmp1.getResult(0)); }); - Value idx = builder.create(loc, iv, vl_step); - builder.create(loc, ValueRange{idx, tmp0.getResult(0)}); + Value idx = + builder.create(loc, itrArgs[0], vl_step); + builder.create( + loc, ValueRange{idx, tmp0.getResult(0)}); }); // Compute the tail size and Process the remaining elements // using masked vector operations. auto result = builder.create( - loc, c0, channels, /*Step=*/vl_step, ValueRange{iter_val.getResult(1)}, + loc, iter_val.getResult(0), channels, /*Step=*/vl_step, + ValueRange{iter_val.getResult(1)}, [&](OpBuilder &builder, Location loc, Value iv, ValueRange itrArgs) { - Value idx = iter_val.getResult(0); - Value tailSize = - builder.create(loc, channels, idx); - Value tailMask = - builder.create(loc, vectorMaskTy, tailSize); - auto tmp0 = builder.create( - loc, ValueRange{c0}, builder.getDimIdentityMap(), - ValueRange{height_k}, builder.getDimIdentityMap(), - /*Step=*/1, ValueRange{itrArgs[0]}, - [&](OpBuilder &builder, Location loc, Value iv0, - ValueRange itrArgs0) { - auto tmp1 = builder.create( + Value idx = iter_val.getResult(0); + Value tailSize = + builder.create(loc, channels, idx); + Value tailMask = + builder.create(loc, vectorMaskTy, tailSize); + auto tmp0 = builder.create( loc, ValueRange{c0}, builder.getDimIdentityMap(), - ValueRange{width_k}, builder.getDimIdentityMap(), - /*Step=*/1, ValueRange{itrArgs0[0]}, - [&](OpBuilder &builder, Location loc, Value iv1, - ValueRange itrArgs1) { - Value inputHeight = - builder.create(loc, ivs[1], - iv0); - Value inputWidth = - builder.create(loc, ivs[2], - iv1); - Value inputVec = builder.create( - loc, vectorTy, input, - ValueRange{ivs[0], inputHeight, inputWidth, idx}, - tailMask, passThroughVec); - Value kernelVec = builder.create( - loc, vectorTy, kernel, - ValueRange{ivs[3], iv0, iv1, idx}, - tailMask, passThroughVec); - // FMA - Value resultVal; - if (auto ty = - llvm::dyn_cast(elementTy)) { - Value tmpVec = - builder.create( - loc, inputVec, kernelVec); - Value tmpVal = - builder.create( - loc, vector::CombiningKind::ADD, tmpVec, ::mlir::arith::FastMathFlags::none); - resultVal = - builder.create( - loc, tmpVal, itrArgs1[0]); - } else { - Value tmpVec = - builder.create( - loc, inputVec, kernelVec); - Value tmpVal = - builder.create( - loc, vector::CombiningKind::ADD, tmpVec, ::mlir::arith::FastMathFlags::none); - resultVal = - builder.create( - loc, tmpVal, itrArgs1[0]); - } - builder.create( - loc, resultVal); + ValueRange{height_k}, builder.getDimIdentityMap(), + /*Step=*/1, ValueRange{itrArgs[0]}, + [&](OpBuilder &builder, Location loc, Value iv0, + ValueRange itrArgs0) { + auto tmp1 = builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{width_k}, builder.getDimIdentityMap(), + /*Step=*/1, ValueRange{itrArgs0[0]}, + [&](OpBuilder &builder, Location loc, Value iv1, + ValueRange itrArgs1) { + Value inputHeight = + builder.create(loc, ivs[1], iv0); + Value inputWidth = + builder.create(loc, ivs[2], iv1); + Value inputVec = builder.create( + loc, vectorTy, input, + ValueRange{ivs[0], inputHeight, inputWidth, + idx}, + tailMask, passThroughVec); + Value kernelVec = builder.create( + loc, vectorTy, kernel, + ValueRange{ivs[3], iv0, iv1, idx}, tailMask, + passThroughVec); + // FMA + Value resultVal; + if (auto ty = + llvm::dyn_cast(elementTy)) { + Value tmpVec = builder.create( + loc, inputVec, kernelVec); + Value tmpVal = + builder.create( + loc, vector::CombiningKind::ADD, tmpVec, + ::mlir::arith::FastMathFlags::none); + resultVal = builder.create( + loc, tmpVal, itrArgs1[0]); + } else { + Value tmpVec = builder.create( + loc, inputVec, kernelVec); + Value tmpVal = + builder.create( + loc, vector::CombiningKind::ADD, tmpVec, + ::mlir::arith::FastMathFlags::none); + resultVal = builder.create( + loc, tmpVal, itrArgs1[0]); + } + builder.create(loc, + resultVal); + }); + builder.create(loc, + tmp1.getResult(0)); }); - builder.create( - loc, tmp1.getResult(0)); + builder.create(loc, tmp0.getResult(0)); }); - builder.create(loc, tmp0.getResult(0)); - }); - builder.create(loc, result.getResult(0), output, ValueRange{ivs[0], ivs[1], ivs[2], ivs[3]}); + builder.create( + loc, result.getResult(0), output, + ValueRange{ivs[0], ivs[1], ivs[2], ivs[3]}); }); // Remove the origin convolution operation. rewriter.eraseOp(op); diff --git a/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir b/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir index 7f35332611..d9def90002 100644 --- a/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir +++ b/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir @@ -6,46 +6,48 @@ // CHECK-NEXT: affine.for %arg4 = #map(%c0) to #map(%dim_4) { // CHECK-NEXT: affine.for %arg5 = #map(%c0) to #map(%dim_5) { // CHECK-NEXT: affine.for %arg6 = #map(%c0) to #map(%dim_1) { -// CHECK-NEXT: %3:2 = scf.for %arg7 = %c0 to %2 step %c8 iter_args(%arg8 = %c0, %arg9 = %cst) -> (index, f32) { -// CHECK-NEXT: %5 = affine.for %arg10 = #map(%c0) to #map(%dim_2) iter_args(%arg11 = %arg9) -> (f32) { -// CHECK-NEXT: %7 = affine.for %arg12 = #map(%c0) to #map(%dim_3) iter_args(%arg13 = %arg11) -> (f32) { -// CHECK-NEXT: %8 = arith.addi %arg4, %arg10 : index -// CHECK-NEXT: %9 = arith.addi %arg5, %arg12 : index -// CHECK-NEXT: %10 = vector.load %arg0[%arg3, %8, %9, %arg7] : memref, vector<8xf32> -// CHECK-NEXT: %11 = vector.load %arg1[%arg6, %arg10, %arg12, %arg7] : memref, vector<8xf32> -// CHECK-NEXT: %12 = arith.mulf %10, %11 : vector<8xf32> -// CHECK-NEXT: %13 = vector.reduction , %12 : vector<8xf32> into f32 -// CHECK-NEXT: %14 = arith.addf %13, %arg13 : f32 -// CHECK-NEXT: affine.yield %14 : f32 -// CHECK-NEXT: } -// CHECK-NEXT: affine.yield %7 : f32 -// CHECK-NEXT: } -// CHECK-NEXT: %6 = arith.addi %arg7, %c8 : index -// CHECK-NEXT: scf.yield %6, %5 : index, f32 -// CHECK-NEXT: } -// CHECK-NEXT: %4 = scf.for %arg7 = %c0 to %dim_0 step %c8 iter_args(%arg8 = %3#1) -> (f32) { -// CHECK-NEXT: %5 = arith.subi %dim_0, %3#0 : index -// CHECK-NEXT: %6 = vector.create_mask %5 : vector<8xi1> -// CHECK-NEXT: %7 = affine.for %arg9 = #map(%c0) to #map(%dim_2) iter_args(%arg10 = %arg8) -> (f32) { -// CHECK-NEXT: %8 = affine.for %arg11 = #map(%c0) to #map(%dim_3) iter_args(%arg12 = %arg10) -> (f32) { -// CHECK-NEXT: %9 = arith.addi %arg4, %arg9 : index -// CHECK-NEXT: %10 = arith.addi %arg5, %arg11 : index -// CHECK-NEXT: %11 = vector.maskedload %arg0[%arg3, %9, %10, %3#0], %6, %0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> -// CHECK-NEXT: %12 = vector.maskedload %arg1[%arg6, %arg9, %arg11, %3#0], %6, %0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> +// CHECK-NEXT: %3 = memref.load %arg2[%arg3, %arg4, %arg5, %arg6] : memref +// CHECK-NEXT: %4:2 = scf.for %arg7 = %c0 to %2 step %c8 iter_args(%arg8 = %c0, %arg9 = %3) -> (index, f32) { +// CHECK-NEXT: %6 = affine.for %arg10 = #map(%c0) to #map(%dim_2) iter_args(%arg11 = %arg9) -> (f32) { +// CHECK-NEXT: %8 = affine.for %arg12 = #map(%c0) to #map(%dim_3) iter_args(%arg13 = %arg11) -> (f32) { +// CHECK-NEXT: %9 = arith.addi %arg4, %arg10 : index +// CHECK-NEXT: %10 = arith.addi %arg5, %arg12 : index +// CHECK-NEXT: %11 = vector.load %arg0[%arg3, %9, %10, %arg7] : memref, vector<8xf32> +// CHECK-NEXT: %12 = vector.load %arg1[%arg6, %arg10, %arg12, %arg7] : memref, vector<8xf32> // CHECK-NEXT: %13 = arith.mulf %11, %12 : vector<8xf32> // CHECK-NEXT: %14 = vector.reduction , %13 : vector<8xf32> into f32 -// CHECK-NEXT: %15 = arith.addf %14, %arg12 : f32 +// CHECK-NEXT: %15 = arith.addf %14, %arg13 : f32 // CHECK-NEXT: affine.yield %15 : f32 // CHECK-NEXT: } // CHECK-NEXT: affine.yield %8 : f32 // CHECK-NEXT: } -// CHECK-NEXT: scf.yield %7 : f32 +// CHECK-NEXT: %7 = arith.addi %arg8, %c8 : index +// CHECK-NEXT: scf.yield %7, %6 : index, f32 // CHECK-NEXT: } -// CHECK-NEXT: memref.store %4, %arg2[%arg3, %arg4, %arg5, %arg6] : memref +// CHECK-NEXT: %5 = scf.for %arg7 = %4#0 to %dim_0 step %c8 iter_args(%arg8 = %4#1) -> (f32) { +// CHECK-NEXT: %6 = arith.subi %dim_0, %4#0 : index +// CHECK-NEXT: %7 = vector.create_mask %6 : vector<8xi1> +// CHECK-NEXT: %8 = affine.for %arg9 = #map(%c0) to #map(%dim_2) iter_args(%arg10 = %arg8) -> (f32) { +// CHECK-NEXT: %9 = affine.for %arg11 = #map(%c0) to #map(%dim_3) iter_args(%arg12 = %arg10) -> (f32) { +// CHECK-NEXT: %10 = arith.addi %arg4, %arg9 : index +// CHECK-NEXT: %11 = arith.addi %arg5, %arg11 : index +// CHECK-NEXT: %12 = vector.maskedload %arg0[%arg3, %10, %11, %4#0], %7, %0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> +// CHECK-NEXT: %13 = vector.maskedload %arg1[%arg6, %arg9, %arg11, %4#0], %7, %0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> +// CHECK-NEXT: %14 = arith.mulf %12, %13 : vector<8xf32> +// CHECK-NEXT: %15 = vector.reduction , %14 : vector<8xf32> into f32 +// CHECK-NEXT: %16 = arith.addf %15, %arg12 : f32 +// CHECK-NEXT: affine.yield %16 : f32 +// CHECK-NEXT: } +// CHECK-NEXT: affine.yield %9 : f32 +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %8 : f32 +// CHECK-NEXT: } +// CHECK-NEXT: memref.store %5, %arg2[%arg3, %arg4, %arg5, %arg6] : memref // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } + func.func @conv_2d_nhwc_fhwc(%arg0: memref, %arg1: memref, %arg2: memref) { linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins (%arg0, %arg1: memref, memref) From fa939c65512770322229e95b6e6fff36c80fdeab Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Mon, 25 Nov 2024 06:02:43 +0000 Subject: [PATCH 05/11] [midend/lib/Conversion/ConvVectorization] Fix cond2dnhwcfhwc vectorization pass and examples. --- .../conv2d-nhwc-fhwc-vec.mlir | 13 ++--- .../Conv2dNhwcFhwcVectorization.cpp | 33 +++++++----- .../conv2d-nhwc-fhwc-max-vectorization.mlir | 50 ++++++++++--------- 3 files changed, 53 insertions(+), 43 deletions(-) diff --git a/examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir b/examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir index 4f31ac9ec0..de33973c41 100644 --- a/examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir +++ b/examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir @@ -67,11 +67,11 @@ module { } // Compute the tail size and Process the remaining elements // using masked vector operations. - %result = scf.for %idx_c = %iter_idx to %c step %vl_step - iter_args(%tmp10 = %iter_value) -> (f32){ - %tail_size = arith.subi %c, %iter_idx : index + %tail_size = arith.subi %c, %iter_idx : index + %3 = arith.cmpi sgt, %tail_size, %c0 : index + scf.if %3 { %mask = vector.create_mask %tail_size : vector<8xi1> - %tmp8 = affine.for %idx_h_k = %c0 to %h_k iter_args(%tmp9 = %tmp10) -> (f32) { + %tmp8 = affine.for %idx_h_k = %c0 to %h_k iter_args(%tmp9 = %iter_value) -> (f32) { %tmp6 = affine.for %idx_w_k = %c0 to %w_k iter_args(%tmp7 = %tmp9) -> (f32) { %in_iter_h = arith.addi %idx_h_k, %idx_h_o : index %in_iter_w = arith.addi %idx_w_k, %idx_w_o : index @@ -84,9 +84,10 @@ module { } affine.yield %tmp6 : f32 } - scf.yield %tmp8 : f32 + memref.store %tmp8, %arg2[%idx_n, %idx_h_o, %idx_w_o, %idx_f] : memref + } else { + memref.store %iter_value, %arg2[%idx_n, %idx_h_o, %idx_w_o, %idx_f] : memref } - memref.store %result, %arg2[%idx_n, %idx_h_o, %idx_w_o, %idx_f] : memref } } } diff --git a/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp b/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp index 9c62dd29ba..c7cef4e24a 100644 --- a/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp +++ b/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp @@ -190,20 +190,21 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { }); // Compute the tail size and Process the remaining elements // using masked vector operations. - auto result = builder.create( - loc, iter_val.getResult(0), channels, /*Step=*/vl_step, - ValueRange{iter_val.getResult(1)}, - [&](OpBuilder &builder, Location loc, Value iv, - ValueRange itrArgs) { - Value idx = iter_val.getResult(0); - Value tailSize = - builder.create(loc, channels, idx); + + Value idx = iter_val.getResult(0); + Value tailSize = builder.create(loc, channels, idx); + Value tailCond = rewriter.create( + loc, arith::CmpIPredicate::sge, tailSize, c0); + // If the current column does not reach the tail. + builder.create( + loc, tailCond, + [&](OpBuilder &builder, Location loc) { Value tailMask = builder.create(loc, vectorMaskTy, tailSize); auto tmp0 = builder.create( loc, ValueRange{c0}, builder.getDimIdentityMap(), ValueRange{height_k}, builder.getDimIdentityMap(), - /*Step=*/1, ValueRange{itrArgs[0]}, + /*Step=*/1, ValueRange{iter_val.getResult(1)}, [&](OpBuilder &builder, Location loc, Value iv0, ValueRange itrArgs0) { auto tmp1 = builder.create( @@ -253,11 +254,17 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { builder.create(loc, tmp1.getResult(0)); }); - builder.create(loc, tmp0.getResult(0)); + builder.create( + loc, tmp0.getResult(0), output, + ValueRange{ivs[0], ivs[1], ivs[2], ivs[3]}); + builder.create(loc); + }, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, iter_val.getResult(1), output, + ValueRange{ivs[0], ivs[1], ivs[2], ivs[3]}); + builder.create(loc); }); - builder.create( - loc, result.getResult(0), output, - ValueRange{ivs[0], ivs[1], ivs[2], ivs[3]}); }); // Remove the origin convolution operation. rewriter.eraseOp(op); diff --git a/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir b/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir index d9def90002..a73ac1df47 100644 --- a/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir +++ b/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir @@ -8,41 +8,43 @@ // CHECK-NEXT: affine.for %arg6 = #map(%c0) to #map(%dim_1) { // CHECK-NEXT: %3 = memref.load %arg2[%arg3, %arg4, %arg5, %arg6] : memref // CHECK-NEXT: %4:2 = scf.for %arg7 = %c0 to %2 step %c8 iter_args(%arg8 = %c0, %arg9 = %3) -> (index, f32) { -// CHECK-NEXT: %6 = affine.for %arg10 = #map(%c0) to #map(%dim_2) iter_args(%arg11 = %arg9) -> (f32) { -// CHECK-NEXT: %8 = affine.for %arg12 = #map(%c0) to #map(%dim_3) iter_args(%arg13 = %arg11) -> (f32) { -// CHECK-NEXT: %9 = arith.addi %arg4, %arg10 : index -// CHECK-NEXT: %10 = arith.addi %arg5, %arg12 : index -// CHECK-NEXT: %11 = vector.load %arg0[%arg3, %9, %10, %arg7] : memref, vector<8xf32> -// CHECK-NEXT: %12 = vector.load %arg1[%arg6, %arg10, %arg12, %arg7] : memref, vector<8xf32> -// CHECK-NEXT: %13 = arith.mulf %11, %12 : vector<8xf32> -// CHECK-NEXT: %14 = vector.reduction , %13 : vector<8xf32> into f32 -// CHECK-NEXT: %15 = arith.addf %14, %arg13 : f32 -// CHECK-NEXT: affine.yield %15 : f32 +// CHECK-NEXT: %7 = affine.for %arg10 = #map(%c0) to #map(%dim_2) iter_args(%arg11 = %arg9) -> (f32) { +// CHECK-NEXT: %9 = affine.for %arg12 = #map(%c0) to #map(%dim_3) iter_args(%arg13 = %arg11) -> (f32) { +// CHECK-NEXT: %10 = arith.addi %arg4, %arg10 : index +// CHECK-NEXT: %11 = arith.addi %arg5, %arg12 : index +// CHECK-NEXT: %12 = vector.load %arg0[%arg3, %10, %11, %arg7] : memref, vector<8xf32> +// CHECK-NEXT: %13 = vector.load %arg1[%arg6, %arg10, %arg12, %arg7] : memref, vector<8xf32> +// CHECK-NEXT: %14 = arith.mulf %12, %13 : vector<8xf32> +// CHECK-NEXT: %15 = vector.reduction , %14 : vector<8xf32> into f32 +// CHECK-NEXT: %16 = arith.addf %15, %arg13 : f32 +// CHECK-NEXT: affine.yield %16 : f32 // CHECK-NEXT: } -// CHECK-NEXT: affine.yield %8 : f32 +// CHECK-NEXT: affine.yield %9 : f32 // CHECK-NEXT: } -// CHECK-NEXT: %7 = arith.addi %arg8, %c8 : index -// CHECK-NEXT: scf.yield %7, %6 : index, f32 +// CHECK-NEXT: %8 = arith.addi %arg8, %c8 : index +// CHECK-NEXT: scf.yield %8, %7 : index, f32 // CHECK-NEXT: } -// CHECK-NEXT: %5 = scf.for %arg7 = %4#0 to %dim_0 step %c8 iter_args(%arg8 = %4#1) -> (f32) { -// CHECK-NEXT: %6 = arith.subi %dim_0, %4#0 : index -// CHECK-NEXT: %7 = vector.create_mask %6 : vector<8xi1> -// CHECK-NEXT: %8 = affine.for %arg9 = #map(%c0) to #map(%dim_2) iter_args(%arg10 = %arg8) -> (f32) { -// CHECK-NEXT: %9 = affine.for %arg11 = #map(%c0) to #map(%dim_3) iter_args(%arg12 = %arg10) -> (f32) { -// CHECK-NEXT: %10 = arith.addi %arg4, %arg9 : index -// CHECK-NEXT: %11 = arith.addi %arg5, %arg11 : index +// CHECK-NEXT: %5 = arith.subi %dim_0, %4#0 : index +// CHECK-NEXT: %6 = arith.cmpi sge, %5, %c0 : index +// CHECK-NEXT: scf.if %6 { +// CHECK-NEXT: %7 = vector.create_mask %5 : vector<8xi1> +// CHECK-NEXT: %8 = affine.for %arg7 = #map(%c0) to #map(%dim_2) iter_args(%arg8 = %4#1) -> (f32) { +// CHECK-NEXT: %9 = affine.for %arg9 = #map(%c0) to #map(%dim_3) iter_args(%arg10 = %arg8) -> (f32) { +// CHECK-NEXT: %10 = arith.addi %arg4, %arg7 : index +// CHECK-NEXT: %11 = arith.addi %arg5, %arg9 : index // CHECK-NEXT: %12 = vector.maskedload %arg0[%arg3, %10, %11, %4#0], %7, %0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> -// CHECK-NEXT: %13 = vector.maskedload %arg1[%arg6, %arg9, %arg11, %4#0], %7, %0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> +// CHECK-NEXT: %13 = vector.maskedload %arg1[%arg6, %arg7, %arg9, %4#0], %7, %0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> // CHECK-NEXT: %14 = arith.mulf %12, %13 : vector<8xf32> // CHECK-NEXT: %15 = vector.reduction , %14 : vector<8xf32> into f32 -// CHECK-NEXT: %16 = arith.addf %15, %arg12 : f32 +// CHECK-NEXT: %16 = arith.addf %15, %arg10 : f32 // CHECK-NEXT: affine.yield %16 : f32 // CHECK-NEXT: } // CHECK-NEXT: affine.yield %9 : f32 // CHECK-NEXT: } -// CHECK-NEXT: scf.yield %8 : f32 +// CHECK-NEXT: memref.store %8, %arg2[%arg3, %arg4, %arg5, %arg6] : memref +// CHECK-NEXT: } else { +// CHECK-NEXT: memref.store %4#1, %arg2[%arg3, %arg4, %arg5, %arg6] : memref // CHECK-NEXT: } -// CHECK-NEXT: memref.store %5, %arg2[%arg3, %arg4, %arg5, %arg6] : memref // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } From 9b77b3f35eec90bea43556e679941f082bbf57f5 Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Tue, 26 Nov 2024 09:08:08 +0000 Subject: [PATCH 06/11] [midend/lib/Conversion/ConvVectorization] Fix conv2dnhwcfhwc vectorization pass and examples. --- .../ConvVectorization/Conv2dNhwcFhwcVectorization.cpp | 2 +- tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp b/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp index c7cef4e24a..990dc4d700 100644 --- a/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp +++ b/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp @@ -194,7 +194,7 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { Value idx = iter_val.getResult(0); Value tailSize = builder.create(loc, channels, idx); Value tailCond = rewriter.create( - loc, arith::CmpIPredicate::sge, tailSize, c0); + loc, arith::CmpIPredicate::sgt, tailSize, c0); // If the current column does not reach the tail. builder.create( loc, tailCond, diff --git a/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir b/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir index a73ac1df47..0fe1eca30e 100644 --- a/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir +++ b/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir @@ -25,7 +25,7 @@ // CHECK-NEXT: scf.yield %8, %7 : index, f32 // CHECK-NEXT: } // CHECK-NEXT: %5 = arith.subi %dim_0, %4#0 : index -// CHECK-NEXT: %6 = arith.cmpi sge, %5, %c0 : index +// CHECK-NEXT: %6 = arith.cmpi sgt, %5, %c0 : index // CHECK-NEXT: scf.if %6 { // CHECK-NEXT: %7 = vector.create_mask %5 : vector<8xi1> // CHECK-NEXT: %8 = affine.for %arg7 = #map(%c0) to #map(%dim_2) iter_args(%arg8 = %4#1) -> (f32) { From a1534a0cf705f9d2c8a9d47ab651d69334624aa7 Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Mon, 2 Dec 2024 06:27:07 +0000 Subject: [PATCH 07/11] [midend/lib/Conversion/ConvVectorization] Fix conv2dnhwcfhwc vectorization pass and examples. --- .../Conv2dNhwcFhwcVectorization.cpp | 23 +++--- .../conv2d-nhwc-fhwc-max-vectorization.mlir | 72 ++++++++++--------- 2 files changed, 52 insertions(+), 43 deletions(-) diff --git a/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp b/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp index 990dc4d700..8e293b0bc8 100644 --- a/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp +++ b/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp @@ -77,6 +77,10 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { Value input = op->getOperand(0); Value kernel = op->getOperand(1); Value output = op->getOperand(2); + auto strides = op->getAttrOfType("strides") + .getValues(); + Value strHeight = rewriter.create(loc, strides[0]); + Value strWidth = rewriter.create(loc, strides[1]); // Get i1 as the element type for mask vector. IntegerType i1 = IntegerType::get(ctx, 1); @@ -124,6 +128,9 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { affine::buildAffineLoopNest( rewriter, loc, lowerBounds, uperBounds, steps, [&](OpBuilder &builder, Location loc, ValueRange ivs) { + Value tmp_ivs1 = + builder.create(loc, ivs[1], strHeight); + Value tmp_ivs2 = builder.create(loc, ivs[2], strWidth); Value tmp_result = builder.create( loc, elementTy, output, ValueRange{ivs[0], ivs[1], ivs[2], ivs[3]}); @@ -144,10 +151,10 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { /*Step=*/1, ValueRange{itrArgs0[0]}, [&](OpBuilder &builder, Location loc, Value iv1, ValueRange itrArgs1) { - Value inputHeight = - builder.create(loc, ivs[1], iv0); - Value inputWidth = - builder.create(loc, ivs[2], iv1); + Value inputHeight = builder.create( + loc, tmp_ivs1, iv0); + Value inputWidth = builder.create( + loc, tmp_ivs2, iv1); Value inputVector = builder.create( loc, vectorTy, input, ValueRange{ivs[0], inputHeight, inputWidth, @@ -213,10 +220,10 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { /*Step=*/1, ValueRange{itrArgs0[0]}, [&](OpBuilder &builder, Location loc, Value iv1, ValueRange itrArgs1) { - Value inputHeight = - builder.create(loc, ivs[1], iv0); - Value inputWidth = - builder.create(loc, ivs[2], iv1); + Value inputHeight = builder.create( + loc, tmp_ivs1, iv0); + Value inputWidth = builder.create( + loc, tmp_ivs2, iv1); Value inputVec = builder.create( loc, vectorTy, input, ValueRange{ivs[0], inputHeight, inputWidth, diff --git a/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir b/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir index 0fe1eca30e..38ee086e2b 100644 --- a/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir +++ b/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir @@ -3,47 +3,49 @@ // CHECK: #map = affine_map<(d0) -> (d0)> // CHECK: module { // CHECK: affine.for %arg3 = #map(%c0) to #map(%dim) { -// CHECK-NEXT: affine.for %arg4 = #map(%c0) to #map(%dim_4) { -// CHECK-NEXT: affine.for %arg5 = #map(%c0) to #map(%dim_5) { -// CHECK-NEXT: affine.for %arg6 = #map(%c0) to #map(%dim_1) { -// CHECK-NEXT: %3 = memref.load %arg2[%arg3, %arg4, %arg5, %arg6] : memref -// CHECK-NEXT: %4:2 = scf.for %arg7 = %c0 to %2 step %c8 iter_args(%arg8 = %c0, %arg9 = %3) -> (index, f32) { -// CHECK-NEXT: %7 = affine.for %arg10 = #map(%c0) to #map(%dim_2) iter_args(%arg11 = %arg9) -> (f32) { -// CHECK-NEXT: %9 = affine.for %arg12 = #map(%c0) to #map(%dim_3) iter_args(%arg13 = %arg11) -> (f32) { -// CHECK-NEXT: %10 = arith.addi %arg4, %arg10 : index -// CHECK-NEXT: %11 = arith.addi %arg5, %arg12 : index -// CHECK-NEXT: %12 = vector.load %arg0[%arg3, %10, %11, %arg7] : memref, vector<8xf32> -// CHECK-NEXT: %13 = vector.load %arg1[%arg6, %arg10, %arg12, %arg7] : memref, vector<8xf32> -// CHECK-NEXT: %14 = arith.mulf %12, %13 : vector<8xf32> -// CHECK-NEXT: %15 = vector.reduction , %14 : vector<8xf32> into f32 -// CHECK-NEXT: %16 = arith.addf %15, %arg13 : f32 -// CHECK-NEXT: affine.yield %16 : f32 +// CHECK-NEXT: affine.for %arg4 = #map(%c0) to #map(%dim_6) { +// CHECK-NEXT: affine.for %arg5 = #map(%c0) to #map(%dim_7) { +// CHECK-NEXT: affine.for %arg6 = #map(%c0) to #map(%dim_3) { +// CHECK-NEXT: %3 = arith.muli %arg4, %c1 : index +// CHECK-NEXT: %4 = arith.muli %arg5, %c1_0 : index +// CHECK-NEXT: %5 = memref.load %arg2[%arg3, %arg4, %arg5, %arg6] : memref +// CHECK-NEXT: %6:2 = scf.for %arg7 = %c0 to %2 step %c8 iter_args(%arg8 = %c0, %arg9 = %5) -> (index, f32) { +// CHECK-NEXT: %9 = affine.for %arg10 = #map(%c0) to #map(%dim_4) iter_args(%arg11 = %arg9) -> (f32) { +// CHECK-NEXT: %11 = affine.for %arg12 = #map(%c0) to #map(%dim_5) iter_args(%arg13 = %arg11) -> (f32) { +// CHECK-NEXT: %12 = arith.addi %3, %arg10 : index +// CHECK-NEXT: %13 = arith.addi %4, %arg12 : index +// CHECK-NEXT: %14 = vector.load %arg0[%arg3, %12, %13, %arg7] : memref, vector<8xf32> +// CHECK-NEXT: %15 = vector.load %arg1[%arg6, %arg10, %arg12, %arg7] : memref, vector<8xf32> +// CHECK-NEXT: %16 = arith.mulf %14, %15 : vector<8xf32> +// CHECK-NEXT: %17 = vector.reduction , %16 : vector<8xf32> into f32 +// CHECK-NEXT: %18 = arith.addf %17, %arg13 : f32 +// CHECK-NEXT: affine.yield %18 : f32 // CHECK-NEXT: } -// CHECK-NEXT: affine.yield %9 : f32 +// CHECK-NEXT: affine.yield %11 : f32 // CHECK-NEXT: } -// CHECK-NEXT: %8 = arith.addi %arg8, %c8 : index -// CHECK-NEXT: scf.yield %8, %7 : index, f32 +// CHECK-NEXT: %10 = arith.addi %arg8, %c8 : index +// CHECK-NEXT: scf.yield %10, %9 : index, f32 // CHECK-NEXT: } -// CHECK-NEXT: %5 = arith.subi %dim_0, %4#0 : index -// CHECK-NEXT: %6 = arith.cmpi sgt, %5, %c0 : index -// CHECK-NEXT: scf.if %6 { -// CHECK-NEXT: %7 = vector.create_mask %5 : vector<8xi1> -// CHECK-NEXT: %8 = affine.for %arg7 = #map(%c0) to #map(%dim_2) iter_args(%arg8 = %4#1) -> (f32) { -// CHECK-NEXT: %9 = affine.for %arg9 = #map(%c0) to #map(%dim_3) iter_args(%arg10 = %arg8) -> (f32) { -// CHECK-NEXT: %10 = arith.addi %arg4, %arg7 : index -// CHECK-NEXT: %11 = arith.addi %arg5, %arg9 : index -// CHECK-NEXT: %12 = vector.maskedload %arg0[%arg3, %10, %11, %4#0], %7, %0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> -// CHECK-NEXT: %13 = vector.maskedload %arg1[%arg6, %arg7, %arg9, %4#0], %7, %0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> -// CHECK-NEXT: %14 = arith.mulf %12, %13 : vector<8xf32> -// CHECK-NEXT: %15 = vector.reduction , %14 : vector<8xf32> into f32 -// CHECK-NEXT: %16 = arith.addf %15, %arg10 : f32 -// CHECK-NEXT: affine.yield %16 : f32 +// CHECK-NEXT: %7 = arith.subi %dim_2, %6#0 : index +// CHECK-NEXT: %8 = arith.cmpi sgt, %7, %c0 : index +// CHECK-NEXT: scf.if %8 { +// CHECK-NEXT: %9 = vector.create_mask %7 : vector<8xi1> +// CHECK-NEXT: %10 = affine.for %arg7 = #map(%c0) to #map(%dim_4) iter_args(%arg8 = %6#1) -> (f32) { +// CHECK-NEXT: %11 = affine.for %arg9 = #map(%c0) to #map(%dim_5) iter_args(%arg10 = %arg8) -> (f32) { +// CHECK-NEXT: %12 = arith.addi %3, %arg7 : index +// CHECK-NEXT: %13 = arith.addi %4, %arg9 : index +// CHECK-NEXT: %14 = vector.maskedload %arg0[%arg3, %12, %13, %6#0], %9, %0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> +// CHECK-NEXT: %15 = vector.maskedload %arg1[%arg6, %arg7, %arg9, %6#0], %9, %0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> +// CHECK-NEXT: %16 = arith.mulf %14, %15 : vector<8xf32> +// CHECK-NEXT: %17 = vector.reduction , %16 : vector<8xf32> into f32 +// CHECK-NEXT: %18 = arith.addf %17, %arg10 : f32 +// CHECK-NEXT: affine.yield %18 : f32 // CHECK-NEXT: } -// CHECK-NEXT: affine.yield %9 : f32 +// CHECK-NEXT: affine.yield %11 : f32 // CHECK-NEXT: } -// CHECK-NEXT: memref.store %8, %arg2[%arg3, %arg4, %arg5, %arg6] : memref +// CHECK-NEXT: memref.store %10, %arg2[%arg3, %arg4, %arg5, %arg6] : memref // CHECK-NEXT: } else { -// CHECK-NEXT: memref.store %4#1, %arg2[%arg3, %arg4, %arg5, %arg6] : memref +// CHECK-NEXT: memref.store %6#1, %arg2[%arg3, %arg4, %arg5, %arg6] : memref // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } From c36ea9dceb6fe94b4a0dbe19482cefec396273cb Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Thu, 19 Dec 2024 06:26:29 +0000 Subject: [PATCH 08/11] [midend/lib/Conversion/ConvVectorization] add conv2dnhwcfhwc vectorization pass and examples about attributes. --- .../BuddyConvolution/conv2d-nhwc-fhwc.mlir | 5 +- examples/BuddyConvolution/makefile | 5 ++ .../Conv2dNhwcFhwcVectorization.cpp | 71 +++++++++++++++---- .../conv2d-nhwc-fhwc-max-vectorization.mlir | 63 +++++----------- 4 files changed, 84 insertions(+), 60 deletions(-) diff --git a/examples/BuddyConvolution/conv2d-nhwc-fhwc.mlir b/examples/BuddyConvolution/conv2d-nhwc-fhwc.mlir index 8928d027a3..dcf88e1a55 100644 --- a/examples/BuddyConvolution/conv2d-nhwc-fhwc.mlir +++ b/examples/BuddyConvolution/conv2d-nhwc-fhwc.mlir @@ -19,8 +19,9 @@ module { func.func private @rtclock() -> f64 func.func @conv_2d_nhwc_fhwc(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.conv_2d_nhwc_fhwc ins (%arg0, %arg1: memref, memref) - outs (%arg2: memref) + linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%arg0, %arg1: memref, memref) + outs (%arg2: memref) return } diff --git a/examples/BuddyConvolution/makefile b/examples/BuddyConvolution/makefile index 42cf9d6313..4c022ae30a 100644 --- a/examples/BuddyConvolution/makefile +++ b/examples/BuddyConvolution/makefile @@ -126,6 +126,11 @@ conv2d-nhwc-fhwc-opt-aot: -o a.out @LD_LIBRARY_PATH=${MLIR_LIB} ./a.out +conv2d-nhwc-fhwc-vectorization-lower: + @${BUDDY_OPT} ./conv2d-nhwc-fhwc.mlir \ + -conv2d-nhwc-fhwc-vectorization \ + -o log.mlir + conv2d-nhwc-fhwc-vectorization-run: @${BUDDY_OPT} ./conv2d-nhwc-fhwc.mlir \ -conv2d-nhwc-fhwc-vectorization \ diff --git a/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp b/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp index 8e293b0bc8..77fb0c931c 100644 --- a/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp +++ b/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp @@ -77,11 +77,32 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { Value input = op->getOperand(0); Value kernel = op->getOperand(1); Value output = op->getOperand(2); - auto strides = op->getAttrOfType("strides") - .getValues(); + // Get Strides. + SmallVector strides = {1, 1}; + if (op->hasAttr("strides")) { + strides.clear(); + for (auto value : op->getAttrOfType("strides").getValues()) { + strides.push_back(value); + } + } + bool stride1 = strides[0] != 1; + bool stride2 = strides[1] != 1; Value strHeight = rewriter.create(loc, strides[0]); Value strWidth = rewriter.create(loc, strides[1]); + // Get Dilations. + SmallVector dilations = {1, 1}; + if (op->hasAttr("dilations")) { + dilations.clear(); + for (auto value : op->getAttrOfType("dilations").getValues()) { + dilations.push_back(value); + } + } + bool dilated1 = dilations[0] != 1; + bool dilated2 = dilations[1] != 1; + Value dilHeight = rewriter.create(loc, dilations[0]); + Value dilWidth = rewriter.create(loc, dilations[1]); + // Get i1 as the element type for mask vector. IntegerType i1 = IntegerType::get(ctx, 1); VectorType vectorMaskTy = mlir::VectorType::get({vecsize}, i1); @@ -128,9 +149,15 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { affine::buildAffineLoopNest( rewriter, loc, lowerBounds, uperBounds, steps, [&](OpBuilder &builder, Location loc, ValueRange ivs) { - Value tmp_ivs1 = - builder.create(loc, ivs[1], strHeight); - Value tmp_ivs2 = builder.create(loc, ivs[2], strWidth); + // Create strides variables. + Value tmp_ivs1 = ivs[1]; + if(stride1){ + tmp_ivs1 = builder.create(loc, ivs[1], strHeight); + } + Value tmp_ivs2 = ivs[2]; + if(stride2){ + tmp_ivs2 = builder.create(loc, ivs[2], strWidth); + } Value tmp_result = builder.create( loc, elementTy, output, ValueRange{ivs[0], ivs[1], ivs[2], ivs[3]}); @@ -145,16 +172,24 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { /*Step=*/1, ValueRange{itrArgs[1]}, [&](OpBuilder &builder, Location loc, Value iv0, ValueRange itrArgs0) { + // Create dilated[0] variables. + Value tmp_ivs3 = iv0; + if(dilated1){ + tmp_ivs3 = builder.create(loc, iv0, dilHeight); + } + Value inputHeight = builder.create(loc, tmp_ivs1, tmp_ivs3); auto tmp1 = builder.create( loc, ValueRange{c0}, builder.getDimIdentityMap(), ValueRange{width_k}, builder.getDimIdentityMap(), /*Step=*/1, ValueRange{itrArgs0[0]}, [&](OpBuilder &builder, Location loc, Value iv1, ValueRange itrArgs1) { - Value inputHeight = builder.create( - loc, tmp_ivs1, iv0); - Value inputWidth = builder.create( - loc, tmp_ivs2, iv1); + // Create dilated[1] variables. + Value tmp_ivs4 = iv1; + if(dilated2){ + tmp_ivs4 = builder.create(loc, iv1, dilWidth); + } + Value inputWidth = builder.create(loc, tmp_ivs2, tmp_ivs4); Value inputVector = builder.create( loc, vectorTy, input, ValueRange{ivs[0], inputHeight, inputWidth, @@ -214,16 +249,26 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { /*Step=*/1, ValueRange{iter_val.getResult(1)}, [&](OpBuilder &builder, Location loc, Value iv0, ValueRange itrArgs0) { + // Create dilated[0] variables. + Value tmp_ivs3 = iv0; + if(dilated1){ + tmp_ivs3 = builder.create(loc, iv0, dilHeight); + } + Value inputHeight = + builder.create(loc, tmp_ivs1, tmp_ivs3); auto tmp1 = builder.create( loc, ValueRange{c0}, builder.getDimIdentityMap(), ValueRange{width_k}, builder.getDimIdentityMap(), /*Step=*/1, ValueRange{itrArgs0[0]}, [&](OpBuilder &builder, Location loc, Value iv1, ValueRange itrArgs1) { - Value inputHeight = builder.create( - loc, tmp_ivs1, iv0); - Value inputWidth = builder.create( - loc, tmp_ivs2, iv1); + // Create dilated[1] variables. + Value tmp_ivs4 = iv1; + if(dilated2){ + tmp_ivs4 = builder.create(loc, iv1, dilWidth); + } + Value inputWidth = + builder.create(loc, tmp_ivs2, tmp_ivs4); Value inputVec = builder.create( loc, vectorTy, input, ValueRange{ivs[0], inputHeight, inputWidth, diff --git a/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir b/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir index 38ee086e2b..0713ac0207 100644 --- a/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir +++ b/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir @@ -3,54 +3,27 @@ // CHECK: #map = affine_map<(d0) -> (d0)> // CHECK: module { // CHECK: affine.for %arg3 = #map(%c0) to #map(%dim) { -// CHECK-NEXT: affine.for %arg4 = #map(%c0) to #map(%dim_6) { -// CHECK-NEXT: affine.for %arg5 = #map(%c0) to #map(%dim_7) { -// CHECK-NEXT: affine.for %arg6 = #map(%c0) to #map(%dim_3) { -// CHECK-NEXT: %3 = arith.muli %arg4, %c1 : index -// CHECK-NEXT: %4 = arith.muli %arg5, %c1_0 : index -// CHECK-NEXT: %5 = memref.load %arg2[%arg3, %arg4, %arg5, %arg6] : memref -// CHECK-NEXT: %6:2 = scf.for %arg7 = %c0 to %2 step %c8 iter_args(%arg8 = %c0, %arg9 = %5) -> (index, f32) { -// CHECK-NEXT: %9 = affine.for %arg10 = #map(%c0) to #map(%dim_4) iter_args(%arg11 = %arg9) -> (f32) { -// CHECK-NEXT: %11 = affine.for %arg12 = #map(%c0) to #map(%dim_5) iter_args(%arg13 = %arg11) -> (f32) { -// CHECK-NEXT: %12 = arith.addi %3, %arg10 : index -// CHECK-NEXT: %13 = arith.addi %4, %arg12 : index -// CHECK-NEXT: %14 = vector.load %arg0[%arg3, %12, %13, %arg7] : memref, vector<8xf32> -// CHECK-NEXT: %15 = vector.load %arg1[%arg6, %arg10, %arg12, %arg7] : memref, vector<8xf32> -// CHECK-NEXT: %16 = arith.mulf %14, %15 : vector<8xf32> -// CHECK-NEXT: %17 = vector.reduction , %16 : vector<8xf32> into f32 -// CHECK-NEXT: %18 = arith.addf %17, %arg13 : f32 -// CHECK-NEXT: affine.yield %18 : f32 +// CHECK-NEXT: affine.for %arg4 = #map(%c0) to #map(%dim_8) { +// CHECK-NEXT: affine.for %arg5 = #map(%c0) to #map(%dim_9) { +// CHECK-NEXT: affine.for %arg6 = #map(%c0) to #map(%dim_5) { +// CHECK-NEXT: %3 = memref.load %arg2[%arg3, %arg4, %arg5, %arg6] : memref +// CHECK-NEXT: %4:2 = scf.for %arg7 = %c0 to %2 step %c8 iter_args(%arg8 = %c0, %arg9 = %3) -> (index, f32) { +// CHECK-NEXT: %7 = affine.for %arg10 = #map(%c0) to #map(%dim_6) iter_args(%arg11 = %arg9) -> (f32) { +// CHECK-NEXT: %9 = arith.addi %arg4, %arg10 : index +// CHECK-NEXT: %10 = affine.for %arg12 = #map(%c0) to #map(%dim_7) iter_args(%arg13 = %arg11) -> (f32) { +// CHECK-NEXT: %11 = arith.addi %arg5, %arg12 : index +// CHECK-NEXT: %12 = vector.load %arg0[%arg3, %9, %11, %arg7] : memref, vector<8xf32> +// CHECK-NEXT: %13 = vector.load %arg1[%arg6, %arg10, %arg12, %arg7] : memref, vector<8xf32> +// CHECK-NEXT: %14 = arith.mulf %12, %13 : vector<8xf32> +// CHECK-NEXT: %15 = vector.reduction , %14 : vector<8xf32> into f32 +// CHECK-NEXT: %16 = arith.addf %15, %arg13 : f32 +// CHECK-NEXT: affine.yield %16 : f32 // CHECK-NEXT: } -// CHECK-NEXT: affine.yield %11 : f32 +// CHECK-NEXT: affine.yield %10 : f32 // CHECK-NEXT: } -// CHECK-NEXT: %10 = arith.addi %arg8, %c8 : index -// CHECK-NEXT: scf.yield %10, %9 : index, f32 +// CHECK-NEXT: %8 = arith.addi %arg8, %c8 : index +// CHECK-NEXT: scf.yield %8, %7 : index, f32 // CHECK-NEXT: } -// CHECK-NEXT: %7 = arith.subi %dim_2, %6#0 : index -// CHECK-NEXT: %8 = arith.cmpi sgt, %7, %c0 : index -// CHECK-NEXT: scf.if %8 { -// CHECK-NEXT: %9 = vector.create_mask %7 : vector<8xi1> -// CHECK-NEXT: %10 = affine.for %arg7 = #map(%c0) to #map(%dim_4) iter_args(%arg8 = %6#1) -> (f32) { -// CHECK-NEXT: %11 = affine.for %arg9 = #map(%c0) to #map(%dim_5) iter_args(%arg10 = %arg8) -> (f32) { -// CHECK-NEXT: %12 = arith.addi %3, %arg7 : index -// CHECK-NEXT: %13 = arith.addi %4, %arg9 : index -// CHECK-NEXT: %14 = vector.maskedload %arg0[%arg3, %12, %13, %6#0], %9, %0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> -// CHECK-NEXT: %15 = vector.maskedload %arg1[%arg6, %arg7, %arg9, %6#0], %9, %0 : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> -// CHECK-NEXT: %16 = arith.mulf %14, %15 : vector<8xf32> -// CHECK-NEXT: %17 = vector.reduction , %16 : vector<8xf32> into f32 -// CHECK-NEXT: %18 = arith.addf %17, %arg10 : f32 -// CHECK-NEXT: affine.yield %18 : f32 -// CHECK-NEXT: } -// CHECK-NEXT: affine.yield %11 : f32 -// CHECK-NEXT: } -// CHECK-NEXT: memref.store %10, %arg2[%arg3, %arg4, %arg5, %arg6] : memref -// CHECK-NEXT: } else { -// CHECK-NEXT: memref.store %6#1, %arg2[%arg3, %arg4, %arg5, %arg6] : memref -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: } func.func @conv_2d_nhwc_fhwc(%arg0: memref, %arg1: memref, %arg2: memref) { linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} From 5266ebc432c03c525806ab2559d3a2e16d35cf6e Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Wed, 25 Dec 2024 09:20:05 +0000 Subject: [PATCH 09/11] examples/BuddyConvolution/ : conv2d-nhwc-fhwc.mlir, conv2d-nhwc-fhwc-opt.mlir: adjust the code structure; conv2d-nhwc-fhwc-vec.mlir: handwritten mlir file; makefile update Related commands. midend/lib/Conversion/ConvVectorizationn/: Conv2dNhwcFhwcVectorization.cpp to implement vectorisation. Under the path tests/Conversion/: conv2d-nhwc-fhwc-max-vectorisation.mlir: a test file. Before vectorisation optimization: 0.000247002, after vectorisation optimization: 0.000134945 (vectorisation size 2), speedup ratio 0.454. --- .../conv2d-nhwc-fhwc-opt.mlir | 21 ++++++++--------- .../conv2d-nhwc-fhwc-vec.mlir | 21 +++++++++-------- .../BuddyConvolution/conv2d-nhwc-fhwc.mlir | 23 +++++++++++-------- examples/BuddyConvolution/makefile | 2 +- 4 files changed, 35 insertions(+), 32 deletions(-) diff --git a/examples/BuddyConvolution/conv2d-nhwc-fhwc-opt.mlir b/examples/BuddyConvolution/conv2d-nhwc-fhwc-opt.mlir index 76d5e4d932..7c30df096e 100644 --- a/examples/BuddyConvolution/conv2d-nhwc-fhwc-opt.mlir +++ b/examples/BuddyConvolution/conv2d-nhwc-fhwc-opt.mlir @@ -40,6 +40,7 @@ module { %h_o = memref.dim %arg2, %c1 : memref %w_o = memref.dim %arg2, %c2 : memref + %t_start = call @rtclock() : () -> f64 // Output is NHoWoF affine.for %idx_n = %c0 to %n { affine.for %idx_f = %c0 to %f { @@ -67,7 +68,14 @@ module { } } } + %t_end = call @rtclock() : () -> f64 + %time = arith.subf %t_end, %t_start : f64 + + %printed_output = memref.cast %arg2 : memref to memref<*xf32> + call @printMemrefF32(%printed_output) : (memref<*xf32>) -> () + // Print timings. + vector.print %time : f64 return } @@ -111,10 +119,6 @@ module { %v1 = call @alloc_f32(%c6, %c5, %c5, %c1, %f3) : (index, index, index, index, f32) -> memref %v2 = call @alloc_f32(%c1, %c24, %c24, %c6, %f0) : (index, index, index, index, f32) -> memref - %t_start = call @rtclock() : () -> f64 - call @conv_2d_nhwc_fhwc(%v0, %v1, %v2) : (memref, memref, memref) -> () - %t_end = call @rtclock() : () -> f64 - // All the elements of the MemRef are the same, // only check the first line to verify the correctness. // CHECK: Unranked Memref @@ -122,16 +126,11 @@ module { // CHECK: [ // CHECK: [ // CHECK: [150{{(, 150)*}}], - %print_v2 = memref.cast %v2 : memref to memref<*xf32> - call @printMemrefF32(%print_v2) : (memref<*xf32>) -> () - - %time = arith.subf %t_end, %t_start : f64 - vector.print %time : f64 - + call @conv_2d_nhwc_fhwc(%v0, %v1, %v2) : (memref, memref, memref) -> () + memref.dealloc %v0 : memref memref.dealloc %v1 : memref memref.dealloc %v2 : memref - return } } diff --git a/examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir b/examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir index de33973c41..d93f41dc8e 100644 --- a/examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir +++ b/examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir @@ -41,6 +41,7 @@ module { %upbound_tmp = arith.subi %c, %vl_step : index %upbound = arith.addi %upbound_tmp, %c1 : index + %t_start = call @rtclock() : () -> f64 // Output is NHoWoF affine.for %idx_n = %c0 to %n { affine.for %idx_h_o = %c0 to %h_o { @@ -92,6 +93,14 @@ module { } } } + %t_end = call @rtclock() : () -> f64 + %time = arith.subf %t_end, %t_start : f64 + + %printed_output = memref.cast %arg2 : memref to memref<*xf32> + call @printMemrefF32(%printed_output) : (memref<*xf32>) -> () + + // Print timings. + vector.print %time : f64 return } @@ -135,10 +144,6 @@ module { %v1 = call @alloc_f32(%c6, %c5, %c5, %c1, %f3) : (index, index, index, index, f32) -> memref %v2 = call @alloc_f32(%c1, %c24, %c24, %c6, %f0) : (index, index, index, index, f32) -> memref - %t_start = call @rtclock() : () -> f64 - call @conv_2d_nhwc_fhwc(%v0, %v1, %v2) : (memref, memref, memref) -> () - %t_end = call @rtclock() : () -> f64 - // All the elements of the MemRef are the same, // only check the first line to verify the correctness. // CHECK: Unranked Memref @@ -146,12 +151,8 @@ module { // CHECK: [ // CHECK: [ // CHECK: [150{{(, 150)*}}], - %print_v2 = memref.cast %v2 : memref to memref<*xf32> - call @printMemrefF32(%print_v2) : (memref<*xf32>) -> () - - %time = arith.subf %t_end, %t_start : f64 - vector.print %time : f64 - + call @conv_2d_nhwc_fhwc(%v0, %v1, %v2) : (memref, memref, memref) -> () + memref.dealloc %v0 : memref memref.dealloc %v1 : memref memref.dealloc %v2 : memref diff --git a/examples/BuddyConvolution/conv2d-nhwc-fhwc.mlir b/examples/BuddyConvolution/conv2d-nhwc-fhwc.mlir index dcf88e1a55..15d7284266 100644 --- a/examples/BuddyConvolution/conv2d-nhwc-fhwc.mlir +++ b/examples/BuddyConvolution/conv2d-nhwc-fhwc.mlir @@ -19,9 +19,20 @@ module { func.func private @rtclock() -> f64 func.func @conv_2d_nhwc_fhwc(%arg0: memref, %arg1: memref, %arg2: memref) { + %t_start = call @rtclock() : () -> f64 + linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins (%arg0, %arg1: memref, memref) outs (%arg2: memref) + + %t_end = call @rtclock() : () -> f64 + %time = arith.subf %t_end, %t_start : f64 + + %printed_output = memref.cast %arg2 : memref to memref<*xf32> + call @printMemrefF32(%printed_output) : (memref<*xf32>) -> () + + // Print timings. + vector.print %time : f64 return } @@ -65,10 +76,6 @@ module { %v1 = call @alloc_f32(%c6, %c5, %c5, %c1, %f3) : (index, index, index, index, f32) -> memref %v2 = call @alloc_f32(%c1, %c24, %c24, %c6, %f0) : (index, index, index, index, f32) -> memref - %t_start = call @rtclock() : () -> f64 - call @conv_2d_nhwc_fhwc(%v0, %v1, %v2) : (memref, memref, memref) -> () - %t_end = call @rtclock() : () -> f64 - // All the elements of the MemRef are the same, // only check the first line to verify the correctness. // CHECK: Unranked Memref @@ -76,12 +83,8 @@ module { // CHECK: [ // CHECK: [ // CHECK: [150{{(, 150)*}}], - %print_v2 = memref.cast %v2 : memref to memref<*xf32> - call @printMemrefF32(%print_v2) : (memref<*xf32>) -> () - - %time = arith.subf %t_end, %t_start : f64 - vector.print %time : f64 - + call @conv_2d_nhwc_fhwc(%v0, %v1, %v2) : (memref, memref, memref) -> () + memref.dealloc %v0 : memref memref.dealloc %v1 : memref memref.dealloc %v2 : memref diff --git a/examples/BuddyConvolution/makefile b/examples/BuddyConvolution/makefile index 4c022ae30a..20c1de2c12 100644 --- a/examples/BuddyConvolution/makefile +++ b/examples/BuddyConvolution/makefile @@ -133,7 +133,7 @@ conv2d-nhwc-fhwc-vectorization-lower: conv2d-nhwc-fhwc-vectorization-run: @${BUDDY_OPT} ./conv2d-nhwc-fhwc.mlir \ - -conv2d-nhwc-fhwc-vectorization \ + -conv2d-nhwc-fhwc-vectorization="vec-size=2" \ -convert-linalg-to-loops \ -lower-affine \ -arith-bufferize \ From fee59fd9c9f9c20da065965d2a077ff341876b93 Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Fri, 27 Dec 2024 09:14:22 +0000 Subject: [PATCH 10/11] [midend/lib/Conversion/ConvVectorization] fix some code style --- .../Conv2dNhwcFhwcVectorization.cpp | 87 ++++++++++--------- 1 file changed, 48 insertions(+), 39 deletions(-) diff --git a/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp b/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp index 77fb0c931c..760795ff3f 100644 --- a/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp +++ b/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp @@ -80,10 +80,11 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { // Get Strides. SmallVector strides = {1, 1}; if (op->hasAttr("strides")) { - strides.clear(); - for (auto value : op->getAttrOfType("strides").getValues()) { - strides.push_back(value); - } + strides.clear(); + for (auto value : op->getAttrOfType("strides") + .getValues()) { + strides.push_back(value); + } } bool stride1 = strides[0] != 1; bool stride2 = strides[1] != 1; @@ -93,14 +94,17 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { // Get Dilations. SmallVector dilations = {1, 1}; if (op->hasAttr("dilations")) { - dilations.clear(); - for (auto value : op->getAttrOfType("dilations").getValues()) { - dilations.push_back(value); - } + dilations.clear(); + for (auto value : + op->getAttrOfType("dilations") + .getValues()) { + dilations.push_back(value); + } } bool dilated1 = dilations[0] != 1; bool dilated2 = dilations[1] != 1; - Value dilHeight = rewriter.create(loc, dilations[0]); + Value dilHeight = + rewriter.create(loc, dilations[0]); Value dilWidth = rewriter.create(loc, dilations[1]); // Get i1 as the element type for mask vector. @@ -115,7 +119,7 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { const Value c1 = rewriter.create(loc, 1); const Value c2 = rewriter.create(loc, 2); const Value c3 = rewriter.create(loc, 3); - const Value vl_step = rewriter.create(loc, vecsize); + const Value vlStep = rewriter.create(loc, vecsize); const Value zero = buddy::insertZeroConstantOp(ctx, rewriter, loc, elementTy); @@ -136,12 +140,11 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { Value width_o = rewriter.create(loc, output, c2); // Calculate the upper bound for vectorized processing - // - Subtract `vl_step` is to avoid overflow at the vectorization tail. + // - Subtract `vlStep` is to avoid overflow at the vectorization tail. // - Add 1 to ensure the final loop runs when the workload length // is divisible by the vector size. - Value upperBound_tmp = - rewriter.create(loc, channels, vl_step); - Value upperBound = rewriter.create(loc, upperBound_tmp, c1); + Value upperBoundTmp = rewriter.create(loc, channels, vlStep); + Value upperBound = rewriter.create(loc, upperBoundTmp, c1); SmallVector lowerBounds(4, c0); SmallVector uperBounds{batch, height_o, width_o, f_o}; @@ -150,20 +153,20 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { rewriter, loc, lowerBounds, uperBounds, steps, [&](OpBuilder &builder, Location loc, ValueRange ivs) { // Create strides variables. - Value tmp_ivs1 = ivs[1]; - if(stride1){ - tmp_ivs1 = builder.create(loc, ivs[1], strHeight); + Value tmpIvs1 = ivs[1]; + if (stride1) { + tmpIvs1 = builder.create(loc, ivs[1], strHeight); } - Value tmp_ivs2 = ivs[2]; - if(stride2){ - tmp_ivs2 = builder.create(loc, ivs[2], strWidth); + Value tmpIvs2 = ivs[2]; + if (stride2) { + tmpIvs2 = builder.create(loc, ivs[2], strWidth); } Value tmp_result = builder.create( loc, elementTy, output, ValueRange{ivs[0], ivs[1], ivs[2], ivs[3]}); // Create vecsize mining loop. auto iter_val = builder.create( - loc, c0, upperBound, /*Step=*/vl_step, ValueRange{c0, tmp_result}, + loc, c0, upperBound, /*Step=*/vlStep, ValueRange{c0, tmp_result}, [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, ValueRange itrArgs) { auto tmp0 = nestedBuilder.create( @@ -173,11 +176,13 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { [&](OpBuilder &builder, Location loc, Value iv0, ValueRange itrArgs0) { // Create dilated[0] variables. - Value tmp_ivs3 = iv0; - if(dilated1){ - tmp_ivs3 = builder.create(loc, iv0, dilHeight); + Value tmpIvs3 = iv0; + if (dilated1) { + tmpIvs3 = + builder.create(loc, iv0, dilHeight); } - Value inputHeight = builder.create(loc, tmp_ivs1, tmp_ivs3); + Value inputHeight = + builder.create(loc, tmpIvs1, tmpIvs3); auto tmp1 = builder.create( loc, ValueRange{c0}, builder.getDimIdentityMap(), ValueRange{width_k}, builder.getDimIdentityMap(), @@ -185,11 +190,13 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { [&](OpBuilder &builder, Location loc, Value iv1, ValueRange itrArgs1) { // Create dilated[1] variables. - Value tmp_ivs4 = iv1; - if(dilated2){ - tmp_ivs4 = builder.create(loc, iv1, dilWidth); + Value tmpIvs4 = iv1; + if (dilated2) { + tmpIvs4 = builder.create(loc, iv1, + dilWidth); } - Value inputWidth = builder.create(loc, tmp_ivs2, tmp_ivs4); + Value inputWidth = builder.create( + loc, tmpIvs2, tmpIvs4); Value inputVector = builder.create( loc, vectorTy, input, ValueRange{ivs[0], inputHeight, inputWidth, @@ -226,7 +233,7 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { nestedLoc, tmp1.getResult(0)); }); Value idx = - builder.create(loc, itrArgs[0], vl_step); + builder.create(loc, itrArgs[0], vlStep); builder.create( loc, ValueRange{idx, tmp0.getResult(0)}); }); @@ -250,12 +257,13 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { [&](OpBuilder &builder, Location loc, Value iv0, ValueRange itrArgs0) { // Create dilated[0] variables. - Value tmp_ivs3 = iv0; - if(dilated1){ - tmp_ivs3 = builder.create(loc, iv0, dilHeight); + Value tmpIvs3 = iv0; + if (dilated1) { + tmpIvs3 = + builder.create(loc, iv0, dilHeight); } Value inputHeight = - builder.create(loc, tmp_ivs1, tmp_ivs3); + builder.create(loc, tmpIvs1, tmpIvs3); auto tmp1 = builder.create( loc, ValueRange{c0}, builder.getDimIdentityMap(), ValueRange{width_k}, builder.getDimIdentityMap(), @@ -263,12 +271,13 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { [&](OpBuilder &builder, Location loc, Value iv1, ValueRange itrArgs1) { // Create dilated[1] variables. - Value tmp_ivs4 = iv1; - if(dilated2){ - tmp_ivs4 = builder.create(loc, iv1, dilWidth); + Value tmpIvs4 = iv1; + if (dilated2) { + tmpIvs4 = builder.create(loc, iv1, + dilWidth); } - Value inputWidth = - builder.create(loc, tmp_ivs2, tmp_ivs4); + Value inputWidth = builder.create( + loc, tmpIvs2, tmpIvs4); Value inputVec = builder.create( loc, vectorTy, input, ValueRange{ivs[0], inputHeight, inputWidth, From 4a8a8e4ca13a147fc7d2736079cafb35bfd5fd06 Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Tue, 7 Jan 2025 16:24:39 +0000 Subject: [PATCH 11/11] Fix vectorization pass for pooling max and update related files conv2d-nhwc-fhwc-vec.mlir: examples/BuddyConvlution Conv2dNhwcFhwcVectorization.cpp: midend/lib/Conversion/ConvVectorization conv2d-nhwc-fhwc-vectorization.mlir: tests/Conversion --- examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir | 2 +- .../ConvVectorization/Conv2dNhwcFhwcVectorization.cpp | 2 +- tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir b/examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir index d93f41dc8e..340f5109fc 100644 --- a/examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir +++ b/examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir @@ -63,7 +63,7 @@ module { } affine.yield %tmp6 : f32 } - %tmp11 = arith.addi %iter_init, %vl_step : index + %tmp11 = arith.addi %idx_c, %vl_step : index scf.yield %tmp11, %tmp8 : index, f32 } // Compute the tail size and Process the remaining elements diff --git a/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp b/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp index 760795ff3f..a4d491e31f 100644 --- a/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp +++ b/midend/lib/Conversion/ConvVectorization/Conv2dNhwcFhwcVectorization.cpp @@ -233,7 +233,7 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern { nestedLoc, tmp1.getResult(0)); }); Value idx = - builder.create(loc, itrArgs[0], vlStep); + builder.create(loc, iv, vlStep); builder.create( loc, ValueRange{idx, tmp0.getResult(0)}); }); diff --git a/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir b/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir index 0713ac0207..37b955453b 100644 --- a/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir +++ b/tests/Conversion/conv2d-nhwc-fhwc-max-vectorization.mlir @@ -21,7 +21,7 @@ // CHECK-NEXT: } // CHECK-NEXT: affine.yield %10 : f32 // CHECK-NEXT: } -// CHECK-NEXT: %8 = arith.addi %arg8, %c8 : index +// CHECK-NEXT: %8 = arith.addi %arg7, %c8 : index // CHECK-NEXT: scf.yield %8, %7 : index, f32 // CHECK-NEXT: }