diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index a69047e81..88b48c326 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -626,7 +626,6 @@ def stack_fn(key, values, is_not_init, is_tensor): key: stack_fn(key, values, is_not_init, is_tensor) for key, (values, is_not_init, is_tensor) in out.items() } - result = clz._new_unsafe( out, batch_size=LazyStackedTensorDict._compute_batch_size( diff --git a/tensordict/utils.py b/tensordict/utils.py index b7f3346cf..e1baa8534 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1757,7 +1757,7 @@ def _check_keys( strict: bool = False, include_nested: bool = False, leaves_only: bool = False, -) -> set[str]: +) -> set[str] | list[str]: from tensordict.base import _is_leaf_nontensor if not len(list_of_tensordicts): @@ -1769,9 +1769,9 @@ def _check_keys( ) # TODO: compile doesn't like set() over an arbitrary object if is_compiling(): - keys = {k for k in keys} # noqa: C416 + keys_set = {k for k in keys} # noqa: C416 else: - keys: set[str] = set(keys) + keys_set: set[str] = set(keys) for td in list_of_tensordicts[1:]: k = td.keys( include_nested=include_nested, @@ -1779,17 +1779,19 @@ def _check_keys( is_leaf=_is_leaf_nontensor, ) if not strict: - keys = keys.intersection(k) + keys_set = keys_set.intersection(k) else: if is_compiling(): k = {v for v in k} # noqa: C416 else: k = set(k) - if k != keys: + if k != keys_set: raise KeyError( f"got keys {keys} and {set(td.keys())} which are incompatible" ) - return keys + if strict: + return keys + return keys_set def _set_max_batch_size(source: T, batch_dims=None): diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index a572da7bc..bede17948 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -1887,6 +1887,20 @@ class MyDataNested: ): torch.stack([data1, data3], dim=0) + def test_stack_keyorder(self): + + class MyTensorClass(TensorClass): + foo: Tensor + bar: Tensor + + tc1 = MyTensorClass(foo=torch.zeros((1,)), bar=torch.ones((1,))) + + for _ in range(10000): + assert list(torch.stack([tc1, tc1], dim=0)._tensordict.keys()) == [ + "foo", + "bar", + ] + def test_statedict_errors(self): @tensorclass class MyClass: