@@ -724,9 +724,6 @@ def forward(self, tensordict: TensorDictBase):
724724 tensordict_shaped .get (key , default )
725725 for key , default in zip (self .in_keys , defaults )
726726 )
727- batch , steps = value .shape [:2 ]
728- device = value .device
729- dtype = value .dtype
730727 # packed sequences do not help to get the accurate last hidden values
731728 # if splits is not None:
732729 # value = torch.nn.utils.rnn.pack_padded_sequence(value, splits, batch_first=True)
@@ -737,8 +734,13 @@ def forward(self, tensordict: TensorDictBase):
737734 # When using the recurrent_mode=True option, the lstm can be called from
738735 # any intermediate state, hence zeroing should not be done.
739736 is_init_expand = expand_as_right (is_init , hidden0 )
740- hidden0 = torch .where (is_init_expand , 0 , hidden0 )
741- hidden1 = torch .where (is_init_expand , 0 , hidden1 )
737+ zeros = torch .zeros_like (hidden0 )
738+ hidden0 = torch .where (is_init_expand , zeros , hidden0 )
739+ hidden1 = torch .where (is_init_expand , zeros , hidden1 )
740+
741+ batch , steps = value .shape [:2 ]
742+ device = value .device
743+ dtype = value .dtype
742744
743745 val , hidden0 , hidden1 = self ._lstm (
744746 value , batch , steps , device , dtype , hidden0 , hidden1
0 commit comments