Skip to content

Commit

Permalink
[BugFix] Pass type directly during reduction
Browse files Browse the repository at this point in the history
ghstack-source-id: 9ca5bf2a5bc1f3fd88c29360fb088836ce35e8a7
Pull Request resolved: #1223
  • Loading branch information
vmoens committed Feb 19, 2025
1 parent 0b901a7 commit baf2e43
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 12 deletions.
13 changes: 8 additions & 5 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1992,11 +1992,14 @@ def _apply_nest(
if all(r is None for r in results) and filter_empty in (None, True):
return
if not inplace:
out = type(self)(
*results,
stack_dim=self.stack_dim,
stack_dim_name=self._td_dim_name,
)
if results:
out = type(self)(
*results,
stack_dim=self.stack_dim,
stack_dim_name=self._td_dim_name,
)
else:
out = None
else:
out = self
if names is not NO_DEFAULT:
Expand Down
19 changes: 14 additions & 5 deletions tensordict/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
from tensordict._lazy import LazyStackedTensorDict
from tensordict._td import TensorDict

from tensordict.tensorclass import NonTensorData
from tensordict.utils import _STRDTYPE2DTYPE
from tensordict.tensorclass import NonTensorData, NonTensorStack
from tensordict.utils import _is_tensorclass, _STRDTYPE2DTYPE

CLS_MAP = {
"TensorDict": TensorDict,
"LazyStackedTensorDict": LazyStackedTensorDict,
"NonTensorData": NonTensorData,
"NonTensorStack": NonTensorStack,
}


Expand Down Expand Up @@ -57,7 +59,9 @@ def from_metadata(metadata=metadata_dict, prefix=None):
d[k] = from_metadata(
v, prefix=prefix + (k,) if prefix is not None else (k,)
)
result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata)
if isinstance(cls, str):
cls = CLS_MAP[cls]
result = cls._from_dict_validated(d, **cls_metadata)
if is_locked:
result.lock_()
# if is_shared:
Expand Down Expand Up @@ -121,10 +125,15 @@ def from_metadata(metadata=metadata, prefix=None):
d[k] = from_metadata(
v, prefix=prefix + (k,) if prefix is not None else (k,)
)
result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata)
if isinstance(cls, str):
cls = CLS_MAP[cls]
result = cls._from_dict_validated(d, **cls_metadata)
if is_locked:
result = result.lock_()
result._consolidated = consolidated
if _is_tensorclass(cls):
result._tensordict._consolidated = consolidated
else:
result._consolidated = consolidated
return result

return from_metadata()
Expand Down
4 changes: 2 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4994,7 +4994,7 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata):
if requires_metadata:
# metadata is nested
metadata_dict = {
"cls": type(self).__name__,
"cls": type(self),
"non_tensors": {},
"leaves": {},
"cls_metadata": self._reduce_get_metadata(),
Expand Down Expand Up @@ -5055,7 +5055,7 @@ def assign(
metadata_dict_key = None
if requires_metadata:
metadata_dict_key = metadata_dict[key] = {
"cls": cls.__name__,
"cls": cls,
"non_tensors": {},
"leaves": {},
"cls_metadata": value._reduce_get_metadata(),
Expand Down

0 comments on commit baf2e43

Please sign in to comment.