diff --git a/diopi_test/python/configs/diopi_configs.py b/diopi_test/python/configs/diopi_configs.py index 0ad81385e..39e9c4204 100755 --- a/diopi_test/python/configs/diopi_configs.py +++ b/diopi_test/python/configs/diopi_configs.py @@ -5552,8 +5552,8 @@ 'fused_adamw': dict( name=['fused_adamw'], interface=["CustomizedTest"], - atol=1e-2, - rtol=2e-3, + atol=3e-5, + rtol=3e-5, atol_half=1e-2, rtol_half=2e-3, para=dict( diff --git a/diopi_test/python/conformance/customized_test.py b/diopi_test/python/conformance/customized_test.py index 3f351e27c..62a4f9141 100644 --- a/diopi_test/python/conformance/customized_test.py +++ b/diopi_test/python/conformance/customized_test.py @@ -163,7 +163,7 @@ def fused_adamw( amsgrad, maximize, ): - torch.optim._functional.adamw( + torch._fused_adamw_( params, grads, exp_avgs, @@ -171,13 +171,12 @@ def fused_adamw( max_exp_avg_sqs, state_steps, amsgrad=amsgrad, + lr=lr, beta1=beta1, beta2=beta2, - lr=lr, weight_decay=weight_decay, eps=eps, maximize=maximize, - fused=True, ) return params, exp_avgs, exp_avg_sqs, max_exp_avg_sqs diff --git a/impl/torch/functions/functions.cpp b/impl/torch/functions/functions.cpp index 6ee8e104e..7e5a46111 100644 --- a/impl/torch/functions/functions.cpp +++ b/impl/torch/functions/functions.cpp @@ -3360,7 +3360,7 @@ diopiError_t diopiGridSample(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiError_t diopiFusedAdamW(diopiContextHandle_t ctx, diopiTensorHandle_t* params, diopiConstTensorHandle_t* grads, diopiTensorHandle_t* exp_avgs, diopiTensorHandle_t* exp_avg_sqs, diopiTensorHandle_t* max_exp_avg_sqs, diopiConstTensorHandle_t* state_steps, int64_t nums, - float lr, float beta1, float beta2, float eps, float weight_decay, bool amsgrad, bool maximize) { + double lr, double beta1, double beta2, double eps, double weight_decay, bool amsgrad, bool maximize) { impl::aten::setCurStream(ctx); DIOPI_CHECK_PTR(params); DIOPI_IMPL_BUILD_ATEN_LIST(atParam, params, nums); diff --git a/proto/include/diopi/functions.h b/proto/include/diopi/functions.h index 4f7dfcecb..e66d12f5e 100644 --- a/proto/include/diopi/functions.h +++ b/proto/include/diopi/functions.h @@ -2829,7 +2829,7 @@ DIOPI_API diopiError_t diopiReciprocalInp(diopiContextHandle_t ctx, diopiTensorH */ DIOPI_API diopiError_t diopiFusedAdamW(diopiContextHandle_t ctx, diopiTensorHandle_t* params, diopiConstTensorHandle_t* grads, diopiTensorHandle_t* exp_avgs, diopiTensorHandle_t* exp_avg_sqs, diopiTensorHandle_t* max_exp_avg_sqs, diopiConstTensorHandle_t* state_steps, - int64_t nums, float lr, float beta1, float beta2, float eps, float weight_decay, bool amsgrad, bool maximize); + int64_t nums, double lr, double beta1, double beta2, double eps, double weight_decay, bool amsgrad, bool maximize); /** * @brief The function is used to implement the AdamW optimizer. Its functionality is to perform a single parameter update.