From 6baea11f8670cf0f6f982c9a4cb1f201c9569d84 Mon Sep 17 00:00:00 2001 From: AntoinePlumerault Date: Sat, 31 May 2025 17:54:10 +0000 Subject: [PATCH 01/18] add convolution layers --- penzai/core/named_axes.py | 527 +++++++++---------- penzai/nn/linear_and_affine.py | 777 ++++++++++++++++++++++++++--- penzai/pz/nn.py | 4 + tests/nn/linear_and_affine_test.py | 56 +++ uv.lock | 34 +- 5 files changed, 1030 insertions(+), 368 deletions(-) diff --git a/penzai/core/named_axes.py b/penzai/core/named_axes.py index 7070a56..0dafc37 100644 --- a/penzai/core/named_axes.py +++ b/penzai/core/named_axes.py @@ -170,15 +170,15 @@ def nmap(fun: Callable[..., Any]) -> Callable[..., Any]: def _nmap_with_doc( - fun: Callable[..., Any], fun_name: str, fun_doc: str | None = None + fun: Callable[..., Any], fun_name: str, fun_doc: str | None = None ) -> Callable[..., Any]: """Builds a nmap-wrapped function with a docstring.""" @functools.wraps(fun) def wrapped_fun(*args, **kwargs): arg_leaves_and_paths, arg_treedef = jax.tree_util.tree_flatten_with_path( - (args, kwargs), - is_leaf=lambda node: isinstance(node, NamedArray | NamedArrayView), + (args, kwargs), + is_leaf=lambda node: isinstance(node, NamedArray | NamedArrayView), ) arg_leaves = [leaf for _, leaf in arg_leaves_and_paths] # Extract any argument leaves that were NamedArrays or NamedArrayViews. The @@ -201,10 +201,7 @@ def wrapped_fun(*args, **kwargs): named_array_arg_leaves.append(leaf.as_namedarrayview()) if bad_names: - msg = [ - f"Inconsistent named axes in a call to nmap({fun}) for axes" - f" {bad_names}:" - ] + msg = [f"Inconsistent named axes in a call to nmap({fun}) for axes {bad_names}:"] for keypath, leaf in arg_leaves_and_paths: if isinstance(leaf, NamedArray | NamedArrayView): assert keypath @@ -274,22 +271,20 @@ def _shift_axis(other_axis): # data_array will still have the extra axis. But after running `vmap`, # it will be valid again. reduced_views.append( - NamedArrayView( - data_array=view.data_array, - data_axis_for_name={ - name: _shift_axis(data_axis) - for name, data_axis in view.data_axis_for_name.items() - if name != vmap_name - }, - data_axis_for_logical_axis=tuple( - _shift_axis(data_axis) - for data_axis in view.data_axis_for_logical_axis - ), - data_shape=( - view.data_shape[:vmap_axis] - + view.data_shape[vmap_axis + 1 :] - ), - ) + NamedArrayView( + data_array=view.data_array, + data_axis_for_name={ + name: _shift_axis(data_axis) + for name, data_axis in view.data_axis_for_name.items() + if name != vmap_name + }, + data_axis_for_logical_axis=tuple( + _shift_axis(data_axis) for data_axis in view.data_axis_for_logical_axis + ), + data_shape=( + view.data_shape[:vmap_axis] + view.data_shape[vmap_axis + 1 :] + ), + ) ) else: # This argument doesn't have this axis, so don't map over anything. @@ -297,13 +292,13 @@ def _shift_axis(other_axis): reduced_views.append(view) return jax.vmap( - functools.partial( - recursive_vectorize_step, - remaining_names=remaining_names[1:], - ), - in_axes=(vmap_axes,), - out_axes=0, - axis_name=vmap_name, + functools.partial( + recursive_vectorize_step, + remaining_names=remaining_names[1:], + ), + in_axes=(vmap_axes,), + out_axes=0, + axis_name=vmap_name, )(reduced_views) # Run the function. @@ -317,24 +312,24 @@ def handle_result(leaf): leaf = jnp.array(leaf) if leaf.ndim == len(all_names): return NamedArray( - data_array=leaf, - named_axes=collections.OrderedDict(zip(all_names, leaf.shape)), + data_array=leaf, + named_axes=collections.OrderedDict(zip(all_names, leaf.shape)), ) else: assert leaf.ndim > len(all_names) return NamedArrayView( - data_array=leaf, - data_shape=leaf.shape, - data_axis_for_name={name: i for i, name in enumerate(all_names)}, - data_axis_for_logical_axis=tuple(range(len(all_names), leaf.ndim)), + data_array=leaf, + data_shape=leaf.shape, + data_axis_for_name={name: i for i, name in enumerate(all_names)}, + data_axis_for_logical_axis=tuple(range(len(all_names), leaf.ndim)), ) return jax.tree_util.tree_map(handle_result, result_data) docstr = ( - f"Name-vectorized version of `{fun_name}`. Takes similar arguments as" - f" `{fun_name}` but accepts and returns NamedArrays (or NamedArrayViews)" - " in place of regular arrays." + f"Name-vectorized version of `{fun_name}`. Takes similar arguments as" + f" `{fun_name}` but accepts and returns NamedArrays (or NamedArrayViews)" + " in place of regular arrays." ) if fun_doc: docstr += f"\n\nOriginal documentation:\n\n{fun_doc}" @@ -357,8 +352,8 @@ def _wrap_scalar_conversion(scalar_conversion): def wrapped_scalar_conversion(self: NamedArrayBase): if self.named_shape or self.positional_shape: raise ValueError( - "Cannot convert a non-scalar NamedArray or NamedArrayView with" - f" {scalar_conversion}" + "Cannot convert a non-scalar NamedArray or NamedArrayView with" + f" {scalar_conversion}" ) return scalar_conversion(self.unwrap()) @@ -374,17 +369,17 @@ def func(array, *args, **kwargs): array_method = getattr(jax.Array, name) wrapped_func = nmap(func) functools.update_wrapper( - wrapped_func, - array_method, - assigned=("__name__", "__qualname__", "__annotations__"), - updated=(), + wrapped_func, + array_method, + assigned=("__name__", "__qualname__", "__annotations__"), + updated=(), ) wrapped_func.__module__ = __name__ wrapped_func.__doc__ = ( - "Name-vectorized version of array method" - f" `{name} `. Takes similar arguments as" - f" `{name} ` but accepts and returns NamedArrays" - " (or NamedArrayViews) in place of regular arrays." + "Name-vectorized version of array method" + f" `{name} `. Takes similar arguments as" + f" `{name} ` but accepts and returns NamedArrays" + " (or NamedArrayViews) in place of regular arrays." ) return wrapped_func @@ -416,56 +411,56 @@ def unwrap(self): @functools.partial( - jax.jit, - static_argnames=[ - "indices_are_sorted", - "unique_indices", - "mode", - "fill_value", - ], + jax.jit, + static_argnames=[ + "indices_are_sorted", + "unique_indices", + "mode", + "fill_value", + ], ) @nmap def _jitted_nmapped_getitem( - array: jax.Array, - index_thunks: tuple[_StaticThunk | _DynamicThunk | _SliceThunk, ...], - *, - indices_are_sorted=False, - unique_indices=False, - mode=None, - fill_value=None, + array: jax.Array, + index_thunks: tuple[_StaticThunk | _DynamicThunk | _SliceThunk, ...], + *, + indices_are_sorted=False, + unique_indices=False, + mode=None, + fill_value=None, ): """JIT-compiled helper for getitem.""" indexer = tuple(thunk.unwrap() for thunk in index_thunks) return array.at[indexer].get( - indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, - mode=mode, - fill_value=fill_value, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, + mode=mode, + fill_value=fill_value, ) @functools.partial( - jax.jit, - static_argnames=["method", "indices_are_sorted", "unique_indices", "mode"], + jax.jit, + static_argnames=["method", "indices_are_sorted", "unique_indices", "mode"], ) @nmap def _jitted_nmapped_update( - array: jax.Array, - index_thunks: tuple[_StaticThunk | _DynamicThunk | _SliceThunk, ...], - values: jax.Array, - method: str, - *, - indices_are_sorted=False, - unique_indices=False, - mode=None, + array: jax.Array, + index_thunks: tuple[_StaticThunk | _DynamicThunk | _SliceThunk, ...], + values: jax.Array, + method: str, + *, + indices_are_sorted=False, + unique_indices=False, + mode=None, ): """JIT-compiled helper for in-place updates.""" indexer = tuple(thunk.unwrap() for thunk in index_thunks) return getattr(array.at[indexer], method)( - values, - indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, - mode=mode, + values, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, + mode=mode, ) @@ -519,16 +514,16 @@ def _partition_dict_index(self): # Sliced axis. sliced_axes.append(name) elif isinstance(index, int) or ( - isinstance(index, jax.Array | np.ndarray | NamedArrayBase) - and jnp.issubdtype(index.dtype, np.integer) + isinstance(index, jax.Array | np.ndarray | NamedArrayBase) + and jnp.issubdtype(index.dtype, np.integer) ): # Indexed axis. indexed_axes.append(name) else: raise TypeError( - "Unsupported index for a named axis using dict-style index:" - " expected a slice, an integer, an integer array, or None, but" - f" got {index}" + "Unsupported index for a named axis using dict-style index:" + " expected a slice, an integer, an integer array, or None, but" + f" got {index}" ) input_prefix_order = (*sliced_axes, *indexed_axes) @@ -543,13 +538,13 @@ def get(self, **kwargs) -> NamedArrayBase: # Dict indexing => desugar it to positional indexing over the requested # names. input_prefix_order, slice_order, output_prefix_order = ( - self._partition_dict_index() + self._partition_dict_index() ) return ( - self.array.untag_prefix(*input_prefix_order) - .at[tuple(self.indexer[name] for name in slice_order)] - .get(**kwargs) - .tag_prefix(*output_prefix_order) + self.array.untag_prefix(*input_prefix_order) + .at[tuple(self.indexer[name] for name in slice_order)] + .get(**kwargs) + .tag_prefix(*output_prefix_order) ) else: @@ -579,7 +574,7 @@ def _nmap_update_op(self, method: str, value, kwargs) -> NamedArrayBase: # Dict indexing => desugar it to positional indexing over the requested # names. input_prefix_order, slice_order, output_prefix_order = ( - self._partition_dict_index() + self._partition_dict_index() ) # Make sure the provided value has the necessary axes, by broadcasting it @@ -596,36 +591,32 @@ def _nmap_update_op(self, method: str, value, kwargs) -> NamedArrayBase: result_shape = result_structure.positional_shape num_new_positional_axes = len(result_shape) - len(value_shape) if num_new_positional_axes < 0 or not all( - vd == 1 or vd == rd - for vd, rd in zip( - value_shape, result_shape[len(result_shape) - len(value_shape) :] - ) + vd == 1 or vd == rd + for vd, rd in zip( + value_shape, result_shape[len(result_shape) - len(value_shape) :] + ) ): raise ValueError( - "Cannot provide updates with positional shape" - f" {value_shape} for an index whose result shape is" - f" {result_shape}! Update shape must be a" - " suffix of the result shape (or broadcastable to it)." + "Cannot provide updates with positional shape" + f" {value_shape} for an index whose result shape is" + f" {result_shape}! Update shape must be a" + " suffix of the result shape (or broadcastable to it)." ) if num_new_positional_axes: value = value[(None,) * num_new_positional_axes + (...,)] new_names = { - name: None - for name in output_prefix_order - if name not in value.named_shape + name: None for name in output_prefix_order if name not in value.named_shape } if new_names: value = value[new_names] # pylint: disable=protected-access return ( - self.array.untag_prefix(*input_prefix_order) - .at[tuple(self.indexer[name] for name in slice_order)] - ._nmap_update_op( - method, value.untag_prefix(*output_prefix_order), kwargs - ) - .tag_prefix(*input_prefix_order) + self.array.untag_prefix(*input_prefix_order) + .at[tuple(self.indexer[name] for name in slice_order)] + ._nmap_update_op(method, value.untag_prefix(*output_prefix_order), kwargs) + .tag_prefix(*input_prefix_order) ) # pylint: enable=protected-access @@ -649,7 +640,7 @@ def _nmap_update_op(self, method: str, value, kwargs) -> NamedArrayBase: index_thunks.append(_StaticThunk(c)) return _jitted_nmapped_update( - self.array, tuple(index_thunks), value, method, **kwargs + self.array, tuple(index_thunks), value, method, **kwargs ) def set(self, values, /, **kwargs): @@ -843,8 +834,7 @@ def tag_prefix(self, *axis_order: AxisName) -> NamedArray | NamedArrayView: # We implement `tag_prefix` using `tag` with temporary axis # identifiers. tmp_axis_ids = [ - TmpPosAxisMarker() - for _ in range(len(self.positional_shape) - len(axis_order)) + TmpPosAxisMarker() for _ in range(len(self.positional_shape) - len(axis_order)) ] return self.tag(*axis_order, *tmp_axis_ids).untag(*tmp_axis_ids) @@ -876,14 +866,14 @@ def order_as(self, *axis_order: AxisName) -> NamedArray: data_array = self.tag(*tmp_names).untag(*tmp_names, *axis_order).unwrap() return ( - NamedArray.wrap(data_array) - .tag(*tmp_names, *axis_order) - .untag(*tmp_names) - .with_positional_prefix() + NamedArray.wrap(data_array) + .tag(*tmp_names, *axis_order) + .untag(*tmp_names) + .with_positional_prefix() ) def order_like( - self, other: NamedArray | NamedArrayView + self, other: NamedArray | NamedArrayView ) -> NamedArray | NamedArrayView: """Ensures that this array's PyTree structure matches another array's. @@ -918,46 +908,46 @@ def order_like( elif isinstance(other, NamedArrayView): if len(self.positional_shape) != len(other.positional_shape): raise ValueError( - "Calling `order_like` with a NamedArrayView requires the two" - " arrays to have the same number of positional axes, but got" - f" positional shapes {self.positional_shape=}," - f" {other.positional_shape=}" + "Calling `order_like` with a NamedArrayView requires the two" + " arrays to have the same number of positional axes, but got" + f" positional shapes {self.positional_shape=}," + f" {other.positional_shape=}" ) if set(self.named_shape.keys()) != set(other.named_shape.keys()): raise ValueError( - "Calling `order_like` with a NamedArrayView requires the two" - " arrays to have the axis names, but got" - f" named shapes {self.named_shape=}, {other.named_shape=}" + "Calling `order_like` with a NamedArrayView requires the two" + " arrays to have the axis names, but got" + f" named shapes {self.named_shape=}, {other.named_shape=}" ) self_view = self.as_namedarrayview() new_to_old_data_axis = {} for old_data_axis, new_data_axis in zip( - self_view.data_axis_for_logical_axis, other.data_axis_for_logical_axis + self_view.data_axis_for_logical_axis, other.data_axis_for_logical_axis ): new_to_old_data_axis[new_data_axis] = old_data_axis for name, new_data_axis in other.data_axis_for_name.items(): new_to_old_data_axis[new_data_axis] = self_view.data_axis_for_name[name] new_data_array = jnp.transpose( - self_view.data_array, - [new_to_old_data_axis[i] for i in range(self_view.data_array.ndim)], + self_view.data_array, + [new_to_old_data_axis[i] for i in range(self_view.data_array.ndim)], ) return NamedArrayView( - data_shape=new_data_array.shape, - data_axis_for_logical_axis=other.data_axis_for_logical_axis, - data_axis_for_name=other.data_axis_for_name, - data_array=new_data_array, + data_shape=new_data_array.shape, + data_axis_for_logical_axis=other.data_axis_for_logical_axis, + data_axis_for_name=other.data_axis_for_name, + data_array=new_data_array, ) else: raise TypeError( - "`order_like` requires a NamedArray or NamedArrayView, but got" - f" {type(other).__name__}" + "`order_like` requires a NamedArray or NamedArrayView, but got" + f" {type(other).__name__}" ) def broadcast_to( - self, - positional_shape: Sequence[int] = (), - named_shape: Mapping[AxisName, int] | None = None, + self, + positional_shape: Sequence[int] = (), + named_shape: Mapping[AxisName, int] | None = None, ) -> NamedArrayBase: """Broadcasts a named array to a possibly-larger shape. @@ -977,22 +967,22 @@ def broadcast_to( named_shape = {} named_shape = dict(named_shape) if ( - self.positional_shape == tuple(positional_shape) - and dict(self.named_shape) == named_shape + self.positional_shape == tuple(positional_shape) + and dict(self.named_shape) == named_shape ): return self # Trick: create a size-zero array with the right shape so that we can # broadcast using nmap's vectorization rules. prototype_data = jnp.zeros( - tuple(named_shape.values()) + tuple(positional_shape) + (0,) + tuple(named_shape.values()) + tuple(positional_shape) + (0,) ) assert prototype_data.size == 0 prototype = NamedArray.wrap(prototype_data).tag_prefix(*named_shape.keys()) return nmap(lambda a, b: jnp.broadcast_to(a, b.shape[:-1]))(self, prototype) def broadcast_like( - self, other: NamedArrayBase | jax.typing.ArrayLike + self, other: NamedArrayBase | jax.typing.ArrayLike ) -> NamedArrayBase: """Broadcasts a named array to be compatible with another. @@ -1255,20 +1245,16 @@ def __treescope_ndarray_adapter__(self): __rsub__ = _nmap_with_doc(_swapped_binop(operator.sub), "jax.Array.__rsub__") __rmul__ = _nmap_with_doc(_swapped_binop(operator.mul), "jax.Array.__rmul__") __rtruediv__ = _nmap_with_doc( - _swapped_binop(operator.truediv), "jax.Array.__rtruediv__" + _swapped_binop(operator.truediv), "jax.Array.__rtruediv__" ) __rfloordiv__ = _nmap_with_doc( - _swapped_binop(operator.floordiv), "jax.Array.__rfloordiv__" + _swapped_binop(operator.floordiv), "jax.Array.__rfloordiv__" ) __rmod__ = _nmap_with_doc(_swapped_binop(operator.mod), "jax.Array.__rmod__") __rdivmod__ = _nmap_with_doc(_swapped_binop(divmod), "jax.Array.__rdivmod__") __rpow__ = _nmap_with_doc(_swapped_binop(operator.pow), "jax.Array.__rpow__") - __rlshift__ = _nmap_with_doc( - _swapped_binop(operator.lshift), "jax.Array.__rlshift__" - ) - __rrshift__ = _nmap_with_doc( - _swapped_binop(operator.rshift), "jax.Array.__rrshift__" - ) + __rlshift__ = _nmap_with_doc(_swapped_binop(operator.lshift), "jax.Array.__rlshift__") + __rrshift__ = _nmap_with_doc(_swapped_binop(operator.rshift), "jax.Array.__rrshift__") __rand__ = _nmap_with_doc(_swapped_binop(operator.and_), "jax.Array.__rand__") __ror__ = _nmap_with_doc(_swapped_binop(operator.or_), "jax.Array.__ror__") __rxor__ = _nmap_with_doc(_swapped_binop(operator.xor), "jax.Array.__rxor__") @@ -1410,7 +1396,7 @@ class NamedArray(NamedArrayBase, struct.Struct): """ named_axes: collections.OrderedDict[AxisName, int] = dataclasses.field( - metadata={"pytree_node": False} + metadata={"pytree_node": False} ) data_array: jax.Array @@ -1434,7 +1420,7 @@ def wrap(cls, array: jax.typing.ArrayLike, *names: AxisName) -> NamedArray: shape. """ wrapped = NamedArray( - named_axes=collections.OrderedDict(), data_array=jnp.asarray(array) + named_axes=collections.OrderedDict(), data_array=jnp.asarray(array) ) if names: return wrapped.tag(*names) @@ -1446,39 +1432,34 @@ def dtype(self) -> np.dtype: return self.data_array.dtype def check_valid(self) -> None: - if not hasattr(self.data_array, "shape") or not hasattr( - self.data_array, "dtype" - ): + if not hasattr(self.data_array, "shape") or not hasattr(self.data_array, "dtype"): raise ValueError( - "NamedArray.data_array must contain a jax or numpy array (or at least" - f" something with a shape and dtype), not {type(self.data_array)}" + "NamedArray.data_array must contain a jax or numpy array (or at least" + f" something with a shape and dtype), not {type(self.data_array)}" ) if not isinstance(self.named_axes, collections.OrderedDict) or not all( - isinstance(size, int) for size in self.named_axes.values() + isinstance(size, int) for size in self.named_axes.values() ): raise ValueError( - "NamedArray.named_axes must be an ordered dictionary of named" - " axis shapes" + "NamedArray.named_axes must be an ordered dictionary of named axis shapes" ) if any(isinstance(name, int) for name in self.named_axes.keys()): raise ValueError( - "Integers are not allowed as axis names, to avoid confusion with" - " positional axis indices." + "Integers are not allowed as axis names, to avoid confusion with" + " positional axis indices." ) true_suffix_shape = tuple( - self.data_array.shape[ - len(self.data_array.shape) - len(self.named_axes) : - ] + self.data_array.shape[len(self.data_array.shape) - len(self.named_axes) :] ) if true_suffix_shape != tuple(self.named_axes.values()): raise ValueError( - "The axis sizes in `named_axes` must exactly match a suffix " - " of the data array's shape, but" - f" {tuple(self.named_axes.values())} was not a suffix of" - f" {self.data_array.shape}" + "The axis sizes in `named_axes` must exactly match a suffix " + " of the data array's shape, but" + f" {tuple(self.named_axes.values())} was not a suffix of" + f" {self.data_array.shape}" ) @property @@ -1495,21 +1476,21 @@ def unwrap(self, *names) -> jax.Array: if names: if self.positional_shape: raise ValueError( - "Cannot unwrap a NamedArray by providing an axis name ordering if" - " it already has a positional shape. For advanced axis name" - " manipulation, try using `untag` and `tag` directly." + "Cannot unwrap a NamedArray by providing an axis name ordering if" + " it already has a positional shape. For advanced axis name" + " manipulation, try using `untag` and `tag` directly." ) name_bound = self.untag(*names) if name_bound.named_shape: raise ValueError( - "When calling `unwrap` with axis names, a position must be given" - f" for every axis name. Unassigned names: {name_bound.named_shape}" + "When calling `unwrap` with axis names, a position must be given" + f" for every axis name. Unassigned names: {name_bound.named_shape}" ) return name_bound.unwrap() if self.named_axes: raise ValueError( - "To unwrap a NamedArray with nonempty named shape, an ordering for" - f" its named axes must be provided. Named shape: {self.named_axes}" + "To unwrap a NamedArray with nonempty named shape, an ordering for" + f" its named axes must be provided. Named shape: {self.named_axes}" ) return self.data_array @@ -1522,15 +1503,15 @@ def as_namedarrayview(self) -> NamedArrayView: positional_axis_count = len(self.data_array.shape) - len(self.named_axes) data_axis_for_name = { - name: index + positional_axis_count - for index, name in enumerate(self.named_axes.keys()) + name: index + positional_axis_count + for index, name in enumerate(self.named_axes.keys()) } data_axis_for_logical_axis = tuple(range(positional_axis_count)) return NamedArrayView( - data_shape=self.data_array.shape, - data_axis_for_logical_axis=data_axis_for_logical_axis, - data_axis_for_name=data_axis_for_name, - data_array=self.data_array, + data_shape=self.data_array.shape, + data_axis_for_logical_axis=data_axis_for_logical_axis, + data_axis_for_name=data_axis_for_name, + data_array=self.data_array, ) def untag(self, *axis_order: AxisName) -> NamedArray | NamedArrayView: @@ -1569,14 +1550,12 @@ class NamedArrayView(NamedArrayBase, struct.Struct): data_array: The underlying positional-indexed array. """ - data_shape: tuple[int, ...] = dataclasses.field( - metadata={"pytree_node": False} - ) + data_shape: tuple[int, ...] = dataclasses.field(metadata={"pytree_node": False}) data_axis_for_logical_axis: tuple[int, ...] = dataclasses.field( - metadata={"pytree_node": False} + metadata={"pytree_node": False} ) data_axis_for_name: dict[AxisName, int] = dataclasses.field( - metadata={"pytree_node": False} + metadata={"pytree_node": False} ) data_array: jax.Array @@ -1588,22 +1567,20 @@ def dtype(self) -> np.dtype: def check_valid(self) -> None: # Data array has a shape. - if not hasattr(self.data_array, "shape") or not hasattr( - self.data_array, "dtype" - ): + if not hasattr(self.data_array, "shape") or not hasattr(self.data_array, "dtype"): raise ValueError( - "NamedArrayView.data_array must contain a jax or numpy array (or at" - " least something with a shape and dtype), not" - f" {type(self.data_array)}" + "NamedArrayView.data_array must contain a jax or numpy array (or at" + " least something with a shape and dtype), not" + f" {type(self.data_array)}" ) # Data shape is valid. if self.data_shape != self.data_array.shape: raise ValueError( - f"Expected data_array to have shape {self.data_shape}, but it has" - f" shape {self.data_array.shape}. Modifying the shape of the data" - " array of a NamedArrayView directly is not allowed; use `nmap`" - " instead, or call `with_positional_prefix` if you need to" - " manipulate the positional axes as prefix axes." + f"Expected data_array to have shape {self.data_shape}, but it has" + f" shape {self.data_array.shape}. Modifying the shape of the data" + " array of a NamedArrayView directly is not allowed; use `nmap`" + " instead, or call `with_positional_prefix` if you need to" + " manipulate the positional axes as prefix axes." ) # Every axis appears exactly once. seen_axes = collections.Counter() @@ -1611,28 +1588,27 @@ def check_valid(self) -> None: seen_axes.update(self.data_axis_for_name.values()) if seen_axes != collections.Counter(range(len(self.data_shape))): raise ValueError( - "Every axis index into `data_shape` must appear exactly once in" - " either `data_axis_for_logical_axis` or `data_axis_for_name`." + "Every axis index into `data_shape` must appear exactly once in" + " either `data_axis_for_logical_axis` or `data_axis_for_name`." ) # Check for bad names. if any(isinstance(name, int) for name in self.data_axis_for_name.keys()): raise ValueError( - "Integers are not allowed as axis names, to avoid confusion with" - " positional axis indices." + "Integers are not allowed as axis names, to avoid confusion with" + " positional axis indices." ) @property def named_shape(self) -> Mapping[AxisName, int]: return { - name: self.data_shape[data_axis] - for name, data_axis in self.data_axis_for_name.items() + name: self.data_shape[data_axis] + for name, data_axis in self.data_axis_for_name.items() } @property def positional_shape(self) -> tuple[int, ...]: return tuple( - self.data_shape[data_axis] - for data_axis in self.data_axis_for_logical_axis + self.data_shape[data_axis] for data_axis in self.data_axis_for_logical_axis ) def unwrap(self, *names) -> jax.Array: @@ -1640,22 +1616,22 @@ def unwrap(self, *names) -> jax.Array: if names: if self.positional_shape: raise ValueError( - "Cannot unwrap a NamedArrayView by providing an axis name ordering" - " if it already has a positional shape. For advanced axis name" - " manipulation, try using `untag` and `tag` directly." + "Cannot unwrap a NamedArrayView by providing an axis name ordering" + " if it already has a positional shape. For advanced axis name" + " manipulation, try using `untag` and `tag` directly." ) name_bound = self.untag(*names) if name_bound.named_shape: raise ValueError( - "When calling `unwrap` with axis names, a position must be given" - f" for every axis name. Unassigned names: {name_bound.named_shape}" + "When calling `unwrap` with axis names, a position must be given" + f" for every axis name. Unassigned names: {name_bound.named_shape}" ) return name_bound.unwrap() if self.named_shape: raise ValueError( - "To unwrap a NamedArrayView with nonempty named shape, an ordering" - " for its named axes must be provided. Named shape:" - f" {self.named_shape}" + "To unwrap a NamedArrayView with nonempty named shape, an ordering" + " for its named axes must be provided. Named shape:" + f" {self.named_shape}" ) # with_positional_prefix will perform the necessary transpositions, we can # then simply unwrap the resulting NamedArray. @@ -1681,7 +1657,7 @@ def with_positional_prefix(self) -> NamedArray: transposition.append(data_axis) # Then the named axes in sorted order (to try to avoid transposition) data_axes_and_names = sorted( - (data_axis, name) for name, data_axis in self.data_axis_for_name.items() + (data_axis, name) for name, data_axis in self.data_axis_for_name.items() ) for data_axis, name in data_axes_and_names: transposition.append(data_axis) @@ -1692,8 +1668,8 @@ def with_positional_prefix(self) -> NamedArray: return NamedArray(data_array=self.data_array, named_axes=named_axes) else: return NamedArray( - data_array=self.data_array.transpose(transposition), - named_axes=named_axes, + data_array=self.data_array.transpose(transposition), + named_axes=named_axes, ) def as_namedarrayview(self) -> NamedArrayView: @@ -1704,10 +1680,10 @@ def untag(self, *axis_order: AxisName) -> NamedArray | NamedArrayView: self.check_valid() if self.data_axis_for_logical_axis: raise ValueError( - "`untag` cannot be used to introduce positional axes for a" - " NamedArray (or NamedArrayView) that already has positional axes." - " Please assign names to the existing positional axes first using" - " `tag`." + "`untag` cannot be used to introduce positional axes for a" + " NamedArray (or NamedArrayView) that already has positional axes." + " Please assign names to the existing positional axes first using" + " `tag`." ) if not axis_order: @@ -1724,7 +1700,7 @@ def untag(self, *axis_order: AxisName) -> NamedArray | NamedArrayView: bad_names = requested_axis_set.difference(actual_axis_set) if bad_names: raise ValueError( - f"Requested axis names {bad_names} are not present in the array." + f"Requested axis names {bad_names} are not present in the array." ) # Build a view. @@ -1735,25 +1711,25 @@ def untag(self, *axis_order: AxisName) -> NamedArray | NamedArrayView: del data_axis_for_name[name] return NamedArrayView( - data_shape=self.data_shape, - data_axis_for_logical_axis=tuple(data_axis_for_logical_axis), - data_axis_for_name=data_axis_for_name, - data_array=self.data_array, + data_shape=self.data_shape, + data_axis_for_logical_axis=tuple(data_axis_for_logical_axis), + data_axis_for_name=data_axis_for_name, + data_array=self.data_array, ) def tag(self, *names) -> NamedArray: self.check_valid() if len(names) != len(self.data_axis_for_logical_axis): raise ValueError( - "There must be exactly as many names given to `tag` as there" - f" are positional axes in the array, but got {names} for positional" - f" shape {self.positional_shape}" + "There must be exactly as many names given to `tag` as there" + f" are positional axes in the array, but got {names} for positional" + f" shape {self.positional_shape}" ) if any(isinstance(name, int) for name in names): raise ValueError( - "Integers are not allowed as axis names, to avoid confusion with" - " positional axis indices." + "Integers are not allowed as axis names, to avoid confusion with" + " positional axis indices." ) seen_axes = collections.Counter() @@ -1762,9 +1738,9 @@ def tag(self, *names) -> NamedArray: repeated_names = [name for name, count in seen_axes.items() if count > 1] if repeated_names: raise ValueError( - "Repeated axis names are not allowed; original names were" - f" {tuple(self.data_axis_for_name.keys())} and new names passed to" - f" tag were {names}; repeated: {repeated_names}" + "Repeated axis names are not allowed; original names were" + f" {tuple(self.data_axis_for_name.keys())} and new names passed to" + f" tag were {names}; repeated: {repeated_names}" ) names_by_index = {} @@ -1775,11 +1751,10 @@ def tag(self, *names) -> NamedArray: names_by_index[data_axis] = names[i] return NamedArray( - data_array=self.data_array, - named_axes=collections.OrderedDict([ - (names_by_index[i], self.data_shape[i]) - for i in range(len(self.data_shape)) - ]), + data_array=self.data_array, + named_axes=collections.OrderedDict( + [(names_by_index[i], self.data_shape[i]) for i in range(len(self.data_shape))] + ), ) @@ -1793,9 +1768,9 @@ def is_namedarray(value) -> typing.TypeGuard[NamedArrayBase]: def full( - named_shape: Mapping[AxisName, int], - fill_value: jax.typing.ArrayLike, - dtype: np.DTypeLike | None = None, + named_shape: Mapping[AxisName, int], + fill_value: jax.typing.ArrayLike, + dtype: np.DTypeLike | None = None, ) -> NamedArray: """Constructs a full named array with a given shape. @@ -1809,14 +1784,14 @@ def full( NamedArray with the given named shape, filled with ``fill_value``. """ return NamedArray( - named_axes=collections.OrderedDict(named_shape), - data_array=jnp.full(tuple(named_shape.values()), fill_value, dtype=dtype), + named_axes=collections.OrderedDict(named_shape), + data_array=jnp.full(tuple(named_shape.values()), fill_value, dtype=dtype), ) def zeros( - named_shape: Mapping[AxisName, int], - dtype: np.DTypeLike | None = None, + named_shape: Mapping[AxisName, int], + dtype: np.DTypeLike | None = None, ) -> NamedArray: """Constructs a named array of zeros with a given shape. @@ -1828,14 +1803,14 @@ def zeros( NamedArray with the given named shape, filled with zeros. """ return NamedArray( - named_axes=collections.OrderedDict(named_shape), - data_array=jnp.zeros(tuple(named_shape.values()), dtype), + named_axes=collections.OrderedDict(named_shape), + data_array=jnp.zeros(tuple(named_shape.values()), dtype), ) def ones( - named_shape: Mapping[AxisName, int], - dtype: np.DTypeLike | None = None, + named_shape: Mapping[AxisName, int], + dtype: np.DTypeLike | None = None, ) -> NamedArray: """Constructs a named array of ones with a given shape. @@ -1847,17 +1822,17 @@ def ones( NamedArray with the given named shape, filled with ones. """ return NamedArray( - named_axes=collections.OrderedDict(named_shape), - data_array=jnp.ones(tuple(named_shape.values()), dtype), + named_axes=collections.OrderedDict(named_shape), + data_array=jnp.ones(tuple(named_shape.values()), dtype), ) def arange( - name: str, - start: int, - stop: int | None = None, - step: int | None = None, - dtype: jax.typing.DTypeLike | None = None, + name: str, + start: int, + stop: int | None = None, + step: int | None = None, + dtype: jax.typing.DTypeLike | None = None, ) -> NamedArray: """Convenience function to create a range along a named axis. @@ -1879,8 +1854,8 @@ def arange( def random_split( - key: jax.Array | NamedArrayBase, - named_shape: Mapping[AxisName, int] | Sequence[tuple[AxisName, int]], + key: jax.Array | NamedArrayBase, + named_shape: Mapping[AxisName, int] | Sequence[tuple[AxisName, int]], ) -> NamedArray | NamedArrayView: """Splits a PRNG key into a `NamedArray` of PRNG keys with the given names. @@ -1910,9 +1885,9 @@ def random_split( names = sorted(unsorted_keys) except Exception as exc: raise ValueError( - "Unordered mappings must have sortable axis names when using" - " `random_split`. If necessary, you can specify a particular ordering" - " using a collections.OrderedDict or a tuple of (name, size) pairs." + "Unordered mappings must have sortable axis names when using" + " `random_split`. If necessary, you can specify a particular ordering" + " using a collections.OrderedDict or a tuple of (name, size) pairs." ) from exc sizes = [named_shape[name] for name in names] @@ -1921,12 +1896,12 @@ def random_split( flat_split_keys = nmap(jax.random.split)(key, total_size) return flat_split_keys.reshape( - tuple(sizes) + flat_split_keys.positional_shape[1:] + tuple(sizes) + flat_split_keys.positional_shape[1:] ).tag_prefix(*names) def concatenate( - arrays: Sequence[NamedArrayBase], axis_name: AxisName + arrays: Sequence[NamedArrayBase], axis_name: AxisName ) -> NamedArray | NamedArrayView: """Concatenates a sequence of named arrays along a named axis. @@ -1942,21 +1917,20 @@ def concatenate( ndims = set(len(array.positional_shape) for array in arrays) if len(ndims) != 1: raise ValueError( - "All arrays must have the same number of positional axes, but got" - f" {ndims}" + f"All arrays must have the same number of positional axes, but got {ndims}" ) (ndim,) = ndims orig_positional_axes = [TmpPosAxisMarker() for _ in range(ndim)] arrays_along_axis = [ - array.tag(*orig_positional_axes).untag(axis_name) for array in arrays + array.tag(*orig_positional_axes).untag(axis_name) for array in arrays ] concatenated = nmap(jnp.concatenate)(arrays_along_axis) return concatenated.tag(axis_name).untag(*orig_positional_axes) def stack( - arrays: Sequence[NamedArrayBase], axis_name: AxisName + arrays: Sequence[NamedArrayBase], axis_name: AxisName ) -> NamedArray | NamedArrayView: """Stacks a sequence of named arrays along a named axis. @@ -1970,8 +1944,7 @@ def stack( ndims = set(len(array.positional_shape) for array in arrays) if len(ndims) != 1: raise ValueError( - "All arrays must have the same number of positional axes, but got" - f" {ndims}" + f"All arrays must have the same number of positional axes, but got {ndims}" ) (ndim,) = ndims @@ -1982,7 +1955,7 @@ def stack( def unstack( - array: NamedArrayBase, axis_name: AxisName + array: NamedArrayBase, axis_name: AxisName ) -> Sequence[NamedArray | NamedArrayView]: """Splits a named array across a given named axis. @@ -2023,13 +1996,11 @@ def _fix(val, ref): else: return val - return jax.tree_util.tree_map( - _fix, value_tree, reference_tree, is_leaf=is_namedarray - ) + return jax.tree_util.tree_map(_fix, value_tree, reference_tree, is_leaf=is_namedarray) def scan( - f: Callable[[Any, Any], Any], axis: AxisName, init, xs=None, **scan_kwargs + f: Callable[[Any, Any], Any], axis: AxisName, init, xs=None, **scan_kwargs ) -> Any: """Scan a function over a named array axis while carrying along state. @@ -2092,17 +2063,15 @@ def wrapped_f(carry, x): new_carry, y = f(carry, x) new_carry = order_like(new_carry, carry) y = jax.tree_util.tree_map( - lambda v: v.with_positional_prefix() if is_namedarray(v) else v, - y, - is_leaf=is_namedarray, + lambda v: v.with_positional_prefix() if is_namedarray(v) else v, + y, + is_leaf=is_namedarray, ) return new_carry, y # Run the scan, which will slice off the positional prefix from the inputs, # and add a positional prefix to the outputs. - final_carry, ys_untagged = jax.lax.scan( - wrapped_f, init, xs_untagged, **scan_kwargs - ) + final_carry, ys_untagged = jax.lax.scan(wrapped_f, init, xs_untagged, **scan_kwargs) # Re-assign the scanned-over axis. def _retag(leaf): diff --git a/penzai/nn/linear_and_affine.py b/penzai/nn/linear_and_affine.py index f696faf..c7ec95e 100644 --- a/penzai/nn/linear_and_affine.py +++ b/penzai/nn/linear_and_affine.py @@ -15,13 +15,11 @@ """Generalized linear operator layer and associated utilities.""" from __future__ import annotations - import collections import dataclasses import functools import itertools -from typing import Any, Literal, Protocol, Sequence - +from typing import Any, Literal, Protocol, Sequence, cast import jax import jax.numpy as jnp from penzai.core import named_axes @@ -31,6 +29,7 @@ from penzai.nn import grouping from penzai.nn import layer as layer_base from penzai.nn import parameters +import abc NamedArray = named_axes.NamedArray Parameter = variables.Parameter @@ -161,18 +160,24 @@ def variance_scaling_initializer( return named_axes.wrap(array).tag(*names) -xavier_uniform_initializer = functools.partial( - variance_scaling_initializer, - scale=1.0, - mode="fan_avg", - distribution="uniform", +xavier_uniform_initializer = cast( + LinearOperatorWeightInitializer, + functools.partial( + variance_scaling_initializer, + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), ) -xavier_normal_initializer = functools.partial( - variance_scaling_initializer, - scale=1.0, - mode="fan_avg", - distribution="normal", +xavier_normal_initializer = cast( + LinearOperatorWeightInitializer, + functools.partial( + variance_scaling_initializer, + scale=1.0, + mode="fan_avg", + distribution="normal", + ), ) @@ -247,6 +252,37 @@ def treescope_color(self) -> tuple[str, str]: return "#eba875", "color-mix(in oklab, #eba875 25%, white)" +@struct.pytree_dataclass +class ConvInPlace(grouping.Sequential): + """Container for "in-place" convolution operators that preserve axis names. + + This is used when initializing `Conv` layers that have overlapping names in + their input and output specifications. We subclass `Sequential` to make + this layer type easier to identify and manipulate. + """ + + sublayers: list[layer_base.Layer] + + def treescope_color(self) -> tuple[str, str]: + return "#79eb75", "color-mix(in oklab, #79eb75 25%, white)" + + +@struct.pytree_dataclass +class ConvTransposeInPlace(grouping.Sequential): + """Container for "in-place" transposed convolution operators that preserve + axis names. + + This is used when initializing `ConvTranspose` layers that have overlapping + names in their input and output specifications. We subclass `Sequential` to + make this layer type easier to identify and manipulate. + """ + + sublayers: list[layer_base.Layer] + + def treescope_color(self) -> tuple[str, str]: + return "#c7eb75", "color-mix(in oklab, #c7eb75 25%, white)" + + def contract( names: str | Sequence[named_axes.AxisName], left: NamedArray, @@ -395,12 +431,63 @@ def _output_structure(self) -> shapecheck.StructureAnnotation: ) +def maybe_rename_output_axes( + input_axes: dict[str, int], + output_axes: dict[str, int], + parallel_axes: dict[str, int], + parallel_broadcast_axes: dict[str, int], + rename_outputs_if_necessary: bool, +): + # By default no rename & no wrapping + output_axes_after_rename = output_axes + primed_names, original_names = None, None + + if any(name in input_axes for name in output_axes): + # Name overlap! + if rename_outputs_if_necessary: + output_axes_after_rename = {} + original_names = [] + primed_names = [] + + for old_name in output_axes.keys(): + if old_name in input_axes: + primed_name = old_name + "_out" + if primed_name in input_axes: + raise ValueError( + f"Tried to rename {old_name} to {primed_name} to avoid a" + " conflict, but both names are already in input_axes. Please" + " rename axes manually to avoid this conflict." + ) + original_names.append(old_name) + primed_names.append(primed_name) + output_axes_after_rename[primed_name] = output_axes[old_name] + else: + output_axes_after_rename[old_name] = output_axes[old_name] + else: + raise ValueError( + "input_axes and output_axes must not overlap if" + " rename_outputs_if_necessary is not set; got" + f" input_axes={input_axes}, output_axes={output_axes}." + ) + + if set(parallel_axes).intersection(set(input_axes).union(output_axes)) or set( + parallel_broadcast_axes + ).intersection(set(input_axes).union(output_axes, parallel_axes)): + raise ValueError( + "parallel_axes and parallel_broadcast_axes must not overlap with" + f" each other or with input/output axes; got input_axes={input_axes}," + f" output_axes={output_axes}, parallel_axes={parallel_axes}," + f" parallel_broadcast_axes={parallel_broadcast_axes}." + ) + return output_axes_after_rename, primed_names, original_names + + @struct.pytree_dataclass class Linear(layer_base.Layer): """A generalized linear (not affine) operator, for named arrays. Applies an arbitrary contraction to the input `NamedArray` and a weight - parameter. This can be used to express an arbitrary linear operator. + parameter. This can be used to express an arbitrary dense linear operator. ``Linear`` layers are often (but not always) followed by `AddBias` to make an affine transformation. @@ -504,80 +591,45 @@ def from_config( parallel_axes = {} if parallel_broadcast_axes is None: parallel_broadcast_axes = {} - if any(name in input_axes for name in output_axes): - # Name overlap! - if rename_outputs_if_necessary: - output_axes_after_rename = {} - original_names = [] - primed_names = [] - - for old_name in output_axes.keys(): - if old_name in input_axes: - primed_name = old_name + "_out" - if primed_name in input_axes: - raise ValueError( - f"Tried to rename {old_name} to {primed_name} to avoid a" - " conflict, but both names are already in input_axes. Please" - " rename axes manually to avoid this conflict." - ) - original_names.append(old_name) - primed_names.append(primed_name) - output_axes_after_rename[primed_name] = output_axes[old_name] - else: - output_axes_after_rename[old_name] = output_axes[old_name] - - return LinearInPlace( - sublayers=[ - cls.from_config( - name=name, - init_base_rng=init_base_rng, - input_axes=input_axes, - output_axes=output_axes_after_rename, - parallel_axes=parallel_axes, - parallel_broadcast_axes=parallel_broadcast_axes, - initializer=initializer, - dtype=dtype, - rename_outputs_if_necessary=False, - ), - RenameAxes(old=tuple(primed_names), new=tuple(original_names)), - ], - ) - else: - raise ValueError( - "input_axes and output_axes must not overlap if" - " rename_outputs_if_necessary is not set; got" - f" input_axes={input_axes}, output_axes={output_axes}." - ) - if set(parallel_axes).intersection( - set(input_axes).union(output_axes) - ) or set(parallel_broadcast_axes).intersection( - set(input_axes).union(output_axes, parallel_axes) - ): - raise ValueError( - "parallel_axes and parallel_broadcast_axes must not overlap with" - f" each other or with input/output axes; got input_axes={input_axes}," - f" output_axes={output_axes}, parallel_axes={parallel_axes}," - f" parallel_broadcast_axes={parallel_broadcast_axes}." - ) + output_axes_after_rename, primed_names, original_names = ( + maybe_rename_output_axes( + input_axes, + output_axes, + parallel_axes, + parallel_broadcast_axes, + rename_outputs_if_necessary, + ) + ) - return cls( + core_layer = cls( weights=parameters.make_parameter( f"{name}.weights", init_base_rng, initializer, input_axes=input_axes, - output_axes=output_axes, + output_axes=output_axes_after_rename, parallel_axes={**parallel_axes, **parallel_broadcast_axes}, convolution_spatial_axes={}, dtype=dtype, ), in_axis_names=tuple(input_axes.keys()), out_axis_names=( - tuple(output_axes.keys()) + tuple(parallel_broadcast_axes.keys()) + tuple(output_axes_after_rename.keys()) + + tuple(parallel_broadcast_axes.keys()) ), ) + # if name overlap wrap layer + if primed_names is not None and original_names is not None: + return LinearInPlace( + sublayers=[ + core_layer, + RenameAxes(old=tuple(primed_names), new=tuple(original_names)), + ], + ) + return core_layer + def _input_structure(self): known_in_axes = { name: size @@ -771,3 +823,584 @@ class ConstantRescale(layer_base.Layer): def __call__(self, value: Any, **_unused_side_inputs) -> Any: """Scales its input by the scaling factor.""" return jax.tree_util.tree_map(lambda x: x * self.by, value) + + +def prepare_for_conv( + inputs: NamedArray, + kernel: NamedArray, + spatial_axis_names: Sequence[str], + in_axis_names: Sequence[str], + out_axis_names: Sequence[str], +): + """Preprocess lhs and rhs for jax convolution operator""" + + lhs = inputs + rhs = kernel + + in_axis_name = "in_axis-" + "-".join(in_axis_names) + out_axis_name = "out_axis-" + "-".join(out_axis_names) + + # merge in axes into one in channel axis for the inputs and the kernel + lhs = lhs.untag(*in_axis_names).flatten().tag(in_axis_name) + rhs = rhs.untag(*in_axis_names).flatten().tag(in_axis_name) + + # merge out axes into one out channels axis for jax convolution + rhs = rhs.untag(*out_axis_names).flatten().tag(out_axis_name) + + # untag spatial axes + lhs = lhs.untag(*spatial_axis_names, in_axis_name) + rhs = rhs.untag(*spatial_axis_names, in_axis_name, out_axis_name) + return lhs, rhs + + +def get_named_axis_back_after_conv( + result: NamedArray, + spatial_axis_names: Sequence[str], + out_axis_names: Sequence[str], + out_axis_shape: Sequence[int], +): + """Postprocess result from jax convolution operator""" + # Get named axes back + return ( + result.tag_prefix(*spatial_axis_names) + .reshape(out_axis_shape) + .tag(*out_axis_names) + ) + + +def maybe_broadcast(value: int | Sequence[int], count: int): + return [value] * count if isinstance(value, int) else value + + +def get_dimension_numbers(ndim): + return jax.lax.ConvDimensionNumbers( + lhs_spec=(0, ndim + 1) + + tuple(range(1, ndim + 1)), # BHSpatial -> BCSpatial + rhs_spec=(ndim + 1, ndim) + tuple(range(ndim)), # SpatialIO -> OISpatial + out_spec=(0, ndim + 1) + + tuple(range(1, ndim + 1)), # BSpatialC -> BCSpatial + ) + + +@struct.pytree_dataclass +class AbstractGeneralConv(layer_base.Layer): + kernel: parameters.ParameterLike[NamedArray] + strides: Sequence[int] + padding: str | Sequence[tuple[int, int]] + kernel_dilation: Sequence[int] + + spatial_axis_names: tuple[str, ...] = dataclasses.field( + metadata={"pytree_node": False} + ) + in_axis_names: tuple[str, ...] = dataclasses.field( + metadata={"pytree_node": False} + ) + out_axis_names: tuple[str, ...] = dataclasses.field( + metadata={"pytree_node": False} + ) + + def _input_structure(self): + known_in_axes = { + name: size + for name, size in self.kernel.value.named_shape.items() + if name not in self.out_axis_names + and name not in self.spatial_axis_names + } + return shapecheck.ArraySpec( + named_shape={**shapecheck.var("B"), **known_in_axes}, + dtype=jnp.floating, + ) + + def _output_structure(self): + known_out_axes = { + name: size + for name, size in self.kernel.value.named_shape.items() + if name not in self.in_axis_names + and name not in self.spatial_axis_names + } + return shapecheck.ArraySpec( + named_shape={**shapecheck.var("B"), **known_out_axes}, + dtype=jnp.floating, + ) + + @property + def input_axes(self) -> dict[str, int]: + """The axis names and sizes that should appear in the input only.""" + return { # pytype: disable=bad-return-type + name: size + for name, size in self.kernel.value.named_shape.items() + if name in self.in_axis_names + } + + @property + def output_axes(self) -> dict[str, int]: + """The axis names and sizes that will appear in the output only.""" + return { # pytype: disable=bad-return-type + name: size + for name, size in self.kernel.value.named_shape.items() + if name in self.out_axis_names + } + + @property + def parallel_axes(self) -> dict[str, int]: + """The axis names and sizes that should appear in both input and output.""" + return { # pytype: disable=bad-return-type + name: size + for name, size in self.kernel.value.named_shape.items() + if name not in self.convolution_spatial_axis_names + and name not in self.in_axis_names + and name not in self.out_axis_names + } + + @property + def convolution_spatial_axes(self) -> dict[str, int]: + """The spatial axis names and sizes of the convolution kernel that should + appear in both input and output. Note that that the sizes are only related + to the kernel shape""" + return { # pytype: disable=bad-return-type + name: size + for name, size in self.kernel.value.named_shape.items() + if name in self.spatial_axes_names + } + + +@struct.pytree_dataclass +class Conv(AbstractGeneralConv): + """A general convolution operator, for named arrays. + + Applies an arbitrary contraction to the input `NamedArray` and a weight + parameter. This can be used to express an arbitrary linear convolution operator. + + Attributes: + kernel: The named array holding the kernel for the convlution operator. + strides: The stride of the convolution operator + padding: The padding to apply to the input before the convolution + inputs_dilation: The input dilation of the convolution + kernel_dilation: The kernel dilation of the convolution + convolution_spatial_axis_names: The names of the spatial axes over wich to + apply the convolution operator + in_axis_names: The names of the axes to contract with the input, removing + them. + out_axis_names: The names of the axes that should not appear in the input + and will be inserted into the output. + """ + + kernel: parameters.ParameterLike[NamedArray] + strides: Sequence[int] + padding: str | Sequence[tuple[int, int]] + kernel_dilation: Sequence[int] + + spatial_axis_names: tuple[str, ...] = dataclasses.field( + metadata={"pytree_node": False} + ) + in_axis_names: tuple[str, ...] = dataclasses.field( + metadata={"pytree_node": False} + ) + out_axis_names: tuple[str, ...] = dataclasses.field( + metadata={"pytree_node": False} + ) + inputs_dilation: Sequence[int] + + def __call__(self, in_array: NamedArray, **_side_inputs) -> NamedArray: + """Runs the Convolution operator.""" + in_struct = self._input_structure() + + # pytype: disable=attribute-error + if isinstance( + self.kernel, + Parameter | ParameterValue, + ) and self.kernel.label.endswith(".kernel"): + error_prefix = f"({self.kernel.label[: 7]}) " + else: + error_prefix = "" + # pytype: enable=attribute-error + + dimvars = shapecheck.check_structure( + in_array, in_struct, error_prefix=error_prefix + ) + + print(in_array) + lhs, rhs = prepare_for_conv( + in_array, + self.kernel.value, + self.spatial_axis_names, + self.in_axis_names, + self.out_axis_names, + ) + + # Perform actual convolution + result = named_axes.nmap( + lambda lhs, rhs: jax.lax.conv_general_dilated( + lhs=lhs[None, ...], + rhs=rhs, + window_strides=self.strides, + padding=self.padding, + lhs_dilation=self.inputs_dilation, + rhs_dilation=self.kernel_dilation, + dimension_numbers=get_dimension_numbers( + ndim=len(self.spatial_axis_names) + ), + )[0] + )(lhs, rhs) + + result = get_named_axis_back_after_conv( + result, + self.spatial_axis_names, + self.out_axis_names, + [self.output_axes[name] for name in self.out_axis_names], + ) + + out_struct = self._output_structure() + shapecheck.check_structure( + result, out_struct, known_vars=dimvars, error_prefix=error_prefix + ) + return result + + @classmethod + def from_config( + cls, + name: str, + init_base_rng: jax.Array | None, + input_axes: dict[str, int], + output_axes: dict[str, int], + convolution_spatial_axes: dict[str, int], + strides: int | Sequence[int] = 1, + padding: str | Sequence[tuple[int, int]] = "SAME", + inputs_dilation: int | Sequence[int] = 1, + kernel_dilation: int | Sequence[int] = 1, + parallel_axes: dict[str, int] | None = None, + parallel_broadcast_axes: dict[str, int] | None = None, + initializer: LinearOperatorWeightInitializer = xavier_uniform_initializer, + dtype: jax.typing.DTypeLike = jnp.float32, + rename_outputs_if_necessary: bool = True, + ) -> Conv | ConvInPlace: + """Constructs a ``Conv`` layer from a configuration. + + This can be used when building a new convolution operator at the start of + training. + + Note: For the purposes of the initializer, the ``parallel_axes`` and + ``parallel_broadcast_axes`` are treated in the same way, without + participating in output-dimension variance scaling. However, after + initialization, the ``parallel_broadcast_axes`` will be treated like extra + output axes (and assumed not to be present in the input). + + Args: + name: The name of the layer. + init_base_rng: The base RNG to use for initializing model parameters. + input_axes: Names and lengths for axes that the linear operator should + contract over. + output_axes: Names and lengths for new axes that the linear operator + should produce. If any axis names overlap with ``input_axes``, the + argument ``rename_outputs_if_necessary`` must be True. + convolution_spatial_axes: Names and lengths for the spatial axes of the + convolution kernel. + strides: strides of the convolution, if strides is an integer, it is + broadcasted to every spatial dimensions + padding: The padding to apply to the input before the convolution. Can be + either the strings ‘SAME’, ‘SAME_LOWER’, or ‘VALID’, or a sequence + of n (low, high) integer pairs that give the padding to apply before and + after each spatial dimension. ‘SAME’ and ‘SAME_LOWER’ add padding to + produce same output size as the input when the stride is one. The + padding is split between the two sides equally or almost equally. In + case the padding is an odd number, the extra padding is added at the end + for ‘SAME’ and at the + beginning for ‘SAME_LOWER’. + inputs_dilation: inputs dilation of the convolution, if inputs_dilation is + an integer, it is broadcasted to every spatial dimensions + kernel_dilation: kernel dilation of the convolution, if kernel_dilation is + an integer, it is broadcasted to every spatial dimensions + parallel_axes: Names and lengths for axes that should be processed in + parallel. These axes should appear in both the input and the output, and + the resulting convolution operator will apply a different operator to + each slice. (This is similar to a grouped convolution) Must not overlap + with any axes named in ``input_axes`` or ``output_axes``. + parallel_broadcast_axes: Names and lengths for axes that should be treated + like ``parallel_axes`` but will only appear in the output. The input + will be implicitly broadcast over these axes. Must not overlap with any + axes named in ``input_axes``, ``output_axes`` or ``parallel_axes``. + initializer: Function to use to initialize the kernel. + dtype: Dtype for the kernel. + rename_outputs_if_necessary: If True, and if ``output_axes`` and + ``input_axes`` have overlapping names, avoids name conflicts by adding + "primed" versions of the overlapping names, and returns an instance of + `ConvInPlace` instead of a ``Conv`` layer directly. + + Returns: + A ``Conv`` layer with uninitialized kernel, or possibly a + `ConvInPlace` layer if ``rename_outputs_if_necessary`` is True and + ``input_axes`` overlaps with ``output_axes``. + """ + spatial_dim_count = len(convolution_spatial_axes) + + strides = maybe_broadcast(strides, spatial_dim_count) + inputs_dilation = maybe_broadcast(inputs_dilation, spatial_dim_count) + kernel_dilation = maybe_broadcast(kernel_dilation, spatial_dim_count) + + if parallel_axes is None: + parallel_axes = {} + if parallel_broadcast_axes is None: + parallel_broadcast_axes = {} + + output_axes_after_rename, primed_names, original_names = ( + maybe_rename_output_axes( + input_axes, + output_axes, + parallel_axes, + parallel_broadcast_axes, + rename_outputs_if_necessary, + ) + ) + + core_layer = cls( + kernel=parameters.make_parameter( + f"{name}.kernel", + init_base_rng, + initializer, + input_axes=input_axes, + output_axes=output_axes_after_rename, + parallel_axes={**parallel_axes, **parallel_broadcast_axes}, + convolution_spatial_axes=convolution_spatial_axes, + dtype=dtype, + ), + strides=strides, + padding=padding, + inputs_dilation=inputs_dilation, + kernel_dilation=kernel_dilation, + spatial_axis_names=tuple(convolution_spatial_axes.keys()), + in_axis_names=tuple(input_axes.keys()), + out_axis_names=( + tuple(output_axes_after_rename.keys()) + + tuple(parallel_broadcast_axes.keys()) + ), + ) + + # if name overlap wrap layer + if primed_names is not None and original_names is not None: + return ConvInPlace( + sublayers=[ + core_layer, + RenameAxes(old=tuple(primed_names), new=tuple(original_names)), + ], + ) + return core_layer + + def treescope_color(self) -> str: + return "#79eb75" + + +@struct.pytree_dataclass +class ConvTranspose(AbstractGeneralConv): + """A general transposed convolution operator, for named arrays. + + Applies an arbitrary contraction to the input `NamedArray` and a kernel + parameter. This can be used to express an arbitrary linear transposed + convolution operator. + + Attributes: + kernel: The named array holding the kernel for the convlution operator. + strides: The stride of the convolution operator + padding: The padding to apply to the input before the convolution + kernel_dilation: The kernel dilation of the convolution + convolution_spatial_axis_names: The names of the spatial axes over wich to + apply the convolution operator + in_axis_names: The names of the axes to contract with the input, removing + them. + out_axis_names: The names of the axes that should not appear in the input + and will be inserted into the output. + """ + + kernel: parameters.ParameterLike[NamedArray] + strides: Sequence[int] + padding: str | Sequence[tuple[int, int]] + kernel_dilation: Sequence[int] + + spatial_axis_names: tuple[str, ...] = dataclasses.field( + metadata={"pytree_node": False} + ) + in_axis_names: tuple[str, ...] = dataclasses.field( + metadata={"pytree_node": False} + ) + out_axis_names: tuple[str, ...] = dataclasses.field( + metadata={"pytree_node": False} + ) + + def __call__(self, in_array: NamedArray, **_side_inputs) -> NamedArray: + """Runs the Convolution operator.""" + in_struct = self._input_structure() + + # pytype: disable=attribute-error + if isinstance( + self.kernel, + Parameter | ParameterValue, + ) and self.kernel.label.endswith(".kernel"): + error_prefix = f"({self.kernel.label[: 7]}) " + else: + error_prefix = "" + # pytype: enable=attribute-error + + dimvars = shapecheck.check_structure( + in_array, in_struct, error_prefix=error_prefix + ) + + lhs, rhs = prepare_for_conv( + in_array, + self.kernel.value, + self.spatial_axis_names, + self.in_axis_names, + self.out_axis_names, + ) + + # Perform actual transposed convolution + result = named_axes.nmap( + lambda lhs, rhs: jax.lax.conv_transpose( + lhs=lhs[None, ...], + rhs=rhs, + strides=self.strides, + padding=self.padding, + rhs_dilation=self.kernel_dilation, + dimension_numbers=get_dimension_numbers( + ndim=len(self.spatial_axis_names) + ), + )[0] + )(lhs, rhs) + + result = get_named_axis_back_after_conv( + result, + self.spatial_axis_names, + self.out_axis_names, + [self.output_axes[name] for name in self.out_axis_names], + ) + + out_struct = self._output_structure() + shapecheck.check_structure( + result, out_struct, known_vars=dimvars, error_prefix=error_prefix + ) + return result + + @classmethod + def from_config( + cls, + name: str, + init_base_rng: jax.Array | None, + input_axes: dict[str, int], + output_axes: dict[str, int], + convolution_spatial_axes: dict[str, int], + strides: int | Sequence[int] = 1, + padding: str | Sequence[tuple[int, int]] = "SAME", + kernel_dilation: int | Sequence[int] = 1, + parallel_axes: dict[str, int] | None = None, + parallel_broadcast_axes: dict[str, int] | None = None, + initializer: LinearOperatorWeightInitializer = xavier_uniform_initializer, + dtype: jax.typing.DTypeLike = jnp.float32, + rename_outputs_if_necessary: bool = True, + ) -> ConvTranspose | ConvTransposeInPlace: + """Constructs a ``Conv`` layer from a configuration. + + This can be used when building a new convolution operator at the start of + training. + + Note: For the purposes of the initializer, the ``parallel_axes`` and + ``parallel_broadcast_axes`` are treated in the same way, without + participating in output-dimension variance scaling. However, after + initialization, the ``parallel_broadcast_axes`` will be treated like extra + output axes (and assumed not to be present in the input). + + Args: + name: The name of the layer. + init_base_rng: The base RNG to use for initializing model parameters. + input_axes: Names and lengths for axes that the linear operator should + contract over. + output_axes: Names and lengths for new axes that the linear operator + should produce. If any axis names overlap with ``input_axes``, the + argument ``rename_outputs_if_necessary`` must be True. + convolution_spatial_axes: Names and lengths for the spatial axes of the + convolution kernel. + strides: strides of the convolution, if strides is an integer, it is + broadcasted to every spatial dimensions + padding: The padding to apply to the input before the convolution. Can be + either the strings ‘SAME’, ‘SAME_LOWER’, or ‘VALID’, or a sequence + of n (low, high) integer pairs that give the padding to apply before and + after each spatial dimension. ‘SAME’ and ‘SAME_LOWER’ add padding to + produce same output size as the input when the stride is one. The + padding is split between the two sides equally or almost equally. In + case the padding is an odd number, the extra padding is added at the end + for ‘SAME’ and at the + beginning for ‘SAME_LOWER’. + kernel_dilation: kernel dilation of the convolution, if kernel_dilation is + an integer, it is broadcasted to every spatial dimensions + parallel_axes: Names and lengths for axes that should be processed in + parallel. These axes should appear in both the input and the output, and + the resulting convolution operator will apply a different operator to + each slice. (This is similar to a grouped convolution) Must not overlap + with any axes named in ``input_axes`` or ``output_axes``. + parallel_broadcast_axes: Names and lengths for axes that should be treated + like ``parallel_axes`` but will only appear in the output. The input + will be implicitly broadcast over these axes. Must not overlap with any + axes named in ``input_axes``, ``output_axes`` or ``parallel_axes``. + initializer: Function to use to initialize the kernel. + dtype: Dtype for the kernel. + rename_outputs_if_necessary: If True, and if ``output_axes`` and + ``input_axes`` have overlapping names, avoids name conflicts by adding + "primed" versions of the overlapping names, and returns an instance of + `ConvInPlace` instead of a ``Conv`` layer directly. + + Returns: + A ``ConvTranspose`` layer with uninitialized kernel, or possibly a + `ConvTransposeInPlace` layer if ``rename_outputs_if_necessary`` is True + and ``input_axes`` overlaps with ``output_axes``. + """ + spatial_dim_count = len(convolution_spatial_axes) + + strides = maybe_broadcast(strides, spatial_dim_count) + kernel_dilation = maybe_broadcast(kernel_dilation, spatial_dim_count) + + if parallel_axes is None: + parallel_axes = {} + if parallel_broadcast_axes is None: + parallel_broadcast_axes = {} + + output_axes_after_rename, primed_names, original_names = ( + maybe_rename_output_axes( + input_axes, + output_axes, + parallel_axes, + parallel_broadcast_axes, + rename_outputs_if_necessary, + ) + ) + + core_layer = cls( + kernel=parameters.make_parameter( + f"{name}.kernel", + init_base_rng, + initializer, + input_axes=input_axes, + output_axes=output_axes_after_rename, + parallel_axes={**parallel_axes, **parallel_broadcast_axes}, + convolution_spatial_axes=convolution_spatial_axes, + dtype=dtype, + ), + strides=strides, + padding=padding, + kernel_dilation=kernel_dilation, + spatial_axis_names=tuple(convolution_spatial_axes.keys()), + in_axis_names=tuple(input_axes.keys()), + out_axis_names=( + tuple(output_axes_after_rename.keys()) + + tuple(parallel_broadcast_axes.keys()) + ), + ) + + # if name overlap wrap layer + if primed_names is not None and original_names is not None: + return ConvTransposeInPlace( + sublayers=[ + core_layer, + RenameAxes(old=tuple(primed_names), new=tuple(original_names)), + ], + ) + return core_layer + + def treescope_color(self) -> str: + return "#c7eb75" diff --git a/penzai/pz/nn.py b/penzai/pz/nn.py index fcd79fc..6069c86 100644 --- a/penzai/pz/nn.py +++ b/penzai/pz/nn.py @@ -74,6 +74,10 @@ Linear, LinearOperatorWeightInitializer, LinearInPlace, + Conv, + ConvInPlace, + ConvTranspose, + ConvTransposeInPlace, RenameAxes, contract, variance_scaling_initializer, diff --git a/tests/nn/linear_and_affine_test.py b/tests/nn/linear_and_affine_test.py index 3c62467..f652b02 100644 --- a/tests/nn/linear_and_affine_test.py +++ b/tests/nn/linear_and_affine_test.py @@ -163,6 +163,62 @@ def test_affine(self): ), ) + def test_conv(self): + layer = pz.nn.Conv.from_config( + name="test", + init_base_rng=jax.random.key(1), + input_axes={"foo": 3}, + output_axes={"foo": 5}, + convolution_spatial_axes={"height": 3, "width": 3}, + parallel_axes={"baz": 7}, + parallel_broadcast_axes={"qux": 11}, + rename_outputs_if_necessary=True, + ) + result = layer( + pz.nx.ones({"batch": 1, "height": 10, "width": 15, "foo": 3, "baz": 7}), + ) + pz.chk.check_structure( + result, + pz.chk.ArraySpec( + named_shape={ + "batch": 1, + "height": 10, + "width": 15, + "foo": 5, + "baz": 7, + "qux": 11, + } + ), + ) + + def test_conv_transpose(self): + layer = pz.nn.ConvTranspose.from_config( + name="test", + init_base_rng=jax.random.key(1), + input_axes={"foo": 3}, + output_axes={"foo": 5}, + convolution_spatial_axes={"height": 3, "width": 3}, + parallel_axes={"baz": 7}, + parallel_broadcast_axes={"qux": 11}, + rename_outputs_if_necessary=True, + ) + result = layer( + pz.nx.ones({"batch": 1, "height": 10, "width": 15, "foo": 3, "baz": 7}), + ) + pz.chk.check_structure( + result, + pz.chk.ArraySpec( + named_shape={ + "batch": 1, + "height": 10, + "width": 15, + "foo": 5, + "baz": 7, + "qux": 11, + } + ), + ) + def test_constant_rescale(self): layer = pz.nn.ConstantRescale(3.0) result = layer(pz.nx.ones({"foo": 3})) diff --git a/uv.lock b/uv.lock index 7e1588b..575abf5 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.10" resolution-markers = [ "python_full_version < '3.11'", @@ -359,7 +360,7 @@ name = "click" version = "8.1.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } wheels = [ @@ -862,7 +863,7 @@ name = "ipykernel" version = "6.29.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "appnope", marker = "platform_system == 'Darwin'" }, + { name = "appnope", marker = "sys_platform == 'darwin'" }, { name = "comm" }, { name = "debugpy" }, { name = "ipython" }, @@ -1980,7 +1981,6 @@ name = "nvidia-nccl-cu12" version = "2.20.5" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/bb/d09dda47c881f9ff504afd6f9ca4f502ded6d8fc2f572cacc5e39da91c28/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01", size = 176238458 }, { url = "https://files.pythonhosted.org/packages/4b/2a/0a131f572aa09f741c30ccd45a8e56316e8be8dfc7bc19bf0ab7cfef7b19/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:057f6bf9685f75215d0c53bf3ac4a10b3e6578351de307abad9e18a99182af56", size = 176249402 }, ] @@ -1989,7 +1989,6 @@ name = "nvidia-nvjitlink-cu12" version = "12.6.68" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/58/8c/69c9e39cd6bfa813852a94e9bd3c075045e2707d163e9dc2326c82d2c330/nvidia_nvjitlink_cu12-12.6.68-py3-none-manylinux2014_aarch64.whl", hash = "sha256:b3fd0779845f68b92063ab1393abab1ed0a23412fc520df79a8190d098b5cd6b", size = 19253287 }, { url = "https://files.pythonhosted.org/packages/a8/48/a9775d377cb95585fb188b469387f58ba6738e268de22eae2ad4cedb2c41/nvidia_nvjitlink_cu12-12.6.68-py3-none-manylinux2014_x86_64.whl", hash = "sha256:125a6c2a44e96386dda634e13d944e60b07a0402d391a070e8fb4104b34ea1ab", size = 19725597 }, ] @@ -2211,6 +2210,7 @@ requires-dist = [ { name = "treescope", specifier = ">=0.1.9" }, { name = "typing-extensions", specifier = ">=4.2" }, ] +provides-extras = ["dev", "docs", "extras", "notebook"] [[package]] name = "pexpect" @@ -3438,19 +3438,19 @@ dependencies = [ { name = "fsspec" }, { name = "jinja2" }, { name = "networkx" }, - { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "sympy" }, - { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "typing-extensions" }, ] wheels = [ @@ -3491,7 +3491,7 @@ name = "tqdm" version = "4.66.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/58/83/6ba9844a41128c62e810fddddd72473201f3eacde02046066142a2d96cc5/tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad", size = 169504 } wheels = [ From d025e56b79c71a92e752d8759edf7d9587dd926b Mon Sep 17 00:00:00 2001 From: AntoinePlumerault Date: Mon, 9 Jun 2025 21:09:05 +0000 Subject: [PATCH 02/18] fix formatting --- penzai/core/named_axes.py | 527 ++++++++++++++++++++------------------ 1 file changed, 279 insertions(+), 248 deletions(-) diff --git a/penzai/core/named_axes.py b/penzai/core/named_axes.py index 0dafc37..6d64df4 100644 --- a/penzai/core/named_axes.py +++ b/penzai/core/named_axes.py @@ -170,15 +170,15 @@ def nmap(fun: Callable[..., Any]) -> Callable[..., Any]: def _nmap_with_doc( - fun: Callable[..., Any], fun_name: str, fun_doc: str | None = None + fun: Callable[..., Any], fun_name: str, fun_doc: str | None = None ) -> Callable[..., Any]: """Builds a nmap-wrapped function with a docstring.""" @functools.wraps(fun) def wrapped_fun(*args, **kwargs): arg_leaves_and_paths, arg_treedef = jax.tree_util.tree_flatten_with_path( - (args, kwargs), - is_leaf=lambda node: isinstance(node, NamedArray | NamedArrayView), + (args, kwargs), + is_leaf=lambda node: isinstance(node, NamedArray | NamedArrayView), ) arg_leaves = [leaf for _, leaf in arg_leaves_and_paths] # Extract any argument leaves that were NamedArrays or NamedArrayViews. The @@ -201,7 +201,10 @@ def wrapped_fun(*args, **kwargs): named_array_arg_leaves.append(leaf.as_namedarrayview()) if bad_names: - msg = [f"Inconsistent named axes in a call to nmap({fun}) for axes {bad_names}:"] + msg = [ + f"Inconsistent named axes in a call to nmap({fun}) for axes" + f" {bad_names}:" + ] for keypath, leaf in arg_leaves_and_paths: if isinstance(leaf, NamedArray | NamedArrayView): assert keypath @@ -271,20 +274,22 @@ def _shift_axis(other_axis): # data_array will still have the extra axis. But after running `vmap`, # it will be valid again. reduced_views.append( - NamedArrayView( - data_array=view.data_array, - data_axis_for_name={ - name: _shift_axis(data_axis) - for name, data_axis in view.data_axis_for_name.items() - if name != vmap_name - }, - data_axis_for_logical_axis=tuple( - _shift_axis(data_axis) for data_axis in view.data_axis_for_logical_axis - ), - data_shape=( - view.data_shape[:vmap_axis] + view.data_shape[vmap_axis + 1 :] - ), - ) + NamedArrayView( + data_array=view.data_array, + data_axis_for_name={ + name: _shift_axis(data_axis) + for name, data_axis in view.data_axis_for_name.items() + if name != vmap_name + }, + data_axis_for_logical_axis=tuple( + _shift_axis(data_axis) + for data_axis in view.data_axis_for_logical_axis + ), + data_shape=( + view.data_shape[:vmap_axis] + + view.data_shape[vmap_axis + 1 :] + ), + ) ) else: # This argument doesn't have this axis, so don't map over anything. @@ -292,13 +297,13 @@ def _shift_axis(other_axis): reduced_views.append(view) return jax.vmap( - functools.partial( - recursive_vectorize_step, - remaining_names=remaining_names[1:], - ), - in_axes=(vmap_axes,), - out_axes=0, - axis_name=vmap_name, + functools.partial( + recursive_vectorize_step, + remaining_names=remaining_names[1:], + ), + in_axes=(vmap_axes,), + out_axes=0, + axis_name=vmap_name, )(reduced_views) # Run the function. @@ -312,24 +317,24 @@ def handle_result(leaf): leaf = jnp.array(leaf) if leaf.ndim == len(all_names): return NamedArray( - data_array=leaf, - named_axes=collections.OrderedDict(zip(all_names, leaf.shape)), + data_array=leaf, + named_axes=collections.OrderedDict(zip(all_names, leaf.shape)), ) else: assert leaf.ndim > len(all_names) return NamedArrayView( - data_array=leaf, - data_shape=leaf.shape, - data_axis_for_name={name: i for i, name in enumerate(all_names)}, - data_axis_for_logical_axis=tuple(range(len(all_names), leaf.ndim)), + data_array=leaf, + data_shape=leaf.shape, + data_axis_for_name={name: i for i, name in enumerate(all_names)}, + data_axis_for_logical_axis=tuple(range(len(all_names), leaf.ndim)), ) return jax.tree_util.tree_map(handle_result, result_data) docstr = ( - f"Name-vectorized version of `{fun_name}`. Takes similar arguments as" - f" `{fun_name}` but accepts and returns NamedArrays (or NamedArrayViews)" - " in place of regular arrays." + f"Name-vectorized version of `{fun_name}`. Takes similar arguments as" + f" `{fun_name}` but accepts and returns NamedArrays (or NamedArrayViews)" + " in place of regular arrays." ) if fun_doc: docstr += f"\n\nOriginal documentation:\n\n{fun_doc}" @@ -352,8 +357,8 @@ def _wrap_scalar_conversion(scalar_conversion): def wrapped_scalar_conversion(self: NamedArrayBase): if self.named_shape or self.positional_shape: raise ValueError( - "Cannot convert a non-scalar NamedArray or NamedArrayView with" - f" {scalar_conversion}" + "Cannot convert a non-scalar NamedArray or NamedArrayView with" + f" {scalar_conversion}" ) return scalar_conversion(self.unwrap()) @@ -369,17 +374,17 @@ def func(array, *args, **kwargs): array_method = getattr(jax.Array, name) wrapped_func = nmap(func) functools.update_wrapper( - wrapped_func, - array_method, - assigned=("__name__", "__qualname__", "__annotations__"), - updated=(), + wrapped_func, + array_method, + assigned=("__name__", "__qualname__", "__annotations__"), + updated=(), ) wrapped_func.__module__ = __name__ wrapped_func.__doc__ = ( - "Name-vectorized version of array method" - f" `{name} `. Takes similar arguments as" - f" `{name} ` but accepts and returns NamedArrays" - " (or NamedArrayViews) in place of regular arrays." + "Name-vectorized version of array method" + f" `{name} `. Takes similar arguments as" + f" `{name} ` but accepts and returns NamedArrays" + " (or NamedArrayViews) in place of regular arrays." ) return wrapped_func @@ -411,56 +416,56 @@ def unwrap(self): @functools.partial( - jax.jit, - static_argnames=[ - "indices_are_sorted", - "unique_indices", - "mode", - "fill_value", - ], + jax.jit, + static_argnames=[ + "indices_are_sorted", + "unique_indices", + "mode", + "fill_value", + ], ) @nmap def _jitted_nmapped_getitem( - array: jax.Array, - index_thunks: tuple[_StaticThunk | _DynamicThunk | _SliceThunk, ...], - *, - indices_are_sorted=False, - unique_indices=False, - mode=None, - fill_value=None, + array: jax.Array, + index_thunks: tuple[_StaticThunk | _DynamicThunk | _SliceThunk, ...], + *, + indices_are_sorted=False, + unique_indices=False, + mode=None, + fill_value=None, ): """JIT-compiled helper for getitem.""" indexer = tuple(thunk.unwrap() for thunk in index_thunks) return array.at[indexer].get( - indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, - mode=mode, - fill_value=fill_value, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, + mode=mode, + fill_value=fill_value, ) @functools.partial( - jax.jit, - static_argnames=["method", "indices_are_sorted", "unique_indices", "mode"], + jax.jit, + static_argnames=["method", "indices_are_sorted", "unique_indices", "mode"], ) @nmap def _jitted_nmapped_update( - array: jax.Array, - index_thunks: tuple[_StaticThunk | _DynamicThunk | _SliceThunk, ...], - values: jax.Array, - method: str, - *, - indices_are_sorted=False, - unique_indices=False, - mode=None, + array: jax.Array, + index_thunks: tuple[_StaticThunk | _DynamicThunk | _SliceThunk, ...], + values: jax.Array, + method: str, + *, + indices_are_sorted=False, + unique_indices=False, + mode=None, ): """JIT-compiled helper for in-place updates.""" indexer = tuple(thunk.unwrap() for thunk in index_thunks) return getattr(array.at[indexer], method)( - values, - indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, - mode=mode, + values, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, + mode=mode, ) @@ -514,16 +519,16 @@ def _partition_dict_index(self): # Sliced axis. sliced_axes.append(name) elif isinstance(index, int) or ( - isinstance(index, jax.Array | np.ndarray | NamedArrayBase) - and jnp.issubdtype(index.dtype, np.integer) + isinstance(index, jax.Array | np.ndarray | NamedArrayBase) + and jnp.issubdtype(index.dtype, np.integer) ): # Indexed axis. indexed_axes.append(name) else: raise TypeError( - "Unsupported index for a named axis using dict-style index:" - " expected a slice, an integer, an integer array, or None, but" - f" got {index}" + "Unsupported index for a named axis using dict-style index:" + " expected a slice, an integer, an integer array, or None, but" + f" got {index}" ) input_prefix_order = (*sliced_axes, *indexed_axes) @@ -538,13 +543,13 @@ def get(self, **kwargs) -> NamedArrayBase: # Dict indexing => desugar it to positional indexing over the requested # names. input_prefix_order, slice_order, output_prefix_order = ( - self._partition_dict_index() + self._partition_dict_index() ) return ( - self.array.untag_prefix(*input_prefix_order) - .at[tuple(self.indexer[name] for name in slice_order)] - .get(**kwargs) - .tag_prefix(*output_prefix_order) + self.array.untag_prefix(*input_prefix_order) + .at[tuple(self.indexer[name] for name in slice_order)] + .get(**kwargs) + .tag_prefix(*output_prefix_order) ) else: @@ -574,7 +579,7 @@ def _nmap_update_op(self, method: str, value, kwargs) -> NamedArrayBase: # Dict indexing => desugar it to positional indexing over the requested # names. input_prefix_order, slice_order, output_prefix_order = ( - self._partition_dict_index() + self._partition_dict_index() ) # Make sure the provided value has the necessary axes, by broadcasting it @@ -591,32 +596,36 @@ def _nmap_update_op(self, method: str, value, kwargs) -> NamedArrayBase: result_shape = result_structure.positional_shape num_new_positional_axes = len(result_shape) - len(value_shape) if num_new_positional_axes < 0 or not all( - vd == 1 or vd == rd - for vd, rd in zip( - value_shape, result_shape[len(result_shape) - len(value_shape) :] - ) + vd == 1 or vd == rd + for vd, rd in zip( + value_shape, result_shape[len(result_shape) - len(value_shape) :] + ) ): raise ValueError( - "Cannot provide updates with positional shape" - f" {value_shape} for an index whose result shape is" - f" {result_shape}! Update shape must be a" - " suffix of the result shape (or broadcastable to it)." + "Cannot provide updates with positional shape" + f" {value_shape} for an index whose result shape is" + f" {result_shape}! Update shape must be a" + " suffix of the result shape (or broadcastable to it)." ) if num_new_positional_axes: value = value[(None,) * num_new_positional_axes + (...,)] new_names = { - name: None for name in output_prefix_order if name not in value.named_shape + name: None + for name in output_prefix_order + if name not in value.named_shape } if new_names: value = value[new_names] # pylint: disable=protected-access return ( - self.array.untag_prefix(*input_prefix_order) - .at[tuple(self.indexer[name] for name in slice_order)] - ._nmap_update_op(method, value.untag_prefix(*output_prefix_order), kwargs) - .tag_prefix(*input_prefix_order) + self.array.untag_prefix(*input_prefix_order) + .at[tuple(self.indexer[name] for name in slice_order)] + ._nmap_update_op( + method, value.untag_prefix(*output_prefix_order), kwargs + ) + .tag_prefix(*input_prefix_order) ) # pylint: enable=protected-access @@ -640,7 +649,7 @@ def _nmap_update_op(self, method: str, value, kwargs) -> NamedArrayBase: index_thunks.append(_StaticThunk(c)) return _jitted_nmapped_update( - self.array, tuple(index_thunks), value, method, **kwargs + self.array, tuple(index_thunks), value, method, **kwargs ) def set(self, values, /, **kwargs): @@ -834,7 +843,8 @@ def tag_prefix(self, *axis_order: AxisName) -> NamedArray | NamedArrayView: # We implement `tag_prefix` using `tag` with temporary axis # identifiers. tmp_axis_ids = [ - TmpPosAxisMarker() for _ in range(len(self.positional_shape) - len(axis_order)) + TmpPosAxisMarker() + for _ in range(len(self.positional_shape) - len(axis_order)) ] return self.tag(*axis_order, *tmp_axis_ids).untag(*tmp_axis_ids) @@ -866,14 +876,14 @@ def order_as(self, *axis_order: AxisName) -> NamedArray: data_array = self.tag(*tmp_names).untag(*tmp_names, *axis_order).unwrap() return ( - NamedArray.wrap(data_array) - .tag(*tmp_names, *axis_order) - .untag(*tmp_names) - .with_positional_prefix() + NamedArray.wrap(data_array) + .tag(*tmp_names, *axis_order) + .untag(*tmp_names) + .with_positional_prefix() ) def order_like( - self, other: NamedArray | NamedArrayView + self, other: NamedArray | NamedArrayView ) -> NamedArray | NamedArrayView: """Ensures that this array's PyTree structure matches another array's. @@ -908,46 +918,46 @@ def order_like( elif isinstance(other, NamedArrayView): if len(self.positional_shape) != len(other.positional_shape): raise ValueError( - "Calling `order_like` with a NamedArrayView requires the two" - " arrays to have the same number of positional axes, but got" - f" positional shapes {self.positional_shape=}," - f" {other.positional_shape=}" + "Calling `order_like` with a NamedArrayView requires the two" + " arrays to have the same number of positional axes, but got" + f" positional shapes {self.positional_shape=}," + f" {other.positional_shape=}" ) if set(self.named_shape.keys()) != set(other.named_shape.keys()): raise ValueError( - "Calling `order_like` with a NamedArrayView requires the two" - " arrays to have the axis names, but got" - f" named shapes {self.named_shape=}, {other.named_shape=}" + "Calling `order_like` with a NamedArrayView requires the two" + " arrays to have the axis names, but got" + f" named shapes {self.named_shape=}, {other.named_shape=}" ) self_view = self.as_namedarrayview() new_to_old_data_axis = {} for old_data_axis, new_data_axis in zip( - self_view.data_axis_for_logical_axis, other.data_axis_for_logical_axis + self_view.data_axis_for_logical_axis, other.data_axis_for_logical_axis ): new_to_old_data_axis[new_data_axis] = old_data_axis for name, new_data_axis in other.data_axis_for_name.items(): new_to_old_data_axis[new_data_axis] = self_view.data_axis_for_name[name] new_data_array = jnp.transpose( - self_view.data_array, - [new_to_old_data_axis[i] for i in range(self_view.data_array.ndim)], + self_view.data_array, + [new_to_old_data_axis[i] for i in range(self_view.data_array.ndim)], ) return NamedArrayView( - data_shape=new_data_array.shape, - data_axis_for_logical_axis=other.data_axis_for_logical_axis, - data_axis_for_name=other.data_axis_for_name, - data_array=new_data_array, + data_shape=new_data_array.shape, + data_axis_for_logical_axis=other.data_axis_for_logical_axis, + data_axis_for_name=other.data_axis_for_name, + data_array=new_data_array, ) else: raise TypeError( - "`order_like` requires a NamedArray or NamedArrayView, but got" - f" {type(other).__name__}" + "`order_like` requires a NamedArray or NamedArrayView, but got" + f" {type(other).__name__}" ) def broadcast_to( - self, - positional_shape: Sequence[int] = (), - named_shape: Mapping[AxisName, int] | None = None, + self, + positional_shape: Sequence[int] = (), + named_shape: Mapping[AxisName, int] | None = None, ) -> NamedArrayBase: """Broadcasts a named array to a possibly-larger shape. @@ -967,22 +977,22 @@ def broadcast_to( named_shape = {} named_shape = dict(named_shape) if ( - self.positional_shape == tuple(positional_shape) - and dict(self.named_shape) == named_shape + self.positional_shape == tuple(positional_shape) + and dict(self.named_shape) == named_shape ): return self # Trick: create a size-zero array with the right shape so that we can # broadcast using nmap's vectorization rules. prototype_data = jnp.zeros( - tuple(named_shape.values()) + tuple(positional_shape) + (0,) + tuple(named_shape.values()) + tuple(positional_shape) + (0,) ) assert prototype_data.size == 0 prototype = NamedArray.wrap(prototype_data).tag_prefix(*named_shape.keys()) return nmap(lambda a, b: jnp.broadcast_to(a, b.shape[:-1]))(self, prototype) def broadcast_like( - self, other: NamedArrayBase | jax.typing.ArrayLike + self, other: NamedArrayBase | jax.typing.ArrayLike ) -> NamedArrayBase: """Broadcasts a named array to be compatible with another. @@ -1245,16 +1255,20 @@ def __treescope_ndarray_adapter__(self): __rsub__ = _nmap_with_doc(_swapped_binop(operator.sub), "jax.Array.__rsub__") __rmul__ = _nmap_with_doc(_swapped_binop(operator.mul), "jax.Array.__rmul__") __rtruediv__ = _nmap_with_doc( - _swapped_binop(operator.truediv), "jax.Array.__rtruediv__" + _swapped_binop(operator.truediv), "jax.Array.__rtruediv__" ) __rfloordiv__ = _nmap_with_doc( - _swapped_binop(operator.floordiv), "jax.Array.__rfloordiv__" + _swapped_binop(operator.floordiv), "jax.Array.__rfloordiv__" ) __rmod__ = _nmap_with_doc(_swapped_binop(operator.mod), "jax.Array.__rmod__") __rdivmod__ = _nmap_with_doc(_swapped_binop(divmod), "jax.Array.__rdivmod__") __rpow__ = _nmap_with_doc(_swapped_binop(operator.pow), "jax.Array.__rpow__") - __rlshift__ = _nmap_with_doc(_swapped_binop(operator.lshift), "jax.Array.__rlshift__") - __rrshift__ = _nmap_with_doc(_swapped_binop(operator.rshift), "jax.Array.__rrshift__") + __rlshift__ = _nmap_with_doc( + _swapped_binop(operator.lshift), "jax.Array.__rlshift__" + ) + __rrshift__ = _nmap_with_doc( + _swapped_binop(operator.rshift), "jax.Array.__rrshift__" + ) __rand__ = _nmap_with_doc(_swapped_binop(operator.and_), "jax.Array.__rand__") __ror__ = _nmap_with_doc(_swapped_binop(operator.or_), "jax.Array.__ror__") __rxor__ = _nmap_with_doc(_swapped_binop(operator.xor), "jax.Array.__rxor__") @@ -1396,7 +1410,7 @@ class NamedArray(NamedArrayBase, struct.Struct): """ named_axes: collections.OrderedDict[AxisName, int] = dataclasses.field( - metadata={"pytree_node": False} + metadata={"pytree_node": False} ) data_array: jax.Array @@ -1420,7 +1434,7 @@ def wrap(cls, array: jax.typing.ArrayLike, *names: AxisName) -> NamedArray: shape. """ wrapped = NamedArray( - named_axes=collections.OrderedDict(), data_array=jnp.asarray(array) + named_axes=collections.OrderedDict(), data_array=jnp.asarray(array) ) if names: return wrapped.tag(*names) @@ -1432,34 +1446,39 @@ def dtype(self) -> np.dtype: return self.data_array.dtype def check_valid(self) -> None: - if not hasattr(self.data_array, "shape") or not hasattr(self.data_array, "dtype"): + if not hasattr(self.data_array, "shape") or not hasattr( + self.data_array, "dtype" + ): raise ValueError( - "NamedArray.data_array must contain a jax or numpy array (or at least" - f" something with a shape and dtype), not {type(self.data_array)}" + "NamedArray.data_array must contain a jax or numpy array (or at least" + f" something with a shape and dtype), not {type(self.data_array)}" ) if not isinstance(self.named_axes, collections.OrderedDict) or not all( - isinstance(size, int) for size in self.named_axes.values() + isinstance(size, int) for size in self.named_axes.values() ): raise ValueError( - "NamedArray.named_axes must be an ordered dictionary of named axis shapes" + "NamedArray.named_axes must be an ordered dictionary of named axis" + " shapes" ) if any(isinstance(name, int) for name in self.named_axes.keys()): raise ValueError( - "Integers are not allowed as axis names, to avoid confusion with" - " positional axis indices." + "Integers are not allowed as axis names, to avoid confusion with" + " positional axis indices." ) true_suffix_shape = tuple( - self.data_array.shape[len(self.data_array.shape) - len(self.named_axes) :] + self.data_array.shape[ + len(self.data_array.shape) - len(self.named_axes) : + ] ) if true_suffix_shape != tuple(self.named_axes.values()): raise ValueError( - "The axis sizes in `named_axes` must exactly match a suffix " - " of the data array's shape, but" - f" {tuple(self.named_axes.values())} was not a suffix of" - f" {self.data_array.shape}" + "The axis sizes in `named_axes` must exactly match a suffix " + " of the data array's shape, but" + f" {tuple(self.named_axes.values())} was not a suffix of" + f" {self.data_array.shape}" ) @property @@ -1476,21 +1495,21 @@ def unwrap(self, *names) -> jax.Array: if names: if self.positional_shape: raise ValueError( - "Cannot unwrap a NamedArray by providing an axis name ordering if" - " it already has a positional shape. For advanced axis name" - " manipulation, try using `untag` and `tag` directly." + "Cannot unwrap a NamedArray by providing an axis name ordering if" + " it already has a positional shape. For advanced axis name" + " manipulation, try using `untag` and `tag` directly." ) name_bound = self.untag(*names) if name_bound.named_shape: raise ValueError( - "When calling `unwrap` with axis names, a position must be given" - f" for every axis name. Unassigned names: {name_bound.named_shape}" + "When calling `unwrap` with axis names, a position must be given" + f" for every axis name. Unassigned names: {name_bound.named_shape}" ) return name_bound.unwrap() if self.named_axes: raise ValueError( - "To unwrap a NamedArray with nonempty named shape, an ordering for" - f" its named axes must be provided. Named shape: {self.named_axes}" + "To unwrap a NamedArray with nonempty named shape, an ordering for" + f" its named axes must be provided. Named shape: {self.named_axes}" ) return self.data_array @@ -1503,15 +1522,15 @@ def as_namedarrayview(self) -> NamedArrayView: positional_axis_count = len(self.data_array.shape) - len(self.named_axes) data_axis_for_name = { - name: index + positional_axis_count - for index, name in enumerate(self.named_axes.keys()) + name: index + positional_axis_count + for index, name in enumerate(self.named_axes.keys()) } data_axis_for_logical_axis = tuple(range(positional_axis_count)) return NamedArrayView( - data_shape=self.data_array.shape, - data_axis_for_logical_axis=data_axis_for_logical_axis, - data_axis_for_name=data_axis_for_name, - data_array=self.data_array, + data_shape=self.data_array.shape, + data_axis_for_logical_axis=data_axis_for_logical_axis, + data_axis_for_name=data_axis_for_name, + data_array=self.data_array, ) def untag(self, *axis_order: AxisName) -> NamedArray | NamedArrayView: @@ -1550,12 +1569,14 @@ class NamedArrayView(NamedArrayBase, struct.Struct): data_array: The underlying positional-indexed array. """ - data_shape: tuple[int, ...] = dataclasses.field(metadata={"pytree_node": False}) + data_shape: tuple[int, ...] = dataclasses.field( + metadata={"pytree_node": False} + ) data_axis_for_logical_axis: tuple[int, ...] = dataclasses.field( - metadata={"pytree_node": False} + metadata={"pytree_node": False} ) data_axis_for_name: dict[AxisName, int] = dataclasses.field( - metadata={"pytree_node": False} + metadata={"pytree_node": False} ) data_array: jax.Array @@ -1567,20 +1588,22 @@ def dtype(self) -> np.dtype: def check_valid(self) -> None: # Data array has a shape. - if not hasattr(self.data_array, "shape") or not hasattr(self.data_array, "dtype"): + if not hasattr(self.data_array, "shape") or not hasattr( + self.data_array, "dtype" + ): raise ValueError( - "NamedArrayView.data_array must contain a jax or numpy array (or at" - " least something with a shape and dtype), not" - f" {type(self.data_array)}" + "NamedArrayView.data_array must contain a jax or numpy array (or at" + " least something with a shape and dtype), not" + f" {type(self.data_array)}" ) # Data shape is valid. if self.data_shape != self.data_array.shape: raise ValueError( - f"Expected data_array to have shape {self.data_shape}, but it has" - f" shape {self.data_array.shape}. Modifying the shape of the data" - " array of a NamedArrayView directly is not allowed; use `nmap`" - " instead, or call `with_positional_prefix` if you need to" - " manipulate the positional axes as prefix axes." + f"Expected data_array to have shape {self.data_shape}, but it has" + f" shape {self.data_array.shape}. Modifying the shape of the data" + " array of a NamedArrayView directly is not allowed; use `nmap`" + " instead, or call `with_positional_prefix` if you need to" + " manipulate the positional axes as prefix axes." ) # Every axis appears exactly once. seen_axes = collections.Counter() @@ -1588,27 +1611,28 @@ def check_valid(self) -> None: seen_axes.update(self.data_axis_for_name.values()) if seen_axes != collections.Counter(range(len(self.data_shape))): raise ValueError( - "Every axis index into `data_shape` must appear exactly once in" - " either `data_axis_for_logical_axis` or `data_axis_for_name`." + "Every axis index into `data_shape` must appear exactly once in" + " either `data_axis_for_logical_axis` or `data_axis_for_name`." ) # Check for bad names. if any(isinstance(name, int) for name in self.data_axis_for_name.keys()): raise ValueError( - "Integers are not allowed as axis names, to avoid confusion with" - " positional axis indices." + "Integers are not allowed as axis names, to avoid confusion with" + " positional axis indices." ) @property def named_shape(self) -> Mapping[AxisName, int]: return { - name: self.data_shape[data_axis] - for name, data_axis in self.data_axis_for_name.items() + name: self.data_shape[data_axis] + for name, data_axis in self.data_axis_for_name.items() } @property def positional_shape(self) -> tuple[int, ...]: return tuple( - self.data_shape[data_axis] for data_axis in self.data_axis_for_logical_axis + self.data_shape[data_axis] + for data_axis in self.data_axis_for_logical_axis ) def unwrap(self, *names) -> jax.Array: @@ -1616,22 +1640,22 @@ def unwrap(self, *names) -> jax.Array: if names: if self.positional_shape: raise ValueError( - "Cannot unwrap a NamedArrayView by providing an axis name ordering" - " if it already has a positional shape. For advanced axis name" - " manipulation, try using `untag` and `tag` directly." + "Cannot unwrap a NamedArrayView by providing an axis name ordering" + " if it already has a positional shape. For advanced axis name" + " manipulation, try using `untag` and `tag` directly." ) name_bound = self.untag(*names) if name_bound.named_shape: raise ValueError( - "When calling `unwrap` with axis names, a position must be given" - f" for every axis name. Unassigned names: {name_bound.named_shape}" + "When calling `unwrap` with axis names, a position must be given" + f" for every axis name. Unassigned names: {name_bound.named_shape}" ) return name_bound.unwrap() if self.named_shape: raise ValueError( - "To unwrap a NamedArrayView with nonempty named shape, an ordering" - " for its named axes must be provided. Named shape:" - f" {self.named_shape}" + "To unwrap a NamedArrayView with nonempty named shape, an ordering" + " for its named axes must be provided. Named shape:" + f" {self.named_shape}" ) # with_positional_prefix will perform the necessary transpositions, we can # then simply unwrap the resulting NamedArray. @@ -1657,7 +1681,7 @@ def with_positional_prefix(self) -> NamedArray: transposition.append(data_axis) # Then the named axes in sorted order (to try to avoid transposition) data_axes_and_names = sorted( - (data_axis, name) for name, data_axis in self.data_axis_for_name.items() + (data_axis, name) for name, data_axis in self.data_axis_for_name.items() ) for data_axis, name in data_axes_and_names: transposition.append(data_axis) @@ -1668,8 +1692,8 @@ def with_positional_prefix(self) -> NamedArray: return NamedArray(data_array=self.data_array, named_axes=named_axes) else: return NamedArray( - data_array=self.data_array.transpose(transposition), - named_axes=named_axes, + data_array=self.data_array.transpose(transposition), + named_axes=named_axes, ) def as_namedarrayview(self) -> NamedArrayView: @@ -1680,10 +1704,10 @@ def untag(self, *axis_order: AxisName) -> NamedArray | NamedArrayView: self.check_valid() if self.data_axis_for_logical_axis: raise ValueError( - "`untag` cannot be used to introduce positional axes for a" - " NamedArray (or NamedArrayView) that already has positional axes." - " Please assign names to the existing positional axes first using" - " `tag`." + "`untag` cannot be used to introduce positional axes for a" + " NamedArray (or NamedArrayView) that already has positional axes." + " Please assign names to the existing positional axes first using" + " `tag`." ) if not axis_order: @@ -1700,7 +1724,7 @@ def untag(self, *axis_order: AxisName) -> NamedArray | NamedArrayView: bad_names = requested_axis_set.difference(actual_axis_set) if bad_names: raise ValueError( - f"Requested axis names {bad_names} are not present in the array." + f"Requested axis names {bad_names} are not present in the array." ) # Build a view. @@ -1711,25 +1735,25 @@ def untag(self, *axis_order: AxisName) -> NamedArray | NamedArrayView: del data_axis_for_name[name] return NamedArrayView( - data_shape=self.data_shape, - data_axis_for_logical_axis=tuple(data_axis_for_logical_axis), - data_axis_for_name=data_axis_for_name, - data_array=self.data_array, + data_shape=self.data_shape, + data_axis_for_logical_axis=tuple(data_axis_for_logical_axis), + data_axis_for_name=data_axis_for_name, + data_array=self.data_array, ) def tag(self, *names) -> NamedArray: self.check_valid() if len(names) != len(self.data_axis_for_logical_axis): raise ValueError( - "There must be exactly as many names given to `tag` as there" - f" are positional axes in the array, but got {names} for positional" - f" shape {self.positional_shape}" + "There must be exactly as many names given to `tag` as there" + f" are positional axes in the array, but got {names} for positional" + f" shape {self.positional_shape}" ) if any(isinstance(name, int) for name in names): raise ValueError( - "Integers are not allowed as axis names, to avoid confusion with" - " positional axis indices." + "Integers are not allowed as axis names, to avoid confusion with" + " positional axis indices." ) seen_axes = collections.Counter() @@ -1738,9 +1762,9 @@ def tag(self, *names) -> NamedArray: repeated_names = [name for name, count in seen_axes.items() if count > 1] if repeated_names: raise ValueError( - "Repeated axis names are not allowed; original names were" - f" {tuple(self.data_axis_for_name.keys())} and new names passed to" - f" tag were {names}; repeated: {repeated_names}" + "Repeated axis names are not allowed; original names were" + f" {tuple(self.data_axis_for_name.keys())} and new names passed to" + f" tag were {names}; repeated: {repeated_names}" ) names_by_index = {} @@ -1751,10 +1775,11 @@ def tag(self, *names) -> NamedArray: names_by_index[data_axis] = names[i] return NamedArray( - data_array=self.data_array, - named_axes=collections.OrderedDict( - [(names_by_index[i], self.data_shape[i]) for i in range(len(self.data_shape))] - ), + data_array=self.data_array, + named_axes=collections.OrderedDict([ + (names_by_index[i], self.data_shape[i]) + for i in range(len(self.data_shape)) + ]), ) @@ -1768,9 +1793,9 @@ def is_namedarray(value) -> typing.TypeGuard[NamedArrayBase]: def full( - named_shape: Mapping[AxisName, int], - fill_value: jax.typing.ArrayLike, - dtype: np.DTypeLike | None = None, + named_shape: Mapping[AxisName, int], + fill_value: jax.typing.ArrayLike, + dtype: np.DTypeLike | None = None, ) -> NamedArray: """Constructs a full named array with a given shape. @@ -1784,14 +1809,14 @@ def full( NamedArray with the given named shape, filled with ``fill_value``. """ return NamedArray( - named_axes=collections.OrderedDict(named_shape), - data_array=jnp.full(tuple(named_shape.values()), fill_value, dtype=dtype), + named_axes=collections.OrderedDict(named_shape), + data_array=jnp.full(tuple(named_shape.values()), fill_value, dtype=dtype), ) def zeros( - named_shape: Mapping[AxisName, int], - dtype: np.DTypeLike | None = None, + named_shape: Mapping[AxisName, int], + dtype: np.DTypeLike | None = None, ) -> NamedArray: """Constructs a named array of zeros with a given shape. @@ -1803,14 +1828,14 @@ def zeros( NamedArray with the given named shape, filled with zeros. """ return NamedArray( - named_axes=collections.OrderedDict(named_shape), - data_array=jnp.zeros(tuple(named_shape.values()), dtype), + named_axes=collections.OrderedDict(named_shape), + data_array=jnp.zeros(tuple(named_shape.values()), dtype), ) def ones( - named_shape: Mapping[AxisName, int], - dtype: np.DTypeLike | None = None, + named_shape: Mapping[AxisName, int], + dtype: np.DTypeLike | None = None, ) -> NamedArray: """Constructs a named array of ones with a given shape. @@ -1822,17 +1847,17 @@ def ones( NamedArray with the given named shape, filled with ones. """ return NamedArray( - named_axes=collections.OrderedDict(named_shape), - data_array=jnp.ones(tuple(named_shape.values()), dtype), + named_axes=collections.OrderedDict(named_shape), + data_array=jnp.ones(tuple(named_shape.values()), dtype), ) def arange( - name: str, - start: int, - stop: int | None = None, - step: int | None = None, - dtype: jax.typing.DTypeLike | None = None, + name: str, + start: int, + stop: int | None = None, + step: int | None = None, + dtype: jax.typing.DTypeLike | None = None, ) -> NamedArray: """Convenience function to create a range along a named axis. @@ -1854,8 +1879,8 @@ def arange( def random_split( - key: jax.Array | NamedArrayBase, - named_shape: Mapping[AxisName, int] | Sequence[tuple[AxisName, int]], + key: jax.Array | NamedArrayBase, + named_shape: Mapping[AxisName, int] | Sequence[tuple[AxisName, int]], ) -> NamedArray | NamedArrayView: """Splits a PRNG key into a `NamedArray` of PRNG keys with the given names. @@ -1885,9 +1910,9 @@ def random_split( names = sorted(unsorted_keys) except Exception as exc: raise ValueError( - "Unordered mappings must have sortable axis names when using" - " `random_split`. If necessary, you can specify a particular ordering" - " using a collections.OrderedDict or a tuple of (name, size) pairs." + "Unordered mappings must have sortable axis names when using" + " `random_split`. If necessary, you can specify a particular ordering" + " using a collections.OrderedDict or a tuple of (name, size) pairs." ) from exc sizes = [named_shape[name] for name in names] @@ -1896,12 +1921,12 @@ def random_split( flat_split_keys = nmap(jax.random.split)(key, total_size) return flat_split_keys.reshape( - tuple(sizes) + flat_split_keys.positional_shape[1:] + tuple(sizes) + flat_split_keys.positional_shape[1:] ).tag_prefix(*names) def concatenate( - arrays: Sequence[NamedArrayBase], axis_name: AxisName + arrays: Sequence[NamedArrayBase], axis_name: AxisName ) -> NamedArray | NamedArrayView: """Concatenates a sequence of named arrays along a named axis. @@ -1917,20 +1942,21 @@ def concatenate( ndims = set(len(array.positional_shape) for array in arrays) if len(ndims) != 1: raise ValueError( - f"All arrays must have the same number of positional axes, but got {ndims}" + "All arrays must have the same number of positional axes, but got" + f" {ndims}" ) (ndim,) = ndims orig_positional_axes = [TmpPosAxisMarker() for _ in range(ndim)] arrays_along_axis = [ - array.tag(*orig_positional_axes).untag(axis_name) for array in arrays + array.tag(*orig_positional_axes).untag(axis_name) for array in arrays ] concatenated = nmap(jnp.concatenate)(arrays_along_axis) return concatenated.tag(axis_name).untag(*orig_positional_axes) def stack( - arrays: Sequence[NamedArrayBase], axis_name: AxisName + arrays: Sequence[NamedArrayBase], axis_name: AxisName ) -> NamedArray | NamedArrayView: """Stacks a sequence of named arrays along a named axis. @@ -1944,7 +1970,8 @@ def stack( ndims = set(len(array.positional_shape) for array in arrays) if len(ndims) != 1: raise ValueError( - f"All arrays must have the same number of positional axes, but got {ndims}" + "All arrays must have the same number of positional axes, but got" + f" {ndims}" ) (ndim,) = ndims @@ -1955,7 +1982,7 @@ def stack( def unstack( - array: NamedArrayBase, axis_name: AxisName + array: NamedArrayBase, axis_name: AxisName ) -> Sequence[NamedArray | NamedArrayView]: """Splits a named array across a given named axis. @@ -1996,11 +2023,13 @@ def _fix(val, ref): else: return val - return jax.tree_util.tree_map(_fix, value_tree, reference_tree, is_leaf=is_namedarray) + return jax.tree_util.tree_map( + _fix, value_tree, reference_tree, is_leaf=is_namedarray + ) def scan( - f: Callable[[Any, Any], Any], axis: AxisName, init, xs=None, **scan_kwargs + f: Callable[[Any, Any], Any], axis: AxisName, init, xs=None, **scan_kwargs ) -> Any: """Scan a function over a named array axis while carrying along state. @@ -2063,15 +2092,17 @@ def wrapped_f(carry, x): new_carry, y = f(carry, x) new_carry = order_like(new_carry, carry) y = jax.tree_util.tree_map( - lambda v: v.with_positional_prefix() if is_namedarray(v) else v, - y, - is_leaf=is_namedarray, + lambda v: v.with_positional_prefix() if is_namedarray(v) else v, + y, + is_leaf=is_namedarray, ) return new_carry, y # Run the scan, which will slice off the positional prefix from the inputs, # and add a positional prefix to the outputs. - final_carry, ys_untagged = jax.lax.scan(wrapped_f, init, xs_untagged, **scan_kwargs) + final_carry, ys_untagged = jax.lax.scan( + wrapped_f, init, xs_untagged, **scan_kwargs + ) # Re-assign the scanned-over axis. def _retag(leaf): From b3252f4692b9b1a0f410f8757f45c35a2cf37536 Mon Sep 17 00:00:00 2001 From: AntoinePlumerault Date: Mon, 9 Jun 2025 21:12:24 +0000 Subject: [PATCH 03/18] removed unecessary cast --- penzai/nn/linear_and_affine.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/penzai/nn/linear_and_affine.py b/penzai/nn/linear_and_affine.py index c7ec95e..7a64eb4 100644 --- a/penzai/nn/linear_and_affine.py +++ b/penzai/nn/linear_and_affine.py @@ -160,24 +160,19 @@ def variance_scaling_initializer( return named_axes.wrap(array).tag(*names) -xavier_uniform_initializer = cast( - LinearOperatorWeightInitializer, - functools.partial( - variance_scaling_initializer, - scale=1.0, - mode="fan_avg", - distribution="uniform", - ), +xavier_uniform_initializer = functools.partial( + variance_scaling_initializer, + scale=1.0, + mode="fan_avg", + distribution="uniform", ) -xavier_normal_initializer = cast( - LinearOperatorWeightInitializer, - functools.partial( - variance_scaling_initializer, - scale=1.0, - mode="fan_avg", - distribution="normal", - ), + +xavier_normal_initializer = functools.partial( + variance_scaling_initializer, + scale=1.0, + mode="fan_avg", + distribution="normal", ) From dc8e5a2df3689138fef6ee4b92ec3ce0f6bd9fb5 Mon Sep 17 00:00:00 2001 From: AntoinePlumerault Date: Mon, 9 Jun 2025 21:22:34 +0000 Subject: [PATCH 04/18] added underscore and docstring for the axis name collision --- penzai/nn/linear_and_affine.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/penzai/nn/linear_and_affine.py b/penzai/nn/linear_and_affine.py index 7a64eb4..5b7d114 100644 --- a/penzai/nn/linear_and_affine.py +++ b/penzai/nn/linear_and_affine.py @@ -19,7 +19,7 @@ import dataclasses import functools import itertools -from typing import Any, Literal, Protocol, Sequence, cast +from typing import Any, Literal, Protocol, Sequence import jax import jax.numpy as jnp from penzai.core import named_axes @@ -29,7 +29,7 @@ from penzai.nn import grouping from penzai.nn import layer as layer_base from penzai.nn import parameters -import abc + NamedArray = named_axes.NamedArray Parameter = variables.Parameter @@ -426,13 +426,32 @@ def _output_structure(self) -> shapecheck.StructureAnnotation: ) -def maybe_rename_output_axes( +def _maybe_rename_output_axes( input_axes: dict[str, int], output_axes: dict[str, int], parallel_axes: dict[str, int], parallel_broadcast_axes: dict[str, int], rename_outputs_if_necessary: bool, ): + """Checks for name overlap between input and output axes, and renames if + needed to avoid collisions. + + Args: + input_axes: Names and lengths for axes that the linear operator should + contract over. + output_axes: Names and lengths for new axes that the linear operator should + produce. + parallel_axes: Names and lengths for axes that should be processed in + parallel. These axes should appear in both the input and the output, and + the resulting linear operator will apply a different operator to each + slice. (This is similar to a block-diagonal matrix.) + parallel_broadcast_axes: Names and lengths for axes that should be treated + like `parallel_axes` but will only appear in the output. The input will be + implicitly broadcast over these axes. + rename_outputs_if_necessary: If True, renames output axes that overlap with + input axes by appending "_out" to their names. + """ + # By default no rename & no wrapping output_axes_after_rename = output_axes primed_names, original_names = None, None @@ -588,7 +607,7 @@ def from_config( parallel_broadcast_axes = {} output_axes_after_rename, primed_names, original_names = ( - maybe_rename_output_axes( + _maybe_rename_output_axes( input_axes, output_axes, parallel_axes, @@ -1138,7 +1157,7 @@ def from_config( parallel_broadcast_axes = {} output_axes_after_rename, primed_names, original_names = ( - maybe_rename_output_axes( + _maybe_rename_output_axes( input_axes, output_axes, parallel_axes, @@ -1356,7 +1375,7 @@ def from_config( parallel_broadcast_axes = {} output_axes_after_rename, primed_names, original_names = ( - maybe_rename_output_axes( + _maybe_rename_output_axes( input_axes, output_axes, parallel_axes, From a4e47a8d0001a6aeef64e9efc04aedcdca1bc090 Mon Sep 17 00:00:00 2001 From: AntoinePlumerault Date: Mon, 9 Jun 2025 21:37:03 +0000 Subject: [PATCH 05/18] added docstring and underscore on utils functions --- penzai/nn/linear_and_affine.py | 98 +++++++++++++++++++++++++++------- 1 file changed, 78 insertions(+), 20 deletions(-) diff --git a/penzai/nn/linear_and_affine.py b/penzai/nn/linear_and_affine.py index 5b7d114..52baa18 100644 --- a/penzai/nn/linear_and_affine.py +++ b/penzai/nn/linear_and_affine.py @@ -839,14 +839,34 @@ def __call__(self, value: Any, **_unused_side_inputs) -> Any: return jax.tree_util.tree_map(lambda x: x * self.by, value) -def prepare_for_conv( +def _prepare_for_conv( inputs: NamedArray, kernel: NamedArray, spatial_axis_names: Sequence[str], in_axis_names: Sequence[str], out_axis_names: Sequence[str], -): - """Preprocess lhs and rhs for jax convolution operator""" +) -> tuple[NamedArray, NamedArray]: + """Preprocess lhs and rhs for jax convolution operator. + + Merges the in axes of the inputs into a single in channel axis, and merges the + out axes of the kernel into a single out channel axis. This is necessary to + use the jax convolution operator, which expects the inputs to have a single + in channel axis and the kernel to have a single out channel axis. + + Args: + inputs: The input named array. + kernel: The kernel named array. + spatial_axis_names: Names of the spatial axes in the input and kernel. + in_axis_names: Names of the input axes that will be contracted with the + kernel. + out_axis_names: Names of the output axes that will be produced by the + convolution. + Returns: + A tuple of two named arrays, the first one is the input with the in axes + merged into a single in channel axis, and the second one is the kernel with + the in axes merged into a single in channel axis and the out axes merged + into a single out channel axis. + """ lhs = inputs rhs = kernel @@ -867,14 +887,32 @@ def prepare_for_conv( return lhs, rhs -def get_named_axis_back_after_conv( +def _get_named_axis_back_after_conv( result: NamedArray, spatial_axis_names: Sequence[str], out_axis_names: Sequence[str], out_axis_shape: Sequence[int], -): - """Postprocess result from jax convolution operator""" - # Get named axes back +) -> NamedArray: + """Postprocess result from jax convolution operator + + Restores the spatial axes and output axes to the result of the jax convolution + operator. The spatial axes are tagged back, and the output axes are reshaped + to the original shape and tagged back. + This is necessary to restore the original shape of the output after the + convolution operator has been applied, since we flattened the output axes into + a single axis before applying the convolution. + + Args: + result: The result of the jax convolution operator. + spatial_axis_names: Names of the spatial axes in the input and kernel. + out_axis_names: Names of the output axes that will be produced by the + convolution. + out_axis_shape: The shape of the output axes, which will be used to reshape + the result back to the original shape. + Returns: + A named array with the spatial axes and output axes tagged back, and the + output axes reshaped to the original shape. + """ return ( result.tag_prefix(*spatial_axis_names) .reshape(out_axis_shape) @@ -882,11 +920,31 @@ def get_named_axis_back_after_conv( ) -def maybe_broadcast(value: int | Sequence[int], count: int): +def _maybe_broadcast(value: int | Sequence[int], count: int) -> Sequence[int]: + """Broadcasts a value to a sequence of the given count. + + If the value is an integer, it will be repeated `count` times. + If the value is already a sequence, it will be returned as is. + + Args: + value: The value to broadcast, either an integer or a sequence of integers. + count: The number of times to repeat the value if it is an integer. + Returns: + A sequence of integers with the value repeated `count` times if it was an + integer, or the original sequence if it was already a sequence. + """ return [value] * count if isinstance(value, int) else value -def get_dimension_numbers(ndim): +def _get_dimension_numbers(ndim) -> jax.lax.ConvDimensionNumbers: + """Returns the dimension numbers for a convolution operator. + Args: + ndim: The number of spatial dimensions of the convolution operator. + Returns: + A `jax.lax.ConvDimensionNumbers` object that specifies the dimension numbers + for the convolution operator. + """ + return jax.lax.ConvDimensionNumbers( lhs_spec=(0, ndim + 1) + tuple(range(1, ndim + 1)), # BHSpatial -> BCSpatial @@ -1034,7 +1092,7 @@ def __call__(self, in_array: NamedArray, **_side_inputs) -> NamedArray: ) print(in_array) - lhs, rhs = prepare_for_conv( + lhs, rhs = _prepare_for_conv( in_array, self.kernel.value, self.spatial_axis_names, @@ -1051,13 +1109,13 @@ def __call__(self, in_array: NamedArray, **_side_inputs) -> NamedArray: padding=self.padding, lhs_dilation=self.inputs_dilation, rhs_dilation=self.kernel_dilation, - dimension_numbers=get_dimension_numbers( + dimension_numbers=_get_dimension_numbers( ndim=len(self.spatial_axis_names) ), )[0] )(lhs, rhs) - result = get_named_axis_back_after_conv( + result = _get_named_axis_back_after_conv( result, self.spatial_axis_names, self.out_axis_names, @@ -1147,9 +1205,9 @@ def from_config( """ spatial_dim_count = len(convolution_spatial_axes) - strides = maybe_broadcast(strides, spatial_dim_count) - inputs_dilation = maybe_broadcast(inputs_dilation, spatial_dim_count) - kernel_dilation = maybe_broadcast(kernel_dilation, spatial_dim_count) + strides = _maybe_broadcast(strides, spatial_dim_count) + inputs_dilation = _maybe_broadcast(inputs_dilation, spatial_dim_count) + kernel_dilation = _maybe_broadcast(kernel_dilation, spatial_dim_count) if parallel_axes is None: parallel_axes = {} @@ -1257,7 +1315,7 @@ def __call__(self, in_array: NamedArray, **_side_inputs) -> NamedArray: in_array, in_struct, error_prefix=error_prefix ) - lhs, rhs = prepare_for_conv( + lhs, rhs = _prepare_for_conv( in_array, self.kernel.value, self.spatial_axis_names, @@ -1273,13 +1331,13 @@ def __call__(self, in_array: NamedArray, **_side_inputs) -> NamedArray: strides=self.strides, padding=self.padding, rhs_dilation=self.kernel_dilation, - dimension_numbers=get_dimension_numbers( + dimension_numbers=_get_dimension_numbers( ndim=len(self.spatial_axis_names) ), )[0] )(lhs, rhs) - result = get_named_axis_back_after_conv( + result = _get_named_axis_back_after_conv( result, self.spatial_axis_names, self.out_axis_names, @@ -1366,8 +1424,8 @@ def from_config( """ spatial_dim_count = len(convolution_spatial_axes) - strides = maybe_broadcast(strides, spatial_dim_count) - kernel_dilation = maybe_broadcast(kernel_dilation, spatial_dim_count) + strides = _maybe_broadcast(strides, spatial_dim_count) + kernel_dilation = _maybe_broadcast(kernel_dilation, spatial_dim_count) if parallel_axes is None: parallel_axes = {} From f6fefcef20e3a66e2c531615d9fc4721da03ca02 Mon Sep 17 00:00:00 2001 From: AntoinePlumerault Date: Sat, 14 Jun 2025 15:30:36 +0000 Subject: [PATCH 06/18] reduced code duplication by moving logic inside AbstractGeneralConv --- penzai/nn/linear_and_affine.py | 390 ++++++++++++++++----------------- 1 file changed, 187 insertions(+), 203 deletions(-) diff --git a/penzai/nn/linear_and_affine.py b/penzai/nn/linear_and_affine.py index 52baa18..a794405 100644 --- a/penzai/nn/linear_and_affine.py +++ b/penzai/nn/linear_and_affine.py @@ -959,7 +959,6 @@ class AbstractGeneralConv(layer_base.Layer): kernel: parameters.ParameterLike[NamedArray] strides: Sequence[int] padding: str | Sequence[tuple[int, int]] - kernel_dilation: Sequence[int] spatial_axis_names: tuple[str, ...] = dataclasses.field( metadata={"pytree_node": False} @@ -971,6 +970,160 @@ class AbstractGeneralConv(layer_base.Layer): metadata={"pytree_node": False} ) + kernel_dilation: Sequence[int] + inputs_dilation: Sequence[int] + + def __call__(self, in_array: NamedArray, **_side_inputs) -> NamedArray: + """Runs the Convolution operator.""" + in_struct = self._input_structure() + + # pytype: disable=attribute-error + if isinstance( + self.kernel, + Parameter | ParameterValue, + ) and self.kernel.label.endswith(".kernel"): + error_prefix = f"({self.kernel.label[: 7]}) " + else: + error_prefix = "" + # pytype: enable=attribute-error + + dimvars = shapecheck.check_structure( + in_array, in_struct, error_prefix=error_prefix + ) + + lhs, rhs = _prepare_for_conv( + in_array, + self.kernel.value, + self.spatial_axis_names, + self.in_axis_names, + self.out_axis_names, + ) + + if self._is_transposed(): + # Perform actual transposed convolution + result = named_axes.nmap( + lambda lhs, rhs: jax.lax.conv_transpose( + lhs=lhs[None, ...], + rhs=rhs, + strides=self.strides, + padding=self.padding, + rhs_dilation=self.kernel_dilation, + dimension_numbers=_get_dimension_numbers( + ndim=len(self.spatial_axis_names) + ), + )[0] + )(lhs, rhs) + else: + # Perform actual convolution + result = named_axes.nmap( + lambda lhs, rhs: jax.lax.conv_general_dilated( + lhs=lhs[None, ...], + rhs=rhs, + window_strides=self.strides, + padding=self.padding, + lhs_dilation=self.inputs_dilation, + rhs_dilation=self.kernel_dilation, + dimension_numbers=_get_dimension_numbers( + ndim=len(self.spatial_axis_names) + ), + )[0] + )(lhs, rhs) + + result = _get_named_axis_back_after_conv( + result, + self.spatial_axis_names, + self.out_axis_names, + [self.output_axes[name] for name in self.out_axis_names], + ) + + out_struct = self._output_structure() + shapecheck.check_structure( + result, out_struct, known_vars=dimvars, error_prefix=error_prefix + ) + return result + + @classmethod + def from_config( + cls, + name: str, + init_base_rng: jax.Array | None, + input_axes: dict[str, int], + output_axes: dict[str, int], + convolution_spatial_axes: dict[str, int], + strides: int | Sequence[int] = 1, + padding: str | Sequence[tuple[int, int]] = "SAME", + inputs_dilation: int | Sequence[int] = 1, + kernel_dilation: int | Sequence[int] = 1, + parallel_axes: dict[str, int] | None = None, + parallel_broadcast_axes: dict[str, int] | None = None, + initializer: LinearOperatorWeightInitializer = xavier_uniform_initializer, + dtype: jax.typing.DTypeLike = jnp.float32, + rename_outputs_if_necessary: bool = True, + ) -> Conv | ConvInPlace: + """Constructs a ``AbstractGeneralConv`` layer from a configuration. + + This can be used when building a new convolution or transposed convolution + operator at the start of training. For more details see Conv or + ConvTranspose. + """ + + spatial_dim_count = len(convolution_spatial_axes) + + strides = _maybe_broadcast(strides, spatial_dim_count) + inputs_dilation = _maybe_broadcast(inputs_dilation, spatial_dim_count) + kernel_dilation = _maybe_broadcast(kernel_dilation, spatial_dim_count) + + if parallel_axes is None: + parallel_axes = {} + if parallel_broadcast_axes is None: + parallel_broadcast_axes = {} + + output_axes_after_rename, primed_names, original_names = ( + _maybe_rename_output_axes( + input_axes, + output_axes, + parallel_axes, + parallel_broadcast_axes, + rename_outputs_if_necessary, + ) + ) + + core_layer = cls( + kernel=parameters.make_parameter( + f"{name}.kernel", + init_base_rng, + initializer, + input_axes=input_axes, + output_axes=output_axes_after_rename, + parallel_axes={**parallel_axes, **parallel_broadcast_axes}, + convolution_spatial_axes=convolution_spatial_axes, + dtype=dtype, + ), + strides=strides, + padding=padding, + inputs_dilation=inputs_dilation, + kernel_dilation=kernel_dilation, + spatial_axis_names=tuple(convolution_spatial_axes.keys()), + in_axis_names=tuple(input_axes.keys()), + out_axis_names=( + tuple(output_axes_after_rename.keys()) + + tuple(parallel_broadcast_axes.keys()) + ), + ) + + # if name overlap wrap layer + if primed_names is not None and original_names is not None: + return ConvInPlace( + sublayers=[ + core_layer, + RenameAxes(old=tuple(primed_names), new=tuple(original_names)), + ], + ) + return core_layer + + def _is_transposed(self) -> bool: + ... + def _input_structure(self): known_in_axes = { name: size @@ -1041,7 +1194,8 @@ class Conv(AbstractGeneralConv): """A general convolution operator, for named arrays. Applies an arbitrary contraction to the input `NamedArray` and a weight - parameter. This can be used to express an arbitrary linear convolution operator. + parameter. This can be used to express an arbitrary linear convolution + operator. Attributes: kernel: The named array holding the kernel for the convlution operator. @@ -1060,7 +1214,6 @@ class Conv(AbstractGeneralConv): kernel: parameters.ParameterLike[NamedArray] strides: Sequence[int] padding: str | Sequence[tuple[int, int]] - kernel_dilation: Sequence[int] spatial_axis_names: tuple[str, ...] = dataclasses.field( metadata={"pytree_node": False} @@ -1071,62 +1224,9 @@ class Conv(AbstractGeneralConv): out_axis_names: tuple[str, ...] = dataclasses.field( metadata={"pytree_node": False} ) - inputs_dilation: Sequence[int] - - def __call__(self, in_array: NamedArray, **_side_inputs) -> NamedArray: - """Runs the Convolution operator.""" - in_struct = self._input_structure() - - # pytype: disable=attribute-error - if isinstance( - self.kernel, - Parameter | ParameterValue, - ) and self.kernel.label.endswith(".kernel"): - error_prefix = f"({self.kernel.label[: 7]}) " - else: - error_prefix = "" - # pytype: enable=attribute-error - dimvars = shapecheck.check_structure( - in_array, in_struct, error_prefix=error_prefix - ) - - print(in_array) - lhs, rhs = _prepare_for_conv( - in_array, - self.kernel.value, - self.spatial_axis_names, - self.in_axis_names, - self.out_axis_names, - ) - - # Perform actual convolution - result = named_axes.nmap( - lambda lhs, rhs: jax.lax.conv_general_dilated( - lhs=lhs[None, ...], - rhs=rhs, - window_strides=self.strides, - padding=self.padding, - lhs_dilation=self.inputs_dilation, - rhs_dilation=self.kernel_dilation, - dimension_numbers=_get_dimension_numbers( - ndim=len(self.spatial_axis_names) - ), - )[0] - )(lhs, rhs) - - result = _get_named_axis_back_after_conv( - result, - self.spatial_axis_names, - self.out_axis_names, - [self.output_axes[name] for name in self.out_axis_names], - ) - - out_struct = self._output_structure() - shapecheck.check_structure( - result, out_struct, known_vars=dimvars, error_prefix=error_prefix - ) - return result + kernel_dilation: Sequence[int] + inputs_dilation: Sequence[int] @classmethod def from_config( @@ -1203,59 +1303,26 @@ def from_config( `ConvInPlace` layer if ``rename_outputs_if_necessary`` is True and ``input_axes`` overlaps with ``output_axes``. """ - spatial_dim_count = len(convolution_spatial_axes) - - strides = _maybe_broadcast(strides, spatial_dim_count) - inputs_dilation = _maybe_broadcast(inputs_dilation, spatial_dim_count) - kernel_dilation = _maybe_broadcast(kernel_dilation, spatial_dim_count) - - if parallel_axes is None: - parallel_axes = {} - if parallel_broadcast_axes is None: - parallel_broadcast_axes = {} - output_axes_after_rename, primed_names, original_names = ( - _maybe_rename_output_axes( - input_axes, - output_axes, - parallel_axes, - parallel_broadcast_axes, - rename_outputs_if_necessary, - ) - ) - - core_layer = cls( - kernel=parameters.make_parameter( - f"{name}.kernel", - init_base_rng, - initializer, - input_axes=input_axes, - output_axes=output_axes_after_rename, - parallel_axes={**parallel_axes, **parallel_broadcast_axes}, - convolution_spatial_axes=convolution_spatial_axes, - dtype=dtype, - ), + return super().from_config( + name=name, + init_base_rng=init_base_rng, + input_axes=input_axes, + output_axes=output_axes, + convolution_spatial_axes=convolution_spatial_axes, strides=strides, padding=padding, inputs_dilation=inputs_dilation, kernel_dilation=kernel_dilation, - spatial_axis_names=tuple(convolution_spatial_axes.keys()), - in_axis_names=tuple(input_axes.keys()), - out_axis_names=( - tuple(output_axes_after_rename.keys()) - + tuple(parallel_broadcast_axes.keys()) - ), + parallel_axes=parallel_axes, + parallel_broadcast_axes=parallel_broadcast_axes, + initializer=initializer, + dtype=dtype, + rename_outputs_if_necessary=rename_outputs_if_necessary, ) - # if name overlap wrap layer - if primed_names is not None and original_names is not None: - return ConvInPlace( - sublayers=[ - core_layer, - RenameAxes(old=tuple(primed_names), new=tuple(original_names)), - ], - ) - return core_layer + def _is_transposed(self): + return False def treescope_color(self) -> str: return "#79eb75" @@ -1285,7 +1352,6 @@ class ConvTranspose(AbstractGeneralConv): kernel: parameters.ParameterLike[NamedArray] strides: Sequence[int] padding: str | Sequence[tuple[int, int]] - kernel_dilation: Sequence[int] spatial_axis_names: tuple[str, ...] = dataclasses.field( metadata={"pytree_node": False} @@ -1297,58 +1363,8 @@ class ConvTranspose(AbstractGeneralConv): metadata={"pytree_node": False} ) - def __call__(self, in_array: NamedArray, **_side_inputs) -> NamedArray: - """Runs the Convolution operator.""" - in_struct = self._input_structure() - - # pytype: disable=attribute-error - if isinstance( - self.kernel, - Parameter | ParameterValue, - ) and self.kernel.label.endswith(".kernel"): - error_prefix = f"({self.kernel.label[: 7]}) " - else: - error_prefix = "" - # pytype: enable=attribute-error - - dimvars = shapecheck.check_structure( - in_array, in_struct, error_prefix=error_prefix - ) - - lhs, rhs = _prepare_for_conv( - in_array, - self.kernel.value, - self.spatial_axis_names, - self.in_axis_names, - self.out_axis_names, - ) - - # Perform actual transposed convolution - result = named_axes.nmap( - lambda lhs, rhs: jax.lax.conv_transpose( - lhs=lhs[None, ...], - rhs=rhs, - strides=self.strides, - padding=self.padding, - rhs_dilation=self.kernel_dilation, - dimension_numbers=_get_dimension_numbers( - ndim=len(self.spatial_axis_names) - ), - )[0] - )(lhs, rhs) - - result = _get_named_axis_back_after_conv( - result, - self.spatial_axis_names, - self.out_axis_names, - [self.output_axes[name] for name in self.out_axis_names], - ) - - out_struct = self._output_structure() - shapecheck.check_structure( - result, out_struct, known_vars=dimvars, error_prefix=error_prefix - ) - return result + kernel_dilation: Sequence[int] + inputs_dilation: Sequence[int] @classmethod def from_config( @@ -1422,57 +1438,25 @@ def from_config( `ConvTransposeInPlace` layer if ``rename_outputs_if_necessary`` is True and ``input_axes`` overlaps with ``output_axes``. """ - spatial_dim_count = len(convolution_spatial_axes) - - strides = _maybe_broadcast(strides, spatial_dim_count) - kernel_dilation = _maybe_broadcast(kernel_dilation, spatial_dim_count) - - if parallel_axes is None: - parallel_axes = {} - if parallel_broadcast_axes is None: - parallel_broadcast_axes = {} - - output_axes_after_rename, primed_names, original_names = ( - _maybe_rename_output_axes( - input_axes, - output_axes, - parallel_axes, - parallel_broadcast_axes, - rename_outputs_if_necessary, - ) - ) - - core_layer = cls( - kernel=parameters.make_parameter( - f"{name}.kernel", - init_base_rng, - initializer, - input_axes=input_axes, - output_axes=output_axes_after_rename, - parallel_axes={**parallel_axes, **parallel_broadcast_axes}, - convolution_spatial_axes=convolution_spatial_axes, - dtype=dtype, - ), + return super().from_config( + name=name, + init_base_rng=init_base_rng, + input_axes=input_axes, + output_axes=output_axes, + convolution_spatial_axes=convolution_spatial_axes, strides=strides, padding=padding, kernel_dilation=kernel_dilation, - spatial_axis_names=tuple(convolution_spatial_axes.keys()), - in_axis_names=tuple(input_axes.keys()), - out_axis_names=( - tuple(output_axes_after_rename.keys()) - + tuple(parallel_broadcast_axes.keys()) - ), + inputs_dilation=[], # not used for transposed convolutions + parallel_axes=parallel_axes, + parallel_broadcast_axes=parallel_broadcast_axes, + initializer=initializer, + dtype=dtype, + rename_outputs_if_necessary=rename_outputs_if_necessary, ) - # if name overlap wrap layer - if primed_names is not None and original_names is not None: - return ConvTransposeInPlace( - sublayers=[ - core_layer, - RenameAxes(old=tuple(primed_names), new=tuple(original_names)), - ], - ) - return core_layer + def _is_transposed(self): + return True def treescope_color(self) -> str: return "#c7eb75" From fa1504086351c82bdbae6b2444d85164b64eb743 Mon Sep 17 00:00:00 2001 From: AntoinePlumerault Date: Wed, 18 Jun 2025 18:04:25 +0000 Subject: [PATCH 07/18] use TmpPosAxisMarker helper --- penzai/nn/linear_and_affine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/penzai/nn/linear_and_affine.py b/penzai/nn/linear_and_affine.py index a794405..1c40b04 100644 --- a/penzai/nn/linear_and_affine.py +++ b/penzai/nn/linear_and_affine.py @@ -871,8 +871,8 @@ def _prepare_for_conv( lhs = inputs rhs = kernel - in_axis_name = "in_axis-" + "-".join(in_axis_names) - out_axis_name = "out_axis-" + "-".join(out_axis_names) + in_axis_name = named_axes.TmpPosAxisMarker() + out_axis_name = named_axes.TmpPosAxisMarker() # merge in axes into one in channel axis for the inputs and the kernel lhs = lhs.untag(*in_axis_names).flatten().tag(in_axis_name) From 01a9aa111d6d51ebf919fe65a48ef68d505011ef Mon Sep 17 00:00:00 2001 From: AntoinePlumerault Date: Wed, 18 Jun 2025 18:12:36 +0000 Subject: [PATCH 08/18] added precisions in _prepare_for_conv docstring --- penzai/nn/linear_and_affine.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/penzai/nn/linear_and_affine.py b/penzai/nn/linear_and_affine.py index 1c40b04..67deb50 100644 --- a/penzai/nn/linear_and_affine.py +++ b/penzai/nn/linear_and_affine.py @@ -862,10 +862,12 @@ def _prepare_for_conv( out_axis_names: Names of the output axes that will be produced by the convolution. Returns: - A tuple of two named arrays, the first one is the input with the in axes - merged into a single in channel axis, and the second one is the kernel with - the in axes merged into a single in channel axis and the out axes merged - into a single out channel axis. + A tuple of two named arrays. The first one is the conv input with the in + axes merged into a single in channel axis. Its positional axis layout is + [spatial_axes..., channel_axis]. The second one is the convolution kernel + with the in axes merged into a single in channel axis and the out axes + merged into a single out channel axis. Its positional axis layout is + [spatial_axes..., in_channel_axis, out_channel_axis]. """ lhs = inputs From c1c3135abc6ce3987539c1f76305703f96e981ce Mon Sep 17 00:00:00 2001 From: AntoinePlumerault Date: Wed, 18 Jun 2025 18:20:54 +0000 Subject: [PATCH 09/18] updated the doc of _get_named_axis_back_after_conv --- penzai/nn/linear_and_affine.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/penzai/nn/linear_and_affine.py b/penzai/nn/linear_and_affine.py index 67deb50..53ab72a 100644 --- a/penzai/nn/linear_and_affine.py +++ b/penzai/nn/linear_and_affine.py @@ -899,10 +899,12 @@ def _get_named_axis_back_after_conv( Restores the spatial axes and output axes to the result of the jax convolution operator. The spatial axes are tagged back, and the output axes are reshaped - to the original shape and tagged back. - This is necessary to restore the original shape of the output after the - convolution operator has been applied, since we flattened the output axes into - a single axis before applying the convolution. + to the original shape and tagged back. It supposes that the result have a + positional axis layout of [spatial_axes..., out_axis] with out_axis of + size equals to the product of the dimensions in out_axis_shape. This is + necessary to restore the desired shape of the output after the convolution + operator has been applied, since the convolution operates on positional + spatial axes and only outputs a single out_axis. Args: result: The result of the jax convolution operator. From f0cae27b8c68e0ed35d00bbb577e3853cb9ea31b Mon Sep 17 00:00:00 2001 From: AntoinePlumerault Date: Wed, 18 Jun 2025 18:28:36 +0000 Subject: [PATCH 10/18] added checks in _maybe_broadcast --- penzai/nn/linear_and_affine.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/penzai/nn/linear_and_affine.py b/penzai/nn/linear_and_affine.py index 53ab72a..7ddadf7 100644 --- a/penzai/nn/linear_and_affine.py +++ b/penzai/nn/linear_and_affine.py @@ -937,7 +937,14 @@ def _maybe_broadcast(value: int | Sequence[int], count: int) -> Sequence[int]: A sequence of integers with the value repeated `count` times if it was an integer, or the original sequence if it was already a sequence. """ - return [value] * count if isinstance(value, int) else value + + if isinstance(value, int): + return [value] * count + else: + assert ( + len(value) == count + ), "If value is a sequence, it must match the count." + return value def _get_dimension_numbers(ndim) -> jax.lax.ConvDimensionNumbers: @@ -1451,7 +1458,7 @@ def from_config( strides=strides, padding=padding, kernel_dilation=kernel_dilation, - inputs_dilation=[], # not used for transposed convolutions + inputs_dilation=1, # not used for transposed convolutions parallel_axes=parallel_axes, parallel_broadcast_axes=parallel_broadcast_axes, initializer=initializer, From d3dce39fcdc5ce881b972bf79e7dbfb3782c7133 Mon Sep 17 00:00:00 2001 From: AntoinePlumerault Date: Wed, 18 Jun 2025 18:44:25 +0000 Subject: [PATCH 11/18] added precision in the _get_dimension_numbers docstring --- penzai/nn/linear_and_affine.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/penzai/nn/linear_and_affine.py b/penzai/nn/linear_and_affine.py index 7ddadf7..714983e 100644 --- a/penzai/nn/linear_and_affine.py +++ b/penzai/nn/linear_and_affine.py @@ -953,12 +953,16 @@ def _get_dimension_numbers(ndim) -> jax.lax.ConvDimensionNumbers: ndim: The number of spatial dimensions of the convolution operator. Returns: A `jax.lax.ConvDimensionNumbers` object that specifies the dimension numbers - for the convolution operator. + for the convolution operator. It assumes that the input and output have the + following positional axis layout: [B, Spatial..., C] and the kernel has the + following positional axis layout: [Spatial..., I, O], where B is the batch + axis, C is the channel axis, I is the input channel axis, and O is the + output channel axis. It matches the result of _prepare_for_conv. """ return jax.lax.ConvDimensionNumbers( lhs_spec=(0, ndim + 1) - + tuple(range(1, ndim + 1)), # BHSpatial -> BCSpatial + + tuple(range(1, ndim + 1)), # BCSpatial -> BCSpatial rhs_spec=(ndim + 1, ndim) + tuple(range(ndim)), # SpatialIO -> OISpatial out_spec=(0, ndim + 1) + tuple(range(1, ndim + 1)), # BSpatialC -> BCSpatial From 142538f34004e884e56405aea1ec51fa5f4cc1ae Mon Sep 17 00:00:00 2001 From: AntoinePlumerault Date: Wed, 18 Jun 2025 18:54:21 +0000 Subject: [PATCH 12/18] fixed bug with ConvTransposedInPlace --- penzai/nn/linear_and_affine.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/penzai/nn/linear_and_affine.py b/penzai/nn/linear_and_affine.py index 714983e..85fd3b5 100644 --- a/penzai/nn/linear_and_affine.py +++ b/penzai/nn/linear_and_affine.py @@ -1058,8 +1058,9 @@ def __call__(self, in_array: NamedArray, **_side_inputs) -> NamedArray: return result @classmethod - def from_config( + def _from_config( cls, + inplace_class: type[Conv | ConvTranspose], name: str, init_base_rng: jax.Array | None, input_axes: dict[str, int], @@ -1074,7 +1075,7 @@ def from_config( initializer: LinearOperatorWeightInitializer = xavier_uniform_initializer, dtype: jax.typing.DTypeLike = jnp.float32, rename_outputs_if_necessary: bool = True, - ) -> Conv | ConvInPlace: + ) -> Conv | ConvInPlace | ConvTranspose | ConvTransposeInPlace: """Constructs a ``AbstractGeneralConv`` layer from a configuration. This can be used when building a new convolution or transposed convolution @@ -1128,12 +1129,13 @@ def from_config( # if name overlap wrap layer if primed_names is not None and original_names is not None: - return ConvInPlace( + return inplace_class( sublayers=[ core_layer, RenameAxes(old=tuple(primed_names), new=tuple(original_names)), ], ) + return core_layer def _is_transposed(self) -> bool: @@ -1319,7 +1321,8 @@ def from_config( ``input_axes`` overlaps with ``output_axes``. """ - return super().from_config( + return super()._from_config( + inplace_class=ConvInPlace, name=name, init_base_rng=init_base_rng, input_axes=input_axes, @@ -1453,7 +1456,8 @@ def from_config( `ConvTransposeInPlace` layer if ``rename_outputs_if_necessary`` is True and ``input_axes`` overlaps with ``output_axes``. """ - return super().from_config( + return super()._from_config( + inplace_class=ConvTransposeInPlace, name=name, init_base_rng=init_base_rng, input_axes=input_axes, From 4fefce417be8a43baba99f46801c24873083d2c6 Mon Sep 17 00:00:00 2001 From: AntoinePlumerault Date: Wed, 18 Jun 2025 18:56:40 +0000 Subject: [PATCH 13/18] added abstract method decorator to _is_transposed --- penzai/nn/linear_and_affine.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/penzai/nn/linear_and_affine.py b/penzai/nn/linear_and_affine.py index 85fd3b5..37281f4 100644 --- a/penzai/nn/linear_and_affine.py +++ b/penzai/nn/linear_and_affine.py @@ -21,6 +21,7 @@ import itertools from typing import Any, Literal, Protocol, Sequence import jax +import abc import jax.numpy as jnp from penzai.core import named_axes from penzai.core import shapecheck @@ -1138,6 +1139,7 @@ def _from_config( return core_layer + @abc.abstractmethod def _is_transposed(self) -> bool: ... From 7692832417734026e52d62db792d25992efb3ba1 Mon Sep 17 00:00:00 2001 From: AntoinePlumerault Date: Sun, 22 Jun 2025 16:08:29 +0000 Subject: [PATCH 14/18] added value tests for Conv and ConvTranspose --- tests/nn/linear_and_affine_test.py | 104 ++++++++++++++++++++++++++++- 1 file changed, 102 insertions(+), 2 deletions(-) diff --git a/tests/nn/linear_and_affine_test.py b/tests/nn/linear_and_affine_test.py index f652b02..4b26b2a 100644 --- a/tests/nn/linear_and_affine_test.py +++ b/tests/nn/linear_and_affine_test.py @@ -163,7 +163,7 @@ def test_affine(self): ), ) - def test_conv(self): + def test_conv_shape(self): layer = pz.nn.Conv.from_config( name="test", init_base_rng=jax.random.key(1), @@ -191,7 +191,57 @@ def test_conv(self): ), ) - def test_conv_transpose(self): + def test_conv_value(self): + inputs = jax.random.normal( + key=jax.random.PRNGKey(42), shape=(1, 10, 15, 3 * 7) + ) + + pz_inputs = pz.nx.wrap(inputs.reshape(1, 10, 15, 3, 7)).tag( + "batch", "height", "width", "foo", "baz" + ) + + simple_layer = pz.nn.Conv.from_config( + name="test", + init_base_rng=jax.random.key(1), + input_axes={"foo": 3, "baz": 7}, + output_axes={"foo_out": 5, "baz_out": 11}, + convolution_spatial_axes={"height": 3, "width": 3}, + parallel_axes=None, + parallel_broadcast_axes=None, + rename_outputs_if_necessary=True, + ) + + pz_outputs = ( + simple_layer(pz_inputs) + .untag("batch", "height", "width", "foo_out", "baz_out") + .reshape((1, 10, 15, 5 * 11)) + .unwrap() + ) + + # build equivalent jax conv + kernel = ( + simple_layer.kernel.value.untag( + "height", + "width", + "foo", + "baz", + "foo_out", + "baz_out", + ) + .reshape(3, 3, 3 * 7, 5 * 11) + .unwrap() + ) + outputs = jax.lax.conv_general_dilated( + inputs, + kernel, + window_strides=(1, 1), + padding="SAME", + dimension_numbers=("NHWC", "HWIO", "NHWC"), + ) + + chex.assert_trees_all_equal(pz_outputs, outputs) + + def test_conv_transpose_shape(self): layer = pz.nn.ConvTranspose.from_config( name="test", init_base_rng=jax.random.key(1), @@ -219,6 +269,56 @@ def test_conv_transpose(self): ), ) + def test_conv_transpose_value(self): + inputs = jax.random.normal( + key=jax.random.PRNGKey(42), shape=(1, 10, 15, 3 * 7) + ) + + pz_inputs = pz.nx.wrap(inputs.reshape(1, 10, 15, 3, 7)).tag( + "batch", "height", "width", "foo", "baz" + ) + + simple_layer = pz.nn.ConvTranspose.from_config( + name="test", + init_base_rng=jax.random.key(1), + input_axes={"foo": 3, "baz": 7}, + output_axes={"foo_out": 5, "baz_out": 11}, + convolution_spatial_axes={"height": 3, "width": 3}, + parallel_axes=None, + parallel_broadcast_axes=None, + rename_outputs_if_necessary=True, + ) + + pz_outputs = ( + simple_layer(pz_inputs) + .untag("batch", "height", "width", "foo_out", "baz_out") + .reshape((1, 10, 15, 5 * 11)) + .unwrap() + ) + + # build equivalent jax conv transpose + kernel = ( + simple_layer.kernel.value.untag( + "height", + "width", + "foo", + "baz", + "foo_out", + "baz_out", + ) + .reshape(3, 3, 3 * 7, 5 * 11) + .unwrap() + ) + outputs = jax.lax.conv_transpose( + inputs, + kernel, + padding="SAME", + strides=(1, 1), + dimension_numbers=("NHWC", "HWIO", "NHWC"), + ) + + chex.assert_trees_all_equal(pz_outputs, outputs) + def test_constant_rescale(self): layer = pz.nn.ConstantRescale(3.0) result = layer(pz.nx.ones({"foo": 3})) From c620e58d30f502e38daa80cf1fa18093e34dac40 Mon Sep 17 00:00:00 2001 From: AntoinePlumerault Date: Sun, 22 Jun 2025 16:19:34 +0000 Subject: [PATCH 15/18] should fix typing errors --- penzai/nn/linear_and_affine.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/penzai/nn/linear_and_affine.py b/penzai/nn/linear_and_affine.py index 37281f4..743acfd 100644 --- a/penzai/nn/linear_and_affine.py +++ b/penzai/nn/linear_and_affine.py @@ -19,7 +19,7 @@ import dataclasses import functools import itertools -from typing import Any, Literal, Protocol, Sequence +from typing import Any, Literal, Protocol, Sequence, cast import jax import abc import jax.numpy as jnp @@ -1076,7 +1076,7 @@ def _from_config( initializer: LinearOperatorWeightInitializer = xavier_uniform_initializer, dtype: jax.typing.DTypeLike = jnp.float32, rename_outputs_if_necessary: bool = True, - ) -> Conv | ConvInPlace | ConvTranspose | ConvTransposeInPlace: + ) -> AbstractGeneralConv | ConvInPlace | ConvTransposeInPlace: """Constructs a ``AbstractGeneralConv`` layer from a configuration. This can be used when building a new convolution or transposed convolution @@ -1323,7 +1323,7 @@ def from_config( ``input_axes`` overlaps with ``output_axes``. """ - return super()._from_config( + layer = super()._from_config( inplace_class=ConvInPlace, name=name, init_base_rng=init_base_rng, @@ -1340,6 +1340,9 @@ def from_config( dtype=dtype, rename_outputs_if_necessary=rename_outputs_if_necessary, ) + if isinstance(layer, AbstractGeneralConv): + return cast(Conv, layer) + return layer def _is_transposed(self): return False @@ -1458,7 +1461,7 @@ def from_config( `ConvTransposeInPlace` layer if ``rename_outputs_if_necessary`` is True and ``input_axes`` overlaps with ``output_axes``. """ - return super()._from_config( + layer = super()._from_config( inplace_class=ConvTransposeInPlace, name=name, init_base_rng=init_base_rng, @@ -1475,6 +1478,9 @@ def from_config( dtype=dtype, rename_outputs_if_necessary=rename_outputs_if_necessary, ) + if isinstance(layer, AbstractGeneralConv): + return cast(ConvTranspose, layer) + return layer def _is_transposed(self): return True From bb42f1e5b06d0fdb6ea37bab622eee3d20c10366 Mon Sep 17 00:00:00 2001 From: AntoinePlumerault Date: Sun, 22 Jun 2025 20:56:41 +0000 Subject: [PATCH 16/18] fixed error in typing --- penzai/nn/linear_and_affine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/penzai/nn/linear_and_affine.py b/penzai/nn/linear_and_affine.py index 743acfd..8e7fa65 100644 --- a/penzai/nn/linear_and_affine.py +++ b/penzai/nn/linear_and_affine.py @@ -1061,7 +1061,7 @@ def __call__(self, in_array: NamedArray, **_side_inputs) -> NamedArray: @classmethod def _from_config( cls, - inplace_class: type[Conv | ConvTranspose], + inplace_class: type[ConvInPlace | ConvTransposeInPlace], name: str, init_base_rng: jax.Array | None, input_axes: dict[str, int], @@ -1480,6 +1480,7 @@ def from_config( ) if isinstance(layer, AbstractGeneralConv): return cast(ConvTranspose, layer) + return layer def _is_transposed(self): From c31b4aa31348b9db4f330349ae67ba436344ed16 Mon Sep 17 00:00:00 2001 From: AntoinePlumerault Date: Tue, 24 Jun 2025 18:13:03 +0000 Subject: [PATCH 17/18] all typing errors should be fixed --- penzai/nn/linear_and_affine.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/penzai/nn/linear_and_affine.py b/penzai/nn/linear_and_affine.py index 8e7fa65..32b6ce2 100644 --- a/penzai/nn/linear_and_affine.py +++ b/penzai/nn/linear_and_affine.py @@ -1191,7 +1191,7 @@ def parallel_axes(self) -> dict[str, int]: return { # pytype: disable=bad-return-type name: size for name, size in self.kernel.value.named_shape.items() - if name not in self.convolution_spatial_axis_names + if name not in self.spatial_axis_names and name not in self.in_axis_names and name not in self.out_axis_names } @@ -1204,7 +1204,7 @@ def convolution_spatial_axes(self) -> dict[str, int]: return { # pytype: disable=bad-return-type name: size for name, size in self.kernel.value.named_shape.items() - if name in self.spatial_axes_names + if name in self.spatial_axis_names } @@ -1342,6 +1342,7 @@ def from_config( ) if isinstance(layer, AbstractGeneralConv): return cast(Conv, layer) + assert isinstance(layer, ConvInPlace) return layer def _is_transposed(self): @@ -1481,6 +1482,7 @@ def from_config( if isinstance(layer, AbstractGeneralConv): return cast(ConvTranspose, layer) + assert isinstance(layer, ConvTransposeInPlace) return layer def _is_transposed(self): From 4d205d8dfa9186e2e014fb126f19955f757a8f86 Mon Sep 17 00:00:00 2001 From: AntoinePlumerault Date: Tue, 1 Jul 2025 17:44:15 +0000 Subject: [PATCH 18/18] added test, fixed shape issue & jit_wrapper_issue --- penzai/nn/linear_and_affine.py | 49 +++++++++++----- tests/nn/linear_and_affine_test.py | 93 ++++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+), 15 deletions(-) diff --git a/penzai/nn/linear_and_affine.py b/penzai/nn/linear_and_affine.py index 32b6ce2..116105a 100644 --- a/penzai/nn/linear_and_affine.py +++ b/penzai/nn/linear_and_affine.py @@ -973,8 +973,10 @@ def _get_dimension_numbers(ndim) -> jax.lax.ConvDimensionNumbers: @struct.pytree_dataclass class AbstractGeneralConv(layer_base.Layer): kernel: parameters.ParameterLike[NamedArray] - strides: Sequence[int] - padding: str | Sequence[tuple[int, int]] + strides: Sequence[int] = dataclasses.field(metadata={"pytree_node": False}) + padding: str | Sequence[tuple[int, int]] = dataclasses.field( + metadata={"pytree_node": False} + ) spatial_axis_names: tuple[str, ...] = dataclasses.field( metadata={"pytree_node": False} @@ -986,8 +988,12 @@ class AbstractGeneralConv(layer_base.Layer): metadata={"pytree_node": False} ) - kernel_dilation: Sequence[int] - inputs_dilation: Sequence[int] + kernel_dilation: Sequence[int] = dataclasses.field( + metadata={"pytree_node": False} + ) + inputs_dilation: Sequence[int] = dataclasses.field( + metadata={"pytree_node": False} + ) def __call__(self, in_array: NamedArray, **_side_inputs) -> NamedArray: """Runs the Convolution operator.""" @@ -1004,7 +1010,7 @@ def __call__(self, in_array: NamedArray, **_side_inputs) -> NamedArray: # pytype: enable=attribute-error dimvars = shapecheck.check_structure( - in_array, in_struct, error_prefix=error_prefix + in_array, in_struct, error_prefix=error_prefix # TODO: here ) lhs, rhs = _prepare_for_conv( @@ -1151,7 +1157,7 @@ def _input_structure(self): and name not in self.spatial_axis_names } return shapecheck.ArraySpec( - named_shape={**shapecheck.var("B"), **known_in_axes}, + named_shape={**shapecheck.var("In"), **known_in_axes}, dtype=jnp.floating, ) @@ -1162,8 +1168,9 @@ def _output_structure(self): if name not in self.in_axis_names and name not in self.spatial_axis_names } + print(f"known_out_axes: {known_out_axes}") return shapecheck.ArraySpec( - named_shape={**shapecheck.var("B"), **known_out_axes}, + named_shape={**shapecheck.var("Out"), **known_out_axes}, dtype=jnp.floating, ) @@ -1231,8 +1238,10 @@ class Conv(AbstractGeneralConv): """ kernel: parameters.ParameterLike[NamedArray] - strides: Sequence[int] - padding: str | Sequence[tuple[int, int]] + strides: Sequence[int] = dataclasses.field(metadata={"pytree_node": False}) + padding: str | Sequence[tuple[int, int]] = dataclasses.field( + metadata={"pytree_node": False} + ) spatial_axis_names: tuple[str, ...] = dataclasses.field( metadata={"pytree_node": False} @@ -1244,8 +1253,12 @@ class Conv(AbstractGeneralConv): metadata={"pytree_node": False} ) - kernel_dilation: Sequence[int] - inputs_dilation: Sequence[int] + kernel_dilation: Sequence[int] = dataclasses.field( + metadata={"pytree_node": False} + ) + inputs_dilation: Sequence[int] = dataclasses.field( + metadata={"pytree_node": False} + ) @classmethod def from_config( @@ -1374,8 +1387,10 @@ class ConvTranspose(AbstractGeneralConv): """ kernel: parameters.ParameterLike[NamedArray] - strides: Sequence[int] - padding: str | Sequence[tuple[int, int]] + strides: Sequence[int] = dataclasses.field(metadata={"pytree_node": False}) + padding: str | Sequence[tuple[int, int]] = dataclasses.field( + metadata={"pytree_node": False} + ) spatial_axis_names: tuple[str, ...] = dataclasses.field( metadata={"pytree_node": False} @@ -1387,8 +1402,12 @@ class ConvTranspose(AbstractGeneralConv): metadata={"pytree_node": False} ) - kernel_dilation: Sequence[int] - inputs_dilation: Sequence[int] + kernel_dilation: Sequence[int] = dataclasses.field( + metadata={"pytree_node": False} + ) + inputs_dilation: Sequence[int] = dataclasses.field( + metadata={"pytree_node": False} + ) @classmethod def from_config( diff --git a/tests/nn/linear_and_affine_test.py b/tests/nn/linear_and_affine_test.py index 4b26b2a..cc1487e 100644 --- a/tests/nn/linear_and_affine_test.py +++ b/tests/nn/linear_and_affine_test.py @@ -18,6 +18,7 @@ import chex import jax from penzai import pz +from penzai.toolshed import jit_wrapper class LinearAndAffineTest(absltest.TestCase): @@ -191,6 +192,52 @@ def test_conv_shape(self): ), ) + def test_strided_conv_shape(self): + layer = pz.nn.Conv.from_config( + name="test", + init_base_rng=jax.random.key(1), + input_axes={"foo": 3}, + output_axes={"foo": 5}, + convolution_spatial_axes={"height": 3, "width": 3}, + strides=(2, 2), + parallel_axes={"baz": 7}, + parallel_broadcast_axes={"qux": 11}, + rename_outputs_if_necessary=True, + ) + result = layer( + pz.nx.ones({"batch": 1, "height": 10, "width": 16, "foo": 3, "baz": 7}), + ) + pz.chk.check_structure( + result, + pz.chk.ArraySpec( + named_shape={ + "batch": 1, + "height": 5, + "width": 8, + "foo": 5, + "baz": 7, + "qux": 11, + } + ), + ) + + def test_conv_jit_wrapper(self): + layer = jit_wrapper.Jitted( + pz.nn.Conv.from_config( + name="test", + init_base_rng=jax.random.key(1), + input_axes={"foo": 3}, + output_axes={"foo": 5}, + convolution_spatial_axes={"height": 3, "width": 3}, + parallel_axes={"baz": 7}, + parallel_broadcast_axes={"qux": 11}, + rename_outputs_if_necessary=True, + ) + ) + layer( + pz.nx.ones({"batch": 1, "height": 10, "width": 15, "foo": 3, "baz": 7}), + ) + def test_conv_value(self): inputs = jax.random.normal( key=jax.random.PRNGKey(42), shape=(1, 10, 15, 3 * 7) @@ -269,6 +316,35 @@ def test_conv_transpose_shape(self): ), ) + def test_strided_conv_transpose_shape(self): + layer = pz.nn.ConvTranspose.from_config( + name="test", + init_base_rng=jax.random.key(1), + input_axes={"foo": 3}, + output_axes={"foo": 5}, + convolution_spatial_axes={"height": 3, "width": 3}, + strides=(2, 2), + parallel_axes={"baz": 7}, + parallel_broadcast_axes={"qux": 11}, + rename_outputs_if_necessary=True, + ) + result = layer( + pz.nx.ones({"batch": 1, "height": 10, "width": 16, "foo": 3, "baz": 7}), + ) + pz.chk.check_structure( + result, + pz.chk.ArraySpec( + named_shape={ + "batch": 1, + "height": 20, + "width": 32, + "foo": 5, + "baz": 7, + "qux": 11, + } + ), + ) + def test_conv_transpose_value(self): inputs = jax.random.normal( key=jax.random.PRNGKey(42), shape=(1, 10, 15, 3 * 7) @@ -319,6 +395,23 @@ def test_conv_transpose_value(self): chex.assert_trees_all_equal(pz_outputs, outputs) + def test_conv_transposed_jit_wrapper(self): + layer = jit_wrapper.Jitted( + pz.nn.ConvTranspose.from_config( + name="test", + init_base_rng=jax.random.key(1), + input_axes={"foo": 3}, + output_axes={"foo": 5}, + convolution_spatial_axes={"height": 3, "width": 3}, + parallel_axes={"baz": 7}, + parallel_broadcast_axes={"qux": 11}, + rename_outputs_if_necessary=True, + ) + ) + layer( + pz.nx.ones({"batch": 1, "height": 10, "width": 15, "foo": 3, "baz": 7}), + ) + def test_constant_rescale(self): layer = pz.nn.ConstantRescale(3.0) result = layer(pz.nx.ones({"foo": 3}))