Skip to content

What is meant by that Flax Linen uses "shape inference"? #2077

Answered by marcvanzee
marcvanzee asked this question in Q&A
Discussion options

You must be logged in to vote

We only declared the number of features we wanted in the output of the model, not the size of the input. Flax finds out by itself the correct size of the inputs. Example:

We create one dense layer instance (taking 'features' parameter as input)

model = nn.Dense(features=5)

Parameters are not stored with the models themselves. You need to initialize parameters by calling the init function, using a PRNGKey and a dummy input parameter.

key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,)) # Dummy input
params = model.init(key2, x) # Initialization call
jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes

Ouputs:

FrozenDict({
    params: {
        b…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant