Skip to content

How one can pass prng key to custom submodule while initialising? #3036

Answered by cgarciae
MarioAuditore asked this question in Q&A
Discussion options

You must be logged in to vote

Hey @MarioAuditore, the only requirement for param initialization functions is that key is the first parameter, and the rest of the arguments are the *args passed to .param(). Here is a working example:

import flax.linen as nn
import jax
import jax.numpy as jnp
from jax import random
from typing import Callable

def bimap_init(key, n, m):
    Q,_ = jnp.linalg.qr(random.uniform(key, shape=(n, n)))
    Q = Q[:m, :]
    return Q.T

class BiMapLayer(nn.Module):
    out_dim: int
    matrix_init: Callable = bimap_init

    @nn.compact
    def __call__(self, inputs):
        
        def quadratic_form(w, X):
            oper_1 = jax.vmap(lambda w, X: w.T @ X, (None, 0), 0)
            oper_2 = jax

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by MarioAuditore
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants