Skip to content

Commit

Permalink
[Feature] TensorDict.tolist() (#1229)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 21, 2025
1 parent bbd5ba8 commit a3493de
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 9 deletions.
4 changes: 3 additions & 1 deletion tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,9 @@ def _has_exclusive_keys(self):
return False

@_fails_exclusive_keys
def to_dict(self, *, retain_none: bool = True) -> dict[str, Any]: ...
def to_dict(
self, *, retain_none: bool = True, convert_tensors: bool = False
) -> dict[str, Any]: ...

def _reduce_get_metadata(self):
metadata = {}
Expand Down
134 changes: 132 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11509,13 +11509,48 @@ def as_tensor(tensor):

return self._fast_apply(as_tensor, propagate_lock=True)

def to_dict(self, *, retain_none: bool = True) -> dict[str, Any]:
def to_dict(
self, *, retain_none: bool = True, convert_tensors: bool = False
) -> dict[str, Any]:
"""Returns a dictionary with key-value pairs matching those of the tensordict.

Args:
retain_none (bool): if ``True``, the ``None`` values from tensorclass instances
will be written in the dictionary.
Otherwise, they will be discarded. Default: ``True``.
convert_tensors (bool): if ``True``, tensors will be converted to lists when creating the dictionary.
Otherwise, they will remain as tensors. Default: ``False``.

Returns:
A dictionary representation of the tensordict.

.. seealso:: :meth:`~tensordict.TensorDictBase.tolist`

Examples:
>>> import torch
>>> from tensordict import TensorDict
>>>
>>> td = TensorDict(
... a=torch.arange(24).view(2, 3, 4),
... b=TensorDict(c=torch.arange(12).reshape(2, 3, 2), batch_size=(2, 3, 2)),
... batch_size=(2, 3)
... )
>>> print(td.to_dict())
{'a': tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],

[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]]), 'b': {'c': tensor([[[ 0, 1],
[ 2, 3],
[ 4, 5]],

[[ 6, 7],
[ 8, 9],
[10, 11]]])}}
>>> print(td.to_dict(convert_tensors=True))
{'a': [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]], 'b': {'c': [[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]]}}

"""
result = {}
Expand All @@ -11527,10 +11562,105 @@ def to_dict(self, *, retain_none: bool = True) -> dict[str, Any]:
and value.data is None
):
continue
value = value.to_dict(retain_none=retain_none)
value = value.to_dict(
retain_none=retain_none, convert_tensors=convert_tensors
)
elif convert_tensors and hasattr(value, "tolist"):
value = value.tolist()
result[key] = value
return result

def tolist(
self, *, convert_nodes: bool = True, convert_tensors: bool = False
) -> List[Any]:
"""Returns a nested list representation of the tensordict.

If the tensordict has no batch dimensions, this method returns a single list or dictionary.
Otherwise, it returns a nested list where each inner list represents a batch dimension.

Args:
convert_nodes (bool): if ``True``, leaf nodes will be converted to dictionaries.
Otherwise, they will be returned as lists of values. Default: ``True``.
convert_tensors (bool): if ``True``, tensors will be converted to lists when creating the dictionary.
Otherwise, they will remain as tensors. Default: ``False``.

Returns:
A nested list representation of the tensordict.

Examples:
>>> import torch
>>> from tensordict import TensorDict
>>>
>>> td = TensorDict(
... a=torch.arange(24).view(2, 3, 4),
... b=TensorDict(c=torch.arange(12).reshape(2, 3, 2), batch_size=(2, 3, 2)),
... batch_size=(2, 3)
... )
>>>
>>> print(td.tolist())
[[{'a': tensor([0, 1, 2, 3]), 'b': {'c': tensor([0, 1])}}, {'a': tensor([4, 5, 6, 7]), 'b': {'c': tensor([2, 3])}}, {'a': tensor([ 8, 9, 10, 11]), 'b': {'c': tensor([4, 5])}}], [{'a': tensor([12, 13, 14, 15]), 'b': {'c': tensor([6, 7])}}, {'a': tensor([16, 17, 18, 19]), 'b': {'c': tensor([8, 9])}}, {'a': tensor([20, 21, 22, 23]), 'b': {'c': tensor([10, 11])}}]]
>>> print(td.tolist(convert_tensors=True))
[[{'a': [0, 1, 2, 3], 'b': {'c': [0, 1]}}, {'a': [4, 5, 6, 7], 'b': {'c': [2, 3]}}, {'a': [8, 9, 10, 11], 'b': {'c': [4, 5]}}], [{'a': [12, 13, 14, 15], 'b': {'c': [6, 7]}}, {'a': [16, 17, 18, 19], 'b': {'c': [8, 9]}}, {'a': [20, 21, 22, 23], 'b': {'c': [10, 11]}}]]
>>> print(td.tolist(convert_nodes=False))
[[[tensor([0, 1, 2, 3]), TensorDict(
fields={
c: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)], [tensor([4, 5, 6, 7]), TensorDict(
fields={
c: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)], [tensor([ 8, 9, 10, 11]), TensorDict(
fields={
c: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)]], [[tensor([12, 13, 14, 15]), TensorDict(
fields={
c: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)], [tensor([16, 17, 18, 19]), TensorDict(
fields={
c: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)], [tensor([20, 21, 22, 23]), TensorDict(
fields={
c: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)]]]

"""
if convert_tensors and not convert_nodes:
raise TypeError("convert_tensors requires convert_nodes to be set to True")
if not self.batch_dims:
if convert_nodes:
return self.to_dict(convert_tensors=convert_tensors)
return self

q = collections.deque()
result = []
q.append((self, result))
while len(q):
val, _result = q.popleft()
vals = val.unbind(0)
if val.ndim == 1:
if convert_nodes:
vals = [v.to_dict(convert_tensors=convert_tensors) for v in vals]
else:
vals = list(vals)
_result.extend(vals)
else:
for local_val in vals:
local_res = []
_result.append(local_res)
q.append((local_val, local_res))
return result

def numpy(self):
"""Converts a tensordict to a (possibly nested) dictionary of numpy arrays.

Expand Down
17 changes: 12 additions & 5 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def __subclasscheck__(self, subclass):
"size",
"sorted_keys",
"to_struct_array",
"tolist",
"values",
]

Expand Down Expand Up @@ -1888,8 +1889,10 @@ def _len(self) -> int:
return len(self._tensordict)


def _to_dict(self, *, retain_none: bool = True) -> dict:
td_dict = self._tensordict.to_dict(retain_none=retain_none)
def _to_dict(self, *, retain_none: bool = True, convert_tensors: bool = False) -> dict:
td_dict = self._tensordict.to_dict(
retain_none=retain_none, convert_tensors=convert_tensors
)
if self._non_tensordict:
if retain_none:
td_dict.update(self._non_tensordict)
Expand Down Expand Up @@ -3207,7 +3210,9 @@ def _apply_nest(self, *args, out=None, **kwargs):
names=kwargs.get("names"),
)

def to_dict(self, *, retain_none: bool = True):
def to_dict(
self, *, retain_none: bool = True, convert_tensors: bool = False
) -> dict[str, Any]:
# override to_dict to return just the data
return self.data

Expand Down Expand Up @@ -3528,8 +3533,10 @@ def lazy_stack(
)
return result

def to_dict(self, *, retain_none: bool = True) -> dict[str, Any]:
return self.tolist()
def to_dict(
self, *, retain_none: bool = True, convert_tensors: bool = False
) -> dict[str, Any]:
return self.tolist(convert_tensors=convert_tensors)

def to_tensordict(self, *, retain_none: bool | None = None):
return self
Expand Down
7 changes: 6 additions & 1 deletion tensordict/tensorclass.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,12 @@ class TensorClass:
self, padding: float = 0.0, mask_key: NestedKey | None = None
): ...
def as_tensor(self): ...
def to_dict(self, *, retain_none: bool = True) -> dict[str, Any]: ...
def to_dict(
self, *, retain_none: bool = True, convert_tensors: bool = False
) -> dict[str, Any]: ...
def tolist(
self, *, convert_nodes: bool = True, convert_tensors: bool = False
) -> List[Any]: ...
def numpy(self): ...
def to_namedtuple(self, dest_cls: type | None = None): ...
@classmethod
Expand Down
31 changes: 31 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3039,6 +3039,37 @@ def test_to_padded_tensor(self, mask_key):
if mask_key:
assert (td_padded[td_padded["mask"]] != 0).all()

@pytest.mark.parametrize("convert_nodes", [False, True])
@pytest.mark.parametrize("convert_tensors", [False, True])
def test_tolist(self, convert_nodes, convert_tensors):
td = TensorDict(
a=torch.arange(120).view(4, 5, 6),
b=TensorDict(c=torch.arange(40).reshape(4, 5, 2), batch_size=(4, 5, 2)),
batch_size=(4, 5),
)
with (
pytest.raises(TypeError, match="convert_tensors")
if convert_tensors and not convert_nodes
else contextlib.nullcontext()
):
tdlist = td.tolist(
convert_nodes=convert_nodes, convert_tensors=convert_tensors
)
assert isinstance(tdlist, list)
assert len(tdlist) == 4
for i in range(4):
assert len(tdlist[i]) == 5
if not convert_tensors:
assert (tdlist[0][0]["a"] == torch.arange(6)).all()
assert (tdlist[0][0]["b"]["c"] == torch.arange(2)).all()
else:
assert tdlist[0][0]["a"] == torch.arange(6).tolist()
assert tdlist[0][0]["b"]["c"] == torch.arange(2).tolist()
if convert_nodes:
assert isinstance(tdlist[0][0]["b"], dict)
else:
assert isinstance(tdlist[0][0]["b"], TensorDict)

def test_unbind_batchsize(self):
td = TensorDict({"a": TensorDict({"b": torch.zeros(2, 3)}, [2, 3])}, [2])
td["a"].batch_size
Expand Down

2 comments on commit a3493de

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: a3493de Previous: c21ded4 Ratio
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_last 75448.96286408497 iter/sec (stddev: 9.143071678840789e-7) 231606.51975156958 iter/sec (stddev: 3.528242313624077e-7) 3.07
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_leaf_last 74693.3892447179 iter/sec (stddev: 7.726010987554742e-7) 232940.6037365483 iter/sec (stddev: 4.011178612554823e-7) 3.12

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'GPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: a3493de Previous: bbd5ba8 Ratio
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_last 121845.18434419009 iter/sec (stddev: 6.234414122016771e-7) 324669.70037109946 iter/sec (stddev: 4.621471867034303e-7) 2.66
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_leaf_last 120196.52404798503 iter/sec (stddev: 5.862319969022119e-7) 327790.13414549315 iter/sec (stddev: 3.3158778622712117e-7) 2.73

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.