From 49e6482a43ccd72b9ae19689d026372af772ba15 Mon Sep 17 00:00:00 2001 From: EllisLambda Date: Thu, 17 Aug 2023 03:07:18 +0800 Subject: [PATCH] [midend][tests] Add batch level parallelism outer product BatchMatMul optimization pass and tests. --- examples/MLIRLinalg/linalg-batch-matmul.mlir | 29 ++ examples/MLIRLinalg/makefile | 55 +++ .../BatchMatMulOptimize.cpp | 312 ++++++++++++++++++ .../MatMulOptimization/CMakeLists.txt | 4 + tools/buddy-opt/CMakeLists.txt | 1 + tools/buddy-opt/buddy-opt.cpp | 2 + 6 files changed, 403 insertions(+) create mode 100644 examples/MLIRLinalg/linalg-batch-matmul.mlir create mode 100644 midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp diff --git a/examples/MLIRLinalg/linalg-batch-matmul.mlir b/examples/MLIRLinalg/linalg-batch-matmul.mlir new file mode 100644 index 0000000000..1ab0fe00ea --- /dev/null +++ b/examples/MLIRLinalg/linalg-batch-matmul.mlir @@ -0,0 +1,29 @@ +// RUN: buddy-opt -batchmatmul-optimize -verify-diagnostics -expand-strided-metadata -lower-affine -convert-vector-to-llvm -finalize-memref-to-llvm -convert-scf-to-cf -convert-linalg-to-llvm -llvm-request-c-wrappers -convert-func-to-llvm -reconcile-unrealized-casts %s \ +// RUN: | mlir-cpu-runner -O0 -e buddy_batchmatmul_f32 \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +memref.global "private" @A : memref<2x2x3xf32> = dense<[[[9., 4., 6.],[2., 4., 0.]],[[6., 3., 3.],[0., 4., 7.]]]> +memref.global "private" @B : memref<2x3x4xf32> = dense<[[[1., 3., 8., 0.],[1., 8., 8., 7.], [6., 9., 7., 9.]],[[3., 8., 6., 8.],[2., 7., 0., 6.],[0., 4., 0., 4.]]]> +memref.global "private" @C : memref<2x2x4xf32> = dense<[[[ 49., 113., 146., 82.],[ 6., 38., 48., 28.]],[[ 24., 81., 36., 78.],[ 8., 56., 0., 52.]]]> + +func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface } + +func.func @buddy_batchmatmul_f32() -> f32{ + %a = memref.get_global @A : memref<2x2x3xf32> + %b = memref.get_global @B : memref<2x3x4xf32> + %c = memref.get_global @C : memref<2x2x4xf32> + + linalg.batch_matmul + ins(%a, %b: memref<2x2x3xf32>, memref<2x3x4xf32>) + outs(%c: memref<2x2x4xf32>) + %printed_c = memref.cast %c : memref<2x2x4xf32> to memref<*xf32> + call @printMemrefF32(%printed_c) : (memref<*xf32>) -> () + // CHECK: {{Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 3 offset = 0 sizes = \[2, 2, 4\] strides = \[8, 4, 1\] data =}} + // CHECK{LITERAL}: [[[98, 226, 292, 164], + // CHECK{LITERAL}: [12, 76, 96, 56]], + // CHECK{LITERAL}: [[48, 162, 72, 156], + // CHECK{LITERAL}: [16, 112, 0, 104]]] + %zero = arith.constant 0.0 :f32 + return %zero :f32 +} diff --git a/examples/MLIRLinalg/makefile b/examples/MLIRLinalg/makefile index bfb165cfe9..6b377e577d 100644 --- a/examples/MLIRLinalg/makefile +++ b/examples/MLIRLinalg/makefile @@ -136,6 +136,61 @@ linalg-matmul-optimize-run: ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} +linalg-batch-matmul-optimize-run: + @${BUDDY_OPT} linalg-matmul.mlir ${MLIR_OPT_OPTIONS} \ + -batchmatmul-optimize="step-placeholder=64" \ + -convert-linalg-to-loops \ + -expand-strided-metadata \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -finalize-memref-to-llvm \ + -convert-arith-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +linalg-batch-matmul-lower: + @${MLIR_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \ + -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ + -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ + -convert-func-to-llvm -reconcile-unrealized-casts \ + -o ./log.mlir + +linalg-batch-matmul-translate: + @${MLIR_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \ + -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ + -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ + -convert-func-to-llvm -reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + +linalg-batch-matmul-run: + @${MLIR_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \ + -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ + -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ + -convert-func-to-llvm -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +linalg-batch-matmul-optimize-lower: + @${BUDDY_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \ + -batchmatmul-optimize="step-placeholder=64" \ + -o ./log.mlir + +linalg-batch-matmul-optimize-translate: + @${BUDDY_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \ + -batchmatmul-optimize="step-placeholder=64" \ + -convert-linalg-to-loops \ + -expand-strided-metadata \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -finalize-memref-to-llvm \ + -convert-arith-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + linalg-conv2d_nchw_fchw-lower: @${MLIR_OPT} ./linalg-conv2d_nchw_fchw.mlir \ -convert-linalg-to-loops -o ./log.mlir diff --git a/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp b/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp new file mode 100644 index 0000000000..9b3924b7d8 --- /dev/null +++ b/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp @@ -0,0 +1,312 @@ +//===- BatchMatMulOptimize.cpp +//-------------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the batchmatmul optimization. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/ValueRange.h" +#include "llvm/ADT/ArrayRef.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace vector; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { + +class BatchMatMulOptimizePattern : public ConversionPattern { +public: + explicit BatchMatMulOptimizePattern(MLIRContext *context, + int64_t stepPlaceHolderParam) + : ConversionPattern(linalg::BatchMatmulOp::getOperationName(), 1, + context) { + stepPlaceHolder = stepPlaceHolderParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + // Get input A, B, C. + Value A = op->getOperand(0); + Value B = op->getOperand(1); + Value C = op->getOperand(2); + // Get ElementType of input and output. + auto A_elementType = A.getType().cast().getElementType(); + + // Some constants. + const Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + const Value step = rewriter.create( + loc, rewriter.getIndexAttr(stepPlaceHolder)); + const AffineExpr d0 = rewriter.getAffineDimExpr(0); + const AffineExpr d1 = rewriter.getAffineDimExpr(1); + const AffineExpr d2 = rewriter.getAffineDimExpr(2); + const AffineExpr c0_affine = rewriter.getAffineConstantExpr(0); + + const Value c0_dynamicType = rewriter.create( + loc, rewriter.getZeroAttr(A_elementType)); + const Value c0_dynamicType_vec = rewriter.create( + loc, VectorType::get({stepPlaceHolder}, A_elementType), c0_dynamicType); + + // Dims + Value BATCH = rewriter.create(loc, A, 0); // Batch size + Value M = rewriter.create(loc, A, 1); // A row + Value N = rewriter.create(loc, B, 2); // B col + Value K = rewriter.create(loc, B, 1); // B row + + auto reducedValues = llvm::to_vector<4>(llvm::map_range( + ArrayRef{}, + [](const mlir::affine::LoopReduction &red) { return red.value; })); + + // Build parallel loop body. + auto parallelLoop = rewriter.create( + loc, ValueRange(reducedValues).getTypes(), ValueRange{BATCH}, + ArrayRef{ + rewriter.getNamedAttr( + "lowerBoundsGroups", + rewriter.getI32TensorAttr(ArrayRef{1})), + rewriter.getNamedAttr( + "upperBoundsGroups", + rewriter.getI32TensorAttr(ArrayRef{1})), + rewriter.getNamedAttr("lowerBoundsMap", + AffineMapAttr::get(AffineMap::get( + 0, 0, {c0_affine}, + rewriter.getContext()))), + rewriter.getNamedAttr("upperBoundsMap", + AffineMapAttr::get(AffineMap::get( + 1, 0, {d0}, rewriter.getContext()))), + rewriter.getNamedAttr("reductions", rewriter.getArrayAttr({})), + rewriter.getNamedAttr("steps", rewriter.getI64ArrayAttr(1))}); + + auto body = new Block(); + rewriter.setInsertionPointToStart(body); + body->addArgument(rewriter.getIndexType(), loc); + + Value ivBatch = body->getArguments()[0]; + + rewriter.create( + loc, A, AffineMap::get(3, 0, {d0, d1, d2}, rewriter.getContext()), + ArrayRef{ivBatch, c0, c0}, false, 3, true); + affine::buildAffineLoopNest( + rewriter, loc, {c0}, {K}, 1, + [&](OpBuilder &builder, Location loc, ValueRange ivRange) { + Value ivB_row = ivRange.front(); + affine::buildAffineLoopNest( + builder, loc, {c0}, {M}, 1, + [&](OpBuilder &builder, Location loc, ValueRange ivRange) { + Value ivA_row = ivRange.front(); + Value applied_n = builder.create( + loc, AffineMap::get(1, 0, d0.ceilDiv(stepPlaceHolder)), + ValueRange{N}); + affine::buildAffineLoopNest( + builder, loc, {c0}, {applied_n}, 1, + [&](OpBuilder &builder, Location loc, ValueRange ivRange) { + Value ivB_col = ivRange.front(); + Value a_ele = builder.create( + loc, A, ValueRange{ivBatch, ivA_row, ivB_row}); + Value a_vec = builder.create( + loc, + VectorType::get({stepPlaceHolder}, A_elementType), + a_ele); + Value b_col_cur = + builder.create(loc, ivB_col, step); + Value tail_len = + builder.create(loc, N, b_col_cur); + Value tail_flag = builder.create( + loc, mlir::arith::CmpIPredicate::sge, tail_len, step); + builder.create( + loc, tail_flag, + [&](OpBuilder &builder, Location loc) { + Value b_vec = + builder.create( + loc, + VectorType::get({stepPlaceHolder}, + A_elementType), + B, + AffineMap::get( + 3, 0, {d0, d1, d2 * stepPlaceHolder}, + rewriter.getContext()), + ValueRange{ivBatch, ivB_row, ivB_col}); + Value c_vec = + builder.create( + loc, + VectorType::get({stepPlaceHolder}, + A_elementType), + C, + AffineMap::get( + 3, 0, {d0, d1, d2 * stepPlaceHolder}, + rewriter.getContext()), + ValueRange{ivBatch, ivA_row, ivB_col}); + Value result_vec; + if (A_elementType.isIntOrFloat() && 0) { // bug + Value add_vec = builder.create( + loc, a_vec, b_vec); + result_vec = builder.create( + loc, add_vec, c_vec); + } else { + result_vec = builder.create( + loc, a_vec, b_vec, c_vec); + } + builder.create( + loc, result_vec, C, + AffineMap::get(3, 0, + {d0, d1, d2 * stepPlaceHolder}, + rewriter.getContext()), + ValueRange{ivBatch, ivA_row, ivB_col}); + builder.create(loc); + }, + [&](OpBuilder &builder, Location loc) { + Value mask_vec = + builder.create( + loc, + VectorType::get({stepPlaceHolder}, + rewriter.getI1Type()), + ValueRange{tail_len}); + Value b_col_idx_tail = + builder.create(loc, ivB_col, + step); + Value b_vec_tail = + builder.create( + loc, + VectorType::get({stepPlaceHolder}, + A_elementType), + B, + ValueRange{ivBatch, ivB_row, + b_col_idx_tail}, + mask_vec, c0_dynamicType_vec); + Value c_vec_tail = + builder.create( + loc, + VectorType::get({stepPlaceHolder}, + A_elementType), + C, + ValueRange{ivBatch, ivA_row, + b_col_idx_tail}, + mask_vec, c0_dynamicType_vec); + Value result_vec_tail; + if (A_elementType.isIntOrFloat() && 0) { // bug + Value add_vec = builder.create( + loc, a_vec, b_vec_tail); + result_vec_tail = builder.create( + loc, add_vec, c_vec_tail); + } else { + result_vec_tail = builder.create( + loc, a_vec, b_vec_tail, c_vec_tail); + } + builder.create( + loc, C, + ValueRange{ivBatch, ivA_row, b_col_idx_tail}, + mask_vec, result_vec_tail); + builder.create(loc); + }); + }); + }); + }); + + rewriter.create(loc); + + parallelLoop.getRegion().push_back(body); + rewriter.setInsertionPointAfter(parallelLoop); + + rewriter.eraseOp(op); + return success(); + } + +private: + int64_t stepPlaceHolder; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// BatchMatMulOptimizePass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering linalg pooling operations to mixture of +/// Affine + Vector operations. +namespace { +class BatchMatMulOptimizePass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BatchMatMulOptimizePass) + StringRef getArgument() const final { return "batchmatmul-optimize"; } + StringRef getDescription() const final { return "BatchMatMul Optimization."; } + BatchMatMulOptimizePass() = default; + BatchMatMulOptimizePass(const BatchMatMulOptimizePass &) {} + explicit BatchMatMulOptimizePass(int64_t stepPlaceHolderParam) { + stepPlaceHolder = stepPlaceHolderParam; + } + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + Option stepPlaceHolder{ + *this, "step-placeholder", + llvm::cl::desc("Affine step placeholder size."), llvm::cl::init(64)}; +}; +} // end anonymous namespace. + +void BatchMatMulOptimizePass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target + .addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, stepPlaceHolder); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +namespace mlir { +namespace buddy { +void registerBatchMatMulOptimizePass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt b/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt index d13c5bcedd..069de2b521 100644 --- a/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt +++ b/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt @@ -1,3 +1,7 @@ add_mlir_library(MatMulOptimization MatMulOptimize.cpp ) + +add_mlir_library(BatchMatMulOptimization + BatchMatMulOptimize.cpp + ) diff --git a/tools/buddy-opt/CMakeLists.txt b/tools/buddy-opt/CMakeLists.txt index 119be40a7a..333ee42ead 100644 --- a/tools/buddy-opt/CMakeLists.txt +++ b/tools/buddy-opt/CMakeLists.txt @@ -20,6 +20,7 @@ target_link_libraries(buddy-opt BuddyRVV LowerRVVPass MatMulOptimization + BatchMatMulOptimization ConvOptimization VectorExp LowerVectorExpPass diff --git a/tools/buddy-opt/buddy-opt.cpp b/tools/buddy-opt/buddy-opt.cpp index 3341e9583a..8bde8b5711 100644 --- a/tools/buddy-opt/buddy-opt.cpp +++ b/tools/buddy-opt/buddy-opt.cpp @@ -54,6 +54,7 @@ void registerLowerBudPass(); void registerLowerDIPPass(); void registerLowerDAPPass(); void registerLowerRVVPass(); +void registerBatchMatMulOptimizePass(); void registerMatMulOptimizePass(); void registerConvOptimizePass(); void registerLowerVectorExpPass(); @@ -80,6 +81,7 @@ int main(int argc, char **argv) { // Register Several Optimize Pass. mlir::buddy::registerMatMulOptimizePass(); + mlir::buddy::registerBatchMatMulOptimizePass(); mlir::buddy::registerConvOptimizePass(); mlir::DialectRegistry registry;