From 8e71fe63250b9604a0f7abf8c89ab2a2b24dd389 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 19 Feb 2025 17:57:59 +0000 Subject: [PATCH 1/4] Update [ghstack-poisoned] --- tensordict/_reductions.py | 12 +++++++++--- tensordict/base.py | 4 ++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tensordict/_reductions.py b/tensordict/_reductions.py index be8aa42f1..ecfe956be 100644 --- a/tensordict/_reductions.py +++ b/tensordict/_reductions.py @@ -11,12 +11,14 @@ from tensordict._lazy import LazyStackedTensorDict from tensordict._td import TensorDict -from tensordict.tensorclass import NonTensorData +from tensordict.tensorclass import NonTensorData, NonTensorStack from tensordict.utils import _STRDTYPE2DTYPE CLS_MAP = { "TensorDict": TensorDict, "LazyStackedTensorDict": LazyStackedTensorDict, + "NonTensorData": NonTensorData, + "NonTensorStack": NonTensorStack, } @@ -57,7 +59,9 @@ def from_metadata(metadata=metadata_dict, prefix=None): d[k] = from_metadata( v, prefix=prefix + (k,) if prefix is not None else (k,) ) - result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata) + if isinstance(cls, str): + cls = CLS_MAP[cls] + result = cls._from_dict_validated(d, **cls_metadata) if is_locked: result.lock_() # if is_shared: @@ -121,7 +125,9 @@ def from_metadata(metadata=metadata, prefix=None): d[k] = from_metadata( v, prefix=prefix + (k,) if prefix is not None else (k,) ) - result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata) + if isinstance(cls, str): + cls = CLS_MAP[cls] + result = cls._from_dict_validated(d, **cls_metadata) if is_locked: result = result.lock_() result._consolidated = consolidated diff --git a/tensordict/base.py b/tensordict/base.py index 21a4ab133..b8d47c075 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -4994,7 +4994,7 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata): if requires_metadata: # metadata is nested metadata_dict = { - "cls": type(self).__name__, + "cls": type(self), "non_tensors": {}, "leaves": {}, "cls_metadata": self._reduce_get_metadata(), @@ -5055,7 +5055,7 @@ def assign( metadata_dict_key = None if requires_metadata: metadata_dict_key = metadata_dict[key] = { - "cls": cls.__name__, + "cls": cls, "non_tensors": {}, "leaves": {}, "cls_metadata": value._reduce_get_metadata(), From 6b24fccf061f8eb1e718b655a0166ada46b1e3d8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 19 Feb 2025 18:05:31 +0000 Subject: [PATCH 2/4] Update [ghstack-poisoned] --- tensordict/_reductions.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tensordict/_reductions.py b/tensordict/_reductions.py index ecfe956be..1816143ed 100644 --- a/tensordict/_reductions.py +++ b/tensordict/_reductions.py @@ -12,7 +12,7 @@ from tensordict._td import TensorDict from tensordict.tensorclass import NonTensorData, NonTensorStack -from tensordict.utils import _STRDTYPE2DTYPE +from tensordict.utils import _is_tensorclass, _STRDTYPE2DTYPE CLS_MAP = { "TensorDict": TensorDict, @@ -130,7 +130,10 @@ def from_metadata(metadata=metadata, prefix=None): result = cls._from_dict_validated(d, **cls_metadata) if is_locked: result = result.lock_() - result._consolidated = consolidated + if _is_tensorclass(cls): + result._tensordict._consolidated = consolidated + else: + result._consolidated = consolidated return result return from_metadata() From 4da1714636abf196b8afcb5bc0d152d7230cb143 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 19 Feb 2025 18:19:19 +0000 Subject: [PATCH 3/4] Update [ghstack-poisoned] --- tensordict/_lazy.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index dcb9945f1..70d68261b 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -1992,11 +1992,14 @@ def _apply_nest( if all(r is None for r in results) and filter_empty in (None, True): return if not inplace: - out = type(self)( - *results, - stack_dim=self.stack_dim, - stack_dim_name=self._td_dim_name, - ) + if results: + out = type(self)( + *results, + stack_dim=self.stack_dim, + stack_dim_name=self._td_dim_name, + ) + else: + out = None else: out = self if names is not NO_DEFAULT: From 016fd56274cf599c7d97ca07bff443171ba9f795 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 20 Feb 2025 10:18:11 +0000 Subject: [PATCH 4/4] Update [ghstack-poisoned] --- tensordict/_lazy.py | 31 ++++++++++++++--------------- tensordict/_td.py | 11 +++-------- tensordict/base.py | 43 +++++++++++++++++++++++++++++++---------- test/test_tensordict.py | 27 ++++++++++++++++++++++++++ 4 files changed, 77 insertions(+), 35 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 70d68261b..7d7a86f47 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -381,17 +381,6 @@ def from_dict( stack_dim_name=None, stack_dim=0, ): - # if batch_size is not None: - # batch_size = list(batch_size) - # if stack_dim is None: - # stack_dim = 0 - # n = batch_size.pop(stack_dim) - # if n != len(input_dict): - # raise ValueError( - # "The number of dicts and the corresponding batch-size must match, " - # f"got len(input_dict)={len(input_dict)} and batch_size[{stack_dim}]={n}." - # ) - # batch_size = torch.Size(batch_size) return LazyStackedTensorDict( *( TensorDict.from_dict( @@ -1992,12 +1981,20 @@ def _apply_nest( if all(r is None for r in results) and filter_empty in (None, True): return if not inplace: - if results: - out = type(self)( - *results, - stack_dim=self.stack_dim, - stack_dim_name=self._td_dim_name, - ) + if not results or any(r is not None for r in results): + try: + out = type(self)( + *results, + stack_dim=self.stack_dim, + stack_dim_name=self._td_dim_name, + ) + except Exception as e: + raise RuntimeError( + f"Failed to reconstruct the lazy stack of tensordicts with class: {type(self)}. " + f"One common issue is that the outputs of apply are a mix of None and non-None " + f"values. Check that the outputs of apply() are all None or all non-None. " + f"Otherwise, please report this bug on tensordict github." + ) from e else: out = None else: diff --git a/tensordict/_td.py b/tensordict/_td.py index 4b8b4e86d..dcc192200 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -2064,6 +2064,8 @@ def from_dict( batch_dims=None, names=None, ): + if _is_tensor_collection(type(input_dict)): + return input_dict if others: if batch_size is not None: raise TypeError( @@ -2120,14 +2122,7 @@ def from_dict( ) if batch_size is None: if auto_batch_size is None and batch_dims is None: - warn( - "The batch-size was not provided and auto_batch_size isn't set either. " - "Currently, from_dict will call set auto_batch_size=True but this behaviour " - "will be changed in v0.8 and auto_batch_size will be False onward. " - "To silence this warning, pass auto_batch_size directly.", - category=DeprecationWarning, - ) - auto_batch_size = True + auto_batch_size = False elif auto_batch_size is None: auto_batch_size = True if auto_batch_size: diff --git a/tensordict/base.py b/tensordict/base.py index b8d47c075..d4cd5a8e2 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -23,7 +23,7 @@ from concurrent.futures import Future, ThreadPoolExecutor, wait from copy import copy -from functools import partial, wraps +from functools import wraps from pathlib import Path from textwrap import indent from threading import Thread @@ -2187,7 +2187,6 @@ def _from_dict_validated(cls, *args, **kwargs): By default, falls back on :meth:`~.from_dict`. """ - kwargs.setdefault("auto_batch_size", True) return cls.from_dict(*args, **kwargs) @abc.abstractmethod @@ -4993,8 +4992,15 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata): if requires_metadata: # metadata is nested + cls = type(self) + from tensordict._reductions import CLS_MAP + + if cls.__name__ in CLS_MAP: + cls = cls.__name__ + else: + pass metadata_dict = { - "cls": type(self), + "cls": cls, "non_tensors": {}, "leaves": {}, "cls_metadata": self._reduce_get_metadata(), @@ -5054,18 +5060,27 @@ def assign( elif _is_tensor_collection(cls): metadata_dict_key = None if requires_metadata: + from tensordict._reductions import CLS_MAP + + if cls.__name__ in CLS_MAP: + cls = cls.__name__ + else: + pass metadata_dict_key = metadata_dict[key] = { "cls": cls, "non_tensors": {}, "leaves": {}, "cls_metadata": value._reduce_get_metadata(), } - local_assign = partial( - assign, - track_key=total_key, - metadata_dict=metadata_dict_key, - flat_size=flat_size, - ) + + def local_assign(*t): + return assign( + *t, + track_key=total_key, + metadata_dict=metadata_dict_key, + flat_size=flat_size, + ) + value._fast_apply( local_assign, named=True, @@ -5253,7 +5268,15 @@ def consolidate( storage.share_memory_() else: # Convert the dict to json - metadata_dict_json = json.dumps(metadata_dict) + try: + metadata_dict_json = json.dumps(metadata_dict) + except TypeError as e: + raise RuntimeError( + "Failed to convert the metatdata to json. " + "This is usually due to a nested class that is unaccounted for by the serializer, " + "such as custom TensorClass. " + "If you encounter this error, please file an issue on github." + ) from e # Represent as a tensor metadata_dict_json = torch.as_tensor( bytearray(metadata_dict_json), dtype=torch.uint8 diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 18aeecf33..d347db593 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -40,6 +40,7 @@ from tensordict import ( capture_non_tensor_stack, get_defaults_to_none, + lazy_stack, LazyStackedTensorDict, make_tensordict, PersistentTensorDict, @@ -11234,6 +11235,32 @@ def non_tensor_data(self): batch_size=[], ) + @set_capture_non_tensor_stack(False) + def test_consolidate_nested(self): + import pickle + + @tensorclass + class A: + a: str + b: torch.Tensor + + td = TensorDict( + a=TensorDict(b=A(a="a string!", b=torch.randn(10))), + c=TensorDict(d=NonTensorData("another string!")), + ) + td = lazy_stack([td.clone(), td.clone()]) + td = lazy_stack([td.clone(), td.clone()], -1) + + tdc = td.consolidate() + + assert (tdc == td).all() + + tdr = pickle.loads(pickle.dumps(td)) + assert (tdr == td).all() + + tdcr = pickle.loads(pickle.dumps(tdc)) + assert (tdcr == td).all() + def test_comparison(self, non_tensor_data): non_tensor_data = non_tensor_data.exclude(("nested", "str")) assert (non_tensor_data | non_tensor_data).get_non_tensor(("nested", "bool"))