diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index f9d0bf60..8d17d0c8 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -133,6 +133,7 @@ class CustomModelingExportMixin: modeling_file: typing.ClassVar[str] configuration_file: typing.ClassVar[str] configuration_cls: typing.ClassVar[type[PretrainedConfig]] + generation_utils_file: str | None = None # Use custom config instead of relying on the transformers library @classmethod @@ -153,3 +154,5 @@ def _copy_modeling_files(self, config: CheckpointSaveConfig) -> None: # Copy the modeling files to the output directory shutil.copy(self.modeling_file, config.path) shutil.copy(self.configuration_file, config.path) + if self.generation_utils_file: + shutil.copy(self.generation_utils_file, config.path) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 1c79a6ec..6bf6e06c 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -56,6 +56,14 @@ class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): class MTPLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "mtp_llama" trust_remote_code: typing.ClassVar[bool] = True + +class DiffusionDreamGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "dream" + trust_remote_code: typing.ClassVar[bool] = True + +class DiffusionLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "diffusion_llama" + trust_remote_code: typing.ClassVar[bool] = True @config_class() @@ -139,6 +147,8 @@ class GPTModelConfig(FastLLMModelConfig): MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, MTPLlamaGPTHuggingfaceCheckpointFormat, + DiffusionDreamGPTHuggingfaceCheckpointFormat, + DiffusionLlamaGPTHuggingfaceCheckpointFormat, ) @classmethod diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 5c689629..82d464b0 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -37,8 +37,12 @@ MTPLlamaGPTHuggingfaceCheckpointFormat, Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, + DiffusionDreamGPTHuggingfaceCheckpointFormat, + DiffusionLlamaGPTHuggingfaceCheckpointFormat, ) from fast_llm.models.gpt.external.mtp_llama.configuration_mtp_llama import MTPLlamaConfig +from fast_llm.models.gpt.external.diffusion_dream.configuration_dream import DreamConfig +from fast_llm.models.gpt.external.diffusion_llama.configuration_diffusion_llama import DiffusionLlamaConfig from fast_llm.models.gpt.model import GPTModel from fast_llm.tensor import SafeTensorSlice from fast_llm.utils import Assert @@ -679,6 +683,124 @@ def _create_lm_head_converters(self) -> list[WeightConverter]: return converters +class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonHuggingfaceCheckpointHandler): + + from fast_llm.models.gpt.external.diffusion_dream import configuration_dream, modeling_dream, generation_utils + + format: typing.ClassVar[type[CheckpointFormat]] = DiffusionDreamGPTHuggingfaceCheckpointFormat + modeling_file = modeling_dream.__file__ + configuration_file = configuration_dream.__file__ + generation_utils_file = generation_utils.__file__ + configuration_cls: typing.ClassVar[type[PretrainedConfig]] = DreamConfig + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + # From Qwen2HuggingfaceCheckpointHandler - Change architectures to DiffusionDream + ConstantImportParamConverter( + fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm + ), + RenameParamConverter( + fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) + ), + ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value="only_attn_qkv" + ), + RopeScalingParamConverter( + fast_llm_names=( + ("transformer", "rotary", "type"), + ("transformer", "rotary", "scale_factor"), + ("transformer", "rotary", "low_frequency_factor"), + ("transformer", "rotary", "high_frequency_factor"), + ("transformer", "rotary", "original_context_length"), + ("transformer", "rotary", "attention_factor"), + ("transformer", "rotary", "beta_fast"), + ("transformer", "rotary", "beta_slow"), + ), + export_names=(("rope_scaling",),), + ), + IgnoreImportQwen2SlidingWindowParamsConverter(), + ConstantExportParamConverter(export_names=(("architectures",),), export_value=["DreamModel"]), + ConstantExportParamConverter( + export_names=(("auto_map",),), + export_value={ + "AutoConfig": "configuration_dream.DreamConfig", + "AutoModel": "modeling_dream.DreamModel", + }, + ), + ] + + + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + # From Qwen2HuggingfaceCheckpointHandler + transformer_config: TransformerConfig = self._model.config.base_model.transformer + return [ + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), + transformer_config.add_mlp_bias, + SplitWeightConverter, + ), + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + transformer_config.add_mlp_bias, + MLPLayer2Converter, + ), + ] + +class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonLlamaHuggingfaceCheckpointHandler): + + from fast_llm.models.gpt.external.diffusion_llama import configuration_diffusion_llama, modeling_diffusion_llama, generation_utils + + format: typing.ClassVar[type[CheckpointFormat]] = DiffusionLlamaGPTHuggingfaceCheckpointFormat + modeling_file = modeling_diffusion_llama.__file__ + configuration_file = configuration_diffusion_llama.__file__ + generation_utils_file = generation_utils.__file__ + configuration_cls: typing.ClassVar[type[PretrainedConfig]] = DiffusionLlamaConfig + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + # From LlamaHuggingfaceCheckpointHandler - Update architectures to DiffusionLlama + # TODO: Llama supports biases + ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False), + ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False), + ConstantExportParamConverter(export_names=(("architectures",),), export_value=["DiffusionLlamaModel"]), + ConstantExportParamConverter( + export_names=(("auto_map",),), + export_value={ + "AutoConfig": "configuration_diffusion_llama.DiffusionLlamaConfig", + "AutoModel": "modeling_diffusion_llama.DiffusionLlamaModel", + },), + # TODO: include when the mask diffusion training is implemented; + # since the imported model (llama) for CPT doesn't have it but the exported model (diffusion llama) does need to have this token. + # RenameParamConverter( + # fast_llm_names=(("mask_token_id",),), + # export_names=(("mask_token_id",),), + # ), + ] + + + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + # From LlamaHuggingfaceCheckpointHandler + transformer_config: TransformerConfig = self._model.config.base_model.transformer + return [ + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), + transformer_config.add_mlp_bias, + SplitWeightConverter, + ), + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + transformer_config.add_mlp_bias, + MLPLayer2Converter, + ), + ] + class AutoGPTHuggingfaceCheckpointHandler( AutoStateDictCheckpointHandler, HuggingfaceStateDictCheckpointHandler, abc.ABC @@ -691,4 +813,6 @@ class AutoGPTHuggingfaceCheckpointHandler( MistralGPTHuggingfaceCheckpointFormat.name: MistralHuggingfaceCheckpointHandler, MixtralGPTHuggingfaceCheckpointFormat.name: MixtralHuggingfaceCheckpointHandler, MTPLlamaGPTHuggingfaceCheckpointFormat.name: MTPLlamaHuggingfaceCheckpointHandler, + DiffusionDreamGPTHuggingfaceCheckpointFormat.name: DiffusionDreamHuggingfaceCheckpointHandler, + DiffusionLlamaGPTHuggingfaceCheckpointFormat.name: DiffusionLlamaHuggingfaceCheckpointHandler, } diff --git a/fast_llm/models/gpt/external/diffusion_dream/configuration_dream.py b/fast_llm/models/gpt/external/diffusion_dream/configuration_dream.py new file mode 100644 index 00000000..58bbd488 --- /dev/null +++ b/fast_llm/models/gpt/external/diffusion_dream/configuration_dream.py @@ -0,0 +1,86 @@ +# coding=utf-8 +# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Dream model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class DreamConfig(PretrainedConfig): + model_type = "dream" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=False, # cache not used in diffusion + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + mask_token_id=151666, + pad_token_id=151643, # vocab_size is set to 8192 for test cases this would fail on Embedding layer check: # pad_token_id=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if use_sliding_window else None + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.mask_token_id = mask_token_id + self.pad_token_id = pad_token_id diff --git a/fast_llm/models/gpt/external/diffusion_dream/generation_utils.py b/fast_llm/models/gpt/external/diffusion_dream/generation_utils.py new file mode 100644 index 00000000..b70dcf49 --- /dev/null +++ b/fast_llm/models/gpt/external/diffusion_dream/generation_utils.py @@ -0,0 +1,1040 @@ +# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import warnings +from dataclasses import dataclass +from math import ceil +from typing import Any, Optional, Union + +import torch +import torch.distributions as dists +from torch.nn import functional as F +from transformers import __version__ +from transformers.generation.configuration_utils import GenerationConfig +from transformers.generation.utils import GenerationMixin +from transformers.utils import ModelOutput, is_torchdynamo_compiling, logging + +logger = logging.get_logger(__name__) + + +def top_p_logits(logits, top_p=None): + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device) + mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove) + logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min) + return logits + + +def top_k_logits(logits, top_k=None): + top_k = min(top_k, logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min) + return logits + + +def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False): + + if temperature > 0: + logits = logits / temperature + if top_p is not None and top_p < 1: + logits = top_p_logits(logits, top_p) + if top_k is not None: + logits = top_k_logits(logits, top_k) + probs = torch.softmax(logits, dim=-1) + + if temperature > 0: + try: + x0 = dists.Categorical(probs=probs).sample() + confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) + except: + confidence, x0 = probs.max(dim=-1) + else: + confidence, x0 = probs.max(dim=-1) + + if margin_confidence: + sorted_probs, _ = torch.sort(probs, dim=-1, descending=True) + # Extract top1 and top2 probabilities + top1_probs = sorted_probs[:, 0] + top2_probs = sorted_probs[:, 1] + # Calculate confidence as top1 - top2 + confidence = top1_probs - top2_probs + + if neg_entropy: + epsilon = 1e-10 + log_probs = torch.log(probs + epsilon) + confidence = torch.sum(probs * log_probs, dim=-1) + + return confidence, x0 + + +# batch_sample_tokens +def batch_sample_tokens( + logits, mask_indexes, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False +): + # print(f"batch_sample_tokens: {logits.shape} ") + if temperature > 0: + logits = logits / temperature + if top_p is not None and top_p < 1: + # logit will have different lengths for each sequence so cannot stack - it is not a proper batch??? + logits = torch.stack([top_p_logits(logit[mask], top_p) for logit, mask in zip(logits, mask_indexes)], dim=0) + if top_k is not None: + logits = torch.stack([top_k_logits(logit[mask], top_k) for logit, mask in zip(logits, mask_indexes)], dim=0) + + # if logits are not of the same sequence so therefore we can pad them with -inf but need remove them back ... + probs = torch.softmax(logits, dim=-1) + # print(f"probs: {probs.shape}") + + if temperature > 0: + try: + x0 = dists.Categorical(probs=probs).sample() + confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) + except: + confidence, x0 = probs.max(dim=-1) + else: + confidence, x0 = probs.max(dim=-1) + + # print(f"confidence: {confidence.shape} x0: {x0.shape}") + if margin_confidence: + sorted_probs, _ = torch.sort(probs, dim=-1, descending=True) + # Extract top1 and top2 probabilities + top1_probs = sorted_probs[:, 0] + top2_probs = sorted_probs[:, 1] + # Calculate confidence as top1 - top2 + confidence = top1_probs - top2_probs + + if neg_entropy: + epsilon = 1e-10 + log_probs = torch.log(probs + epsilon) + confidence = torch.sum(probs * log_probs, dim=-1) + + return confidence, x0 + + +@dataclass +class DreamModelOutput(ModelOutput): + sequences: torch.LongTensor = None + history: Optional[tuple[torch.FloatTensor]] = None + + +class DreamGenerationConfig(GenerationConfig): + def __init__(self, **kwargs): + self.temperature: float = kwargs.pop("temperature", 0.0) + self.top_p: Optional[float] = kwargs.pop("top_p", None) + self.top_k: Optional[int] = kwargs.pop("top_k", None) + self.max_length = kwargs.pop("max_length", 20) + self.max_new_tokens = kwargs.pop("max_new_tokens", None) + # diffusion specific params + self.eps: float = kwargs.pop("eps", 1e-3) + self.steps: int = kwargs.pop("steps", 512) + self.alg: str = kwargs.pop("alg", "origin") + self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None) + + # Parameters that define the output variables of `generate` + self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1) + self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False) + self.output_history: bool = kwargs.pop("output_history", False) + + # Special tokens that can be used at generation time + self.mask_token_id = kwargs.pop("mask_token_id", None) + self.pad_token_id = kwargs.pop("pad_token_id", None) + self.bos_token_id = kwargs.pop("bos_token_id", None) + self.eos_token_id = kwargs.pop("eos_token_id", None) + + # Wild card + self.generation_kwargs = kwargs.pop("generation_kwargs", {}) + + # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub + # interface. + self._from_model_config = kwargs.pop("_from_model_config", False) + self._commit_hash = kwargs.pop("_commit_hash", None) + self.transformers_version = kwargs.pop("transformers_version", __version__) + + # Additional attributes without default values + if not self._from_model_config: + # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a + # model's default configuration file + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error(f"Can't set {key} with value {value} for {self}") + raise err + + # Validate the values of the attributes + self.validate(is_init=True) + + def validate(self, is_init=False): + pass + + +class DreamGenerationMixin(GenerationMixin): + @staticmethod + def _expand_inputs_for_generation( + expand_size: int = 1, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]""" + # Do not call torch.repeat_interleave if expand_size is 1 because it clones + # the input tensor and thus requires more memory although no change is applied + if expand_size == 1: + return input_ids, attention_mask + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + if attention_mask is not None: + attention_mask = attention_mask.repeat_interleave(expand_size, dim=0) + return input_ids, attention_mask + + def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length): + """Performs validation related to the resulting generated length""" + + # Can't throw warnings/exceptions during compilation + if is_torchdynamo_compiling(): + return + + # 1. Max length warnings related to poor parameterization + if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: + # 20 is the default max_length of the generation config + warnings.warn( + f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the " + "generation length. We recommend setting `max_new_tokens` to control the maximum length of the " + "generation.", + UserWarning, + ) + if input_ids_length >= generation_config.max_length: + input_ids_string = "input_ids" + raise ValueError( + f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_length` or, better yet, setting `max_new_tokens`." + ) + + def _prepare_generated_length( + self, + generation_config, + has_default_max_length, + input_ids_length, + ): + """Prepared max and min length in generation configs to avoid clashes between similar attributes""" + + if generation_config.max_new_tokens is not None: + if not has_default_max_length and generation_config.max_length is not None: + logger.warning( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + generation_config.max_length = generation_config.max_new_tokens + input_ids_length + + elif has_default_max_length: + if generation_config.max_length == DreamGenerationConfig().max_length: + generation_config.max_length = generation_config.max_length + input_ids_length + max_position_embeddings = getattr(self.config, "max_position_embeddings", None) + if max_position_embeddings is not None: + generation_config.max_length = min(generation_config.max_length, max_position_embeddings) + + return generation_config + + def _prepare_generation_config( + self, generation_config: Optional[DreamGenerationConfig], **kwargs: dict + ) -> DreamGenerationConfig: + """ + Prepares the base generation config, then applies any generation configuration options from kwargs. This + function handles retrocompatibility with respect to configuration files. + """ + # priority: `generation_config` argument > `model.generation_config` (the default generation config) + using_model_generation_config = False + if generation_config is None: + generation_config = DreamGenerationConfig.from_model_config(self.config) + using_model_generation_config = True + + # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config` + # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an + # exception will be raised in `_validate_model_kwargs` + if not is_torchdynamo_compiling(): + generation_config = copy.deepcopy(generation_config) + generation_config.update(**kwargs) + # If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model + if not using_model_generation_config: + if generation_config.bos_token_id is None: + generation_config.bos_token_id = self.generation_config.bos_token_id + if generation_config.eos_token_id is None: + generation_config.eos_token_id = self.generation_config.eos_token_id + if generation_config.pad_token_id is None: + generation_config.pad_token_id = self.generation_config.pad_token_id + if generation_config.mask_token_id is None: + generation_config.mask_token_id = self.generation_config.mask_token_id + + return generation_config + + def _prepare_special_tokens( + self, + generation_config: DreamGenerationConfig, + device: Optional[Union[torch.device, str]] = None, + ): + """ + Prepares the special tokens for generation, overwriting the generation config with their processed versions + converted to tensor. + + Note that `generation_config` is changed in place and stops being serializable after this method is called. + That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the + function). However, if called outside `generate`, consider creating a copy of `generation_config` first. + """ + + # Convert special tokens to tensors + def _tensor_or_none(token, device=None): + if token is None: + return token + + device = device if device is not None else self.device + if isinstance(token, torch.Tensor): + return token.to(device) + return torch.tensor(token, device=device, dtype=torch.long) + + bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device) + eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device) + pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device) + mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device) + + # We can have more than one eos token. Always treat it as a 1D tensor (when it exists). + if eos_token_tensor is not None and eos_token_tensor.ndim == 0: + eos_token_tensor = eos_token_tensor.unsqueeze(0) + + # Set pad token if unset (and there are conditions to do so) + if pad_token_tensor is None and eos_token_tensor is not None: + pad_token_tensor = eos_token_tensor[0] + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.") + + # Update generation config with the updated special tokens tensors + # NOTE: this must be written into a different attribute name than the one holding the original special tokens + # (in their non-tensor form), in order to enable end-to-end compilation. See + # https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations + generation_config._bos_token_tensor = bos_token_tensor + generation_config._eos_token_tensor = eos_token_tensor + generation_config._pad_token_tensor = pad_token_tensor + generation_config._mask_token_tensor = mask_token_tensor + + @torch.no_grad() + def diffusion_generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[DreamGenerationConfig] = None, + **kwargs, + ) -> Union[DreamModelOutput, torch.LongTensor]: + # fix seed for reproducability torch.random.manual_seed - lm-eval is setting it + torch.random.manual_seed(0) + + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call + generation_config = self._prepare_generation_config(generation_config, **kwargs) + generation_tokens_hook_func = kwargs.pop( + "generation_tokens_hook_func", lambda step, x, logits, end_of_prompt: x + ) + generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits) + + # 2. Define model inputs + assert inputs is not None + input_ids = inputs + device = input_ids.device + attention_mask = kwargs.pop("attention_mask", None) + self._prepare_special_tokens(generation_config, device=device) + + # 3. Prepare `max_length`. + input_ids_length = input_ids.shape[-1] + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length= # The code `has_default_max_length` is not a valid Python code + # snippet. It seems to be a placeholder or a comment in the code. + # It does not perform any specific action or functionality in + # Python. + has_default_max_length, + input_ids_length=input_ids_length, + ) + + self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) + + # 4. Check input_ids + if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type: + warnings.warn( + "You are calling .generate() with the `input_ids` being on a device type different" + f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" + f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." + " Please make sure that you have put `input_ids` to the" + f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" + " running `.generate()`.", + UserWarning, + ) + if ( + hasattr(generation_config, "pad_token_id") + and torch.any(input_ids == generation_config.pad_token_id) + and attention_mask is None + ): + warnings.warn( + "Padding was detected but no attention mask is passed here. For correct " + "generation results, please set `attention_mask` when batch-padding inputs.", + UserWarning, + ) + + input_ids, attention_mask = self._expand_inputs_for_generation( + expand_size=generation_config.num_return_sequences, input_ids=input_ids, attention_mask=attention_mask + ) + + block_size = kwargs.pop("block_size", None) + use_cache = kwargs.pop("use_cache", False) + causal_cache = kwargs.pop("causal_cache", False) + + if block_size is None: + # Default diffusion generation + result = self._sample( + input_ids, + attention_mask=attention_mask, + generation_config=generation_config, + generation_tokens_hook_func=generation_tokens_hook_func, + generation_logits_hook_func=generation_logits_hook_func, + ) + return result + else: + if causal_cache: + # Block generation with casual KV Caching only works for Flash attention + result = self._sample_with_block_with_causal_kv( + input_ids, + attention_mask=attention_mask, + generation_config=generation_config, + generation_tokens_hook_func=generation_tokens_hook_func, + generation_logits_hook_func=generation_logits_hook_func, + block_size=block_size, + ) + return result + else: + # Block generation with (diffusion) KV Caching + result = self._sample_with_block( + input_ids, + attention_mask=attention_mask, + generation_config=generation_config, + generation_tokens_hook_func=generation_tokens_hook_func, + generation_logits_hook_func=generation_logits_hook_func, + block_size=block_size, + use_cache=use_cache, + ) + return result + + # loop confidence implementation - working same results for bs 1 + def _sample( + self, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.LongTensor], + generation_config: DreamGenerationConfig, + generation_tokens_hook_func, + generation_logits_hook_func, + ) -> Union[DreamModelOutput, torch.LongTensor]: + # init values + output_history = generation_config.output_history + return_dict_in_generate = generation_config.return_dict_in_generate + max_length = generation_config.max_length + mask_token_id = generation_config.mask_token_id + steps = generation_config.steps + eps = generation_config.eps + alg = generation_config.alg + alg_temp = generation_config.alg_temp + temperature = generation_config.temperature + top_p = generation_config.top_p + top_k = generation_config.top_k + + histories = [] if (return_dict_in_generate and output_history) else None + + pad_token_id = generation_config.pad_token_id + eos_token_id = generation_config.eos_token_id + + # pad input_ids to max_length + x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id) + + if attention_mask is not None and torch.any(attention_mask == 0.0): + # we do not mask the [MASK] tokens so value = 1.0 + attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0) + tok_idx = attention_mask.long().cumsum(-1) - 1 + tok_idx.masked_fill_(attention_mask == 0, 1) + # attention_mask is of shape [B, N] + # broadcast to [B, 1, N, N] + attention_mask = torch.logical_and( + attention_mask.unsqueeze(1).unsqueeze(-2), + attention_mask.unsqueeze(1).unsqueeze(-1), + ) + else: + tok_idx = None + attention_mask = "full" + + timesteps = torch.linspace(1, eps, steps + 1, device=x.device) + + input_ids_length = input_ids.shape[1] + batch_size = input_ids.shape[0] + + # this allows user-defined token control of the intermediate steps + # x = generation_tokens_hook_func(None, x, None, input_ids_length) + + for i in range(steps): + + logits = self(x, attention_mask, tok_idx).logits + logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1) + + # this allows user-defined logits control of the intermediate steps + logits = generation_logits_hook_func(i, x, logits) + + t = timesteps[i] + s = timesteps[i + 1] + + # loop around the batch + for b in range(batch_size): + x_row = x[b, :] + mask_index = x_row == mask_token_id + # if the sequence is already completed, skip it + if mask_index.sum() == 0: + continue + mask_logits = logits[b, mask_index] + + if alg == "origin": + # p_transfer = 1 - s / t if i < steps - 1 else 1 + # x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id + # transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer + # _, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k) + # x[mask_index] = x0.clone() + raise RuntimeError("batch origin alg is not supported") + else: + if alg == "maskgit_plus": + confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k) + elif alg == "topk_margin": + confidence, x0 = sample_tokens( + mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True + ) + elif alg == "entropy": + confidence, x0 = sample_tokens( + mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True + ) + else: + raise RuntimeError(f"Unknown alg: {alg}") + num_mask_token = mask_index.sum() + number_transfer_tokens = ceil(num_mask_token * (1 - s / t)) if i < steps - 1 else num_mask_token + + if number_transfer_tokens > 0: + if alg_temp is None or alg_temp == 0: + _, transfer_index = torch.topk(confidence, number_transfer_tokens) + + else: + confidence = confidence / alg_temp + confidence = F.softmax(confidence, dim=-1) + + transfer_index = torch.multinomial(confidence, num_samples=number_transfer_tokens) + x0_ = torch.zeros_like(x0, device=self.device, dtype=torch.long) + mask_token_id + x0_[transfer_index] = x0[transfer_index].clone() + x[b, mask_index] = x0_ + + # this allows user-defined token control of the intermediate steps + x = generation_tokens_hook_func(i, x, logits, input_ids_length) + + if not torch.any(x == mask_token_id): + break + + # Update attention mask based on pad_token_id and eos_token_id + attention_mask_tmp = torch.where( + (x == pad_token_id) | (x == eos_token_id), + torch.tensor(0, device=x.device, dtype=torch.bool), + torch.tensor(1, device=x.device, dtype=torch.bool), + ) + attention_mask_tmp = torch.logical_and( + attention_mask_tmp.unsqueeze(1).unsqueeze(-2), + attention_mask_tmp.unsqueeze(1).unsqueeze(-1), + ) + # print(f"attention_mask: {attention_mask_tmp.shape} {attention_mask_tmp}") + attention_mask = attention_mask_tmp + + if histories is not None: + histories.append(x.clone()) + + if return_dict_in_generate: + return DreamModelOutput( + sequences=x, + history=histories, + ) + else: + return x + + # block generation with kv cache + def _sample_with_block( + self, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.LongTensor], + generation_config: DreamGenerationConfig, + block_size: int, + use_cache: bool, + generation_tokens_hook_func, + generation_logits_hook_func, + ) -> Union[DreamModelOutput, torch.LongTensor]: + # init values + output_history = generation_config.output_history + return_dict_in_generate = generation_config.return_dict_in_generate + max_length = generation_config.max_length + mask_token_id = generation_config.mask_token_id + steps = generation_config.steps + eps = generation_config.eps + alg = generation_config.alg + alg_temp = generation_config.alg_temp + temperature = generation_config.temperature + top_p = generation_config.top_p + top_k = generation_config.top_k + use_cache = use_cache + + histories = [] if (return_dict_in_generate and output_history) else None + + pad_token_id = generation_config.pad_token_id + eos_token_id = generation_config.eos_token_id + block_size = block_size + gen_length = generation_config.max_new_tokens + num_of_blocks = gen_length // block_size + steps = steps // num_of_blocks + + assert gen_length % block_size == 0, "gen_length should be divisible by block_size" + + # pad input_ids to max_length + x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id) + + # TODO: Avoid this check and all future checks by always creating a mask + # If any padding tokens i.e 0 in attention mask + if attention_mask is not None and torch.any(attention_mask == 0.0): + # we do not mask the [MASK] tokens so value = 1.0 + attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0) + tok_idx_base = attention_mask.long().cumsum(-1) - 1 + tok_idx_base.masked_fill_(attention_mask == 0, 1) + # Leave padding out "<|endoftext|>1+1=2 2+2=" -> [ 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17] + + # attention_mask is of shape [B, N] + # broadcast to [B, 1, N, N] + # Set False for padding tokens and rest True + attention_mask = torch.logical_and( + attention_mask.unsqueeze(1).unsqueeze(-2), + attention_mask.unsqueeze(1).unsqueeze(-1), + ) + else: + tok_idx_base = None + attention_mask = "full" + + timesteps = torch.linspace(1, eps, steps + 1, device=x.device) + + input_ids_length = input_ids.shape[1] + batch_size = input_ids.shape[0] + + past_key_values = None + past_length = 0 + settled_length = input_ids_length + x_input = x.clone() + tok_idx = tok_idx_base.clone() if tok_idx_base is not None else None + + for blk_indx in range(num_of_blocks): + current_block = (num_of_blocks - (blk_indx + 1)) * block_size + + for i in range(steps): + + model_outputs = self( + x_input, attention_mask, tok_idx, use_cache=use_cache, past_key_values=past_key_values + ) + + logits = model_outputs.logits + logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1) + + t = timesteps[i] + s = timesteps[i + 1] + + # loop around the batch + for b in range(batch_size): + x_row = x_input[b, :] + mask_index = x_row == mask_token_id + + # if the sequence is already completed, skip it + if mask_index.sum() == 0: + continue + + if current_block > 0: + mask_index[-current_block:] = False + mask_logits = logits[b, mask_index] + + if alg == "origin": + # p_transfer = 1 - s / t if i < steps - 1 else 1 + # x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id + # transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer + # _, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k) + # x[mask_index] = x0.clone() + raise RuntimeError("batch origin alg is not supported") + else: + if alg == "maskgit_plus": + confidence, x0 = sample_tokens( + mask_logits, temperature=temperature, top_p=top_p, top_k=top_k + ) + elif alg == "topk_margin": + confidence, x0 = sample_tokens( + mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True + ) + elif alg == "entropy": + confidence, x0 = sample_tokens( + mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True + ) + else: + raise RuntimeError(f"Unknown alg: {alg}") + num_mask_token = mask_index.sum() + number_transfer_tokens = ( + ceil(num_mask_token * (1 - s / t)) if i < steps - 1 else num_mask_token + ) + + # print(f"block: {blk_indx} step: {i} batch: {b} confidence: {confidence} x0: {x0}") + # print(f"number_transfer_tokens: {number_transfer_tokens} num_mask_token: {num_mask_token} ") + if number_transfer_tokens > 0: + if alg_temp is None or alg_temp == 0: + _, transfer_index = torch.topk(confidence, number_transfer_tokens) + + else: + confidence = confidence / alg_temp + confidence = F.softmax(confidence, dim=-1) + + transfer_index = torch.multinomial(confidence, num_samples=number_transfer_tokens) + x0_ = torch.zeros_like(x0, device=self.device, dtype=torch.long) + mask_token_id + x0_[transfer_index] = x0[transfer_index].clone() + x_input[b, mask_index] = x0_ + + # this allows user-defined token control of the intermediate steps + x_input = generation_tokens_hook_func(i, x_input, logits, input_ids_length) + + if use_cache: + # 1. Update settled tokens + x[:, past_length:] = x_input + + # TODO: We can avoid these updates by setting a flag in the Attention call to not set KVs for these forward passes and only set when we reach end of the block + # Prepare for next forward pass + # 2. Update past_key_values to include only settled tokens from previous blocks + past_key_values = model_outputs.past_key_values + # need to reset this since the Attention call will add new KVs + past_key_values.crop(settled_length) + # past_key_values are already set from the last forward pass + past_length = past_key_values.get_seq_length() + + # 3. Generic cache-dependent input and position index + # https://github.com/huggingface/transformers/blob/5f4ecf2d9f867a1255131d2461d75793c0cf1db2/src/transformers/generation/utils.py#L410C13-L410C53 + x_input = x[:, past_length:].clone() + tok_idx = tok_idx_base[:, past_length:] if tok_idx is not None else None + + # TODO: optimize this we don't need to compute this every forward pass maybe only change location where tokens are settled; adhering to early stopping + # 4. Set attention mask + # Update attention mask based from the full x to capture past eos and pad tokens masks + attention_mask_tmp = torch.where( + (x == pad_token_id) | (x == eos_token_id), + torch.tensor(0, device=x.device, dtype=torch.bool), + torch.tensor(1, device=x.device, dtype=torch.bool), + ) + attention_mask_tmp = torch.logical_and( + attention_mask_tmp.unsqueeze(1).unsqueeze(-2), + attention_mask_tmp.unsqueeze(1).unsqueeze(-1), + ) + + # Drop values from the 3rd dimension to the size of new x_input so that it current Qs (aka inputs) + # [B, 1, Q_dim, K_dim] + attention_mask_tmp = attention_mask_tmp[:, :, past_length:, :] + attention_mask = attention_mask_tmp + # print(f"attention_mask: {attention_mask_tmp.shape}") + + else: + x = x_input + + # Set attention mask + # Update attention mask based on pad_token_id and eos_token_id + attention_mask_tmp = torch.where( + (x_input == pad_token_id) | (x_input == eos_token_id), + torch.tensor(0, device=x.device, dtype=torch.bool), + torch.tensor(1, device=x.device, dtype=torch.bool), + ) + attention_mask_tmp = torch.logical_and( + attention_mask_tmp.unsqueeze(1).unsqueeze(-2), + attention_mask_tmp.unsqueeze(1).unsqueeze(-1), + ) + attention_mask = attention_mask_tmp + # No need to update tok_idx since we are computing all KVs with original positions again + + if histories is not None: + histories.append(x.clone()) + + if not torch.any(x == mask_token_id): + # print("unmasked all tokens in current x exiting") + break + + # print(f"x_input: {x_input.shape} tok_idx: {tok_idx.shape if tok_idx is not None else None} {tok_idx}") + + # A block is completed update settled tokens length + if not torch.any(x == mask_token_id): + # print("unmasked all tokens in current x exiting") + break + settled_length += block_size + + if return_dict_in_generate: + return DreamModelOutput( + sequences=x, + history=histories, + ) + else: + return x + + # block generation with casual kv cache for flash attention ONLY + def _sample_with_block_with_causal_kv( + self, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.LongTensor], + generation_config: DreamGenerationConfig, + block_size: int, + generation_tokens_hook_func, + generation_logits_hook_func, + ) -> Union[DreamModelOutput, torch.LongTensor]: + # init values + output_history = generation_config.output_history + return_dict_in_generate = generation_config.return_dict_in_generate + max_length = generation_config.max_length + mask_token_id = generation_config.mask_token_id + steps = generation_config.steps + eps = generation_config.eps + alg = generation_config.alg + alg_temp = generation_config.alg_temp + temperature = generation_config.temperature + top_p = generation_config.top_p + top_k = generation_config.top_k + generation_config.pad_token_id + generation_config.eos_token_id + + histories = [] if (return_dict_in_generate and output_history) else None + + block_size = block_size + gen_length = generation_config.max_new_tokens + num_of_blocks = gen_length // block_size + steps = steps // num_of_blocks + + assert gen_length % block_size == 0, "gen_length should be divisible by block_size" + + # pad input_ids to max_length + x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id) + + # If any padding tokens i.e 0 in attention mask + if attention_mask is not None and torch.any(attention_mask == 0.0): + # we do not mask the [MASK] tokens so value = 1.0 + attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0) + tok_idx_base = attention_mask.long().cumsum(-1) - 1 + tok_idx_base.masked_fill_(attention_mask == 0, 1) + # Leave padding out "<|endoftext|>1+1=2 2+2=" -> [ 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17] + + # attention_mask is of shape [B, N] + # broadcast to [B, 1, N, N] + # Set False for padding tokens and rest True + attention_mask = torch.logical_and( + attention_mask.unsqueeze(1).unsqueeze(-2), + attention_mask.unsqueeze(1).unsqueeze(-1), + ) + else: + tok_idx_base = None + attention_mask = "full" + + timesteps = torch.linspace(1, eps, steps + 1, device=x.device) + + input_ids_length = input_ids.shape[1] + batch_size = input_ids.shape[0] + + past_key_values = None + past_length = 0 + tok_idx = tok_idx_base.clone() if tok_idx_base is not None else None + x_input = x.clone() + # initial settled length is the context/prompt length + settled_length = input_ids_length + # 1. Do first forward pass to get past_key_values for context in casual attention + model_outputs = self( + x_input, + attention_mask=attention_mask, + position_ids=tok_idx, + use_cache=True, + past_key_values=past_key_values, + is_causal=True, + ) + past_key_values = model_outputs.past_key_values + # 2. Crop past_key_values to include only context tokens + past_key_values.crop(settled_length) + past_length = past_key_values.get_seq_length() + # 3. Create new input for prediction + x_input = x[:, past_length:].clone() + tok_idx = tok_idx_base[:, past_length:] if tok_idx_base is not None else None + + # print(f"settled_length: {settled_length} past_length: {past_length} x_input: {x_input.shape} past_key_values: {past_key_values.get_seq_length()}") + + for blk_indx in range(num_of_blocks): + current_block = (num_of_blocks - (blk_indx + 1)) * block_size + + for i in range(steps): + model_outputs = self( + x_input, + attention_mask= # The above code is defining a variable + # named `attention_mask` in Python. + attention_mask, + position_ids=tok_idx, + use_cache=True, + past_key_values=past_key_values, + is_causal=False, + ) + + logits = model_outputs.logits + logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1) + + t = timesteps[i] + s = timesteps[i + 1] + + # loop around the batch + for b in range(batch_size): + x_row = x_input[b, :] + mask_index = x_row == mask_token_id + + # if the sequence is already completed, skip it + if mask_index.sum() == 0: + continue + + if current_block > 0: + mask_index[-current_block:] = False + + mask_logits = logits[b, mask_index] + + if alg == "origin": + # p_transfer = 1 - s / t if i < steps - 1 else 1 + # x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id + # transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer + # _, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k) + # x[mask_index] = x0.clone() + raise RuntimeError("batch origin alg is not supported") + else: + if alg == "maskgit_plus": + confidence, x0 = sample_tokens( + mask_logits, temperature=temperature, top_p=top_p, top_k=top_k + ) + elif alg == "topk_margin": + confidence, x0 = sample_tokens( + mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True + ) + elif alg == "entropy": + confidence, x0 = sample_tokens( + mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True + ) + else: + raise RuntimeError(f"Unknown alg: {alg}") + num_mask_token = mask_index.sum() + number_transfer_tokens = ( + ceil(num_mask_token * (1 - s / t)) if i < steps - 1 else num_mask_token + ) + + # print(f"block: {blk_indx} step: {i} batch: {b} confidence: {confidence} x0: {x0}") + # print(f"number_transfer_tokens: {number_transfer_tokens} num_mask_token: {num_mask_token}") + if number_transfer_tokens > 0: + if alg_temp is None or alg_temp == 0: + _, transfer_index = torch.topk(confidence, number_transfer_tokens) + + else: + confidence = confidence / alg_temp + confidence = F.softmax(confidence, dim=-1) + + transfer_index = torch.multinomial(confidence, num_samples=number_transfer_tokens) + x0_ = torch.zeros_like(x0, device=self.device, dtype=torch.long) + mask_token_id + x0_[transfer_index] = x0[transfer_index].clone() + x_input[b, mask_index] = x0_ + + # this allows user-defined token control of the intermediate steps + x_input = generation_tokens_hook_func(i, x_input, logits, input_ids_length) + + # Update settled tokens + x[:, past_length:] = x_input + + # Prepare for next forward pass + + # 1. Update past_key_values to include only settled tokens from previous blocks + # past_key_values = model_outputs.past_key_values + + # needed bcuz Attention module adds them to cache so we need to remove them for next forward pass + # we can stop the Attention module from adding them with a param for speedup + past_key_values.crop(settled_length) + # past_length = past_key_values.get_seq_length() + # print(f"past_length: {past_length} x_input: {x_input.shape} past_length: {past_length} past_key_values: {past_key_values.get_seq_length()}") + + # # only works for sdpa + # attention_mask_tmp = torch.where( + # (x == pad_token_id) | (x == eos_token_id), + # torch.tensor(0, device=x.device, dtype=torch.bool), + # torch.tensor(1, device=x.device, dtype=torch.bool) + # ) + # attention_mask_tmp = torch.logical_and( + # attention_mask_tmp.unsqueeze(1).unsqueeze(-2), + # attention_mask_tmp.unsqueeze(1).unsqueeze(-1), + # ) + + # # Drop values from the 3rd dimension to the size of new x_input so that it current Qs (aka inputs) + # # [B, 1, Q_dim, K_dim] + # attention_mask_tmp = attention_mask_tmp[:, :, past_length:, :] + # attention_mask = attention_mask_tmp + + if histories is not None: + histories.append(x.clone()) + + # A block is completed update settled tokens length + if not torch.any(x == mask_token_id): + break + settled_length += block_size + model_outputs = self( + x_input, + attention_mask=attention_mask, + position_ids=tok_idx, + use_cache=True, + past_key_values=past_key_values, + is_causal=True, + ) + past_key_values = model_outputs.past_key_values + past_key_values.crop(settled_length) + past_length = past_key_values.get_seq_length() + x_input = x[:, past_length:].clone() + tok_idx = tok_idx_base[:, past_length:] if tok_idx is not None else None + + # # Only works for sdpa + # attention_mask_tmp = torch.where( + # (x == pad_token_id) | (x == eos_token_id), + # torch.tensor(0, device=x.device, dtype=torch.bool), + # torch.tensor(1, device=x.device, dtype=torch.bool) + # ) + # attention_mask_tmp = torch.logical_and( + # attention_mask_tmp.unsqueeze(1).unsqueeze(-2), + # attention_mask_tmp.unsqueeze(1).unsqueeze(-1), + # ) + + # # Drop values from the 3rd dimension to the size of new x_input so that it current Qs (aka inputs) + # # [B, 1, Q_dim, K_dim] + # attention_mask_tmp = attention_mask_tmp[:, :, past_length:, :] + # attention_mask = attention_mask_tmp + # print(f"settled_length: {settled_length} past_length: {past_length} x_input: {x_input.shape} past_key_values: {past_key_values.get_seq_length()}") + + if return_dict_in_generate: + return DreamModelOutput( + sequences=x, + history=histories, + ) + else: + return x diff --git a/fast_llm/models/gpt/external/diffusion_dream/modeling_dream.py b/fast_llm/models/gpt/external/diffusion_dream/modeling_dream.py new file mode 100644 index 00000000..e041d618 --- /dev/null +++ b/fast_llm/models/gpt/external/diffusion_dream/modeling_dream.py @@ -0,0 +1,996 @@ +# coding=utf-8 +# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT and Qwen implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT and Qwen used by the Meta AI and Qwen team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Dream model.""" + +import math +from typing import List, Optional, Tuple, Union +import os +import torch +import torch.utils.checkpoint +from torch import nn +from dataclasses import dataclass + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPast, + MaskedLMOutput, +) +from transformers.utils import ModelOutput +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, +) +from transformers import PretrainedConfig +from .configuration_dream import DreamConfig +from .generation_utils import DreamGenerationMixin, DreamGenerationConfig + +if is_flash_attn_2_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward + from flash_attn import flash_attn_with_kvcache, flash_attn_func + + +logger = logging.get_logger(__name__) + + +_CHECKPOINT_FOR_DOC = "Dream-7B" +_CONFIG_FOR_DOC = "DreamConfig" + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Dream +class DreamRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DreamRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Dream +class DreamRotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[DreamConfig] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`DreamRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def reset_parameters(self): + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, self.inv_freq.device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Dream +class DreamMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class DreamAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: DreamConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = False + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = DreamRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + # Luke: Computing K and Vs for all tokens upto now q_len w/o using cache ? + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class DreamSdpaAttention(DreamAttention): + """ + Dream attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `DreamAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from DreamAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + is_causal: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "DreamModel is using DreamSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # print(f"query_states {query_states.shape} {query_states}") + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + # is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + +class DreamFlashAttention(DreamAttention): + """ + Dream attention module using Flash attention 2. + """ + + # Adapted from DreamAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + is_causal: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "DreamModel is using DreamFlashAttention, it does not support `output_attentions=True`. Falling back to the manual attention implementation, " + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + # print(f"hidden_states: {hidden_states.shape} query_states {query_states.shape} key_states {key_states.shape} value_states {value_states.shape}") + # print(f"position_ids {position_ids} {position_ids.shape}") + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + + # if query_states.device.type == "cuda" and attention_mask is not None: + # query_states = query_states.contiguous() + # key_states = key_states.contiguous() + # value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + # is_causal = True if causal_mask is None and q_len > 1 else False + + # attn_output_sdpa = torch.nn.functional.scaled_dot_product_attention( + # query_states, + # key_states, + # value_states, + # attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None, + # dropout_p=self.attention_dropout if self.training else 0.0, + # is_causal=False, # hard coded + # ) + + # print(f"query_states {query_states.shape} key_states {key_states.shape} value_states {value_states.shape}") + + # replacing with flash attention + attn_output = flash_attn_with_kvcache( + # q dim (batch_size, seqlen, nheads, headdim) + q=query_states.transpose(1, 2).contiguous(), + k_cache=key_states.transpose(1, 2).contiguous(), + v_cache=value_states.transpose(1, 2).contiguous(), + causal=is_causal, # hard coded + softmax_scale=1.0 / math.sqrt(self.head_dim), + ) + + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + +class DreamDecoderLayer(nn.Module): + def __init__(self, config: DreamConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + + # self.self_attn = Dream_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + if config._attn_implementation == "flash_attention_2": + self.self_attn = DreamFlashAttention(config, layer_idx) + else: + self.self_attn = DreamSdpaAttention(config, layer_idx) + + self.mlp = DreamMLP(config) + self.input_layernorm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + # print(f"DreamDecoderLayer: past_key_value {past_key_value} use_cache {use_cache}") + + is_casual = kwargs.get("is_casual", False) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + is_causal=is_casual, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + # When use_cache is True, outputs will have length: + # - 2 if output_attentions is False (hidden_states, present_key_value) + # - 3 if output_attentions is True (hidden_states, self_attn_weights, present_key_value) + # print(f"DreamDecoderLayer: outputs {len(outputs)}") + return outputs + +class DreamPreTrainedModel(PreTrainedModel): + config_class = DreamConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DreamDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *model_args, + config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + ignore_mismatched_sizes: bool = False, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + use_safetensors: Optional[bool] = None, + weights_only: bool = True, + **kwargs, + ): + _model = super().from_pretrained( + pretrained_model_name_or_path, + *model_args, + config=config, + cache_dir=cache_dir, + ignore_mismatched_sizes=ignore_mismatched_sizes, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + use_safetensors=use_safetensors, + weights_only=weights_only, + **kwargs, + ) + # NOTE(Lin): we need to override the generation config + # because the generation config loaded in `from_pretrained` + # does not include all the attributes of DreamGenerationConfig + resume_download = kwargs.get("resume_download", None) + proxies = kwargs.get("proxies", None) + subfolder = kwargs.get("subfolder", "") + from_auto_class = kwargs.get("_from_auto", False) + from_pipeline = kwargs.get("_from_pipeline", None) + _model.generation_config = DreamGenerationConfig.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + ) + return _model + +class DreamBaseModel(DreamPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DreamDecoderLayer`] + + Args: + config: DreamConfig + """ + + def __init__(self, config: DreamConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [DreamDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = DreamRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + is_casual: Optional[bool] = False, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + # print("DreamBaseModel: past_key_values", past_key_values, "use_cache", use_cache,) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + is_casual=is_casual, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attns] if v is not None) + # return BaseModelOutput( + # last_hidden_state=hidden_states, + # hidden_states=all_hidden_states, + # attentions=all_self_attns, + # ) + if use_cache: + # print("past_key_values", past_key_values, "use_cache", use_cache, "layer_outputs", layer_outputs) + past_key_values = layer_outputs[-1] + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) +@dataclass +class MaskedLMOutputWithPast(ModelOutput): + """ + Base class for masked language models outputs with past. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Masked language modeling (MLM) loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + past_key_values: Optional[Tuple[Cache]] = None + +class DreamModel(DreamGenerationMixin, DreamPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = DreamBaseModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def reset_rope_parameters(self): + self.model.rotary_emb.reset_parameters() + for layer in self.model.layers: + layer.self_attn.rotary_emb.reset_parameters() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **loss_kwargs, + ) -> Union[Tuple, MaskedLMOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + # print("DreamModel: past_key_values", past_key_values, "use_cache", use_cache) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + is_casual=loss_kwargs.get("is_casual", False), + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return MaskedLMOutputWithPast( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + past_key_values=outputs.past_key_values, + ) \ No newline at end of file diff --git a/fast_llm/models/gpt/external/diffusion_llama/configuration_diffusion_llama.py b/fast_llm/models/gpt/external/diffusion_llama/configuration_diffusion_llama.py new file mode 100644 index 00000000..28bf0efe --- /dev/null +++ b/fast_llm/models/gpt/external/diffusion_llama/configuration_diffusion_llama.py @@ -0,0 +1,453 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LLaMA model configuration""" + +import math +from typing import Optional, Tuple + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import is_torch_available, logging + +logger = logging.get_logger(__name__) + +if is_torch_available(): + import torch + +# Update yarn implementation for RoPE (Taken from Llama but updated to use original_max_position_embeddings) +def _compute_default_rope_parameters( + config: Optional[PretrainedConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + **rope_kwargs, +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + if config is not None and len(rope_kwargs) > 0: + raise ValueError( + "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " + f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" + ) + if len(rope_kwargs) > 0: + base = rope_kwargs["base"] + dim = rope_kwargs["dim"] + elif config is not None: + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) + return inv_freq, attention_factor + +def _compute_yarn_parameters( + config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with NTK scaling. Please refer to the + [original paper](https://arxiv.org/abs/2309.00071) + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + # No need to keep BC with yarn, unreleased when this new pattern was created. + if len(rope_kwargs) > 0: + raise ValueError( + f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}" + ) + + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + + # Apriel: Use original max_position_embeddings instead of max_position_embeddings + max_position_embeddings = config.rope_scaling.get("original_max_position_embeddings") + factor = config.rope_scaling["factor"] + + # Sets the attention factor as suggested in the paper + attention_factor = config.rope_scaling.get("attention_factor") + if attention_factor is None: + attention_factor = 0.1 * math.log(factor) + 1.0 + + # Optional config options + # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) + beta_fast = config.rope_scaling.get("beta_fast") or 32 + beta_slow = config.rope_scaling.get("beta_slow") or 1 + + # Compute the inverse frequencies + def find_correction_dim(num_rotations, dim, base, max_position_embeddings): + """Inverse dimension formula to find the dimension based on the number of rotations""" + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings): + """Find dimension range bounds based on rotations""" + low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs + # to expand the possible context length. In other words, interpolation = apply scaling factor. + pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (factor * pos_freqs) + + low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings) + + # Get n-dimensional rotational scaling corrected for extrapolation + inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device) + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + + inv_freq_extrapolation * inv_freq_extrapolation_factor + ) + + return inv_freq, attention_factor + +def _check_received_keys( + rope_type: str, + received_keys: set, + required_keys: set, + optional_keys: Optional[set] = None, + ignore_keys: Optional[set] = None, +): + + """Compare the received keys in `config.rope_scaling` against the expected and optional keys""" + # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present + if "type" in received_keys: + received_keys -= {"type"} + required_keys.add("rope_type") + + # Some models need to store model-specific keys, and we don't want to throw warning at them + if ignore_keys is not None: + received_keys -= ignore_keys + + missing_keys = required_keys - received_keys + if missing_keys: + raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}") + + if optional_keys is not None: + unused_keys = received_keys - required_keys - optional_keys + else: + unused_keys = received_keys - required_keys + if unused_keys: + logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") + + +def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) + +def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor", "original_max_position_embeddings"} + optional_keys = {"attention_factor", "beta_fast", "beta_slow"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + attention_factor = rope_scaling.get("attention_factor") + if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0): + logger.warning( + f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + beta_fast = rope_scaling.get("beta_fast") + if beta_fast is not None and not isinstance(beta_fast, float): + logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") + beta_slow = rope_scaling.get("beta_slow") + if beta_slow is not None and not isinstance(beta_slow, float): + logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") + + if (beta_fast or 32) < (beta_slow or 1): + logger.warning( + f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} " + f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" + ) +# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters +# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE +# parameterizations, as long as the callable has the same signature. +ROPE_INIT_FUNCTIONS = { + "default": _compute_default_rope_parameters, + "yarn": _compute_yarn_parameters, +} + +# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types. +ROPE_VALIDATION_FUNCTIONS = { + "default": _validate_default_rope_parameters, + "yarn": _validate_yarn_parameters, +} + +def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None): + """ + Validate the RoPE config arguments, given a `PretrainedConfig` object + """ + rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig` + if rope_scaling is None: + return + + # BC: "rope_type" was originally "type" + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) + validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type) + if validation_fn is not None: + validation_fn(config, ignore_keys=ignore_keys) + else: + logger.warning( + f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" + ) + +class DiffusionLlamaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + e.g. [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`LlamaModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, + Llama 2 up to 4096, CodeLlama up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to + understand more about it. This value is necessary to ensure exact reproducibility of the pretraining + results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + head_dim (`int`, *optional*): + The attention head dimension. If None, it will default to hidden_size // num_attention_heads + + ```python + >>> from transformers import LlamaModel, LlamaConfig + + >>> # Initializing a LLaMA llama-7b style configuration + >>> configuration = LlamaConfig() + + >>> # Initializing a model from the llama-7b style configuration + >>> model = LlamaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "diffusion_llama" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `LlamaModel` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=False, # cache not implemented in diffusion + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + head_dim=None, + # mask_token_id= TODO: add the mask_token_id we will be using, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + # TODO: self.mask_token_id = mask_token_id + +__all__ = ["LlamaConfig"] \ No newline at end of file diff --git a/fast_llm/models/gpt/external/diffusion_llama/generation_utils.py b/fast_llm/models/gpt/external/diffusion_llama/generation_utils.py new file mode 100644 index 00000000..b70dcf49 --- /dev/null +++ b/fast_llm/models/gpt/external/diffusion_llama/generation_utils.py @@ -0,0 +1,1040 @@ +# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import warnings +from dataclasses import dataclass +from math import ceil +from typing import Any, Optional, Union + +import torch +import torch.distributions as dists +from torch.nn import functional as F +from transformers import __version__ +from transformers.generation.configuration_utils import GenerationConfig +from transformers.generation.utils import GenerationMixin +from transformers.utils import ModelOutput, is_torchdynamo_compiling, logging + +logger = logging.get_logger(__name__) + + +def top_p_logits(logits, top_p=None): + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device) + mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove) + logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min) + return logits + + +def top_k_logits(logits, top_k=None): + top_k = min(top_k, logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min) + return logits + + +def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False): + + if temperature > 0: + logits = logits / temperature + if top_p is not None and top_p < 1: + logits = top_p_logits(logits, top_p) + if top_k is not None: + logits = top_k_logits(logits, top_k) + probs = torch.softmax(logits, dim=-1) + + if temperature > 0: + try: + x0 = dists.Categorical(probs=probs).sample() + confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) + except: + confidence, x0 = probs.max(dim=-1) + else: + confidence, x0 = probs.max(dim=-1) + + if margin_confidence: + sorted_probs, _ = torch.sort(probs, dim=-1, descending=True) + # Extract top1 and top2 probabilities + top1_probs = sorted_probs[:, 0] + top2_probs = sorted_probs[:, 1] + # Calculate confidence as top1 - top2 + confidence = top1_probs - top2_probs + + if neg_entropy: + epsilon = 1e-10 + log_probs = torch.log(probs + epsilon) + confidence = torch.sum(probs * log_probs, dim=-1) + + return confidence, x0 + + +# batch_sample_tokens +def batch_sample_tokens( + logits, mask_indexes, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False +): + # print(f"batch_sample_tokens: {logits.shape} ") + if temperature > 0: + logits = logits / temperature + if top_p is not None and top_p < 1: + # logit will have different lengths for each sequence so cannot stack - it is not a proper batch??? + logits = torch.stack([top_p_logits(logit[mask], top_p) for logit, mask in zip(logits, mask_indexes)], dim=0) + if top_k is not None: + logits = torch.stack([top_k_logits(logit[mask], top_k) for logit, mask in zip(logits, mask_indexes)], dim=0) + + # if logits are not of the same sequence so therefore we can pad them with -inf but need remove them back ... + probs = torch.softmax(logits, dim=-1) + # print(f"probs: {probs.shape}") + + if temperature > 0: + try: + x0 = dists.Categorical(probs=probs).sample() + confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) + except: + confidence, x0 = probs.max(dim=-1) + else: + confidence, x0 = probs.max(dim=-1) + + # print(f"confidence: {confidence.shape} x0: {x0.shape}") + if margin_confidence: + sorted_probs, _ = torch.sort(probs, dim=-1, descending=True) + # Extract top1 and top2 probabilities + top1_probs = sorted_probs[:, 0] + top2_probs = sorted_probs[:, 1] + # Calculate confidence as top1 - top2 + confidence = top1_probs - top2_probs + + if neg_entropy: + epsilon = 1e-10 + log_probs = torch.log(probs + epsilon) + confidence = torch.sum(probs * log_probs, dim=-1) + + return confidence, x0 + + +@dataclass +class DreamModelOutput(ModelOutput): + sequences: torch.LongTensor = None + history: Optional[tuple[torch.FloatTensor]] = None + + +class DreamGenerationConfig(GenerationConfig): + def __init__(self, **kwargs): + self.temperature: float = kwargs.pop("temperature", 0.0) + self.top_p: Optional[float] = kwargs.pop("top_p", None) + self.top_k: Optional[int] = kwargs.pop("top_k", None) + self.max_length = kwargs.pop("max_length", 20) + self.max_new_tokens = kwargs.pop("max_new_tokens", None) + # diffusion specific params + self.eps: float = kwargs.pop("eps", 1e-3) + self.steps: int = kwargs.pop("steps", 512) + self.alg: str = kwargs.pop("alg", "origin") + self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None) + + # Parameters that define the output variables of `generate` + self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1) + self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False) + self.output_history: bool = kwargs.pop("output_history", False) + + # Special tokens that can be used at generation time + self.mask_token_id = kwargs.pop("mask_token_id", None) + self.pad_token_id = kwargs.pop("pad_token_id", None) + self.bos_token_id = kwargs.pop("bos_token_id", None) + self.eos_token_id = kwargs.pop("eos_token_id", None) + + # Wild card + self.generation_kwargs = kwargs.pop("generation_kwargs", {}) + + # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub + # interface. + self._from_model_config = kwargs.pop("_from_model_config", False) + self._commit_hash = kwargs.pop("_commit_hash", None) + self.transformers_version = kwargs.pop("transformers_version", __version__) + + # Additional attributes without default values + if not self._from_model_config: + # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a + # model's default configuration file + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error(f"Can't set {key} with value {value} for {self}") + raise err + + # Validate the values of the attributes + self.validate(is_init=True) + + def validate(self, is_init=False): + pass + + +class DreamGenerationMixin(GenerationMixin): + @staticmethod + def _expand_inputs_for_generation( + expand_size: int = 1, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]""" + # Do not call torch.repeat_interleave if expand_size is 1 because it clones + # the input tensor and thus requires more memory although no change is applied + if expand_size == 1: + return input_ids, attention_mask + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + if attention_mask is not None: + attention_mask = attention_mask.repeat_interleave(expand_size, dim=0) + return input_ids, attention_mask + + def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length): + """Performs validation related to the resulting generated length""" + + # Can't throw warnings/exceptions during compilation + if is_torchdynamo_compiling(): + return + + # 1. Max length warnings related to poor parameterization + if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: + # 20 is the default max_length of the generation config + warnings.warn( + f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the " + "generation length. We recommend setting `max_new_tokens` to control the maximum length of the " + "generation.", + UserWarning, + ) + if input_ids_length >= generation_config.max_length: + input_ids_string = "input_ids" + raise ValueError( + f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_length` or, better yet, setting `max_new_tokens`." + ) + + def _prepare_generated_length( + self, + generation_config, + has_default_max_length, + input_ids_length, + ): + """Prepared max and min length in generation configs to avoid clashes between similar attributes""" + + if generation_config.max_new_tokens is not None: + if not has_default_max_length and generation_config.max_length is not None: + logger.warning( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + generation_config.max_length = generation_config.max_new_tokens + input_ids_length + + elif has_default_max_length: + if generation_config.max_length == DreamGenerationConfig().max_length: + generation_config.max_length = generation_config.max_length + input_ids_length + max_position_embeddings = getattr(self.config, "max_position_embeddings", None) + if max_position_embeddings is not None: + generation_config.max_length = min(generation_config.max_length, max_position_embeddings) + + return generation_config + + def _prepare_generation_config( + self, generation_config: Optional[DreamGenerationConfig], **kwargs: dict + ) -> DreamGenerationConfig: + """ + Prepares the base generation config, then applies any generation configuration options from kwargs. This + function handles retrocompatibility with respect to configuration files. + """ + # priority: `generation_config` argument > `model.generation_config` (the default generation config) + using_model_generation_config = False + if generation_config is None: + generation_config = DreamGenerationConfig.from_model_config(self.config) + using_model_generation_config = True + + # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config` + # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an + # exception will be raised in `_validate_model_kwargs` + if not is_torchdynamo_compiling(): + generation_config = copy.deepcopy(generation_config) + generation_config.update(**kwargs) + # If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model + if not using_model_generation_config: + if generation_config.bos_token_id is None: + generation_config.bos_token_id = self.generation_config.bos_token_id + if generation_config.eos_token_id is None: + generation_config.eos_token_id = self.generation_config.eos_token_id + if generation_config.pad_token_id is None: + generation_config.pad_token_id = self.generation_config.pad_token_id + if generation_config.mask_token_id is None: + generation_config.mask_token_id = self.generation_config.mask_token_id + + return generation_config + + def _prepare_special_tokens( + self, + generation_config: DreamGenerationConfig, + device: Optional[Union[torch.device, str]] = None, + ): + """ + Prepares the special tokens for generation, overwriting the generation config with their processed versions + converted to tensor. + + Note that `generation_config` is changed in place and stops being serializable after this method is called. + That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the + function). However, if called outside `generate`, consider creating a copy of `generation_config` first. + """ + + # Convert special tokens to tensors + def _tensor_or_none(token, device=None): + if token is None: + return token + + device = device if device is not None else self.device + if isinstance(token, torch.Tensor): + return token.to(device) + return torch.tensor(token, device=device, dtype=torch.long) + + bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device) + eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device) + pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device) + mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device) + + # We can have more than one eos token. Always treat it as a 1D tensor (when it exists). + if eos_token_tensor is not None and eos_token_tensor.ndim == 0: + eos_token_tensor = eos_token_tensor.unsqueeze(0) + + # Set pad token if unset (and there are conditions to do so) + if pad_token_tensor is None and eos_token_tensor is not None: + pad_token_tensor = eos_token_tensor[0] + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.") + + # Update generation config with the updated special tokens tensors + # NOTE: this must be written into a different attribute name than the one holding the original special tokens + # (in their non-tensor form), in order to enable end-to-end compilation. See + # https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations + generation_config._bos_token_tensor = bos_token_tensor + generation_config._eos_token_tensor = eos_token_tensor + generation_config._pad_token_tensor = pad_token_tensor + generation_config._mask_token_tensor = mask_token_tensor + + @torch.no_grad() + def diffusion_generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[DreamGenerationConfig] = None, + **kwargs, + ) -> Union[DreamModelOutput, torch.LongTensor]: + # fix seed for reproducability torch.random.manual_seed - lm-eval is setting it + torch.random.manual_seed(0) + + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call + generation_config = self._prepare_generation_config(generation_config, **kwargs) + generation_tokens_hook_func = kwargs.pop( + "generation_tokens_hook_func", lambda step, x, logits, end_of_prompt: x + ) + generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits) + + # 2. Define model inputs + assert inputs is not None + input_ids = inputs + device = input_ids.device + attention_mask = kwargs.pop("attention_mask", None) + self._prepare_special_tokens(generation_config, device=device) + + # 3. Prepare `max_length`. + input_ids_length = input_ids.shape[-1] + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length= # The code `has_default_max_length` is not a valid Python code + # snippet. It seems to be a placeholder or a comment in the code. + # It does not perform any specific action or functionality in + # Python. + has_default_max_length, + input_ids_length=input_ids_length, + ) + + self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) + + # 4. Check input_ids + if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type: + warnings.warn( + "You are calling .generate() with the `input_ids` being on a device type different" + f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" + f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." + " Please make sure that you have put `input_ids` to the" + f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" + " running `.generate()`.", + UserWarning, + ) + if ( + hasattr(generation_config, "pad_token_id") + and torch.any(input_ids == generation_config.pad_token_id) + and attention_mask is None + ): + warnings.warn( + "Padding was detected but no attention mask is passed here. For correct " + "generation results, please set `attention_mask` when batch-padding inputs.", + UserWarning, + ) + + input_ids, attention_mask = self._expand_inputs_for_generation( + expand_size=generation_config.num_return_sequences, input_ids=input_ids, attention_mask=attention_mask + ) + + block_size = kwargs.pop("block_size", None) + use_cache = kwargs.pop("use_cache", False) + causal_cache = kwargs.pop("causal_cache", False) + + if block_size is None: + # Default diffusion generation + result = self._sample( + input_ids, + attention_mask=attention_mask, + generation_config=generation_config, + generation_tokens_hook_func=generation_tokens_hook_func, + generation_logits_hook_func=generation_logits_hook_func, + ) + return result + else: + if causal_cache: + # Block generation with casual KV Caching only works for Flash attention + result = self._sample_with_block_with_causal_kv( + input_ids, + attention_mask=attention_mask, + generation_config=generation_config, + generation_tokens_hook_func=generation_tokens_hook_func, + generation_logits_hook_func=generation_logits_hook_func, + block_size=block_size, + ) + return result + else: + # Block generation with (diffusion) KV Caching + result = self._sample_with_block( + input_ids, + attention_mask=attention_mask, + generation_config=generation_config, + generation_tokens_hook_func=generation_tokens_hook_func, + generation_logits_hook_func=generation_logits_hook_func, + block_size=block_size, + use_cache=use_cache, + ) + return result + + # loop confidence implementation - working same results for bs 1 + def _sample( + self, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.LongTensor], + generation_config: DreamGenerationConfig, + generation_tokens_hook_func, + generation_logits_hook_func, + ) -> Union[DreamModelOutput, torch.LongTensor]: + # init values + output_history = generation_config.output_history + return_dict_in_generate = generation_config.return_dict_in_generate + max_length = generation_config.max_length + mask_token_id = generation_config.mask_token_id + steps = generation_config.steps + eps = generation_config.eps + alg = generation_config.alg + alg_temp = generation_config.alg_temp + temperature = generation_config.temperature + top_p = generation_config.top_p + top_k = generation_config.top_k + + histories = [] if (return_dict_in_generate and output_history) else None + + pad_token_id = generation_config.pad_token_id + eos_token_id = generation_config.eos_token_id + + # pad input_ids to max_length + x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id) + + if attention_mask is not None and torch.any(attention_mask == 0.0): + # we do not mask the [MASK] tokens so value = 1.0 + attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0) + tok_idx = attention_mask.long().cumsum(-1) - 1 + tok_idx.masked_fill_(attention_mask == 0, 1) + # attention_mask is of shape [B, N] + # broadcast to [B, 1, N, N] + attention_mask = torch.logical_and( + attention_mask.unsqueeze(1).unsqueeze(-2), + attention_mask.unsqueeze(1).unsqueeze(-1), + ) + else: + tok_idx = None + attention_mask = "full" + + timesteps = torch.linspace(1, eps, steps + 1, device=x.device) + + input_ids_length = input_ids.shape[1] + batch_size = input_ids.shape[0] + + # this allows user-defined token control of the intermediate steps + # x = generation_tokens_hook_func(None, x, None, input_ids_length) + + for i in range(steps): + + logits = self(x, attention_mask, tok_idx).logits + logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1) + + # this allows user-defined logits control of the intermediate steps + logits = generation_logits_hook_func(i, x, logits) + + t = timesteps[i] + s = timesteps[i + 1] + + # loop around the batch + for b in range(batch_size): + x_row = x[b, :] + mask_index = x_row == mask_token_id + # if the sequence is already completed, skip it + if mask_index.sum() == 0: + continue + mask_logits = logits[b, mask_index] + + if alg == "origin": + # p_transfer = 1 - s / t if i < steps - 1 else 1 + # x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id + # transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer + # _, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k) + # x[mask_index] = x0.clone() + raise RuntimeError("batch origin alg is not supported") + else: + if alg == "maskgit_plus": + confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k) + elif alg == "topk_margin": + confidence, x0 = sample_tokens( + mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True + ) + elif alg == "entropy": + confidence, x0 = sample_tokens( + mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True + ) + else: + raise RuntimeError(f"Unknown alg: {alg}") + num_mask_token = mask_index.sum() + number_transfer_tokens = ceil(num_mask_token * (1 - s / t)) if i < steps - 1 else num_mask_token + + if number_transfer_tokens > 0: + if alg_temp is None or alg_temp == 0: + _, transfer_index = torch.topk(confidence, number_transfer_tokens) + + else: + confidence = confidence / alg_temp + confidence = F.softmax(confidence, dim=-1) + + transfer_index = torch.multinomial(confidence, num_samples=number_transfer_tokens) + x0_ = torch.zeros_like(x0, device=self.device, dtype=torch.long) + mask_token_id + x0_[transfer_index] = x0[transfer_index].clone() + x[b, mask_index] = x0_ + + # this allows user-defined token control of the intermediate steps + x = generation_tokens_hook_func(i, x, logits, input_ids_length) + + if not torch.any(x == mask_token_id): + break + + # Update attention mask based on pad_token_id and eos_token_id + attention_mask_tmp = torch.where( + (x == pad_token_id) | (x == eos_token_id), + torch.tensor(0, device=x.device, dtype=torch.bool), + torch.tensor(1, device=x.device, dtype=torch.bool), + ) + attention_mask_tmp = torch.logical_and( + attention_mask_tmp.unsqueeze(1).unsqueeze(-2), + attention_mask_tmp.unsqueeze(1).unsqueeze(-1), + ) + # print(f"attention_mask: {attention_mask_tmp.shape} {attention_mask_tmp}") + attention_mask = attention_mask_tmp + + if histories is not None: + histories.append(x.clone()) + + if return_dict_in_generate: + return DreamModelOutput( + sequences=x, + history=histories, + ) + else: + return x + + # block generation with kv cache + def _sample_with_block( + self, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.LongTensor], + generation_config: DreamGenerationConfig, + block_size: int, + use_cache: bool, + generation_tokens_hook_func, + generation_logits_hook_func, + ) -> Union[DreamModelOutput, torch.LongTensor]: + # init values + output_history = generation_config.output_history + return_dict_in_generate = generation_config.return_dict_in_generate + max_length = generation_config.max_length + mask_token_id = generation_config.mask_token_id + steps = generation_config.steps + eps = generation_config.eps + alg = generation_config.alg + alg_temp = generation_config.alg_temp + temperature = generation_config.temperature + top_p = generation_config.top_p + top_k = generation_config.top_k + use_cache = use_cache + + histories = [] if (return_dict_in_generate and output_history) else None + + pad_token_id = generation_config.pad_token_id + eos_token_id = generation_config.eos_token_id + block_size = block_size + gen_length = generation_config.max_new_tokens + num_of_blocks = gen_length // block_size + steps = steps // num_of_blocks + + assert gen_length % block_size == 0, "gen_length should be divisible by block_size" + + # pad input_ids to max_length + x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id) + + # TODO: Avoid this check and all future checks by always creating a mask + # If any padding tokens i.e 0 in attention mask + if attention_mask is not None and torch.any(attention_mask == 0.0): + # we do not mask the [MASK] tokens so value = 1.0 + attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0) + tok_idx_base = attention_mask.long().cumsum(-1) - 1 + tok_idx_base.masked_fill_(attention_mask == 0, 1) + # Leave padding out "<|endoftext|>1+1=2 2+2=" -> [ 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17] + + # attention_mask is of shape [B, N] + # broadcast to [B, 1, N, N] + # Set False for padding tokens and rest True + attention_mask = torch.logical_and( + attention_mask.unsqueeze(1).unsqueeze(-2), + attention_mask.unsqueeze(1).unsqueeze(-1), + ) + else: + tok_idx_base = None + attention_mask = "full" + + timesteps = torch.linspace(1, eps, steps + 1, device=x.device) + + input_ids_length = input_ids.shape[1] + batch_size = input_ids.shape[0] + + past_key_values = None + past_length = 0 + settled_length = input_ids_length + x_input = x.clone() + tok_idx = tok_idx_base.clone() if tok_idx_base is not None else None + + for blk_indx in range(num_of_blocks): + current_block = (num_of_blocks - (blk_indx + 1)) * block_size + + for i in range(steps): + + model_outputs = self( + x_input, attention_mask, tok_idx, use_cache=use_cache, past_key_values=past_key_values + ) + + logits = model_outputs.logits + logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1) + + t = timesteps[i] + s = timesteps[i + 1] + + # loop around the batch + for b in range(batch_size): + x_row = x_input[b, :] + mask_index = x_row == mask_token_id + + # if the sequence is already completed, skip it + if mask_index.sum() == 0: + continue + + if current_block > 0: + mask_index[-current_block:] = False + mask_logits = logits[b, mask_index] + + if alg == "origin": + # p_transfer = 1 - s / t if i < steps - 1 else 1 + # x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id + # transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer + # _, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k) + # x[mask_index] = x0.clone() + raise RuntimeError("batch origin alg is not supported") + else: + if alg == "maskgit_plus": + confidence, x0 = sample_tokens( + mask_logits, temperature=temperature, top_p=top_p, top_k=top_k + ) + elif alg == "topk_margin": + confidence, x0 = sample_tokens( + mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True + ) + elif alg == "entropy": + confidence, x0 = sample_tokens( + mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True + ) + else: + raise RuntimeError(f"Unknown alg: {alg}") + num_mask_token = mask_index.sum() + number_transfer_tokens = ( + ceil(num_mask_token * (1 - s / t)) if i < steps - 1 else num_mask_token + ) + + # print(f"block: {blk_indx} step: {i} batch: {b} confidence: {confidence} x0: {x0}") + # print(f"number_transfer_tokens: {number_transfer_tokens} num_mask_token: {num_mask_token} ") + if number_transfer_tokens > 0: + if alg_temp is None or alg_temp == 0: + _, transfer_index = torch.topk(confidence, number_transfer_tokens) + + else: + confidence = confidence / alg_temp + confidence = F.softmax(confidence, dim=-1) + + transfer_index = torch.multinomial(confidence, num_samples=number_transfer_tokens) + x0_ = torch.zeros_like(x0, device=self.device, dtype=torch.long) + mask_token_id + x0_[transfer_index] = x0[transfer_index].clone() + x_input[b, mask_index] = x0_ + + # this allows user-defined token control of the intermediate steps + x_input = generation_tokens_hook_func(i, x_input, logits, input_ids_length) + + if use_cache: + # 1. Update settled tokens + x[:, past_length:] = x_input + + # TODO: We can avoid these updates by setting a flag in the Attention call to not set KVs for these forward passes and only set when we reach end of the block + # Prepare for next forward pass + # 2. Update past_key_values to include only settled tokens from previous blocks + past_key_values = model_outputs.past_key_values + # need to reset this since the Attention call will add new KVs + past_key_values.crop(settled_length) + # past_key_values are already set from the last forward pass + past_length = past_key_values.get_seq_length() + + # 3. Generic cache-dependent input and position index + # https://github.com/huggingface/transformers/blob/5f4ecf2d9f867a1255131d2461d75793c0cf1db2/src/transformers/generation/utils.py#L410C13-L410C53 + x_input = x[:, past_length:].clone() + tok_idx = tok_idx_base[:, past_length:] if tok_idx is not None else None + + # TODO: optimize this we don't need to compute this every forward pass maybe only change location where tokens are settled; adhering to early stopping + # 4. Set attention mask + # Update attention mask based from the full x to capture past eos and pad tokens masks + attention_mask_tmp = torch.where( + (x == pad_token_id) | (x == eos_token_id), + torch.tensor(0, device=x.device, dtype=torch.bool), + torch.tensor(1, device=x.device, dtype=torch.bool), + ) + attention_mask_tmp = torch.logical_and( + attention_mask_tmp.unsqueeze(1).unsqueeze(-2), + attention_mask_tmp.unsqueeze(1).unsqueeze(-1), + ) + + # Drop values from the 3rd dimension to the size of new x_input so that it current Qs (aka inputs) + # [B, 1, Q_dim, K_dim] + attention_mask_tmp = attention_mask_tmp[:, :, past_length:, :] + attention_mask = attention_mask_tmp + # print(f"attention_mask: {attention_mask_tmp.shape}") + + else: + x = x_input + + # Set attention mask + # Update attention mask based on pad_token_id and eos_token_id + attention_mask_tmp = torch.where( + (x_input == pad_token_id) | (x_input == eos_token_id), + torch.tensor(0, device=x.device, dtype=torch.bool), + torch.tensor(1, device=x.device, dtype=torch.bool), + ) + attention_mask_tmp = torch.logical_and( + attention_mask_tmp.unsqueeze(1).unsqueeze(-2), + attention_mask_tmp.unsqueeze(1).unsqueeze(-1), + ) + attention_mask = attention_mask_tmp + # No need to update tok_idx since we are computing all KVs with original positions again + + if histories is not None: + histories.append(x.clone()) + + if not torch.any(x == mask_token_id): + # print("unmasked all tokens in current x exiting") + break + + # print(f"x_input: {x_input.shape} tok_idx: {tok_idx.shape if tok_idx is not None else None} {tok_idx}") + + # A block is completed update settled tokens length + if not torch.any(x == mask_token_id): + # print("unmasked all tokens in current x exiting") + break + settled_length += block_size + + if return_dict_in_generate: + return DreamModelOutput( + sequences=x, + history=histories, + ) + else: + return x + + # block generation with casual kv cache for flash attention ONLY + def _sample_with_block_with_causal_kv( + self, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.LongTensor], + generation_config: DreamGenerationConfig, + block_size: int, + generation_tokens_hook_func, + generation_logits_hook_func, + ) -> Union[DreamModelOutput, torch.LongTensor]: + # init values + output_history = generation_config.output_history + return_dict_in_generate = generation_config.return_dict_in_generate + max_length = generation_config.max_length + mask_token_id = generation_config.mask_token_id + steps = generation_config.steps + eps = generation_config.eps + alg = generation_config.alg + alg_temp = generation_config.alg_temp + temperature = generation_config.temperature + top_p = generation_config.top_p + top_k = generation_config.top_k + generation_config.pad_token_id + generation_config.eos_token_id + + histories = [] if (return_dict_in_generate and output_history) else None + + block_size = block_size + gen_length = generation_config.max_new_tokens + num_of_blocks = gen_length // block_size + steps = steps // num_of_blocks + + assert gen_length % block_size == 0, "gen_length should be divisible by block_size" + + # pad input_ids to max_length + x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id) + + # If any padding tokens i.e 0 in attention mask + if attention_mask is not None and torch.any(attention_mask == 0.0): + # we do not mask the [MASK] tokens so value = 1.0 + attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0) + tok_idx_base = attention_mask.long().cumsum(-1) - 1 + tok_idx_base.masked_fill_(attention_mask == 0, 1) + # Leave padding out "<|endoftext|>1+1=2 2+2=" -> [ 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17] + + # attention_mask is of shape [B, N] + # broadcast to [B, 1, N, N] + # Set False for padding tokens and rest True + attention_mask = torch.logical_and( + attention_mask.unsqueeze(1).unsqueeze(-2), + attention_mask.unsqueeze(1).unsqueeze(-1), + ) + else: + tok_idx_base = None + attention_mask = "full" + + timesteps = torch.linspace(1, eps, steps + 1, device=x.device) + + input_ids_length = input_ids.shape[1] + batch_size = input_ids.shape[0] + + past_key_values = None + past_length = 0 + tok_idx = tok_idx_base.clone() if tok_idx_base is not None else None + x_input = x.clone() + # initial settled length is the context/prompt length + settled_length = input_ids_length + # 1. Do first forward pass to get past_key_values for context in casual attention + model_outputs = self( + x_input, + attention_mask=attention_mask, + position_ids=tok_idx, + use_cache=True, + past_key_values=past_key_values, + is_causal=True, + ) + past_key_values = model_outputs.past_key_values + # 2. Crop past_key_values to include only context tokens + past_key_values.crop(settled_length) + past_length = past_key_values.get_seq_length() + # 3. Create new input for prediction + x_input = x[:, past_length:].clone() + tok_idx = tok_idx_base[:, past_length:] if tok_idx_base is not None else None + + # print(f"settled_length: {settled_length} past_length: {past_length} x_input: {x_input.shape} past_key_values: {past_key_values.get_seq_length()}") + + for blk_indx in range(num_of_blocks): + current_block = (num_of_blocks - (blk_indx + 1)) * block_size + + for i in range(steps): + model_outputs = self( + x_input, + attention_mask= # The above code is defining a variable + # named `attention_mask` in Python. + attention_mask, + position_ids=tok_idx, + use_cache=True, + past_key_values=past_key_values, + is_causal=False, + ) + + logits = model_outputs.logits + logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1) + + t = timesteps[i] + s = timesteps[i + 1] + + # loop around the batch + for b in range(batch_size): + x_row = x_input[b, :] + mask_index = x_row == mask_token_id + + # if the sequence is already completed, skip it + if mask_index.sum() == 0: + continue + + if current_block > 0: + mask_index[-current_block:] = False + + mask_logits = logits[b, mask_index] + + if alg == "origin": + # p_transfer = 1 - s / t if i < steps - 1 else 1 + # x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id + # transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer + # _, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k) + # x[mask_index] = x0.clone() + raise RuntimeError("batch origin alg is not supported") + else: + if alg == "maskgit_plus": + confidence, x0 = sample_tokens( + mask_logits, temperature=temperature, top_p=top_p, top_k=top_k + ) + elif alg == "topk_margin": + confidence, x0 = sample_tokens( + mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True + ) + elif alg == "entropy": + confidence, x0 = sample_tokens( + mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True + ) + else: + raise RuntimeError(f"Unknown alg: {alg}") + num_mask_token = mask_index.sum() + number_transfer_tokens = ( + ceil(num_mask_token * (1 - s / t)) if i < steps - 1 else num_mask_token + ) + + # print(f"block: {blk_indx} step: {i} batch: {b} confidence: {confidence} x0: {x0}") + # print(f"number_transfer_tokens: {number_transfer_tokens} num_mask_token: {num_mask_token}") + if number_transfer_tokens > 0: + if alg_temp is None or alg_temp == 0: + _, transfer_index = torch.topk(confidence, number_transfer_tokens) + + else: + confidence = confidence / alg_temp + confidence = F.softmax(confidence, dim=-1) + + transfer_index = torch.multinomial(confidence, num_samples=number_transfer_tokens) + x0_ = torch.zeros_like(x0, device=self.device, dtype=torch.long) + mask_token_id + x0_[transfer_index] = x0[transfer_index].clone() + x_input[b, mask_index] = x0_ + + # this allows user-defined token control of the intermediate steps + x_input = generation_tokens_hook_func(i, x_input, logits, input_ids_length) + + # Update settled tokens + x[:, past_length:] = x_input + + # Prepare for next forward pass + + # 1. Update past_key_values to include only settled tokens from previous blocks + # past_key_values = model_outputs.past_key_values + + # needed bcuz Attention module adds them to cache so we need to remove them for next forward pass + # we can stop the Attention module from adding them with a param for speedup + past_key_values.crop(settled_length) + # past_length = past_key_values.get_seq_length() + # print(f"past_length: {past_length} x_input: {x_input.shape} past_length: {past_length} past_key_values: {past_key_values.get_seq_length()}") + + # # only works for sdpa + # attention_mask_tmp = torch.where( + # (x == pad_token_id) | (x == eos_token_id), + # torch.tensor(0, device=x.device, dtype=torch.bool), + # torch.tensor(1, device=x.device, dtype=torch.bool) + # ) + # attention_mask_tmp = torch.logical_and( + # attention_mask_tmp.unsqueeze(1).unsqueeze(-2), + # attention_mask_tmp.unsqueeze(1).unsqueeze(-1), + # ) + + # # Drop values from the 3rd dimension to the size of new x_input so that it current Qs (aka inputs) + # # [B, 1, Q_dim, K_dim] + # attention_mask_tmp = attention_mask_tmp[:, :, past_length:, :] + # attention_mask = attention_mask_tmp + + if histories is not None: + histories.append(x.clone()) + + # A block is completed update settled tokens length + if not torch.any(x == mask_token_id): + break + settled_length += block_size + model_outputs = self( + x_input, + attention_mask=attention_mask, + position_ids=tok_idx, + use_cache=True, + past_key_values=past_key_values, + is_causal=True, + ) + past_key_values = model_outputs.past_key_values + past_key_values.crop(settled_length) + past_length = past_key_values.get_seq_length() + x_input = x[:, past_length:].clone() + tok_idx = tok_idx_base[:, past_length:] if tok_idx is not None else None + + # # Only works for sdpa + # attention_mask_tmp = torch.where( + # (x == pad_token_id) | (x == eos_token_id), + # torch.tensor(0, device=x.device, dtype=torch.bool), + # torch.tensor(1, device=x.device, dtype=torch.bool) + # ) + # attention_mask_tmp = torch.logical_and( + # attention_mask_tmp.unsqueeze(1).unsqueeze(-2), + # attention_mask_tmp.unsqueeze(1).unsqueeze(-1), + # ) + + # # Drop values from the 3rd dimension to the size of new x_input so that it current Qs (aka inputs) + # # [B, 1, Q_dim, K_dim] + # attention_mask_tmp = attention_mask_tmp[:, :, past_length:, :] + # attention_mask = attention_mask_tmp + # print(f"settled_length: {settled_length} past_length: {past_length} x_input: {x_input.shape} past_key_values: {past_key_values.get_seq_length()}") + + if return_dict_in_generate: + return DreamModelOutput( + sequences=x, + history=histories, + ) + else: + return x diff --git a/fast_llm/models/gpt/external/diffusion_llama/modeling_diffusion_llama.py b/fast_llm/models/gpt/external/diffusion_llama/modeling_diffusion_llama.py new file mode 100644 index 00000000..7e0bd797 --- /dev/null +++ b/fast_llm/models/gpt/external/diffusion_llama/modeling_diffusion_llama.py @@ -0,0 +1,1426 @@ +import math +import os +from dataclasses import dataclass + +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Optional, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers import PretrainedConfig +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.integrations import use_kernel_forward_from_hub +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs + +# from transformers.modeling_layers import GradientCheckpointingLayer # Update transformer +from transformers.modeling_outputs import BaseModelOutputWithPast, MaskedLMOutput +from transformers.modeling_rope_utils import dynamic_rope_update +from transformers.modeling_utils import PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( # auto_docstring + LossKwargs, + ModelOutput, + can_return_tuple, + is_torch_flex_attn_available, + logging, +) + +from .configuration_diffusion_llama import ROPE_INIT_FUNCTIONS, DiffusionLlamaConfig +from .generation_utils import DreamGenerationConfig, DreamGenerationMixin + +if is_torch_flex_attn_available(): + from flash_attn import flash_attn_with_kvcache + + +logger = logging.get_logger(__name__) + + +@use_kernel_forward_from_hub("RMSNorm") +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, config: DiffusionLlamaConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.integrations.sdpa_attention +def sdpa_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + # is_causal: Optional[bool] = None, + **kwargs, +) -> tuple[torch.Tensor, None]: + if hasattr(module, "num_key_value_groups"): + key = repeat_kv(key, module.num_key_value_groups) + value = repeat_kv(value, module.num_key_value_groups) + + # Note: Updates from Dream + # causal_mask = attention_mask + # if attention_mask is not None: + # causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions + # Reference: https://github.com/pytorch/pytorch/issues/112577. + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # Note: Updates from Dream + # if is_causal is None: + # is_causal = causal_mask is None and query.shape[2] > 1 + + # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor. + # We convert it to a bool for the SDPA kernel that only accepts bools. + # note: Updates from Dream + # if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor): + # is_causal = is_causal.item() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + # Note: Updates from Dream + # attn_mask=causal_mask, + attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None, + dropout_p=dropout, + scale=scaling, + # is_causal=is_causal, + is_causal=False, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, None + + +def sdpa_attention_from_dream_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + is_causal: Optional[bool] = False, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "DreamModel is using DreamSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # print(f"query_states {query_states.shape} {query_states}") + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + # is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +def flash_attention_from_dreamforward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + is_causal: Optional[bool] = False, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "DreamModel is using DreamFlashAttention, it does not support `output_attentions=True`. Falling back to the manual attention implementation, " + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + # print(f"hidden_states: {hidden_states.shape} query_states {query_states.shape} key_states {key_states.shape} value_states {value_states.shape}") + # print(f"position_ids {position_ids} {position_ids.shape}") + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # if query_states.device.type == "cuda" and attention_mask is not None: + # query_states = query_states.contiguous() + # key_states = key_states.contiguous() + # value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + # is_causal = True if causal_mask is None and q_len > 1 else False + + # attn_output_sdpa = torch.nn.functional.scaled_dot_product_attention( + # query_states, + # key_states, + # value_states, + # attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None, + # dropout_p=self.attention_dropout if self.training else 0.0, + # is_causal=False, # hard coded + # ) + + # print(f"query_states {query_states.shape} key_states {key_states.shape} value_states {value_states.shape}") + + # replacing with flash attention + attn_output = flash_attn_with_kvcache( + # q dim (batch_size, seqlen, nheads, headdim) + q=query_states.transpose(1, 2).contiguous(), + k_cache=key_states.transpose(1, 2).contiguous(), + v_cache=value_states.transpose(1, 2).contiguous(), + causal=is_causal, # hard coded + softmax_scale=1.0 / math.sqrt(self.head_dim), + ) + + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DiffusionLlamaConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + elif self.config._attn_implementation == "sdpa": + attention_interface = sdpa_attention_from_dream_forward + elif self.config._attn_implementation == "flash_attention": + attention_interface = flash_attention_from_dreamforward + else: + raise ValueError(f"Unsupported attention implementation: {self.config._attn_implementation}") + # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +# TODO: Update after transformer update: class LlamaDecoderLayer(GradientCheckpointingLayer): +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: DiffusionLlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + # When use_cache is True, outputs will have length: + # - 2 if output_attentions is False (hidden_states, present_key_value) + # - 3 if output_attentions is True (hidden_states, self_attn_weights, present_key_value) + # print(f"DreamDecoderLayer: outputs {len(outputs)}") + + return outputs + + +# @auto_docstring +class DiffusionLlamaPreTrainedModel(PreTrainedModel): + config_class = DiffusionLlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = False + _supports_sdpa = False # TODO: Enable sdpa + _supports_flex_attn = False + _supports_cache_class = True + _supports_quantized_cache = False + _supports_static_cache = True + _supports_attention_backend = False + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, LlamaRMSNorm): + module.weight.data.fill_(1.0) + + # TODO: Copied from Dream Update + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *model_args, + config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + ignore_mismatched_sizes: bool = False, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + use_safetensors: Optional[bool] = None, + weights_only: bool = True, + **kwargs, + ): + _model = super().from_pretrained( + pretrained_model_name_or_path, + *model_args, + config=config, + cache_dir=cache_dir, + ignore_mismatched_sizes=ignore_mismatched_sizes, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + use_safetensors=use_safetensors, + weights_only=weights_only, + **kwargs, + ) + # NOTE(Lin): we need to override the generation config + # because the generation config loaded in `from_pretrained` + # does not include all the attributes of DreamGenerationConfig + resume_download = kwargs.get("resume_download", None) + proxies = kwargs.get("proxies", None) + subfolder = kwargs.get("subfolder", "") + from_auto_class = kwargs.get("_from_auto", False) + from_pipeline = kwargs.get("_from_pipeline", None) + _model.generation_config = DreamGenerationConfig.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + ) + return _model + + +# @auto_docstring +class DiffusionLlamaBaseModel(DiffusionLlamaPreTrainedModel): + def __init__(self, config: DiffusionLlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + # @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # TODO: Fix attention mask for diffusion + # causal_mask = self._update_causal_mask( + # attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + # ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + # attention_mask=causal_mask, # TODO: Fix attention mask for diffusion + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if use_cache: + # print("past_key_values", past_key_values, "use_cache", use_cache, "layer_outputs", layer_outputs) + past_key_values = layer_outputs[-1] + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +@dataclass +class MaskedLMOutputWithPast(ModelOutput): + """ + Base class for masked language models outputs with past. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Masked language modeling (MLM) loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + past_key_values: Optional[tuple[Cache]] = None + + # TODO: Update for diffusion with bi-directional attention (later block casual masking) + # def _update_causal_mask( + # self, + # attention_mask: Union[torch.Tensor, "BlockMask"], + # input_tensor: torch.Tensor, + # cache_position: torch.Tensor, + # past_key_values: Cache, + # output_attentions: bool = False, + # ): + # if self.config._attn_implementation == "flash_attention_2": + # if attention_mask is not None and (attention_mask == 0.0).any(): + # return attention_mask + # return None + # if self.config._attn_implementation == "flex_attention": + # if isinstance(attention_mask, torch.Tensor): + # attention_mask = make_flex_block_causal_mask(attention_mask) + # return attention_mask + + # # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # # to infer the attention mask. + # past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + # using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + # if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + # if AttentionMaskConverter._ignore_causal_mask_sdpa( + # attention_mask, + # inputs_embeds=input_tensor, + # past_key_values_length=past_seen_tokens, + # is_training=self.training, + # ): + # return None + + # dtype = input_tensor.dtype + # sequence_length = input_tensor.shape[1] + # if using_compilable_cache: + # target_length = past_key_values.get_max_cache_shape() + # else: + # target_length = ( + # attention_mask.shape[-1] + # if isinstance(attention_mask, torch.Tensor) + # else past_seen_tokens + sequence_length + 1 + # ) + + # # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + # causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + # attention_mask, + # sequence_length=sequence_length, + # target_length=target_length, + # dtype=dtype, + # cache_position=cache_position, + # batch_size=input_tensor.shape[0], + # ) + + # if ( + # self.config._attn_implementation == "sdpa" + # and attention_mask is not None + # and attention_mask.device.type in ["cuda", "xpu", "npu"] + # and not output_attentions + # ): + # # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # # Details: https://github.com/pytorch/pytorch/issues/110213 + # min_dtype = torch.finfo(dtype).min + # causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + # return causal_mask + + # @staticmethod + # def _prepare_4d_causal_attention_mask_with_cache_position( + # attention_mask: torch.Tensor, + # sequence_length: int, + # target_length: int, + # dtype: torch.dtype, + # cache_position: torch.Tensor, + # batch_size: int, + # **kwargs, + # ): + # """ + # Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + # `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + # Args: + # attention_mask (`torch.Tensor`): + # A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + # `(batch_size, 1, query_length, key_value_length)`. + # sequence_length (`int`): + # The sequence length being processed. + # target_length (`int`): + # The target length: when generating with static cache, the mask should be as long as the static cache, + # to account for the 0 padding, the part of the cache that is not filled yet. + # dtype (`torch.dtype`): + # The dtype to use for the 4D attention mask. + # cache_position (`torch.Tensor`): + # Indices depicting the position of the input sequence tokens in the sequence. + # batch_size (`torch.Tensor`): + # Batch size. + # """ + # if attention_mask is not None and attention_mask.dim() == 4: + # # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + # causal_mask = attention_mask + # else: + # min_dtype = torch.finfo(dtype).min + # causal_mask = torch.full( + # (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + # ) + # if sequence_length != 1: + # causal_mask = torch.triu(causal_mask, diagonal=1) + # causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + # causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + # if attention_mask is not None: + # causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + # mask_length = attention_mask.shape[-1] + # padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + # causal_mask.device + # ) + # padding_mask = padding_mask == 0 + # causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + # padding_mask, min_dtype + # ) + + # return causal_mask + + +# class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +# @auto_docstring +class DiffusionLlamaModel(DiffusionLlamaPreTrainedModel, DreamGenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = DiffusionLlamaBaseModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + # @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, # TODO: Kwargs for Diffusion? : Unpack[KwargsForCausalLM], + ) -> MaskedLMOutput: + r""" + # TODO: Update docstring for diffusion + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + # is_casual=kwargs.get("is_casual", False), + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + # return MaskedLMOutput( + # loss=loss, + # logits=logits, + # hidden_states=outputs.hidden_states, + # attentions=outputs.attentions, + # ) + return MaskedLMOutputWithPast( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + past_key_values=outputs.past_key_values, + ) + + +# @auto_docstring +# class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): +# _tied_weights_keys = ["lm_head.weight"] +# _tp_plan = {"lm_head": "colwise_rep"} +# _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + +# def __init__(self, config): +# super().__init__(config) +# self.model = LlamaModel(config) +# self.vocab_size = config.vocab_size +# self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + +# # Initialize weights and apply final processing +# self.post_init() + +# def get_input_embeddings(self): +# return self.model.embed_tokens + +# def set_input_embeddings(self, value): +# self.model.embed_tokens = value + +# def get_output_embeddings(self): +# return self.lm_head + +# def set_output_embeddings(self, new_embeddings): +# self.lm_head = new_embeddings + +# def set_decoder(self, decoder): +# self.model = decoder + +# def get_decoder(self): +# return self.model + +# @can_return_tuple +# @auto_docstring +# def forward( +# self, +# input_ids: Optional[torch.LongTensor] = None, +# attention_mask: Optional[torch.Tensor] = None, +# position_ids: Optional[torch.LongTensor] = None, +# past_key_values: Optional[Cache] = None, +# inputs_embeds: Optional[torch.FloatTensor] = None, +# labels: Optional[torch.LongTensor] = None, +# use_cache: Optional[bool] = None, +# output_attentions: Optional[bool] = None, +# output_hidden_states: Optional[bool] = None, +# cache_position: Optional[torch.LongTensor] = None, +# logits_to_keep: Union[int, torch.Tensor] = 0, +# **kwargs: Unpack[KwargsForCausalLM], +# ) -> CausalLMOutputWithPast: +# r""" +# labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): +# Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., +# config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored +# (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + +# Example: + +# ```python +# >>> from transformers import AutoTokenizer, LlamaForCausalLM + +# >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") +# >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + +# >>> prompt = "Hey, are you conscious? Can you talk to me?" +# >>> inputs = tokenizer(prompt, return_tensors="pt") + +# >>> # Generate +# >>> generate_ids = model.generate(inputs.input_ids, max_length=30) +# >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] +# "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." +# ```""" +# output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions +# output_hidden_states = ( +# output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states +# ) + +# # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) +# outputs: BaseModelOutputWithPast = self.model( +# input_ids=input_ids, +# attention_mask=attention_mask, +# position_ids=position_ids, +# past_key_values=past_key_values, +# inputs_embeds=inputs_embeds, +# use_cache=use_cache, +# output_attentions=output_attentions, +# output_hidden_states=output_hidden_states, +# cache_position=cache_position, +# **kwargs, +# ) + +# hidden_states = outputs.last_hidden_state +# # Only compute necessary logits, and do not upcast them to float if we are not computing the loss +# slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep +# logits = self.lm_head(hidden_states[:, slice_indices, :]) + +# loss = None +# if labels is not None: +# loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + +# return CausalLMOutputWithPast( +# loss=loss, +# logits=logits, +# past_key_values=outputs.past_key_values, +# hidden_states=outputs.hidden_states, +# attentions=outputs.attentions, +# ) + + +# @auto_docstring( +# custom_intro=""" +# The LLaMa Model transformer with a sequence classification head on top (linear layer). + +# [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models +# (e.g. GPT-2) do. + +# Since it does classification on the last token, it requires to know the position of the last token. If a +# `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If +# no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the +# padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in +# each row of the batch). +# """ +# ) +# class LlamaForSequenceClassification(LlamaPreTrainedModel): +# def __init__(self, config): +# super().__init__(config) +# self.num_labels = config.num_labels +# self.model = LlamaModel(config) +# self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + +# # Initialize weights and apply final processing +# self.post_init() + +# def get_input_embeddings(self): +# return self.model.embed_tokens + +# def set_input_embeddings(self, value): +# self.model.embed_tokens = value + +# @can_return_tuple +# @auto_docstring +# def forward( +# self, +# input_ids: Optional[torch.LongTensor] = None, +# attention_mask: Optional[torch.Tensor] = None, +# position_ids: Optional[torch.LongTensor] = None, +# past_key_values: Optional[Cache] = None, +# inputs_embeds: Optional[torch.FloatTensor] = None, +# labels: Optional[torch.LongTensor] = None, +# use_cache: Optional[bool] = None, +# output_attentions: Optional[bool] = None, +# output_hidden_states: Optional[bool] = None, +# ) -> SequenceClassifierOutputWithPast: +# r""" +# labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): +# Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., +# config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If +# `config.num_labels > 1` a classification loss is computed (Cross-Entropy). +# """ + +# transformer_outputs: BaseModelOutputWithPast = self.model( +# input_ids, +# attention_mask=attention_mask, +# position_ids=position_ids, +# past_key_values=past_key_values, +# inputs_embeds=inputs_embeds, +# use_cache=use_cache, +# output_attentions=output_attentions, +# output_hidden_states=output_hidden_states, +# ) +# hidden_states = transformer_outputs.last_hidden_state +# logits = self.score(hidden_states) + +# if input_ids is not None: +# batch_size = input_ids.shape[0] +# else: +# batch_size = inputs_embeds.shape[0] + +# if self.config.pad_token_id is None and batch_size != 1: +# raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") +# if self.config.pad_token_id is None: +# last_non_pad_token = -1 +# elif input_ids is not None: +# # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id +# non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) +# token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) +# last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) +# else: +# last_non_pad_token = -1 +# logger.warning_once( +# f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " +# "unexpected if using padding tokens in conjunction with `inputs_embeds.`" +# ) + +# pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + +# loss = None +# if labels is not None: +# loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + +# return SequenceClassifierOutputWithPast( +# loss=loss, +# logits=pooled_logits, +# past_key_values=transformer_outputs.past_key_values, +# hidden_states=transformer_outputs.hidden_states, +# attentions=transformer_outputs.attentions, +# ) + + +# @auto_docstring +# class LlamaForQuestionAnswering(LlamaPreTrainedModel): +# base_model_prefix = "transformer" + +# # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama +# def __init__(self, config): +# super().__init__(config) +# self.transformer = LlamaModel(config) +# self.qa_outputs = nn.Linear(config.hidden_size, 2) + +# # Initialize weights and apply final processing +# self.post_init() + +# def get_input_embeddings(self): +# return self.transformer.embed_tokens + +# def set_input_embeddings(self, value): +# self.transformer.embed_tokens = value + +# @can_return_tuple +# @auto_docstring +# def forward( +# self, +# input_ids: Optional[torch.LongTensor] = None, +# attention_mask: Optional[torch.Tensor] = None, +# position_ids: Optional[torch.LongTensor] = None, +# past_key_values: Optional[Cache] = None, +# inputs_embeds: Optional[torch.FloatTensor] = None, +# start_positions: Optional[torch.LongTensor] = None, +# end_positions: Optional[torch.LongTensor] = None, +# output_attentions: Optional[bool] = None, +# output_hidden_states: Optional[bool] = None, +# **kwargs, +# ) -> QuestionAnsweringModelOutput: +# outputs: BaseModelOutputWithPast = self.transformer( +# input_ids, +# attention_mask=attention_mask, +# position_ids=position_ids, +# past_key_values=past_key_values, +# inputs_embeds=inputs_embeds, +# output_attentions=output_attentions, +# output_hidden_states=output_hidden_states, +# ) + +# sequence_output = outputs.last_hidden_state + +# logits = self.qa_outputs(sequence_output) +# start_logits, end_logits = logits.split(1, dim=-1) +# start_logits = start_logits.squeeze(-1).contiguous() +# end_logits = end_logits.squeeze(-1).contiguous() + +# loss = None +# if start_positions is not None and end_positions is not None: +# loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + +# return QuestionAnsweringModelOutput( +# loss=loss, +# start_logits=start_logits, +# end_logits=end_logits, +# hidden_states=outputs.hidden_states, +# attentions=outputs.attentions, +# ) + + +# @auto_docstring +# class LlamaForTokenClassification(LlamaPreTrainedModel): +# def __init__(self, config): +# super().__init__(config) +# self.num_labels = config.num_labels +# self.model = LlamaModel(config) +# if getattr(config, "classifier_dropout", None) is not None: +# classifier_dropout = config.classifier_dropout +# elif getattr(config, "hidden_dropout", None) is not None: +# classifier_dropout = config.hidden_dropout +# else: +# classifier_dropout = 0.1 +# self.dropout = nn.Dropout(classifier_dropout) +# self.score = nn.Linear(config.hidden_size, config.num_labels) + +# # Initialize weights and apply final processing +# self.post_init() + +# def get_input_embeddings(self): +# return self.model.embed_tokens + +# def set_input_embeddings(self, value): +# self.model.embed_tokens = value + +# @can_return_tuple +# @auto_docstring +# def forward( +# self, +# input_ids: Optional[torch.LongTensor] = None, +# attention_mask: Optional[torch.Tensor] = None, +# position_ids: Optional[torch.LongTensor] = None, +# past_key_values: Optional[Cache] = None, +# inputs_embeds: Optional[torch.FloatTensor] = None, +# labels: Optional[torch.LongTensor] = None, +# use_cache: Optional[bool] = None, +# output_attentions: Optional[bool] = None, +# output_hidden_states: Optional[bool] = None, +# ) -> TokenClassifierOutput: +# r""" +# labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): +# Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., +# config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If +# `config.num_labels > 1` a classification loss is computed (Cross-Entropy). +# """ + +# outputs: BaseModelOutputWithPast = self.model( +# input_ids, +# attention_mask=attention_mask, +# position_ids=position_ids, +# past_key_values=past_key_values, +# inputs_embeds=inputs_embeds, +# use_cache=use_cache, +# output_attentions=output_attentions, +# output_hidden_states=output_hidden_states, +# ) +# sequence_output = outputs.last_hidden_state +# sequence_output = self.dropout(sequence_output) +# logits = self.score(sequence_output) + +# loss = None +# if labels is not None: +# loss = self.loss_function(logits, labels, self.config) + +# return TokenClassifierOutput( +# loss=loss, +# logits=logits, +# hidden_states=outputs.hidden_states, +# attentions=outputs.attentions, +# ) + + +__all__ = [ + "DiffusionLlamaModel", +] diff --git a/tests/common.py b/tests/common.py index 992e070e..ad677381 100644 --- a/tests/common.py +++ b/tests/common.py @@ -17,6 +17,8 @@ from fast_llm.layers.ssm.config import SSMConfig from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.models.gpt.config import ( + DiffusionDreamGPTHuggingfaceCheckpointFormat, + DiffusionLlamaGPTHuggingfaceCheckpointFormat, LlamaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, @@ -281,6 +283,20 @@ CONFIG_LLAMA_MTP_COMMON, MTPLlamaGPTHuggingfaceCheckpointFormat, ), + "dream": ( + "gpt", + CONFIG_QWEN2_FAST_LLM, + CONFIG_QWEN2_MEGATRON, + CONFIG_QWEN2_COMMON, + DiffusionDreamGPTHuggingfaceCheckpointFormat, + ), + "diffusion_llama": ( + "gpt", + CONFIG_LLAMA_YARN_FAST_LLM, + CONFIG_LLAMA_YARN_MEGATRON, + CONFIG_LLAMA_YARN_COMMON, + DiffusionLlamaGPTHuggingfaceCheckpointFormat, + ), } TEST_MODEL_TYPE, CONFIG_FAST_LLM, CONFIG_GPT2, CONFIG_COMMON, HUGGINGFACE_CHECKPOINT_FORMAT = _CONFIGS[TEST_MODEL] diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 481a7016..458b8f36 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -406,7 +406,7 @@ def test_run_converted_model(): ) errors = [] compare = CompareConfig() - model_as_hf = transformers.AutoModelForCausalLM.from_pretrained( + model_as_hf = transformers.AutoModel.from_pretrained( CONVERT_PATH / "huggingface_0", trust_remote_code=HUGGINGFACE_CHECKPOINT_FORMAT.trust_remote_code ).cuda() for name, model in zip(