@@ -687,36 +687,14 @@ def apply_weights(self, destination: Any, weights: Any) -> None:
687687
688688 # Auto-detect format from weights type
689689 if isinstance (weights , dict ):
690- # Apply state_dict format
691- if isinstance (destination , nn .Module ):
692- destination .load_state_dict (weights )
693- elif isinstance (destination , dict ):
694- destination = TensorDict (destination )
695- weights = TensorDict (weights )
696- destination .data .update_ (weights .data )
697- elif isinstance (destination , TensorDictBase ):
698- weights_td = TensorDict (weights )
699- if (dest_keys := sorted (destination .keys (True , True ))) != sorted (
700- weights .keys (True , True )
701- ):
702- weights_td = weights_td .unflatten_keys ("." )
703- weights_keys = sorted (weights_td .keys (True , True ))
704- if dest_keys != weights_keys :
705- raise ValueError (
706- f"The keys of the weights and destination do not match: { dest_keys } != { weights_keys } "
707- )
708- destination .data .update_ (weights_td .data )
709- else :
710- raise ValueError (
711- f"Unsupported destination type for state_dict: { type (destination )} "
712- )
713- elif isinstance (weights , TensorDictBase ):
690+ weights = TensorDict (weights ).unflatten_keys ("." )
691+
692+ if isinstance (weights , TensorDictBase ):
714693 # Apply TensorDict format
715694 if isinstance (destination , nn .Module ):
716- weights .to_module (destination )
717- elif isinstance (destination , TensorDictBase ):
718- destination .data .update_ (weights .data )
719- elif isinstance (destination , dict ):
695+ destination = TensorDict .from_module (destination )
696+
697+ if isinstance (destination , dict ):
720698 destination_td = TensorDict (destination )
721699 if (dest_keys := sorted (destination_td .keys (True , True ))) != sorted (
722700 weights .keys (True , True )
@@ -727,11 +705,15 @@ def apply_weights(self, destination: Any, weights: Any) -> None:
727705 raise ValueError (
728706 f"The keys of the weights and destination do not match: { dest_keys } != { weights_keys } "
729707 )
730- destination_td .data .update_ (weights .data )
708+ destination = destination_td
709+
710+ if isinstance (destination , TensorDictBase ):
711+ destination .data .update_ (weights .data )
731712 else :
732713 raise ValueError (
733714 f"Unsupported destination type for TensorDict: { type (destination )} "
734715 )
716+
735717 else :
736718 raise ValueError (
737719 f"Unsupported weights type: { type (weights )} . Expected dict or TensorDictBase."
0 commit comments