-
Notifications
You must be signed in to change notification settings - Fork 150
Implement pack/unpack helpers #1578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
9568a83
2e22d34
79d9662
58c0286
5788333
ed60651
0b86851
20ab4e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2267,17 +2267,113 @@ def pack( | |
| Parameters | ||
| ---------- | ||
| tensors: TensorVariable | ||
| Tensors to be packed into a single vector. | ||
| Tensors to be packed. Tensors can have varying shapes and dimensions, but must have the same size along each | ||
| of the dimensions specified in the `axes` parameter. | ||
| axes: int or sequence of int, optional | ||
| Axes to be concatenated. All other axes will be raveled (packed) and joined. If None, all axes will be raveled | ||
| and joined. | ||
| Axes to be preserved. All other axes will be raveled (packed), and the output is the result of concatenating | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC this is changing a bit the usual meaning of axes. Axes=None is usually the same as an exhaustive list of all axes.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. axis=[] would be the same as None in this case, nothing is preserved. |
||
| on the new raveled dimension. If None, all axes will be raveled and joined. Axes can be either positive or | ||
| negative, but must be striclty increasing in both the positive and negative parts of the list. Negative axes | ||
| must come after positive axes. | ||
|
|
||
| Returns | ||
| ------- | ||
| flat_tensor: TensorVariable | ||
| A new symbolic variable representing the concatenated 1d vector of all tensor inputs | ||
| packed_shapes: list of tuples of TensorVariable | ||
| A list of tuples, where each tuple contains the symbolic shape of the original tensors. | ||
|
|
||
| Notes | ||
| ----- | ||
| This function is a helper for joining tensors of varying shapes into a single tenor. This is done by choosing a | ||
| list of axes to concatenate, and raveling all other axes. The resulting tensor are then concatenated along the | ||
| raveled axis. The original shapes of the tensors are also returned, so that they can be unpacked later. | ||
|
|
||
| The `axes` parameter determines which dimensions are *not* raveled. The requested axes must exist in all input | ||
| tensors, but there are otherwwise no restrictions on the shapes or dimensions of the input tensors. For example, if | ||
| `axes=[0]`, then the first dimension of each tensor is preserved, and all other dimensions are raveled: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| import pytensor.tensor as pt | ||
|
|
||
| x = pt.tensor("x", shape=(2, 3, 4)) | ||
| y = pt.tensor("y", shape=(2, 5)) | ||
| packed_output, shapes = pack(x, y, axes=0) | ||
| # packed_output will have shape (2, 3 * 4 + 5) = (2, 17) | ||
|
|
||
| Since axes = 0, the first dimension of both `x` and `y` is preserved. This first example is equivalent to a simple | ||
| reshape and concat operation: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| x_reshaped = x.reshape(2, -1) # shape (2, 12) | ||
| y_reshaped = y.reshape(2, -1) # shape (2, 5) | ||
| packed_output = pt.concatenate( | ||
| [x_reshaped, y_reshaped], axis=1 | ||
| ) # shape (2, 17) | ||
|
|
||
| `axes` can also be negative, in which case the axes are counted from the end of the tensor shape. For example, | ||
| if `axes=[-1]`, then the last dimension of each tensor is preserved, and all other dimensions are raveled: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The axes specification seems troublesome for vectorization. If axis were those to ravel, then vectorization of this input is very much like axis for other Ops with axes, just shift them to the right by the number of new batch dims.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nvm it's not too bad, we need to add the new batch axis to whatever was defined originally. I'm a bit worried about negative axis at the Op level still |
||
|
|
||
| .. code-block:: python | ||
|
|
||
| import pytensor.tensor as pt | ||
|
|
||
| x = pt.tensor("x", shape=(3, 4, 7)) | ||
| y = pt.tensor("y", shape=(6, 2, 1, 7)) | ||
| packed_output, shapes = pack(x, y, axes=-1) | ||
| # packed_output will have shape (3 * 4 + 6 * 2 * 1, 7) = (24, 7) | ||
|
|
||
| The most important restriction of `axes` is that there can be at most one "hole" in the axes list. A hole is | ||
| defined as a missing axis in the sequence of axes. The easiest way to define a hole is by using both positive | ||
| and negative axes together. For example, `axes=[0, -1]` has a hole between the first and last axes. In this case, | ||
| the first and last dimensions of each tensor are preserved, and all other dimensions are raveled: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| import pytensor.tensor as pt | ||
|
|
||
| x = pt.tensor("x", shape=(2, 3, 2, 3, 7)) | ||
| y = pt.tensor("y", shape=(2, 6, 7)) | ||
| packed_output, shapes = pack(x, y, axes=[0, -1]) | ||
| # packed_output will have shape (2, 3 * 2 * 3 + 6, 7) = (2, 24, 7) | ||
|
|
||
| Multiple explicit holes are not allowed. For example, `axes = [0, 2, -1]` is illegal because there are two holes, | ||
| one between axes 0 and 2, and another between axes 2 and -1. | ||
|
|
||
| Implicit holes are also possible when using only positive or only negative axes. `axes = [0]` already has an | ||
| implicit hole to the right of axis 0. `axes = [2, 3]` has two implicit holes, one to the left of axis 2, and another | ||
| to the right. This is illegal, since there are two holes. However, `axes = [2, 3]` can be made legal if we interpret | ||
| axis 3 as the last axis (-1), which closes the right implicit hole. The interpretation requires that at least one | ||
| input tensor has exactly 4 dimensions: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| import pytensor.tensor as pt | ||
|
|
||
| x = pt.tensor("x", shape=(5, 2, 3, 4)) | ||
| y = pt.tensor("y", shape=(2, 3, 4)) | ||
| packed_output, shapes = pack(x, y, axes=[2, 3]) | ||
| # packed_output will have shape (5 * 2 + 2, 3, 4) = (12, 3, 4) | ||
|
|
||
| Note here that `y` has only 3 dimensions, so axis 3 is interpreted as -1, the last axis. If no input has 4 | ||
| dimensions, or if any input has more than 4 dimensions, an error is raised in this case. | ||
|
|
||
| Negative axes have similar rules regarding implicit holes. `axes = [-1]` has an implicit hole to the left of | ||
| axis -1. `axes = [-3, -2]` has two implicit holes. To arrive at a valid interpretation, we take -3 to be axis 0, | ||
| which closes the left implicit hole. This requires that at least one input tensor has exactly 3 dimensions: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| import pytensor.tensor as pt | ||
|
|
||
| x = pt.tensor("x", shape=(2, 3, 4)) | ||
| y = pt.tensor("y", shape=(6, 4)) | ||
| packed_output, shapes = pack(x, y, axes=[-3, -2]) | ||
| # packed_output will have shape (2 + 6, 3, 4) = (8, 3, 4) | ||
|
|
||
| Similarly to the previous example, if no input has 3 dimensions, or if any input has more than 3 dimensions, an | ||
| error would be raised in this example. | ||
| """ | ||
| if not tensors: | ||
| raise ValueError("Cannot pack an empty list of tensors.") | ||
|
|
@@ -2316,6 +2412,7 @@ def pack( | |
| inputs=tensors, | ||
| outputs=[packed_output_tensor, *packed_output_shapes], | ||
| name="Pack{axes=" + str(axes) + "}", | ||
jessegrabowski marked this conversation as resolved.
Show resolved
Hide resolved
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why give name instead of just defining the
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because I looked at the docs for |
||
| inline=True, | ||
| ) | ||
|
|
||
| outputs = pack_op(*tensors) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.