Control Flow in Module methods #1754
-
Hi there, I am trying to understand how to use control flow in Flax. I am porting something from PyTorch, that when literally translated to Flax would look like this: class TestModel(nn.Module):
def setup(self):
self.d1 = nn.Dense(4, name="first")
self.d2 = nn.Dense(4, name="second")
self.d3 = nn.Dense(2, name="third")
def __call__(self, x, y):
if y is None:
p = self.d1(x)
else:
p = self.d2(x+y)
return self.d3(p) Now the initialization will either return parameters for "first" and "third" (if I init with y set to None) or for "second" and "third" (otherwise). For instance this would fail: m = TestModel()
x = jax.random.randint(jax.random.PRNGKey(1), (7, 11), 1, 4)
y = jax.random.randint(jax.random.PRNGKey(2), (7, 11), 1, 4)
params = m.init(jax.random.PRNGKey(3), x, y)
results = m.apply(params, x, None) # --> No parameter named "kernel" exists in "/first". How is one supposed to handle control flow in this case? I was looking into jax.lax.cond, but quickly stopped because I think it's not supported in Flax. [EDIT below] class TestModel(nn.Module):
def setup(self):
self.d1 = nn.Dense(4, name="first")
self.d2 = nn.Dense(4, name="second")
self.d3 = nn.Dense(2, name="third")
def __call__(self, x, y, ignore_y = True):
p1 = self.d1(x)
p2 = self.d2(x+y)
p = jnp.where(ignore_y, x, x+y)
return self.d3(p) where the variable y would be set to zeros instead of None when it's supposed to be ignored. m = TestModel()
x = jax.random.randint(jax.random.PRNGKey(1), (7, 11), 1, 4)
y = jax.random.randint(jax.random.PRNGKey(2), (7, 11), 1, 4)
params = m.init(jax.random.PRNGKey(3), x, y)
results_no_ignore = m.apply(params, x, y, ignore_y = False)
results_ignore = m.apply(params, x, jnp.zeros_like(y), ignore_y = True) Is there a better way? Thanks for your help! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Parameter initialization is lazy in Flax. In this case for Dense the input features are derived from the input. However, due to the control flow not all layers are initialized during init that are used during apply. There are a bunch of ways to fix this. I would avoid using where because it means you will always pay the cost of both Dense layers. You could also use a dedicated init function to make sure everything is initialized like this:
|
Beta Was this translation helpful? Give feedback.
Parameter initialization is lazy in Flax. In this case for Dense the input features are derived from the input. However, due to the control flow not all layers are initialized during init that are used during apply.
There are a bunch of ways to fix this. I would avoid using where because it means you will always pay the cost of both Dense layers.
One fix is to call both layers in both branches but discard one output (the dummy call will be optimized away when using jit or pmap).
You could also use a dedicated init function to make sure everything is initialized like this: