From a3493defbc6a3a50a4e2fbf5cc1b2e51205cc94e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 21 Feb 2025 13:32:01 +0000 Subject: [PATCH] [Feature] TensorDict.tolist() (#1229) --- tensordict/_lazy.py | 4 +- tensordict/base.py | 134 ++++++++++++++++++++++++++++++++++++- tensordict/tensorclass.py | 17 +++-- tensordict/tensorclass.pyi | 7 +- test/test_tensordict.py | 31 +++++++++ 5 files changed, 184 insertions(+), 9 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 7d7a86f47..f13be041f 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -360,7 +360,9 @@ def _has_exclusive_keys(self): return False @_fails_exclusive_keys - def to_dict(self, *, retain_none: bool = True) -> dict[str, Any]: ... + def to_dict( + self, *, retain_none: bool = True, convert_tensors: bool = False + ) -> dict[str, Any]: ... def _reduce_get_metadata(self): metadata = {} diff --git a/tensordict/base.py b/tensordict/base.py index 6cbfd8beb..6f3b1926c 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -11509,13 +11509,48 @@ def as_tensor(tensor): return self._fast_apply(as_tensor, propagate_lock=True) - def to_dict(self, *, retain_none: bool = True) -> dict[str, Any]: + def to_dict( + self, *, retain_none: bool = True, convert_tensors: bool = False + ) -> dict[str, Any]: """Returns a dictionary with key-value pairs matching those of the tensordict. Args: retain_none (bool): if ``True``, the ``None`` values from tensorclass instances will be written in the dictionary. Otherwise, they will be discarded. Default: ``True``. + convert_tensors (bool): if ``True``, tensors will be converted to lists when creating the dictionary. + Otherwise, they will remain as tensors. Default: ``False``. + + Returns: + A dictionary representation of the tensordict. + + .. seealso:: :meth:`~tensordict.TensorDictBase.tolist` + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> + >>> td = TensorDict( + ... a=torch.arange(24).view(2, 3, 4), + ... b=TensorDict(c=torch.arange(12).reshape(2, 3, 2), batch_size=(2, 3, 2)), + ... batch_size=(2, 3) + ... ) + >>> print(td.to_dict()) + {'a': tensor([[[ 0, 1, 2, 3], + [ 4, 5, 6, 7], + [ 8, 9, 10, 11]], + + [[12, 13, 14, 15], + [16, 17, 18, 19], + [20, 21, 22, 23]]]), 'b': {'c': tensor([[[ 0, 1], + [ 2, 3], + [ 4, 5]], + + [[ 6, 7], + [ 8, 9], + [10, 11]]])}} + >>> print(td.to_dict(convert_tensors=True)) + {'a': [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]], 'b': {'c': [[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]]}} """ result = {} @@ -11527,10 +11562,105 @@ def to_dict(self, *, retain_none: bool = True) -> dict[str, Any]: and value.data is None ): continue - value = value.to_dict(retain_none=retain_none) + value = value.to_dict( + retain_none=retain_none, convert_tensors=convert_tensors + ) + elif convert_tensors and hasattr(value, "tolist"): + value = value.tolist() result[key] = value return result + def tolist( + self, *, convert_nodes: bool = True, convert_tensors: bool = False + ) -> List[Any]: + """Returns a nested list representation of the tensordict. + + If the tensordict has no batch dimensions, this method returns a single list or dictionary. + Otherwise, it returns a nested list where each inner list represents a batch dimension. + + Args: + convert_nodes (bool): if ``True``, leaf nodes will be converted to dictionaries. + Otherwise, they will be returned as lists of values. Default: ``True``. + convert_tensors (bool): if ``True``, tensors will be converted to lists when creating the dictionary. + Otherwise, they will remain as tensors. Default: ``False``. + + Returns: + A nested list representation of the tensordict. + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> + >>> td = TensorDict( + ... a=torch.arange(24).view(2, 3, 4), + ... b=TensorDict(c=torch.arange(12).reshape(2, 3, 2), batch_size=(2, 3, 2)), + ... batch_size=(2, 3) + ... ) + >>> + >>> print(td.tolist()) + [[{'a': tensor([0, 1, 2, 3]), 'b': {'c': tensor([0, 1])}}, {'a': tensor([4, 5, 6, 7]), 'b': {'c': tensor([2, 3])}}, {'a': tensor([ 8, 9, 10, 11]), 'b': {'c': tensor([4, 5])}}], [{'a': tensor([12, 13, 14, 15]), 'b': {'c': tensor([6, 7])}}, {'a': tensor([16, 17, 18, 19]), 'b': {'c': tensor([8, 9])}}, {'a': tensor([20, 21, 22, 23]), 'b': {'c': tensor([10, 11])}}]] + >>> print(td.tolist(convert_tensors=True)) + [[{'a': [0, 1, 2, 3], 'b': {'c': [0, 1]}}, {'a': [4, 5, 6, 7], 'b': {'c': [2, 3]}}, {'a': [8, 9, 10, 11], 'b': {'c': [4, 5]}}], [{'a': [12, 13, 14, 15], 'b': {'c': [6, 7]}}, {'a': [16, 17, 18, 19], 'b': {'c': [8, 9]}}, {'a': [20, 21, 22, 23], 'b': {'c': [10, 11]}}]] + >>> print(td.tolist(convert_nodes=False)) + [[[tensor([0, 1, 2, 3]), TensorDict( + fields={ + c: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([2]), + device=None, + is_shared=False)], [tensor([4, 5, 6, 7]), TensorDict( + fields={ + c: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([2]), + device=None, + is_shared=False)], [tensor([ 8, 9, 10, 11]), TensorDict( + fields={ + c: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([2]), + device=None, + is_shared=False)]], [[tensor([12, 13, 14, 15]), TensorDict( + fields={ + c: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([2]), + device=None, + is_shared=False)], [tensor([16, 17, 18, 19]), TensorDict( + fields={ + c: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([2]), + device=None, + is_shared=False)], [tensor([20, 21, 22, 23]), TensorDict( + fields={ + c: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([2]), + device=None, + is_shared=False)]]] + + """ + if convert_tensors and not convert_nodes: + raise TypeError("convert_tensors requires convert_nodes to be set to True") + if not self.batch_dims: + if convert_nodes: + return self.to_dict(convert_tensors=convert_tensors) + return self + + q = collections.deque() + result = [] + q.append((self, result)) + while len(q): + val, _result = q.popleft() + vals = val.unbind(0) + if val.ndim == 1: + if convert_nodes: + vals = [v.to_dict(convert_tensors=convert_tensors) for v in vals] + else: + vals = list(vals) + _result.extend(vals) + else: + for local_val in vals: + local_res = [] + _result.append(local_res) + q.append((local_val, local_res)) + return result + def numpy(self): """Converts a tensordict to a (possibly nested) dictionary of numpy arrays. diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 6a91a93c5..2c8da7d11 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -200,6 +200,7 @@ def __subclasscheck__(self, subclass): "size", "sorted_keys", "to_struct_array", + "tolist", "values", ] @@ -1888,8 +1889,10 @@ def _len(self) -> int: return len(self._tensordict) -def _to_dict(self, *, retain_none: bool = True) -> dict: - td_dict = self._tensordict.to_dict(retain_none=retain_none) +def _to_dict(self, *, retain_none: bool = True, convert_tensors: bool = False) -> dict: + td_dict = self._tensordict.to_dict( + retain_none=retain_none, convert_tensors=convert_tensors + ) if self._non_tensordict: if retain_none: td_dict.update(self._non_tensordict) @@ -3207,7 +3210,9 @@ def _apply_nest(self, *args, out=None, **kwargs): names=kwargs.get("names"), ) - def to_dict(self, *, retain_none: bool = True): + def to_dict( + self, *, retain_none: bool = True, convert_tensors: bool = False + ) -> dict[str, Any]: # override to_dict to return just the data return self.data @@ -3528,8 +3533,10 @@ def lazy_stack( ) return result - def to_dict(self, *, retain_none: bool = True) -> dict[str, Any]: - return self.tolist() + def to_dict( + self, *, retain_none: bool = True, convert_tensors: bool = False + ) -> dict[str, Any]: + return self.tolist(convert_tensors=convert_tensors) def to_tensordict(self, *, retain_none: bool | None = None): return self diff --git a/tensordict/tensorclass.pyi b/tensordict/tensorclass.pyi index 9db6d2c71..49c5039d3 100644 --- a/tensordict/tensorclass.pyi +++ b/tensordict/tensorclass.pyi @@ -967,7 +967,12 @@ class TensorClass: self, padding: float = 0.0, mask_key: NestedKey | None = None ): ... def as_tensor(self): ... - def to_dict(self, *, retain_none: bool = True) -> dict[str, Any]: ... + def to_dict( + self, *, retain_none: bool = True, convert_tensors: bool = False + ) -> dict[str, Any]: ... + def tolist( + self, *, convert_nodes: bool = True, convert_tensors: bool = False + ) -> List[Any]: ... def numpy(self): ... def to_namedtuple(self, dest_cls: type | None = None): ... @classmethod diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 3a76b3c9a..b5369ed31 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -3039,6 +3039,37 @@ def test_to_padded_tensor(self, mask_key): if mask_key: assert (td_padded[td_padded["mask"]] != 0).all() + @pytest.mark.parametrize("convert_nodes", [False, True]) + @pytest.mark.parametrize("convert_tensors", [False, True]) + def test_tolist(self, convert_nodes, convert_tensors): + td = TensorDict( + a=torch.arange(120).view(4, 5, 6), + b=TensorDict(c=torch.arange(40).reshape(4, 5, 2), batch_size=(4, 5, 2)), + batch_size=(4, 5), + ) + with ( + pytest.raises(TypeError, match="convert_tensors") + if convert_tensors and not convert_nodes + else contextlib.nullcontext() + ): + tdlist = td.tolist( + convert_nodes=convert_nodes, convert_tensors=convert_tensors + ) + assert isinstance(tdlist, list) + assert len(tdlist) == 4 + for i in range(4): + assert len(tdlist[i]) == 5 + if not convert_tensors: + assert (tdlist[0][0]["a"] == torch.arange(6)).all() + assert (tdlist[0][0]["b"]["c"] == torch.arange(2)).all() + else: + assert tdlist[0][0]["a"] == torch.arange(6).tolist() + assert tdlist[0][0]["b"]["c"] == torch.arange(2).tolist() + if convert_nodes: + assert isinstance(tdlist[0][0]["b"], dict) + else: + assert isinstance(tdlist[0][0]["b"], TensorDict) + def test_unbind_batchsize(self): td = TensorDict({"a": TensorDict({"b": torch.zeros(2, 3)}, [2, 3])}, [2]) td["a"].batch_size