Skip to content

Commit

Permalink
[BugFix] Fix non-deterministic key order in stack
Browse files Browse the repository at this point in the history
ghstack-source-id: 7f394789b783d6359a78a300aaf449eb25adb5e3
Pull Request resolved: #1230
  • Loading branch information
vmoens committed Feb 22, 2025
1 parent a3493de commit f5f84fd
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
1 change: 0 additions & 1 deletion tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,6 @@ def stack_fn(key, values, is_not_init, is_tensor):
key: stack_fn(key, values, is_not_init, is_tensor)
for key, (values, is_not_init, is_tensor) in out.items()
}

result = clz._new_unsafe(
out,
batch_size=LazyStackedTensorDict._compute_batch_size(
Expand Down
14 changes: 8 additions & 6 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1757,7 +1757,7 @@ def _check_keys(
strict: bool = False,
include_nested: bool = False,
leaves_only: bool = False,
) -> set[str]:
) -> set[str] | list[str]:
from tensordict.base import _is_leaf_nontensor

if not len(list_of_tensordicts):
Expand All @@ -1769,27 +1769,29 @@ def _check_keys(
)
# TODO: compile doesn't like set() over an arbitrary object
if is_compiling():
keys = {k for k in keys} # noqa: C416
keys_set = {k for k in keys} # noqa: C416
else:
keys: set[str] = set(keys)
keys_set: set[str] = set(keys)
for td in list_of_tensordicts[1:]:
k = td.keys(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=_is_leaf_nontensor,
)
if not strict:
keys = keys.intersection(k)
keys_set = keys_set.intersection(k)
else:
if is_compiling():
k = {v for v in k} # noqa: C416
else:
k = set(k)
if k != keys:
if k != keys_set:
raise KeyError(
f"got keys {keys} and {set(td.keys())} which are incompatible"
)
return keys
if strict:
return keys
return keys_set


def _set_max_batch_size(source: T, batch_dims=None):
Expand Down
14 changes: 14 additions & 0 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1887,6 +1887,20 @@ class MyDataNested:
):
torch.stack([data1, data3], dim=0)

def test_stack_keyorder(self):

class MyTensorClass(TensorClass):
foo: Tensor
bar: Tensor

tc1 = MyTensorClass(foo=torch.zeros((1,)), bar=torch.ones((1,)))

for _ in range(10000):
assert list(torch.stack([tc1, tc1], dim=0)._tensordict.keys()) == [
"foo",
"bar",
]

def test_statedict_errors(self):
@tensorclass
class MyClass:
Expand Down

0 comments on commit f5f84fd

Please sign in to comment.