diff --git a/sgd.lua b/sgd.lua index e21c696..a93c5cf 100644 --- a/sgd.lua +++ b/sgd.lua @@ -36,6 +36,7 @@ function optim.sgd(opfunc, x, config, state) local nesterov = config.nesterov or false local lrs = config.learningRates local wds = config.weightDecays + local clip = config.clipval state.evalCounter = state.evalCounter or 0 local nevals = state.evalCounter assert(not nesterov or (mom > 0 and damp == 0), "Nesterov momentum requires a momentum and zero dampening") @@ -71,7 +72,13 @@ function optim.sgd(opfunc, x, config, state) -- (4) learning rate decay (annealing) local clr = lr / (1 + nevals*lrd) - -- (5) parameter update with single or individual learning rates + -- (5) gradient clipping + ndfdx = dfdx:norm() + if clip and ndfdx >= clip then + dfdx:div(ndfdx):mul(clip) + end + + -- (6) parameter update with single or individual learning rates if lrs then if not state.deltaParameters then state.deltaParameters = torch.Tensor():typeAs(x):resizeAs(dfdx) @@ -82,7 +89,7 @@ function optim.sgd(opfunc, x, config, state) x:add(-clr, dfdx) end - -- (6) update evaluation counter + -- (7) update evaluation counter state.evalCounter = state.evalCounter + 1 -- return x*, f(x) before optimization