diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index db8e01997..9f31ddf6d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -353,21 +353,6 @@ jobs: source scripts/ci/ascend/ci_ascend_env.sh bash scripts/ci/ascend/ci_ascend_script.sh build_dipu \ || ( cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/ && rm -rf ${GITHUB_JOB} && exit 1 ) - - Build-Ascend-910b-with-autocompare: - name: Build-dipu-ascend-910b-with-autocompare - needs: [Build-PyTorch-For-Ascend-910b] - runs-on: tps-ascend-ci-910b - steps: - - name: Build dipu - run: | - set -ex - export USE_COVERAGE=ON - export USE_AUTOCOMPARE=ON - cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/ && rm -rf ${GITHUB_JOB} && cp -R source ${GITHUB_JOB} && cd ${GITHUB_JOB}/dipu - source scripts/ci/ascend/ci_ascend_env.sh - bash scripts/ci/ascend/ci_ascend_script.sh build_dipu \ - || ( cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/ && rm -rf ${GITHUB_JOB} && exit 1 ) Test-Ascend-910b: name: Test-dipu-ascend-910b diff --git a/dipu/QuickStart.md b/dipu/QuickStart.md index f2b91ec51..7e19ee8d2 100644 --- a/dipu/QuickStart.md +++ b/dipu/QuickStart.md @@ -158,9 +158,10 @@ sh ./tests/python/run_tests.sh ### 算子库拓展功能 -#### 算子 Fallback +#### 算子Fallback功能 -Fallback 给定算子: +Fallback指的是使用算子的CPU实现,而非设备实现。 +Fallback给定算子: ```bash export DIPU_FORCE_FALLBACK_OPS_LIST=add.out,conv2d @@ -181,20 +182,13 @@ export DIPU_FORCE_FALLBACK_OPS_LIST='.*' python -c "import torch_dipu" ``` -#### 算子精度自动对比功能介绍 +#### 算子精度自动对比功能 -由于该功能默认不开启,使用该功能时需要打开该功能并重新编译DIPU。 - -可以通过设置环境变量USE_AUTOCOMPARE=ON,来开启该功能,然后需要重新编译DIPU。 - -```shell -export USE_AUTOCOMPARE=ON -``` - -以上方法是对所有算子开启自动精度对比。如果只需要对特定算子做精度对比,也可只给需要的算子做精度对比,只需要在相关的配置文件(如 `dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml`)给相应的算子添加 `autocompare: True` 即可。 +算子精度自动对比功能(autocompare)用于确保算子计算结果的正确性,通过将设备参数拷贝到CPU上,对比CPU和设备的计算结果来判断精度是否达标。以下是算子精度自动对比功能的使用例子: ```shell -$ unset DIPU_FORCE_FALLBACK_OPS_LIST # 主要是确保要比较的算子没有强制 fallback 到 cpu, 可选 +$ unset DIPU_FORCE_FALLBACK_OPS_LIST # 主要是确保要比较的算子没有强制 fallback 到 CPU, 可选 +$ export DIPU_AUTOCOMPARE_OPS_LIST=add.out # 对add.out算子开启autocompare功能 $ python >>> import torch >>> import torch_dipu @@ -220,11 +214,33 @@ autocompare: add.out other: allclose >>> ``` -可以看到,CPU 计算结果与设备计算结果 `allclose`,也能看到 CPU 和设备计算结果的 `shape`、`dtype` 等信息。特别的,需要注意以下几个问题: +可以看到,输出包括 CPU 和设备计算结果的 `shape`、`stride`、`dtype` 等信息, 最终结果是CPU和设备的self和out都是allclose的。 + +##### 算子精度自动对比功能的设置 + +算子精度自动对比功能默认不开启,可以设置环境变量`DIPU_AUTOCOMPARE_OPS_LIST`来控制该功能,在开启算子自动对比功能前,必须unset `DIPU_FORCE_FALLBACK_OPS_LIST` + +- 可以通过设置环境变量`DIPU_AUTOCOMPARE_OPS_LIST='.*'`,开启全局的精度对比,这种情况下所有调用的算子都会进行精度对比。 + +```shell +# 开启全局的算子精度自动对比功能 +export DIPU_AUTOCOMPARE_OPS_LIST='.*' +``` + +- 可以设置`DIPU_AUTOCOMPARE_OPS_LIST`来指定算子开启自动精度对比,支持正则表达式匹配,也可以指定多个算子开启自动精度对比。算子名可以参考[diopi_functions.yaml](https://github.com/DeepLink-org/deeplink.framework/blob/main/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml)。 + +```shell +# 指定匹配add.*?的算子进行自动精度对比 +export DIPU_AUTOCOMPARE_OPS_LIST=add.*? +# 指定add.out、sub.out算子进行自动精度对比 +export DIPU_AUTOCOMPARE_OPS_LIST="add.out, sub.out" +``` + +NOTE: -1. `dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml` 中配置了 `autograd:True` 的算子 (`cross_entropy_loss`、`conv2d`、`dropout`、`dropout_`、`linear`) 暂不支持 *backward* 的精度自动对比。如模型精度对不齐,可根据需要先将这几个算子 fallback 到 CPU 来确定问题。 -2. 随机数生成相关的算子(`dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml` 中配置了 `autocompare:False`)没有做 `autocompare`,因为结果总是 `not_allclose`。 -3. 对输入做检查是确保算子输入不被意外修改。 +1. 部分算子并不支持自动精度对比功能,可以查看[diopi_functions.yaml](https://github.com/DeepLink-org/deeplink.framework/blob/main/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml),其中的`autocompare`配置项为`disable`即不支持自动精度对比功能,同时也可以修改`diopi_functions.yaml`,将某些算子的`autocompare`配置项设置为`disable`来禁用自动对比功能。 +2. `dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml` 中配置了 `autograd:True` 的算子 (`cross_entropy_loss`、`conv2d`、`dropout`、`dropout_`、`linear`) 暂不支持 *backward* 的精度自动对比。如模型精度对不齐,可根据需要先将这几个算子 fallback 到 CPU 来确定问题。 +3. 对输入参数(self)做检查是确保算子的输入不被意外修改。 #### 抓取算子参数 diff --git a/dipu/scripts/autogen_diopi_wrapper/autogen_diopi_wrapper.py b/dipu/scripts/autogen_diopi_wrapper/autogen_diopi_wrapper.py index 84c4aa2ac..e973b6669 100644 --- a/dipu/scripts/autogen_diopi_wrapper/autogen_diopi_wrapper.py +++ b/dipu/scripts/autogen_diopi_wrapper/autogen_diopi_wrapper.py @@ -11,7 +11,6 @@ op_register_template_content, custom_autograd_template_content, autocompare_template_content, - op_with_custom_fallback_register_template_content, ) @@ -458,6 +457,9 @@ def create_call_aten_cpu_cpp_function_code_from_config(fun_config): opname = re.sub("\.correction", "", opname) opname = re.sub("\.input", "", opname) opname = re.sub("\.dim_IntList", "", opname) + opname = re.sub("\.dim", "", opname) + opname = re.sub("\.mode", "", opname) + opname = opname.replace(".", "_") opname = opname.split(".")[0] if opname[-1] == "_" and len(get_function_return_param_from_schema(schema)) > 0: @@ -673,10 +675,6 @@ def create_optional_generator_process_code(arg_name): op_register_template = CodeTemplate(op_register_template_content) -op_with_custom_fallback_register_template = CodeTemplate( - op_with_custom_fallback_register_template_content -) - custom_autograd_template = CodeTemplate(custom_autograd_template_content) autocompare_template = CodeTemplate(autocompare_template_content) @@ -906,74 +904,66 @@ def functions_code_gen(fun_config): fbody += custom_autograd_function_code fun_name = wrapper_fun_name - if fun_config.get("autocompare", False) in [True, "True"] and fun_config.get( - "register_op", True - ) in [True, "True"]: + if fun_config.get("register_op", True) in [True, "True"]: auto_compare_fun_name = fun_name + "_autocompare" - autocompare_code = autocompare_template.substitute( - cppsignautre=[ - create_cpp_signature_from_schema(fun_config["schema"]).replace( - raw_fun_name, auto_compare_fun_name - ) - ], - transform_input_to_cpu_code=[ - create_transform_input_to_cpu_code(fun_config) - ], - execute_op_on_cpu_code=[ - create_call_aten_cpu_cpp_function_code_from_config(fun_config) - ], - comment=[fun_config["schema"]], - execute_op_on_device_code=[ - create_call_dipu_cpp_function_code_from_schema( - fun_config["schema"] - ).replace(raw_fun_name, fun_name) - ], - transform_result_to_cpu_code=[], - result_compare_code=[ - create_result_compare_code(fun_config) - + ( - "\nreturn result_device;\n" - if len(get_function_return_param_from_schema(fun_config["schema"])) - > 0 - else "" - ) - ], - ) + autocompare_code = "" + if fun_config.get("autocompare", True) not in [False]: + autocompare_code = autocompare_template.substitute( + cppsignautre=[ + create_cpp_signature_from_schema(fun_config["schema"]).replace( + raw_fun_name, auto_compare_fun_name + ) + ], + transform_input_to_cpu_code=[ + create_transform_input_to_cpu_code(fun_config) + ], + execute_op_on_cpu_code=[ + create_call_aten_cpu_cpp_function_code_from_config(fun_config) + ], + comment=[fun_config["schema"]], + execute_op_on_device_code=[ + create_call_dipu_cpp_function_code_from_schema( + fun_config["schema"] + ).replace(raw_fun_name, fun_name) + ], + transform_result_to_cpu_code=[], + result_compare_code=[ + create_result_compare_code(fun_config) + + ( + "\nreturn result_device;\n" + if len( + get_function_return_param_from_schema(fun_config["schema"]) + ) + > 0 + else "" + ) + ], + ) + if fun_config.get("autocompare", True) in [False]: + disable_autocompare_comment = ( + "// since autocompare is disabled, " + + auto_compare_fun_name + + " will do nothing.\n" + ) + autocompare_code = ( + disable_autocompare_comment + + "void " + + auto_compare_fun_name + + "() " + + "{}\n" + ) fbody += autocompare_code - fun_name = auto_compare_fun_name - - if fun_config.get("custom_fallback", False) in ["False", False]: - register_body = op_register_template.substitute( - register_name=[get_op_name_from_schema(fun_config["schema"])], - aten_fun_name=["dipu::native::" + fun_name], - diopi_fun_name=[ - get_fun_name_from_cppsignature(diopi_interface).replace( - "diopi", "::diopi" - ) - ], - ) - else: - register_body = op_with_custom_fallback_register_template.substitute( - register_name=[get_op_name_from_schema(fun_config["schema"])], - aten_fun_name=["dipu::native::" + fun_name], - diopi_fun_name=[ - get_fun_name_from_cppsignature(diopi_interface).replace( - "diopi", "::diopi" - ) - ], - force_fallback=[ - ( - "false" - if fun_config.get("force_fallback", False) in [False, "False"] - else "true" - ) - ], - fallbackFunc=[ - "dipu::native::" - + "custom_fallback_" - + fun_name.replace("_autocompare", "") - ], - ) + + # generate the op_register code + register_body = op_register_template.substitute( + register_name=[get_op_name_from_schema(fun_config["schema"])], + aten_fun_name=["dipu::native::" + fun_name], + diopi_fun_name=[ + get_fun_name_from_cppsignature(diopi_interface).replace("diopi", "::diopi") + ], + custom_fallback_config=str(fun_config.get("custom_fallback", False)).lower(), + autocompare_config=str(fun_config.get("autocompare", True)).lower(), + ) return fbody, register_body @@ -1039,12 +1029,6 @@ def parse_args(): type=boolean_string, help="whether generate code that prints op args", ) - parser.add_argument( - "--autocompare", - default=False, - type=boolean_string, - help="whether generate code that compare device calculation results with cpu calculation results", - ) parser.add_argument( "--fun_config_dict", type=json.loads, diff --git a/dipu/scripts/autogen_diopi_wrapper/autogen_wrapped_code.sh b/dipu/scripts/autogen_diopi_wrapper/autogen_wrapped_code.sh index 3d6e0dd18..fd6e01b11 100755 --- a/dipu/scripts/autogen_diopi_wrapper/autogen_wrapped_code.sh +++ b/dipu/scripts/autogen_diopi_wrapper/autogen_wrapped_code.sh @@ -5,17 +5,16 @@ DIPU_DIR=$(readlink -f $(dirname $(readlink -f "$0"))/../..) AUTOGEN_DIOPI_WRAPPER=$DIPU_DIR/scripts/autogen_diopi_wrapper -USE_AUTOCOMPARE=${1:-OFF} -UsedVendor=${2:-cuda} -Torch_VERSION=${3:-2.1.0} -GENERATED_KERNELS_SCRIPT=${4:-$AUTOGEN_DIOPI_WRAPPER/autogen_diopi_wrapper.py} -GENERATED_KERNELS_CONFIG=${5:-$AUTOGEN_DIOPI_WRAPPER/diopi_functions.yaml} -GENERATED_KERNELS=${6:-$DIPU_DIR/torch_dipu/csrc_dipu/aten/ops/AutoGenedKernels.cpp} +UsedVendor=${1:-cuda} +Torch_VERSION=${2:-2.1.0} +GENERATED_KERNELS_SCRIPT=${3:-$AUTOGEN_DIOPI_WRAPPER/autogen_diopi_wrapper.py} +GENERATED_KERNELS_CONFIG=${4:-$AUTOGEN_DIOPI_WRAPPER/diopi_functions.yaml} +GENERATED_KERNELS=${5:-$DIPU_DIR/torch_dipu/csrc_dipu/aten/ops/AutoGenedKernels.cpp} GENERATED_KERNELS_VENDOR=${DIPU_DIR}/third_party/DIOPI/impl/${UsedVendor}/convert_config.yaml PYTHON_CMD="python3 ${GENERATED_KERNELS_SCRIPT} --out=${GENERATED_KERNELS} --config=${GENERATED_KERNELS_CONFIG} \ - --autocompare=${USE_AUTOCOMPARE} --print_op_arg=True --use_diopi_adapter=False --print_func_call_info=True \ + --print_op_arg=True --use_diopi_adapter=False --print_func_call_info=True \ --fun_config_dict='{\"current_device\":\"${UsedVendor}\",\"current_torch_ver\":\"${Torch_VERSION}\"}'" if [ -f "$GENERATED_KERNELS_VENDOR" ]; then diff --git a/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml b/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml index 16d324dd0..e1b81f832 100755 --- a/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml +++ b/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml @@ -1,5 +1,5 @@ - schema: "exampleop.overloadname(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)" - autocompare: disable + autocompare: False # op gen only on these torch version. use it only if op has different signature on different torch. # if it's only different on implementation , please use compile macro DIPU_TORCHXXX. # torch version number, 5 in total: {X-major}{XX-minor}{XX-patch} @@ -309,7 +309,7 @@ interface: diopiLayerNormBackward(ctx, grad_input, grad_weight, grad_bias, grad_out, input, weight, bias, mean, rstd, normalized_shape); - schema: "adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)" - #autocompare: disable # TODO: cpu impl not support half now + #autocompare: False # TODO: cpu impl not support half now interface: diopiAdaptiveAvgPool2d(ctx, out, self, output_size) - schema: "_adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor" @@ -486,13 +486,13 @@ interface: diopiRelu(ctx, out, self) - schema: "randperm.out(int n, *, Tensor(a!) out) -> Tensor(a!)" - autocompare: disable + autocompare: False custom_code_at_the_beginning: | diopiGeneratorHandle_t generatorDiopiGenerator = toDiopiGeneratorHandle(getDefaultDIPUGenerator()); interface: diopiRandperm(ctx, out, n, generatorDiopiGenerator) - schema: "randperm.generator_out(int n, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)" - autocompare: disable + autocompare: False interface: diopiRandperm(ctx, out, n, generator) - schema: "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, * ScalarType? dtype=None) -> Tensor" @@ -686,7 +686,7 @@ interface: diopiMul(ctx, out, out, out) - schema: "bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!)" - autocompare: disable + autocompare: False interface: diopiBernoulliScalar(ctx, self, p, generatorDiopiGenerator); - schema: "log.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)" @@ -1095,7 +1095,7 @@ interface: diopiRsqrt(ctx, out, self) - schema: "uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!)" - autocompare: disable + autocompare: False interface: diopiUniformInp(ctx, self, from, to, generator) - schema: "tril(Tensor self, int diagonal=0) -> Tensor" @@ -1177,15 +1177,15 @@ interface: diopiClamp(ctx, out, self, min, max) - schema: "random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!)" - autocompare: disable + autocompare: False interface: diopiRandomInp(ctx, self, 0, nullptr, generator) - schema: "random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!)" - autocompare: disable + autocompare: False interface: diopiRandomInp(ctx, self, 0, &to, generator) - schema: "random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!)" - autocompare: disable + autocompare: False interface: "diopiRandomInp(ctx, self, from, to.has_value() ? &to.value() : nullptr, generator)" - schema: "nonzero(Tensor self) -> Tensor" @@ -1239,7 +1239,7 @@ interface: diopiProd(ctx, out, self_dtype_diopi, &dim) - schema: repeat(Tensor self, SymInt[] repeats) -> Tensor - autocompare: disable + autocompare: False custom_code_at_the_beginning: | std::vector output_size(repeats.size()); for (int i = 0;i< repeats.size();++i) { @@ -1655,41 +1655,41 @@ interface: diopiReciprocal(ctx, out, self) - schema: "normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)" - autocompare: disable + autocompare: False interface: diopiNormalInp(ctx, self, mean, std, generator) - schema: normal.Tensor_float_out(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) - autocompare: disable + autocompare: False interface: diopiNormalTensorScalar(ctx, out, mean, std, generator) - schema: normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor - autocompare: disable + autocompare: False custom_code_at_the_beginning: | auto out = nodispatch::empty_like(mean); interface: diopiNormalTensorScalar(ctx, out, mean, std, generator) - schema: normal.float_Tensor_out(float mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) - autocompare: disable + autocompare: False interface: diopiNormalScalarTensor(ctx, out, mean, std, generator) - schema: normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor - autocompare: disable + autocompare: False custom_code_at_the_beginning: | auto out = nodispatch::empty_like(std); interface: diopiNormalScalarTensor(ctx, out, mean, std, generator) - schema: normal.Tensor_Tensor_out(Tensor mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) - autocompare: disable + autocompare: False interface: diopiNormalTensor(ctx, out, mean, std, generator) - schema: normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor - autocompare: disable + autocompare: False custom_code_at_the_beginning: | auto out = nodispatch::empty_like(mean); interface: diopiNormalTensor(ctx, out, mean, std, generator) - schema: normal.float_float_out(float mean, float std, SymInt[] size, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) - autocompare: disable + autocompare: False interface: diopiNormal(ctx, out, mean, std, generator) - schema: "mm(Tensor self, Tensor mat2) -> Tensor" @@ -1829,7 +1829,7 @@ - schema: "ctc_loss_tensor_backward(Tensor grad_output, Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, int reduction=Mean, bool zero_infinity=False) -> Tensor grad_input" device: [camb] - autocompare: disable + autocompare: False register_op: False custom_code_at_the_beginning: | const auto reductionDiopi = static_cast<::diopiReduction_t>(reduction); @@ -1925,7 +1925,7 @@ - schema: "ctc_loss_intlist_backward(Tensor grad_output, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, int reduction=Mean, bool zero_infinity=False) -> Tensor grad_input" device: [camb] - autocompare: disable + autocompare: False register_op: False ins: [input_lengths_tensor, target_lengths_tensor] custom_code_at_the_beginning: | @@ -2748,7 +2748,6 @@ # this copy_ aten op may use both diopiCastDtype and diopiCopyInp. it's a proxy/composite op - schema: copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) - autocompare: disable dummy_call_diopi: True custom_fallback: True device: [cuda, camb, ascend, droplet, supa, kunlunxin] @@ -2760,7 +2759,6 @@ # vendor who has no fully implemented diopi and proper fallback DIPUCopy sub-class - schema: copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) - autocompare: disable custom_fallback: True dummy_call_diopi: True custom_code_at_the_beginning: | @@ -2768,8 +2766,7 @@ device: [topsrider] interface: diopiCopyInp(ctx, src, self) -- schema: _amp_foreach_non_finite_check_and_unscale_(at::TensorList self, Tensor(b!) found_inf, Tensor inv_scale) -> void - autocompare: disable +- schema: _amp_foreach_non_finite_check_and_unscale_(at::TensorList self, Tensor(b!) found_inf, Tensor inv_scale) -> () custom_fallback: True custom_code_at_the_beginning: | std::vector diopiTensorHandles(self.size(), nullptr); @@ -2780,8 +2777,6 @@ }); // NOLINTEND(cppcoreguidelines-pro-type-const-cast) interface: diopiAmpForeachNonFiniteCheckAndUnscaleInp(ctx, diopiTensorHandles.data(), static_cast(self.size()), found_inf, inv_scale) - # TODO(someone): fix this issue when `autocompare` is on - autocompare: disable - schema: _amp_update_scale_(Tensor(a!) self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> Tensor(a!) custom_fallback: True diff --git a/dipu/scripts/autogen_diopi_wrapper/diopi_wrapper_template.py b/dipu/scripts/autogen_diopi_wrapper/diopi_wrapper_template.py index ba723da1b..6ca764a05 100644 --- a/dipu/scripts/autogen_diopi_wrapper/diopi_wrapper_template.py +++ b/dipu/scripts/autogen_diopi_wrapper/diopi_wrapper_template.py @@ -50,6 +50,7 @@ #include "csrc_dipu/aten/ops/DIPUCopy.hpp" #include "csrc_dipu/aten/ops/NodispatchUtils.hpp" #include "csrc_dipu/aten/ops/OpUtils.hpp" +#include "csrc_dipu/aten/ops/OpRegexMatch.hpp" #include "csrc_dipu/base/basedef.h" #include "csrc_dipu/diopirt/diopirt_impl.h" #include "csrc_dipu/profiler/profiler.h" @@ -128,11 +129,7 @@ """ op_register_template_content = """ -DIOPI_ATEN_FUNC("$register_name", $diopi_fun_name, $aten_fun_name); -""" - -op_with_custom_fallback_register_template_content = """ -DIOPI_ATEN_FUNC_CUSTOM_FALLBACK("$register_name", $diopi_fun_name, $force_fallback /*whether force fallback*/, $aten_fun_name, $fallbackFunc); +DIOPI_ATEN_FUNC("$register_name", $diopi_fun_name, $aten_fun_name, $custom_fallback_config, $autocompare_config); """ custom_autograd_template_content = """ diff --git a/dipu/scripts/ci/ascend/ci_ascend_script.sh b/dipu/scripts/ci/ascend/ci_ascend_script.sh index 5c976271a..c1c986c71 100644 --- a/dipu/scripts/ci/ascend/ci_ascend_script.sh +++ b/dipu/scripts/ci/ascend/ci_ascend_script.sh @@ -12,9 +12,6 @@ function build_diopi_lib() { function config_dipu_ascend_cmake() { mkdir -p build && cd ./build cmake_args="-DCMAKE_BUILD_TYPE=Release -DDEVICE=ascend -DWITH_DIOPI_LIBRARY=DISABLE" - if [ -n "$USE_AUTOCOMPARE" ]; then - cmake_args+=" -DUSE_AUTOCOMPARE=${USE_AUTOCOMPARE}" - fi cmake ../ $cmake_args cd ../ } @@ -22,9 +19,6 @@ function config_dipu_ascend_cmake() { function config_all_ascend_cmake() { mkdir -p build && cd ./build cmake_args="-DCMAKE_BUILD_TYPE=Release -DDEVICE=ascend -DENABLE_COVERAGE=${USE_COVERAGE} -DWITH_DIOPI=INTERNAL" - if [ -n "$USE_AUTOCOMPARE" ]; then - cmake_args+=" -DUSE_AUTOCOMPARE=${USE_AUTOCOMPARE}" - fi cmake ../ $cmake_args cd ../ } diff --git a/dipu/torch_dipu/csrc_dipu/CMakeLists.txt b/dipu/torch_dipu/csrc_dipu/CMakeLists.txt index 20bb442fe..457d706f4 100644 --- a/dipu/torch_dipu/csrc_dipu/CMakeLists.txt +++ b/dipu/torch_dipu/csrc_dipu/CMakeLists.txt @@ -1,5 +1,4 @@ #[[ Dependencies ]] -option(USE_AUTOCOMPARE "whether to use USE_AUTOCOMPARE" OFF) # Import Python3::Python, Python3_EXECUTABLE # Also see https://cmake.org/cmake/help/latest/module/FindPython3.html @@ -58,7 +57,7 @@ endif() add_custom_command( OUTPUT "${GENERATED_KERNELS}" - COMMAND bash -c "${AUTOGEN_CODE_SH} ${USE_AUTOCOMPARE} ${UsedVendor} ${Torch_VERSION} ${GENERATED_KERNELS_SCRIPT} ${GENERATED_KERNELS_CONFIG} ${GENERATED_KERNELS}" + COMMAND bash -c "${AUTOGEN_CODE_SH} ${UsedVendor} ${Torch_VERSION} ${GENERATED_KERNELS_SCRIPT} ${GENERATED_KERNELS_CONFIG} ${GENERATED_KERNELS}" COMMENT "Generating ${GENERATED_KERNELS}$<$: with ${GENERATED_KERNELS_VENDOR}>" DEPENDS "${GENERATED_KERNELS_SCRIPT}" @@ -76,6 +75,7 @@ set(TORCH_DIPU_SOURCE aten/ops/PinMemoryKernel.cpp aten/ops/EmptyOpsKernel.cpp aten/ops/CustomFallbackFunctionsForCopy.cpp + aten/ops/OpRegexMatch.cpp aten/RegisterDIPU.cpp aten/CPUFallback.cpp diff --git a/dipu/torch_dipu/csrc_dipu/aten/RegisterDIPU.cpp b/dipu/torch_dipu/csrc_dipu/aten/RegisterDIPU.cpp index 51d856a81..c99cac94e 100644 --- a/dipu/torch_dipu/csrc_dipu/aten/RegisterDIPU.cpp +++ b/dipu/torch_dipu/csrc_dipu/aten/RegisterDIPU.cpp @@ -21,57 +21,6 @@ namespace dnative = dipu::native::dipu_aten; namespace dipu { -namespace { - -std::vector load_fallback_matcher() { - auto constexpr env_name = "DIPU_FORCE_FALLBACK_OPS_LIST"; - auto constexpr file_name = ".dipu_force_fallback_op_list.config"; - - auto append = [](std::istream& input, std::vector& output) { - auto constexpr separator = ','; - - auto line = std::string(); - while (std::getline(input, line)) { - auto buffer = std::istringstream(line); - auto pattern = std::string(); - while (std::getline(buffer, pattern, separator)) { - if (pattern.empty()) { - continue; - } - try { - output.emplace_back(pattern); - } catch (const std::regex_error& e) { - TORCH_CHECK(false, e.what()); - } - } - } - }; - - auto list = std::vector(); - if (auto env = std::getenv(env_name)) { - auto iss = std::istringstream(env); - append(iss, list); - } - if (auto file = std::ifstream(file_name, std::ios::binary)) { - append(file, list); - } - return list; -} - -auto const force_fallback_matchers = load_fallback_matcher(); - -} // end of namespace - -bool get_force_fallback(const char* opname) { - if (force_fallback_matchers.empty() || opname == nullptr) { - return false; - } - - return std::any_of( - force_fallback_matchers.begin(), force_fallback_matchers.end(), - [&opname](auto& matcher) { return std::regex_match(opname, matcher); }); -} - namespace native { void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack); } // end of namespace native diff --git a/dipu/torch_dipu/csrc_dipu/aten/RegisterDIPU.hpp b/dipu/torch_dipu/csrc_dipu/aten/RegisterDIPU.hpp index 82f36671b..d3d7e6b24 100644 --- a/dipu/torch_dipu/csrc_dipu/aten/RegisterDIPU.hpp +++ b/dipu/torch_dipu/csrc_dipu/aten/RegisterDIPU.hpp @@ -1,6 +1,9 @@ // Copyright (c) 2023, DeepLink. #pragma once +#include +#include // for printf +#include // for std::getenv #include #include #include @@ -8,16 +11,10 @@ #include +#include "csrc_dipu/aten/ops/OpRegexMatch.hpp" #include "csrc_dipu/aten/ops/OpUtils.hpp" -namespace dipu { - -bool get_force_fallback(const char* opname); - -}; // namespace dipu - namespace at { - void dipu_fallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack); @@ -52,43 +49,34 @@ void dipu_fallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, // It mat be necessary to determine whether to keep torchop default impl // for non-custom ops through function dipuKeepTorchopDefaultImpl firstly in the // future, and we use force fallback to keep torchop default impl now. -#define DIOPI_ATEN_FUNC(opname, diopiFunc, wapperFunc) \ - do { \ - if ((reinterpret_cast(diopiFunc) != nullptr) && \ - (!dipu::get_force_fallback(opname))) { \ - m.impl(opname, TORCH_FN(wapperFunc)); \ - } else { \ - if ((reinterpret_cast(diopiFunc) == nullptr)) { \ - DIPU_OP_LOG_WARNING_ONCE(#diopiFunc << " is not yet implemented, "); \ - } else { \ - DIPU_OP_LOG_WARNING_ONCE("force fallback has been set, "); \ - } \ - DIPU_OP_LOG_WARNING_ONCE((opname) << " will be fallback to cpu" \ - << "\n"); \ - } \ - } while (false); - -// Determine whether to keep torchop default impl for custom ops through -// function dipuKeepTorchopDefaultImpl firstly. -#define DIOPI_ATEN_FUNC_CUSTOM_FALLBACK(opname, diopi_func, force_fallback, \ - wapper_func, custom_fallback_func) \ +#define CONCAT_NAME(a, b) a##b +#define DIOPI_ATEN_FUNC(opname, diopiFunc, wrapperFunc, customFallbackConfig, \ + autocompareConfig) \ do { \ - if (dipu::native::dipuKeepTorchopDefaultImpl(opname)) { \ + bool isAutoCompareMatch = dipu::op_regex_match::isOpMatch( \ + opname, dipu::op_regex_match::autocompareMatchers); \ + bool isFallbackMatch = dipu::op_regex_match::isOpMatch( \ + opname, dipu::op_regex_match::fallbackMatchers); \ + if (reinterpret_cast(diopiFunc) == nullptr) { \ + DIPU_OP_LOG_WARNING_ONCE(#diopiFunc << " is not yet implemented, " \ + << (opname) \ + << " will be fallback to cpu" \ + << "\n"); \ + break; \ + } \ + if ((autocompareConfig) && isAutoCompareMatch && \ + reinterpret_cast(wrapperFunc##_autocompare) != nullptr) { \ + m.impl(opname, TORCH_FN(wrapperFunc##_autocompare)); \ break; \ } \ - if ((reinterpret_cast(diopi_func) != nullptr) && \ - !((force_fallback) || dipu::get_force_fallback(opname))) { \ - m.impl(opname, TORCH_FN(wapper_func)); \ - } else { \ - if ((reinterpret_cast(diopi_func) == nullptr)) { \ - DIPU_OP_LOG_WARNING_ONCE(#diopi_func << " is not yet implemented, "); \ - } else { \ - DIPU_OP_LOG_WARNING_ONCE("force fallback has been set, "); \ - } \ - DIPU_OP_LOG_WARNING_ONCE((opname) << " will be fallback to cpu" \ - << "\n"); \ - m.impl(opname, TORCH_FN(custom_fallback_func)); \ + if ((customFallbackConfig) || isFallbackMatch) { \ + DIPU_OP_LOG_WARNING_ONCE("force fallback has been set, " \ + << (opname) << " will be fallback to cpu" \ + << "\n"); \ + break; \ } \ + m.impl(opname, TORCH_FN(wrapperFunc)); \ + \ } while (false); class DIPUOpRegister { diff --git a/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUCopy.hpp b/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUCopy.hpp index 41bde2531..41819eadf 100644 --- a/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUCopy.hpp +++ b/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUCopy.hpp @@ -35,6 +35,8 @@ enum class DIPUCopyType { D2H, // from host to device H2D, + // from host to host + H2H, }; // Align with pytorch's behavior, see TensorIterator.cpp compute_mem_overlaps() @@ -59,16 +61,21 @@ inline void tryRecordStream(const at::Tensor& tensor, DIPUStream& curStream, inline DIPUCopyType getCopyType(const at::Tensor& dst, const at::Tensor& src) { bool isSrcDevice = dipu::isDeviceTensor(src); bool isDstDevice = dipu::isDeviceTensor(dst); - if (!isSrcDevice) { - return DIPUCopyType::H2D; // this op not handle h2h, dest always device + if (!isSrcDevice && isDstDevice) { + return DIPUCopyType::H2D; } - if (!isDstDevice) { - return DIPUCopyType::D2H; // here src always device + if (!isDstDevice && isSrcDevice) { + return DIPUCopyType::D2H; } - if (src.device().index() != dst.device().index()) { + if (isSrcDevice && isDstDevice && + src.device().index() != dst.device().index()) { return DIPUCopyType::D2OtherD; } - return DIPUCopyType::D2Self; + if (isSrcDevice && isDstDevice && + src.device().index() == dst.device().index()) { + return DIPUCopyType::D2Self; + } + return DIPUCopyType::H2H; } inline int64_t getMemCopyBytes(const at::Tensor& dst, const at::Tensor& src, @@ -117,6 +124,17 @@ inline void doMemCopyD2H(const at::Tensor& dst, const at::Tensor& src, } } +inline void doMemCopyH2H(const at::Tensor& dst, const at::Tensor& src, + int64_t nbytes) { + if (!dst.is_contiguous() || !src.is_contiguous()) { + std::cerr << "Tensors must be contiguous for memory copy." << std::endl; + return; + } + void* src_ptr = src.data_ptr(); + void* dst_ptr = dst.data_ptr(); + memcpy(dst_ptr, src_ptr, nbytes); +} + inline void doMemCopyD2D(const at::Tensor& dst, const at::Tensor& src, dipu::DIPUStream& stream, int64_t nbytes, bool isSynchronousCopy) { @@ -148,6 +166,9 @@ inline void memCopy(const at::Tensor& dst, const at::Tensor& src, // dst is cpu. doMemCopyD2H(dst, src, stream, nbytes, isSynchronousCopy); break; + case DIPUCopyType::H2H: + doMemCopyH2H(dst, src, nbytes); + break; default: // device to device doMemCopyD2D(dst, src, stream, nbytes, isSynchronousCopy); } @@ -294,7 +315,6 @@ class DIPUCopyInplace : public DIPUCopyBase { if (native::dumpOpArgLevel() > 0) { printf("--%-50s %-30s \n", "[copy_]:", "doDirectMemCopy"); } - memCopy(dst, src, curStream, copyType, /*nonOverlappingAndDense=*/true, /*isSynchronousCopy=*/false); diff --git a/dipu/torch_dipu/csrc_dipu/aten/ops/OpRegexMatch.cpp b/dipu/torch_dipu/csrc_dipu/aten/ops/OpRegexMatch.cpp new file mode 100644 index 000000000..1e0c02cef --- /dev/null +++ b/dipu/torch_dipu/csrc_dipu/aten/ops/OpRegexMatch.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2024, DeepLink. +#include "OpRegexMatch.hpp" + +#include +#include +#include +#include +#include + +#include + +// loadMatcher is used to get regex matcher from env_name and config +// fallback_env_name = "DIPU_FORCE_FALLBACK_OPS_LIST"; fallback_config_name = +// ".dipu_force_fallback_op_list.config" specified_autocompare_env_name = +// "DIPU_AUTOCOMPARE_OPS_LIST"; specified_autocompare_config_name = +// ".specified_autocompare_op_list.config" + +namespace dipu { +namespace op_regex_match { +std::vector loadMatcher(const char* env_name, + const char* config_name) { + auto append = [](std::istream& input, std::vector& output) { + auto constexpr separator = ','; + + auto line = std::string(); + while (std::getline(input, line)) { + auto buffer = std::istringstream(line); + auto pattern = std::string(); + while (std::getline(buffer, pattern, separator)) { + if (pattern.empty()) { + continue; + } + try { + output.emplace_back(pattern); + } catch (const std::regex_error& e) { + TORCH_CHECK(false, e.what()); + } + } + } + }; + + auto list = std::vector(); + if (auto env = std::getenv(env_name)) { + auto iss = std::istringstream(env); + append(iss, list); + } + if (auto file = std::ifstream(config_name, std::ios::binary)) { + append(file, list); + } + return list; +} + +bool isOpMatch(const char* opname, + const std::vector& regexMatchers) { + if (regexMatchers.empty() || opname == nullptr) { + return false; + } + + return std::any_of( + regexMatchers.begin(), regexMatchers.end(), + [&opname](auto& matcher) { return std::regex_match(opname, matcher); }); +} + +constexpr const char* fallback_env_name = "DIPU_FORCE_FALLBACK_OPS_LIST"; +constexpr const char* fallback_config_name = + ".dipu_force_fallback_op_list.config"; +const std::vector fallbackMatchers = + dipu::op_regex_match::loadMatcher(fallback_env_name, fallback_config_name); + +constexpr const char* specified_autocompare_env_name = + "DIPU_AUTOCOMPARE_OPS_LIST"; +constexpr const char* specified_autocompare_config_name = + ".specified_autocompare_op_list.config"; +const std::vector autocompareMatchers = + dipu::op_regex_match::loadMatcher(specified_autocompare_env_name, + specified_autocompare_config_name); +} // namespace op_regex_match +} // namespace dipu diff --git a/dipu/torch_dipu/csrc_dipu/aten/ops/OpRegexMatch.hpp b/dipu/torch_dipu/csrc_dipu/aten/ops/OpRegexMatch.hpp new file mode 100644 index 000000000..2d8dda86d --- /dev/null +++ b/dipu/torch_dipu/csrc_dipu/aten/ops/OpRegexMatch.hpp @@ -0,0 +1,18 @@ +// Copyright (c) 2024, DeepLink. +#include +#include +#include +#include + +#include + +namespace dipu { +namespace op_regex_match { +std::vector loadMatcher(const char* env_name, + const char* config_name); +bool isOpMatch(const char* opname, + const std::vector& regexMatchers); +extern const std::vector fallbackMatchers; +extern const std::vector autocompareMatchers; +} // namespace op_regex_match +} // namespace dipu