Using bind
within loss functions instead of apply?
#3472
Replies: 1 comment
-
Could you define the
Flax encourages a functional programming paradigm to make it compatible with Jax. Using
Whereas writing it functionally is a little more explicit and clear that the
I think in the example you provided using |
Beta Was this translation helpful? Give feedback.
-
In my code, I have a module that has nested submodules, something like the following:
Right now, in my train step, to make sure I can compute the losses on both submodules, I have to jump through some hoops by defining an intermediate function, and calling nn.Module.apply using this function
It seems that if
nn.Module.bind
works as advertised, then I can write the code a lot simplerbut it seems from the documentation that
module.bind()
is highly discouraged and only for interactive sessions?Why is
bind
discouraged within more traditional training workflows? Does this type of design pattern withmodule.bind()
not work in Flax? What are the gotchas / where it fails?Beta Was this translation helpful? Give feedback.
All reactions