diff --git a/CMakeLists.txt b/CMakeLists.txt index 9b276863..39386b54 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,9 +9,8 @@ cmake_minimum_required(VERSION 3.10) #------------------------------------------------------------------------------- # Project setup and globals #------------------------------------------------------------------------------- - project(buddy-benchmark LANGUAGES CXX C) - + set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED YES) set(CMAKE_CXX_FLAGS "-no-pie") @@ -23,14 +22,27 @@ set(CMAKE_C_FLAGS "-no-pie") set(BuddyMLIR_DIR ${BUDDY_MLIR_BUILD_DIR}/cmake) find_package(BuddyMLIR REQUIRED CONFIG) +if(CROSS_COMPILE_RVV) + set(RISCV_GNU_TOOLCHAIN ${BUDDY_MLIR_BUILD_DIR}/thirdparty/riscv-gnu-toolchain) + set(RISCV_GNU_TOOLCHAIN_SYSROOT ${RISCV_GNU_TOOLCHAIN}/sysroot) + set(BUDDY_OPT_ATTR +v,+m CACHE STRING "Target Architecture.") + set(BUDDY_OPT_TRIPLE riscv64 CACHE STRING "Target Triple.") + set(BUDDY_MLIR_CROSS_LIB_DIR ${BUDDY_MLIR_BUILD_CROSS_DIR}/lib) +else() + set(BUDDY_OPT_ATTR avx512f CACHE STRING "Target Architecture.") + set(BUDDY_OPT_TRIPLE x86_64-unknown-linux-gnu CACHE STRING "Target Triple.") + set(BUDDY_MLIR_LIB_DIR ${BUDDY_MLIR_BUILD_DIR}/lib) +endif() +message(STATUS "Configuring Target Architecture: ${BUDDY_OPT_ATTR}") +message(STATUS "Configuring Target Triple: ${BUDDY_OPT_TRIPLE}") + # BUDDY project. set(BUDDY_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) set(BUDDY_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/bin) -set(BUDDY_EXAMPLES_DIR ${BUDDY_SOURCE_DIR}/examples) -set(BUDDY_OPT_ATTR avx512f CACHE STRING "Target Architecture.") -set(BUDDY_OPT_TRIPLE x86_64-unknown-linux-gnu CACHE STRING "Target Triple.") -message(STATUS "Configuring Target Architecture: ${BUDDY_OPT_ATTR}") -message(STATUS "Configuring Target Triple: ${BUDDY_OPT_TRIPLE}") +set(BUDDY_MLIR_BINARY_DIR ${BUDDY_MLIR_BUILD_DIR}/bin) +set(BUDDY_BENCHMARK_DEEP_LEARNING_DIR ${BUDDY_SOURCE_DIR}/benchmarks/DeepLearning) + + set(BUILD_TESTS OFF CACHE BOOL "Build tests") set(BUILD_VALIDATION OFF CACHE BOOL "Build validations") @@ -49,10 +61,9 @@ set(LLVM_MLIR_LIBRARY_DIR ${BUDDY_MLIR_BUILD_DIR}/../llvm/build/lib) # Helper functions. include(${BUDDY_SOURCE_DIR}/cmake/buddy-benchmark.cmake) -#------------------------------------------------------------------------------- +# ------------------------------------------------------------------------------- # Deploy google/benchmark -#------------------------------------------------------------------------------- - +# ------------------------------------------------------------------------------- message(STATUS "Configuring benchmarks: google") include(ExternalProject) @@ -65,12 +76,12 @@ ExternalProject_Add(project_googlebenchmark TIMEOUT 10 BUILD_BYPRODUCTS /lib/${CMAKE_STATIC_LIBRARY_PREFIX}benchmark${CMAKE_STATIC_LIBRARY_SUFFIX} CMAKE_ARGS - -DCMAKE_INSTALL_PREFIX=${CMAKE_CURRENT_BINARY_DIR}/vendor/benchmark - -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} - -DBENCHMARK_ENABLE_TESTING=OFF - -DCMAKE_SYSTEM_NAME=${CMAKE_SYSTEM_NAME} - -DCMAKE_SYSTEM_PROCESSOR=${CMAKE_SYSTEM_PROCESSOR} - -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} + -DCMAKE_INSTALL_PREFIX=${CMAKE_CURRENT_BINARY_DIR}/vendor/benchmark + -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} + -DBENCHMARK_ENABLE_TESTING=OFF + -DCMAKE_SYSTEM_NAME=${CMAKE_SYSTEM_NAME} + -DCMAKE_SYSTEM_PROCESSOR=${CMAKE_SYSTEM_PROCESSOR} + -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} UPDATE_COMMAND "" TEST_COMMAND "") @@ -87,39 +98,34 @@ add_dependencies(GoogleBenchmark project_googlebenchmark) find_package(Threads) target_link_libraries(GoogleBenchmark INTERFACE Threads::Threads) -#------------------------------------------------------------------------------- +# ------------------------------------------------------------------------------- # Find OpenCV -#------------------------------------------------------------------------------- - +# ------------------------------------------------------------------------------- if(DEFINED IMAGE_PROCESSING_BENCHMARKS OR OP_OPTIMIZATION_BENCHMARKS) find_package(OpenCV REQUIRED CONFIG) include_directories(${OpenCV_INCLUDE_DIRS}) endif() -#------------------------------------------------------------------------------- +# ------------------------------------------------------------------------------- # Find PNG -#------------------------------------------------------------------------------- - +# ------------------------------------------------------------------------------- if(DEFINED IMAGE_PROCESSING_BENCHMARKS) find_package(PNG REQUIRED) include_directories(${PNG_INCLUDE_DIR}) endif() -#------------------------------------------------------------------------------- +# ------------------------------------------------------------------------------- # Hardware detection -#------------------------------------------------------------------------------- - +# ------------------------------------------------------------------------------- include(${BUDDY_SOURCE_DIR}/cmake/check-simd.cmake) check_simd() -#------------------------------------------------------------------------------- +# ------------------------------------------------------------------------------- # Subdirectory -#------------------------------------------------------------------------------- - +# ------------------------------------------------------------------------------- add_subdirectory(benchmarks) add_subdirectory(utils) -if (BUILD_VALIDATION) +if(BUILD_VALIDATION) add_subdirectory(validation) endif() - diff --git a/benchmarks/DeepLearning/CMakeLists.txt b/benchmarks/DeepLearning/CMakeLists.txt index 61e7fa96..c09dc1c8 100644 --- a/benchmarks/DeepLearning/CMakeLists.txt +++ b/benchmarks/DeepLearning/CMakeLists.txt @@ -1 +1,2 @@ +add_subdirectory(Models) add_subdirectory(Ops) diff --git a/benchmarks/DeepLearning/Models/CMakeLists.txt b/benchmarks/DeepLearning/Models/CMakeLists.txt new file mode 100644 index 00000000..509fa5e2 --- /dev/null +++ b/benchmarks/DeepLearning/Models/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(MobileNet-V3) diff --git a/benchmarks/DeepLearning/Models/MobileNet-V3/.gitignore b/benchmarks/DeepLearning/Models/MobileNet-V3/.gitignore new file mode 100644 index 00000000..9eb1b173 --- /dev/null +++ b/benchmarks/DeepLearning/Models/MobileNet-V3/.gitignore @@ -0,0 +1,7 @@ +# model params file +arg0.data +arg1.data + +# model mlir file +forward.mlir +subgraph0.mlir diff --git a/benchmarks/DeepLearning/Models/MobileNet-V3/CMakeLists.txt b/benchmarks/DeepLearning/Models/MobileNet-V3/CMakeLists.txt new file mode 100644 index 00000000..a624e096 --- /dev/null +++ b/benchmarks/DeepLearning/Models/MobileNet-V3/CMakeLists.txt @@ -0,0 +1,161 @@ +add_custom_command( + OUTPUT + ${BUDDY_BENCHMARK_DEEP_LEARNING_DIR}/Models/MobileNet-V3/forward.mlir + ${BUDDY_BENCHMARK_DEEP_LEARNING_DIR}/Models/MobileNet-V3/subgraph0.mlir + COMMAND python3 ${BUDDY_BENCHMARK_DEEP_LEARNING_DIR}/Models/MobileNet-V3/buddy_mobilenetv3_import.py + COMMENT "Generating forward.mlir, subgraph0.mlir" +) + +add_custom_command( + OUTPUT forward_auto_vectorization.o + COMMAND cat ${BUDDY_BENCHMARK_DEEP_LEARNING_DIR}/Models/MobileNet-V3/forward.mlir | + sed -e {s/@forward/@forward_auto_vectorization/} -e {s/@subgraph0/@subgraph0_auto_vectorization/} | + ${LLVM_MLIR_BINARY_DIR}/mlir-opt + -pass-pipeline + "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith), \ + empty-tensor-to-alloc-tensor, convert-elementwise-to-linalg, arith-bufferize, \ + func.func(linalg-bufferize, tensor-bufferize), func-bufferize)" | + ${LLVM_MLIR_BINARY_DIR}/mlir-opt + -pass-pipeline + "builtin.module(func.func(buffer-deallocation-simplification, convert-linalg-to-loops), \ + eliminate-empty-tensors, func.func(llvm-request-c-wrappers), \ + convert-math-to-llvm, convert-math-to-libm, convert-scf-to-cf, \ + convert-arith-to-llvm, expand-strided-metadata, finalize-memref-to-llvm, \ + convert-func-to-llvm, reconcile-unrealized-casts)" | + ${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir | + ${LLVM_MLIR_BINARY_DIR}/llvm-as | + ${LLVM_MLIR_BINARY_DIR}/llc -O3 -mtriple=${BUDDY_OPT_TRIPLE} + -mattr=${BUDDY_OPT_ATTR} -filetype=obj + -o ${BUDDY_BINARY_DIR}/../benchmarks/DeepLearning/Models/MobileNet-V3/forward_auto_vectorization.o + DEPENDS ${BUDDY_BENCHMARK_DEEP_LEARNING_DIR}/Models/MobileNet-V3/forward.mlir + COMMENT "Building forward_auto_vectorization.o" + VERBATIM) + +add_custom_command( + OUTPUT subgraph0_auto_vectorization.o + COMMAND cat ${BUDDY_BENCHMARK_DEEP_LEARNING_DIR}/Models/MobileNet-V3/subgraph0.mlir | + sed -e {s/@subgraph0/@subgraph0_auto_vectorization/} | + ${BUDDY_MLIR_BINARY_DIR}/buddy-opt + -pass-pipeline + "builtin.module(func.func(tosa-to-linalg-named, tosa-to-arith, tosa-to-linalg, tosa-to-tensor))" | + ${BUDDY_MLIR_BINARY_DIR}/buddy-opt + -convert-elementwise-to-linalg + -func-bufferize-dynamic-offset + -arith-bufferize + -func-bufferize + -tensor-bufferize + -linalg-bufferize + -finalizing-bufferize + -convert-linalg-to-loops + -lower-affine + -convert-scf-to-cf + -llvm-request-c-wrappers + -convert-math-to-llvm + -convert-math-to-libm + -convert-arith-to-llvm + -convert-func-to-llvm + -expand-strided-metadata + -finalize-memref-to-llvm + -reconcile-unrealized-casts | + ${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir | + ${LLVM_MLIR_BINARY_DIR}/llvm-as | + ${LLVM_MLIR_BINARY_DIR}/llc -O3 -mtriple=${BUDDY_OPT_TRIPLE} + -mattr=${BUDDY_OPT_ATTR} -filetype=obj + -o ${BUDDY_BINARY_DIR}/../benchmarks/DeepLearning/Models/MobileNet-V3/subgraph0_auto_vectorization.o + DEPENDS ${BUDDY_BENCHMARK_DEEP_LEARNING_DIR}/Models/MobileNet-V3/subgraph0.mlir + ${BUDDY_MLIR_BINARY_DIR}/buddy-opt + COMMENT "Building subgraph0_auto_vectorization.o" + VERBATIM) + +add_custom_command( + OUTPUT forward_vectorization.o + COMMAND cat ${BUDDY_BENCHMARK_DEEP_LEARNING_DIR}/Models/MobileNet-V3/forward.mlir | + sed -e {s/@forward/@forward_vectorization/} -e {s/@subgraph0/@subgraph0_vectorization/} | + ${LLVM_MLIR_BINARY_DIR}/mlir-opt + -pass-pipeline + "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith), \ + empty-tensor-to-alloc-tensor, convert-elementwise-to-linalg, arith-bufferize, \ + func.func(linalg-bufferize, tensor-bufferize), func-bufferize)" | + ${LLVM_MLIR_BINARY_DIR}/mlir-opt + -pass-pipeline + "builtin.module(func.func(buffer-deallocation-simplification, convert-linalg-to-loops), \ + eliminate-empty-tensors, func.func(llvm-request-c-wrappers), \ + convert-math-to-llvm, convert-math-to-libm, convert-scf-to-cf, \ + convert-arith-to-llvm, expand-strided-metadata, finalize-memref-to-llvm, \ + convert-func-to-llvm, reconcile-unrealized-casts)" | + ${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir | + ${LLVM_MLIR_BINARY_DIR}/llvm-as | + ${LLVM_MLIR_BINARY_DIR}/llc -O3 -mtriple=${BUDDY_OPT_TRIPLE} + -mattr=${BUDDY_OPT_ATTR} -filetype=obj + -o ${BUDDY_BINARY_DIR}/../benchmarks/DeepLearning/Models/MobileNet-V3/forward_vectorization.o + DEPENDS ${BUDDY_BENCHMARK_DEEP_LEARNING_DIR}/Models/MobileNet-V3/forward.mlir + COMMENT "Building forward_vectorization.o" + VERBATIM) + +add_custom_command( + OUTPUT subgraph0_vectorization.o + COMMAND cat ${BUDDY_BENCHMARK_DEEP_LEARNING_DIR}/Models/MobileNet-V3/subgraph0.mlir | + sed -e {s/@subgraph0/@subgraph0_vectorization/} | + ${BUDDY_MLIR_BINARY_DIR}/buddy-opt + -pass-pipeline + "builtin.module(func.func(tosa-to-linalg-named, tosa-to-arith, tosa-to-linalg, tosa-to-tensor))" | + ${BUDDY_MLIR_BINARY_DIR}/buddy-opt + -convert-elementwise-to-linalg + -func-bufferize-dynamic-offset + -arith-bufferize + -func-bufferize + -tensor-bufferize + -linalg-bufferize + -finalizing-bufferize + -batchmatmul-optimize + -convert-linalg-to-affine-loops + -lower-affine + -convert-vector-to-scf + -convert-scf-to-cf + -llvm-request-c-wrappers + -convert-vector-to-llvm + -convert-math-to-llvm + -convert-math-to-libm + -convert-arith-to-llvm + -convert-func-to-llvm + -expand-strided-metadata + -finalize-memref-to-llvm + -reconcile-unrealized-casts | + ${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir | + ${LLVM_MLIR_BINARY_DIR}/llvm-as | + ${LLVM_MLIR_BINARY_DIR}/llc -O3 -mtriple=${BUDDY_OPT_TRIPLE} + -mattr=${BUDDY_OPT_ATTR} -filetype=obj + -o ${BUDDY_BINARY_DIR}/../benchmarks/DeepLearning/Models/MobileNet-V3/subgraph0_vectorization.o + DEPENDS ${BUDDY_BENCHMARK_DEEP_LEARNING_DIR}/Models/MobileNet-V3/subgraph0.mlir + ${BUDDY_MLIR_BINARY_DIR}/buddy-opt + COMMENT "Building subgraph0_vectorization.o" + VERBATIM) + +add_library(MOBILENETV3_AUTO_VECTORIZATION STATIC subgraph0_auto_vectorization.o forward_auto_vectorization.o) +set_target_properties(MOBILENETV3_AUTO_VECTORIZATION PROPERTIES LINKER_LANGUAGE CXX) + +add_library(MOBILENETV3_VECTORIZATION STATIC subgraph0_vectorization.o forward_vectorization.o) +set_target_properties(MOBILENETV3_VECTORIZATION PROPERTIES LINKER_LANGUAGE CXX) + +add_executable(dl-model-mobileNetV3-benchmark + GoogleBenchmarkMain.cpp +) + +set_target_properties(dl-model-mobileNetV3-benchmark PROPERTIES + LINK_FLAGS "-static" +) + +set(BenchmarkTool GoogleBenchmark) + +if(CROSS_COMPILE_RVV) + set(BUDDY_LIB_DIR ${BUDDY_MLIR_CROSS_LIB_DIR}) +else() + set(BUDDY_LIB_DIR ${BUDDY_MLIR_LIB_DIR}) +endif() + +target_link_libraries(dl-model-mobileNetV3-benchmark + ${BenchmarkTool} + MOBILENETV3_AUTO_VECTORIZATION + MOBILENETV3_VECTORIZATION + ${BUDDY_LIB_DIR}/libStaticMLIRCRunnerUtils.a +) diff --git a/benchmarks/DeepLearning/Models/MobileNet-V3/GoogleBenchmarkMain.cpp b/benchmarks/DeepLearning/Models/MobileNet-V3/GoogleBenchmarkMain.cpp new file mode 100644 index 00000000..1375f8ea --- /dev/null +++ b/benchmarks/DeepLearning/Models/MobileNet-V3/GoogleBenchmarkMain.cpp @@ -0,0 +1,164 @@ +//===- GoogleBenchmarkMain.cpp --------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the benchmark for Mobilenet-V3 model. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include +#include +#include + +#define INPUT_N 1 +#define INPUT_C 3 +#define INPUT_H 224 +#define INPUT_W 224 +#define OUTPUT_N 1000 + +// Helper functions and variables. +namespace { +const std::string PASS = "\033[32mPASS\033[0m"; +const std::string FAIL = "\033[31mFAIL\033[0m"; + +constexpr size_t ParamsSize = 2554968; + +bool areArraysEqual(float array1[], float array2[], int size, + float epsilon = 0.0001) { + for (int i = 0; i < size; ++i) { + if (fabs(array1[i] - array2[i]) > epsilon) { + return false; + } + } + return true; +} +} // namespace + +namespace { + +// Declare the mobilenet C interface. +extern "C" { +void _mlir_ciface_forward_auto_vectorization(MemRef *output, + MemRef *arg0, + MemRef *arg1, + Img *input); + +void _mlir_ciface_forward_vectorization(MemRef *output, + MemRef *arg0, + MemRef *arg1, + Img *input); +} + +template +void BM_MobileNet_V3(benchmark::State &state, Func func) { + + // Define the sizes of the input and output tensors. + intptr_t sizesInput[4] = {INPUT_N, INPUT_C, INPUT_H, INPUT_W}; + intptr_t sizesOutput[2] = {1, OUTPUT_N}; + + // Generate input memref container with random numbers. + const int inputSize = INPUT_N * INPUT_C * INPUT_H * INPUT_W; + + // Create input and output containers for the image and model output. + Img input(sizesInput); + MemRef output(sizesOutput); + + // Set random model parameters. + MemRef paramsContainerf32({ParamsSize}, 2.0); + MemRef ParamsContainerInt64({34}, 1.0); + + for (auto _ : state) { + func(&output, ¶msContainerf32, &ParamsContainerInt64, &input); + } +} + +} // namespace + +// Register benchmarking function with different arguments. +BENCHMARK_CAPTURE(BM_MobileNet_V3, BM_MobileNet_V3_Auto_Vectorization, + _mlir_ciface_forward_auto_vectorization) + ->Unit(benchmark::kMillisecond); +BENCHMARK_CAPTURE(BM_MobileNet_V3, BM_MobileNet_V3_Vectorization, + _mlir_ciface_forward_vectorization) + ->Unit(benchmark::kMillisecond); + +/// Correctness Verification +/// The verification does not affect the performance. +/// - Set the scalar case as the criteria. +/// - Input elements are random numbers. +/// - Output elements are initialized to zero. +/// - Compare the output of various optimizations with the scalar version to +/// verify correctness. +void verification() { + // Set the random number generator. + std::random_device rd; + std::mt19937 generator(rd()); + std::uniform_real_distribution distribution(0.0, 1.0); + + // Define the sizes of the input and output tensors. + intptr_t sizesInput[4] = {INPUT_N, INPUT_C, INPUT_H, INPUT_W}; + intptr_t sizesOutput[2] = {1, OUTPUT_N}; + + // Generate input memref container with random numbers. + const int inputSize = INPUT_N * INPUT_C * INPUT_H * INPUT_W; + float inputRand[inputSize]; + for (int i = 0; i < inputSize; ++i) { + inputRand[i] = distribution(generator); + } + + // Create input and output containers for the image and model output. + Img input(inputRand, sizesInput); + MemRef outputScalar(sizesOutput); + MemRef outputVectorization(sizesOutput); + + // Load model parameters from the specified file. + MemRef paramsContainerf32({ParamsSize}, 3.0); + MemRef ParamsContainerInt64({34}, 2.0); + + // Call the forward function of the model. + _mlir_ciface_forward_auto_vectorization(&outputScalar, ¶msContainerf32, + &ParamsContainerInt64, &input); + _mlir_ciface_forward_vectorization(&outputVectorization, ¶msContainerf32, + &ParamsContainerInt64, &input); + + auto resultScalar = outputScalar.getData(); + auto resultVectorization = outputVectorization.getData(); + + // Print the verfication result. + std::cout << "-----------------------------------------------------------" + << std::endl; + std::cout << "Correctness Verification:" << std::endl; + std::cout << "Transform case: " + << (areArraysEqual(resultScalar, resultVectorization, OUTPUT_N) + ? PASS + : FAIL) + << std::endl; + std::cout << "-----------------------------------------------------------" + << std::endl; +} + +int main(int argc, char **argv) { + // Run benchmark. + ::benchmark::Initialize(&argc, argv); + ::benchmark::RunSpecifiedBenchmarks(); + // Run correctness verification. + verification(); + return 0; +} diff --git a/benchmarks/DeepLearning/Models/MobileNet-V3/buddy_mobilenetv3_import.py b/benchmarks/DeepLearning/Models/MobileNet-V3/buddy_mobilenetv3_import.py new file mode 100644 index 00000000..bad22e94 --- /dev/null +++ b/benchmarks/DeepLearning/Models/MobileNet-V3/buddy_mobilenetv3_import.py @@ -0,0 +1,60 @@ +# ===- buddy_mobilenetv3_import.py --------------------------------------------- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===--------------------------------------------------------------------------- +# +# This is the MobileNet V3 model AOT importer. +# +# ===--------------------------------------------------------------------------- + +import os + +from pathlib import Path +import numpy as np +import torch +import torchvision.models as models +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.graph import GraphDriver +from buddy.compiler.graph.transform import simply_fuse +from buddy.compiler.ops import tosa + + +model = models.mobilenet_v3_small( + weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1, pretrained=True +) +model = model.eval() + +# Initialize Dynamo Compiler with specific configurations as an importer. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) +data = torch.randn([1, 3, 224, 224]) +# Import the model into MLIR module and parameters. +with torch.no_grad(): + graphs = dynamo_compiler.importer(model, data) +assert len(graphs) == 1 +graph = graphs[0] +params = dynamo_compiler.imported_params[graph] +pattern_list = [simply_fuse] +graphs[0].fuse_ops(pattern_list) +driver = GraphDriver(graphs[0]) +driver.subgraphs[0].lower_to_top_level_ir() +path_prefix = os.path.dirname(os.path.abspath(__file__)) +with open(os.path.join(path_prefix, "subgraph0.mlir"), "w") as module_file: + print(driver.subgraphs[0]._imported_module, file=module_file) +with open(os.path.join(path_prefix, "forward.mlir"), "w") as module_file: + print(driver.construct_main_graph(True), file=module_file) diff --git a/benchmarks/DeepLearning/README.md b/benchmarks/DeepLearning/README.md index c4a6950b..4f519ba3 100644 --- a/benchmarks/DeepLearning/README.md +++ b/benchmarks/DeepLearning/README.md @@ -1,5 +1,12 @@ # Deep Learning Benchmark +## Model Level Benchmark +The table below lists the benchmark cases at the operation level. + +| Name | Build Target | Introduction | +| -------------- | ------------- | ------------- | +| MobileNet-V3 | `ninja dl-model-mobileNetV3-benchmark` | This benchmark compares multiple optimization strategies targeting the MobileNet-V3 model. | + ## Operation Level Benchmark The table below lists the benchmark cases at the operation level. @@ -24,18 +31,29 @@ The table below lists the benchmark cases at the operation level. | Reduce Maxf | `ninja dl-op-reduce-maxf-benchmark` | This benchmark evaluates optimization strategies for the `reduce.maxf` operation. The benchmark size can be adjusted in [this file](./Ops/ReduceMaxfOp/GoogleBenchmarkMain.cpp). | | Softmax Exp Sum Div | `ninja dl-op-softmax-exp-sum-div-benchmark` | This benchmark evaluates optimization strategies for the `softmax.exp_sum_div` operation. The benchmark size can be adjusted in [this file](./Ops/SoftmaxExpSumDivOp/GoogleBenchmarkMain.cpp). | +### Enter Python virtual environment +We recommend you to use anaconda3 to create python virtual environment. You should install python packages as buddy-mlir/requirements. +```bash +$ conda activate +$ cd buddy-benchmark +$ pip install -r requirements.txt +``` + ### Local Hardware Platform. -1. Set the `buddy-mlir` toolchain: +1. Set the `buddy-mlir` toolchain and PYTHONPATH environment variable: +Make sure that the PYTHONPATH variable includes the directory of LLVM/MLIR python bindings and the directory of Buddy MLIR python packages. -``` +```bash $ cd buddy-mlir/build $ export BUDDY_MLIR_BUILD_DIR=$PWD +$ export LLVM_MLIR_BUILD_DIR=$PWD/../llvm/build +$ export PYTHONPATH=${LLVM_MLIR_BUILD_DIR}/tools/mlir/python_packages/mlir_core:${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} ``` 2. Build benchmark for local platform: -``` +```bash $ cd buddy-benchmark $ mkdir build && cd build $ cmake -G Ninja .. \ @@ -63,15 +81,19 @@ Follow the relevant [documentation](https://github.com/buddy-compiler/buddy-mlir 1. Set variables for the toolchain: -``` +```bash $ cd buddy-mlir/build $ export BUDDY_MLIR_BUILD_DIR=$PWD +$ export LLVM_MLIR_BUILD_DIR=$PWD/../llvm/build +$ export PYTHONPATH=${LLVM_MLIR_BUILD_DIR}/tools/mlir/python_packages/mlir_core:${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} $ export RISCV_GNU_TOOLCHAIN=${BUDDY_MLIR_BUILD_DIR}/thirdparty/riscv-gnu-toolchain +$ cd ../build-cross-rv +$ export BUDDY_MLIR_BUILD_CROSS_DIR=$PWD ``` 2. Build the benchmark for the target platform: -``` +```bash $ cd buddy-benchmark $ mkdir build && cd build $ cmake -G Ninja .. \ @@ -82,7 +104,8 @@ $ cmake -G Ninja .. \ -DCMAKE_SYSTEM_PROCESSOR=riscv \ -DCMAKE_C_COMPILER=${RISCV_GNU_TOOLCHAIN}/bin/riscv64-unknown-linux-gnu-gcc \ -DCMAKE_CXX_COMPILER=${RISCV_GNU_TOOLCHAIN}/bin/riscv64-unknown-linux-gnu-g++ \ - -DBUDDY_MLIR_BUILD_DIR=${BUDDY_MLIR_BUILD_DIR} + -DBUDDY_MLIR_BUILD_DIR=${BUDDY_MLIR_BUILD_DIR} \ + -DBUDDY_MLIR_BUILD_CROSS_DIR=${BUDDY_MLIR_BUILD_CROSS_DIR} $ ninja // For example: $ ninja dl-op-linalg-matmul-benchmark diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..9818b8ec --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +--pre --extra-index-url https://download.pytorch.org/whl/cpu +torch == 2.1.2 +numpy < 2 +transformers == 4.33.1 +tokenizers == 0.13.3 +sentencepiece == 0.1.99 +accelerate +protobuf +pybind11 == 2.11.1 +torchvision +tabulate +datasets +soundfile +librosa