Skip to content

Commit

Permalink
zq/fix_declaration_for_some_func (DeepLink-org#1141)
Browse files Browse the repository at this point in the history
* fix_declaration_for_some_func

* fix gen_output and diopi_functions for optimizer_funcs

* reformat customized_test
  • Loading branch information
NeosZhang authored Apr 16, 2024
1 parent b4126ce commit 510359e
Show file tree
Hide file tree
Showing 14 changed files with 945 additions and 779 deletions.
572 changes: 572 additions & 0 deletions diopi_test/python/conformance/customized_test.py

Large diffs are not rendered by default.

541 changes: 245 additions & 296 deletions diopi_test/python/conformance/diopi_functions.py

Large diffs are not rendered by default.

491 changes: 66 additions & 425 deletions diopi_test/python/conformance/gen_output.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion impl/ascend_npu/diopi_impl/functions_ext/adamw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

namespace OP_IMPL_NS {

diopiError_t diopiAdamW(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiTensorHandle_t grad, diopiTensorHandle_t expAvg, diopiTensorHandle_t expAvgSq,
diopiError_t diopiAdamW(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t grad, diopiTensorHandle_t expAvg,
diopiTensorHandle_t expAvgSq,

diopiTensorHandle_t maxExpAvgSq, float lr, float beta1, float beta2, float eps, float weightDecay, int64_t step, bool amsgrad) {
DIOPI_CHECK(amsgrad == false, "at present, ApplyAdamW only supports amsgrad false on ascend.");
BEGIN_CALL_ACL_OP(input, grad, expAvg, expAvgSq, maxExpAvgSq);
Expand Down
4 changes: 2 additions & 2 deletions impl/camb/device_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,15 +1110,15 @@
'adadelta': dict(
name=["adadelta"],
atol_half=1e-1,
rtol_half=1e-3,
rtol_half=1e-1,
atol=1e-1,
rtol=1e-3,
tensor_para=dict(
args=[
{
# can't get correct result
"ins": ['param', 'param_grad'],
"dtype": [Skip(np.float16)],
"dtype": [Skip(np.float64)],
},
]
),
Expand Down
2 changes: 1 addition & 1 deletion impl/camb/functions/adadelta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
namespace impl {
namespace camb {

diopiError_t diopiAdadelta(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiTensorHandle_t grad, diopiTensorHandle_t squareAvg,
diopiError_t diopiAdadelta(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t grad, diopiTensorHandle_t squareAvg,
diopiTensorHandle_t accDelta, float lr, float rho, float eps, float weightDecay) {
cnnlHandle_t handle = cnnlHandlePool.get(ctx);
DiopiTensor inputTensor = DiopiTensor(input);
Expand Down
15 changes: 9 additions & 6 deletions impl/camb/functions/adam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ void bangAdamInternal(void* grad, void* m, void* v, void* vMax, void* variable,
float learningRateCorrection, int adamMode, float decay, float decayCorrection, cnrtDim3_t kDim, cnrtFunctionType_t kType,
cnrtQueue_t queue, cnrtDataType_t cnrtType, bool amsgrad);

diopiError_t bangAdam(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiTensorHandle_t grad, diopiTensorHandle_t expAvg, diopiTensorHandle_t expAvgSq,
diopiTensorHandle_t maxExpAvgSq, float lr, float beta1, float beta2, float eps, float weightDecay, int64_t step, bool amsgrad,
int adamMode = 0) {
diopiError_t bangAdam(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t grad, diopiTensorHandle_t expAvg,
diopiTensorHandle_t expAvgSq, diopiTensorHandle_t maxExpAvgSq, float lr, float beta1, float beta2, float eps, float weightDecay,
int64_t step, bool amsgrad, int adamMode = 0) {
cnrtQueue_t queue = getStream(ctx);
DiopiTensor inputTensor = DiopiTensor(input);
DiopiTensor gradTensor = DiopiTensor(grad);
Expand Down Expand Up @@ -92,13 +92,16 @@ diopiError_t bangAdam(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopi
return diopiSuccess;
}

diopiError_t diopiAdam(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiTensorHandle_t grad, diopiTensorHandle_t expAvg, diopiTensorHandle_t expAvgSq,
diopiTensorHandle_t maxExpAvgSq, float lr, float beta1, float beta2, float eps, float weightDecay, int64_t step, bool amsgrad) {
diopiError_t diopiAdam(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t grad, diopiTensorHandle_t expAvg,
diopiTensorHandle_t expAvgSq, diopiTensorHandle_t maxExpAvgSq, float lr, float beta1, float beta2, float eps, float weightDecay,
int64_t step, bool amsgrad) {
DIOPI_CALL(bangAdam(ctx, input, grad, expAvg, expAvgSq, maxExpAvgSq, lr, beta1, beta2, eps, weightDecay, step, amsgrad, 0));
return diopiSuccess;
}

diopiError_t diopiAdamW(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiTensorHandle_t grad, diopiTensorHandle_t expAvg, diopiTensorHandle_t expAvgSq,
diopiError_t diopiAdamW(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t grad, diopiTensorHandle_t expAvg,
diopiTensorHandle_t expAvgSq,

diopiTensorHandle_t maxExpAvgSq, float lr, float beta1, float beta2, float eps, float weightDecay, int64_t step, bool amsgrad) {
DIOPI_CALL(bangAdam(ctx, input, grad, expAvg, expAvgSq, maxExpAvgSq, lr, beta1, beta2, eps, weightDecay, step, amsgrad, 1));
return diopiSuccess;
Expand Down
6 changes: 3 additions & 3 deletions impl/camb_pytorch/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1923,7 +1923,7 @@ diopiError_t diopiMaskedFillInpScalar(diopiContextHandle_t ctx, diopiTensorHandl
return diopiSuccess;
}

diopiError_t diopiAdamW(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiTensorHandle_t grad, diopiTensorHandle_t exp_avg,
diopiError_t diopiAdamW(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t grad, diopiTensorHandle_t exp_avg,
diopiTensorHandle_t exp_avg_sq, diopiTensorHandle_t max_exp_avg_sq, float lr, float beta1, float beta2, float eps, float weight_decay,
int64_t step, bool amsgrad) {
camb::aten::setCurCtx(ctx);
Expand Down Expand Up @@ -1959,7 +1959,7 @@ diopiError_t diopiAdamW(diopiContextHandle_t ctx, diopiTensorHandle_t input, dio
return diopiSuccess;
}

diopiError_t diopiAdam(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiTensorHandle_t grad, diopiTensorHandle_t exp_avg,
diopiError_t diopiAdam(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t grad, diopiTensorHandle_t exp_avg,
diopiTensorHandle_t exp_avg_sq, diopiTensorHandle_t max_exp_avg_sq, float lr, float beta1, float beta2, float eps, float weight_decay,
int64_t step, bool amsgrad) {
camb::aten::setCurCtx(ctx);
Expand Down Expand Up @@ -1998,7 +1998,7 @@ diopiError_t diopiAdam(diopiContextHandle_t ctx, diopiTensorHandle_t input, diop
return diopiSuccess;
}

diopiError_t diopiAdadelta(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiTensorHandle_t grad, diopiTensorHandle_t square_avg,
diopiError_t diopiAdadelta(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t grad, diopiTensorHandle_t square_avg,
diopiTensorHandle_t acc_delta, float lr, float rho, float eps, float weight_decay) {
camb::aten::setCurCtx(ctx);
auto atInput = camb::aten::buildATen(input);
Expand Down
4 changes: 2 additions & 2 deletions impl/topsrider/functions/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@ DIOPI_API diopiError_t diopiExpand(diopiContextHandle_t ctx, diopiTensorHandle_t
return impl::tops::topsExpand(ctx, out, input);
}

DIOPI_API diopiError_t diopiAdam(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiTensorHandle_t grad, diopiTensorHandle_t exp_avg,
DIOPI_API diopiError_t diopiAdam(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t grad, diopiTensorHandle_t exp_avg,
diopiTensorHandle_t exp_avg_sq, diopiTensorHandle_t max_exp_avg_sq, float lr, float beta1, float beta2, float eps,
float weight_decay, int64_t step, bool amsgrad) {
TOPSOP_LOG();
Expand All @@ -739,7 +739,7 @@ DIOPI_API diopiError_t diopiAddInpScalar(diopiContextHandle_t ctx, diopiTensorHa
return impl::tops::topsAddInpScalar(ctx, input, other, alpha);
}

DIOPI_API diopiError_t diopiAdamW(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiTensorHandle_t grad, diopiTensorHandle_t exp_avg,
DIOPI_API diopiError_t diopiAdamW(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t grad, diopiTensorHandle_t exp_avg,
diopiTensorHandle_t exp_avg_sq, diopiTensorHandle_t max_exp_avg_sq, float lr, float beta1, float beta2, float eps,
float weight_decay, int64_t step, bool amsgrad) {
TOPSOP_LOG();
Expand Down
4 changes: 2 additions & 2 deletions impl/topsrider/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ DIOPI_API diopiError_t topsSgd(diopiContextHandle_t ctx, diopiTensorHandle_t w,

DIOPI_API diopiError_t topsExpand(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input);

DIOPI_API diopiError_t topsAdam(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiTensorHandle_t grad, diopiTensorHandle_t exp_avg,
DIOPI_API diopiError_t topsAdam(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t grad, diopiTensorHandle_t exp_avg,
diopiTensorHandle_t exp_avg_sq, diopiTensorHandle_t max_exp_avg_sq, float lr, float beta1, float beta2, float eps,
float weight_decay, int64_t step, bool amsgrad);

Expand All @@ -321,7 +321,7 @@ DIOPI_API diopiError_t topsMaxAll(diopiContextHandle_t ctx, diopiTensorHandle_t

DIOPI_API diopiError_t topsAddInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t input, const diopiScalar_t *other, const diopiScalar_t *alpha);

DIOPI_API diopiError_t topsAdamW(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiTensorHandle_t grad, diopiTensorHandle_t exp_avg,
DIOPI_API diopiError_t topsAdamW(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t grad, diopiTensorHandle_t exp_avg,
diopiTensorHandle_t exp_avg_sq, diopiTensorHandle_t max_exp_avg_sq, float lr, float beta1, float beta2, float eps,
float weight_decay, int64_t step, bool amsgrad);

Expand Down
8 changes: 4 additions & 4 deletions impl/torch/functions/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2730,7 +2730,7 @@ diopiError_t diopiMeshGrid(diopiContextHandle_t ctx, diopiTensorHandle_t* outs,
return diopiSuccess;
}

diopiError_t diopiAdamW(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiTensorHandle_t grad, diopiTensorHandle_t exp_avg,
diopiError_t diopiAdamW(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t grad, diopiTensorHandle_t exp_avg,
diopiTensorHandle_t exp_avg_sq, diopiTensorHandle_t max_exp_avg_sq, float lr, float beta1, float beta2, float eps, float weight_decay,
int64_t step, bool amsgrad) {
impl::aten::setCurStream(ctx);
Expand Down Expand Up @@ -2759,7 +2759,7 @@ diopiError_t diopiAdamW(diopiContextHandle_t ctx, diopiTensorHandle_t input, dio
return diopiSuccess;
}

diopiError_t diopiAdam(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiTensorHandle_t grad, diopiTensorHandle_t exp_avg,
diopiError_t diopiAdam(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t grad, diopiTensorHandle_t exp_avg,
diopiTensorHandle_t exp_avg_sq, diopiTensorHandle_t max_exp_avg_sq, float lr, float beta1, float beta2, float eps, float weight_decay,
int64_t step, bool amsgrad) {
impl::aten::setCurStream(ctx);
Expand Down Expand Up @@ -2791,7 +2791,7 @@ diopiError_t diopiAdam(diopiContextHandle_t ctx, diopiTensorHandle_t input, diop
return diopiSuccess;
}

diopiError_t diopiAdadelta(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiTensorHandle_t grad, diopiTensorHandle_t square_avg,
diopiError_t diopiAdadelta(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t grad, diopiTensorHandle_t square_avg,
diopiTensorHandle_t acc_delta, float lr, float rho, float eps, float weight_decay) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
Expand All @@ -2813,7 +2813,7 @@ diopiError_t diopiAdadelta(diopiContextHandle_t ctx, diopiTensorHandle_t input,
return diopiSuccess;
}

diopiError_t diopiRmsprop(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiTensorHandle_t grad, diopiTensorHandle_t square_avg,
diopiError_t diopiRmsprop(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t grad, diopiTensorHandle_t square_avg,
diopiTensorHandle_t grad_avg, diopiTensorHandle_t momentum_buf, float lr, float alpha, float eps, float weight_decay, float momentum,
bool centered) {
impl::aten::setCurStream(ctx);
Expand Down
Loading

0 comments on commit 510359e

Please sign in to comment.