Skip to content

How to make an Optax optimizer in Flax? #994

Answered by marcvanzee
marcvanzee asked this question in Q&A
Discussion options

You must be logged in to vote

Answer by @levskaya: "below is a demo how to wrap an Optax optimizer in Flax. Be aware that Flax optimizers return updated params, whereas optax gradient transforms return updated gradients instead. So there's a slightly tricky, optimizer-dependent new_params, old_params --> delta calculation that you need to do, and to be safe you want to make sure the learning rate fed to this wrapped optimizer is ~1.0 so you don't get bad cancellation in reverting back to the grad from subtraction."

from functools import partial
import numpy as np

import jax
from jax import random, lax, numpy as jnp

import flax
from flax import linen as nn
from flax import optim

import optax
from optax import Gradie…

Replies: 1 comment

Comment options

marcvanzee
Feb 5, 2021
Maintainer Author

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant