Why remat
modules must take a different name
#2251
-
When wrapping an If we first define a two network linear model without using remat: import flax.linen as nn
from flax.linen import partitioning as nn_partitioning
import jax
remat = nn_partitioning.remat
class Net1(nn.Module):
features: int
def setup(self):
self.linear = nn.Dense(features=self.features)
def __call__(self, x):
return self.linear(x)
class Net2(nn.Module):
features: int
num_layers: int
def setup(self):
self.layers = [Net1(features=self.features) for layer in range(self.num_layers)]
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
rng = jax.random.PRNGKey(0)
x = jax.random.normal(rng, (1, 3))
model = Net2(features=3, num_layers=2)
params = model.init(rng, x)
out = model.apply(params, x)
print(out.shape) Great, works as expected! If we now wrap class RematNet2(nn.Module):
features: int
num_layers: int
gradient_checkpointing: bool = False
def setup(self):
if self.gradient_checkpointing:
Net1 = remat(Net1)
self.layers = [Net1(features=self.features) for layer in range(self.num_layers)]
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
model = RematNet2(features=3, num_layers=2, gradient_checkpointing=True)
params = model.init(rng, x)
out = model.apply(params, x)
print(out.shape) Traceback---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
[<ipython-input-65-9c490341f785>](https://jlqquoe2ip-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20220627-060046-RC00_457442455#) in <module>()
1 model = RematNet2(features=3, num_layers=2, gradient_checkpointing=True)
----> 2 params = model.init(rng, x)
3 out = model.apply(params, x)
16 frames
[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://jlqquoe2ip-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20220627-060046-RC00_457442455#) in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
--> 162 return fun(*args, **kwargs)
163 except Exception as e:
[/usr/local/lib/python3.7/dist-packages/flax/linen/module.py](https://jlqquoe2ip-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20220627-060046-RC00_457442455#) in init(self, rngs, method, mutable, *args, **kwargs)
1228 rngs, *args,
-> 1229 method=method, mutable=mutable, **kwargs)
1230 return v_out
[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://jlqquoe2ip-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20220627-060046-RC00_457442455#) in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
--> 162 return fun(*args, **kwargs)
163 except Exception as e:
[/usr/local/lib/python3.7/dist-packages/flax/linen/module.py](https://jlqquoe2ip-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20220627-060046-RC00_457442455#) in init_with_output(self, rngs, method, mutable, *args, **kwargs)
1194 return self.apply(
-> 1195 {}, *args, rngs=rngs, method=method, mutable=mutable, **kwargs)
1196
[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://jlqquoe2ip-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20220627-060046-RC00_457442455#) in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
--> 162 return fun(*args, **kwargs)
163 except Exception as e:
[/usr/local/lib/python3.7/dist-packages/flax/linen/module.py](https://jlqquoe2ip-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20220627-060046-RC00_457442455#) in apply(self, variables, rngs, method, mutable, capture_intermediates, *args, **kwargs)
1161 mutable=mutable, capture_intermediates=capture_intermediates
-> 1162 )(variables, *args, **kwargs, rngs=rngs)
1163
[/usr/local/lib/python3.7/dist-packages/flax/core/scope.py](https://jlqquoe2ip-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20220627-060046-RC00_457442455#) in wrapper(variables, rngs, *args, **kwargs)
830 with bind(variables, rngs=rngs, mutable=mutable).temporary() as root:
--> 831 y = fn(root, *args, **kwargs)
832 if mutable is not False:
[/usr/local/lib/python3.7/dist-packages/flax/linen/module.py](https://jlqquoe2ip-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20220627-060046-RC00_457442455#) in scope_fn(scope, *args, **kwargs)
1534 try:
-> 1535 return fn(module.clone(parent=scope), *args, **kwargs)
1536 finally:
[/usr/local/lib/python3.7/dist-packages/flax/linen/transforms.py](https://jlqquoe2ip-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20220627-060046-RC00_457442455#) in wrapped_fn(self, *args, **kwargs)
1245 or self._state.in_setup): # pylint: disable=protected-access
-> 1246 return prewrapped_fn(self, *args, **kwargs)
1247 fn_name = class_fn.__name__
[/usr/local/lib/python3.7/dist-packages/flax/linen/module.py](https://jlqquoe2ip-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20220627-060046-RC00_457442455#) in wrapped_module_method(*args, **kwargs)
351 self, args = args[0], args[1:]
--> 352 return self._call_wrapped_method(fun, args, kwargs)
353 else:
[/usr/local/lib/python3.7/dist-packages/flax/linen/module.py](https://jlqquoe2ip-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20220627-060046-RC00_457442455#) in _call_wrapped_method(self, fun, args, kwargs)
641 else:
--> 642 self._try_setup()
643
[/usr/local/lib/python3.7/dist-packages/flax/linen/module.py](https://jlqquoe2ip-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20220627-060046-RC00_457442455#) in _try_setup(self, shallow)
858 if not shallow:
--> 859 self.setup()
860 # We run static checks abstractly once for setup before any transforms
[/usr/local/lib/python3.7/dist-packages/flax/linen/module.py](https://jlqquoe2ip-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20220627-060046-RC00_457442455#) in wrapped_module_method(*args, **kwargs)
351 self, args = args[0], args[1:]
--> 352 return self._call_wrapped_method(fun, args, kwargs)
353 else:
[/usr/local/lib/python3.7/dist-packages/flax/linen/module.py](https://jlqquoe2ip-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20220627-060046-RC00_457442455#) in _call_wrapped_method(self, fun, args, kwargs)
650 try:
--> 651 y = fun(self, *args, **kwargs)
652 if _context.capture_stack:
[<ipython-input-64-f5e7b32d537a>](https://jlqquoe2ip-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20220627-060046-RC00_457442455#) in setup(self)
7 if self.gradient_checkpointing:
----> 8 Net1 = remat(Net1)
9
UnfilteredStackTrace: UnboundLocalError: local variable 'Net1' referenced before assignment
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
UnboundLocalError Traceback (most recent call last)
[<ipython-input-65-9c490341f785>](https://jlqquoe2ip-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20220627-060046-RC00_457442455#) in <module>()
1 model = RematNet2(features=3, num_layers=2, gradient_checkpointing=True)
----> 2 params = model.init(rng, x)
3 out = model.apply(params, x)
[<ipython-input-64-f5e7b32d537a>](https://jlqquoe2ip-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20220627-060046-RC00_457442455#) in setup(self)
6 def setup(self):
7 if self.gradient_checkpointing:
----> 8 Net1 = remat(Net1)
9
10 self.layers = [Net1(features=self.features) for layer in range(self.num_layers)]
UnboundLocalError: local variable 'Net1' referenced before assignment Changing the module name for class RematNet2(nn.Module):
features: int
num_layers: int
gradient_checkpointing: bool = False
def setup(self):
if self.gradient_checkpointing:
RematNet1 = remat(Net1)
else:
RematNet1 = Net1
self.layers = [RematNet1(features=self.features) for layer in range(self.num_layers)]
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
model = RematNet2(features=3, num_layers=2, gradient_checkpointing=True)
params = model.init(rng, x)
out = model.apply(params, x)
print(out.shape) Of course in this dummy example it would be trivial to change the I noticed this pattern in other JAX/Flax repos such as T5x, where the naming for the Full runnable code end-to-end: https://colab.research.google.com/drive/1bnWrLUYvdV7znVk4dgyKLXg9KeGr0yJx?usp=sharing |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hey @sanchit-gandhi! What you are describing is not related to A = 1
def f(x):
A = A + 1
return x * A
print(f(2)) This yields the same error:
Given python will force you to choose a 3rd name anyway, you can get creative, I've used something like Good luck! |
Beta Was this translation helpful? Give feedback.
Hey @sanchit-gandhi! What you are describing is not related to
remat
or even Flax at all, its just how Python non local variables work inside functions:This yields the same error:
Given python will force you to choose a 3rd name anyway, you can get creative, I've used something like
RematOrNet1
in the past.Good luck!