Skip to content

Converter for Llama based Masked Diffusion Models (Based on Dream) #263

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions fast_llm/engine/checkpoint/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class CustomModelingExportMixin:
modeling_file: typing.ClassVar[str]
configuration_file: typing.ClassVar[str]
configuration_cls: typing.ClassVar[type[PretrainedConfig]]
generation_utils_file: str | None = None

# Use custom config instead of relying on the transformers library
@classmethod
Expand All @@ -153,3 +154,5 @@ def _copy_modeling_files(self, config: CheckpointSaveConfig) -> None:
# Copy the modeling files to the output directory
shutil.copy(self.modeling_file, config.path)
shutil.copy(self.configuration_file, config.path)
if self.generation_utils_file:
shutil.copy(self.generation_utils_file, config.path)
10 changes: 10 additions & 0 deletions fast_llm/models/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
class MTPLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
name: typing.ClassVar[str] = "mtp_llama"
trust_remote_code: typing.ClassVar[bool] = True

class DiffusionDreamGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
name: typing.ClassVar[str] = "dream"
trust_remote_code: typing.ClassVar[bool] = True

class DiffusionLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
name: typing.ClassVar[str] = "diffusion_llama"
trust_remote_code: typing.ClassVar[bool] = True


@config_class()
Expand Down Expand Up @@ -139,6 +147,8 @@ class GPTModelConfig(FastLLMModelConfig):
MistralGPTHuggingfaceCheckpointFormat,
MixtralGPTHuggingfaceCheckpointFormat,
MTPLlamaGPTHuggingfaceCheckpointFormat,
DiffusionDreamGPTHuggingfaceCheckpointFormat,
DiffusionLlamaGPTHuggingfaceCheckpointFormat,
)

@classmethod
Expand Down
124 changes: 124 additions & 0 deletions fast_llm/models/gpt/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,12 @@
MTPLlamaGPTHuggingfaceCheckpointFormat,
Qwen2GPTHuggingfaceCheckpointFormat,
Starcoder2GPTHuggingfaceCheckpointFormat,
DiffusionDreamGPTHuggingfaceCheckpointFormat,
DiffusionLlamaGPTHuggingfaceCheckpointFormat,
)
from fast_llm.models.gpt.external.mtp_llama.configuration_mtp_llama import MTPLlamaConfig
from fast_llm.models.gpt.external.diffusion_dream.configuration_dream import DreamConfig
from fast_llm.models.gpt.external.diffusion_llama.configuration_diffusion_llama import DiffusionLlamaConfig
from fast_llm.models.gpt.model import GPTModel
from fast_llm.tensor import SafeTensorSlice
from fast_llm.utils import Assert
Expand Down Expand Up @@ -679,6 +683,124 @@ def _create_lm_head_converters(self) -> list[WeightConverter]:

return converters

class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonHuggingfaceCheckpointHandler):

from fast_llm.models.gpt.external.diffusion_dream import configuration_dream, modeling_dream, generation_utils

format: typing.ClassVar[type[CheckpointFormat]] = DiffusionDreamGPTHuggingfaceCheckpointFormat
modeling_file = modeling_dream.__file__
configuration_file = configuration_dream.__file__
generation_utils_file = generation_utils.__file__
configuration_cls: typing.ClassVar[type[PretrainedConfig]] = DreamConfig

@classmethod
def _create_config_converters(cls) -> list[ParamConverter]:
return super()._create_config_converters() + [
# From Qwen2HuggingfaceCheckpointHandler - Change architectures to DiffusionDream
ConstantImportParamConverter(
fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm
),
RenameParamConverter(
fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),)
),
ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True),
ConstantImportParamConverter(
fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value="only_attn_qkv"
),
RopeScalingParamConverter(
fast_llm_names=(
("transformer", "rotary", "type"),
("transformer", "rotary", "scale_factor"),
("transformer", "rotary", "low_frequency_factor"),
("transformer", "rotary", "high_frequency_factor"),
("transformer", "rotary", "original_context_length"),
("transformer", "rotary", "attention_factor"),
("transformer", "rotary", "beta_fast"),
("transformer", "rotary", "beta_slow"),
),
export_names=(("rope_scaling",),),
),
IgnoreImportQwen2SlidingWindowParamsConverter(),
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["DreamModel"]),
ConstantExportParamConverter(
export_names=(("auto_map",),),
export_value={
"AutoConfig": "configuration_dream.DreamConfig",
"AutoModel": "modeling_dream.DreamModel",
},
),
]


def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]:
# From Qwen2HuggingfaceCheckpointHandler
transformer_config: TransformerConfig = self._model.config.base_model.transformer
return [
*self._get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_1",
(f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"),
transformer_config.add_mlp_bias,
SplitWeightConverter,
),
*self._get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_2",
f"{hf_prefix}.mlp.down_proj",
transformer_config.add_mlp_bias,
MLPLayer2Converter,
),
]

class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonLlamaHuggingfaceCheckpointHandler):

from fast_llm.models.gpt.external.diffusion_llama import configuration_diffusion_llama, modeling_diffusion_llama, generation_utils

format: typing.ClassVar[type[CheckpointFormat]] = DiffusionLlamaGPTHuggingfaceCheckpointFormat
modeling_file = modeling_diffusion_llama.__file__
configuration_file = configuration_diffusion_llama.__file__
generation_utils_file = generation_utils.__file__
configuration_cls: typing.ClassVar[type[PretrainedConfig]] = DiffusionLlamaConfig

@classmethod
def _create_config_converters(cls) -> list[ParamConverter]:
return super()._create_config_converters() + [
# From LlamaHuggingfaceCheckpointHandler - Update architectures to DiffusionLlama
# TODO: Llama supports biases
ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False),
ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False),
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["DiffusionLlamaModel"]),
ConstantExportParamConverter(
export_names=(("auto_map",),),
export_value={
"AutoConfig": "configuration_diffusion_llama.DiffusionLlamaConfig",
"AutoModel": "modeling_diffusion_llama.DiffusionLlamaModel",
},),
# TODO: include when the mask diffusion training is implemented;
# since the imported model (llama) for CPT doesn't have it but the exported model (diffusion llama) does need to have this token.
# RenameParamConverter(
# fast_llm_names=(("mask_token_id",),),
# export_names=(("mask_token_id",),),
# ),
]


def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]:
# From LlamaHuggingfaceCheckpointHandler
transformer_config: TransformerConfig = self._model.config.base_model.transformer
return [
*self._get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_1",
(f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"),
transformer_config.add_mlp_bias,
SplitWeightConverter,
),
*self._get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_2",
f"{hf_prefix}.mlp.down_proj",
transformer_config.add_mlp_bias,
MLPLayer2Converter,
),
]


class AutoGPTHuggingfaceCheckpointHandler(
AutoStateDictCheckpointHandler, HuggingfaceStateDictCheckpointHandler, abc.ABC
Expand All @@ -691,4 +813,6 @@ class AutoGPTHuggingfaceCheckpointHandler(
MistralGPTHuggingfaceCheckpointFormat.name: MistralHuggingfaceCheckpointHandler,
MixtralGPTHuggingfaceCheckpointFormat.name: MixtralHuggingfaceCheckpointHandler,
MTPLlamaGPTHuggingfaceCheckpointFormat.name: MTPLlamaHuggingfaceCheckpointHandler,
DiffusionDreamGPTHuggingfaceCheckpointFormat.name: DiffusionDreamHuggingfaceCheckpointHandler,
DiffusionLlamaGPTHuggingfaceCheckpointFormat.name: DiffusionLlamaHuggingfaceCheckpointHandler,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# coding=utf-8
# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dream model configuration"""

from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging


logger = logging.get_logger(__name__)


class DreamConfig(PretrainedConfig):
model_type = "dream"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size=151936,
hidden_size=4096,
intermediate_size=22016,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=False, # cache not used in diffusion
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=28,
attention_dropout=0.0,
mask_token_id=151666,
pad_token_id=151643, # vocab_size is set to 8192 for test cases this would fail on Embedding layer check: # pad_token_id=None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window if use_sliding_window else None
self.max_window_layers = max_window_layers

# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads

self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_dropout = attention_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)

super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.mask_token_id = mask_token_id
self.pad_token_id = pad_token_id
Loading