Skip to content

Why this loss function does not take batch as input? #2319

Answered by marcvanzee
uduse asked this question in Q&A
Discussion options

You must be logged in to vote

It is not related to jitting, but indeed, jax.grad has the following explanation:

  • Input fun: Function to be differentiated. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers. (...) . It should return a scalar (which includes arrays with shape () but not arrays with shape (1,) etc.)
  • Returns: A function with the same arguments as fun, that evaluates the gradient of fun.

Note the default value of argnums is 0, so by default grad will return a function that evaluates the gradients of your input wrt the first argument, which in the case of your example is params.

You could also pass the batch as a second argument to loss_fn (which would …

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@uduse
Comment options

@cgarciae
Comment options

Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants