Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]add mluop cholesky #1146

Open
wants to merge 47 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
fc47e70
complete the float type cholesky operator
dglr Apr 16, 2024
77db74c
[WIP]add mluop cholesky
dglr Apr 30, 2024
0e9a1f8
add cholesky doc
dglr May 23, 2024
6051cce
modify mathematical formula
dglr May 27, 2024
86a2c41
add complex type
dglr Jun 7, 2024
efa3d08
finish complex batch
dglr Jun 28, 2024
4872a42
fix ang bugs
dglr Jul 19, 2024
0f12676
fix nram workspace, update doc
dglr Jul 23, 2024
86ceaba
add pseudocode
dglr Jul 25, 2024
beb7e53
add comments
dglr Jul 25, 2024
40f62ba
add index.rst
dglr Jul 25, 2024
d935fb9
format code
dglr Jul 25, 2024
e42d270
[Fix](mluOpCholesky): fix format
dglr Aug 12, 2024
bfcf2b2
[Fix](mluOpCholesky): add mluoplog when sqrt
dglr Aug 15, 2024
e4f330b
[Fix](mluOpCholesky): reset workspace
dglr Aug 15, 2024
9eb9dc4
[Fix](mluOpCholesky): rename getworkspace size function
dglr Aug 15, 2024
76631ee
[Fix](mluOpCholesky): rewrite description in mlu_op
dglr Aug 15, 2024
23acef7
[Docs](mluOpCholesky): update docs
dglr Aug 15, 2024
40661cc
[Fix](mluOpCholesky): del printf
dglr Aug 15, 2024
4d34d54
[Docs](mluOpCholesky): rewrite Conjugate transpose symbol
dglr Aug 15, 2024
fc1a0ac
[Fix](mluOpCholesky): format
dglr Aug 24, 2024
b0d5b6e
[Fix](mluOpCholesky): add layout check
dglr Sep 16, 2024
9f7dcd5
[Fix](mluOpCholesky): fix mem check
dglr Sep 16, 2024
7888578
[Docs](mluOpCholesky): add test doc
dglr Sep 16, 2024
435f829
[Docs](mluOpCholesky): add coverage test
dglr Sep 20, 2024
4f9a1af
[Fix](mluOpCholesky): add dimension equals test
dglr Sep 20, 2024
43bbe67
[Fix](mluOpCholesky): add coverage function
dglr Sep 21, 2024
976a88b
[Fix](mluOpCholesky): test
Nov 13, 2024
284e6e8
[Fix](mluOpCholesky): test
Nov 14, 2024
65705ea
[Fix](mluOpCholesky): resolve conflict
Nov 14, 2024
4d18629
[Fix](mluOpCholesky): fix magic number
Nov 14, 2024
5ae8c94
[Fix](mluOpCholesky): cut one branch
Nov 14, 2024
2b5822d
[Fix](mluOpCholesky): add policy func
Nov 30, 2024
ecc80ec
[Fix](mluOpCholesky): rename variables
Dec 1, 2024
0938f55
[Fix](mluOpCholesky): add some comments
Dec 1, 2024
3213104
[Fix](mluOpCholesky): mv cnnl to cpp
Dec 1, 2024
a9c9db6
[Fix](mluOpCholesky): add new memcpy
Dec 4, 2024
876ebc2
[Fix](mluOpCholesky): remove useless sync
Dec 4, 2024
94664e5
[Fix](mluOpCholesky): remove cnrtmemcpy cnrtqueuesync
Dec 8, 2024
2c96f6f
[Fix](mluOpCholesky): update cholesky_test
Dec 8, 2024
9718b02
[Fix](mluOpCholesky): fix bugs
Dec 9, 2024
a472285
[Docs](mluOpCholesky): update doc
Dec 9, 2024
5d35eca
[Fix](mluOpCholesky): fix sync bugs
Dec 9, 2024
55ff179
[Docs](mluOpCholesky): update doc
Dec 9, 2024
b52a27a
[Fix](mluOpCholesky): add param check
Dec 10, 2024
905b0ad
[Fix](mluOpCholesky): update
Feb 18, 2025
9a3aba7
[Fix](mluOpCholesky): update
Feb 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/design_docs/cholesky/32_128性能分析.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
495 changes: 495 additions & 0 deletions docs/design_docs/cholesky/cholesky.md

Large diffs are not rendered by default.

Binary file added docs/design_docs/cholesky/coverage.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/design_docs/cholesky/coverage_error.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/design_docs/cholesky/divide.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/design_docs/cholesky/gemm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/design_docs/cholesky/potrf.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/design_docs/cholesky/recur_p1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/design_docs/cholesky/recur_p2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/design_docs/cholesky/syrk.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/design_docs/cholesky/timeline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/design_docs/cholesky/trsm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 7 additions & 0 deletions docs/user_guide/9_operators/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -793,3 +793,10 @@ mluOpLgamma

- ``x`` 为输入张量。

.. cholesky::
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的.. cholesky::是多写了吧?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.. cholesky::
.. _cholesky:


mluOpCholesky
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有公式的话,需要补充下公式。

---------------
执行 Cholesky 分解,将一个正定矩阵分解为其下三角矩阵(L)或其转置的上三角矩阵(U),具体分解为上三角或下三角取决于参数``upper``。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
执行 Cholesky 分解,将一个正定矩阵分解为其下三角矩阵(L)或其转置的上三角矩阵(U),具体分解为上三角或下三角取决于参数``upper``。
执行 Cholesky 分解,将一个正定矩阵分解为下三角矩阵(L)或转置的上三角矩阵(U)。分解为上三角还是下三角取决于参数 ``upper``。

可以把里面的“其”删除吗?


该算子包含7个输入:handle 为操作句柄,input_desc 与 d_input 分别描述并提供输入矩阵的信息;两个输出:output_desc 与 d_output 分别描述并存储输出矩阵的信息;此外,还包含一个布尔参数 upper,用于指定输出是上三角还是下三角矩阵,以及一个 workspace 用于临时存储计算过程中的数据。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
该算子包含7个输入:handle 为操作句柄,input_desc 与 d_input 分别描述并提供输入矩阵的信息;两个输出:output_desc 与 d_output 分别描述并存储输出矩阵的信息;此外,还包含一个布尔参数 upper用于指定输出是上三角还是下三角矩阵,以及一个 workspace 用于临时存储计算过程中的数据。
该算子包含7个输入:handle 为操作句柄,input_desc 与 d_input 分别描述并提供输入矩阵的信息;两个输出:output_desc 与 d_output 分别描述并存储输出矩阵的信息;此外,还包括参数 ``upper`` 用于指定输出是上三角还是下三角矩阵,以及 ``workspace`` 用于临时存储计算过程中的数据。

7个输入,这里只说明了3个(handle,input_desc 与 d_input)?upper,workspace也是输入的话,建议放在一起说明。

318 changes: 318 additions & 0 deletions kernels/cholesky/cholesky.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
/*************************************************************************
* Copyright (C) [2024] by Cambricon, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the
* "Software"), to deal in the Software without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Software, and to
* permit persons to whom the Software is furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/

#include "cholesky.h"
#include <cstdio>
#include <algorithm>
// calculates the required workspace size for performing the Cholesky
// decomposition on a given matrix or batch of matrices.
mluOpStatus_t MLUOP_WIN_API mluOpGetCholeskyWorkspaceSize(
mluOpTensorDescriptor_t input_desc, size_t* size) {
PARAM_CHECK("mluOpCholesky", input_desc != NULL);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

缺少handle入参和handle是否为空的判断


PARAM_CHECK("mluOpCholesky", input_desc->dim == 2 || input_desc->dim == 3);
PARAM_CHECK("mluOpCholesky", input_desc->dims[0] > 0);
PARAM_CHECK("mluOpCholesky", input_desc->dims[1] > 0);

if (input_desc->dim == 3) {
PARAM_CHECK("mluOpCholesky", input_desc->dims[2] > 0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

输入是个方阵这个没有进行对应的check
if (input_desc->dim == 3) {
input_desc->dims[2] == input_desc->dims[1]
} else {
input_desc->dims[1] == input_desc->dims[0]
}

}

mluOpDataType_t dtype = input_desc->dtype;
PARAM_CHECK("mluOpCholesky",
dtype == MLUOP_DTYPE_FLOAT || dtype == MLUOP_DTYPE_COMPLEX_FLOAT);

uint64_t type_size;
MLUOP_CHECK(mluOpGetSizeOfDataType(dtype, &type_size));
int64_t size_a = 0, lda = 0, size_c = 0, ldc = 0;
int64_t batch_size = 1;
int dim = input_desc->dim;
if (dim == 2) {
size_a = input_desc->dims[0];
} else if (dim == 3) {
batch_size = input_desc->dims[0];
size_a = input_desc->dims[1];
}

if (dtype == MLUOP_DTYPE_FLOAT) {
*size = size_a * size_a * sizeof(float) * batch_size * 3;
} else {
*size = size_a * size_a * sizeof(float) * 2 * batch_size * 3;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return MLUOP_STATUS_SUCCESS;
}

// performs the necessary operations to compute matrix transformations,
// potentially involving Cholesky decomposition or matrix transposition,
// depending on the input parameters.
mluOpStatus_t MLUOP_WIN_API
calculate_body(mluOpHandle_t handle, int batch_size,
const mluOpTensorDescriptor_t input_desc, float* d_input,
const mluOpTensorDescriptor_t output_desc, float* d_output,
bool upper, float* workspace) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float* workspace -->void* workspace

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

mluOpDataType_t dtype = input_desc->dtype;

int recnb = REC_NB;
int gbstep = 0;
int dim = input_desc->dim;
bool is_row_major = (input_desc->strides)[dim - 1] == 1;

uint64_t type_size;
MLUOP_CHECK(mluOpGetSizeOfDataType(dtype, &type_size));
int size_a = 0, lda = 0, size_c = 0, ldc = 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

统一使用int64_t,防止溢出

if (dim == 2) {
size_a = input_desc->dims[0];
lda = input_desc->dims[1];
size_c = output_desc->dims[0];
ldc = output_desc->dims[1];
} else if (dim == 3) {
size_a = input_desc->dims[1];
lda = input_desc->dims[2];
size_c = output_desc->dims[1];
ldc = output_desc->dims[2];
}

PARAM_CHECK("mluOpCholesky", lda >= size_a);
PARAM_CHECK("mluOpCholesky", ldc >= size_c);

cnrtQueue_t queue;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不能自己创建临时的queue,要统一使用handle->queue

mluOpGetQueue(handle, &queue);

int jb;
const float s_one = 1.0;
const float s_neg_one = -1.0;

if (dtype == MLUOP_DTYPE_FLOAT) {
if (upper == true) {
CHECK_RETURN("mluOpCholesky",
transpose(batch_size, size_a, size_a, d_input, d_output,
handle, dtype, workspace));
} else {
CNRT_CHECK(cnrtMemcpy(d_output, d_input,
type_size * size_a * lda * ((uint64_t)batch_size),
CNRT_MEM_TRANS_DIR_DEV2DEV));
}
} else {
CHECK_RETURN("mluOpCholesky",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

transpose(batch_size, size_a * size_a, 2, d_input, d_output,
handle, MLUOP_DTYPE_FLOAT, workspace));
}

cnrtQueueSync(queue);
int stride = size_a * lda;

if (dtype == MLUOP_DTYPE_FLOAT) {
int row = is_row_major ? lda : size_a;
int nb = NB;
set_half_zero(batch_size, stride, d_output, lda, lda, handle);
cnrtQueueSync(queue);
for (int j = 0; j < row; j += nb) {
jb = std::min(nb, row - j);
CHECK_RETURN("mluOpCholesky",
ssyrk(batch_size, stride, false, is_row_major, jb, j,
OFFSET_ROW(d_output, j, 0), lda,
OFFSET_ROW(d_output, j, j), lda, handle, workspace));
cnrtQueueSync(queue);
CHECK_RETURN("mluOpCholesky",
mlu_spotrf_rectile(batch_size, stride, is_row_major, false,
jb, recnb, OFFSET_ROW(d_output, j, j),
lda, j, handle, workspace));
if (j + jb < row) {
CHECK_RETURN(
"mluOpCholesky",
sgemm(batch_size, !is_row_major, is_row_major, row - j - jb, jb, j,
-1.0f, 1.0f, OFFSET_ROW(d_output, j + jb, 0), lda, stride,
OFFSET_ROW(d_output, j, 0), lda, stride,
OFFSET_ROW(d_output, j + jb, j), lda, stride, handle,
workspace));
cnrtQueueSync(queue);
}
if (j + jb < row) {
CHECK_RETURN(
"mluOpCholesky",
strsm(batch_size, stride, false, is_row_major, jb, row - j - jb,
OFFSET_ROW(d_output, j, j), lda,
OFFSET_ROW(d_output, j + jb, j), lda, handle, workspace));
cnrtQueueSync(queue);
}
}

if (upper) {
cnrtQueueSync(queue);
CHECK_RETURN("mluOpCholesky",
transpose(batch_size, size_a, size_a, d_output, workspace,
handle, dtype, workspace));
cnrtQueueSync(queue);
CNRT_CHECK(cnrtMemcpy(d_output, workspace,
type_size * size_a * lda * ((uint64_t)batch_size),
CNRT_MEM_TRANS_DIR_DEV2DEV));
}
} else {
recnb = CREC_NB;
int nb = CNB;
int row = lda;
float* r_start = d_output;
float* i_start = d_output + size_a * lda;
stride *= 2;

set_half_zero(batch_size, stride, r_start, lda, lda, handle);
set_half_zero(batch_size, stride, i_start, lda, lda, handle);
cnrtQueueSync(queue);

for (int j = 0; j < row; j += nb) {
jb = std::min(nb, row - j);
CHECK_RETURN("mluOpCholesky",
cherk(batch_size, stride, jb, j, r_start + j * lda,
i_start + j * lda, lda, r_start + j * lda + j,
i_start + j * lda + j, lda, handle, workspace));
cnrtQueueSync(queue);
CHECK_RETURN("mluOpCholesky",
mlu_cpotrf_rectile(
batch_size, stride, jb, recnb, r_start + j * lda + j,
i_start + j * lda + j, lda, handle, workspace));
cnrtQueueSync(queue);
if (j + jb < row) {
CHECK_RETURN("mluOpCholesky",
cgemm(batch_size, false, true, row - j - jb, jb, j, -1.0f,
1.0f, OFFSET_ROW(r_start, j + jb, 0),
OFFSET_ROW(i_start, j + jb, 0), lda, stride,
OFFSET_ROW(r_start, j, 0), OFFSET_ROW(i_start, j, 0),
lda, stride, OFFSET_ROW(r_start, j + jb, j),
OFFSET_ROW(i_start, j + jb, j), lda, stride, handle,
workspace));

cnrtQueueSync(queue);
}
if (j + jb < row) {
CHECK_RETURN(
"mluOpCholesky",
ctrsm(batch_size, stride, jb, row - j - jb,
OFFSET_ROW(r_start, j, j), OFFSET_ROW(i_start, j, j), lda,
OFFSET_ROW(r_start, j + jb, j),
OFFSET_ROW(i_start, j + jb, j), lda, handle, workspace));
cnrtQueueSync(queue);
}
}

CHECK_RETURN("mluOpCholesky",
transpose(batch_size, 2, size_a * size_a, d_output, workspace,
handle, MLUOP_DTYPE_FLOAT, workspace));
cnrtQueueSync(queue);

if (upper) {
cnrtQueueSync(queue);
CHECK_RETURN("mluOpCholesky",
transpose(batch_size, size_a, size_a, workspace, d_output,
handle, dtype, workspace));
cnrtQueueSync(queue);
CHECK_RETURN("mluOpCholesky", conj_complex(batch_size, size_a, size_a,
d_output, d_output, handle));
cnrtQueueSync(queue);
} else {
if (batch_size > 16) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥大于16要特殊处理,注释下原因?
为啥是16batch要这么特殊处理,而不是15,12
对于不同的板卡,这个参数是否有什么差异

CNRT_CHECK(cnrtMemcpy(d_output, workspace,
type_size * size_a * lda * 16,
CNRT_MEM_TRANS_DIR_DEV2DEV));
CNRT_CHECK(
cnrtMemcpy(d_output + type_size / 4 * size_a * lda * 16,
workspace + type_size / 4 * size_a * lda * 16,
type_size * size_a * lda * ((uint64_t)batch_size - 16),
CNRT_MEM_TRANS_DIR_DEV2DEV));
} else {
CNRT_CHECK(cnrtMemcpy(d_output, workspace,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不建议使用cnrtMemcpy和cnrtMemset,cnrtQueueSync,会对上层使用mlu_graph有问题
建议cnrtMemcpy使用片上的__memcpy来替换
cnrtMemset使用片上设置数据来替换
cnrtQueueSync可以去掉,对于同一个queue来说,queue内的kernel调用(使用<<<>>>)是串行的

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

type_size * size_a * lda * ((uint64_t)batch_size),
CNRT_MEM_TRANS_DIR_DEV2DEV));
}
}
}

cnrtQueueSync(queue);

return MLUOP_STATUS_SUCCESS;
}

// computes the Cholesky decomposition.
// This function is designed to handle both single and batch processing of
// matrices in either 2D or 3D formats. The function ensures that the input
// matrices are either float or complex float types and performs the
// decomposition either on the upper or lower triangular part of the matrix,
// based on the 'upper' boolean flag.
mluOpStatus_t MLUOP_WIN_API
mluOpCholesky(mluOpHandle_t handle, const mluOpTensorDescriptor_t input_desc,
float* d_input, const mluOpTensorDescriptor_t output_desc,
float* d_output, bool upper, void* workspace) {
PARAM_CHECK("mluOpCholesky", handle != NULL);
Copy link
Collaborator

@ArtIntAI ArtIntAI Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PARAM_CHECK("mluOpCholesky", input_desc != NULL);
PARAM_CHECK("mluOpCholesky", d_input != NULL);
PARAM_CHECK("mluOpCholesky", output_desc != NULL);
PARAM_CHECK("mluOpCholesky", d_output != NULL);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于d_output和d_input的是否为空的判断要放到最后,相关的参数检查的顺序可以参考这个算子:
https://github.com/Cambricon/mlu-ops/blob/master/kernels/mutual_information/mutual_information_backward/mutual_information_backward.cpp#L322

PARAM_CHECK("mluOpCholesky", input_desc->layout == MLUOP_LAYOUT_ARRAY);
PARAM_CHECK("mluOpCholesky", output_desc->layout == MLUOP_LAYOUT_ARRAY);

PARAM_CHECK("mluOpCholesky", input_desc->dim == 2 || input_desc->dim == 3);
PARAM_CHECK("mluOpCholesky", output_desc->dim == input_desc->dim);
PARAM_CHECK("mluOpCholesky", input_desc->dims[0] > 0);
PARAM_CHECK("mluOpCholesky", input_desc->dims[1] > 0);
PARAM_CHECK("mluOpCholesky", output_desc->dims[0] > 0);
PARAM_CHECK("mluOpCholesky", output_desc->dims[1] > 0);
if (input_desc->dim == 2) {
PARAM_CHECK("mluOpCholesky", input_desc->dims[0] == input_desc->dims[1]);
PARAM_CHECK("mluOpCholesky", output_desc->dims[0] == output_desc->dims[1]);
} else {
PARAM_CHECK("mluOpCholesky", input_desc->dims[1] == input_desc->dims[2]);
PARAM_CHECK("mluOpCholesky", output_desc->dims[1] == output_desc->dims[2]);
}

cnrtQueue_t queue;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用handle->queue,不能自己创建临时的

mluOpGetQueue(handle, &queue);

if (input_desc->dim == 3) {
PARAM_CHECK("mluOpCholesky", input_desc->dims[2] > 0);
PARAM_CHECK("mluOpCholesky", output_desc->dims[2] > 0);
}

mluOpDataType_t dtype = input_desc->dtype;
PARAM_CHECK("mluOpCholesky", dtype == output_desc->dtype);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参数检查缺少0元素和large tensor 处理,建议参考下

if (mluOpGetTensorElementNum(x_desc) == 0) {

PARAM_CHECK("mluOpCholesky",
dtype == MLUOP_DTYPE_FLOAT || dtype == MLUOP_DTYPE_COMPLEX_FLOAT);

int dim = input_desc->dim;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

缺少必要的阶段性VLOG信息,host上拼接的kenrel前后都加上VLOG(5),参考

VLOG(5) << "kernel Kernel3StagePipelineAbs";

int size_a = 0, lda = 0, size_c = 0, ldc = 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里用了int,对于单个维度超过int max,这里会出现overflow异常


int batch_size = 1;
if (dim == 2) {
size_a = input_desc->dims[0];
lda = input_desc->dims[1];
size_c = output_desc->dims[0];
ldc = output_desc->dims[1];
} else if (dim == 3) {
batch_size = input_desc->dims[0];
size_a = input_desc->dims[1];
lda = input_desc->dims[2];
size_c = output_desc->dims[1];
ldc = output_desc->dims[2];
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gencase代码当前算子对于相同规模但是输入真值不一样时有差异,使用GEN_CASE_DATA_REAL

calculate_body(handle, ((uint64_t)batch_size), input_desc, d_input,
output_desc, d_output, upper, (float*)workspace);
return MLUOP_STATUS_SUCCESS;
}
Loading
Loading