Skip to content

Commit

Permalink
optimize performance (DeepLink-org#1238)
Browse files Browse the repository at this point in the history
* optimize performance

* remove no need gil lock

* add ContextManger

* use cast when buildATen

* fix no need copy, sync

* fix compiler bug

* Put the two different implementations into different namespaces and control which version is used via environment variables

* fix: rename variable in ascend_npu CMakeLists.txt

* fix transpose bug

---------

Co-authored-by: chenchiyu <[email protected]>
  • Loading branch information
zhaoguochun1995 and CyCle1024 authored Jun 20, 2024
1 parent 390d7dd commit 2ee61ea
Show file tree
Hide file tree
Showing 21 changed files with 208 additions and 142 deletions.
4 changes: 2 additions & 2 deletions impl/ascend_npu/diopi_impl/baddbmm.cpp
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "helper.hpp"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"

namespace OP_IMPL_NS {

Expand All @@ -16,8 +17,7 @@ diopiError_t diopiBaddbmm(diopiContextHandle_t ctx, diopiTensorHandle_t out, dio
auto alphaAt = at::Scalar(alpha);

if (batch1At.numel() == 0 || batch2At.numel() == 0) {
auto outMul = op_api::mul(inputAt, betaAt);
outAt.copy_(outMul);
EXEC_NPU_CMD(aclnnMuls, inputAt, betaAt, outAt);
END_CALL_ACL_OP();
}

Expand Down
3 changes: 1 addition & 2 deletions impl/ascend_npu/diopi_impl/bmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
namespace OP_IMPL_NS {

diopiError_t diopiBmm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t mat2) {
BEGIN_CALL_ACL_OP(out);
BEGIN_CALL_ACL_OP(out, input, mat2);
if (outAt.numel() == 0) {
return diopiSuccess;
}
BEGIN_CALL_ACL_OP(input, mat2);
// op_api::bmm_out(inputAt, mat2At, outAt);
signed char cubeMathType = at_npu::native::OpPreparation::get_cube_math_type(at_npu::native::env::IsAllowMatmulHF32());
EXEC_NPU_CMD(aclnnBatchMatMul, inputAt, mat2At, outAt, cubeMathType);
Expand Down
2 changes: 1 addition & 1 deletion impl/ascend_npu/diopi_impl/cat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ diopiError_t diopiCat(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiCo
}
}
if (outAt.scalar_type() != outTempAt.scalar_type()) {
outAt.copy_(outTempAt);
outAt.copy_(outTempAt, true);
}

END_CALL_ACL_OP();
Expand Down
23 changes: 11 additions & 12 deletions impl/ascend_npu/diopi_impl/clamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "helper.hpp"
#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"

namespace OP_IMPL_NS {
diopiError_t diopiClamp(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t min,
Expand All @@ -17,7 +18,7 @@ diopiError_t diopiClamp(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopi
if (inputAt.numel() == 0) {
return diopiSuccess;
}
at::Tensor inputTmp = inputAt.to(outAt.scalar_type());
at::Tensor inputTmp = inputAt.to(outAt.scalar_type(), true);
if (minAt.defined() && maxAt.defined()) {
op_api::clamp_out(inputTmp, minAt, maxAt, outAt);
} else {
Expand All @@ -38,7 +39,7 @@ diopiError_t diopiClampScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out,
if (inputAt.numel() == 0) {
return diopiSuccess;
}
at::Tensor inputTmp = inputAt.to(outAt.scalar_type());
at::Tensor inputTmp = inputAt.to(outAt.scalar_type(), true);
op_api::clamp_out(inputTmp, minAt, maxAt, outAt);
END_CALL_ACL_OP();
} else {
Expand All @@ -47,7 +48,7 @@ diopiError_t diopiClampScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out,
if (inputAt.numel() == 0) {
return diopiSuccess;
}
at::Tensor inputTmp = inputAt.to(outAt.scalar_type());
at::Tensor inputTmp = inputAt.to(outAt.scalar_type(), true);
op_api::clamp_min_out(inputTmp, minAt, outAt);
END_CALL_ACL_OP();
}
Expand All @@ -56,9 +57,8 @@ diopiError_t diopiClampScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out,
if (inputAt.numel() == 0) {
return diopiSuccess;
}
at::Tensor inputTmp = inputAt.to(outAt.scalar_type());
at::Tensor tmp = op_api::clamp_max(inputTmp, maxAt);
outAt.copy_(tmp);
at::Tensor inputTmp = inputAt.to(outAt.scalar_type(), true);
EXEC_NPU_CMD(aclnnClampMaxTensor, inputTmp, maxAt, outAt);
END_CALL_ACL_OP();
}
}
Expand Down Expand Up @@ -133,7 +133,7 @@ diopiError_t diopiClampMinScalar(diopiContextHandle_t ctx, diopiTensorHandle_t o
if (inputAt.numel() == 0) {
return diopiSuccess;
}
at::Tensor inputTmp = inputAt.to(outAt.scalar_type());
at::Tensor inputTmp = inputAt.to(outAt.scalar_type(), true);
op_api::clamp_min_out(inputTmp, minAt, outAt);
END_CALL_ACL_OP();
}
Expand All @@ -143,7 +143,7 @@ diopiError_t diopiClampMin(diopiContextHandle_t ctx, diopiTensorHandle_t out, di
if (inputAt.numel() == 0) {
return diopiSuccess;
}
at::Tensor inputTmp = inputAt.to(outAt.scalar_type());
at::Tensor inputTmp = inputAt.to(outAt.scalar_type(), true);
op_api::clamp_min_out(inputTmp, minAt, outAt);
END_CALL_ACL_OP();
}
Expand Down Expand Up @@ -171,9 +171,8 @@ diopiError_t diopiClampMaxScalar(diopiContextHandle_t ctx, diopiTensorHandle_t o
if (inputAt.numel() == 0) {
return diopiSuccess;
}
at::Tensor inputTmp = inputAt.to(outAt.scalar_type());
at::Tensor tmp = op_api::clamp_max(inputTmp, maxAt);
outAt.copy_(tmp);
at::Tensor inputTmp = inputAt.to(outAt.scalar_type(), true);
EXEC_NPU_CMD(aclnnClampMaxTensor, inputTmp, maxAt, outAt);
END_CALL_ACL_OP();
}

Expand All @@ -182,7 +181,7 @@ diopiError_t diopiClampMax(diopiContextHandle_t ctx, diopiTensorHandle_t out, di
if (inputAt.numel() == 0) {
return diopiSuccess;
}
at::Tensor inputTmp = inputAt.to(outAt.scalar_type());
at::Tensor inputTmp = inputAt.to(outAt.scalar_type(), true);
op_api::clamp_max_out(inputTmp, maxAt, outAt);
END_CALL_ACL_OP();
}
Expand Down
2 changes: 1 addition & 1 deletion impl/ascend_npu/diopi_impl/copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ diopiError_t diopiCopyInp(diopiContextHandle_t ctx, diopiConstTensorHandle_t src
if (src == nullptr || dest == nullptr || !srcAt.defined() || !destAt.defined() || srcAt.numel() <= 0 || destAt.numel() <= 0) {
return diopiSuccess;
}
at_npu::native::NPUNativeOpApiFunctions::copy_(destAt, srcAt, false);
at_npu::native::NPUNativeOpApiFunctions::copy_(destAt, srcAt, true);
END_CALL_ACL_OP();
}
#if 0
Expand Down
2 changes: 1 addition & 1 deletion impl/ascend_npu/diopi_impl/dropout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ diopiError_t diopiDropout(diopiContextHandle_t ctx, diopiTensorHandle_t out, dio
DIOPI_CHECK(maskAt.defined(), "[DIOPI][Ascend] Check if mask tensor defined");

if (p == 0 || train == false) {
outAt.copy_(inputAt);
outAt.copy_(inputAt, true);
op_api::fill_(maskAt, c10::Scalar(1));
END_CALL_ACL_OP();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ at::Tensor torchContextAttention(at::Tensor xq, at::Tensor xk, at::Tensor xv, in
at::Tensor mask = op_api::tril(op_api::ones({seqLen, seqLen}, at::kFloat, layout, device)).unsqueeze(0).unsqueeze(0);
mask.masked_fill_(mask == 0., -100000000.0);
mask = mask.repeat({batchSize, head, 1, 1});
at::Tensor scores = op_api::matmul(xq.to(at::kFloat), xk.transpose(2, 3).to(at::kFloat)) / std::sqrt(dim);
at::Tensor output = op_api::matmul((scores + mask).softmax(-1), xv.to(at::kFloat)).transpose(1, 2).to(dtype);
at::Tensor scores = op_api::matmul(xq.to(at::kFloat, true), xk.transpose(2, 3).to(at::kFloat, true)) / std::sqrt(dim);
at::Tensor output = op_api::matmul((scores + mask).softmax(-1), xv.to(at::kFloat, true)).transpose(1, 2).to(dtype, true);
output = output.view({output.numel() / static_cast<int64_t>(head * dim), head, dim});
return output;
}
Expand Down
4 changes: 2 additions & 2 deletions impl/ascend_npu/diopi_impl/functions_ext/rms_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ diopiError_t diopiRMSNormBackward(diopiContextHandle_t ctx, diopiTensorHandle_t
at::Tensor gradWeightTempAt = at_npu::native::OpPreparation::apply_tensor_with_format(
op_infer::rms_norm_grad_npu_output_size(inputAt, weightAt)[1], gradWeightAt.options().dtype(at::kFloat), ACL_FORMAT_ND);
EXEC_NPU_CMD(aclnnRmsNormGrad, gradOutputAt, inputAt, invRmsAt, weightAt, gradInputAt, gradWeightTempAt);
gradWeightAt.copy_(gradWeightTempAt);
gradWeightAt.copy_(gradWeightTempAt, true);
} else {
EXEC_NPU_CMD(aclnnRmsNormGrad, gradOutputAt, inputAt, invRmsAt, weightAt, gradInputAt, gradWeightAt);
}
Expand All @@ -60,7 +60,7 @@ diopiError_t diopiRMSNormBackward(diopiContextHandle_t ctx, diopiTensorHandle_t
std::iota(sumDims.begin(), sumDims.end(), 0);
op_api::sum_out(gradOutputAt, sumDims, false, gradBiasAt.scalar_type(), gradBiasAt);
} else {
gradBiasAt.copy_(gradOutputAt);
gradBiasAt.copy_(gradOutputAt, true);
}
}
END_CALL_ACL_OP();
Expand Down
6 changes: 3 additions & 3 deletions impl/ascend_npu/diopi_impl/functions_ext/rotary_embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ DIOPI_API diopiError_t diopiRotaryEmbedding(diopiContextHandle_t ctx, diopiTenso

if (xAt.dim() >= 5) {
set_last_error_string("rotary embedding not support 5D tensor yet");
impl::aten::unsetCurCtx();
return diopi5DNotSupported;
}

Expand All @@ -68,8 +67,9 @@ DIOPI_API diopiError_t diopiRotaryEmbedding(diopiContextHandle_t ctx, diopiTenso

std::vector<at::Tensor> chunkResult = xView.chunk(2, -1);
at::Tensor xNew = op_api::cat({chunkResult[1] * (-1), chunkResult[0]}, -1);
at::Tensor result = op_api::mul(cosCat, xView) + op_api::mul(sinCat, xNew);
outView.copy_(result);
auto result1 = op_api::mul(cosCat, xView);
auto result2 = op_api::mul(sinCat, xNew);
op_api::add_out(result1, result2, 1.0, outView);

END_CALL_ACL_OP();
}
Expand Down
6 changes: 2 additions & 4 deletions impl/ascend_npu/diopi_impl/group_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@ namespace OP_IMPL_NS {

diopiError_t diopiGroupNorm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t saveMean, diopiTensorHandle_t saveInvstd,
diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t numGroups, double eps) {
BEGIN_CALL_ACL_OP(input);
BEGIN_CALL_ACL_OP(input, weight, bias, out, saveMean, saveInvstd);
if (!inputAt.defined() || inputAt.numel() == 0) {
return diopiSuccess;
}
BEGIN_CALL_ACL_OP(weight, bias, out, saveMean, saveInvstd);
int64_t n = inputAt.sizes()[0];
int64_t c = inputAt.sizes()[1];
int64_t hw = inputAt.numel() / (n * c);
Expand All @@ -28,11 +27,10 @@ diopiError_t diopiGroupNorm(diopiContextHandle_t ctx, diopiTensorHandle_t out, d
diopiError_t diopiGroupNormBackward(diopiContextHandle_t ctx, diopiTensorHandle_t gradInput, diopiTensorHandle_t gradWeight, diopiTensorHandle_t gradBias,
diopiConstTensorHandle_t gradOutput, diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight,
diopiConstTensorHandle_t mean, diopiConstTensorHandle_t rstd, int64_t numGroups) {
BEGIN_CALL_ACL_OP(input);
BEGIN_CALL_ACL_OP(input, gradWeight, gradBias);
if (!inputAt.defined()) {
return diopiSuccess;
}
BEGIN_CALL_ACL_OP(gradWeight, gradBias);
if (inputAt.numel() == 0) {
if (inputAt.sizes()[0] == 0) {
op_api::fill_(gradWeightAt, c10::Scalar(0.0));
Expand Down
13 changes: 7 additions & 6 deletions impl/ascend_npu/diopi_impl/helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,10 @@ inline int debugLevel() {
if (debugLevel()) { \
std::cout << __FILE__ << ":" << __LINE__ << " :" << __FUNCTION__ << std::endl; \
} \
impl::aten::setCurCtx(ctx); \
impl::aten::ContextManger contextManger(ctx); \
BUILD_ATEN_ARGS(__VA_ARGS__)

#define END_CALL_ACL_OP() \
impl::aten::unsetCurCtx(); \
if (debugLevel()) { \
std::cout << __FILE__ << ":" << __LINE__ << " :" << __FUNCTION__ << " over" << std::endl; \
} \
Expand Down Expand Up @@ -198,9 +197,11 @@ namespace impl {

namespace aten {

void setCurCtx(diopiContextHandle_t ctx);

void unsetCurCtx();
class ContextManger {
public:
ContextManger(diopiContextHandle_t context);
~ContextManger();
};

inline void sync(diopiContextHandle_t ctx) {
diopiStreamHandle_t streamHandle;
Expand Down Expand Up @@ -405,8 +406,8 @@ inline decltype(auto) buildATenList(T* tensors, int64_t numTensors) {
}

inline void updateATen2Tensor(diopiContextHandle_t ctx, const at::Tensor& atOut, diopiTensorHandle_t out) {
// TODO(fengsibo): add device and nbytes check
if (out != nullptr) {
TORCH_WARN(false, "can be optimized: there is no need copy");
at::Tensor atOutput = buildATen(out);
atOutput.reshape_as(atOut).copy_(atOut, true);
}
Expand Down
2 changes: 1 addition & 1 deletion impl/ascend_npu/diopi_impl/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ diopiError_t diopiIndexBackward(diopiContextHandle_t ctx, diopiTensorHandle_t gr

auto indicesCast = impl::aten::castIntIndicesToLongIndices(indicesAtList);
op_api::_index_put_impl_(zerosLikeInputAt, indicesCast, gradOutputAt, true, false);
gradInputAt.copy_(zerosLikeInputAt);
gradInputAt.copy_(zerosLikeInputAt, true);
END_CALL_ACL_OP();
}

Expand Down
2 changes: 1 addition & 1 deletion impl/ascend_npu/diopi_impl/index_put.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ diopiError_t diopiIndexPut(diopiContextHandle_t ctx, diopiTensorHandle_t out, di
indicesAtList.emplace_back(impl::aten::buildATen(indices[i]));
}

outAt.copy_(inputAt);
outAt.copy_(inputAt, true);
auto indicesCast = impl::aten::castIntIndicesToLongIndices(indicesAtList);
op_api::_index_put_impl_(outAt, indicesCast, valuesAt, accumulate, false);
END_CALL_ACL_OP();
Expand Down
4 changes: 2 additions & 2 deletions impl/ascend_npu/diopi_impl/index_select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ diopiError_t diopiIndexSelect(diopiContextHandle_t ctx, diopiTensorHandle_t out,

at::Tensor indexTempAt = indexAt;
if (indexAt.scalar_type() != at::kInt || indexAt.scalar_type() != at::kLong) {
indexTempAt = indexAt.to(at::kLong);
indexTempAt = indexAt.to(at::kLong, true);
}

if (false) {
Expand All @@ -41,7 +41,7 @@ diopiError_t diopiIndexSelectBackward(diopiContextHandle_t ctx, diopiTensorHandl

at::Tensor indexTempAt = indexAt;
if (indexAt.scalar_type() != at::kInt || indexAt.scalar_type() != at::kLong) {
indexTempAt = indexAt.to(at::kLong);
indexTempAt = indexAt.to(at::kLong, true);
}

at::Scalar zero{0.0};
Expand Down
4 changes: 2 additions & 2 deletions impl/ascend_npu/diopi_impl/masked_fill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ diopiError_t diopiMaskedFill(diopiContextHandle_t ctx, diopiTensorHandle_t out,
return diopiSuccess;
}
if (outAt.data_ptr() != inputAt.data_ptr()) {
outAt.copy_(inputAt);
outAt.copy_(inputAt, true);
}
op_api::masked_fill_(outAt, maskAt, valueAt);
END_CALL_ACL_OP();
Expand All @@ -38,7 +38,7 @@ diopiError_t diopiMaskedFillScalar(diopiContextHandle_t ctx, diopiTensorHandle_t
return diopiSuccess;
}
if (outAt.data_ptr() != inputAt.data_ptr()) {
outAt.copy_(inputAt);
outAt.copy_(inputAt, true);
}
op_api::masked_fill_(outAt, maskAt, valueAt);
END_CALL_ACL_OP();
Expand Down
7 changes: 2 additions & 5 deletions impl/ascend_npu/diopi_impl/mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,11 @@
namespace OP_IMPL_NS {

diopiError_t diopiMul(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t other) {
BEGIN_CALL_ACL_OP(input, other);
BEGIN_CALL_ACL_OP(input, other, out);
if (!inputAt.defined() || inputAt.numel() == 0 || !otherAt.defined() || otherAt.numel() == 0) {
return diopiSuccess;
}

BEGIN_CALL_ACL_OP(out);
// op_api::mul_out(inputAt, otherAt, outAt);
EXEC_NPU_CMD(aclnnMul, inputAt, otherAt, outAt);
END_CALL_ACL_OP();
}
Expand All @@ -34,12 +32,11 @@ diopiError_t diopiMulInp(diopiContextHandle_t ctx, diopiTensorHandle_t input, di
}

diopiError_t diopiMulScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* other) {
BEGIN_CALL_ACL_OP(input, other);
BEGIN_CALL_ACL_OP(input, other, out);
if (!inputAt.defined() || inputAt.numel() == 0) {
return diopiSuccess;
}

BEGIN_CALL_ACL_OP(out);
EXEC_NPU_CMD(aclnnMuls, inputAt, otherAt, outAt);
END_CALL_ACL_OP();
}
Expand Down
2 changes: 1 addition & 1 deletion impl/ascend_npu/diopi_impl/repeat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ diopiError_t diopiRepeat(diopiContextHandle_t ctx, diopiTensorHandle_t out, diop
TORCH_CHECK(inputAt.dim() <= repeatSize.len, "repeats size should not be smaller than input tensor dim on ascend!");
// When repeatSize.len is equal to 0, out is the same as input.
if (repeatSize.len == 0) {
outAt.copy_(inputAt);
outAt.copy_(inputAt, true);
END_CALL_ACL_OP();
}

Expand Down
6 changes: 3 additions & 3 deletions impl/ascend_npu/diopi_impl/scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ diopiError_t diopiScatter(diopiContextHandle_t ctx, diopiTensorHandle_t out, dio
TORCH_CHECK((inputAt.dim() == srcAt.dim() && inputAt.dim() == indexAt.dim()) || indexAt.dim() == 0,
"input,src,index must have same ndim! only exception is index is empty");
if (indexAt.dim() == 0) {
outAt.copy_(inputAt);
outAt.copy_(inputAt, true);
return diopiSuccess;
}
// input to output type
at::Tensor inputTmpAt = inputAt;
if (outAt.scalar_type() != inputAt.scalar_type()) {
inputTmpAt = inputAt.to(outAt.scalar_type());
inputTmpAt = inputAt.to(outAt.scalar_type(), true);
}
int64_t reduction = getReduce(reduce);
EXEC_NPU_CMD(aclnnScatter, inputTmpAt, dim, indexAt, srcAt, reduction, outAt);
Expand Down Expand Up @@ -89,7 +89,7 @@ diopiError_t diopiScatterScalar(diopiContextHandle_t ctx, diopiTensorHandle_t ou
}
// check index
TORCH_CHECK(inputAt.dim() == indexAt.dim() || indexAt.dim() == 0, "input,index must have same ndim! only exception is index is empty");
outAt.copy_(inputAt);
outAt.copy_(inputAt, true);
if (indexAt.dim() == 0) {
return diopiSuccess;
}
Expand Down
4 changes: 2 additions & 2 deletions impl/ascend_npu/diopi_impl/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace OP_IMPL_NS {
diopiError_t diopiTranspose(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, int64_t dim0, int64_t dim1) {
BEGIN_CALL_ACL_OP(input, out);
if (0 == inputAt.dim()) {
outAt.copy_(inputAt);
outAt.copy_(inputAt, true);
return diopiSuccess;
}

Expand All @@ -31,7 +31,7 @@ diopiError_t diopiTranspose(diopiContextHandle_t ctx, diopiTensorHandle_t out, d
diopiError_t diopiPermute(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiSize_t dims) {
BEGIN_CALL_ACL_OP(input, dims, out);
if (0 == dims.len) {
outAt.copy_(inputAt);
outAt.copy_(inputAt, true);
return diopiSuccess;
}
EXEC_NPU_CMD(aclnnPermute, inputAt, dimsAt, outAt);
Expand Down
Loading

0 comments on commit 2ee61ea

Please sign in to comment.