From 880ddf4f3b2b017c1641acefdb8a26b1feca32ac Mon Sep 17 00:00:00 2001 From: lhy <442488254@qq.com> Date: Wed, 31 Jul 2024 19:13:44 +0800 Subject: [PATCH] use updateATen2Tensor in functions.cpp(this will slow down the perf) --- impl/torch/functions/functions.cpp | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/impl/torch/functions/functions.cpp b/impl/torch/functions/functions.cpp index cda0ed467..afada5cee 100644 --- a/impl/torch/functions/functions.cpp +++ b/impl/torch/functions/functions.cpp @@ -1168,12 +1168,11 @@ diopiError_t diopiAddInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t inp diopiError_t diopiForeachaddScalar(diopiContextHandle_t ctx, diopiTensorHandle_t* outs, diopiConstTensorHandle_t* inputs, int64_t inputSize, const diopiScalar_t* other) { impl::aten::setCurStream(ctx); - DIOPI_IMPL_BUILD_ATEN_LIST(atOuts, outs, inputSize) DIOPI_IMPL_BUILD_ATEN_LIST(atInputs, inputs, inputSize) auto atOther = impl::aten::buildAtScalar(other); auto tempOut = CALL_ATEN_CUDA_FUNC(_foreach_add, atInputs, atOther); for (int i = 0; i < inputSize; i++) { - *(reinterpret_cast(outs[i])) = tempOut[i]; + impl::aten::updateATen2Tensor(ctx, tempOut[i], outs[i]); } return diopiSuccess; @@ -1284,11 +1283,10 @@ diopiError_t diopiForeachmulScalar(diopiContextHandle_t ctx, diopiTensorHandle_t DIOPI_CHECK_PTR(outs); impl::aten::setCurStream(ctx); DIOPI_IMPL_BUILD_ATEN_LIST(atInputs, inputs, inputSize) - DIOPI_IMPL_BUILD_ATEN_LIST(atOuts, outs, inputSize) auto atOther = impl::aten::buildAtScalar(other); auto tempOut = CALL_ATEN_CUDA_FUNC(_foreach_mul, atInputs, atOther); for (int i = 0; i < inputSize; i++) { - *(reinterpret_cast(outs[i])) = tempOut[i]; + impl::aten::updateATen2Tensor(ctx, tempOut[i], outs[i]); } return diopiSuccess; @@ -1312,7 +1310,7 @@ diopiError_t diopiForeachmulTensor(diopiContextHandle_t ctx, diopiTensorHandle_t auto atOther = impl::aten::buildATen(other); auto tempOut = CALL_ATEN_CUDA_FUNC(_foreach_mul, atInputs, atOther); for (int i = 0; i < inputSize; i++) { - *(reinterpret_cast(outs[i])) = tempOut[i]; + impl::aten::updateATen2Tensor(ctx, tempOut[i], outs[i]); } return diopiSuccess; @@ -3317,12 +3315,11 @@ diopiError_t diopiForeachnormScalar(diopiContextHandle_t ctx, diopiTensorHandle_ DIOPI_CHECK_PTR(outs); impl::aten::setCurStream(ctx); DIOPI_IMPL_BUILD_ATEN_LIST(atInputs, inputs, inputSize) - DIOPI_IMPL_BUILD_ATEN_LIST(atOuts, outs, inputSize) auto atP = impl::aten::buildAtScalar(p); auto tempOut = CALL_ATEN_CUDA_FUNC(_foreach_norm, atInputs, atP); for (int i = 0; i < inputSize; i++) { - // impl::aten::updateATen2Tensor(ctx, tempOut[i], out[i]); - *(reinterpret_cast(outs[i])) = tempOut[i]; + //WARN NO NEED TO COPY HERE, WE NEED FASTER UPDATE HERE + impl::aten::updateATen2Tensor(ctx, tempOut[i], outs[i]); } return diopiSuccess;