L1/L2 regularization of network weights with NNX #4160
Answered
by
cgarciae
pushkar5586
asked this question in
Q&A
-
Hi all! Could you please share a sample code / example showing how to apply L1/L2 regularization on network weights using the NNX API? Many thanks! |
Beta Was this translation helpful? Give feedback.
Answered by
cgarciae
Sep 5, 2024
Replies: 1 comment 2 replies
-
Hey @pushkar5586, sorry for the delay. To apply global regularization you could use from flax import nnx
import optax
import
class Model(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.2, rngs=rngs)
self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = nnx.relu(self.dropout(self.bn(self.linear(x))))
return self.linear_out(x)
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing
def l2_loss(x, alpha):
return alpha * (x ** 2).sum()
@nnx.jit # automatic state management
def train_step(model, optimizer, x, y):
def loss_fn(model):
y_pred = model(x)
loss = ((y_pred - y) ** 2).mean() # model loss
loss += sum(
l2_loss(w, alpha=0.001)
for w in jax.tree_leaves(nnx.state(model, nnx.Param))
)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads) # inplace updates
return loss |
Beta Was this translation helpful? Give feedback.
2 replies
Answer selected by
pushkar5586
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hey @pushkar5586, sorry for the delay. To apply global regularization you could use
nnx.state
to extract theParam
s and then follow recipe from #1654. Here is the basic example on the landing page with L2 regularization: