Taking gradients with respect to network activations (not weights) #2501
Unanswered
awtaw5q25ASF
asked this question in
Q&A
Replies: 1 comment
-
There is a new from typing import Callable
import jax
import jax.numpy as jnp
import flax.linen as nn
class Block(nn.Module):
units: int
activation: Callable
@nn.compact
def __call__(self, x):
x = nn.Dense(features=self.units)(x)
x = nn.relu(x)
x = self.perturb('act', x)
return x
class Model(nn.Module):
units: int
num_blocks: int
activation: Callable
@nn.compact
def __call__(self, x):
for _ in range(self.num_blocks):
x = Block(units=self.units, activation=self.activation)(x)
return x
x = jnp.ones((1, 4))
y = jax.random.normal(jax.random.PRNGKey(0), (1, 3))
model = Model(units=3, num_blocks=2, activation=nn.relu)
variables = model.init(jax.random.PRNGKey(0), jnp.ones((1, 4)))
def loss_fn(params, perturbations, x, y):
y_pred = model.apply({'params': params, 'perturbations': perturbations}, x)
return jnp.mean((y_pred - y) ** 2)
activation_grads = jax.grad(loss_fn, argnums=1)(variables['params'], variables['perturbations'], x, y)
print(activation_grads) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Consider a feedforward neural network
where each
sigma_n
is an activation function, and eachf_n
is a layer (e.g., linear). We can define the network's activations recursively aswhere
a_0 = x
andsigma_0
is the identity.Suppose we also have a loss function
L(y, z(x))
.Then my question is: whether we can use jax and flax to easily calculate the set of gradients
I.e., the gradient of the loss with respect to each of its activations?
Thank you!
Beta Was this translation helpful? Give feedback.
All reactions