diff --git a/impl/torch/functions/functions.cpp b/impl/torch/functions/functions.cpp index cda0ed467..afada5cee 100644 --- a/impl/torch/functions/functions.cpp +++ b/impl/torch/functions/functions.cpp @@ -1168,12 +1168,11 @@ diopiError_t diopiAddInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t inp diopiError_t diopiForeachaddScalar(diopiContextHandle_t ctx, diopiTensorHandle_t* outs, diopiConstTensorHandle_t* inputs, int64_t inputSize, const diopiScalar_t* other) { impl::aten::setCurStream(ctx); - DIOPI_IMPL_BUILD_ATEN_LIST(atOuts, outs, inputSize) DIOPI_IMPL_BUILD_ATEN_LIST(atInputs, inputs, inputSize) auto atOther = impl::aten::buildAtScalar(other); auto tempOut = CALL_ATEN_CUDA_FUNC(_foreach_add, atInputs, atOther); for (int i = 0; i < inputSize; i++) { - *(reinterpret_cast(outs[i])) = tempOut[i]; + impl::aten::updateATen2Tensor(ctx, tempOut[i], outs[i]); } return diopiSuccess; @@ -1284,11 +1283,10 @@ diopiError_t diopiForeachmulScalar(diopiContextHandle_t ctx, diopiTensorHandle_t DIOPI_CHECK_PTR(outs); impl::aten::setCurStream(ctx); DIOPI_IMPL_BUILD_ATEN_LIST(atInputs, inputs, inputSize) - DIOPI_IMPL_BUILD_ATEN_LIST(atOuts, outs, inputSize) auto atOther = impl::aten::buildAtScalar(other); auto tempOut = CALL_ATEN_CUDA_FUNC(_foreach_mul, atInputs, atOther); for (int i = 0; i < inputSize; i++) { - *(reinterpret_cast(outs[i])) = tempOut[i]; + impl::aten::updateATen2Tensor(ctx, tempOut[i], outs[i]); } return diopiSuccess; @@ -1312,7 +1310,7 @@ diopiError_t diopiForeachmulTensor(diopiContextHandle_t ctx, diopiTensorHandle_t auto atOther = impl::aten::buildATen(other); auto tempOut = CALL_ATEN_CUDA_FUNC(_foreach_mul, atInputs, atOther); for (int i = 0; i < inputSize; i++) { - *(reinterpret_cast(outs[i])) = tempOut[i]; + impl::aten::updateATen2Tensor(ctx, tempOut[i], outs[i]); } return diopiSuccess; @@ -3317,12 +3315,11 @@ diopiError_t diopiForeachnormScalar(diopiContextHandle_t ctx, diopiTensorHandle_ DIOPI_CHECK_PTR(outs); impl::aten::setCurStream(ctx); DIOPI_IMPL_BUILD_ATEN_LIST(atInputs, inputs, inputSize) - DIOPI_IMPL_BUILD_ATEN_LIST(atOuts, outs, inputSize) auto atP = impl::aten::buildAtScalar(p); auto tempOut = CALL_ATEN_CUDA_FUNC(_foreach_norm, atInputs, atP); for (int i = 0; i < inputSize; i++) { - // impl::aten::updateATen2Tensor(ctx, tempOut[i], out[i]); - *(reinterpret_cast(outs[i])) = tempOut[i]; + //WARN NO NEED TO COPY HERE, WE NEED FASTER UPDATE HERE + impl::aten::updateATen2Tensor(ctx, tempOut[i], outs[i]); } return diopiSuccess;