Skip to content

Commit bad63d8

Browse files
committed
Fix zeros in hidden are not tensors
1 parent 8570c25 commit bad63d8

File tree

1 file changed

+7
-5
lines changed
  • torchrl/modules/tensordict_module

1 file changed

+7
-5
lines changed

torchrl/modules/tensordict_module/rnn.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -724,9 +724,6 @@ def forward(self, tensordict: TensorDictBase):
724724
tensordict_shaped.get(key, default)
725725
for key, default in zip(self.in_keys, defaults)
726726
)
727-
batch, steps = value.shape[:2]
728-
device = value.device
729-
dtype = value.dtype
730727
# packed sequences do not help to get the accurate last hidden values
731728
# if splits is not None:
732729
# value = torch.nn.utils.rnn.pack_padded_sequence(value, splits, batch_first=True)
@@ -737,8 +734,13 @@ def forward(self, tensordict: TensorDictBase):
737734
# When using the recurrent_mode=True option, the lstm can be called from
738735
# any intermediate state, hence zeroing should not be done.
739736
is_init_expand = expand_as_right(is_init, hidden0)
740-
hidden0 = torch.where(is_init_expand, 0, hidden0)
741-
hidden1 = torch.where(is_init_expand, 0, hidden1)
737+
zeros = torch.zeros_like(hidden0)
738+
hidden0 = torch.where(is_init_expand, zeros, hidden0)
739+
hidden1 = torch.where(is_init_expand, zeros, hidden1)
740+
741+
batch, steps = value.shape[:2]
742+
device = value.device
743+
dtype = value.dtype
742744

743745
val, hidden0, hidden1 = self._lstm(
744746
value, batch, steps, device, dtype, hidden0, hidden1

0 commit comments

Comments
 (0)