diff --git a/impl/torch/functions/functions.cpp b/impl/torch/functions/functions.cpp index 1d7ac869e..6ee8e104e 100644 --- a/impl/torch/functions/functions.cpp +++ b/impl/torch/functions/functions.cpp @@ -3429,7 +3429,8 @@ diopiError_t diopiAdam(diopiContextHandle_t ctx, diopiTensorHandle_t param, diop if (weight_decay != 0) { grad_d = grad_d.add(atParam, weight_decay); } - atExpAvg.mul_(beta1).add_(grad_d, 1 - beta1); + // atExpAvg.mul_(beta1).add_(grad_d, 1 - beta1); + atExpAvg.lerp_(grad_d, 1 - beta1); atExpAvgSq.mul_(beta2).addcmul_(grad_d, grad_d.conj(), 1 - beta2); at::Tensor denom; @@ -4036,19 +4037,6 @@ diopiError_t diopiNormBackward(diopiContextHandle_t ctx, diopiTensorHandle_t gra return diopiSuccess; } -/* -diopiError_t diopiNormBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_output, diopiConstTensorHandle_t input, diopiConstTensorHandle_t grad_input, -diopiConstTensorHandle_t result, const diopiScalar_t* p, diopiSize_t dim) { impl::aten::setCurStream(ctx); auto atInput = impl::aten::buildATen(input); auto -atGradInput = impl::aten::buildATen(grad_input); auto atP = impl::aten::buildAtScalar(p); auto atResult = impl::aten::buildATen(result); at::IntArrayRef atDim = -impl::aten::buildAtIntArray(dim); - - bool keepdim = true; - auto atGradOutput = torch::autograd::generated::details::norm_backward(atGradInput, atInput, atP, atResult, atDim, keepdim); - - return diopiSuccess; -} -*/ - diopiError_t diopiForeachnormScalar(diopiContextHandle_t ctx, diopiTensorHandle_t* outs, diopiConstTensorHandle_t* inputs, int64_t inputSize, const diopiScalar_t* ord) { DIOPI_CHECK_PTR(outs);