From c4a6054332f88f293bea1a0a41fe9571f811036d Mon Sep 17 00:00:00 2001 From: fandaoyi Date: Mon, 24 Jun 2024 10:19:25 +0800 Subject: [PATCH] [muxi]Fdy/add mx rt diopi (#1267) * add mx impl --- .gitignore | 1 + adaptor/codegen/gen.py | 170 +++++++++++------- impl/CMakeLists.txt | 19 ++ impl/muxi/CMakeLists.txt | 75 ++++++++ impl/muxi/convert_config.yaml | 10 ++ impl/muxi/device_configs.py | 7 + impl/muxi/functions/functions.cpp | 32 ++++ .../flash-attention/CMakeLists.txt | 5 + impl/muxi/test/CMakeLists.txt | 49 +++++ impl/scripts/build_impl.sh | 10 +- impl/torch/CMakeLists.txt | 3 + impl/torch/cmake/TorchBaseFunc.cmake | 11 ++ 12 files changed, 326 insertions(+), 66 deletions(-) create mode 100644 impl/muxi/CMakeLists.txt create mode 100644 impl/muxi/convert_config.yaml create mode 100755 impl/muxi/device_configs.py create mode 100644 impl/muxi/functions/functions.cpp create mode 100644 impl/muxi/functions/functions_ext/flash-attention/CMakeLists.txt create mode 100644 impl/muxi/test/CMakeLists.txt create mode 100644 impl/torch/cmake/TorchBaseFunc.cmake diff --git a/.gitignore b/.gitignore index 5d2d5f3609..7a4a68c4d7 100644 --- a/.gitignore +++ b/.gitignore @@ -141,3 +141,4 @@ kernel_meta_* # generate files impl/*/test/export_functions.cpp proto/include/diopi/diopi_adaptors.hpp +diopilib \ No newline at end of file diff --git a/adaptor/codegen/gen.py b/adaptor/codegen/gen.py index 2d4e71cbc8..5069d370b2 100644 --- a/adaptor/codegen/gen.py +++ b/adaptor/codegen/gen.py @@ -68,6 +68,9 @@ "AdamW": ["param", "exp_avg", "exp_avg_sq", "max_exp_avg_sq"], } +def remap_impl_device(device): + return "cuda" if device == "torch" else device + def findAllFile(base: str) -> Iterator[str]: for root, ds, fs in os.walk(base): @@ -178,16 +181,21 @@ def prepare() -> Tuple[dict, str]: help="if functinos are implemented with plugin mode once more, then compile both of them.", default=False, ) + parser.add_argument( + '--base_device', + help="if set base device, generator scan base device dir first and then override them if functions defined in device dir", + default=None, + ) options = parser.parse_args() source = os.path.join(options.diopi_dir, "proto/include/diopi") config_path = os.path.join( options.diopi_dir, "impl/", options.config_device ) - device = ( - "cuda" if options.config_device == "torch" else options.config_device - ) - plugin = options.impl_plugin + device = remap_impl_device(options.config_device) + + impl_plugin = options.impl_plugin + base_device = options.base_device def create_if_not_exist(name): if not os.path.exists(name): @@ -197,7 +205,7 @@ def create_if_not_exist(name): dirs = dict( source=source, output_dir=options.output_dir, config_path=config_path ) - return dirs, device, plugin + return dirs, device, impl_plugin, base_device def get_func_info(content: list) -> Tuple[list, list, list, dict]: @@ -529,7 +537,7 @@ def memory_format_to_str(memory_format): def autogen_op_adaptor( - op_configs: dict, device: str, func_infos: dict, impl_funcs: dict, impl_plugin: bool, plugin_config: dict + op_configs: dict, device: str, func_infos: dict, impl_funcs: dict, func_device_map: dict ) -> list: adaptors_code = [] cast = ( @@ -546,7 +554,7 @@ def autogen_op_adaptor( device_mapping = "composite" else: continue - device_mapping = plugin_config.get(func, device) + device_mapping = func_device_map.get(func, device) if ( ( func not in op_configs.keys() @@ -700,8 +708,73 @@ def get_composite_funcs_declaration( return composite_funcs_decl +def get_all_impl_functions(impl_base_dir) -> dict: + # get the implemented functions + impl_func_dir = os.path.join(impl_base_dir, "functions") + impl_func_ext_dir = os.path.join(impl_base_dir, "functions_ext") + impl_functions = obtain_impl_func(impl_func_dir) + impl_functions_ext = obtain_impl_func(impl_func_ext_dir) + impl_functions.update(impl_functions_ext) + return impl_functions + + +def gen_ascend_impl_plugin_funcs(dirs: dict, impl_base_dir: str, impl_functions: dict): + ascend_config_path = os.path.join(dirs.get("config_path"), "../ascend_npu/ascend_config.yaml") + try: + with open(ascend_config_path, "r") as f: + ascend_configs = yaml.safe_load(f) + except Exception as e: + print(e) + return + func_device_map = ascend_func_impl_config(ascend_configs) + + impl_plugin_dir = os.path.join(impl_base_dir, "../ascend_npu/diopi_impl") + impl_npu_functions = obtain_impl_func(impl_plugin_dir) + + #check config items all implemented + not_impled = [] + for op in ascend_configs['ascend']: + if op not in impl_functions.keys(): + not_impled.append(op) + if not_impled != []: + print(f"[GenAscendConfig] {not_impled} not implemented in ascend namespace") + return + not_impled.clear() + for op in ascend_configs['ascend_npu']: + if op not in impl_npu_functions.keys(): + not_impled.append(op) + if not_impled != []: + print(f"[GenAscendConfig] {not_impled} not implemented in ascend_npu namespace.") + return + + funcs_info, funcs_decl_raw = get_functions_support(dirs.get("source")) + funcs_npu_decl = get_impl_funcs_declaration( + funcs_decl_raw, funcs_info, impl_npu_functions.keys(), True, + ) + return funcs_npu_decl, impl_npu_functions, func_device_map + +def gen_base_device_impl_funcs(device: str, base_device: str, dirs: dict, impl_functions: dict): + base_device_impl_dir = os.path.join(os.path.dirname(dirs.get("config_path")), base_device) + impl_basedev_functions = get_all_impl_functions(base_device_impl_dir) + # remove ops already exist in device impl. + impl_basedev_functions = {op: args for op, args in impl_basedev_functions.items() if op not in impl_functions} + + funcs_info, funcs_decl_raw = get_functions_support(dirs.get("source")) + func_base_decl = get_impl_funcs_declaration( + funcs_decl_raw, funcs_info, impl_basedev_functions.keys(), True, + ) + + func_device_map = {} + for op in impl_basedev_functions.keys(): + func_device_map[op] = remap_impl_device(base_device) + for op in impl_functions.keys(): + func_device_map[op] = device + + return func_base_decl, impl_basedev_functions, func_device_map + def gen_autogen_operators( - dirs: dict, device: str, adaptor_fm: FileManager, impl_plugin: bool + dirs: dict, device: str, adaptor_fm: FileManager, impl_plugin: bool, + base_device: str, ) -> None: config_file_path = os.path.join( dirs.get("config_path"), "convert_config.yaml" @@ -713,46 +786,8 @@ def gen_autogen_operators( print(e) return - if impl_plugin: - ascend_config_path = os.path.join(dirs.get("config_path"), "../ascend_npu/ascend_config.yaml") - try: - with open(ascend_config_path, "r") as f: - ascend_configs = yaml.safe_load(f) - except Exception as e: - print(e) - return - ascend_impl_configs = ascend_func_impl_config(ascend_configs) - else: - ascend_impl_configs = {} - - # get the implemented functions impl_base_dir = os.path.dirname(config_file_path) - impl_func_dir = os.path.join(impl_base_dir, "functions") - impl_func_ext_dir = os.path.join(impl_base_dir, "functions_ext") - impl_functions = obtain_impl_func(impl_func_dir) - impl_functions_ext = obtain_impl_func(impl_func_ext_dir) - impl_functions.update(impl_functions_ext) - - if impl_plugin: - impl_plugin_dir = os.path.join(impl_base_dir, "../ascend_npu/diopi_impl") - impl_npu_functions = obtain_impl_func(impl_plugin_dir) - - #check config items all implemented - not_impled = [] - for op in ascend_configs['ascend']: - if op not in impl_functions.keys(): - not_impled.append(op) - if not_impled != []: - print(f"[GenAscendConfig] {not_impled} not implemented in ascend namespace") - return - not_impled.clear() - for op in ascend_configs['ascend_npu']: - if op not in impl_npu_functions.keys(): - not_impled.append(op) - if not_impled != []: - print(f"[GenAscendConfig] {not_impled} not implemented in ascend_npu namespace.") - return - + impl_functions = get_all_impl_functions(impl_base_dir) impl_funcs = impl_functions.keys() # generate func information and declarations by scanning functions.h @@ -765,28 +800,38 @@ def gen_autogen_operators( funcs_decl = get_impl_funcs_declaration( funcs_decl_raw, funcs_info, impl_funcs, impl_plugin ) - composite_funcs_decl = get_composite_funcs_declaration( - funcs_decl_raw, funcs_info, impl_funcs, op_configs - ) - impl_functions_content = [OT.impl_declaration_content_template.substitute(dict( device=device, impl_declaration=list(funcs_decl.values()), ))] + composite_funcs_decl = get_composite_funcs_declaration( + funcs_decl_raw, funcs_info, impl_funcs, op_configs + ) + impl_functions_content.append(OT.impl_declaration_content_template.substitute(dict( + device='composite', + impl_declaration=list(composite_funcs_decl.values()), + ))) + + func_device_map = {} + if base_device: + funcs_basedev_decl, impl_basedev_functions, func_device_map = gen_base_device_impl_funcs(device, + base_device, dirs, impl_functions) + impl_functions_content.append(OT.impl_declaration_content_template.substitute(dict( + device= remap_impl_device(base_device), + impl_declaration=list(funcs_basedev_decl.values()), + ))) + impl_funcs = {*impl_funcs, *impl_basedev_functions.keys()} + if impl_plugin: - funcs_npu_decl = get_impl_funcs_declaration( - funcs_decl_raw, funcs_info, impl_npu_functions.keys(), impl_plugin - ) + funcs_npu_decl, impl_npu_functions, func_device_map = gen_ascend_impl_plugin_funcs( + dirs, impl_base_dir, impl_functions) impl_functions_content.append(OT.impl_declaration_content_template.substitute(dict( device=device + '_npu', impl_declaration=list(funcs_npu_decl.values()), ))) + impl_funcs = {*impl_funcs, *impl_npu_functions.keys()} - impl_functions_content.append(OT.impl_declaration_content_template.substitute(dict( - device='composite', - impl_declaration=list(composite_funcs_decl.values()), - ))) adaptor_fm.write( "impl_functions.hpp", @@ -796,12 +841,9 @@ def gen_autogen_operators( ), ) - if impl_plugin: - impl_funcs = {*impl_funcs, *impl_npu_functions.keys()} - # generate adaptor implementation codes adaptors_code = autogen_op_adaptor( - op_configs, device, funcs_info, impl_funcs, impl_plugin, ascend_impl_configs + op_configs, device, funcs_info, impl_funcs, func_device_map ) adaptor_fm.write( @@ -817,10 +859,10 @@ def declare_outputs(adaptor_fm: FileManager) -> None: def gen_all_codes() -> None: - dirs, device, impl_plugin = prepare() + dirs, device, impl_plugin, base_device = prepare() adaptor_fm = FileManager(dirs.get("output_dir", ".")) declare_outputs(adaptor_fm) - gen_autogen_operators(dirs, device, adaptor_fm, impl_plugin) + gen_autogen_operators(dirs, device, adaptor_fm, impl_plugin, base_device) adaptor_fm.check_all_files_written() diff --git a/impl/CMakeLists.txt b/impl/CMakeLists.txt index 0ba79a3404..d4b45462a6 100644 --- a/impl/CMakeLists.txt +++ b/impl/CMakeLists.txt @@ -47,6 +47,7 @@ set(LIBRARY_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/lib) list(APPEND IMPL_CUDA "CUDA" "cuda") list(APPEND IMPL_TORCH "TORCH" "LIBTORCH" "ATEN" "PYTORCH" "torch" "libtorch" "aten" "pytorch" "PyTorch") +list(APPEND IMPL_MUXI "MUXI" "muxi") list(APPEND IMPL_TOPS "TOPS" "tops" "TOPSRIDER" "topsrider") list(APPEND IMPL_CAMB_TORCH "CAMB_PYTORCH" "camb_pytorch") list(APPEND IMPL_CAMB "CAMB" "camb") @@ -61,6 +62,8 @@ elseif(${IMPL_OPT} IN_LIST IMPL_TOPS) add_subdirectory(topsrider) elseif (${IMPL_OPT} IN_LIST IMPL_TORCH) add_subdirectory(torch) +elseif (${IMPL_OPT} IN_LIST IMPL_MUXI) + add_subdirectory(muxi) elseif (${IMPL_OPT} IN_LIST IMPL_CAMB_TORCH) add_subdirectory(camb_pytorch) elseif (${IMPL_OPT} IN_LIST IMPL_CAMB) @@ -77,6 +80,22 @@ else() message(WARNING "No implementation module is compiled, cmake requires option -DIMPL_OPT=CUDA or TORCH") endif() +# 1.the lib ${DEVICEIMPL} in which all exported symbols are 'weak' can be considered as 'no-needed' lib. +# some compilers force link such libs by default, but others having 'as-needed' default link-config may +# throw away these libs. so we manually add '-no-as-needed' here to guarantee linking ${DEVICEIMPL}. +# eg: if you compiler don't link 'no-needed' libs by default, please use 'g++ -dumpspecs' and see '*link:' section +# to check if it contains policy like '%{!fsanitize=*:--as-needed}' or other policy having '--as-needed' set. +# you can change compiler's default spec by typing 'gcc -specs=./new.specs' but it's hard to use. +# Supplementary: https://gcc.gnu.org/onlinedocs/gcc/Spec-Files.html + +# 2. when the code below adding "-no-as-needed" opt to link.txt, the opt isn't be added exactly before the place +# ${DEVICEIMPL} is linked but as a link option before any link-items. if another link-item change linking-policy +# as "-Wl,--no-as-needed,\"\$\" -Wl,--as-needed" and ${DEVICEIMPL} is linked just after +# this item; it will still be linked by -as-needed and finally be throw away by ld and cause error !! +# so if this error happens, please add link-option "-Wl,-no-as-needed" to the lib which link ${DEVICEIMPL}. + +target_link_options(${DEVICEIMPL} INTERFACE "LINKER:-no-as-needed") + # install install(DIRECTORY ${DIOPI_IMPL_DIR}/../proto/include/ TYPE INCLUDE) install(FILES lib/lib${DEVICEIMPL}.so TYPE LIB) diff --git a/impl/muxi/CMakeLists.txt b/impl/muxi/CMakeLists.txt new file mode 100644 index 0000000000..eee0940191 --- /dev/null +++ b/impl/muxi/CMakeLists.txt @@ -0,0 +1,75 @@ +project(muxi_impl) + +# muxi torch config +add_compile_definitions(USE_MACA=1) +set(USE_MACA ON) + +set(BASE_TORCH_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../torch") +include(${BASE_TORCH_DIR}/cmake/TorchBaseFunc.cmake) +InitFindTorch() + +find_package(Torch REQUIRED) +if (Torch_FOUND) + message(STATUS "TORCH_CXX_FLAGS: ${TORCH_CXX_FLAGS}") + message(STATUS "TORCH_LIBRARIES: ${TORCH_LIBRARIES}") + + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") + add_definitions(-DTORCH_VERSION_MAJOR=${Torch_VERSION_MAJOR}) + add_definitions(-DTORCH_VERSION_MINOR=${Torch_VERSION_MINOR}) + add_definitions(-DTORCH_VERSION_PATCH=${Torch_VERSION_PATCH}) + add_definitions(-DTORCH_VERSION=${Torch_VERSION}) + message(STATUS "Found Torch Version: ${Torch_VERSION}") +endif() + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") + +file(GLOB REAL_IMPL_SRC + ${BASE_TORCH_DIR}/functions/error.cpp + ${BASE_TORCH_DIR}/functions/functions.cpp + + ${BASE_TORCH_DIR}/functions/functions_lightllm.cpp + ${BASE_TORCH_DIR}/functions/functions_mmcv.cpp + ${BASE_TORCH_DIR}/helper.cpp + ${BASE_TORCH_DIR}/functions/functions_mmcv/*.cu + + ${BASE_TORCH_DIR}/functions/functions_ext.cpp + ${BASE_TORCH_DIR}/functions/functions_ext/*.cu + + # mx cpp + functions/functions.cpp +) + +# adaptor +set(USE_ADAPTOR ON) +if(NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/convert_config.yaml") + message(FATAL_ERROR "convert_config.yaml doesn't exist.") +endif() + +if(USE_ADAPTOR) + # dependency + file(GLOB ADAPTOR_TEMPLATE_CODE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${ADAPTOR_DIR}/codegen/*.py) + add_custom_target(adaptor_gen_dependency DEPENDS ${ADAPTOR_TEMPLATE_CODE}) + + set(ADAPTOR_CSRC_PATH "${ADAPTOR_DIR}/csrc") + set(GEN_FILES ${ADAPTOR_CSRC_PATH}/diopi_adaptor.cpp ${ADAPTOR_CSRC_PATH}/impl_functions.hpp) + add_custom_target(adaptor_code_gen + COMMAND python3 ${ADAPTOR_DIR}/codegen/gen.py --diopi_dir=${DIOPI_IMPL_DIR}/../ --output_dir=${ADAPTOR_CSRC_PATH} + --config_device=muxi --base_device=torch + BYPRODUCTS ${GEN_FILES} + DEPENDS adaptor_gen_dependency) + list(APPEND REAL_IMPL_SRC ${ADAPTOR_CSRC_PATH}/convert.cpp ${ADAPTOR_CSRC_PATH}/diopi_adaptor.cpp ${ADAPTOR_CSRC_PATH}/composite_ops.cpp) +endif() + +cuda_add_library(${DEVICEIMPL} SHARED ${REAL_IMPL_SRC}) +target_link_libraries(${DEVICEIMPL} ${TORCH_LIBRARIES}) +add_subdirectory(functions/functions_ext/flash-attention) +target_link_libraries(${DEVICEIMPL} diopi_torch_ext_flash_attn) +target_include_directories(${DEVICEIMPL} PRIVATE ${BASE_TORCH_DIR}) + +if(USE_ADAPTOR) + add_dependencies(${DEVICEIMPL} adaptor_code_gen) +endif() + +if (TEST) + add_subdirectory(test) +endif() diff --git a/impl/muxi/convert_config.yaml b/impl/muxi/convert_config.yaml new file mode 100644 index 0000000000..fb7787cd14 --- /dev/null +++ b/impl/muxi/convert_config.yaml @@ -0,0 +1,10 @@ +- diopiLinear: + supportComposite: true + +- diopiRMSNorm: + tensor_dtype: + inv_rms: (float16)->float32 + +- diopiRMSNormBackward: + tensor_dtype: + inv_rms: (float16)->float32 \ No newline at end of file diff --git a/impl/muxi/device_configs.py b/impl/muxi/device_configs.py new file mode 100755 index 0000000000..ffaf933c4b --- /dev/null +++ b/impl/muxi/device_configs.py @@ -0,0 +1,7 @@ +# Copyright (c) 2023, DeepLink. +import numpy as np +from skip import Skip + +device_configs = { + +} diff --git a/impl/muxi/functions/functions.cpp b/impl/muxi/functions/functions.cpp new file mode 100644 index 0000000000..37a3409d72 --- /dev/null +++ b/impl/muxi/functions/functions.cpp @@ -0,0 +1,32 @@ +/** + * @file + * @author DeepLink + * @copyright (c) 2024, DeepLink. + */ +#include +#include +#include +#include +#include +#include + +#include "helper.hpp" + +static const char* name = "MuxiDevice"; +const char* diopiGetVendorName() { return name; } + +namespace impl { +namespace muxi { + +diopiError_t diopiCat(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t* tensors, int64_t insNum, int64_t dim) { + impl::aten::setCurStream(ctx); + DIOPI_CHECK_PTR(tensors); + auto tensorList = impl::aten::buildATenList(tensors, insNum); + auto atOut = impl::aten::buildATen(out); + at::cat_out(atOut, tensorList, dim); + + return diopiSuccess; +} + +} // namespace muxi +} // namespace impl diff --git a/impl/muxi/functions/functions_ext/flash-attention/CMakeLists.txt b/impl/muxi/functions/functions_ext/flash-attention/CMakeLists.txt new file mode 100644 index 0000000000..96a86507de --- /dev/null +++ b/impl/muxi/functions/functions_ext/flash-attention/CMakeLists.txt @@ -0,0 +1,5 @@ +message(STATUS "flash-attention DISABLED") +add_library(diopi_torch_ext_flash_attn INTERFACE) + +target_include_directories(diopi_torch_ext_flash_attn INTERFACE + ${BASE_TORCH_DIR}/functions/functions_ext/flash-attention/include) diff --git a/impl/muxi/test/CMakeLists.txt b/impl/muxi/test/CMakeLists.txt new file mode 100644 index 0000000000..4b12592047 --- /dev/null +++ b/impl/muxi/test/CMakeLists.txt @@ -0,0 +1,49 @@ +set(RT_PYBIND export_runtime) +set(FUNC_PYBIND export_functions) + +add_subdirectory(${CMAKE_SOURCE_DIR}/third_party/pybind11 build) +set(DIOPI_TEST_DIR "${CMAKE_SOURCE_DIR}/../diopi_test") + +include_directories(SYSTEM "${PROJECT_SOURCE_DIR}/../third_party/pybind11/include") +include_directories(SYSTEM "${DIOPI_TEST_DIR}/diopi_stub/include") +set(TEST_CSRC_PATH "${DIOPI_TEST_DIR}/diopi_stub/csrc") + + +# diopt rt test. +set(RUNTIME_SRC + ${TEST_CSRC_PATH}/litert.cpp + # use torch cuda runtime + ${BASE_TORCH_DIR}/test/conform_test.cpp +) +add_library(diopirt SHARED ${RUNTIME_SRC}) +message(STATUS "test diopirt CUDA_LIBRARIES is:" ${CUDA_LIBRARIES}) +target_link_libraries(diopirt ${CUDA_LIBRARIES}) + +# rt test py export +set(RT_PYBIND_SRC + ${TEST_CSRC_PATH}/export_runtime.cpp +) +pybind11_add_module(${RT_PYBIND} SHARED ${RT_PYBIND_SRC}) +target_link_libraries(${RT_PYBIND} PRIVATE diopirt) + +# diopi op func test py export +file(GLOB TEST_TEMPLATE_CODE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${DIOPI_TEST_DIR}/diopi_stub/codegen/*.py) +add_custom_target(test_gen_dependency DEPENDS ${TEST_TEMPLATE_CODE}) +set(GEN_FILES ${TEST_CSRC_PATH}/export_functions.cpp) +set(OP_GEN_PATH "${DIOPI_TEST_DIR}/diopi_stub/codegen") +add_custom_target(test_code_gen ALL + COMMAND python3 ${OP_GEN_PATH}/gen.py --device=muxi + BYPRODUCTS ${GEN_FILES} + DEPENDS test_gen_dependency) + +set(FUNCTIONS_SRC ${GEN_FILES}) + +pybind11_add_module(${FUNC_PYBIND} SHARED ${FUNCTIONS_SRC}) +target_link_libraries(${FUNC_PYBIND} PRIVATE diopirt ${DEVICEIMPL}) +add_dependencies(${FUNC_PYBIND} test_code_gen) + +file(MAKE_DIRECTORY ${CMAKE_SOURCE_DIR}/../diopi_test/python) +add_custom_target(python_copy ALL + COMMAND ln -f ${LIBRARY_OUTPUT_PATH}/$ ${CMAKE_SOURCE_DIR}/../diopi_test/python/diopilib + COMMAND ln -f ${LIBRARY_OUTPUT_PATH}/$ ${CMAKE_SOURCE_DIR}/../diopi_test/python/diopilib + DEPENDS ${FUNC_PYBIND} ${RT_PYBIND}) diff --git a/impl/scripts/build_impl.sh b/impl/scripts/build_impl.sh index bf93da360d..2c91e9a08b 100644 --- a/impl/scripts/build_impl.sh +++ b/impl/scripts/build_impl.sh @@ -26,15 +26,21 @@ case $1 in torch) mkdir -p build && cd build cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DIMPL_OPT=torch -DCMAKE_BUILD_TYPE=Release -DTEST=ON \ - -DENABLE_COVERAGE=${USE_COVERAGE} -DCMAKE_PREFIX_PATH=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)') + -DENABLE_COVERAGE=${USE_COVERAGE} make -j8 ;; torch_dyload) mkdir -p build && cd build cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DIMPL_OPT=torch -DCMAKE_BUILD_TYPE=Release -DDYLOAD=ON -DTEST=ON \ - -DCMAKE_PREFIX_PATH=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)') && make -j8 + && make -j8 mkdir -p ${DIOPI_TEST_PATH}/lib && ln -sf ${CURRENT_DIR}/../lib/libdiopi_real_impl.so ${DIOPI_TEST_PATH}/lib ;; + muxi) + mkdir -p build && cd build + cmake_maca .. -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DIMPL_OPT=muxi -DCMAKE_BUILD_TYPE=Release -DTEST=ON \ + -DENABLE_COVERAGE=${USE_COVERAGE} + make_maca -j8 + ;; camb_pytorch) mkdir -p build && cd build cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DIMPL_OPT=camb_pytorch -DCMAKE_BUILD_TYPE=Release -DTEST=ON \ diff --git a/impl/torch/CMakeLists.txt b/impl/torch/CMakeLists.txt index 8a3b39afbd..c97c79fb95 100644 --- a/impl/torch/CMakeLists.txt +++ b/impl/torch/CMakeLists.txt @@ -3,6 +3,9 @@ project(torch_impl) option(HIP "Whether to use HIP when available" OFF) +include(cmake/TorchBaseFunc.cmake) +InitFindTorch() + find_package(Torch 2.0 REQUIRED) if (Torch_FOUND) message(STATUS "TORCH_CXX_FLAGS: ${TORCH_CXX_FLAGS}") diff --git a/impl/torch/cmake/TorchBaseFunc.cmake b/impl/torch/cmake/TorchBaseFunc.cmake new file mode 100644 index 0000000000..a8c925aa07 --- /dev/null +++ b/impl/torch/cmake/TorchBaseFunc.cmake @@ -0,0 +1,11 @@ + +macro(InitFindTorch) + execute_process( + COMMAND sh -c "python -c 'import torch;print(torch.utils.cmake_prefix_path)'" + OUTPUT_VARIABLE DIOPI_TORCH_CMAKE_PREFIX + OUTPUT_STRIP_TRAILING_WHITESPACE) + message(STATUS "DIOPI_TORCH_CMAKE_PREFIX:${DIOPI_TORCH_CMAKE_PREFIX}") + if(DIOPI_TORCH_CMAKE_PREFIX) + list(APPEND CMAKE_PREFIX_PATH ${DIOPI_TORCH_CMAKE_PREFIX}) + endif() +endmacro()