Skip to content

Commit

Permalink
[Minor] Expose is_leaf_nontensor and default_is_leaf
Browse files Browse the repository at this point in the history
ghstack-source-id: aeaa6981081b45430f6664530f3dab7cc3d21759
Pull Request resolved: #1219

(cherry picked from commit 8bdbf01)
  • Loading branch information
vmoens committed Feb 17, 2025
1 parent 25ec492 commit d08e74e
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 15 deletions.
4 changes: 3 additions & 1 deletion docs/source/reference/tensordict.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ or ``cat``.
:template: td_template.rst

cat
from_consolidated
default_is_leaf
from_any
from_consolidated
from_dict
from_h5
from_module
Expand All @@ -43,6 +44,7 @@ or ``cat``.
from_tuple
fromkeys
is_batchedtensor
is_leaf_nontensor
lazy_stack
load
load_memmap
Expand Down
3 changes: 3 additions & 0 deletions tensordict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from tensordict._unbatched import UnbatchedTensor

from tensordict.base import (
_default_is_leaf as default_is_leaf,
_is_leaf_nontensor as is_leaf_nontensor,
from_any,
from_dict,
from_h5,
Expand Down Expand Up @@ -56,6 +58,7 @@
assert_allclose_td,
assert_close,
is_batchedtensor,
is_non_tensor,
is_tensorclass,
lazy_legacy,
parse_tensor_dict_string,
Expand Down
2 changes: 1 addition & 1 deletion tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1770,7 +1770,7 @@ def _multithread_apply_flat(
named: bool = False,
nested_keys: bool = False,
prefix: tuple = (),
is_leaf: Callable = None,
is_leaf: Callable[[Type], bool] | None = None,
executor: ThreadPoolExecutor,
futures: List[Future],
local_futures: List,
Expand Down
4 changes: 2 additions & 2 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,7 @@ def _multithread_apply_flat(
named: bool = False,
nested_keys: bool = False,
prefix: tuple = (),
is_leaf: Callable = None,
is_leaf: Callable[[Type], bool] | None = None,
executor: ThreadPoolExecutor,
futures: List[Future],
local_futures: List,
Expand Down Expand Up @@ -1336,7 +1336,7 @@ def _apply_nest(
nested_keys: bool = False,
prefix: tuple = (),
filter_empty: bool | None = None,
is_leaf: Callable | None = None,
is_leaf: Callable[[Type], bool] | None = None,
out: TensorDictBase | None = None,
**constructor_kwargs,
) -> T | None:
Expand Down
96 changes: 86 additions & 10 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6554,6 +6554,8 @@ def update(
whether an object type is to be considered a leaf and swapped
or a tensor collection.

.. seealso:: :meth:`~tensordict.is_leaf_nontensor` and :meth:`~tensordict.default_is_leaf`.

Returns:
self

Expand Down Expand Up @@ -6993,8 +6995,20 @@ def items(
Defaults to ``False``.
leaves_only (bool, optional): if ``False``, only leaves will be
returned. Defaults to ``False``.
is_leaf: an optional callable that indicates if a class is to be considered a
leaf or not.
is_leaf (callable, optional): a callable over a class type returning
a bool indicating if this class has to be considered as a leaf.

.. note:: The purpose of `is_leaf` is not to prevent recursive calls into nested tensordicts, but
rather to mark certain types as "leaves" for the purpose of filtering when `leaves_only=True`.
Even if `is_leaf(cls)` returns `True`, the nested structure of the tensordict will still be
traversed if `include_nested=True`.
In other words, `is_leaf` does not control the recursion depth, but rather provides a way to filter
out certain types from the result when `leaves_only=True`. This means that a node in the tree can
be both a leaf and a node with children.
In practice, the default value of ``is_leaf`` does exclude tensordict and tensorclass instances
from the leaf set.

.. seealso:: :meth:`~tensordict.is_leaf_nontensor` and :meth:`~tensordict.default_is_leaf`.

Keyword Args:
sort (bool, optional): whether the keys should be sorted. For nested keys,
Expand Down Expand Up @@ -7073,8 +7087,20 @@ def values(
Defaults to ``False``.
leaves_only (bool, optional): if ``False``, only leaves will be
returned. Defaults to ``False``.
is_leaf: an optional callable that indicates if a class is to be considered a
leaf or not.
is_leaf (callable, optional): a callable over a class type returning
a bool indicating if this class has to be considered as a leaf.

.. note:: The purpose of `is_leaf` is not to prevent recursive calls into nested tensordicts, but
rather to mark certain types as "leaves" for the purpose of filtering when `leaves_only=True`.
Even if `is_leaf(cls)` returns `True`, the nested structure of the tensordict will still be
traversed if `include_nested=True`.
In other words, `is_leaf` does not control the recursion depth, but rather provides a way to filter
out certain types from the result when `leaves_only=True`. This means that a node in the tree can
be both a leaf and a node with children.
In practice, the default value of ``is_leaf`` does exclude tensordict and tensorclass instances
from the leaf set.

.. seealso:: :meth:`~tensordict.is_leaf_nontensor` and :meth:`~tensordict.default_is_leaf`.

Keyword Args:
sort (bool, optional): whether the keys should be sorted. For nested keys,
Expand Down Expand Up @@ -7252,8 +7278,20 @@ def keys(
Defaults to ``False``.
leaves_only (bool, optional): if ``False``, only leaves will be
returned. Defaults to ``False``.
is_leaf: an optional callable that indicates if a class is to be considered a
leaf or not.
is_leaf (callable, optional): a callable over a class type returning
a bool indicating if this class has to be considered as a leaf.

.. note:: The purpose of `is_leaf` is not to prevent recursive calls into nested tensordicts, but
rather to mark certain types as "leaves" for the purpose of filtering when `leaves_only=True`.
Even if `is_leaf(cls)` returns `True`, the nested structure of the tensordict will still be
traversed if `include_nested=True`.
In other words, `is_leaf` does not control the recursion depth, but rather provides a way to filter
out certain types from the result when `leaves_only=True`. This means that a node in the tree can
be both a leaf and a node with children.
In practice, the default value of ``is_leaf`` does exclude tensordict and tensorclass instances
from the leaf set.

.. seealso:: :meth:`~tensordict.is_leaf_nontensor` and :meth:`~tensordict.default_is_leaf`.

Keyword Args:
sort (bool, optional): whether the keys shoulbe sorted. For nested keys,
Expand Down Expand Up @@ -8377,7 +8415,7 @@ def _multithread_apply_flat(
named: bool = False,
nested_keys: bool = False,
prefix: tuple = (),
is_leaf: Callable = None,
is_leaf: Callable[[Type], bool] | None = None,
executor: ThreadPoolExecutor,
futures: List[Future],
local_futures: List,
Expand Down Expand Up @@ -8417,7 +8455,7 @@ def _multithread_apply_nest(
nested_keys: bool = False,
prefix: tuple = (),
filter_empty: bool | None = None,
is_leaf: Callable = None,
is_leaf: Callable[[Type], bool] | None = None,
out: TensorDictBase | None = None,
num_threads: int,
call_when_done: Callable | None = None,
Expand Down Expand Up @@ -8505,7 +8543,7 @@ def _apply_nest(
nested_keys: bool = False,
prefix: tuple = (),
filter_empty: bool | None = None,
is_leaf: Callable = None,
is_leaf: Callable[[Type], bool] | None = None,
out: TensorDictBase | None = None,
**constructor_kwargs,
) -> T | None: ...
Expand All @@ -8525,7 +8563,7 @@ def _fast_apply(
# filter_empty must be False because we use _fast_apply for all sorts of ops like expand etc
# and non-tensor data will disappear if we use True by default.
filter_empty: bool | None = False,
is_leaf: Callable = None,
is_leaf: Callable[[Type], bool] | None = None,
propagate_lock: bool = False,
out: TensorDictBase | None = None,
num_threads: int = 0,
Expand Down Expand Up @@ -12207,6 +12245,18 @@ def flatten_keys(
is_leaf (callable, optional): a callable over a class type returning
a bool indicating if this class has to be considered as a leaf.

.. note:: The purpose of `is_leaf` is not to prevent recursive calls into nested tensordicts, but
rather to mark certain types as "leaves" for the purpose of filtering when `leaves_only=True`.
Even if `is_leaf(cls)` returns `True`, the nested structure of the tensordict will still be
traversed if `include_nested=True`.
In other words, `is_leaf` does not control the recursion depth, but rather provides a way to filter
out certain types from the result when `leaves_only=True`. This means that a node in the tree can
be both a leaf and a node with children.
In practice, the default value of ``is_leaf`` does exclude tensordict and tensorclass instances
from the leaf set.

.. seealso:: :meth:`~tensordict.is_leaf_nontensor` and :meth:`~tensordict.default_is_leaf`.

Examples:
>>> data = TensorDict({"a": 1, ("b", "c"): 2, ("e", "f", "g"): 3}, batch_size=[])
>>> data.flatten_keys(separator=" - ")
Expand Down Expand Up @@ -13343,10 +13393,36 @@ def is_tensor_collection(datatype: type | Any) -> bool:


def _default_is_leaf(cls: Type) -> bool:
"""Returns ``True`` if a type is not a tensor collection (tensordict or tensorclass).

Examples:
>>> from tensordict import TensorDict, default_is_leaf
>>> import torch
>>> td = TensorDict(a={}, b="a string!", c=torch.randn(()))
>>> print(td.keys(leaves_only=True, is_leaf=default_is_leaf))
_TensorDictKeysView(['c'],
include_nested=False,
leaves_only=True)

.. seealso:: :meth:`~tensordict.is_leaf_nontensor`.
"""
return not _is_tensor_collection(cls)


def _is_leaf_nontensor(cls: Type) -> bool:
"""Returns ``True`` if a type is not a tensor collection (tensordict or tensorclass) or is a non-tensor.

Examples:
>>> from tensordict import TensorDict, default_is_leaf
>>> import torch
>>> td = TensorDict(a={}, b="a string!", c=torch.randn(()))
>>> print(td.keys(leaves_only=True, is_leaf=default_is_leaf))
_TensorDictKeysView(['b', 'c'],
include_nested=False,
leaves_only=True)

.. seealso:: :meth:`~tensordict.default_is_leaf`.
"""
if _is_tensor_collection(cls):
return _pass_through_cls(cls)
# if issubclass(cls, KeyedJaggedTensor):
Expand Down
2 changes: 1 addition & 1 deletion tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def _multithread_apply_flat(
named: bool = False,
nested_keys: bool = False,
prefix: tuple = (),
is_leaf: Callable = None,
is_leaf: Callable[[Type], bool] | None = None,
executor: ThreadPoolExecutor,
futures: List[Future],
local_futures: List,
Expand Down

0 comments on commit d08e74e

Please sign in to comment.