From c8df23a156ad496805d75dbe195cd44021f0bd38 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 16 Sep 2024 16:19:46 -0700 Subject: [PATCH] [Performance] Remove list against list check during set ghstack-source-id: 0cd65696a91d83674212ca9a62dce02d1cabf44d Pull Request resolved: https://github.com/pytorch/tensordict/pull/954 --- tensordict/_td.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 376c93b97..3ee19e2f3 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -2323,18 +2323,16 @@ def _set_at_str(self, key, value, idx, *, validated, non_blocking: bool): if is_non_tensor(value) and not (self._is_shared or self._is_memmap): dest = tensor_in - is_diff = dest[idx].tolist() != value.tolist() - if is_diff: - dest_val = dest.maybe_to_stack() - dest_val[idx] = value - if dest_val is not dest: - self._set_str( - key, - dest_val, - validated=True, - inplace=False, - ignore_lock=True, - ) + dest_val = dest.maybe_to_stack() + dest_val[idx] = value + if dest_val is not dest: + self._set_str( + key, + dest_val, + validated=True, + inplace=False, + ignore_lock=True, + ) return if isinstance(idx, tuple) and len(idx) and isinstance(idx[0], tuple):