From 0b901a77763ea2101f96f96cd68d87bbb9dbcda1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 19 Feb 2025 16:29:16 +0000 Subject: [PATCH] [BugFix] Consolidate lazy stacks of non-tensors ghstack-source-id: afb1480da5702ec582d4c8438ce16e569b819d9b Pull Request resolved: https://github.com/pytorch/tensordict/pull/1222 --- tensordict/base.py | 6 ++++-- test/test_tensordict.py | 11 +++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index a0e57dfab..21a4ab133 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -5043,7 +5043,8 @@ def assign( cls = type(value) if issubclass(cls, torch.Tensor): pass - elif _is_non_tensor(cls): + # We want to skip NonTensorStacks + elif _is_non_tensor(cls) and not issubclass(cls, TensorDictBase): if requires_metadata: metadata_dict["non_tensors"][key] = ( value.data, @@ -5411,7 +5412,8 @@ def _view_and_pad(tensor): if non_blocking and device.type != "cuda": # sync if needed self._sync_all() - torch.cat(items, out=storage) + if items: + torch.cat(items, out=storage) for v, (k, oldv) in _zip_strict( storage.split(flat_size), list(flat_dict.items()) ): diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 5c96fd90d..18aeecf33 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -11404,6 +11404,17 @@ def test_stack(self, non_tensor_data): LazyStackedTensorDict, ) + def test_stack_consolidate(self): + td = torch.stack( + [ + TensorDict(a="a string", b="b string"), + TensorDict(a="another string", b="bnother string"), + ] + ) + tdc = td.consolidate() + assert (tdc == td).all() + assert tdc["a"] == ["a string", "another string"] + def test_assign_non_tensor(self): data = TensorDict({}, [1, 10])