Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 100 additions & 3 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2267,17 +2267,113 @@ def pack(
Parameters
----------
tensors: TensorVariable
Tensors to be packed into a single vector.
Tensors to be packed. Tensors can have varying shapes and dimensions, but must have the same size along each
of the dimensions specified in the `axes` parameter.
axes: int or sequence of int, optional
Axes to be concatenated. All other axes will be raveled (packed) and joined. If None, all axes will be raveled
and joined.
Axes to be preserved. All other axes will be raveled (packed), and the output is the result of concatenating
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC this is changing a bit the usual meaning of axes. Axes=None is usually the same as an exhaustive list of all axes. a.sum(None) == a.sum(tuple(range(a.ndim)), but here it is flipped?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

axis=[] would be the same as None in this case, nothing is preserved.

on the new raveled dimension. If None, all axes will be raveled and joined. Axes can be either positive or
negative, but must be striclty increasing in both the positive and negative parts of the list. Negative axes
must come after positive axes.

Returns
-------
flat_tensor: TensorVariable
A new symbolic variable representing the concatenated 1d vector of all tensor inputs
packed_shapes: list of tuples of TensorVariable
A list of tuples, where each tuple contains the symbolic shape of the original tensors.

Notes
-----
This function is a helper for joining tensors of varying shapes into a single tenor. This is done by choosing a
list of axes to concatenate, and raveling all other axes. The resulting tensor are then concatenated along the
raveled axis. The original shapes of the tensors are also returned, so that they can be unpacked later.

The `axes` parameter determines which dimensions are *not* raveled. The requested axes must exist in all input
tensors, but there are otherwwise no restrictions on the shapes or dimensions of the input tensors. For example, if
`axes=[0]`, then the first dimension of each tensor is preserved, and all other dimensions are raveled:

.. code-block:: python

import pytensor.tensor as pt

x = pt.tensor("x", shape=(2, 3, 4))
y = pt.tensor("y", shape=(2, 5))
packed_output, shapes = pack(x, y, axes=0)
# packed_output will have shape (2, 3 * 4 + 5) = (2, 17)

Since axes = 0, the first dimension of both `x` and `y` is preserved. This first example is equivalent to a simple
reshape and concat operation:

.. code-block:: python

x_reshaped = x.reshape(2, -1) # shape (2, 12)
y_reshaped = y.reshape(2, -1) # shape (2, 5)
packed_output = pt.concatenate(
[x_reshaped, y_reshaped], axis=1
) # shape (2, 17)

`axes` can also be negative, in which case the axes are counted from the end of the tensor shape. For example,
if `axes=[-1]`, then the last dimension of each tensor is preserved, and all other dimensions are raveled:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The axes specification seems troublesome for vectorization. If axis were those to ravel, then vectorization of this input is very much like axis for other Ops with axes, just shift them to the right by the number of new batch dims.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nvm it's not too bad, we need to add the new batch axis to whatever was defined originally. I'm a bit worried about negative axis at the Op level still


.. code-block:: python

import pytensor.tensor as pt

x = pt.tensor("x", shape=(3, 4, 7))
y = pt.tensor("y", shape=(6, 2, 1, 7))
packed_output, shapes = pack(x, y, axes=-1)
# packed_output will have shape (3 * 4 + 6 * 2 * 1, 7) = (24, 7)

The most important restriction of `axes` is that there can be at most one "hole" in the axes list. A hole is
defined as a missing axis in the sequence of axes. The easiest way to define a hole is by using both positive
and negative axes together. For example, `axes=[0, -1]` has a hole between the first and last axes. In this case,
the first and last dimensions of each tensor are preserved, and all other dimensions are raveled:

.. code-block:: python

import pytensor.tensor as pt

x = pt.tensor("x", shape=(2, 3, 2, 3, 7))
y = pt.tensor("y", shape=(2, 6, 7))
packed_output, shapes = pack(x, y, axes=[0, -1])
# packed_output will have shape (2, 3 * 2 * 3 + 6, 7) = (2, 24, 7)

Multiple explicit holes are not allowed. For example, `axes = [0, 2, -1]` is illegal because there are two holes,
one between axes 0 and 2, and another between axes 2 and -1.

Implicit holes are also possible when using only positive or only negative axes. `axes = [0]` already has an
implicit hole to the right of axis 0. `axes = [2, 3]` has two implicit holes, one to the left of axis 2, and another
to the right. This is illegal, since there are two holes. However, `axes = [2, 3]` can be made legal if we interpret
axis 3 as the last axis (-1), which closes the right implicit hole. The interpretation requires that at least one
input tensor has exactly 4 dimensions:

.. code-block:: python

import pytensor.tensor as pt

x = pt.tensor("x", shape=(5, 2, 3, 4))
y = pt.tensor("y", shape=(2, 3, 4))
packed_output, shapes = pack(x, y, axes=[2, 3])
# packed_output will have shape (5 * 2 + 2, 3, 4) = (12, 3, 4)

Note here that `y` has only 3 dimensions, so axis 3 is interpreted as -1, the last axis. If no input has 4
dimensions, or if any input has more than 4 dimensions, an error is raised in this case.

Negative axes have similar rules regarding implicit holes. `axes = [-1]` has an implicit hole to the left of
axis -1. `axes = [-3, -2]` has two implicit holes. To arrive at a valid interpretation, we take -3 to be axis 0,
which closes the left implicit hole. This requires that at least one input tensor has exactly 3 dimensions:

.. code-block:: python

import pytensor.tensor as pt

x = pt.tensor("x", shape=(2, 3, 4))
y = pt.tensor("y", shape=(6, 4))
packed_output, shapes = pack(x, y, axes=[-3, -2])
# packed_output will have shape (2 + 6, 3, 4) = (8, 3, 4)

Similarly to the previous example, if no input has 3 dimensions, or if any input has more than 3 dimensions, an
error would be raised in this example.
"""
if not tensors:
raise ValueError("Cannot pack an empty list of tensors.")
Expand Down Expand Up @@ -2316,6 +2412,7 @@ def pack(
inputs=tensors,
outputs=[packed_output_tensor, *packed_output_shapes],
name="Pack{axes=" + str(axes) + "}",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why give name instead of just defining the __str__ of the Pack?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I looked at the docs for OpFromGraph and saw there was a name field I could pass

inline=True,
)

outputs = pack_op(*tensors)
Expand Down
Loading