Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix!(proto): change fused_adamw's arg type float->double #1359

Merged
merged 1 commit into from
Nov 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
fix adamw
  • Loading branch information
ustclight-sls committed Nov 6, 2024
commit aab0ca0239c9e76b35d9465a48e555bcd5297d38
4 changes: 2 additions & 2 deletions diopi_test/python/configs/diopi_configs.py
Original file line number Diff line number Diff line change
@@ -5552,8 +5552,8 @@
'fused_adamw': dict(
name=['fused_adamw'],
interface=["CustomizedTest"],
atol=1e-2,
rtol=2e-3,
atol=3e-5,
rtol=3e-5,
atol_half=1e-2,
rtol_half=2e-3,
para=dict(
5 changes: 2 additions & 3 deletions diopi_test/python/conformance/customized_test.py
Original file line number Diff line number Diff line change
@@ -163,21 +163,20 @@ def fused_adamw(
amsgrad,
maximize,
):
torch.optim._functional.adamw(
torch._fused_adamw_(
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
lr=lr,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
eps=eps,
maximize=maximize,
fused=True,
)
return params, exp_avgs, exp_avg_sqs, max_exp_avg_sqs

2 changes: 1 addition & 1 deletion impl/torch/functions/functions.cpp
Original file line number Diff line number Diff line change
@@ -3360,7 +3360,7 @@ diopiError_t diopiGridSample(diopiContextHandle_t ctx, diopiTensorHandle_t out,

diopiError_t diopiFusedAdamW(diopiContextHandle_t ctx, diopiTensorHandle_t* params, diopiConstTensorHandle_t* grads, diopiTensorHandle_t* exp_avgs,
diopiTensorHandle_t* exp_avg_sqs, diopiTensorHandle_t* max_exp_avg_sqs, diopiConstTensorHandle_t* state_steps, int64_t nums,
float lr, float beta1, float beta2, float eps, float weight_decay, bool amsgrad, bool maximize) {
double lr, double beta1, double beta2, double eps, double weight_decay, bool amsgrad, bool maximize) {
impl::aten::setCurStream(ctx);
DIOPI_CHECK_PTR(params);
DIOPI_IMPL_BUILD_ATEN_LIST(atParam, params, nums);
2 changes: 1 addition & 1 deletion proto/include/diopi/functions.h
Original file line number Diff line number Diff line change
@@ -2829,7 +2829,7 @@ DIOPI_API diopiError_t diopiReciprocalInp(diopiContextHandle_t ctx, diopiTensorH
*/
DIOPI_API diopiError_t diopiFusedAdamW(diopiContextHandle_t ctx, diopiTensorHandle_t* params, diopiConstTensorHandle_t* grads, diopiTensorHandle_t* exp_avgs,
diopiTensorHandle_t* exp_avg_sqs, diopiTensorHandle_t* max_exp_avg_sqs, diopiConstTensorHandle_t* state_steps,
int64_t nums, float lr, float beta1, float beta2, float eps, float weight_decay, bool amsgrad, bool maximize);
int64_t nums, double lr, double beta1, double beta2, double eps, double weight_decay, bool amsgrad, bool maximize);

/**
* @brief The function is used to implement the AdamW optimizer. Its functionality is to perform a single parameter update.