forked from DeepLink-org/DIOPI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[muxi]Fdy/add mx rt diopi (DeepLink-org#1267)
* add mx impl
- Loading branch information
Showing
12 changed files
with
326 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
project(muxi_impl) | ||
|
||
# muxi torch config | ||
add_compile_definitions(USE_MACA=1) | ||
set(USE_MACA ON) | ||
|
||
set(BASE_TORCH_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../torch") | ||
include(${BASE_TORCH_DIR}/cmake/TorchBaseFunc.cmake) | ||
InitFindTorch() | ||
|
||
find_package(Torch REQUIRED) | ||
if (Torch_FOUND) | ||
message(STATUS "TORCH_CXX_FLAGS: ${TORCH_CXX_FLAGS}") | ||
message(STATUS "TORCH_LIBRARIES: ${TORCH_LIBRARIES}") | ||
|
||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") | ||
add_definitions(-DTORCH_VERSION_MAJOR=${Torch_VERSION_MAJOR}) | ||
add_definitions(-DTORCH_VERSION_MINOR=${Torch_VERSION_MINOR}) | ||
add_definitions(-DTORCH_VERSION_PATCH=${Torch_VERSION_PATCH}) | ||
add_definitions(-DTORCH_VERSION=${Torch_VERSION}) | ||
message(STATUS "Found Torch Version: ${Torch_VERSION}") | ||
endif() | ||
|
||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") | ||
|
||
file(GLOB REAL_IMPL_SRC | ||
${BASE_TORCH_DIR}/functions/error.cpp | ||
${BASE_TORCH_DIR}/functions/functions.cpp | ||
|
||
${BASE_TORCH_DIR}/functions/functions_lightllm.cpp | ||
${BASE_TORCH_DIR}/functions/functions_mmcv.cpp | ||
${BASE_TORCH_DIR}/helper.cpp | ||
${BASE_TORCH_DIR}/functions/functions_mmcv/*.cu | ||
|
||
${BASE_TORCH_DIR}/functions/functions_ext.cpp | ||
${BASE_TORCH_DIR}/functions/functions_ext/*.cu | ||
|
||
# mx cpp | ||
functions/functions.cpp | ||
) | ||
|
||
# adaptor | ||
set(USE_ADAPTOR ON) | ||
if(NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/convert_config.yaml") | ||
message(FATAL_ERROR "convert_config.yaml doesn't exist.") | ||
endif() | ||
|
||
if(USE_ADAPTOR) | ||
# dependency | ||
file(GLOB ADAPTOR_TEMPLATE_CODE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${ADAPTOR_DIR}/codegen/*.py) | ||
add_custom_target(adaptor_gen_dependency DEPENDS ${ADAPTOR_TEMPLATE_CODE}) | ||
|
||
set(ADAPTOR_CSRC_PATH "${ADAPTOR_DIR}/csrc") | ||
set(GEN_FILES ${ADAPTOR_CSRC_PATH}/diopi_adaptor.cpp ${ADAPTOR_CSRC_PATH}/impl_functions.hpp) | ||
add_custom_target(adaptor_code_gen | ||
COMMAND python3 ${ADAPTOR_DIR}/codegen/gen.py --diopi_dir=${DIOPI_IMPL_DIR}/../ --output_dir=${ADAPTOR_CSRC_PATH} | ||
--config_device=muxi --base_device=torch | ||
BYPRODUCTS ${GEN_FILES} | ||
DEPENDS adaptor_gen_dependency) | ||
list(APPEND REAL_IMPL_SRC ${ADAPTOR_CSRC_PATH}/convert.cpp ${ADAPTOR_CSRC_PATH}/diopi_adaptor.cpp ${ADAPTOR_CSRC_PATH}/composite_ops.cpp) | ||
endif() | ||
|
||
cuda_add_library(${DEVICEIMPL} SHARED ${REAL_IMPL_SRC}) | ||
target_link_libraries(${DEVICEIMPL} ${TORCH_LIBRARIES}) | ||
add_subdirectory(functions/functions_ext/flash-attention) | ||
target_link_libraries(${DEVICEIMPL} diopi_torch_ext_flash_attn) | ||
target_include_directories(${DEVICEIMPL} PRIVATE ${BASE_TORCH_DIR}) | ||
|
||
if(USE_ADAPTOR) | ||
add_dependencies(${DEVICEIMPL} adaptor_code_gen) | ||
endif() | ||
|
||
if (TEST) | ||
add_subdirectory(test) | ||
endif() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
- diopiLinear: | ||
supportComposite: true | ||
|
||
- diopiRMSNorm: | ||
tensor_dtype: | ||
inv_rms: (float16)->float32 | ||
|
||
- diopiRMSNormBackward: | ||
tensor_dtype: | ||
inv_rms: (float16)->float32 |
Oops, something went wrong.