Explicit naming of modules #2046
-
Give two modules: SimpleModule and WrapperModule, the former which takes an nn.Module as an argument, is it possible to explicitly list the nn.Module passed into SimpleModule under both the WrapperModule and SimpleModule? Here is a code snippet that clarifies my question: from flax import linen as nn
from jax import random
from flax.core import freeze, frozen_dict
from flax.core.frozen_dict import FrozenDict
# define a simple linear module that takes an nn.Dense module as an input
class SimpleModule(nn.Module):
features: int
dense1: nn.Dense
def setup(self):
# define a second dense layer
self.dense2 = nn.Dense(self.features)
def __call__(self, inputs, deterministic):
x = self.dense1(inputs)
x = self.dense2(x)
return x
# define a wrapper module that passes an nn.Dense module to SimpleModule
class WrapperModule(nn.Module):
features: int
def setup(self):
self.dense1 = nn.Dense(self.features)
self.module = SimpleModule(features=self.features, dense1=self.dense1)
def __call__(self, inputs):
x = self.module(inputs)
return x
# generate random keys
key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))
# initialise the model
model = WrapperModule(features=5)
params = model.init(key2, x)
# utility function for viewing the param tree
def print_tree(d, depth=0, print_value=False):
for k in d.keys():
if isinstance(d[k], FrozenDict):
print(' ' * depth, k)
print_tree(d[k], depth + 1, print_value)
else:
if print_value:
print(' ' * depth, k, d[k])
else:
print(' ' * depth, k)
print_tree(params) Output:
With this model definition, the nn.Module
The reason I ask this is that I am endeavouring to port a PyTorch model to Flax. For this, I need the PyTorch state dict to match the Flax param dict. For the demonstration example above, the equivalent PyTorch model is constructed as follows: import torch
from torch import nn
from typing import Optional
class SimpleModule(nn.Module):
def __init__(self, features: int, dense1: Optional[nn.Linear] = None):
super(SimpleModule, self).__init__()
if dense1 is not None:
self.dense1 = dense1
else:
self.dense1 = nn.Dense(features, features)
self.dense2 = nn.Linear(features, features)
def forward(self, inputs):
x = self.dense1(inputs)
x = self.dense2(x)
return x
class WrapperModule(nn.Module):
def __init__(self, features: int):
super(WrapperModule, self).__init__()
self.dense1 = nn.Linear(features, features)
self.module = SimpleModule(features=features, dense1=self.dense1)
def forward(self, inputs):
x = self.module(inputs)
return x
model = WrapperModule(features=5)
print(model) Output:
We see here that |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
Disclaimer: I am not a Pytorch expert. It seems a Pytorch state dict and Flax parameters are essential two different things, so you shouldn't need to make them the same if you want your Flax and Pytorch models to behave the same. Because Flax Modules are stateless, we do not use a state dict (Modules are instantiated with parameters only during If you want to make the state dict the same as the param dict, you can move the definition of the dense inside the submodule in the Pytorch code... @andsteing can maybe help here as well since he wrote the Pytorch conversion HOWTO -- Andreas, maybe we can add a note about this discussion in the HOWTO?) |
Beta Was this translation helpful? Give feedback.
Disclaimer: I am not a Pytorch expert.
It seems a Pytorch state dict and Flax parameters are essential two different things, so you shouldn't need to make them the same if you want your Flax and Pytorch models to behave the same.
Because Flax Modules are stateless, we do not use a state dict (Modules are instantiated with parameters only during
apply
andinit
, and immediately destroyed afterwards). A Flax parameter dict is exactly what the name says: a dictionary of all parameters used in the Module. While you define the Module insideWrapperModule
, the parameters are only use in the submodule, so that is where they will appear in the parameter dict. If you want to have a parameter dict w…