Skip to content

Commit

Permalink
[BugFix] Pass type directly during reduction
Browse files Browse the repository at this point in the history
ghstack-source-id: 2a0f011758991f07958b2b1742d3d2136b6e9fb8
Pull Request resolved: #1223
  • Loading branch information
vmoens committed Feb 19, 2025
1 parent 0b901a7 commit 154fdf5
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
19 changes: 14 additions & 5 deletions tensordict/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
from tensordict._lazy import LazyStackedTensorDict
from tensordict._td import TensorDict

from tensordict.tensorclass import NonTensorData
from tensordict.utils import _STRDTYPE2DTYPE
from tensordict.tensorclass import NonTensorData, NonTensorStack
from tensordict.utils import _is_tensorclass, _STRDTYPE2DTYPE

CLS_MAP = {
"TensorDict": TensorDict,
"LazyStackedTensorDict": LazyStackedTensorDict,
"NonTensorData": NonTensorData,
"NonTensorStack": NonTensorStack,
}


Expand Down Expand Up @@ -57,7 +59,9 @@ def from_metadata(metadata=metadata_dict, prefix=None):
d[k] = from_metadata(
v, prefix=prefix + (k,) if prefix is not None else (k,)
)
result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata)
if isinstance(cls, str):
cls = CLS_MAP[cls]
result = cls._from_dict_validated(d, **cls_metadata)
if is_locked:
result.lock_()
# if is_shared:
Expand Down Expand Up @@ -121,10 +125,15 @@ def from_metadata(metadata=metadata, prefix=None):
d[k] = from_metadata(
v, prefix=prefix + (k,) if prefix is not None else (k,)
)
result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata)
if isinstance(cls, str):
cls = CLS_MAP[cls]
result = cls._from_dict_validated(d, **cls_metadata)
if is_locked:
result = result.lock_()
result._consolidated = consolidated
if _is_tensorclass(cls):
result._tensordict._consolidated = consolidated
else:
result._consolidated = consolidated
return result

return from_metadata()
Expand Down
4 changes: 2 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4994,7 +4994,7 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata):
if requires_metadata:
# metadata is nested
metadata_dict = {
"cls": type(self).__name__,
"cls": type(self),
"non_tensors": {},
"leaves": {},
"cls_metadata": self._reduce_get_metadata(),
Expand Down Expand Up @@ -5055,7 +5055,7 @@ def assign(
metadata_dict_key = None
if requires_metadata:
metadata_dict_key = metadata_dict[key] = {
"cls": cls.__name__,
"cls": cls,
"non_tensors": {},
"leaves": {},
"cls_metadata": value._reduce_get_metadata(),
Expand Down

0 comments on commit 154fdf5

Please sign in to comment.