Skip to content

Commit

Permalink
[BugFix] Pass type directly during reduction
Browse files Browse the repository at this point in the history
ghstack-source-id: 737f03775bc98090172e397ad4a65c8e777302e5
Pull Request resolved: #1225
  • Loading branch information
vmoens committed Feb 20, 2025
1 parent f67a15c commit ad0a8dd
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 40 deletions.
32 changes: 16 additions & 16 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,17 +381,6 @@ def from_dict(
stack_dim_name=None,
stack_dim=0,
):
# if batch_size is not None:
# batch_size = list(batch_size)
# if stack_dim is None:
# stack_dim = 0
# n = batch_size.pop(stack_dim)
# if n != len(input_dict):
# raise ValueError(
# "The number of dicts and the corresponding batch-size must match, "
# f"got len(input_dict)={len(input_dict)} and batch_size[{stack_dim}]={n}."
# )
# batch_size = torch.Size(batch_size)
return LazyStackedTensorDict(
*(
TensorDict.from_dict(
Expand Down Expand Up @@ -1992,11 +1981,22 @@ def _apply_nest(
if all(r is None for r in results) and filter_empty in (None, True):
return
if not inplace:
out = type(self)(
*results,
stack_dim=self.stack_dim,
stack_dim_name=self._td_dim_name,
)
if not results or any(r is not None for r in results):
try:
out = type(self)(
*results,
stack_dim=self.stack_dim,
stack_dim_name=self._td_dim_name,
)
except Exception as e:
raise RuntimeError(
f"Failed to reconstruct the lazy stack of tensordicts with class: {type(self)}. "
f"One common issue is that the outputs of apply are a mix of None and non-None "
f"values. Check that the outputs of apply() are all None or all non-None. "
f"Otherwise, please report this bug on tensordict github."
) from e
else:
out = None
else:
out = self
if names is not NO_DEFAULT:
Expand Down
19 changes: 14 additions & 5 deletions tensordict/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
from tensordict._lazy import LazyStackedTensorDict
from tensordict._td import TensorDict

from tensordict.tensorclass import NonTensorData
from tensordict.utils import _STRDTYPE2DTYPE
from tensordict.tensorclass import NonTensorData, NonTensorStack
from tensordict.utils import _is_tensorclass, _STRDTYPE2DTYPE

CLS_MAP = {
"TensorDict": TensorDict,
"LazyStackedTensorDict": LazyStackedTensorDict,
"NonTensorData": NonTensorData,
"NonTensorStack": NonTensorStack,
}


Expand Down Expand Up @@ -57,7 +59,9 @@ def from_metadata(metadata=metadata_dict, prefix=None):
d[k] = from_metadata(
v, prefix=prefix + (k,) if prefix is not None else (k,)
)
result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata)
if isinstance(cls, str):
cls = CLS_MAP[cls]
result = cls._from_dict_validated(d, **cls_metadata)
if is_locked:
result.lock_()
# if is_shared:
Expand Down Expand Up @@ -121,10 +125,15 @@ def from_metadata(metadata=metadata, prefix=None):
d[k] = from_metadata(
v, prefix=prefix + (k,) if prefix is not None else (k,)
)
result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata)
if isinstance(cls, str):
cls = CLS_MAP[cls]
result = cls._from_dict_validated(d, **cls_metadata)
if is_locked:
result = result.lock_()
result._consolidated = consolidated
if _is_tensorclass(cls):
result._tensordict._consolidated = consolidated
else:
result._consolidated = consolidated
return result

return from_metadata()
Expand Down
11 changes: 3 additions & 8 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -2064,6 +2064,8 @@ def from_dict(
batch_dims=None,
names=None,
):
if _is_tensor_collection(type(input_dict)):
return input_dict
if others:
if batch_size is not None:
raise TypeError(
Expand Down Expand Up @@ -2120,14 +2122,7 @@ def from_dict(
)
if batch_size is None:
if auto_batch_size is None and batch_dims is None:
warn(
"The batch-size was not provided and auto_batch_size isn't set either. "
"Currently, from_dict will call set auto_batch_size=True but this behaviour "
"will be changed in v0.8 and auto_batch_size will be False onward. "
"To silence this warning, pass auto_batch_size directly.",
category=DeprecationWarning,
)
auto_batch_size = True
auto_batch_size = False
elif auto_batch_size is None:
auto_batch_size = True
if auto_batch_size:
Expand Down
45 changes: 34 additions & 11 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from concurrent.futures import Future, ThreadPoolExecutor, wait
from copy import copy
from functools import partial, wraps
from functools import wraps
from pathlib import Path
from textwrap import indent
from threading import Thread
Expand Down Expand Up @@ -2188,7 +2188,6 @@ def _from_dict_validated(cls, *args, **kwargs):

By default, falls back on :meth:`~.from_dict`.
"""
kwargs.setdefault("auto_batch_size", True)
return cls.from_dict(*args, **kwargs)

@abc.abstractmethod
Expand Down Expand Up @@ -4994,8 +4993,15 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata):

if requires_metadata:
# metadata is nested
cls = type(self)
from tensordict._reductions import CLS_MAP

if cls.__name__ in CLS_MAP:
cls = cls.__name__
else:
pass
metadata_dict = {
"cls": type(self).__name__,
"cls": cls,
"non_tensors": {},
"leaves": {},
"cls_metadata": self._reduce_get_metadata(),
Expand Down Expand Up @@ -5055,18 +5061,27 @@ def assign(
elif _is_tensor_collection(cls):
metadata_dict_key = None
if requires_metadata:
from tensordict._reductions import CLS_MAP

if cls.__name__ in CLS_MAP:
cls = cls.__name__
else:
pass
metadata_dict_key = metadata_dict[key] = {
"cls": cls.__name__,
"cls": cls,
"non_tensors": {},
"leaves": {},
"cls_metadata": value._reduce_get_metadata(),
}
local_assign = partial(
assign,
track_key=total_key,
metadata_dict=metadata_dict_key,
flat_size=flat_size,
)

def local_assign(*t):
return assign(
*t,
track_key=total_key,
metadata_dict=metadata_dict_key,
flat_size=flat_size,
)

value._fast_apply(
local_assign,
named=True,
Expand Down Expand Up @@ -5254,7 +5269,15 @@ def consolidate(
storage.share_memory_()
else:
# Convert the dict to json
metadata_dict_json = json.dumps(metadata_dict)
try:
metadata_dict_json = json.dumps(metadata_dict)
except TypeError as e:
raise RuntimeError(
"Failed to convert the metatdata to json. "
"This is usually due to a nested class that is unaccounted for by the serializer, "
"such as custom TensorClass. "
"If you encounter this error, please file an issue on github."
) from e
# Represent as a tensor
metadata_dict_json = torch.as_tensor(
bytearray(metadata_dict_json), dtype=torch.uint8
Expand Down

0 comments on commit ad0a8dd

Please sign in to comment.