Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions penzai/core/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,44 @@
T = typing.TypeVar("T")


def _shift_negative_indices(indices: Iterable[int], shift: int) -> tuple[int, ...]:
"""Adds `shift` to negative indices and leaves non-negative indices unchanged

Can be used to handle negative indices. For example, if we expect indices in
`r = range(6)` and we get `[0, 3, -2]` as input, we can use

```py
shift_negative_indices([0, 3, -2], len(r))
```

to get `(0, 3, 4)`. The same can be achieved in more generality with

```py
pz.select((0, 3, -2)) \
.at_instances_of(int) \
.where(lambda i: i < 0) \
.apply(lambda i: i + shift)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I should just use this instead of defining _shift_negative_indices. I'm weary of doing so because we're in the module that defines selectors here. Also, we don't need that kind of generality at this point and _shift_negative_indices avoids some overhead

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defining a new function here like you've done makes sense to me, the logic is probably too simple to benefit much from using selectors. (You could also just inline the logic into the call site since it's only used once.)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could also just inline the logic into the call site since it's only used once

I'd prefer to keep it separate so code blocks stay short and hence easy to interpret

```

which is why this method is private to this module

Args:
indices: The integers to shift
shift: The offset to add to negative indices. Usually, this is the largest
index + 1, i.e. the length of the range of indices

Returns:
The indices as a tuple, with negative indices increased by `shift`
"""
maybe_shifted_indices = []
for index in indices:
if index < 0:
maybe_shifted_indices.append(index + shift)
else:
maybe_shifted_indices.append(index)
return tuple(maybe_shifted_indices)


@struct.pytree_dataclass
class SelectionHole(struct.Struct):
"""A hole in a structure, taking the place of a selected subtree.
Expand Down Expand Up @@ -1356,6 +1394,8 @@ def pick_nth_selected(self, n: int | Sequence[int]) -> Selection:
else:
indices = n

indices = _shift_negative_indices(indices, len(self.selected_by_path))

with _wrap_selection_errors(self):
keep = _InProgressSelectionBoundary
new_selected_by_path = collections.OrderedDict({
Expand Down
41 changes: 40 additions & 1 deletion tests/core/selectors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,31 @@

import collections
import dataclasses
from typing import Any
from typing import Any, Iterable

from absl.testing import absltest
import jax
from penzai import pz
import penzai.core.selectors
import pytest


@pytest.mark.parametrize(
"input_indices, shift, expected_output",
[
((), 1, ()),
([0, 3, -2], len(range(6)), (0, 3, 4)),
]
)
def test_shift_negative_indices(
input_indices: Iterable[int],
shift: int,
expected_output: tuple[int, ...],
):
assert (
penzai.core.selectors._shift_negative_indices(input_indices, shift=shift)
== expected_output
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind adding this as an absltest-style test (as a method inside SelectorsTest) instead of using pytest here? The rest of the tests do not use pytest yet and I'd prefer to be consistent here for now.



@dataclasses.dataclass
Expand Down Expand Up @@ -565,6 +584,26 @@ def test_pick_nth_selected(self):
),
[0, 1, 2, 3, SELECTED_PART(value=4), 5, 6, 7, 8, 9],
)
# Test negative indices for `pick_nth_selected`
self.assertEqual(
(
pz.select(list(range(10)))
.at_instances_of(int)
.pick_nth_selected(-2)
.apply(SELECTED_PART)
),
[0, 1, 2, 3, 4, 5, 6, 7, SELECTED_PART(value=8), 9],
)
# Don't select anything if index is out of range
self.assertEqual(
(
pz.select([0, 1, 2])
.at_instances_of(int)
.pick_nth_selected(5)
.apply(SELECTED_PART)
),
[0, 1, 2],
)

def test_invert__example_1(self):
predicate = lambda node: isinstance(node, CustomLeaf) and node.tag <= 10
Expand Down
Loading