Question about LBFGS #307
-
|
Hello JAXopt Team, |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
|
Hi DoTulip The tool The For neural network I assume you are interested in a stochastic variant of |
Beta Was this translation helpful? Give feedback.
Hi DoTulip
The tool
jaxopt.ScipyMinimizeis just a wrapper for Scipy - it is equivalent to callingScipy.minimizeon your function directly (same code is running hunder the hood). In particular this code is not jittable, does not benefit from GPU/TPU speed up. The only exception with Scipy is that it is actually possible to differentiate through the wrapper thanks to implicit differentiation.The
jaxopt.LBFGSis a pure re-implementation of L-BFGS in Jax: it is differentiable, run on GPU/TPU, can be wrapped injax.jit. This should be your preferred tool if performance is an issue (this definitively what you want to use for thetrain_stepfunction of your neural network).For neural network…