Why this loss function does not take batch as input? #2319
-
@jax.jit
def train_step(state, batch):
"""Train for a single step."""
def loss_fn(params):
logits = CNN().apply({'params': params}, batch['image'])
loss = cross_entropy_loss(logits=logits, labels=batch['label'])
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
metrics = compute_metrics(logits=logits, labels=batch['label'])
return state, metrics My guess is that since we are inside jit all values including the batch sample are already traced, and gradient computation is only need to computed with regards to the parameters. The reason that we are passing parameters is solely for the purpose of telling the |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
It is not related to jitting, but indeed, jax.grad has the following explanation:
Note the default value of You could also pass the batch as a second argument to |
Beta Was this translation helpful? Give feedback.
It is not related to jitting, but indeed, jax.grad has the following explanation:
fun
: Function to be differentiated. Its arguments at positions specified byargnums
should be arrays, scalars, or standard Python containers. (...) . It should return a scalar (which includes arrays with shape () but not arrays with shape (1,) etc.)Note the default value of
argnums
is 0, so by defaultgrad
will return a function that evaluates the gradients of your input wrt the first argument, which in the case of your example isparams
.You could also pass the batch as a second argument to
loss_fn
(which would …