Skip to content

Commit 2d88c37

Browse files
authored
Converter for Llama based Masked Diffusion Models (Based on Dream) (#263)
1 parent d9bb084 commit 2d88c37

File tree

11 files changed

+5195
-1
lines changed

11 files changed

+5195
-1
lines changed

fast_llm/engine/checkpoint/huggingface.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ class CustomModelingExportMixin:
133133
modeling_file: typing.ClassVar[str]
134134
configuration_file: typing.ClassVar[str]
135135
configuration_cls: typing.ClassVar[type[PretrainedConfig]]
136+
generation_utils_file: str | None = None
136137

137138
# Use custom config instead of relying on the transformers library
138139
@classmethod
@@ -153,3 +154,5 @@ def _copy_modeling_files(self, config: CheckpointSaveConfig) -> None:
153154
# Copy the modeling files to the output directory
154155
shutil.copy(self.modeling_file, config.path)
155156
shutil.copy(self.configuration_file, config.path)
157+
if self.generation_utils_file:
158+
shutil.copy(self.generation_utils_file, config.path)

fast_llm/models/gpt/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
5656
class MTPLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
5757
name: typing.ClassVar[str] = "mtp_llama"
5858
trust_remote_code: typing.ClassVar[bool] = True
59+
60+
class DiffusionDreamGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
61+
name: typing.ClassVar[str] = "dream"
62+
trust_remote_code: typing.ClassVar[bool] = True
63+
64+
class DiffusionLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
65+
name: typing.ClassVar[str] = "diffusion_llama"
66+
trust_remote_code: typing.ClassVar[bool] = True
5967

6068

6169
@config_class()
@@ -139,6 +147,8 @@ class GPTModelConfig(FastLLMModelConfig):
139147
MistralGPTHuggingfaceCheckpointFormat,
140148
MixtralGPTHuggingfaceCheckpointFormat,
141149
MTPLlamaGPTHuggingfaceCheckpointFormat,
150+
DiffusionDreamGPTHuggingfaceCheckpointFormat,
151+
DiffusionLlamaGPTHuggingfaceCheckpointFormat,
142152
)
143153

144154
@classmethod

fast_llm/models/gpt/conversion.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,12 @@
3737
MTPLlamaGPTHuggingfaceCheckpointFormat,
3838
Qwen2GPTHuggingfaceCheckpointFormat,
3939
Starcoder2GPTHuggingfaceCheckpointFormat,
40+
DiffusionDreamGPTHuggingfaceCheckpointFormat,
41+
DiffusionLlamaGPTHuggingfaceCheckpointFormat,
4042
)
4143
from fast_llm.models.gpt.external.mtp_llama.configuration_mtp_llama import MTPLlamaConfig
44+
from fast_llm.models.gpt.external.diffusion_dream.configuration_dream import DreamConfig
45+
from fast_llm.models.gpt.external.diffusion_llama.configuration_diffusion_llama import DiffusionLlamaConfig
4246
from fast_llm.models.gpt.model import GPTModel
4347
from fast_llm.tensor import SafeTensorSlice
4448
from fast_llm.utils import Assert
@@ -679,6 +683,124 @@ def _create_lm_head_converters(self) -> list[WeightConverter]:
679683

680684
return converters
681685

686+
class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonHuggingfaceCheckpointHandler):
687+
688+
from fast_llm.models.gpt.external.diffusion_dream import configuration_dream, modeling_dream, generation_utils
689+
690+
format: typing.ClassVar[type[CheckpointFormat]] = DiffusionDreamGPTHuggingfaceCheckpointFormat
691+
modeling_file = modeling_dream.__file__
692+
configuration_file = configuration_dream.__file__
693+
generation_utils_file = generation_utils.__file__
694+
configuration_cls: typing.ClassVar[type[PretrainedConfig]] = DreamConfig
695+
696+
@classmethod
697+
def _create_config_converters(cls) -> list[ParamConverter]:
698+
return super()._create_config_converters() + [
699+
# From Qwen2HuggingfaceCheckpointHandler - Change architectures to DiffusionDream
700+
ConstantImportParamConverter(
701+
fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm
702+
),
703+
RenameParamConverter(
704+
fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),)
705+
),
706+
ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True),
707+
ConstantImportParamConverter(
708+
fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value="only_attn_qkv"
709+
),
710+
RopeScalingParamConverter(
711+
fast_llm_names=(
712+
("transformer", "rotary", "type"),
713+
("transformer", "rotary", "scale_factor"),
714+
("transformer", "rotary", "low_frequency_factor"),
715+
("transformer", "rotary", "high_frequency_factor"),
716+
("transformer", "rotary", "original_context_length"),
717+
("transformer", "rotary", "attention_factor"),
718+
("transformer", "rotary", "beta_fast"),
719+
("transformer", "rotary", "beta_slow"),
720+
),
721+
export_names=(("rope_scaling",),),
722+
),
723+
IgnoreImportQwen2SlidingWindowParamsConverter(),
724+
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["DreamModel"]),
725+
ConstantExportParamConverter(
726+
export_names=(("auto_map",),),
727+
export_value={
728+
"AutoConfig": "configuration_dream.DreamConfig",
729+
"AutoModel": "modeling_dream.DreamModel",
730+
},
731+
),
732+
]
733+
734+
735+
def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]:
736+
# From Qwen2HuggingfaceCheckpointHandler
737+
transformer_config: TransformerConfig = self._model.config.base_model.transformer
738+
return [
739+
*self._get_weight_and_bias_converters(
740+
f"{fast_llm_prefix}.mlp.layer_1",
741+
(f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"),
742+
transformer_config.add_mlp_bias,
743+
SplitWeightConverter,
744+
),
745+
*self._get_weight_and_bias_converters(
746+
f"{fast_llm_prefix}.mlp.layer_2",
747+
f"{hf_prefix}.mlp.down_proj",
748+
transformer_config.add_mlp_bias,
749+
MLPLayer2Converter,
750+
),
751+
]
752+
753+
class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonLlamaHuggingfaceCheckpointHandler):
754+
755+
from fast_llm.models.gpt.external.diffusion_llama import configuration_diffusion_llama, modeling_diffusion_llama, generation_utils
756+
757+
format: typing.ClassVar[type[CheckpointFormat]] = DiffusionLlamaGPTHuggingfaceCheckpointFormat
758+
modeling_file = modeling_diffusion_llama.__file__
759+
configuration_file = configuration_diffusion_llama.__file__
760+
generation_utils_file = generation_utils.__file__
761+
configuration_cls: typing.ClassVar[type[PretrainedConfig]] = DiffusionLlamaConfig
762+
763+
@classmethod
764+
def _create_config_converters(cls) -> list[ParamConverter]:
765+
return super()._create_config_converters() + [
766+
# From LlamaHuggingfaceCheckpointHandler - Update architectures to DiffusionLlama
767+
# TODO: Llama supports biases
768+
ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False),
769+
ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False),
770+
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["DiffusionLlamaModel"]),
771+
ConstantExportParamConverter(
772+
export_names=(("auto_map",),),
773+
export_value={
774+
"AutoConfig": "configuration_diffusion_llama.DiffusionLlamaConfig",
775+
"AutoModel": "modeling_diffusion_llama.DiffusionLlamaModel",
776+
},),
777+
# TODO: include when the mask diffusion training is implemented;
778+
# since the imported model (llama) for CPT doesn't have it but the exported model (diffusion llama) does need to have this token.
779+
# RenameParamConverter(
780+
# fast_llm_names=(("mask_token_id",),),
781+
# export_names=(("mask_token_id",),),
782+
# ),
783+
]
784+
785+
786+
def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]:
787+
# From LlamaHuggingfaceCheckpointHandler
788+
transformer_config: TransformerConfig = self._model.config.base_model.transformer
789+
return [
790+
*self._get_weight_and_bias_converters(
791+
f"{fast_llm_prefix}.mlp.layer_1",
792+
(f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"),
793+
transformer_config.add_mlp_bias,
794+
SplitWeightConverter,
795+
),
796+
*self._get_weight_and_bias_converters(
797+
f"{fast_llm_prefix}.mlp.layer_2",
798+
f"{hf_prefix}.mlp.down_proj",
799+
transformer_config.add_mlp_bias,
800+
MLPLayer2Converter,
801+
),
802+
]
803+
682804

683805
class AutoGPTHuggingfaceCheckpointHandler(
684806
AutoStateDictCheckpointHandler, HuggingfaceStateDictCheckpointHandler, abc.ABC
@@ -691,4 +813,6 @@ class AutoGPTHuggingfaceCheckpointHandler(
691813
MistralGPTHuggingfaceCheckpointFormat.name: MistralHuggingfaceCheckpointHandler,
692814
MixtralGPTHuggingfaceCheckpointFormat.name: MixtralHuggingfaceCheckpointHandler,
693815
MTPLlamaGPTHuggingfaceCheckpointFormat.name: MTPLlamaHuggingfaceCheckpointHandler,
816+
DiffusionDreamGPTHuggingfaceCheckpointFormat.name: DiffusionDreamHuggingfaceCheckpointHandler,
817+
DiffusionLlamaGPTHuggingfaceCheckpointFormat.name: DiffusionLlamaHuggingfaceCheckpointHandler,
694818
}
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# coding=utf-8
2+
# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Dream model configuration"""
16+
17+
from transformers.configuration_utils import PretrainedConfig
18+
from transformers.modeling_rope_utils import rope_config_validation
19+
from transformers.utils import logging
20+
21+
22+
logger = logging.get_logger(__name__)
23+
24+
25+
class DreamConfig(PretrainedConfig):
26+
model_type = "dream"
27+
keys_to_ignore_at_inference = ["past_key_values"]
28+
29+
def __init__(
30+
self,
31+
vocab_size=151936,
32+
hidden_size=4096,
33+
intermediate_size=22016,
34+
num_hidden_layers=32,
35+
num_attention_heads=32,
36+
num_key_value_heads=32,
37+
hidden_act="silu",
38+
max_position_embeddings=32768,
39+
initializer_range=0.02,
40+
rms_norm_eps=1e-6,
41+
use_cache=False, # cache not used in diffusion
42+
tie_word_embeddings=False,
43+
rope_theta=10000.0,
44+
rope_scaling=None,
45+
use_sliding_window=False,
46+
sliding_window=4096,
47+
max_window_layers=28,
48+
attention_dropout=0.0,
49+
mask_token_id=151666,
50+
pad_token_id=151643, # vocab_size is set to 8192 for test cases this would fail on Embedding layer check: # pad_token_id=None,
51+
**kwargs,
52+
):
53+
self.vocab_size = vocab_size
54+
self.max_position_embeddings = max_position_embeddings
55+
self.hidden_size = hidden_size
56+
self.intermediate_size = intermediate_size
57+
self.num_hidden_layers = num_hidden_layers
58+
self.num_attention_heads = num_attention_heads
59+
self.use_sliding_window = use_sliding_window
60+
self.sliding_window = sliding_window if use_sliding_window else None
61+
self.max_window_layers = max_window_layers
62+
63+
# for backward compatibility
64+
if num_key_value_heads is None:
65+
num_key_value_heads = num_attention_heads
66+
67+
self.num_key_value_heads = num_key_value_heads
68+
self.hidden_act = hidden_act
69+
self.initializer_range = initializer_range
70+
self.rms_norm_eps = rms_norm_eps
71+
self.use_cache = use_cache
72+
self.rope_theta = rope_theta
73+
self.rope_scaling = rope_scaling
74+
self.attention_dropout = attention_dropout
75+
# Validate the correctness of rotary position embeddings parameters
76+
# BC: if there is a 'type' field, move it to 'rope_type'.
77+
if self.rope_scaling is not None and "type" in self.rope_scaling:
78+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
79+
rope_config_validation(self)
80+
81+
super().__init__(
82+
tie_word_embeddings=tie_word_embeddings,
83+
**kwargs,
84+
)
85+
self.mask_token_id = mask_token_id
86+
self.pad_token_id = pad_token_id

0 commit comments

Comments
 (0)