What is meant by that Flax Linen uses "shape inference"? #2077
-
This is a question that comes up often, and I think it is useful to clarify this in a discussion. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
We only declared the number of features we wanted in the output of the model, not the size of the input. Flax finds out by itself the correct size of the inputs. Example: We create one dense layer instance (taking 'features' parameter as input) model = nn.Dense(features=5) Parameters are not stored with the models themselves. You need to initialize parameters by calling the init function, using a PRNGKey and a dummy input parameter. key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,)) # Dummy input
params = model.init(key2, x) # Initialization call
jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes Ouputs:
The result is what we expect: bias and kernel parameters of the correct size. Under the hood: The dummy input variable |
Beta Was this translation helpful? Give feedback.
We only declared the number of features we wanted in the output of the model, not the size of the input. Flax finds out by itself the correct size of the inputs. Example:
We create one dense layer instance (taking 'features' parameter as input)
Parameters are not stored with the models themselves. You need to initialize parameters by calling the init function, using a PRNGKey and a dummy input parameter.
Ouputs: