@@ -218,7 +218,9 @@ def __init__(
218218 ** {key : val for key , val in _zip_strict (modules [0 ], modules_vals )}
219219 )
220220 super ().__init__ (
221- module = nn .ModuleDict (modules ), in_keys = in_keys , out_keys = out_keys
221+ module = nn .ModuleDict (modules ),
222+ in_keys = in_keys ,
223+ out_keys = out_keys ,
222224 )
223225 elif len (modules ) == 1 and isinstance (
224226 modules [0 ], collections .abc .MutableSequence
@@ -227,20 +229,25 @@ def __init__(
227229 in_keys , out_keys = self ._compute_in_and_out_keys (modules )
228230 self ._complete_out_keys = list (out_keys )
229231 super ().__init__ (
230- module = nn .ModuleList (modules ), in_keys = in_keys , out_keys = out_keys
232+ module = nn .ModuleList (modules ),
233+ in_keys = in_keys ,
234+ out_keys = out_keys ,
231235 )
232236 elif len (modules ) == 1 and isinstance (modules [0 ], dict ):
233237 return self .__init__ (
234238 collections .OrderedDict (modules [0 ]),
235239 partial_tolerant = partial_tolerant ,
236240 selected_out_keys = selected_out_keys ,
241+ inplace = inplace ,
237242 )
238243 else :
239244 modules = self ._convert_modules (modules )
240245 in_keys , out_keys = self ._compute_in_and_out_keys (modules )
241246 self ._complete_out_keys = list (out_keys )
242247 super ().__init__ (
243- module = nn .ModuleList (list (modules )), in_keys = in_keys , out_keys = out_keys
248+ module = nn .ModuleList (list (modules )),
249+ in_keys = in_keys ,
250+ out_keys = out_keys ,
244251 )
245252
246253 self .inplace = inplace
@@ -628,6 +635,7 @@ def forward(
628635 )
629636 if tensordict_out is not None :
630637 result = tensordict_out
638+ print ('here! update' )
631639 result .update (tensordict_exec , keys_to_update = self .out_keys )
632640 else :
633641 result = tensordict_exec
0 commit comments