Skip to content

How does nn.vjp works ? #2217

Answered by jheek
pablo2909 asked this question in Q&A
Jun 22, 2022 · 1 comments · 4 replies
Discussion options

You must be logged in to vote

In this example you are only calling init. Because the params aren't initialized yet you will get a vjp for h but an empty dict as the vjp for params. When you call vf1.apply, z2 will contain a tangent for the params as well.

You can also preinitialize the weights to get a param vjp even during init:

class JVP_VF_params(nn.Module):
    @nn.compact
    def __call__(self, h, t, x):
        vf = VF(init_val=-3)
        vf(h) # make sure params are initialized
        primal_out, vjpfun_params = nn.vjp(lambda mdl, h: mdl(h, t), vf, h)
        print(f"primal out {primal_out}")
        z1, z2 = vjpfun_params(x)
        print(z1, z2)
        return z1

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@pablo2909
Comment options

@pablo2909
Comment options

@pablo2909
Comment options

@pablo2909
Comment options

Answer selected by pablo2909
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants