Skip to content

Commit

Permalink
[Feature] capture_non_tensor_stack
Browse files Browse the repository at this point in the history
ghstack-source-id: 8667892d782a5904e2c5117a1b039edcdaacb9e0
Pull Request resolved: #1221
  • Loading branch information
vmoens committed Feb 19, 2025
1 parent 3ec3be7 commit 7b0fd93
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 42 deletions.
8 changes: 5 additions & 3 deletions docs/source/reference/tensordict.rst
Original file line number Diff line number Diff line change
Expand Up @@ -231,13 +231,15 @@ Utils
utils.expand_right
utils.isin
utils.remove_duplicates
capture_non_tensor_stack
dense_stack_tds
is_batchedtensor
is_tensor_collection
lazy_legacy
make_tensordict
merge_tensordicts
pad
pad_sequence
dense_stack_tds
set_lazy_legacy
lazy_legacy
parse_tensor_dict_string
set_capture_non_tensor_stack
set_lazy_legacy
2 changes: 2 additions & 0 deletions tensordict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,13 @@
from tensordict.utils import (
assert_allclose_td,
assert_close,
capture_non_tensor_stack,
is_batchedtensor,
is_non_tensor,
is_tensorclass,
lazy_legacy,
parse_tensor_dict_string,
set_capture_non_tensor_stack,
set_lazy_legacy,
unravel_key,
unravel_key_list,
Expand Down
5 changes: 5 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
_unravel_key_to_tuple,
_zip_strict,
cache,
capture_non_tensor_stack,
convert_ellipsis_to_idx,
DeviceType,
erase_cache,
Expand Down Expand Up @@ -6229,6 +6230,10 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT):
value = self._get_str(key, default=default)

if is_non_tensor(value):
from tensordict import NonTensorStack

if isinstance(value, NonTensorStack) and not capture_non_tensor_stack():
return value.tolist()
data = getattr(value, "data", None)
if data is None:
return value.tolist()
Expand Down
26 changes: 21 additions & 5 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@
_TENSORCLASS_MEMO,
_unravel_key_to_tuple,
_zip_strict,
capture_non_tensor_stack,
DeviceType,
IndexType,
is_tensorclass,
KeyDependentDefaultDict,
set_capture_non_tensor_stack,
)
from torch import multiprocessing as mp, Tensor
from torch.multiprocessing import Manager
Expand Down Expand Up @@ -3219,7 +3221,7 @@ def _stack_non_tensor(cls, list_of_non_tensor, dim=0, raise_if_non_unique=False)

ids = set()
firstdata = NO_DEFAULT
return_stack = False
return_stack = not capture_non_tensor_stack()
for data in list_of_non_tensor:
if not isinstance(data, NonTensorData):
if raise_if_non_unique:
Expand All @@ -3242,8 +3244,18 @@ def _stack_non_tensor(cls, list_of_non_tensor, dim=0, raise_if_non_unique=False)
return_stack = True
break
else:
return_stack = False
return_stack = not capture_non_tensor_stack()
if not return_stack:
if capture_non_tensor_stack(allow_none=True) is None:
warnings.warn(
"The default behavior of stacking non-tensor data will change in "
"version v0.9 and switch from True to False (current default). "
"To prepare for this change, use set_capture_non_tensor_stack(val: bool) as a decorator or context "
"manager, or set the environment variable CAPTURE_NONTENSOR_STACK "
"to 'False'.",
FutureWarning,
stacklevel=2,
)
batch_size = list(first.batch_size)
batch_size.insert(dim, len(list_of_non_tensor))
return NonTensorData(
Expand Down Expand Up @@ -3772,9 +3784,13 @@ def data(self):
Raises a ValueError if there is more than one unique value.
"""
try:
return NonTensorData._stack_non_tensor(
self.tensordicts, raise_if_non_unique=True
).data
with set_capture_non_tensor_stack(True):
nt = NonTensorData._stack_non_tensor(
self.tensordicts, raise_if_non_unique=True
)
if not isinstance(nt, NonTensorData):
raise ValueError
return nt.data
except ValueError:
raise AttributeError(
"Cannot get the non-unique data of a NonTensorStack. Use .tolist() instead."
Expand Down
88 changes: 87 additions & 1 deletion tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2098,7 +2098,7 @@ def set(self) -> None:

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
global _LAZY_OP
_LAZY_OP = bool(self._old_mode)
_LAZY_OP = self._old_mode
os.environ["LAZY_LEGACY_OP"] = str(_LAZY_OP)


Expand All @@ -2121,6 +2121,92 @@ def _legacy_lazy(func):
return func


# non tensor stack control
_DEFAULT_CAPTURE_NONTENSOR_STACK = True
_CAPTURE_NONTENSOR_STACK = os.environ.get("CAPTURE_NONTENSOR_STACK")


class set_capture_non_tensor_stack(_DecoratorContextManager):
"""A context manager or decorator to control whether identical non-tensor data should be stacked into a single NonTensorData object or a NonTensorStack.
Args:
mode (bool): Whether to capture non-tensor stacks. If ``False``, identical
non-tensor data will be stacked into a :class:`~tensordict.NonTensorStack`. If ``True``,
a single :class:`~tensordict.NonTensorData` object will contain the unique value, but with the desired batch-size.
Defaults to ``True``.
.. note:: Until v0.9, this will raise a warning if the same value is encountered and the value is not set
explicitly (`capture_non_tensor_stack() = True` default behavior).
You can set the value of :func:`~tensordict.capture_non_tensor_stack` through:
- The ``CAPTURE_NON_TENSOR_STACK`` environment variable;
- By setting ``set_capture_non_tensor_stack(val: bool).set()`` at the beginning of your script;
- By using ``set_capture_non_tensor_stack(val: bool)`` as a context manager or a decorator.
It is recommended to use the `set_capture_non_tensor_stack(False)` behavior.
.. seealso:: :class:`~tensordict.capture_non_tensor_stack`
Examples:
>>> with set_capture_non_tensor_stack(False):
... torch.stack([NonTensorData("a"), NonTensorData("a")])
NonTensorData("a", batch_size=[2])
>>> @set_capture_non_tensor_stack(False)
... def my_function():
... return torch.stack([NonTensorData("a"), NonTensorData("a")])
>>> my_function()
NonTensorStack(["a", "a"], stack_dim=0)
"""

def __init__(self, mode: bool) -> None:
super().__init__()
self.mode = mode

def clone(self) -> set_capture_non_tensor_stack:
# override this method if your children class takes __init__ parameters
return type(self)(self.mode)

def __enter__(self) -> None:
self.set()

def set(self) -> None:
global _CAPTURE_NONTENSOR_STACK
self._old_mode = _CAPTURE_NONTENSOR_STACK
_CAPTURE_NONTENSOR_STACK = bool(self.mode)
# we do this such that sub-processes see the same lazy op than the main one
os.environ["CAPTURE_NONTENSOR_STACK"] = str(_CAPTURE_NONTENSOR_STACK)

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
global _CAPTURE_NONTENSOR_STACK
_CAPTURE_NONTENSOR_STACK = self._old_mode
os.environ["CAPTURE_NONTENSOR_STACK"] = str(_CAPTURE_NONTENSOR_STACK)


def capture_non_tensor_stack(allow_none=False):
"""Get the current setting for capturing non-tensor stacks.
Args:
allow_none (bool, optional): If ``True``, returns ``None`` if no setting has been
specified. Otherwise, returns the default setting. Defaults to ``False``.
seealso: :func:`~tensordict.set_capture_non_tensor_stack`
Returns:
bool or None: The current setting for capturing non-tensor stacks.
"""
global _CAPTURE_NONTENSOR_STACK
if _CAPTURE_NONTENSOR_STACK is None and allow_none:
return None
elif _CAPTURE_NONTENSOR_STACK is None:
return _DEFAULT_CAPTURE_NONTENSOR_STACK
return (
strtobool(_CAPTURE_NONTENSOR_STACK)
if isinstance(_CAPTURE_NONTENSOR_STACK, str)
else _CAPTURE_NONTENSOR_STACK
)


# Process initializer for map
def _proc_init(base_seed, queue, num_threads):
worker_id = queue.get(timeout=120)
Expand Down
42 changes: 25 additions & 17 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,6 @@
import pytest
import tensordict.utils
import torch
from tensordict import TensorClass
from tensordict.tensorclass import from_dataclass

try:
import torchsnapshot

_has_torchsnapshot = True
TORCHSNAPSHOT_ERR = ""
except ImportError as err:
_has_torchsnapshot = False
TORCHSNAPSHOT_ERR = str(err)

from _utils_internal import get_available_devices

Expand All @@ -44,14 +33,26 @@
lazy_legacy,
LazyStackedTensorDict,
MemoryMappedTensor,
set_capture_non_tensor_stack,
tensorclass,
TensorClass,
TensorDict,
TensorDictBase,
)
from tensordict._lazy import _PermutedTensorDict, _ViewedTensorDict
from tensordict.base import _GENERIC_NESTED_ERR
from tensordict.tensorclass import from_dataclass
from torch import Tensor

try:
import torchsnapshot

_has_torchsnapshot = True
TORCHSNAPSHOT_ERR = ""
except ImportError as err:
_has_torchsnapshot = False
TORCHSNAPSHOT_ERR = str(err)

# Capture all warnings
pytestmark = [
pytest.mark.filterwarnings("error"),
Expand Down Expand Up @@ -381,7 +382,8 @@ class MyData:
data3 = MyData(D, B, A, C=C, E=E, batch_size=[3, 4])
data4 = MyData(D, B, A, C, E=E, batch_size=[3, 4])
data5 = MyData(D, B, A, C, E, batch_size=[3, 4])
data = torch.stack([data1, data2, data3, data4, data5], 0)
with set_capture_non_tensor_stack(True):
data = torch.stack([data1, data2, data3, data4, data5], 0)
assert (data.A == A).all()
assert (data.B == B).all()
assert (data.C == C).all()
Expand Down Expand Up @@ -1857,7 +1859,8 @@ class MyDataNested:
if lazy:
stacked_tc = LazyStackedTensorDict.lazy_stack([data1, data2], 0)
else:
stacked_tc = torch.stack([data1, data2], 0)
with set_capture_non_tensor_stack(True):
stacked_tc = torch.stack([data1, data2], 0)
assert type(stacked_tc) is type(data1)
assert isinstance(stacked_tc.y, type(data1.y))
assert stacked_tc.X.shape == torch.Size([2, 3, 4, 5])
Expand Down Expand Up @@ -2145,7 +2148,8 @@ def z(self) -> torch.Tensor:
y1 = Y(weakref.ref(obj), batch_size=[1])
y = torch.cat([y0, y1])
assert y.z.shape == torch.Size(())
y = torch.stack([y0, y1])
with set_capture_non_tensor_stack(True):
y = torch.stack([y0, y1])
assert y.z.shape == torch.Size(())


Expand Down Expand Up @@ -2253,9 +2257,13 @@ class TensorClass:
def get_nested(self):
c = self.TensorClass(torch.ones(1), ("a", "b", "c"), "Hello", batch_size=[])

td = torch.stack(
[TensorDict({"t": torch.ones(1), "c": c}, batch_size=[]) for _ in range(3)]
)
with set_capture_non_tensor_stack(True):
td = torch.stack(
[
TensorDict({"t": torch.ones(1), "c": c}, batch_size=[])
for _ in range(3)
]
)
return td

def test_apply(self):
Expand Down
Loading

0 comments on commit 7b0fd93

Please sign in to comment.