Access and replace layers #2906
Replies: 7 comments 1 reply
-
Hi @gianlucadetommaso, |
Beta Was this translation helpful? Give feedback.
-
By convention we define layers by code execution. Other groups have written simple template systems on top of flax to allow fine-grained config using advanced config systems like "gin" or "fiddle". But we don't offer that out of the box. |
Beta Was this translation helpful? Give feedback.
-
Thanks for letting me know. |
Beta Was this translation helpful? Give feedback.
-
Please bear with me and let me iterate on this once more. I would like to be sure there is no way around this issue. Consider a scenario where a user would like to inspect a |
Beta Was this translation helpful? Give feedback.
-
@gianlucadetommaso Flax's lazy behavior makes it such that you cannot do this in general. You can only replace things in when they are inputs to a module (i.e. dataclass fields): import flax.linen as nn
import jax
import jax.numpy as jnp
class Foo(nn.Module):
bar: nn.Module
def __call__(self, x):
return self.bar(x)
x = jnp.ones((1, 32, 32, 3))
module = Foo(bar=nn.Dense(10))
module.bar = nn.Conv(features=10, kernel_size=(3, 3), padding='SAME') # replace
variables = module.init(jax.random.PRNGKey(0), x)
print("\nFoo\n-----------")
print(jax.tree_map(lambda x: x.shape, variables))
#--------------------
# Sequential
#--------------------
module = nn.Sequential([
nn.Dense(10),
nn.relu,
nn.Dense(20),
])
module.layers = [
nn.Conv(features=10, kernel_size=(3, 3), padding='SAME') if isinstance(x, nn.Dense) else x
for x in module.layers
]
variables = module.init(jax.random.PRNGKey(0), x)
print("\nSequential\n-----------")
print(jax.tree_map(lambda x: x.shape, variables)) However, you cannot easily (if at all) replace modules that are defined inside |
Beta Was this translation helpful? Give feedback.
-
Thanks for the exhaustive answer! |
Beta Was this translation helpful? Give feedback.
-
@cgarciae As a follow up question, let's assume I can give up on working with an instantiated model, but rather I can work directly with the model class (e.g. For example, similarly to your example above, I may want to find all |
Beta Was this translation helpful? Give feedback.
-
Hello,
given an arbitrary Flax model, is it possible to access, and potentially replace, the layers it formed of?
For example, consider the following simple model:
Given
model
, is there any method that would allow to access the sequence of layers used (nn.Dense(features=10)
andnn.Dense(features=1)
), and potentially replace them?Thanks! 🙏
Beta Was this translation helpful? Give feedback.
All reactions