Skip to content

Commit d6e8e14

Browse files
author
Flax Authors
committed
Merge pull request #5055 from mohsinm-dev:fix-bound-method-auto-unbinding
PiperOrigin-RevId: 833378860
2 parents 544a050 + dece6b3 commit d6e8e14

File tree

5 files changed

+280
-18
lines changed

5 files changed

+280
-18
lines changed

flax/nnx/transforms/autodiff.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@
2727
)
2828
from flax.nnx.statelib import State
2929
import jax
30-
import jax.core
31-
import jax.stages
3230

3331
from flax.nnx.transforms import general
34-
from flax.nnx.transforms.transforms import resolve_kwargs
32+
from flax.nnx.transforms.transforms import (
33+
resolve_kwargs,
34+
_resolve_bound_callable,
35+
_raise_bound_method_error,
36+
)
3537
from flax.typing import MISSING, Missing
3638

3739

@@ -60,6 +62,7 @@ class DiffState:
6062
filter: filterlib.Filter
6163

6264

65+
6366
@dataclasses.dataclass(eq=False)
6467
class GradFn:
6568
f: tp.Callable[..., tp.Any]
@@ -312,15 +315,23 @@ def grad(
312315
holomorphic=holomorphic,
313316
allow_int=allow_int,
314317
)
315-
return _grad_general(
316-
f,
318+
# Detect bound nnx.Module methods and raise error.
319+
f_unbound, _, was_bound = _resolve_bound_callable(f)
320+
321+
if was_bound:
322+
_raise_bound_method_error('grad')
323+
324+
grad_fn = _grad_general(
325+
f_unbound,
317326
argnums,
318327
has_aux,
319328
holomorphic,
320329
allow_int,
321330
return_value=False,
322331
)
323332

333+
return grad_fn
334+
324335

325336
@tp.overload
326337
def value_and_grad(
@@ -366,8 +377,14 @@ def value_and_grad(
366377
holomorphic=holomorphic,
367378
allow_int=allow_int,
368379
)
380+
# Detect bound nnx.Module methods and raise error.
381+
f_unbound, _, was_bound = _resolve_bound_callable(f)
382+
383+
if was_bound:
384+
_raise_bound_method_error('value_and_grad')
385+
369386
return _grad_general(
370-
f,
387+
f_unbound,
371388
argnums,
372389
has_aux,
373390
holomorphic,
@@ -829,7 +846,13 @@ def custom_vjp(
829846
"""
830847
if isinstance(fun, Missing):
831848
return functools.partial(custom_vjp, nondiff_argnums=nondiff_argnums)
832-
return CustomVjp(fun, nondiff_argnums)
849+
850+
# Detect bound nnx.Module methods and raise error.
851+
fun_unbound, _, was_bound = _resolve_bound_callable(fun)
852+
if was_bound:
853+
_raise_bound_method_error('custom_vjp')
854+
855+
return CustomVjp(fun_unbound, nondiff_argnums)
833856

834857

835858
# -------------------------------
@@ -881,11 +904,18 @@ def remat(
881904
policy=policy,
882905
) # type: ignore[return-value]
883906

884-
return resolve_kwargs()(
907+
# Detect bound nnx.Module methods and raise error.
908+
f_unbound, _, was_bound = _resolve_bound_callable(f)
909+
910+
if was_bound:
911+
_raise_bound_method_error('remat')
912+
913+
# Unbound function path: preserve the concise composition used in NNX.
914+
return resolve_kwargs()( # type: ignore[return-value]
885915
graph.update_context('remat')(
886916
general.split_inputs(
887917
jax.checkpoint(
888-
general.merge_inputs(f, ctxtag='remat'),
918+
general.merge_inputs(f_unbound, ctxtag='remat'),
889919
prevent_cse=prevent_cse,
890920
static_argnums=static_argnums,
891921
policy=policy,

flax/nnx/transforms/compilation.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030
statelib,
3131
variablelib,
3232
)
33+
from flax.nnx.transforms.transforms import (
34+
_resolve_bound_callable,
35+
_raise_bound_method_error,
36+
)
3337
from flax.typing import MISSING, Missing, PathParts
3438

3539
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
@@ -218,6 +222,11 @@ def jit(
218222
JAX keeps a weak reference to ``fun`` for use as a compilation cache key,
219223
so the object ``fun`` must be weakly-referenceable. Most :class:`Callable`
220224
objects will already satisfy this requirement.
225+
226+
.. note::
227+
Bound methods (e.g., ``module.method``) are not supported. Use the
228+
decorator form ``@nnx.jit`` on the method definition or call
229+
``nnx.jit(MyClass.method)(instance, ...)`` with the unbound method.
221230
in_shardings: Pytree of structure matching that of arguments to ``fun``,
222231
with all actual arguments replaced by resource assignment specifications.
223232
It is also valid to specify a pytree prefix (e.g. one value in place of a
@@ -335,8 +344,13 @@ def jit(
335344
inline=inline,
336345
abstracted_axes=abstracted_axes,
337346
) # type: ignore[return-value]
347+
# Detect bound nnx.Module methods and raise error.
348+
fun_unbound, _, was_bound = _resolve_bound_callable(fun)
349+
if was_bound:
350+
_raise_bound_method_error('jit')
351+
338352
return JitWrapped(
339-
fun,
353+
fun_unbound,
340354
in_shardings=in_shardings,
341355
out_shardings=out_shardings,
342356
static_argnums=static_argnums,
@@ -986,6 +1000,11 @@ def f(m, x):
9861000
) # type: ignore[return-value]
9871001
assert not isinstance(f, type)
9881002

1003+
# Detect bound nnx.Module methods and raise error.
1004+
f_unbound, _, was_bound = _resolve_bound_callable(f)
1005+
if was_bound:
1006+
_raise_bound_method_error('shard_map')
1007+
9891008
kwarg_specs = PartitionSpec()
9901009
jax_in_specs = jax.tree.map(
9911010
lambda x: extract.NodeStates(
@@ -1033,7 +1052,7 @@ def shard_map_wrapper(*args, **kwargs):
10331052
return out
10341053

10351054
shard_map_fn = jax.shard_map(
1036-
ShardMapFn(f, in_specs, out_specs, kwarg_specs, shard_map_wrapper),
1055+
ShardMapFn(f_unbound, in_specs, out_specs, kwarg_specs, shard_map_wrapper),
10371056
mesh=mesh,
10381057
in_specs=jax_in_specs,
10391058
out_specs=(jax_in_specs, kwarg_specs, jax_out_specs), # type: ignore

flax/nnx/transforms/iteration.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
from flax.nnx import statelib
2626
from flax.nnx.module import Module
2727
from flax.nnx.statelib import State
28-
from flax.nnx.transforms.transforms import resolve_kwargs
28+
from flax.nnx.transforms.transforms import (
29+
resolve_kwargs,
30+
_resolve_bound_callable,
31+
_raise_bound_method_error,
32+
)
2933
from flax.typing import Leaf, Missing, PytreeDeque
3034
import jax
3135
import jax.core
@@ -329,6 +333,11 @@ def vmap(
329333
transform_metadata=transform_metadata,
330334
) # type: ignore[return-value]
331335

336+
# Detect bound nnx.Module methods and raise error.
337+
f_unbound, _, was_bound = _resolve_bound_callable(f)
338+
if was_bound:
339+
_raise_bound_method_error('vmap')
340+
332341
jax_in_axes = jax.tree.map(
333342
lambda x: extract.NodeStates.from_prefixes(x.axes, metadata=x)
334343
if isinstance(x, StateAxes)
@@ -342,7 +351,7 @@ def vmap(
342351
out_axes,
343352
)
344353
vmapped_fn = jax.vmap(
345-
VmapFn(f, transform_metadata, in_axes, out_axes),
354+
VmapFn(f_unbound, transform_metadata, in_axes, out_axes),
346355
in_axes=jax_in_axes,
347356
out_axes=(jax_in_axes, jax_out_axes),
348357
axis_name=axis_name,
@@ -551,6 +560,11 @@ def pmap(
551560
transform_metadata=transform_metadata,
552561
) # type: ignore[return-value]
553562

563+
# Detect bound nnx.Module methods and raise error.
564+
f_unbound, _, was_bound = _resolve_bound_callable(f)
565+
if was_bound:
566+
_raise_bound_method_error('pmap')
567+
554568
jax_in_axes = jax.tree.map(
555569
lambda x: extract.NodeStates.from_prefixes(x.axes, metadata=x)
556570
if isinstance(x, StateAxes)
@@ -564,7 +578,7 @@ def pmap(
564578
out_axes,
565579
)
566580
pmapped_fn = jax.pmap(
567-
PmapFn(f, transform_metadata, in_axes, out_axes),
581+
PmapFn(f_unbound, transform_metadata, in_axes, out_axes),
568582
axis_name=axis_name,
569583
in_axes=jax_in_axes,
570584
out_axes=(jax_in_axes, jax_out_axes),
@@ -1258,6 +1272,10 @@ def forward(x, model):
12581272
transform_metadata=transform_metadata,
12591273
) # type: ignore[return-value]
12601274

1275+
f_unbound, _, was_bound = _resolve_bound_callable(f)
1276+
if was_bound:
1277+
_raise_bound_method_error('scan')
1278+
12611279
_check_out_axes(out_axes)
12621280

12631281
input_carry_argnum = _get_carry_argnum(in_axes, is_in_axes=True)
@@ -1272,7 +1290,7 @@ def forward(x, model):
12721290
)
12731291

12741292
scan_fn = ScanFn(
1275-
f,
1293+
f_unbound,
12761294
input_carry_argnum,
12771295
output_carry_argnum,
12781296
in_axes,

flax/nnx/transforms/transforms.py

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,66 @@ def resolve_kwargs_wrapper(*args, **kwargs):
9494

9595

9696

97+
# -------------------------------
98+
# helper utilities for bound methods & indices
99+
# -------------------------------
100+
101+
def _resolve_bound_callable(
102+
f: tp.Callable[..., tp.Any],
103+
) -> tuple[tp.Callable[..., tp.Any], tp.Any | None, bool]:
104+
"""Detects and extracts bound methods from NNX Module callables.
105+
106+
This function unwraps functools.partial layers to reach the underlying
107+
callable before checking if it's a bound method of an NNX Module.
108+
109+
Args:
110+
f: A callable that may be a bound method of an NNX Module, potentially
111+
wrapped in functools.partial.
112+
113+
Returns:
114+
A tuple of (unbound_fn, bound_self, was_bound) where:
115+
- unbound_fn: The unbound function (or original if not bound)
116+
- bound_self: The Module instance if f was bound, None otherwise
117+
- was_bound: True if f was a bound method, False otherwise
118+
119+
Note:
120+
Preserves functools.partial wrappers around the callable and follows
121+
the same detection pattern as _get_unbound_fn in bridge/module.py.
122+
Detection occurs before any argnum shifting or index normalization.
123+
"""
124+
# Unwrap functools.partial layers to reach the underlying callable.
125+
partials: list[tuple[tuple[tp.Any, ...], dict[str, tp.Any] | None]] = []
126+
g = f
127+
while isinstance(g, functools.partial): # type: ignore[arg-type]
128+
partials.append((g.args or (), g.keywords)) # type: ignore[attr-defined]
129+
g = g.func # type: ignore[attr-defined]
130+
131+
bound_self = getattr(g, "__self__", None)
132+
was_bound = bool(inspect.ismethod(g) and isinstance(bound_self, Module))
133+
if was_bound:
134+
g = g.__func__ # type: ignore[attr-defined]
135+
136+
# Reapply partials in reverse unwrap order.
137+
for args, kwargs in reversed(partials):
138+
kwargs = {} if kwargs is None else kwargs
139+
g = functools.partial(g, *args, **kwargs)
140+
141+
return g, (bound_self if was_bound else None), was_bound
142+
143+
144+
def _raise_bound_method_error(transform_name: str):
145+
"""Raises a standardized error for bound method usage with NNX transforms.
146+
147+
Args:
148+
transform_name: Name of the transform (e.g., 'grad', 'jit', 'remat').
149+
"""
150+
raise ValueError(
151+
f"nnx.{transform_name} does not support bound methods. "
152+
f"Use the decorator form @nnx.{transform_name} or call "
153+
f"nnx.{transform_name}(MyClass.method)(instance, ...) with the unbound method."
154+
)
155+
156+
97157
class LiftedModule(tp.Generic[M], Module): # type: ignore[ignored-abstractmethod]
98158
@abstractmethod
99159
def _call(self, accessor: DelayedAccessor, *args, **kwargs) -> tp.Any:
@@ -195,12 +255,18 @@ def eval_shape(
195255
performing any floating point operations (FLOPs) which can be expensive. This can be
196256
useful for performing shape inference, for example.
197257
"""
258+
# Detect bound nnx.Module methods and raise error.
259+
f_call, _, was_bound = _resolve_bound_callable(f)
260+
261+
if was_bound:
262+
_raise_bound_method_error('eval_shape')
263+
198264
args, kwargs = extract.to_tree((args, kwargs))
199265

200266
@functools.wraps(f)
201267
def _eval_shape_fn(*args, **kwargs):
202268
args, kwargs = extract.from_tree((args, kwargs))
203-
out = f(*args, **kwargs)
269+
out = f_call(*args, **kwargs)
204270
return _to_value_metadata(extract.to_tree(out))
205271

206272
out = jax.eval_shape(_eval_shape_fn, *args, **kwargs)
@@ -255,7 +321,13 @@ def checkify(
255321
>>> print(err)
256322
Error(nan generated by primitive: sin.)
257323
"""
258-
checkify_fn = checkify_lib.checkify(CheckifyFn(f), errors)
324+
# Detect bound nnx.Module methods and raise error.
325+
f_call, _, was_bound = _resolve_bound_callable(f)
326+
327+
if was_bound:
328+
_raise_bound_method_error('checkify')
329+
330+
checkify_fn = checkify_lib.checkify(CheckifyFn(f_call), errors)
259331

260332
@functools.wraps(f)
261333
@graph.update_context('checkify')
@@ -307,4 +379,3 @@ def switch(
307379
[general.merge_inputs(f, ctxtag='switch') for f in branches],
308380
*operands,
309381
)
310-

0 commit comments

Comments
 (0)