How one can pass prng key to custom submodule while initialising? #3036
-
Hello! I am impleneting custom layer, which uses this function def bimap_init(n, m, key):
Q,_ = jnp.linalg.qr(random.uniform(key, shape=(n, n)))
Q = Q[:m, :]
return Q.T In class BiMapLayer(nn.Module):
out_dim: int
matrix_init: Callable = bimap_init
@nn.compact
def __call__(self, inputs):
def quadratic_form(w, X):
oper_1 = jax.vmap(lambda w, X: w.T @ X, (None, 0), 0)
oper_2 = jax.vmap(lambda X, w: X @ w, (0, None), 0)
return oper_2(oper_1(w, X), w)
mapping_matrix = self.param('Map',
self.matrix_init, # Initialization function for Orthogonal matrix
(inputs.shape[-1], self.out_dim, ???)) # shape info. Attempt with self.make_rng('Map') failed
y = quadratic_form(mapping_matrix, inputs)
return y Perhaps there is a way to get the key, which is passed in |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Hey @MarioAuditore, the only requirement for param initialization functions is that import flax.linen as nn
import jax
import jax.numpy as jnp
from jax import random
from typing import Callable
def bimap_init(key, n, m):
Q,_ = jnp.linalg.qr(random.uniform(key, shape=(n, n)))
Q = Q[:m, :]
return Q.T
class BiMapLayer(nn.Module):
out_dim: int
matrix_init: Callable = bimap_init
@nn.compact
def __call__(self, inputs):
def quadratic_form(w, X):
oper_1 = jax.vmap(lambda w, X: w.T @ X, (None, 0), 0)
oper_2 = jax.vmap(lambda X, w: X @ w, (0, None), 0)
return oper_2(oper_1(w, X), w)
mapping_matrix = self.param('Map',
self.matrix_init, # Initialization function for Orthogonal matrix
inputs.shape[-1], self.out_dim) # shape info. Attempt with self.make_rng('Map') failed
y = quadratic_form(mapping_matrix, inputs)
return y
module = BiMapLayer(out_dim=3)
x = jnp.ones((2, 2))
y, variables = module.init_with_output(random.PRNGKey(0), x)
print(y)
print(jax.tree_map(lambda x: x.shape, variables)) |
Beta Was this translation helpful? Give feedback.
Hey @MarioAuditore, the only requirement for param initialization functions is that
key
is the first parameter, and the rest of the arguments are the*args
passed to.param()
. Here is a working example: