Skip to content

Commit

Permalink
use updateATen2Tensor in functions.cpp(this will slow down the perf)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrench-Git committed Jul 31, 2024
1 parent 981ea99 commit 880ddf4
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions impl/torch/functions/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1168,12 +1168,11 @@ diopiError_t diopiAddInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t inp
diopiError_t diopiForeachaddScalar(diopiContextHandle_t ctx, diopiTensorHandle_t* outs, diopiConstTensorHandle_t* inputs, int64_t inputSize,
const diopiScalar_t* other) {
impl::aten::setCurStream(ctx);
DIOPI_IMPL_BUILD_ATEN_LIST(atOuts, outs, inputSize)
DIOPI_IMPL_BUILD_ATEN_LIST(atInputs, inputs, inputSize)
auto atOther = impl::aten::buildAtScalar(other);
auto tempOut = CALL_ATEN_CUDA_FUNC(_foreach_add, atInputs, atOther);
for (int i = 0; i < inputSize; i++) {
*(reinterpret_cast<at::Tensor*>(outs[i])) = tempOut[i];
impl::aten::updateATen2Tensor(ctx, tempOut[i], outs[i]);
}

return diopiSuccess;
Expand Down Expand Up @@ -1284,11 +1283,10 @@ diopiError_t diopiForeachmulScalar(diopiContextHandle_t ctx, diopiTensorHandle_t
DIOPI_CHECK_PTR(outs);
impl::aten::setCurStream(ctx);
DIOPI_IMPL_BUILD_ATEN_LIST(atInputs, inputs, inputSize)
DIOPI_IMPL_BUILD_ATEN_LIST(atOuts, outs, inputSize)
auto atOther = impl::aten::buildAtScalar(other);
auto tempOut = CALL_ATEN_CUDA_FUNC(_foreach_mul, atInputs, atOther);
for (int i = 0; i < inputSize; i++) {
*(reinterpret_cast<at::Tensor*>(outs[i])) = tempOut[i];
impl::aten::updateATen2Tensor(ctx, tempOut[i], outs[i]);
}

return diopiSuccess;
Expand All @@ -1312,7 +1310,7 @@ diopiError_t diopiForeachmulTensor(diopiContextHandle_t ctx, diopiTensorHandle_t
auto atOther = impl::aten::buildATen(other);
auto tempOut = CALL_ATEN_CUDA_FUNC(_foreach_mul, atInputs, atOther);
for (int i = 0; i < inputSize; i++) {
*(reinterpret_cast<at::Tensor*>(outs[i])) = tempOut[i];
impl::aten::updateATen2Tensor(ctx, tempOut[i], outs[i]);
}

return diopiSuccess;
Expand Down Expand Up @@ -3317,12 +3315,11 @@ diopiError_t diopiForeachnormScalar(diopiContextHandle_t ctx, diopiTensorHandle_
DIOPI_CHECK_PTR(outs);
impl::aten::setCurStream(ctx);
DIOPI_IMPL_BUILD_ATEN_LIST(atInputs, inputs, inputSize)
DIOPI_IMPL_BUILD_ATEN_LIST(atOuts, outs, inputSize)
auto atP = impl::aten::buildAtScalar(p);
auto tempOut = CALL_ATEN_CUDA_FUNC(_foreach_norm, atInputs, atP);
for (int i = 0; i < inputSize; i++) {
// impl::aten::updateATen2Tensor(ctx, tempOut[i], out[i]);
*(reinterpret_cast<at::Tensor*>(outs[i])) = tempOut[i];
//WARN NO NEED TO COPY HERE, WE NEED FASTER UPDATE HERE
impl::aten::updateATen2Tensor(ctx, tempOut[i], outs[i]);
}

return diopiSuccess;
Expand Down

0 comments on commit 880ddf4

Please sign in to comment.