Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Pass type directly during reduction #1223

Open
wants to merge 4 commits into
base: gh/vmoens/47/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,17 +381,6 @@ def from_dict(
stack_dim_name=None,
stack_dim=0,
):
# if batch_size is not None:
# batch_size = list(batch_size)
# if stack_dim is None:
# stack_dim = 0
# n = batch_size.pop(stack_dim)
# if n != len(input_dict):
# raise ValueError(
# "The number of dicts and the corresponding batch-size must match, "
# f"got len(input_dict)={len(input_dict)} and batch_size[{stack_dim}]={n}."
# )
# batch_size = torch.Size(batch_size)
return LazyStackedTensorDict(
*(
TensorDict.from_dict(
Expand Down Expand Up @@ -1992,11 +1981,22 @@ def _apply_nest(
if all(r is None for r in results) and filter_empty in (None, True):
return
if not inplace:
out = type(self)(
*results,
stack_dim=self.stack_dim,
stack_dim_name=self._td_dim_name,
)
if not results or any(r is not None for r in results):
try:
out = type(self)(
*results,
stack_dim=self.stack_dim,
stack_dim_name=self._td_dim_name,
)
except Exception as e:
raise RuntimeError(
f"Failed to reconstruct the lazy stack of tensordicts with class: {type(self)}. "
f"One common issue is that the outputs of apply are a mix of None and non-None "
f"values. Check that the outputs of apply() are all None or all non-None. "
f"Otherwise, please report this bug on tensordict github."
) from e
else:
out = None
else:
out = self
if names is not NO_DEFAULT:
Expand Down
19 changes: 14 additions & 5 deletions tensordict/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
from tensordict._lazy import LazyStackedTensorDict
from tensordict._td import TensorDict

from tensordict.tensorclass import NonTensorData
from tensordict.utils import _STRDTYPE2DTYPE
from tensordict.tensorclass import NonTensorData, NonTensorStack
from tensordict.utils import _is_tensorclass, _STRDTYPE2DTYPE

CLS_MAP = {
"TensorDict": TensorDict,
"LazyStackedTensorDict": LazyStackedTensorDict,
"NonTensorData": NonTensorData,
"NonTensorStack": NonTensorStack,
}


Expand Down Expand Up @@ -57,7 +59,9 @@ def from_metadata(metadata=metadata_dict, prefix=None):
d[k] = from_metadata(
v, prefix=prefix + (k,) if prefix is not None else (k,)
)
result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata)
if isinstance(cls, str):
cls = CLS_MAP[cls]
result = cls._from_dict_validated(d, **cls_metadata)
if is_locked:
result.lock_()
# if is_shared:
Expand Down Expand Up @@ -121,10 +125,15 @@ def from_metadata(metadata=metadata, prefix=None):
d[k] = from_metadata(
v, prefix=prefix + (k,) if prefix is not None else (k,)
)
result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata)
if isinstance(cls, str):
cls = CLS_MAP[cls]
result = cls._from_dict_validated(d, **cls_metadata)
if is_locked:
result = result.lock_()
result._consolidated = consolidated
if _is_tensorclass(cls):
result._tensordict._consolidated = consolidated
else:
result._consolidated = consolidated
return result

return from_metadata()
Expand Down
11 changes: 3 additions & 8 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -2064,6 +2064,8 @@ def from_dict(
batch_dims=None,
names=None,
):
if _is_tensor_collection(type(input_dict)):
return input_dict
if others:
if batch_size is not None:
raise TypeError(
Expand Down Expand Up @@ -2120,14 +2122,7 @@ def from_dict(
)
if batch_size is None:
if auto_batch_size is None and batch_dims is None:
warn(
"The batch-size was not provided and auto_batch_size isn't set either. "
"Currently, from_dict will call set auto_batch_size=True but this behaviour "
"will be changed in v0.8 and auto_batch_size will be False onward. "
"To silence this warning, pass auto_batch_size directly.",
category=DeprecationWarning,
)
auto_batch_size = True
auto_batch_size = False
elif auto_batch_size is None:
auto_batch_size = True
if auto_batch_size:
Expand Down
45 changes: 34 additions & 11 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from concurrent.futures import Future, ThreadPoolExecutor, wait
from copy import copy
from functools import partial, wraps
from functools import wraps
from pathlib import Path
from textwrap import indent
from threading import Thread
Expand Down Expand Up @@ -2187,7 +2187,6 @@ def _from_dict_validated(cls, *args, **kwargs):

By default, falls back on :meth:`~.from_dict`.
"""
kwargs.setdefault("auto_batch_size", True)
return cls.from_dict(*args, **kwargs)

@abc.abstractmethod
Expand Down Expand Up @@ -4993,8 +4992,15 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata):

if requires_metadata:
# metadata is nested
cls = type(self)
from tensordict._reductions import CLS_MAP

if cls.__name__ in CLS_MAP:
cls = cls.__name__
else:
pass
metadata_dict = {
"cls": type(self).__name__,
"cls": cls,
"non_tensors": {},
"leaves": {},
"cls_metadata": self._reduce_get_metadata(),
Expand Down Expand Up @@ -5054,18 +5060,27 @@ def assign(
elif _is_tensor_collection(cls):
metadata_dict_key = None
if requires_metadata:
from tensordict._reductions import CLS_MAP

if cls.__name__ in CLS_MAP:
cls = cls.__name__
else:
pass
metadata_dict_key = metadata_dict[key] = {
"cls": cls.__name__,
"cls": cls,
"non_tensors": {},
"leaves": {},
"cls_metadata": value._reduce_get_metadata(),
}
local_assign = partial(
assign,
track_key=total_key,
metadata_dict=metadata_dict_key,
flat_size=flat_size,
)

def local_assign(*t):
return assign(
*t,
track_key=total_key,
metadata_dict=metadata_dict_key,
flat_size=flat_size,
)

value._fast_apply(
local_assign,
named=True,
Expand Down Expand Up @@ -5253,7 +5268,15 @@ def consolidate(
storage.share_memory_()
else:
# Convert the dict to json
metadata_dict_json = json.dumps(metadata_dict)
try:
metadata_dict_json = json.dumps(metadata_dict)
except TypeError as e:
raise RuntimeError(
"Failed to convert the metatdata to json. "
"This is usually due to a nested class that is unaccounted for by the serializer, "
"such as custom TensorClass. "
"If you encounter this error, please file an issue on github."
) from e
# Represent as a tensor
metadata_dict_json = torch.as_tensor(
bytearray(metadata_dict_json), dtype=torch.uint8
Expand Down
27 changes: 27 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from tensordict import (
capture_non_tensor_stack,
get_defaults_to_none,
lazy_stack,
LazyStackedTensorDict,
make_tensordict,
PersistentTensorDict,
Expand Down Expand Up @@ -11234,6 +11235,32 @@ def non_tensor_data(self):
batch_size=[],
)

@set_capture_non_tensor_stack(False)
def test_consolidate_nested(self):
import pickle

@tensorclass
class A:
a: str
b: torch.Tensor

td = TensorDict(
a=TensorDict(b=A(a="a string!", b=torch.randn(10))),
c=TensorDict(d=NonTensorData("another string!")),
)
td = lazy_stack([td.clone(), td.clone()])
td = lazy_stack([td.clone(), td.clone()], -1)

tdc = td.consolidate()

assert (tdc == td).all()

tdr = pickle.loads(pickle.dumps(td))
assert (tdr == td).all()

tdcr = pickle.loads(pickle.dumps(tdc))
assert (tdcr == td).all()

def test_comparison(self, non_tensor_data):
non_tensor_data = non_tensor_data.exclude(("nested", "str"))
assert (non_tensor_data | non_tensor_data).get_non_tensor(("nested", "bool"))
Expand Down
Loading