37
37
MTPLlamaGPTHuggingfaceCheckpointFormat ,
38
38
Qwen2GPTHuggingfaceCheckpointFormat ,
39
39
Starcoder2GPTHuggingfaceCheckpointFormat ,
40
+ DiffusionDreamGPTHuggingfaceCheckpointFormat ,
41
+ DiffusionLlamaGPTHuggingfaceCheckpointFormat ,
40
42
)
41
43
from fast_llm .models .gpt .external .mtp_llama .configuration_mtp_llama import MTPLlamaConfig
44
+ from fast_llm .models .gpt .external .diffusion_dream .configuration_dream import DreamConfig
45
+ from fast_llm .models .gpt .external .diffusion_llama .configuration_diffusion_llama import DiffusionLlamaConfig
42
46
from fast_llm .models .gpt .model import GPTModel
43
47
from fast_llm .tensor import SafeTensorSlice
44
48
from fast_llm .utils import Assert
@@ -679,6 +683,124 @@ def _create_lm_head_converters(self) -> list[WeightConverter]:
679
683
680
684
return converters
681
685
686
+ class DiffusionDreamHuggingfaceCheckpointHandler (CustomModelingExportMixin , CommonHuggingfaceCheckpointHandler ):
687
+
688
+ from fast_llm .models .gpt .external .diffusion_dream import configuration_dream , modeling_dream , generation_utils
689
+
690
+ format : typing .ClassVar [type [CheckpointFormat ]] = DiffusionDreamGPTHuggingfaceCheckpointFormat
691
+ modeling_file = modeling_dream .__file__
692
+ configuration_file = configuration_dream .__file__
693
+ generation_utils_file = generation_utils .__file__
694
+ configuration_cls : typing .ClassVar [type [PretrainedConfig ]] = DreamConfig
695
+
696
+ @classmethod
697
+ def _create_config_converters (cls ) -> list [ParamConverter ]:
698
+ return super ()._create_config_converters () + [
699
+ # From Qwen2HuggingfaceCheckpointHandler - Change architectures to DiffusionDream
700
+ ConstantImportParamConverter (
701
+ fast_llm_names = (("transformer" , "normalization" , "type" ),), fast_llm_value = NormalizationType .rms_norm
702
+ ),
703
+ RenameParamConverter (
704
+ fast_llm_names = (("transformer" , "normalization" , "epsilon" ),), export_names = (("rms_norm_eps" ,),)
705
+ ),
706
+ ConstantImportParamConverter (fast_llm_names = (("transformer" , "gated" ),), fast_llm_value = True ),
707
+ ConstantImportParamConverter (
708
+ fast_llm_names = (("transformer" , "add_linear_biases" ),), fast_llm_value = "only_attn_qkv"
709
+ ),
710
+ RopeScalingParamConverter (
711
+ fast_llm_names = (
712
+ ("transformer" , "rotary" , "type" ),
713
+ ("transformer" , "rotary" , "scale_factor" ),
714
+ ("transformer" , "rotary" , "low_frequency_factor" ),
715
+ ("transformer" , "rotary" , "high_frequency_factor" ),
716
+ ("transformer" , "rotary" , "original_context_length" ),
717
+ ("transformer" , "rotary" , "attention_factor" ),
718
+ ("transformer" , "rotary" , "beta_fast" ),
719
+ ("transformer" , "rotary" , "beta_slow" ),
720
+ ),
721
+ export_names = (("rope_scaling" ,),),
722
+ ),
723
+ IgnoreImportQwen2SlidingWindowParamsConverter (),
724
+ ConstantExportParamConverter (export_names = (("architectures" ,),), export_value = ["DreamModel" ]),
725
+ ConstantExportParamConverter (
726
+ export_names = (("auto_map" ,),),
727
+ export_value = {
728
+ "AutoConfig" : "configuration_dream.DreamConfig" ,
729
+ "AutoModel" : "modeling_dream.DreamModel" ,
730
+ },
731
+ ),
732
+ ]
733
+
734
+
735
+ def _get_mlp_converters (self , fast_llm_prefix : str , hf_prefix : str ) -> list [WeightConverter ]:
736
+ # From Qwen2HuggingfaceCheckpointHandler
737
+ transformer_config : TransformerConfig = self ._model .config .base_model .transformer
738
+ return [
739
+ * self ._get_weight_and_bias_converters (
740
+ f"{ fast_llm_prefix } .mlp.layer_1" ,
741
+ (f"{ hf_prefix } .mlp.gate_proj" , f"{ hf_prefix } .mlp.up_proj" ),
742
+ transformer_config .add_mlp_bias ,
743
+ SplitWeightConverter ,
744
+ ),
745
+ * self ._get_weight_and_bias_converters (
746
+ f"{ fast_llm_prefix } .mlp.layer_2" ,
747
+ f"{ hf_prefix } .mlp.down_proj" ,
748
+ transformer_config .add_mlp_bias ,
749
+ MLPLayer2Converter ,
750
+ ),
751
+ ]
752
+
753
+ class DiffusionLlamaHuggingfaceCheckpointHandler (CustomModelingExportMixin , CommonLlamaHuggingfaceCheckpointHandler ):
754
+
755
+ from fast_llm .models .gpt .external .diffusion_llama import configuration_diffusion_llama , modeling_diffusion_llama , generation_utils
756
+
757
+ format : typing .ClassVar [type [CheckpointFormat ]] = DiffusionLlamaGPTHuggingfaceCheckpointFormat
758
+ modeling_file = modeling_diffusion_llama .__file__
759
+ configuration_file = configuration_diffusion_llama .__file__
760
+ generation_utils_file = generation_utils .__file__
761
+ configuration_cls : typing .ClassVar [type [PretrainedConfig ]] = DiffusionLlamaConfig
762
+
763
+ @classmethod
764
+ def _create_config_converters (cls ) -> list [ParamConverter ]:
765
+ return super ()._create_config_converters () + [
766
+ # From LlamaHuggingfaceCheckpointHandler - Update architectures to DiffusionLlama
767
+ # TODO: Llama supports biases
768
+ ConstantExportParamConverter (export_names = (("attention_bias" ,),), export_value = False ),
769
+ ConstantExportParamConverter (export_names = (("mlp_bias" ,),), export_value = False ),
770
+ ConstantExportParamConverter (export_names = (("architectures" ,),), export_value = ["DiffusionLlamaModel" ]),
771
+ ConstantExportParamConverter (
772
+ export_names = (("auto_map" ,),),
773
+ export_value = {
774
+ "AutoConfig" : "configuration_diffusion_llama.DiffusionLlamaConfig" ,
775
+ "AutoModel" : "modeling_diffusion_llama.DiffusionLlamaModel" ,
776
+ },),
777
+ # TODO: include when the mask diffusion training is implemented;
778
+ # since the imported model (llama) for CPT doesn't have it but the exported model (diffusion llama) does need to have this token.
779
+ # RenameParamConverter(
780
+ # fast_llm_names=(("mask_token_id",),),
781
+ # export_names=(("mask_token_id",),),
782
+ # ),
783
+ ]
784
+
785
+
786
+ def _get_mlp_converters (self , fast_llm_prefix : str , hf_prefix : str ) -> list [WeightConverter ]:
787
+ # From LlamaHuggingfaceCheckpointHandler
788
+ transformer_config : TransformerConfig = self ._model .config .base_model .transformer
789
+ return [
790
+ * self ._get_weight_and_bias_converters (
791
+ f"{ fast_llm_prefix } .mlp.layer_1" ,
792
+ (f"{ hf_prefix } .mlp.gate_proj" , f"{ hf_prefix } .mlp.up_proj" ),
793
+ transformer_config .add_mlp_bias ,
794
+ SplitWeightConverter ,
795
+ ),
796
+ * self ._get_weight_and_bias_converters (
797
+ f"{ fast_llm_prefix } .mlp.layer_2" ,
798
+ f"{ hf_prefix } .mlp.down_proj" ,
799
+ transformer_config .add_mlp_bias ,
800
+ MLPLayer2Converter ,
801
+ ),
802
+ ]
803
+
682
804
683
805
class AutoGPTHuggingfaceCheckpointHandler (
684
806
AutoStateDictCheckpointHandler , HuggingfaceStateDictCheckpointHandler , abc .ABC
@@ -691,4 +813,6 @@ class AutoGPTHuggingfaceCheckpointHandler(
691
813
MistralGPTHuggingfaceCheckpointFormat .name : MistralHuggingfaceCheckpointHandler ,
692
814
MixtralGPTHuggingfaceCheckpointFormat .name : MixtralHuggingfaceCheckpointHandler ,
693
815
MTPLlamaGPTHuggingfaceCheckpointFormat .name : MTPLlamaHuggingfaceCheckpointHandler ,
816
+ DiffusionDreamGPTHuggingfaceCheckpointFormat .name : DiffusionDreamHuggingfaceCheckpointHandler ,
817
+ DiffusionLlamaGPTHuggingfaceCheckpointFormat .name : DiffusionLlamaHuggingfaceCheckpointHandler ,
694
818
}
0 commit comments