diff --git a/docs/source/en/api/pipelines/block_refinement.md b/docs/source/en/api/pipelines/block_refinement.md new file mode 100644 index 000000000000..96e2f64e8660 --- /dev/null +++ b/docs/source/en/api/pipelines/block_refinement.md @@ -0,0 +1,63 @@ + + +# Block Refinement + +`BlockRefinementPipeline` performs block-wise iterative refinement over a masked token template, sampling and +committing tokens based on confidence. + +## Config defaults + +You can set default sampling parameters when creating the pipeline. Passing `None` for a parameter in `__call__` +falls back to `pipe.config`. + +```py +from diffusers import BlockRefinementPipeline + +pipe = BlockRefinementPipeline( + model=model, + tokenizer=tokenizer, + gen_length=256, + block_length=32, + steps=16, + temperature=0.8, + sampling_method="multinomial", +) + +out = pipe(prompt="Explain gradient descent.") +print(out.texts[0]) +``` + +## Callbacks + +Callbacks run after each refinement step and can inspect or override the current tokens. + +```py +def on_step_end(pipe, step, timestep, callback_kwargs): + cur_x = callback_kwargs["cur_x"] + # Inspect or modify `cur_x` here. + return {"cur_x": cur_x} + +out = pipe( + prompt="Write a short poem.", + callback_on_step_end=on_step_end, + callback_on_step_end_tensor_inputs=["cur_x"], +) +``` + +## BlockRefinementPipeline +[[autodoc]] BlockRefinementPipeline + - all + - __call__ + +## BlockRefinementPipelineOutput +[[autodoc]] pipelines.BlockRefinementPipelineOutput diff --git a/docs/source/en/api/pipelines/block_token_diffusion.md b/docs/source/en/api/pipelines/block_token_diffusion.md new file mode 100644 index 000000000000..df25001ffa03 --- /dev/null +++ b/docs/source/en/api/pipelines/block_token_diffusion.md @@ -0,0 +1,23 @@ + + +# Block Token Diffusion + +`BlockTokenDiffusionPipeline` performs token diffusion by iterating over fixed-size blocks of the sequence. + +## BlockTokenDiffusionPipeline +[[autodoc]] BlockTokenDiffusionPipeline + - all + - __call__ + +## BlockTokenDiffusionPipelineOutput +[[autodoc]] pipelines.BlockTokenDiffusionPipelineOutput diff --git a/docs/source/en/api/pipelines/hybrid_token_diffusion.md b/docs/source/en/api/pipelines/hybrid_token_diffusion.md new file mode 100644 index 000000000000..56ccd61bfbc8 --- /dev/null +++ b/docs/source/en/api/pipelines/hybrid_token_diffusion.md @@ -0,0 +1,23 @@ + + +# Hybrid Token Diffusion + +`HybridTokenDiffusionPipeline` is an alias of `TokenDiffusionPipeline` for hybrid-transition schedulers. + +## HybridTokenDiffusionPipeline +[[autodoc]] HybridTokenDiffusionPipeline + - all + - __call__ + +## TokenDiffusionPipelineOutput +[[autodoc]] pipelines.TokenDiffusionPipelineOutput diff --git a/docs/source/en/api/pipelines/llada2.md b/docs/source/en/api/pipelines/llada2.md new file mode 100644 index 000000000000..ba9330e4f5b3 --- /dev/null +++ b/docs/source/en/api/pipelines/llada2.md @@ -0,0 +1,23 @@ + + +# LLaDA2 + +`LLaDA2Pipeline` adapts block refinement sampling for LLaDA2-style token diffusion models. + +## LLaDA2Pipeline +[[autodoc]] LLaDA2Pipeline + - all + - __call__ + +## LLaDA2PipelineOutput +[[autodoc]] pipelines.LLaDA2PipelineOutput diff --git a/docs/source/en/api/pipelines/overview.md b/docs/source/en/api/pipelines/overview.md index 22fcf560eaca..4124761f86a9 100644 --- a/docs/source/en/api/pipelines/overview.md +++ b/docs/source/en/api/pipelines/overview.md @@ -34,6 +34,8 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an | [AudioLDM2](audioldm2) | text2audio | | [AuraFlow](aura_flow) | text2image | | [BLIP Diffusion](blip_diffusion) | text2image | +| [Block Refinement](block_refinement) | text2text | +| [Block Token Diffusion](block_token_diffusion) | text2text | | [Bria 3.2](bria_3_2) | text2image | | [CogVideoX](cogvideox) | text2video | | [Consistency Models](consistency_models) | unconditional image generation | @@ -47,11 +49,14 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an | [Dance Diffusion](dance_diffusion) | unconditional audio generation | | [DDIM](ddim) | unconditional image generation | | [DDPM](ddpm) | unconditional image generation | +| [DFlash](dflash) | text2text | +| [SDAR](sdar) | text2text | | [DeepFloyd IF](deepfloyd_if) | text2image, image2image, inpainting, super-resolution | | [DiffEdit](diffedit) | inpainting | | [DiT](dit) | text2image | | [Flux](flux) | text2image | | [Hunyuan-DiT](hunyuandit) | text2image | +| [Hybrid Token Diffusion](hybrid_token_diffusion) | text2text | | [I2VGen-XL](i2vgenxl) | image2video | | [InstructPix2Pix](pix2pix) | image editing | | [Kandinsky 2.1](kandinsky) | text2image, image2image, inpainting, interpolation | @@ -62,6 +67,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an | [Latent Diffusion](latent_diffusion) | text2image, super-resolution | | [Latte](latte) | text2image | | [LEDITS++](ledits_pp) | image editing | +| [LLaDA2](llada2) | text2text | | [Lumina-T2X](lumina) | text2image | | [Marigold](marigold) | depth-estimation, normals-estimation, intrinsic-decomposition | | [MultiDiffusion](panorama) | text2image | @@ -83,6 +89,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an | [T2I-Adapter](stable_diffusion/adapter) | text2image | | [Text2Video](text_to_video) | text2video, video2video | | [Text2Video-Zero](text_to_video_zero) | text2video | +| [Token Diffusion](token_diffusion) | text2text | | [unCLIP](unclip) | text2image, image variation | | [UniDiffuser](unidiffuser) | text2image, image2text, image variation, text variation, unconditional image generation, unconditional audio generation | | [Value-guided planning](value_guided_sampling) | value guided sampling | diff --git a/docs/source/en/api/pipelines/token_diffusion.md b/docs/source/en/api/pipelines/token_diffusion.md new file mode 100644 index 000000000000..c105f53abece --- /dev/null +++ b/docs/source/en/api/pipelines/token_diffusion.md @@ -0,0 +1,24 @@ + + +# Token Diffusion + +`TokenDiffusionPipeline` provides a generic token-space diffusion sampler for discrete denoising over token IDs. It +pairs a token denoiser model with a token diffusion scheduler. + +## TokenDiffusionPipeline +[[autodoc]] TokenDiffusionPipeline + - all + - __call__ + +## TokenDiffusionPipelineOutput +[[autodoc]] pipelines.TokenDiffusionPipelineOutput diff --git a/docs/source/en/api/schedulers/block_token_diffusion.md b/docs/source/en/api/schedulers/block_token_diffusion.md new file mode 100644 index 000000000000..2ad92117fb0f --- /dev/null +++ b/docs/source/en/api/schedulers/block_token_diffusion.md @@ -0,0 +1,21 @@ + + +# BlockTokenDiffusionScheduler + +`BlockTokenDiffusionScheduler` extends `TokenDiffusionScheduler` with block-wise updates over token positions. + +## BlockTokenDiffusionScheduler +[[autodoc]] BlockTokenDiffusionScheduler + +## TokenDiffusionSchedulerOutput +[[autodoc]] schedulers.scheduling_token_diffusion.TokenDiffusionSchedulerOutput diff --git a/docs/source/en/api/schedulers/hybrid_token_diffusion.md b/docs/source/en/api/schedulers/hybrid_token_diffusion.md new file mode 100644 index 000000000000..4dcdda0ea49c --- /dev/null +++ b/docs/source/en/api/schedulers/hybrid_token_diffusion.md @@ -0,0 +1,22 @@ + + +# HybridTokenDiffusionScheduler + +`HybridTokenDiffusionScheduler` defines hybrid discrete token diffusion updates with separate transitions for +masked and unmasked tokens. + +## HybridTokenDiffusionScheduler +[[autodoc]] HybridTokenDiffusionScheduler + +## HybridTokenDiffusionSchedulerOutput +[[autodoc]] schedulers.scheduling_hybrid_token_diffusion.HybridTokenDiffusionSchedulerOutput diff --git a/docs/source/en/api/schedulers/overview.md b/docs/source/en/api/schedulers/overview.md index a57e99a3e46e..2dc8feae6964 100644 --- a/docs/source/en/api/schedulers/overview.md +++ b/docs/source/en/api/schedulers/overview.md @@ -54,6 +54,28 @@ Many schedulers are implemented from the [k-diffusion](https://github.com/crowso | exponential | init with `timestep_spacing="linspace"`, `use_exponential_sigmas=True` | | beta | init with `timestep_spacing="linspace"`, `use_beta_sigmas=True` | +## Token diffusion schedulers + +These schedulers operate over categorical token IDs instead of continuous latents. They are designed for discrete +token diffusion models and expose the same `set_timesteps`/`step` interface as other schedulers. + +Differences between the discrete token schedulers: +- `TokenDiffusionScheduler`: token-level diffusion with per-token corruption (e.g. mask/uniform) and a single-step `step` to denoise logits. +- `BlockTokenDiffusionScheduler`: block-wise token diffusion that updates fixed-size blocks in parallel. +- `HybridTokenDiffusionScheduler`: hybrid transitions that combine token- and block-wise updates in the same schedule. +- `DFlashTokenDiffusionScheduler`: block diffusion scheduler specialized for speculative decoding with a draft model and target acceptance. +- `SDARTokenDiffusionScheduler`: block diffusion scheduler with remasking strategies (sequential/low-confidence/entropy-bounded) per step. + +[[autodoc]] TokenDiffusionScheduler + +[[autodoc]] BlockTokenDiffusionScheduler + +[[autodoc]] HybridTokenDiffusionScheduler + +[[autodoc]] DFlashTokenDiffusionScheduler + +[[autodoc]] SDARTokenDiffusionScheduler + All schedulers are built from the base [`SchedulerMixin`] class which implements low level utilities shared by all schedulers. ## SchedulerMixin diff --git a/docs/source/en/api/schedulers/token_diffusion.md b/docs/source/en/api/schedulers/token_diffusion.md new file mode 100644 index 000000000000..fe5305c00ae5 --- /dev/null +++ b/docs/source/en/api/schedulers/token_diffusion.md @@ -0,0 +1,22 @@ + + +# TokenDiffusionScheduler + +`TokenDiffusionScheduler` defines discrete token diffusion updates over categorical token IDs and supports multiple +forward processes and alpha schedules. + +## TokenDiffusionScheduler +[[autodoc]] TokenDiffusionScheduler + +## TokenDiffusionSchedulerOutput +[[autodoc]] schedulers.scheduling_token_diffusion.TokenDiffusionSchedulerOutput diff --git a/examples/discrete_diffusion/README.md b/examples/discrete_diffusion/README.md new file mode 100644 index 000000000000..9da3d084e846 --- /dev/null +++ b/examples/discrete_diffusion/README.md @@ -0,0 +1,201 @@ +# Discrete Token Diffusion (Experimental) + +This folder contains **training examples** for *discrete diffusion over token IDs* (language-model style), built to follow the `diffusers` + `accelerate` training conventions. + +## Quickstart: block refinement with Qwen (causal LM) + +If you want a causal-LM example, start here. This trains block refinement with a CAP-style confidence loss. + +```bash +accelerate launch examples/discrete_diffusion/train_block_refinement_qwen_cap.py \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --text_column text \ + --output_dir qwen-block-refinement-output \ + --max_train_steps 1000 \ + --prompt_length 32 \ + --block_length 32 \ + --lambda_conf 2.0 \ + --conf_temperature 0.5 +``` + +## MDLM-style absorbing diffusion + +`train_mdlm.py` trains a masked/absorbing discrete diffusion model: +- Forward process: with probability `1 - alpha(t)`, replace tokens with `mask_token_id` +- Noise schedule: log-linear `alpha(t) = 1 - (1 - eps) * t` +- Loss: weighted token reconstruction NLL over masked positions + +### Run + +```bash +accelerate launch examples/discrete_diffusion/train_mdlm.py \ + --model_name_or_path bert-base-uncased \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --output_dir mdlm-output \ + --max_train_steps 1000 \ + --lambda_conf 0.0 \ + --conf_temperature 1.0 +``` + +The script saves: +- `transformers` model + tokenizer +- `diffusers.TokenDiffusionScheduler` + +into `--output_dir` checkpoints and `--output_dir/final`. + +### Sample + +```bash +python examples/discrete_diffusion/sample_mdlm.py \ + --checkpoint_path mdlm-output/final \ + --num_samples 4 \ + --seq_len 64 \ + --num_inference_steps 128 +``` + +## Block-wise sampling + +Block-wise sampling updates the sequence in chunks, refining only the active block at a time. + +```bash +python examples/discrete_diffusion/sample_block_token_diffusion.py \ + --checkpoint_path mdlm-output/final \ + --num_samples 4 \ + --seq_len 256 \ + --block_size 32 \ + --num_inference_steps 64 \ + --top_p 0.9 +``` + +## Block refinement (commit-by-confidence) with Qwen + +For causal LMs that only support a 2D `attention_mask`, run `BlockRefinementPipeline` with `--attention_mask_mode 2d`. + +### Train + +```bash +accelerate launch examples/discrete_diffusion/train_block_refinement_qwen_cap.py \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --text_column text \ + --output_dir qwen-block-refinement-output \ + --max_train_steps 1000 \ + --prompt_length 32 \ + --block_length 32 \ + --lambda_conf 2.0 \ + --conf_temperature 0.5 +``` + +If you don't want to download a dataset, you can use random-token data: + +```bash +accelerate launch examples/discrete_diffusion/train_block_refinement_qwen_cap.py \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --output_dir qwen-block-refinement-output \ + --use_dummy_data \ + --num_dummy_samples 2048 +``` + +### Sample + +```bash +python examples/discrete_diffusion/sample_block_refinement.py \ + --checkpoint_path qwen-block-refinement-output/final \ + --device cuda \ + --attention_mask_mode 2d \ + --prompt "Write a short paragraph about diffusion models." \ + --gen_length 128 +``` + +## DFlash speculative decoding + +Use a diffusion draft model with a target causal LM for block-wise speculative decoding. + +```bash +python examples/discrete_diffusion/sample_dflash.py \ + --draft_model_id z-lab/Qwen3-8B-DFlash-b16 \ + --target_model_id Qwen/Qwen3-8B \ + --prompt "How many positive whole-number divisors does 196 have?" \ + --max_new_tokens 256 \ + --use_chat_template \ + --add_generation_prompt +``` + +## SDAR block diffusion decoding + +Run SDAR-style block diffusion sampling with remasking strategies. + +```bash +python examples/discrete_diffusion/sample_sdar.py \ + --model_id JetLM/SDAR-1.7B-Chat \ + --prompt "Explain what reinforcement learning is in simple terms." \ + --max_new_tokens 256 \ + --block_length 4 \ + --denoising_steps 4 \ + --remasking_strategy low_confidence_dynamic \ + --confidence_threshold 0.9 \ + --use_chat_template \ + --add_generation_prompt +``` + +### Fine-tune (draft model) + +```bash +accelerate launch examples/discrete_diffusion/train_dflash.py \ + --draft_model_id z-lab/Qwen3-4B-DFlash-b16 \ + --target_model_id Qwen/Qwen3-4B \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --output_dir dflash-output \ + --max_train_steps 100 \ + --logging_steps 10 +``` + +## Hybrid sampling + +Hybrid sampling uses a different transition kernel than absorbing/uniform diffusion and requires a compatible scheduler +configuration saved in the checkpoint directory. + +```bash +python examples/discrete_diffusion/sample_hybrid_token_diffusion.py \ + --checkpoint_path hybrid-output/final \ + --num_samples 4 \ + --seq_len 256 \ + --num_inference_steps 64 +``` + +### Train + +```bash +accelerate launch examples/discrete_diffusion/train_hybrid_token_diffusion.py \ + --model_name_or_path bert-base-uncased \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --output_dir hybrid-output \ + --max_train_steps 1000 \ + --lambda_conf 0.0 \ + --conf_temperature 1.0 +``` + +## UDLM-style uniform diffusion + +`train_udlm.py` trains a uniform token diffusion model: +- Forward process: with probability `1 - alpha(t)`, replace tokens with a uniform random token +- Noise schedule: configurable via `--alpha_schedule` (`log_linear`, `linear`, `cosine`, `geometric`) +- Loss: diffusion loss for uniform token diffusion + +### Run + +```bash +accelerate launch examples/discrete_diffusion/train_udlm.py \ + --model_name_or_path bert-base-uncased \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --output_dir udlm-output \ + --max_train_steps 1000 \ + --exclude_mask_from_uniform +``` diff --git a/examples/discrete_diffusion/sample_block_refinement.py b/examples/discrete_diffusion/sample_block_refinement.py new file mode 100644 index 000000000000..72e728d876f8 --- /dev/null +++ b/examples/discrete_diffusion/sample_block_refinement.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python + +import argparse + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from diffusers import BlockRefinementPipeline + + +def main(): + parser = argparse.ArgumentParser(description="Sample with BlockRefinementPipeline using a transformers causal LM.") + parser.add_argument("--checkpoint_path", type=str, required=True) + parser.add_argument("--cache_dir", type=str, default=None) + parser.add_argument("--prompt", type=str, default="Write a short paragraph about diffusion models.") + parser.add_argument("--gen_length", type=int, default=128) + parser.add_argument("--block_length", type=int, default=32) + parser.add_argument("--steps", type=int, default=32) + parser.add_argument("--temperature", type=float, default=0.0) + parser.add_argument("--top_p", type=float, default=1.0) + parser.add_argument("--top_k", type=int, default=0) + parser.add_argument("--threshold", type=float, default=0.95) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + parser.add_argument("--attention_mask_mode", type=str, default="2d", choices=["auto", "4d", "2d", "none"]) + + args = parser.parse_args() + + tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, use_fast=True, cache_dir=args.cache_dir) + model = AutoModelForCausalLM.from_pretrained( + args.checkpoint_path, + torch_dtype=torch.bfloat16 if args.device.startswith("cuda") else torch.float32, + cache_dir=args.cache_dir, + ) + model.to(args.device) + model.eval() + + if tokenizer.mask_token_id is None: + raise ValueError("Tokenizer must have `mask_token_id` for block refinement sampling.") + + pipe = BlockRefinementPipeline(model=model, tokenizer=tokenizer).to(args.device) + gen = torch.Generator(device=args.device).manual_seed(args.seed) + + prompt_ids = tokenizer(args.prompt, return_tensors="pt")["input_ids"].to(args.device) + out = pipe( + prompt_ids=prompt_ids, + gen_length=int(args.gen_length), + block_length=int(args.block_length), + steps=int(args.steps), + temperature=float(args.temperature), + top_p=None if args.top_p >= 1.0 else float(args.top_p), + top_k=None if args.top_k <= 0 else int(args.top_k), + threshold=float(args.threshold), + eos_early_stop=True, + eos_token_id=int(tokenizer.eos_token_id) if tokenizer.eos_token_id is not None else None, + mask_token_id=int(tokenizer.mask_token_id), + attention_mask_mode=args.attention_mask_mode, + generator=gen, + return_text=True, + ) + + print(out.texts[0] if out.texts is not None else tokenizer.decode(out.sequences[0], skip_special_tokens=True)) + + +if __name__ == "__main__": + main() diff --git a/examples/discrete_diffusion/sample_block_token_diffusion.py b/examples/discrete_diffusion/sample_block_token_diffusion.py new file mode 100644 index 000000000000..5fb7c5a3d121 --- /dev/null +++ b/examples/discrete_diffusion/sample_block_token_diffusion.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python + +import argparse +from typing import Optional + +import torch +from transformers import AutoModelForMaskedLM, AutoTokenizer + +from diffusers import BlockTokenDiffusionPipeline, BlockTokenDiffusionScheduler + + +def parse_args(): + parser = argparse.ArgumentParser(description="Sample with block-wise token diffusion.") + parser.add_argument( + "--checkpoint_path", type=str, required=True, help="Path saved by train scripts (or compatible)." + ) + parser.add_argument("--prompt", type=str, default=None, help="Optional prompt; will be used as a fixed prefix.") + parser.add_argument("--num_samples", type=int, default=4) + parser.add_argument("--seq_len", type=int, default=64) + parser.add_argument("--block_size", type=int, default=16) + parser.add_argument("--num_inference_steps", type=int, default=32) + parser.add_argument("--top_p", type=float, default=1.0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + parser.add_argument("--inject_start_token", action="store_true") + return parser.parse_args() + + +def main(): + args = parse_args() + device = torch.device(args.device) + + tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, use_fast=True) + model = AutoModelForMaskedLM.from_pretrained(args.checkpoint_path).to(device) + scheduler = BlockTokenDiffusionScheduler.from_pretrained(args.checkpoint_path) + + pipe = BlockTokenDiffusionPipeline(model=model, scheduler=scheduler, tokenizer=tokenizer).to(device) + model.eval() + + generator: Optional[torch.Generator] = torch.Generator(device=device).manual_seed(args.seed) + + prefix_ids = None + if args.prompt is not None: + encoded = tokenizer(args.prompt, return_tensors="pt", add_special_tokens=True) + prefix_ids = encoded["input_ids"].to(device=device, dtype=torch.long) + if prefix_ids.shape[1] > args.seq_len: + raise ValueError(f"--seq_len ({args.seq_len}) must be >= prompt length ({prefix_ids.shape[1]}).") + + out = pipe( + batch_size=args.num_samples, + seq_len=args.seq_len, + block_size=args.block_size, + num_inference_steps=args.num_inference_steps, + generator=generator, + prefix_ids=prefix_ids, + inject_start_token=args.inject_start_token, + top_p=args.top_p, + return_text=True, + ) + + for i, t in enumerate(out.texts or []): + print(f"[{i}] {t}") + + +if __name__ == "__main__": + main() diff --git a/examples/discrete_diffusion/sample_dflash.py b/examples/discrete_diffusion/sample_dflash.py new file mode 100644 index 000000000000..2c172dc04464 --- /dev/null +++ b/examples/discrete_diffusion/sample_dflash.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Sample script for DFlash speculative decoding. + +Example: + python sample_dflash.py \ + --draft_model_id z-lab/Qwen3-8B-DFlash-b16 \ + --target_model_id Qwen/Qwen3-8B \ + --prompt "How many positive whole-number divisors does 196 have?" \ + --max_new_tokens 256 +""" + +import argparse + +import torch + +from diffusers import DFlashPipeline + + +def main(): + parser = argparse.ArgumentParser(description="Run DFlash speculative decoding.") + parser.add_argument( + "--draft_model_id", + type=str, + default="z-lab/Qwen3-8B-DFlash-b16", + help="Draft model ID or local path.", + ) + parser.add_argument( + "--target_model_id", + type=str, + default="Qwen/Qwen3-8B", + help="Target model ID or local path.", + ) + parser.add_argument( + "--prompt", + type=str, + default="How many positive whole-number divisors does 196 have?", + help="Prompt text to generate from.", + ) + parser.add_argument( + "--max_new_tokens", + type=int, + default=2048, + help="Maximum number of new tokens to generate.", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.0, + help="Sampling temperature.", + ) + parser.add_argument( + "--use_chat_template", + action="store_true", + help="Use the tokenizer chat template for the prompt.", + ) + parser.add_argument( + "--add_generation_prompt", + action="store_true", + help="Add the generation prompt when using the chat template.", + ) + parser.add_argument( + "--enable_thinking", + action="store_true", + help="Enable chat-template thinking mode if supported by the tokenizer.", + ) + parser.add_argument( + "--mask_token", + type=str, + default="<|MASK|>", + help="Mask token to add if the tokenizer does not define one.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to run inference on.", + ) + parser.add_argument( + "--dtype", + type=str, + default="auto", + choices=["auto", "float32", "float16", "bfloat16"], + help="Model dtype.", + ) + + args = parser.parse_args() + + dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16} + torch_dtype = dtype_map.get(args.dtype) + + print(f"Loading draft model: {args.draft_model_id}") + print(f"Loading target model: {args.target_model_id}") + dtype_arg = torch_dtype if torch_dtype is not None else "auto" + pipe = DFlashPipeline.from_pretrained( + draft_model_id=args.draft_model_id, + target_model_id=args.target_model_id, + mask_token=args.mask_token, + draft_model_kwargs={ + "trust_remote_code": True, + "dtype": dtype_arg, + "device_map": args.device, + }, + target_model_kwargs={ + "dtype": dtype_arg, + "device_map": args.device, + }, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + ) + + chat_kwargs = {"enable_thinking": args.enable_thinking} + + print(f"\nPrompt: {args.prompt}") + output = pipe( + prompt=args.prompt, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + use_chat_template=args.use_chat_template, + add_generation_prompt=args.add_generation_prompt, + chat_template_kwargs=chat_kwargs, + ) + + print("\nGenerated text:") + print(output.texts[0]) + print(f"\nGenerated {output.sequences.shape[1]} tokens") + + +if __name__ == "__main__": + main() diff --git a/examples/discrete_diffusion/sample_hybrid_token_diffusion.py b/examples/discrete_diffusion/sample_hybrid_token_diffusion.py new file mode 100644 index 000000000000..81f35ae5b9c6 --- /dev/null +++ b/examples/discrete_diffusion/sample_hybrid_token_diffusion.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python + +import argparse +from typing import Optional + +import torch +from transformers import AutoModelForMaskedLM, AutoTokenizer + +from diffusers import HybridTokenDiffusionPipeline, HybridTokenDiffusionScheduler + + +def parse_args(): + parser = argparse.ArgumentParser(description="Sample with a hybrid-transition token diffusion scheduler.") + parser.add_argument( + "--checkpoint_path", type=str, required=True, help="Path containing a model + scheduler config." + ) + parser.add_argument("--prompt", type=str, default=None, help="Optional prompt; will be used as a fixed prefix.") + parser.add_argument("--num_samples", type=int, default=4) + parser.add_argument("--seq_len", type=int, default=64) + parser.add_argument("--num_inference_steps", type=int, default=64) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + parser.add_argument("--inject_start_token", action="store_true") + return parser.parse_args() + + +def main(): + args = parse_args() + device = torch.device(args.device) + + tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, use_fast=True) + model = AutoModelForMaskedLM.from_pretrained(args.checkpoint_path).to(device) + scheduler = HybridTokenDiffusionScheduler.from_pretrained(args.checkpoint_path) + + pipe = HybridTokenDiffusionPipeline(model=model, scheduler=scheduler, tokenizer=tokenizer).to(device) + model.eval() + + generator: Optional[torch.Generator] = torch.Generator(device=device).manual_seed(args.seed) + + prefix_ids = None + if args.prompt is not None: + encoded = tokenizer(args.prompt, return_tensors="pt", add_special_tokens=True) + prefix_ids = encoded["input_ids"].to(device=device, dtype=torch.long) + if prefix_ids.shape[1] > args.seq_len: + raise ValueError(f"--seq_len ({args.seq_len}) must be >= prompt length ({prefix_ids.shape[1]}).") + + out = pipe( + batch_size=args.num_samples, + seq_len=args.seq_len, + num_inference_steps=args.num_inference_steps, + generator=generator, + prefix_ids=prefix_ids, + inject_start_token=args.inject_start_token, + return_text=True, + ) + + for i, t in enumerate(out.texts or []): + print(f"[{i}] {t}") + + +if __name__ == "__main__": + main() diff --git a/examples/discrete_diffusion/sample_llada2.py b/examples/discrete_diffusion/sample_llada2.py new file mode 100644 index 000000000000..3e0fd0434008 --- /dev/null +++ b/examples/discrete_diffusion/sample_llada2.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Sample script for LLaDA2-style discrete diffusion text generation. + +This script demonstrates how to use the LLaDA2Pipeline for text generation +using block-wise iterative refinement. + +Example usage: + python sample_llada2.py --model_id inclusionAI/LLaDA2.0-mini --prompt "What is the capital of France?" + python sample_llada2.py --model_id inclusionAI/LLaDA2.0-flash-CAP --prompt "Explain quantum computing." --temperature 0.7 +""" + +import argparse + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from diffusers import LLaDA2Pipeline +from diffusers.hooks import apply_group_offloading + + +def main(): + parser = argparse.ArgumentParser( + description="Generate text using LLaDA2Pipeline with block-wise discrete diffusion." + ) + parser.add_argument( + "--model_id", + type=str, + default="inclusionAI/LLaDA2.0-mini", + help="HuggingFace model ID or path to local model.", + ) + parser.add_argument( + "--prompt", + type=str, + default="Why does Camus think that Sisyphus is happy?", + help="Text prompt to generate from.", + ) + parser.add_argument( + "--gen_length", + type=int, + default=2048, + help="Number of tokens to generate.", + ) + parser.add_argument( + "--block_length", + type=int, + default=32, + help="Size of each generation block.", + ) + parser.add_argument( + "--steps", + type=int, + default=32, + help="Number of refinement steps per block.", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.0, + help="Sampling temperature (0.0 for greedy).", + ) + parser.add_argument( + "--top_p", + type=float, + default=None, + help="Nucleus sampling probability threshold.", + ) + parser.add_argument( + "--top_k", + type=int, + default=None, + help="Top-k sampling parameter.", + ) + parser.add_argument( + "--threshold", + type=float, + default=0.95, + help="Confidence threshold for committing tokens.", + ) + parser.add_argument( + "--sampling_method", + type=str, + default="multinomial", + choices=["auto", "greedy", "multinomial"], + help="Sampling method for block refinement.", + ) + parser.add_argument( + "--eos_early_stop", + action="store_true", + help="Stop generation early when EOS token is generated.", + ) + parser.add_argument( + "--use_chat_template", + action="store_true", + help="Use the tokenizer chat template for the prompt.", + ) + parser.add_argument( + "--add_generation_prompt", + action="store_true", + help="Add the generation prompt when using the chat template.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to run inference on.", + ) + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + choices=["float32", "float16", "bfloat16"], + help="Model dtype.", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Random seed for reproducibility.", + ) + parser.add_argument( + "--offload", + type=str, + default=None, + choices=["group", "sequential"], + help="Memory offloading strategy: 'group' for group offloading (faster), 'sequential' for sequential CPU offload (slower but lower memory).", + ) + + args = parser.parse_args() + + # Parse dtype + dtype_map = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + } + torch_dtype = dtype_map[args.dtype] + + print(f"Loading model: {args.model_id}") + tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True) + + # Load model with appropriate memory settings based on offload strategy + if args.offload == "group": + # For group offloading, load to CPU first then apply hooks + print("Using group offloading for memory efficiency...") + model = AutoModelForCausalLM.from_pretrained( + args.model_id, + trust_remote_code=True, + torch_dtype=torch_dtype, + low_cpu_mem_usage=True, + ) + # Apply group offloading with CUDA streams for better performance + onload_device = torch.device(args.device) + offload_device = torch.device("cpu") + apply_group_offloading( + model, + onload_device=onload_device, + offload_device=offload_device, + offload_type="leaf_level", + use_stream=True, + ) + elif args.offload == "sequential": + # For sequential offloading, load to CPU first + print("Using sequential CPU offloading (slower but lower memory)...") + model = AutoModelForCausalLM.from_pretrained( + args.model_id, + trust_remote_code=True, + torch_dtype=torch_dtype, + low_cpu_mem_usage=True, + ) + # Sequential offloading will be applied via pipeline + else: + # Default: use device_map="auto" for automatic memory management + model = AutoModelForCausalLM.from_pretrained( + args.model_id, + trust_remote_code=True, + torch_dtype=torch_dtype, + device_map="auto", + low_cpu_mem_usage=True, + ) + model.eval() + + # Create pipeline + pipe = LLaDA2Pipeline(model=model, tokenizer=tokenizer) + + # Apply sequential CPU offload if requested + if args.offload == "sequential": + pipe.enable_sequential_cpu_offload() + + # Set up generator for reproducibility + generator = None + if args.seed is not None: + generator = torch.Generator(device=args.device).manual_seed(args.seed) + + print(f"\nPrompt: {args.prompt}") + print(f"Generating {args.gen_length} tokens with block_length={args.block_length}, steps={args.steps}") + print("-" * 50) + + # Generate + output = pipe( + prompt=args.prompt, + use_chat_template=args.use_chat_template, + add_generation_prompt=args.add_generation_prompt, + gen_length=args.gen_length, + block_length=args.block_length, + steps=args.steps, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + threshold=args.threshold, + sampling_method=args.sampling_method, + eos_early_stop=args.eos_early_stop, + generator=generator, + ) + + print("\nGenerated text:") + print(output.texts[0]) + + print(f"\nGenerated {output.sequences.shape[1]} tokens") + + +if __name__ == "__main__": + main() diff --git a/examples/discrete_diffusion/sample_mdlm.py b/examples/discrete_diffusion/sample_mdlm.py new file mode 100644 index 000000000000..b6add10812f6 --- /dev/null +++ b/examples/discrete_diffusion/sample_mdlm.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from typing import Optional + +import torch +from transformers import AutoModelForMaskedLM, AutoTokenizer + +from diffusers import TokenDiffusionScheduler + + +def parse_args(): + parser = argparse.ArgumentParser(description="Sample from an absorbing token diffusion LM (MDLM-style).") + parser.add_argument( + "--checkpoint_path", type=str, required=True, help="Path saved by train_mdlm.py (or compatible)." + ) + parser.add_argument("--num_samples", type=int, default=4) + parser.add_argument("--seq_len", type=int, default=64) + parser.add_argument("--num_inference_steps", type=int, default=128) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + parser.add_argument("--inject_bos", action="store_true") + return parser.parse_args() + + +@torch.no_grad() +def sample( + model, + tokenizer, + scheduler: TokenDiffusionScheduler, + *, + num_samples: int, + seq_len: int, + num_inference_steps: int, + generator: Optional[torch.Generator], + inject_bos: bool, + device: torch.device, +): + scheduler.set_timesteps(num_inference_steps, device=device) + + x = torch.full((num_samples, seq_len), scheduler.mask_token_id, dtype=torch.long, device=device) + attention_mask = torch.ones_like(x, dtype=torch.long) + + if inject_bos and tokenizer.bos_token_id is not None: + x[:, 0] = int(tokenizer.bos_token_id) + + for t in scheduler.timesteps: + logits = model(input_ids=x, attention_mask=attention_mask).logits # [B, L, V] + x = scheduler.step(logits, t, x, generator=generator, return_dict=True).prev_sample + + if inject_bos and tokenizer.bos_token_id is not None: + x[:, 0] = int(tokenizer.bos_token_id) + + return x + + +def main(): + args = parse_args() + device = torch.device(args.device) + + tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, use_fast=True) + model = AutoModelForMaskedLM.from_pretrained(args.checkpoint_path).to(device) + scheduler = TokenDiffusionScheduler.from_pretrained(args.checkpoint_path) + + model.eval() + + gen = torch.Generator(device=device) + gen.manual_seed(args.seed) + + samples = sample( + model, + tokenizer, + scheduler, + num_samples=args.num_samples, + seq_len=args.seq_len, + num_inference_steps=args.num_inference_steps, + generator=gen, + inject_bos=args.inject_bos, + device=device, + ) + + texts = tokenizer.batch_decode(samples, skip_special_tokens=True) + for i, t in enumerate(texts): + print(f"[{i}] {t}") + + +if __name__ == "__main__": + main() diff --git a/examples/discrete_diffusion/train_block_refinement_cap.py b/examples/discrete_diffusion/train_block_refinement_cap.py new file mode 100644 index 000000000000..21fabc8d7da1 --- /dev/null +++ b/examples/discrete_diffusion/train_block_refinement_cap.py @@ -0,0 +1,387 @@ +#!/usr/bin/env python +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import math +import os +from dataclasses import asdict, dataclass +from typing import Optional, Tuple + +import torch +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from torch.utils.data import DataLoader, Dataset + +from diffusers import BlockRefinementPipeline +from diffusers.training_utils import compute_confidence_aware_loss + + +logger = get_logger(__name__) + + +@dataclass +class TrainConfig: + output_dir: str + seed: int + max_train_steps: int + logging_steps: int + checkpointing_steps: int + + per_device_train_batch_size: int + gradient_accumulation_steps: int + learning_rate: float + weight_decay: float + + vocab_size: int + mask_token_id: int + eos_token_id: int + max_length: int + prompt_length: int + + block_length: int + steps: int + lambda_conf: float + conf_temperature: float + temperature: float + threshold: float + + +def parse_args() -> TrainConfig: + parser = argparse.ArgumentParser( + description="Train a block-wise refinement model with a confidence-aware objective (CAP-style)." + ) + + parser.add_argument("--output_dir", type=str, default="block-refinement-output") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--max_train_steps", type=int, default=1000) + parser.add_argument("--logging_steps", type=int, default=50) + parser.add_argument("--checkpointing_steps", type=int, default=500) + + parser.add_argument("--per_device_train_batch_size", type=int, default=64) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--learning_rate", type=float, default=2e-4) + parser.add_argument("--weight_decay", type=float, default=0.0) + + parser.add_argument("--vocab_size", type=int, default=256) + parser.add_argument("--mask_token_id", type=int, default=255) + parser.add_argument("--eos_token_id", type=int, default=254) + parser.add_argument("--max_length", type=int, default=64) + parser.add_argument("--prompt_length", type=int, default=8) + + parser.add_argument("--block_length", type=int, default=16) + parser.add_argument("--steps", type=int, default=16) + parser.add_argument("--lambda_conf", type=float, default=2.0) + parser.add_argument("--conf_temperature", type=float, default=0.5) + + parser.add_argument("--temperature", type=float, default=0.0) + parser.add_argument("--threshold", type=float, default=0.95) + + args = parser.parse_args() + return TrainConfig(**vars(args)) + + +def build_block_attention_mask( + *, + num_blocks: int, + block_length: int, + total_length: int, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + block_mask = torch.tril(torch.ones(num_blocks, num_blocks, device=device, dtype=torch.bool)) + attn = ( + block_mask.repeat_interleave(block_length, dim=0) + .repeat_interleave(block_length, dim=1) + .unsqueeze(0) + .unsqueeze(0) + ) + attn = attn[:, :, :total_length, :total_length] + return torch.where( + attn, torch.zeros((), device=device, dtype=dtype), torch.full((), float("-inf"), device=device, dtype=dtype) + ) + + +def forward_process_semi_ar( + input_ids: torch.LongTensor, + *, + prompt_length: int, + block_length: int, + mask_token_id: int, + generator: Optional[torch.Generator], +) -> Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor, torch.BoolTensor]: + batch_size, seq_len = input_ids.shape + device = input_ids.device + + noisy = input_ids.clone() + noisy_rev = input_ids.clone() + masked = torch.zeros_like(input_ids, dtype=torch.bool) + masked_rev = torch.zeros_like(input_ids, dtype=torch.bool) + + start = int(prompt_length) + for block_start in range(start, seq_len, int(block_length)): + block_end = min(seq_len, block_start + int(block_length)) + seg_len = block_end - block_start + if seg_len <= 0: + continue + + p_mask = torch.rand((batch_size, 1), device=device, generator=generator) + seg = torch.rand((batch_size, seg_len), device=device, generator=generator) < p_mask + seg_rev = ~seg + + masked[:, block_start:block_end] = seg + masked_rev[:, block_start:block_end] = seg_rev + + noisy = torch.where(masked, torch.full_like(noisy, int(mask_token_id)), noisy) + noisy_rev = torch.where(masked_rev, torch.full_like(noisy_rev, int(mask_token_id)), noisy_rev) + return noisy, noisy_rev, masked, masked_rev + + +class RandomTokenDataset(Dataset): + def __init__(self, *, num_samples: int, seq_len: int, vocab_size: int, eos_token_id: int): + self.num_samples = int(num_samples) + self.seq_len = int(seq_len) + self.vocab_size = int(vocab_size) + self.eos_token_id = int(eos_token_id) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + del idx + # Keep EOS out of the training distribution to avoid trivial early-stops during sampling. + ids = torch.randint(0, self.vocab_size - 2, (self.seq_len,), dtype=torch.long) + return {"input_ids": ids} + + +class TinyBlockRefinementLM(torch.nn.Module): + def __init__(self, *, vocab_size: int, hidden_size: int = 128, num_heads: int = 4, num_layers: int = 4): + super().__init__() + self.vocab_size = int(vocab_size) + self.hidden_size = int(hidden_size) + + self.token_emb = torch.nn.Embedding(self.vocab_size, self.hidden_size) + self.pos_emb = torch.nn.Embedding(2048, self.hidden_size) + enc_layer = torch.nn.TransformerEncoderLayer( + d_model=self.hidden_size, + nhead=int(num_heads), + dim_feedforward=self.hidden_size * 4, + dropout=0.0, + activation="gelu", + batch_first=True, + norm_first=True, + ) + self.encoder = torch.nn.TransformerEncoder(enc_layer, num_layers=int(num_layers)) + self.lm_head = torch.nn.Linear(self.hidden_size, self.vocab_size, bias=False) + + @property + def dtype(self): + return next(self.parameters()).dtype + + def forward(self, input_ids, attention_mask=None, position_ids=None, **kwargs): + if position_ids is None: + position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand_as(input_ids) + + x = self.token_emb(input_ids) + self.pos_emb(position_ids) + + attn_mask = None + if attention_mask is not None: + if attention_mask.ndim == 4: + attn_mask = attention_mask[0, 0] + elif attention_mask.ndim == 2: + attn_mask = attention_mask + else: + raise ValueError(f"Unsupported `attention_mask` shape: {attention_mask.shape}") + attn_mask = attn_mask.to(dtype=torch.float32) + + hidden = self.encoder(x, mask=attn_mask) + logits = self.lm_head(hidden) + return type("Output", (), {"logits": logits}) + + +def save_checkpoint(output_dir: str, *, model: torch.nn.Module, cfg: TrainConfig): + os.makedirs(output_dir, exist_ok=True) + torch.save(model.state_dict(), os.path.join(output_dir, "pytorch_model.bin")) + with open(os.path.join(output_dir, "training_config.json"), "w", encoding="utf-8") as f: + json.dump(asdict(cfg), f, indent=2, sort_keys=True) + + +def main(): + cfg = parse_args() + if cfg.mask_token_id >= cfg.vocab_size: + raise ValueError("`mask_token_id` must be < `vocab_size`.") + if cfg.eos_token_id >= cfg.vocab_size: + raise ValueError("`eos_token_id` must be < `vocab_size`.") + if cfg.prompt_length >= cfg.max_length: + raise ValueError("`prompt_length` must be < `max_length`.") + + project_config = ProjectConfiguration(project_dir=cfg.output_dir, logging_dir=os.path.join(cfg.output_dir, "logs")) + accelerator = Accelerator( + gradient_accumulation_steps=cfg.gradient_accumulation_steps, + project_config=project_config, + ) + if accelerator.is_main_process: + os.makedirs(cfg.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + set_seed(cfg.seed) + logger.info("Training configuration: %s", asdict(cfg)) + + dataset = RandomTokenDataset( + num_samples=max(cfg.max_train_steps * cfg.per_device_train_batch_size, 4096), + seq_len=cfg.max_length, + vocab_size=cfg.vocab_size, + eos_token_id=cfg.eos_token_id, + ) + dataloader = DataLoader(dataset, batch_size=cfg.per_device_train_batch_size, shuffle=True, drop_last=True) + + model = TinyBlockRefinementLM(vocab_size=cfg.vocab_size) + pipe = BlockRefinementPipeline(model=model, tokenizer=None) + + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay) + + num_update_steps_per_epoch = math.ceil(len(dataloader) / cfg.gradient_accumulation_steps) + num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch) + + model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) + pipe = pipe.to(accelerator.device) + + global_step = 0 + model.train() + + for _epoch in range(num_train_epochs): + for batch in dataloader: + with accelerator.accumulate(model): + input_ids = batch["input_ids"] + + # Build the same attention mask that the sampler uses. + prompt_len = int(cfg.prompt_length) + num_blocks = (prompt_len + int(cfg.max_length - prompt_len) + int(cfg.block_length) - 1) // int( + cfg.block_length + ) + total_length = int(num_blocks) * int(cfg.block_length) + total_length = max(total_length, int(cfg.max_length)) + attn_mask = build_block_attention_mask( + num_blocks=(total_length + int(cfg.block_length) - 1) // int(cfg.block_length), + block_length=int(cfg.block_length), + total_length=int(cfg.max_length), + device=input_ids.device, + dtype=torch.bfloat16 if input_ids.device.type == "cuda" else torch.float32, + ) + position_ids = ( + torch.arange(int(cfg.max_length), device=input_ids.device, dtype=torch.long) + .unsqueeze(0) + .expand_as(input_ids) + ) + + gen = None + if accelerator.is_local_main_process: + gen = torch.Generator(device=input_ids.device).manual_seed(cfg.seed + global_step) + + noisy, noisy_rev, masked, masked_rev = forward_process_semi_ar( + input_ids, + prompt_length=prompt_len, + block_length=int(cfg.block_length), + mask_token_id=int(cfg.mask_token_id), + generator=gen, + ) + + logits = model(noisy, attention_mask=attn_mask, position_ids=position_ids).logits + logits_rev = model(noisy_rev, attention_mask=attn_mask, position_ids=position_ids).logits + + # Do not allow predicting mask_id. + logits = logits.clone() + logits[..., int(cfg.mask_token_id)] = torch.finfo(logits.dtype).min + logits_rev = logits_rev.clone() + logits_rev[..., int(cfg.mask_token_id)] = torch.finfo(logits_rev.dtype).min + + labels = input_ids.clone() + labels[~masked] = -100 + labels_rev = input_ids.clone() + labels_rev[~masked_rev] = -100 + + weights = masked.to(dtype=logits.dtype) + weights_rev = masked_rev.to(dtype=logits.dtype) + + loss, loss_sft, loss_conf = compute_confidence_aware_loss( + logits, + labels, + lambda_conf=cfg.lambda_conf, + temperature=cfg.conf_temperature, + per_token_weights=weights, + ) + loss_rev, loss_sft_rev, loss_conf_rev = compute_confidence_aware_loss( + logits_rev, + labels_rev, + lambda_conf=cfg.lambda_conf, + temperature=cfg.conf_temperature, + per_token_weights=weights_rev, + ) + + total_loss = loss + loss_rev + accelerator.backward(total_loss) + optimizer.step() + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + global_step += 1 + + if global_step % cfg.logging_steps == 0 and accelerator.is_main_process: + logger.info( + "step=%d loss=%.4f sft=%.4f conf=%.4f", + global_step, + total_loss.item(), + (loss_sft + loss_sft_rev).item(), + (loss_conf + loss_conf_rev).item(), + ) + + if cfg.checkpointing_steps > 0 and global_step % cfg.checkpointing_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + save_dir = os.path.join(cfg.output_dir, f"checkpoint-{global_step}") + save_checkpoint(save_dir, model=accelerator.unwrap_model(model), cfg=cfg) + + if global_step >= cfg.max_train_steps: + break + + if global_step >= cfg.max_train_steps: + break + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + final_dir = os.path.join(cfg.output_dir, "final") + save_checkpoint(final_dir, model=accelerator.unwrap_model(model), cfg=cfg) + + # Quick sampler smoke to ensure the pipeline runs with the trained weights. + out = pipe( + prompt_ids=torch.randint(0, cfg.vocab_size - 2, (1, cfg.prompt_length), device=accelerator.device), + gen_length=int(cfg.max_length - cfg.prompt_length), + block_length=int(cfg.block_length), + steps=int(cfg.steps), + temperature=float(cfg.temperature), + threshold=float(cfg.threshold), + eos_early_stop=False, + eos_token_id=int(cfg.eos_token_id), + mask_token_id=int(cfg.mask_token_id), + return_text=False, + ) + logger.info("sample shape=%s", tuple(out.sequences.shape)) + + logger.info("Done.") + + +if __name__ == "__main__": + main() diff --git a/examples/discrete_diffusion/train_block_refinement_qwen_cap.py b/examples/discrete_diffusion/train_block_refinement_qwen_cap.py new file mode 100644 index 000000000000..5149f3ba61d0 --- /dev/null +++ b/examples/discrete_diffusion/train_block_refinement_qwen_cap.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import math +import os +from dataclasses import asdict, dataclass +from typing import Dict, Optional, Tuple + +import torch +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, get_scheduler + +from diffusers.training_utils import compute_confidence_aware_loss + + +logger = get_logger(__name__) + + +@dataclass +class TrainConfig: + model_name_or_path: str + dataset_name: str + dataset_config_name: Optional[str] + text_column: str + cache_dir: Optional[str] + use_dummy_data: bool + num_dummy_samples: int + + output_dir: str + seed: int + max_train_steps: int + checkpointing_steps: int + logging_steps: int + + per_device_train_batch_size: int + gradient_accumulation_steps: int + learning_rate: float + weight_decay: float + lr_scheduler: str + lr_warmup_steps: int + + max_length: int + prompt_length: int + block_length: int + + lambda_conf: float + conf_temperature: float + + +def parse_args() -> TrainConfig: + parser = argparse.ArgumentParser(description="Train block-refinement with a confidence-aware loss on a causal LM.") + + parser.add_argument("--model_name_or_path", type=str, default="Qwen/Qwen2.5-0.5B") + parser.add_argument("--dataset_name", type=str, default="wikitext") + parser.add_argument("--dataset_config_name", type=str, default="wikitext-2-raw-v1") + parser.add_argument("--text_column", type=str, default="text") + parser.add_argument("--cache_dir", type=str, default=None) + parser.add_argument("--use_dummy_data", action="store_true", help="Use random-token data instead of downloading.") + parser.add_argument("--num_dummy_samples", type=int, default=2048) + + parser.add_argument("--output_dir", type=str, default="qwen-block-refinement-output") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--max_train_steps", type=int, default=1000) + parser.add_argument("--checkpointing_steps", type=int, default=500) + parser.add_argument("--logging_steps", type=int, default=50) + + parser.add_argument("--per_device_train_batch_size", type=int, default=1) + parser.add_argument("--gradient_accumulation_steps", type=int, default=8) + parser.add_argument("--learning_rate", type=float, default=2e-5) + parser.add_argument("--weight_decay", type=float, default=0.0) + parser.add_argument( + "--lr_scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_with_restarts"] + ) + parser.add_argument("--lr_warmup_steps", type=int, default=100) + + parser.add_argument("--max_length", type=int, default=256) + parser.add_argument("--prompt_length", type=int, default=32) + parser.add_argument("--block_length", type=int, default=32) + + parser.add_argument("--lambda_conf", type=float, default=2.0) + parser.add_argument("--conf_temperature", type=float, default=0.5) + + args = parser.parse_args() + return TrainConfig(**vars(args)) + + +def tokenize_fn(examples: Dict, tokenizer, text_column: str, max_length: int): + texts = examples[text_column] + texts = [t for t in texts if isinstance(t, str) and len(t.strip()) > 0] + return tokenizer(texts, truncation=True, padding=False, max_length=max_length) + + +class RandomTokenDataset(torch.utils.data.Dataset): + def __init__(self, *, num_samples: int, seq_len: int, vocab_size: int, pad_token_id: int): + self.num_samples = int(num_samples) + self.seq_len = int(seq_len) + self.vocab_size = int(vocab_size) + self.pad_token_id = int(pad_token_id) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + del idx + input_ids = torch.randint(0, self.vocab_size, (self.seq_len,), dtype=torch.long) + attention_mask = torch.ones_like(input_ids) + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +def forward_process_semi_ar( + input_ids: torch.LongTensor, + attention_mask: torch.LongTensor, + *, + prompt_length: int, + block_length: int, + mask_token_id: int, + generator: Optional[torch.Generator], +) -> Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor, torch.BoolTensor]: + batch_size, seq_len = input_ids.shape + device = input_ids.device + + noisy = input_ids.clone() + noisy_rev = input_ids.clone() + masked = torch.zeros_like(input_ids, dtype=torch.bool) + masked_rev = torch.zeros_like(input_ids, dtype=torch.bool) + + # Only mask non-padding positions after the prompt. + valid = attention_mask.to(dtype=torch.bool) + start = int(prompt_length) + for block_start in range(start, seq_len, int(block_length)): + block_end = min(seq_len, block_start + int(block_length)) + seg_len = block_end - block_start + if seg_len <= 0: + continue + + p_mask = torch.rand((batch_size, 1), device=device, generator=generator) + seg = torch.rand((batch_size, seg_len), device=device, generator=generator) < p_mask + seg = seg & valid[:, block_start:block_end] + seg_rev = (~seg) & valid[:, block_start:block_end] + + masked[:, block_start:block_end] = seg + masked_rev[:, block_start:block_end] = seg_rev + + noisy = torch.where(masked, torch.full_like(noisy, int(mask_token_id)), noisy) + noisy_rev = torch.where(masked_rev, torch.full_like(noisy_rev, int(mask_token_id)), noisy_rev) + return noisy, noisy_rev, masked, masked_rev + + +def main(): + cfg = parse_args() + if cfg.prompt_length >= cfg.max_length: + raise ValueError("`prompt_length` must be < `max_length`.") + if cfg.block_length <= 0: + raise ValueError("`block_length` must be > 0.") + + project_config = ProjectConfiguration(project_dir=cfg.output_dir, logging_dir=os.path.join(cfg.output_dir, "logs")) + accelerator = Accelerator( + gradient_accumulation_steps=cfg.gradient_accumulation_steps, + project_config=project_config, + ) + if accelerator.is_main_process: + os.makedirs(cfg.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + set_seed(cfg.seed) + logger.info("Training configuration: %s", asdict(cfg)) + + tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path, use_fast=True, cache_dir=cfg.cache_dir) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + + if tokenizer.mask_token_id is None: + tokenizer.add_special_tokens({"mask_token": "[MASK]"}) + + load_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 + model = AutoModelForCausalLM.from_pretrained( + cfg.model_name_or_path, cache_dir=cfg.cache_dir, torch_dtype=load_dtype + ) + model.resize_token_embeddings(len(tokenizer)) + if load_dtype == torch.float32: + model.to(dtype=torch.float32) + + mask_token_id = int(tokenizer.mask_token_id) + + if cfg.use_dummy_data: + dataset = RandomTokenDataset( + num_samples=cfg.num_dummy_samples, + seq_len=cfg.max_length, + vocab_size=len(tokenizer), + pad_token_id=int(tokenizer.pad_token_id), + ) + train_dataloader = DataLoader( + dataset, + shuffle=True, + batch_size=cfg.per_device_train_batch_size, + drop_last=True, + ) + else: + raw_datasets = load_dataset(cfg.dataset_name, cfg.dataset_config_name, cache_dir=cfg.cache_dir) + if "train" not in raw_datasets: + raise ValueError(f"Dataset {cfg.dataset_name} has no 'train' split.") + + with accelerator.main_process_first(): + tokenized = raw_datasets["train"].map( + lambda ex: tokenize_fn(ex, tokenizer, cfg.text_column, cfg.max_length), + batched=True, + remove_columns=raw_datasets["train"].column_names, + desc="Tokenizing", + ) + + collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt") + train_dataloader = DataLoader( + tokenized, shuffle=True, collate_fn=collator, batch_size=cfg.per_device_train_batch_size, drop_last=True + ) + + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay) + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps) + num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch) + + lr_scheduler = get_scheduler( + name=cfg.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.lr_warmup_steps, + num_training_steps=cfg.max_train_steps, + ) + + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + global_step = 0 + model.train() + + for _epoch in range(num_train_epochs): + for batch in train_dataloader: + with accelerator.accumulate(model): + input_ids = batch["input_ids"] + attention_mask = batch.get("attention_mask", torch.ones_like(input_ids)) + + gen = torch.Generator(device=input_ids.device).manual_seed(cfg.seed + global_step) + noisy, noisy_rev, masked, masked_rev = forward_process_semi_ar( + input_ids, + attention_mask, + prompt_length=int(cfg.prompt_length), + block_length=int(cfg.block_length), + mask_token_id=mask_token_id, + generator=gen, + ) + + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand_as(input_ids) + ) + + logits = model(input_ids=noisy, attention_mask=attention_mask, position_ids=position_ids).logits + logits_rev = model( + input_ids=noisy_rev, attention_mask=attention_mask, position_ids=position_ids + ).logits + + logits = logits.clone() + logits[..., mask_token_id] = torch.finfo(logits.dtype).min + logits_rev = logits_rev.clone() + logits_rev[..., mask_token_id] = torch.finfo(logits_rev.dtype).min + + valid = attention_mask.to(dtype=torch.bool) + masked = masked & valid + masked_rev = masked_rev & valid + + labels = input_ids.clone() + labels[~masked] = -100 + labels_rev = input_ids.clone() + labels_rev[~masked_rev] = -100 + + weights = masked.to(dtype=logits.dtype) + weights_rev = masked_rev.to(dtype=logits.dtype) + + loss, loss_sft, loss_conf = compute_confidence_aware_loss( + logits, + labels, + lambda_conf=cfg.lambda_conf, + temperature=cfg.conf_temperature, + per_token_weights=weights, + ) + loss_rev, loss_sft_rev, loss_conf_rev = compute_confidence_aware_loss( + logits_rev, + labels_rev, + lambda_conf=cfg.lambda_conf, + temperature=cfg.conf_temperature, + per_token_weights=weights_rev, + ) + + total_loss = loss + loss_rev + accelerator.backward(total_loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + global_step += 1 + + if global_step % cfg.logging_steps == 0 and accelerator.is_main_process: + logger.info( + "step=%d loss=%.4f sft=%.4f conf=%.4f lr=%.6g", + global_step, + total_loss.item(), + (loss_sft + loss_sft_rev).item(), + (loss_conf + loss_conf_rev).item(), + lr_scheduler.get_last_lr()[0], + ) + print( + f"step={global_step} loss={total_loss.item():.4f} " + f"sft={(loss_sft + loss_sft_rev).item():.4f} " + f"conf={(loss_conf + loss_conf_rev).item():.4f} " + f"lr={lr_scheduler.get_last_lr()[0]:.6g}" + ) + + if cfg.checkpointing_steps > 0 and global_step % cfg.checkpointing_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + save_dir = os.path.join(cfg.output_dir, f"checkpoint-{global_step}") + os.makedirs(save_dir, exist_ok=True) + accelerator.unwrap_model(model).save_pretrained(save_dir, save_function=accelerator.save) + tokenizer.save_pretrained(save_dir) + + if global_step >= cfg.max_train_steps: + break + + if global_step >= cfg.max_train_steps: + break + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + final_dir = os.path.join(cfg.output_dir, "final") + os.makedirs(final_dir, exist_ok=True) + accelerator.unwrap_model(model).save_pretrained(final_dir, save_function=accelerator.save) + tokenizer.save_pretrained(final_dir) + + logger.info("Done.") + + +if __name__ == "__main__": + main() diff --git a/examples/discrete_diffusion/train_hybrid_token_diffusion.py b/examples/discrete_diffusion/train_hybrid_token_diffusion.py new file mode 100644 index 000000000000..ee23b430f555 --- /dev/null +++ b/examples/discrete_diffusion/train_hybrid_token_diffusion.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python + +import argparse +import math +import os +from dataclasses import asdict, dataclass +from typing import Dict, Optional + +import torch +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers import ( + AutoConfig, + AutoModelForMaskedLM, + AutoTokenizer, + DataCollatorForLanguageModeling, + get_scheduler, +) + +from diffusers import HybridTokenDiffusionScheduler +from diffusers.training_utils import compute_confidence_aware_loss + + +logger = get_logger(__name__) + + +@dataclass +class TrainConfig: + model_name_or_path: str + dataset_name: str + dataset_config_name: Optional[str] + text_column: str + + output_dir: str + seed: int + max_train_steps: int + checkpointing_steps: int + logging_steps: int + + per_device_train_batch_size: int + gradient_accumulation_steps: int + learning_rate: float + weight_decay: float + lr_scheduler: str + lr_warmup_steps: int + + max_length: int + num_train_timesteps: int + t_eps: float + p_uniform: float + gamma: float + lambda_conf: float + conf_temperature: float + + +def parse_args() -> TrainConfig: + parser = argparse.ArgumentParser(description="Train a hybrid-transition token diffusion model with Accelerate.") + + parser.add_argument("--model_name_or_path", type=str, default="bert-base-uncased") + parser.add_argument("--dataset_name", type=str, default="wikitext") + parser.add_argument("--dataset_config_name", type=str, default="wikitext-2-raw-v1") + parser.add_argument("--text_column", type=str, default="text") + + parser.add_argument("--output_dir", type=str, default="hybrid-output") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--max_train_steps", type=int, default=1000) + parser.add_argument("--checkpointing_steps", type=int, default=500) + parser.add_argument("--logging_steps", type=int, default=50) + + parser.add_argument("--per_device_train_batch_size", type=int, default=8) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--learning_rate", type=float, default=5e-5) + parser.add_argument("--weight_decay", type=float, default=0.01) + parser.add_argument( + "--lr_scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_with_restarts"] + ) + parser.add_argument("--lr_warmup_steps", type=int, default=100) + + parser.add_argument("--max_length", type=int, default=256) + parser.add_argument("--num_train_timesteps", type=int, default=1000) + parser.add_argument("--t_eps", type=float, default=1e-4) + parser.add_argument("--p_uniform", type=float, default=0.0) + parser.add_argument("--gamma", type=float, default=1.0) + parser.add_argument( + "--lambda_conf", + type=float, + default=0.0, + help="Optional confidence-aware penalty weight (entropy on correctly predicted tokens).", + ) + parser.add_argument( + "--conf_temperature", + type=float, + default=1.0, + help="Temperature for the confidence term only; lower values sharpen the entropy penalty.", + ) + + args = parser.parse_args() + return TrainConfig(**vars(args)) + + +def tokenize_fn(examples: Dict, tokenizer, text_column: str, max_length: int): + texts = examples[text_column] + texts = [t for t in texts if isinstance(t, str) and len(t.strip()) > 0] + return tokenizer( + texts, + truncation=True, + padding=False, + max_length=max_length, + return_special_tokens_mask=True, + ) + + +def main(): + cfg = parse_args() + + project_config = ProjectConfiguration(project_dir=cfg.output_dir, logging_dir=os.path.join(cfg.output_dir, "logs")) + accelerator = Accelerator( + gradient_accumulation_steps=cfg.gradient_accumulation_steps, + project_config=project_config, + ) + + if accelerator.is_main_process: + os.makedirs(cfg.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + set_seed(cfg.seed) + logger.info("Training configuration: %s", asdict(cfg)) + + tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path, use_fast=True) + if tokenizer.mask_token_id is None: + tokenizer.add_special_tokens({"mask_token": "[MASK]"}) + + config = AutoConfig.from_pretrained(cfg.model_name_or_path) + model = AutoModelForMaskedLM.from_pretrained(cfg.model_name_or_path, config=config) + model.resize_token_embeddings(len(tokenizer)) + + scheduler = HybridTokenDiffusionScheduler( + vocab_size=len(tokenizer), + mask_token_id=int(tokenizer.mask_token_id), + num_train_timesteps=cfg.num_train_timesteps, + t_eps=cfg.t_eps, + p_uniform=cfg.p_uniform, + gamma=cfg.gamma, + ) + + raw_datasets = load_dataset(cfg.dataset_name, cfg.dataset_config_name) + if "train" not in raw_datasets: + raise ValueError(f"Dataset {cfg.dataset_name} has no 'train' split.") + + with accelerator.main_process_first(): + tokenized = raw_datasets["train"].map( + lambda ex: tokenize_fn(ex, tokenizer, cfg.text_column, cfg.max_length), + batched=True, + remove_columns=raw_datasets["train"].column_names, + desc="Tokenizing", + ) + + collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt") + train_dataloader = DataLoader( + tokenized, shuffle=True, collate_fn=collator, batch_size=cfg.per_device_train_batch_size, drop_last=True + ) + + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay) + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps) + num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch) + + lr_scheduler = get_scheduler( + name=cfg.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.lr_warmup_steps, + num_training_steps=cfg.max_train_steps, + ) + + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + global_step = 0 + model.train() + + for epoch in range(num_train_epochs): + for batch in train_dataloader: + with accelerator.accumulate(model): + input_ids = batch["input_ids"] + attention_mask = batch.get("attention_mask", torch.ones_like(input_ids)) + + timesteps = torch.randint( + 0, scheduler.num_train_timesteps, (input_ids.shape[0],), device=input_ids.device, dtype=torch.long + ) + + x_t = scheduler.add_noise(input_ids, noise=None, timesteps=timesteps) + logits = model(input_ids=x_t, attention_mask=attention_mask).logits + + # For this hybrid kernel, we use a simple denoising objective: predict x0 from z_t. + logits = logits.clone() + logits[..., scheduler.mask_token_id] = torch.finfo(logits.dtype).min + + labels = input_ids.clone() + labels[attention_mask.eq(0)] = -100 + per_token_weights = attention_mask.to(dtype=logits.dtype) + loss, loss_sft, loss_conf = compute_confidence_aware_loss( + logits, + labels, + lambda_conf=cfg.lambda_conf, + temperature=cfg.conf_temperature, + per_token_weights=per_token_weights, + ) + + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + global_step += 1 + + if global_step % cfg.logging_steps == 0 and accelerator.is_main_process: + logger.info( + "step=%d loss=%.4f loss_sft=%.4f loss_conf=%.4f lr=%.6g", + global_step, + loss.item(), + loss_sft.item(), + loss_conf.item(), + lr_scheduler.get_last_lr()[0], + ) + + if cfg.checkpointing_steps > 0 and global_step % cfg.checkpointing_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + save_dir = os.path.join(cfg.output_dir, f"checkpoint-{global_step}") + os.makedirs(save_dir, exist_ok=True) + unwrapped = accelerator.unwrap_model(model) + unwrapped.save_pretrained(save_dir, save_function=accelerator.save) + tokenizer.save_pretrained(save_dir) + scheduler.save_pretrained(save_dir) + + if global_step >= cfg.max_train_steps: + break + + if global_step >= cfg.max_train_steps: + break + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + final_dir = os.path.join(cfg.output_dir, "final") + os.makedirs(final_dir, exist_ok=True) + unwrapped = accelerator.unwrap_model(model) + unwrapped.save_pretrained(final_dir, save_function=accelerator.save) + tokenizer.save_pretrained(final_dir) + scheduler.save_pretrained(final_dir) + + logger.info("Done.") + + +if __name__ == "__main__": + main() diff --git a/examples/discrete_diffusion/train_mdlm.py b/examples/discrete_diffusion/train_mdlm.py new file mode 100644 index 000000000000..59b323d35275 --- /dev/null +++ b/examples/discrete_diffusion/train_mdlm.py @@ -0,0 +1,298 @@ +#!/usr/bin/env python +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import math +import os +from dataclasses import asdict, dataclass +from typing import Dict, Optional + +import torch +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers import ( + AutoConfig, + AutoModelForMaskedLM, + AutoTokenizer, + DataCollatorForLanguageModeling, + get_scheduler, +) + +from diffusers import TokenDiffusionScheduler +from diffusers.training_utils import compute_confidence_aware_loss + + +logger = get_logger(__name__) + + +@dataclass +class TrainConfig: + model_name_or_path: str + dataset_name: str + dataset_config_name: Optional[str] + text_column: str + + output_dir: str + seed: int + max_train_steps: int + checkpointing_steps: int + logging_steps: int + + per_device_train_batch_size: int + gradient_accumulation_steps: int + learning_rate: float + weight_decay: float + lr_scheduler: str + lr_warmup_steps: int + + max_length: int + num_train_timesteps: int + alpha_schedule: str + eps: float + sigma_min: float + sigma_max: float + min_timestep: int + lambda_conf: float + conf_temperature: float + + +def parse_args() -> TrainConfig: + parser = argparse.ArgumentParser(description="Train an absorbing token diffusion LM (MDLM-style) with Accelerate.") + + parser.add_argument("--model_name_or_path", type=str, default="bert-base-uncased") + parser.add_argument("--dataset_name", type=str, default="wikitext") + parser.add_argument("--dataset_config_name", type=str, default="wikitext-2-raw-v1") + parser.add_argument("--text_column", type=str, default="text") + + parser.add_argument("--output_dir", type=str, default="mdlm-output") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--max_train_steps", type=int, default=1000) + parser.add_argument("--checkpointing_steps", type=int, default=500) + parser.add_argument("--logging_steps", type=int, default=50) + + parser.add_argument("--per_device_train_batch_size", type=int, default=8) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--learning_rate", type=float, default=5e-5) + parser.add_argument("--weight_decay", type=float, default=0.01) + parser.add_argument( + "--lr_scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_with_restarts"] + ) + parser.add_argument("--lr_warmup_steps", type=int, default=100) + + parser.add_argument("--max_length", type=int, default=256) + + parser.add_argument("--num_train_timesteps", type=int, default=1000) + parser.add_argument( + "--alpha_schedule", + type=str, + default="log_linear", + choices=["log_linear", "linear", "cosine", "geometric"], + ) + parser.add_argument("--eps", type=float, default=1e-3) + parser.add_argument("--sigma_min", type=float, default=1e-4) + parser.add_argument("--sigma_max", type=float, default=20.0) + parser.add_argument("--min_timestep", type=int, default=1, help="Avoid t=0 to prevent 1/t weighting blow-ups.") + parser.add_argument( + "--lambda_conf", + type=float, + default=0.0, + help="Optional confidence-aware penalty weight (entropy on correctly predicted tokens).", + ) + parser.add_argument( + "--conf_temperature", + type=float, + default=1.0, + help="Temperature for the confidence term only; lower values sharpen the entropy penalty.", + ) + + args = parser.parse_args() + return TrainConfig(**vars(args)) + + +def tokenize_fn(examples: Dict, tokenizer, text_column: str, max_length: int): + texts = examples[text_column] + # drop empty lines + texts = [t for t in texts if isinstance(t, str) and len(t.strip()) > 0] + return tokenizer( + texts, + truncation=True, + padding=False, + max_length=max_length, + return_special_tokens_mask=True, + ) + + +def main(): + cfg = parse_args() + + project_config = ProjectConfiguration(project_dir=cfg.output_dir, logging_dir=os.path.join(cfg.output_dir, "logs")) + accelerator = Accelerator( + gradient_accumulation_steps=cfg.gradient_accumulation_steps, + project_config=project_config, + ) + + if accelerator.is_main_process: + os.makedirs(cfg.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + set_seed(cfg.seed) + logger.info("Training configuration: %s", asdict(cfg)) + + tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path, use_fast=True) + if tokenizer.mask_token_id is None: + # MDLM-style absorbing diffusion assumes a mask token exists. + tokenizer.add_special_tokens({"mask_token": "[MASK]"}) + + config = AutoConfig.from_pretrained(cfg.model_name_or_path) + model = AutoModelForMaskedLM.from_pretrained(cfg.model_name_or_path, config=config) + model.resize_token_embeddings(len(tokenizer)) + + scheduler = TokenDiffusionScheduler( + vocab_size=len(tokenizer), + mask_token_id=int(tokenizer.mask_token_id), + num_train_timesteps=cfg.num_train_timesteps, + alpha_schedule=cfg.alpha_schedule, + eps=cfg.eps, + sigma_min=cfg.sigma_min, + sigma_max=cfg.sigma_max, + ) + + raw_datasets = load_dataset(cfg.dataset_name, cfg.dataset_config_name) + if "train" not in raw_datasets: + raise ValueError(f"Dataset {cfg.dataset_name} has no 'train' split.") + + with accelerator.main_process_first(): + tokenized = raw_datasets["train"].map( + lambda ex: tokenize_fn(ex, tokenizer, cfg.text_column, cfg.max_length), + batched=True, + remove_columns=raw_datasets["train"].column_names, + desc="Tokenizing", + ) + + # We reuse the standard MLM collator to pad and build attention masks; we won't use its masking. + collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt") + train_dataloader = DataLoader( + tokenized, shuffle=True, collate_fn=collator, batch_size=cfg.per_device_train_batch_size, drop_last=True + ) + + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay) + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps) + num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch) + + lr_scheduler = get_scheduler( + name=cfg.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.lr_warmup_steps, + num_training_steps=cfg.max_train_steps, + ) + + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + global_step = 0 + model.train() + + for epoch in range(num_train_epochs): + for batch in train_dataloader: + with accelerator.accumulate(model): + input_ids = batch["input_ids"] + attention_mask = batch.get("attention_mask", torch.ones_like(input_ids)) + + # Sample discrete time indices (avoid timestep 0 for stability with 1/t weighting). + min_t = max(1, int(cfg.min_timestep)) + max_t = scheduler.num_train_timesteps - 1 + timesteps = torch.randint(min_t, max_t + 1, (input_ids.shape[0],), device=input_ids.device) + + # Forward process q(x_t | x_0): replace tokens with [MASK] according to alpha(t). + x_t = scheduler.add_noise(input_ids, noise=None, timesteps=timesteps) + + # Model predicts token logits for x0 reconstruction. + logits = model(input_ids=x_t, attention_mask=attention_mask).logits # [B, L, V] + + # MDLM-style constraints: + # - Do not predict the mask token as x0. + logits = logits.clone() + logits[..., scheduler.mask_token_id] = torch.finfo(logits.dtype).min + + # Only compute loss on tokens that were masked by the forward process. + mask_positions = x_t.eq(scheduler.mask_token_id) & attention_mask.to(dtype=torch.bool) + + weights = scheduler.get_mdlm_loss_weights(timesteps) + + labels = input_ids.clone() + labels[~mask_positions] = -100 + + per_token_weights = weights.to(dtype=logits.dtype).expand_as(labels) + loss, loss_sft, loss_conf = compute_confidence_aware_loss( + logits, + labels, + lambda_conf=cfg.lambda_conf, + temperature=cfg.conf_temperature, + per_token_weights=per_token_weights, + ) + + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + global_step += 1 + + if global_step % cfg.logging_steps == 0 and accelerator.is_main_process: + logger.info( + "step=%d loss=%.4f loss_sft=%.4f loss_conf=%.4f lr=%.6g", + global_step, + loss.item(), + loss_sft.item(), + loss_conf.item(), + lr_scheduler.get_last_lr()[0], + ) + + if cfg.checkpointing_steps > 0 and global_step % cfg.checkpointing_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + save_dir = os.path.join(cfg.output_dir, f"checkpoint-{global_step}") + os.makedirs(save_dir, exist_ok=True) + unwrapped = accelerator.unwrap_model(model) + unwrapped.save_pretrained(save_dir, save_function=accelerator.save) + tokenizer.save_pretrained(save_dir) + scheduler.save_pretrained(save_dir) + + if global_step >= cfg.max_train_steps: + break + + if global_step >= cfg.max_train_steps: + break + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + final_dir = os.path.join(cfg.output_dir, "final") + os.makedirs(final_dir, exist_ok=True) + unwrapped = accelerator.unwrap_model(model) + unwrapped.save_pretrained(final_dir, save_function=accelerator.save) + tokenizer.save_pretrained(final_dir) + scheduler.save_pretrained(final_dir) + + logger.info("Done.") + + +if __name__ == "__main__": + main() diff --git a/examples/discrete_diffusion/train_udlm.py b/examples/discrete_diffusion/train_udlm.py new file mode 100644 index 000000000000..8c61790defc2 --- /dev/null +++ b/examples/discrete_diffusion/train_udlm.py @@ -0,0 +1,311 @@ +#!/usr/bin/env python +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import math +import os +from dataclasses import asdict, dataclass +from typing import Dict, Optional + +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers import ( + AutoConfig, + AutoModelForMaskedLM, + AutoTokenizer, + DataCollatorForLanguageModeling, + get_scheduler, +) + +from diffusers import TokenDiffusionScheduler + + +logger = get_logger(__name__) + + +@dataclass +class TrainConfig: + model_name_or_path: str + dataset_name: str + dataset_config_name: Optional[str] + text_column: str + + output_dir: str + seed: int + max_train_steps: int + checkpointing_steps: int + logging_steps: int + + per_device_train_batch_size: int + gradient_accumulation_steps: int + learning_rate: float + weight_decay: float + lr_scheduler: str + lr_warmup_steps: int + + max_length: int + num_train_timesteps: int + alpha_schedule: str + eps: float + sigma_min: float + sigma_max: float + min_timestep: int + exclude_mask_from_uniform: bool + + +def parse_args() -> TrainConfig: + parser = argparse.ArgumentParser(description="Train a uniform token diffusion LM (UDLM-style) with Accelerate.") + + parser.add_argument("--model_name_or_path", type=str, default="bert-base-uncased") + parser.add_argument("--dataset_name", type=str, default="wikitext") + parser.add_argument("--dataset_config_name", type=str, default="wikitext-2-raw-v1") + parser.add_argument("--text_column", type=str, default="text") + + parser.add_argument("--output_dir", type=str, default="udlm-output") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--max_train_steps", type=int, default=1000) + parser.add_argument("--checkpointing_steps", type=int, default=500) + parser.add_argument("--logging_steps", type=int, default=50) + + parser.add_argument("--per_device_train_batch_size", type=int, default=8) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--learning_rate", type=float, default=5e-5) + parser.add_argument("--weight_decay", type=float, default=0.01) + parser.add_argument( + "--lr_scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_with_restarts"] + ) + parser.add_argument("--lr_warmup_steps", type=int, default=100) + + parser.add_argument("--max_length", type=int, default=256) + parser.add_argument("--num_train_timesteps", type=int, default=1000) + parser.add_argument( + "--alpha_schedule", + type=str, + default="log_linear", + choices=["log_linear", "linear", "cosine", "geometric"], + help="Alpha schedule used for the uniform forward process and the continuous-time UDLM objective.", + ) + parser.add_argument("--eps", type=float, default=1e-3) + parser.add_argument("--sigma_min", type=float, default=1e-4) + parser.add_argument("--sigma_max", type=float, default=20.0) + parser.add_argument("--min_timestep", type=int, default=1) + parser.add_argument( + "--exclude_mask_from_uniform", action="store_true", help="Exclude mask token from uniform draws." + ) + + args = parser.parse_args() + return TrainConfig(**vars(args)) + + +def tokenize_fn(examples: Dict, tokenizer, text_column: str, max_length: int): + texts = examples[text_column] + texts = [t for t in texts if isinstance(t, str) and len(t.strip()) > 0] + return tokenizer( + texts, + truncation=True, + padding=False, + max_length=max_length, + return_special_tokens_mask=True, + ) + + +def udlm_diffusion_loss( + logits: torch.Tensor, + x0: torch.LongTensor, + x_t: torch.LongTensor, + *, + alpha_t: torch.Tensor, + dalpha_t: torch.Tensor, +): + """ + UDLM diffusion loss (continuous-time form). + + Args: + logits: [B, L, V] + x0: [B, L] + x_t: [B, L] + alpha_t: [B, 1] alpha(t) for the uniform forward process. + dalpha_t: [B, 1] time derivative alpha'(t) with respect to continuous time t in [0, 1]. + Returns: + loss_per_token: [B, L] + """ + log_x_theta = torch.log_softmax(logits, dim=-1) + B, L, V = log_x_theta.shape + + alpha = alpha_t.to(dtype=log_x_theta.dtype).view(B, 1, 1) + alpha_prime = dalpha_t.to(dtype=log_x_theta.dtype).view(B, 1, 1) + + x0_one_hot = F.one_hot(x0, V).to(dtype=log_x_theta.dtype) + xt_one_hot = F.one_hot(x_t, V).to(dtype=log_x_theta.dtype) + + x_bar = V * alpha * x0_one_hot + 1.0 - alpha + x_bar_theta = V * alpha * log_x_theta.exp() + 1.0 - alpha + + coeff = alpha_prime / (V * alpha.clamp_min(torch.finfo(alpha.dtype).eps)) + + x_bar_zt = (x_bar * xt_one_hot).sum(dim=-1, keepdim=True) # (B, L, 1) + x_bar_theta_zt = (x_bar_theta * xt_one_hot).sum(dim=-1, keepdim=True) # (B, L, 1) + + term1 = (V / x_bar_zt) - (V / x_bar_theta_zt) + + term2 = ((x_bar / x_bar_zt) * (x_bar_theta_zt.log() - x_bar_theta.log() + x_bar.log() - x_bar_zt.log())).sum( + dim=-1, keepdim=True + ) + + diffusion_loss = (coeff * (term1 - term2)).squeeze(-1) # (B, L) + return diffusion_loss + + +def main(): + cfg = parse_args() + + project_config = ProjectConfiguration(project_dir=cfg.output_dir, logging_dir=os.path.join(cfg.output_dir, "logs")) + accelerator = Accelerator( + gradient_accumulation_steps=cfg.gradient_accumulation_steps, + project_config=project_config, + ) + + if accelerator.is_main_process: + os.makedirs(cfg.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + set_seed(cfg.seed) + logger.info("Training configuration: %s", asdict(cfg)) + + tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path, use_fast=True) + if tokenizer.mask_token_id is None: + tokenizer.add_special_tokens({"mask_token": "[MASK]"}) + + config = AutoConfig.from_pretrained(cfg.model_name_or_path) + model = AutoModelForMaskedLM.from_pretrained(cfg.model_name_or_path, config=config) + model.resize_token_embeddings(len(tokenizer)) + + scheduler = TokenDiffusionScheduler( + vocab_size=len(tokenizer), + mask_token_id=int(tokenizer.mask_token_id), + num_train_timesteps=cfg.num_train_timesteps, + alpha_schedule=cfg.alpha_schedule, + eps=cfg.eps, + sigma_min=cfg.sigma_min, + sigma_max=cfg.sigma_max, + forward_process="uniform", + exclude_mask_from_uniform=cfg.exclude_mask_from_uniform, + ) + + raw_datasets = load_dataset(cfg.dataset_name, cfg.dataset_config_name) + if "train" not in raw_datasets: + raise ValueError(f"Dataset {cfg.dataset_name} has no 'train' split.") + + with accelerator.main_process_first(): + tokenized = raw_datasets["train"].map( + lambda ex: tokenize_fn(ex, tokenizer, cfg.text_column, cfg.max_length), + batched=True, + remove_columns=raw_datasets["train"].column_names, + desc="Tokenizing", + ) + + collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt") + train_dataloader = DataLoader( + tokenized, shuffle=True, collate_fn=collator, batch_size=cfg.per_device_train_batch_size, drop_last=True + ) + + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay) + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps) + num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch) + + lr_scheduler = get_scheduler( + name=cfg.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.lr_warmup_steps, + num_training_steps=cfg.max_train_steps, + ) + + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + global_step = 0 + model.train() + + for epoch in range(num_train_epochs): + for batch in train_dataloader: + with accelerator.accumulate(model): + input_ids = batch["input_ids"] + attention_mask = batch.get("attention_mask", torch.ones_like(input_ids)) + + min_t = max(1, int(cfg.min_timestep)) + max_t = scheduler.num_train_timesteps - 1 + timesteps = torch.randint(min_t, max_t + 1, (input_ids.shape[0],), device=input_ids.device) + + x_t = scheduler.add_noise(input_ids, noise=None, timesteps=timesteps) + logits = model(input_ids=x_t, attention_mask=attention_mask).logits + + if scheduler.exclude_mask_from_uniform: + logits = logits.clone() + logits[..., scheduler.mask_token_id] = torch.finfo(logits.dtype).min + + alpha_t = scheduler.get_alpha(timesteps) + dalpha_t = scheduler.get_alpha_prime(timesteps) + loss_per_token = udlm_diffusion_loss(logits, input_ids, x_t, alpha_t=alpha_t, dalpha_t=dalpha_t) + loss = (loss_per_token * attention_mask.to(loss_per_token.dtype)).sum() + loss = loss / attention_mask.sum().clamp_min(1) + + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + global_step += 1 + + if global_step % cfg.logging_steps == 0 and accelerator.is_main_process: + logger.info("step=%d loss=%.4f lr=%.6g", global_step, loss.item(), lr_scheduler.get_last_lr()[0]) + + if cfg.checkpointing_steps > 0 and global_step % cfg.checkpointing_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + save_dir = os.path.join(cfg.output_dir, f"checkpoint-{global_step}") + os.makedirs(save_dir, exist_ok=True) + unwrapped = accelerator.unwrap_model(model) + unwrapped.save_pretrained(save_dir, save_function=accelerator.save) + tokenizer.save_pretrained(save_dir) + scheduler.save_pretrained(save_dir) + + if global_step >= cfg.max_train_steps: + break + + if global_step >= cfg.max_train_steps: + break + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + final_dir = os.path.join(cfg.output_dir, "final") + os.makedirs(final_dir, exist_ok=True) + unwrapped = accelerator.unwrap_model(model) + unwrapped.save_pretrained(final_dir, save_function=accelerator.save) + tokenizer.save_pretrained(final_dir) + scheduler.save_pretrained(final_dir) + + logger.info("Done.") + + +if __name__ == "__main__": + main() diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index bc59497d1db7..e03c94d28d26 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -326,10 +326,22 @@ "StableDiffusionMixin", ] ) + _import_structure["pipelines"].extend( + [ + "BlockRefinementPipeline", + "BlockRefinementPipelineOutput", + "BlockTokenDiffusionPipeline", + "BlockTokenDiffusionPipelineOutput", + "HybridTokenDiffusionPipeline", + "TokenDiffusionPipeline", + "TokenDiffusionPipelineOutput", + ] + ) _import_structure["quantizers"] = ["DiffusersQuantizer"] _import_structure["schedulers"].extend( [ "AmusedScheduler", + "BlockTokenDiffusionScheduler", "CMStochasticIterativeScheduler", "CogVideoXDDIMScheduler", "CogVideoXDPMScheduler", @@ -340,6 +352,8 @@ "DDPMScheduler", "DDPMWuerstchenScheduler", "DEISMultistepScheduler", + "DFlashTokenDiffusionScheduler", + "DFlashTokenDiffusionSchedulerOutput", "DPMSolverMultistepInverseScheduler", "DPMSolverMultistepScheduler", "DPMSolverSinglestepScheduler", @@ -351,6 +365,8 @@ "FlowMatchHeunDiscreteScheduler", "FlowMatchLCMScheduler", "HeunDiscreteScheduler", + "HybridTokenDiffusionScheduler", + "HybridTokenDiffusionSchedulerOutput", "IPNDMScheduler", "KarrasVeScheduler", "KDPM2AncestralDiscreteScheduler", @@ -363,7 +379,10 @@ "SchedulerMixin", "SCMScheduler", "ScoreSdeVeScheduler", + "SDARTokenDiffusionScheduler", + "SDARTokenDiffusionSchedulerOutput", "TCDScheduler", + "TokenDiffusionScheduler", "UnCLIPScheduler", "UniPCMultistepScheduler", "VQDiffusionScheduler", @@ -475,6 +494,8 @@ "CosmosTextToWorldPipeline", "CosmosVideoToWorldPipeline", "CycleDiffusionPipeline", + "DFlashPipeline", + "DFlashPipelineOutput", "EasyAnimateControlPipeline", "EasyAnimateInpaintPipeline", "EasyAnimatePipeline", @@ -541,6 +562,8 @@ "LDMTextToImagePipeline", "LEditsPPPipelineStableDiffusion", "LEditsPPPipelineStableDiffusionXL", + "LLaDA2Pipeline", + "LLaDA2PipelineOutput", "LongCatImageEditPipeline", "LongCatImagePipeline", "LTX2ImageToVideoPipeline", @@ -587,6 +610,8 @@ "SanaSprintPipeline", "SanaVideoPipeline", "SanaVideoPipeline", + "SDARPipeline", + "SDARPipelineOutput", "SemanticStableDiffusionPipeline", "ShapEImg2ImgPipeline", "ShapEPipeline", @@ -1056,6 +1081,10 @@ AutoPipelineForText2Image, BlipDiffusionControlNetPipeline, BlipDiffusionPipeline, + BlockRefinementPipeline, + BlockRefinementPipelineOutput, + BlockTokenDiffusionPipeline, + BlockTokenDiffusionPipelineOutput, CLIPImageProjection, ConsistencyModelPipeline, DanceDiffusionPipeline, @@ -1063,6 +1092,7 @@ DDPMPipeline, DiffusionPipeline, DiTPipeline, + HybridTokenDiffusionPipeline, ImagePipelineOutput, KarrasVePipeline, LDMPipeline, @@ -1071,10 +1101,13 @@ RePaintPipeline, ScoreSdeVePipeline, StableDiffusionMixin, + TokenDiffusionPipeline, + TokenDiffusionPipelineOutput, ) from .quantizers import DiffusersQuantizer from .schedulers import ( AmusedScheduler, + BlockTokenDiffusionScheduler, CMStochasticIterativeScheduler, CogVideoXDDIMScheduler, CogVideoXDPMScheduler, @@ -1085,6 +1118,8 @@ DDPMScheduler, DDPMWuerstchenScheduler, DEISMultistepScheduler, + DFlashTokenDiffusionScheduler, + DFlashTokenDiffusionSchedulerOutput, DPMSolverMultistepInverseScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, @@ -1096,6 +1131,8 @@ FlowMatchHeunDiscreteScheduler, FlowMatchLCMScheduler, HeunDiscreteScheduler, + HybridTokenDiffusionScheduler, + HybridTokenDiffusionSchedulerOutput, IPNDMScheduler, KarrasVeScheduler, KDPM2AncestralDiscreteScheduler, @@ -1108,7 +1145,10 @@ SchedulerMixin, SCMScheduler, ScoreSdeVeScheduler, + SDARTokenDiffusionScheduler, + SDARTokenDiffusionSchedulerOutput, TCDScheduler, + TokenDiffusionScheduler, UnCLIPScheduler, UniPCMultistepScheduler, VQDiffusionScheduler, @@ -1199,6 +1239,8 @@ CosmosTextToWorldPipeline, CosmosVideoToWorldPipeline, CycleDiffusionPipeline, + DFlashPipeline, + DFlashPipelineOutput, EasyAnimateControlPipeline, EasyAnimateInpaintPipeline, EasyAnimatePipeline, @@ -1265,6 +1307,8 @@ LDMTextToImagePipeline, LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, + LLaDA2Pipeline, + LLaDA2PipelineOutput, LongCatImageEditPipeline, LongCatImagePipeline, LTX2ImageToVideoPipeline, @@ -1310,6 +1354,8 @@ SanaSprintImg2ImgPipeline, SanaSprintPipeline, SanaVideoPipeline, + SDARPipeline, + SDARPipelineOutput, SemanticStableDiffusionPipeline, ShapEImg2ImgPipeline, ShapEPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index b94319ffcbdc..db14c6f64f54 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -47,11 +47,15 @@ "AutoPipelineForInpainting", "AutoPipelineForText2Image", ] + _import_structure["block_refinement"] = ["BlockRefinementPipeline", "BlockRefinementPipelineOutput"] + _import_structure["block_token_diffusion"] = ["BlockTokenDiffusionPipeline", "BlockTokenDiffusionPipelineOutput"] _import_structure["consistency_models"] = ["ConsistencyModelPipeline"] _import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"] _import_structure["ddim"] = ["DDIMPipeline"] _import_structure["ddpm"] = ["DDPMPipeline"] _import_structure["dit"] = ["DiTPipeline"] + _import_structure["hybrid_token_diffusion"] = ["HybridTokenDiffusionPipeline"] + _import_structure["token_diffusion"] = ["TokenDiffusionPipeline", "TokenDiffusionPipelineOutput"] _import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"]) _import_structure["pipeline_utils"] = [ "AudioPipelineOutput", @@ -408,6 +412,9 @@ "Kandinsky5T2IPipeline", "Kandinsky5I2IPipeline", ] + _import_structure["dflash"] = ["DFlashPipeline", "DFlashPipelineOutput"] + _import_structure["sdar"] = ["SDARPipeline", "SDARPipelineOutput"] + _import_structure["llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"] _import_structure["z_image"] = [ "ZImageImg2ImgPipeline", "ZImagePipeline", @@ -547,12 +554,15 @@ AutoPipelineForInpainting, AutoPipelineForText2Image, ) + from .block_refinement import BlockRefinementPipeline, BlockRefinementPipelineOutput + from .block_token_diffusion import BlockTokenDiffusionPipeline, BlockTokenDiffusionPipelineOutput from .consistency_models import ConsistencyModelPipeline from .dance_diffusion import DanceDiffusionPipeline from .ddim import DDIMPipeline from .ddpm import DDPMPipeline from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline from .dit import DiTPipeline + from .hybrid_token_diffusion import HybridTokenDiffusionPipeline from .latent_diffusion import LDMSuperResolutionPipeline from .pipeline_utils import ( AudioPipelineOutput, @@ -560,6 +570,7 @@ ImagePipelineOutput, StableDiffusionMixin, ) + from .token_diffusion import TokenDiffusionPipeline, TokenDiffusionPipelineOutput try: if not (is_torch_available() and is_librosa_available()): @@ -654,6 +665,7 @@ VersatileDiffusionTextToImagePipeline, VQDiffusionPipeline, ) + from .dflash import DFlashPipeline, DFlashPipelineOutput from .easyanimate import ( EasyAnimateControlPipeline, EasyAnimateInpaintPipeline, @@ -730,6 +742,7 @@ LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, ) + from .llada2 import LLaDA2Pipeline, LLaDA2PipelineOutput from .longcat_image import LongCatImageEditPipeline, LongCatImagePipeline from .ltx import ( LTXConditionPipeline, @@ -792,6 +805,7 @@ SanaSprintPipeline, ) from .sana_video import SanaImageToVideoPipeline, SanaVideoPipeline + from .sdar import SDARPipeline, SDARPipelineOutput from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline from .stable_audio import StableAudioPipeline, StableAudioProjectionModel diff --git a/src/diffusers/pipelines/block_refinement/__init__.py b/src/diffusers/pipelines/block_refinement/__init__.py new file mode 100644 index 000000000000..1eec2ee97e81 --- /dev/null +++ b/src/diffusers/pipelines/block_refinement/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .pipeline_block_refinement import BlockRefinementPipeline, BlockRefinementPipelineOutput + + +__all__ = ["BlockRefinementPipeline", "BlockRefinementPipelineOutput"] diff --git a/src/diffusers/pipelines/block_refinement/pipeline_block_refinement.py b/src/diffusers/pipelines/block_refinement/pipeline_block_refinement.py new file mode 100644 index 000000000000..9b5ea4695f4e --- /dev/null +++ b/src/diffusers/pipelines/block_refinement/pipeline_block_refinement.py @@ -0,0 +1,507 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import torch + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...utils import BaseOutput +from ..pipeline_utils import DiffusionPipeline + + +@dataclass +class BlockRefinementPipelineOutput(BaseOutput): + sequences: torch.LongTensor + texts: Optional[List[str]] = None + + +def _top_k_filtering(logits: torch.Tensor, top_k: Optional[int]) -> torch.Tensor: + if top_k is None or top_k <= 0: + return logits + if top_k >= logits.shape[-1]: + return logits + values, _ = torch.topk(logits, k=int(top_k), dim=-1) + min_keep = values[..., -1, None] + return logits.masked_fill(logits < min_keep, torch.finfo(logits.dtype).min) + + +def _top_p_filtering(logits: torch.Tensor, top_p: Optional[float]) -> torch.Tensor: + if top_p is None or top_p >= 1.0: + return logits + if not (0.0 < top_p <= 1.0): + raise ValueError(f"`top_p` must be in (0, 1], got {top_p}.") + + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + sorted_probs = torch.softmax(sorted_logits, dim=-1) + cumulative_probs = sorted_probs.cumsum(dim=-1) + + sorted_indices_to_remove = cumulative_probs > float(top_p) + sorted_indices_to_remove[..., 0] = 0 + + sorted_logits = sorted_logits.masked_fill(sorted_indices_to_remove, torch.finfo(sorted_logits.dtype).min) + filtered = logits.scatter(-1, sorted_indices, sorted_logits) + return filtered + + +def _sample_with_temperature_topk_topp( + logits: torch.Tensor, + *, + temperature: float, + top_k: Optional[int], + top_p: Optional[float], + generator: Optional[torch.Generator], + use_multinomial: bool, +) -> tuple[torch.LongTensor, torch.Tensor]: + vocab_size = logits.shape[-1] + flat_logits = logits.reshape(-1, vocab_size) + + filtered = _top_k_filtering(flat_logits, top_k=top_k) + filtered = _top_p_filtering(filtered, top_p=top_p) + + if temperature < 0: + raise ValueError(f"`temperature` must be >= 0, got {temperature}.") + + scaled = filtered + if temperature > 0.0 and temperature != 1.0: + scaled = filtered / float(temperature) + + probs = torch.softmax(scaled.float(), dim=-1) + if use_multinomial: + token = torch.multinomial(probs, num_samples=1, generator=generator) + else: + token = scaled.argmax(dim=-1, keepdim=True) + token_prob = torch.gather(probs, -1, token) + + return token.view(*logits.shape[:-1]), token_prob.view(*logits.shape[:-1]) + + +def _get_num_transfer_tokens(block_length: int, steps: int) -> torch.LongTensor: + if steps <= 0: + return torch.zeros((0,), dtype=torch.long) + base = int(block_length) // int(steps) + remainder = int(block_length) % int(steps) + out = torch.full((int(steps),), base, dtype=torch.long) + out[:remainder] += 1 + return out + + +class BlockRefinementPipeline(DiffusionPipeline): + """ + Block-wise iterative refinement pipeline for token generation. + + This pipeline maintains a template sequence filled with a `mask_token_id` and refines it in blocks. In each + refinement step, it samples candidate tokens for the active block and commits a subset based on confidence. + + The model is expected to accept an additive attention mask of shape `[batch, 1, seq, seq]` (0 for allowed, `-inf` + for disallowed) and `position_ids`, and to return logits of shape `[batch, seq, vocab_size]`. + """ + + model: Any + tokenizer: Any + + _callback_tensor_inputs = ["cur_x", "x0", "x0_p", "transfer_index", "confidence", "active_block"] + + def __init__( + self, + model: Any, + tokenizer: Optional[Any] = None, + *, + gen_length: int = 128, + block_length: int = 32, + steps: int = 32, + temperature: float = 0.0, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + sampling_method: str = "auto", + threshold: float = 0.95, + minimal_topk: int = 1, + eos_early_stop: bool = False, + attention_mask_mode: str = "auto", + ): + super().__init__() + self.register_modules(model=model, tokenizer=tokenizer) + self.register_to_config( + gen_length=gen_length, + block_length=block_length, + steps=steps, + temperature=temperature, + top_p=top_p, + top_k=top_k, + sampling_method=sampling_method, + threshold=threshold, + minimal_topk=minimal_topk, + eos_early_stop=eos_early_stop, + attention_mask_mode=attention_mask_mode, + ) + + @property + def num_timesteps(self): + return self._num_timesteps + + def _model_forward_logits( + self, + input_ids: torch.LongTensor, + *, + attention_mask_4d: Optional[torch.Tensor], + attention_mask_2d: Optional[torch.Tensor], + position_ids: torch.LongTensor, + attention_mask_mode: str, + ) -> tuple[torch.Tensor, str]: + if attention_mask_mode not in {"auto", "4d", "2d", "none"}: + raise ValueError( + f"`attention_mask_mode` must be one of {{'auto','4d','2d','none'}}, got {attention_mask_mode!r}." + ) + + def _call(mask): + return self.model(input_ids, attention_mask=mask, position_ids=position_ids).logits + + if attention_mask_mode == "none": + return _call(None), "none" + if attention_mask_mode == "2d": + return _call(attention_mask_2d), "2d" + if attention_mask_mode == "4d": + return _call(attention_mask_4d), "4d" + + # auto: try 4d additive mask first, then fall back to 2d padding mask, then no mask. + try: + return _call(attention_mask_4d), "4d" + except (TypeError, ValueError, RuntimeError): + pass + try: + return _call(attention_mask_2d), "2d" + except (TypeError, ValueError, RuntimeError): + return _call(None), "none" + + def _build_block_attention_mask( + self, + *, + num_blocks: int, + block_length: int, + total_length: int, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + block_mask = torch.tril(torch.ones(num_blocks, num_blocks, device=device, dtype=torch.bool)) + attn = ( + block_mask.repeat_interleave(block_length, dim=0) + .repeat_interleave(block_length, dim=1) + .unsqueeze(0) + .unsqueeze(0) + ) + attn = attn[:, :, :total_length, :total_length] + return torch.where( + attn, + torch.zeros((), device=device, dtype=dtype), + torch.full((), float("-inf"), device=device, dtype=dtype), + ) + + def _encode_prompt( + self, + prompt: Optional[Union[str, List[str]]], + prompt_ids: Optional[torch.LongTensor], + *, + device: torch.device, + ) -> torch.LongTensor: + if prompt_ids is not None: + if prompt_ids.ndim == 1: + prompt_ids = prompt_ids.unsqueeze(0) + if prompt_ids.ndim != 2: + raise ValueError( + f"`prompt_ids` must have shape [prompt_len] or [batch, prompt_len], got {prompt_ids.shape}." + ) + if prompt_ids.dtype != torch.long: + raise ValueError(f"`prompt_ids` must be int64 token IDs, got dtype={prompt_ids.dtype}.") + return prompt_ids.to(device=device) + + if prompt is None: + return torch.zeros((1, 0), device=device, dtype=torch.long) + if getattr(self, "tokenizer", None) is None: + raise ValueError("`prompt` requires a tokenizer, but no tokenizer was provided to the pipeline.") + + encoded = self.tokenizer(prompt, return_tensors="pt", padding=True) + return encoded["input_ids"].to(device=device) + + @torch.no_grad() + def __call__( + self, + *, + prompt: Optional[Union[str, List[str]]] = None, + prompt_ids: Optional[torch.LongTensor] = None, + gen_length: Optional[int] = None, + block_length: Optional[int] = None, + steps: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + sampling_method: Optional[str] = None, + threshold: Optional[float] = None, + minimal_topk: Optional[int] = None, + eos_early_stop: Optional[bool] = None, + eos_token_id: Optional[int] = None, + mask_token_id: Optional[int] = None, + attention_mask_mode: Optional[str] = None, + generator: Optional[torch.Generator] = None, + return_text: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, + ) -> BlockRefinementPipelineOutput: + """ + Generate tokens with block-wise refinement. + + Args: + prompt (`str` or `List[str]`, *optional*): + Prompt text to encode with the tokenizer. + prompt_ids (`torch.LongTensor`, *optional*): + Pre-tokenized prompt IDs with shape `[prompt_len]` or `[batch, prompt_len]`. + gen_length (`int`, *optional*): + Number of tokens to generate. If `None`, uses `pipe.config.gen_length`. + block_length (`int`, *optional*): + Block size for refinement. If `None`, uses `pipe.config.block_length`. + steps (`int`, *optional*): + Refinement steps per block. If `None`, uses `pipe.config.steps`. + temperature (`float`, *optional*): + Sampling temperature. If `None`, uses `pipe.config.temperature`. + top_p (`float`, *optional*): + Nucleus sampling cutoff. If `None`, uses `pipe.config.top_p`. + top_k (`int`, *optional*): + Top-k sampling cutoff. If `None`, uses `pipe.config.top_k`. + sampling_method (`str`, *optional*): + Sampling method (`auto`, `greedy`, `multinomial`). If `None`, uses `pipe.config.sampling_method`. + threshold (`float`, *optional*): + Confidence threshold for committing tokens. If `None`, uses `pipe.config.threshold`. + minimal_topk (`int`, *optional*): + Minimum number of tokens to commit per step. If `None`, uses `pipe.config.minimal_topk`. + eos_early_stop (`bool`, *optional*): + Whether to stop after committing EOS in a block. If `None`, uses `pipe.config.eos_early_stop`. + eos_token_id (`int`, *optional*): + EOS token ID to use for early stopping. + mask_token_id (`int`, *optional*): + Mask token ID to use for the template. + attention_mask_mode (`str`, *optional*): + Attention mask mode (`auto`, `4d`, `2d`, `none`). If `None`, uses `pipe.config.attention_mask_mode`. + generator (`torch.Generator`, *optional*): + RNG for sampling. + return_text (`bool`, *optional*, defaults to `True`): + Whether to decode sequences into text when a tokenizer is available. + callback_on_step_end (`Callable` or `PipelineCallback`, *optional*): + Callback executed after each refinement step with signature `callback_on_step_end(self, step: int, + timestep: int, callback_kwargs: Dict)`. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + Tensor keys to pass to the callback. Allowed keys: `cur_x`, `x0`, `x0_p`, `transfer_index`, + `confidence`, `active_block`. + """ + if gen_length is None: + gen_length = int(self.config.gen_length) + if block_length is None: + block_length = int(self.config.block_length) + if steps is None: + steps = int(self.config.steps) + if temperature is None: + temperature = float(self.config.temperature) + if top_p is None: + top_p = self.config.top_p + if top_k is None: + top_k = self.config.top_k + if sampling_method is None: + sampling_method = str(self.config.sampling_method) + if threshold is None: + threshold = float(self.config.threshold) + if minimal_topk is None: + minimal_topk = int(self.config.minimal_topk) + if eos_early_stop is None: + eos_early_stop = bool(self.config.eos_early_stop) + if attention_mask_mode is None: + attention_mask_mode = str(self.config.attention_mask_mode) + + if gen_length <= 0: + raise ValueError(f"`gen_length` must be > 0, got {gen_length}.") + if block_length <= 0: + raise ValueError(f"`block_length` must be > 0, got {block_length}.") + if steps <= 0: + raise ValueError(f"`steps` must be > 0, got {steps}.") + if minimal_topk <= 0: + raise ValueError(f"`minimal_topk` must be > 0, got {minimal_topk}.") + if not (0.0 <= threshold <= 1.0) and not (threshold > 1.0): + raise ValueError(f"`threshold` must be in [0, 1] (or > 1 to force top-k commits), got {threshold}.") + if sampling_method not in {"auto", "greedy", "multinomial"}: + raise ValueError( + f"`sampling_method` must be one of {{'auto','greedy','multinomial'}}, got {sampling_method!r}." + ) + + if callback_on_step_end is not None and isinstance( + callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks) + ): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if callback_on_step_end_tensor_inputs is None: + callback_on_step_end_tensor_inputs = ["cur_x"] + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " + f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + model_params = list(self.model.parameters()) if hasattr(self.model, "parameters") else [] + model_device = model_params[0].device if len(model_params) > 0 else torch.device("cpu") + + prompt_ids = self._encode_prompt(prompt, prompt_ids, device=model_device) + batch_size, prompt_length = prompt_ids.shape + + if eos_token_id is None: + eos_token_id = getattr(getattr(self, "tokenizer", None), "eos_token_id", None) + if mask_token_id is None: + mask_token_id = getattr(getattr(self, "tokenizer", None), "mask_token_id", None) + if mask_token_id is None: + raise ValueError("`mask_token_id` must be provided (or available on the tokenizer).") + + steps = min(int(steps), int(gen_length) // int(minimal_topk)) + + num_blocks = (prompt_length + int(gen_length) + int(block_length) - 1) // int(block_length) + total_length = int(num_blocks) * int(block_length) + + dtype = getattr(self.model, "dtype", torch.float32) + attn_dtype = torch.bfloat16 if dtype in (torch.bfloat16, torch.float16) else torch.float32 + attn_mask_4d = self._build_block_attention_mask( + num_blocks=num_blocks, + block_length=block_length, + total_length=total_length, + device=model_device, + dtype=attn_dtype, + ) + attn_mask_2d_full = torch.ones((batch_size, total_length), device=model_device, dtype=torch.long) + position_ids = ( + torch.arange(total_length, device=model_device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1) + ) + + x = torch.full((batch_size, total_length), int(mask_token_id), device=model_device, dtype=torch.long) + if prompt_length > 0: + x[:, :prompt_length] = prompt_ids.to(device=model_device) + + prefill_blocks = prompt_length // int(block_length) + self._num_timesteps = int(steps) * max(int(num_blocks) - int(prefill_blocks), 0) + transfer_schedule = _get_num_transfer_tokens(int(block_length), int(steps)).to(device=model_device) + + finished = torch.zeros((batch_size,), device=model_device, dtype=torch.bool) + resolved_attention_mode: str = str(attention_mask_mode) + + use_multinomial = sampling_method == "multinomial" or (sampling_method == "auto" and float(temperature) != 0.0) + global_step = 0 + + for num_block in range(int(prefill_blocks), int(num_blocks)): + current_window_end = (num_block + 1) * int(block_length) + cur_x = x[:, :current_window_end] + cur_attn_mask_4d = attn_mask_4d[:, :, :current_window_end, :current_window_end] + cur_attn_mask_2d = attn_mask_2d_full[:, :current_window_end] + cur_position_ids = position_ids[:, :current_window_end] + + for step_idx in range(int(steps)): + if finished.all(): + break + + active_block = cur_x[:, -int(block_length) :] == int(mask_token_id) + if active_block.sum() == 0: + break + + logits, resolved_attention_mode = self._model_forward_logits( + cur_x, + attention_mask_4d=cur_attn_mask_4d, + attention_mask_2d=cur_attn_mask_2d, + position_ids=cur_position_ids, + attention_mask_mode=resolved_attention_mode, + ) + block_logits = logits[:, -int(block_length) :, :] + + x0, x0_p = _sample_with_temperature_topk_topp( + block_logits, + temperature=float(temperature), + top_k=top_k, + top_p=top_p, + generator=generator, + use_multinomial=use_multinomial, + ) + + num_to_transfer = int(transfer_schedule[step_idx].item()) + transfer_index = torch.zeros_like(x0, dtype=torch.bool) + + confidence = torch.where( + active_block, x0_p.to(dtype=torch.float32), torch.full_like(x0_p, -torch.inf, dtype=torch.float32) + ) + + for b in range(batch_size): + if finished[b]: + continue + + high_conf = confidence[b] > float(threshold) + if high_conf.sum().item() >= num_to_transfer: + transfer_index[b] = high_conf + else: + k = min(num_to_transfer, int(active_block[b].sum().item())) + if k > 0: + _, idx = torch.topk(confidence[b], k=k) + transfer_index[b, idx] = True + + if transfer_index.any(): + updated = cur_x[:, -int(block_length) :].clone() + updated[transfer_index] = x0[transfer_index] + cur_x[:, -int(block_length) :] = updated + + if eos_early_stop and eos_token_id is not None: + for b in range(batch_size): + if finished[b]: + continue + eos_in_commits = (x0[b][transfer_index[b]] == int(eos_token_id)).any().item() + if not eos_in_commits: + continue + eos_pos = (cur_x[b] == int(eos_token_id)).nonzero(as_tuple=True) + if len(eos_pos[0]) == 0: + continue + eos_pos = int(eos_pos[0][0].item()) + if prompt_length >= eos_pos: + continue + if (cur_x[b, prompt_length:eos_pos] != int(mask_token_id)).all().item(): + finished[b] = True + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, global_step, step_idx, callback_kwargs) + cur_x = callback_outputs.pop("cur_x", cur_x) + + global_step += 1 + + x[:, :current_window_end] = cur_x + if eos_token_id is not None and (x[:, prompt_length:current_window_end] == int(eos_token_id)).any().item(): + if eos_early_stop: + break + + generated = x[:, : prompt_length + int(gen_length)] + sequences = generated[:, prompt_length:] + if eos_token_id is not None and batch_size == 1: + eos_positions = (sequences[0] == int(eos_token_id)).nonzero(as_tuple=True)[0] + if len(eos_positions) > 0: + sequences = sequences[:, : int(eos_positions[0].item()) + 1] + + texts = None + if return_text and getattr(self, "tokenizer", None) is not None: + texts = self.tokenizer.batch_decode(sequences, skip_special_tokens=True) + + return BlockRefinementPipelineOutput(sequences=sequences.to(device=model_device), texts=texts) diff --git a/src/diffusers/pipelines/block_token_diffusion/__init__.py b/src/diffusers/pipelines/block_token_diffusion/__init__.py new file mode 100644 index 000000000000..ab331c11f765 --- /dev/null +++ b/src/diffusers/pipelines/block_token_diffusion/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .pipeline_block_token_diffusion import BlockTokenDiffusionPipeline, BlockTokenDiffusionPipelineOutput + + +__all__ = ["BlockTokenDiffusionPipeline", "BlockTokenDiffusionPipelineOutput"] diff --git a/src/diffusers/pipelines/block_token_diffusion/pipeline_block_token_diffusion.py b/src/diffusers/pipelines/block_token_diffusion/pipeline_block_token_diffusion.py new file mode 100644 index 000000000000..43f6a71621a8 --- /dev/null +++ b/src/diffusers/pipelines/block_token_diffusion/pipeline_block_token_diffusion.py @@ -0,0 +1,281 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...utils import BaseOutput +from ..pipeline_utils import DiffusionPipeline + + +@dataclass +class BlockTokenDiffusionPipelineOutput(BaseOutput): + sequences: torch.LongTensor + texts: Optional[List[str]] = None + + +def _top_p_filtering(logits: torch.Tensor, top_p: float) -> torch.Tensor: + if top_p >= 1.0: + return logits + if not (0.0 < top_p <= 1.0): + raise ValueError(f"`top_p` must be in (0, 1], got {top_p}.") + + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + sorted_probs = torch.softmax(sorted_logits, dim=-1) + cumulative_probs = sorted_probs.cumsum(dim=-1) + + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 0] = 0 + + sorted_logits = sorted_logits.masked_fill(sorted_indices_to_remove, torch.finfo(sorted_logits.dtype).min) + filtered = logits.scatter(-1, sorted_indices, sorted_logits) + return filtered + + +class BlockTokenDiffusionPipeline(DiffusionPipeline): + """ + Block-wise token diffusion sampling pipeline. + + Compared to `TokenDiffusionPipeline`, this pipeline updates the sequence in blocks. Only the current block's + positions are allowed to change during the inner denoising loop. + """ + + model: Any + tokenizer: Any + scheduler: Any + + _callback_tensor_inputs = ["input_ids", "logits", "block_mask"] + + def __init__( + self, + model: Any, + scheduler: Any, + tokenizer: Optional[Any] = None, + *, + seq_len: int = 64, + block_size: int = 32, + num_inference_steps: int = 64, + inject_start_token: bool = False, + top_p: float = 1.0, + ): + super().__init__() + self.register_modules(model=model, scheduler=scheduler, tokenizer=tokenizer) + self.register_to_config( + seq_len=seq_len, + block_size=block_size, + num_inference_steps=num_inference_steps, + inject_start_token=inject_start_token, + top_p=top_p, + ) + + @property + def num_timesteps(self): + return self._num_timesteps + + def _resolve_start_token_id(self) -> Optional[int]: + tok = getattr(self, "tokenizer", None) + if tok is None: + return None + for attr in ("bos_token_id", "cls_token_id"): + token_id = getattr(tok, attr, None) + if token_id is not None: + return int(token_id) + return None + + def _normalize_prefix_ids( + self, prefix_ids: torch.LongTensor, batch_size: int, device: torch.device + ) -> torch.LongTensor: + if prefix_ids.ndim == 1: + prefix_ids = prefix_ids.unsqueeze(0) + if prefix_ids.ndim != 2: + raise ValueError( + f"`prefix_ids` must have shape [prefix_len] or [batch, prefix_len], got {prefix_ids.shape}." + ) + if prefix_ids.shape[0] not in (1, batch_size): + raise ValueError( + f"`prefix_ids` batch dim must be 1 or batch_size={batch_size}, got {prefix_ids.shape[0]}." + ) + if prefix_ids.dtype != torch.long: + raise ValueError(f"`prefix_ids` must be int64 token IDs, got dtype={prefix_ids.dtype}.") + prefix_ids = prefix_ids.to(device=device) + if prefix_ids.shape[0] == 1 and batch_size > 1: + prefix_ids = prefix_ids.expand(batch_size, -1) + return prefix_ids + + def _init_latents( + self, + batch_size: int, + seq_len: int, + *, + generator: Optional[torch.Generator], + device: torch.device, + ) -> torch.LongTensor: + mask_token_id = getattr(self.scheduler, "mask_token_id", None) + if mask_token_id is None: + raise ValueError("Scheduler must define `mask_token_id` for block diffusion sampling.") + return torch.full((batch_size, seq_len), int(mask_token_id), device=device, dtype=torch.long) + + @torch.no_grad() + def __call__( + self, + *, + batch_size: int = 1, + seq_len: Optional[int] = None, + block_size: Optional[int] = None, + num_inference_steps: Optional[int] = None, + generator: Optional[torch.Generator] = None, + prefix_ids: Optional[torch.LongTensor] = None, + infill_mask: Optional[torch.BoolTensor] = None, + inject_start_token: Optional[bool] = None, + top_p: Optional[float] = None, + return_text: bool = True, + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, + **model_kwargs, + ) -> Union[BlockTokenDiffusionPipelineOutput, Tuple[torch.LongTensor, Optional[List[str]]]]: + if seq_len is None: + seq_len = int(self.config.seq_len) + if block_size is None: + block_size = int(self.config.block_size) + if num_inference_steps is None: + num_inference_steps = int(self.config.num_inference_steps) + if inject_start_token is None: + inject_start_token = bool(self.config.inject_start_token) + if top_p is None: + top_p = float(self.config.top_p) + + if callback_on_step_end is not None and isinstance( + callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks) + ): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if callback_on_step_end_tensor_inputs is None: + callback_on_step_end_tensor_inputs = ["input_ids"] + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " + f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + device = self._execution_device + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + input_ids = self._init_latents(batch_size, seq_len, generator=generator, device=device) + attention_mask = torch.ones_like(input_ids, dtype=torch.long) + + fixed_mask = None + fixed_values = None + if infill_mask is not None: + if infill_mask.shape != (batch_size, seq_len): + raise ValueError( + f"`infill_mask` must have shape {(batch_size, seq_len)}, got {tuple(infill_mask.shape)}." + ) + fixed_mask = (~infill_mask.to(device=device)).to(dtype=torch.bool) + fixed_values = input_ids.clone() + + if prefix_ids is not None: + prefix_ids = self._normalize_prefix_ids(prefix_ids, batch_size=batch_size, device=device) + prefix_len = prefix_ids.shape[1] + if prefix_len > seq_len: + raise ValueError(f"`prefix_ids` length {prefix_len} must be <= seq_len={seq_len}.") + + input_ids[:, :prefix_len] = prefix_ids + if fixed_mask is None: + fixed_mask = torch.zeros((batch_size, seq_len), device=device, dtype=torch.bool) + fixed_values = input_ids.clone() + fixed_mask[:, :prefix_len] = True + fixed_values[:, :prefix_len] = prefix_ids + + start_token_id = self._resolve_start_token_id() + if inject_start_token and start_token_id is not None: + input_ids[:, 0] = start_token_id + if fixed_mask is None: + fixed_mask = torch.zeros((batch_size, seq_len), device=device, dtype=torch.bool) + fixed_values = input_ids.clone() + fixed_mask[:, 0] = True + fixed_values[:, 0] = start_token_id + + if block_size <= 0 or block_size > seq_len: + raise ValueError(f"`block_size` must be in [1, seq_len], got block_size={block_size}, seq_len={seq_len}.") + + num_blocks = (seq_len + block_size - 1) // block_size + self._num_timesteps = len(timesteps) * int(num_blocks) + global_step = 0 + for block_idx in range(num_blocks): + start = block_idx * block_size + end = min((block_idx + 1) * block_size, seq_len) + + block_mask = torch.zeros((batch_size, seq_len), device=device, dtype=torch.bool) + block_mask[:, start:end] = True + if fixed_mask is not None: + block_mask = block_mask & (~fixed_mask) + + if not torch.any(block_mask): + continue + + input_ids = torch.where(block_mask, int(self.scheduler.mask_token_id), input_ids) + + for step_idx, t in enumerate(timesteps): + out = self.model(input_ids=input_ids, attention_mask=attention_mask, **model_kwargs) + logits = getattr(out, "logits", None) + if logits is None: + logits = out[0] + + if top_p < 1.0: + logits_block = logits[block_mask].view(-1, logits.shape[-1]) + logits_block = _top_p_filtering(logits_block, top_p=top_p) + logits = logits.clone() + logits[block_mask] = logits_block.view(-1, logits.shape[-1]) + + input_ids = self.scheduler.step( + logits, + t, + input_ids, + generator=generator, + return_dict=True, + block_mask=block_mask, + ).prev_sample + + if fixed_mask is not None: + input_ids = torch.where(fixed_mask, fixed_values, input_ids) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, global_step, t, callback_kwargs) + input_ids = callback_outputs.pop("input_ids", input_ids) + + global_step += 1 + + texts = None + if return_text and getattr(self, "tokenizer", None) is not None: + texts = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True) + + if not return_dict: + return (input_ids, texts) + return BlockTokenDiffusionPipelineOutput(sequences=input_ids, texts=texts) + + +__all__ = ["BlockTokenDiffusionPipeline", "BlockTokenDiffusionPipelineOutput"] diff --git a/src/diffusers/pipelines/dflash/__init__.py b/src/diffusers/pipelines/dflash/__init__.py new file mode 100644 index 000000000000..c5d0f5fae4cd --- /dev/null +++ b/src/diffusers/pipelines/dflash/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_dflash"] = ["DFlashPipeline", "DFlashPipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_dflash import DFlashPipeline, DFlashPipelineOutput +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/dflash/pipeline_dflash.py b/src/diffusers/pipelines/dflash/pipeline_dflash.py new file mode 100644 index 000000000000..ae7f11f3b628 --- /dev/null +++ b/src/diffusers/pipelines/dflash/pipeline_dflash.py @@ -0,0 +1,471 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, DynamicCache + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...schedulers import DFlashTokenDiffusionScheduler +from ...utils import BaseOutput, logging, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline + + +logger = logging.get_logger(__name__) + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import DFlashPipeline + + >>> draft_id = "z-lab/Qwen3-8B-DFlash-b16" + >>> target_id = "Qwen/Qwen3-8B" + >>> pipe = DFlashPipeline.from_pretrained( + ... draft_model_id=draft_id, + ... target_model_id=target_id, + ... draft_model_kwargs={"trust_remote_code": True, "dtype": torch.bfloat16}, + ... target_model_kwargs={"dtype": torch.bfloat16}, + ... ) + >>> out = pipe(prompt="How many positive whole-number divisors does 196 have?") + >>> print(out.texts[0]) + ``` +""" + + +@dataclass +class DFlashPipelineOutput(BaseOutput): + sequences: torch.LongTensor + texts: Optional[List[str]] = None + + +def _build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> List[int]: + if num_draft_layers == 1: + return [int(num_target_layers // 2)] + start = 1 + end = int(num_target_layers) - 3 + span = end - start + return [int(round(start + (i * span) / (num_draft_layers - 1))) for i in range(int(num_draft_layers))] + + +def _extract_context_feature(hidden_states: List[torch.Tensor], layer_ids: List[int]) -> torch.Tensor: + offset = 1 + selected_states = [hidden_states[layer_id + offset] for layer_id in layer_ids] + return torch.cat(selected_states, dim=-1) + + +class DFlashPipeline(DiffusionPipeline): + r""" + Block diffusion pipeline for speculative decoding with a DFlash draft model and a target causal LM. + """ + + draft_model: torch.nn.Module + target_model: torch.nn.Module + tokenizer: Optional[object] + scheduler: DFlashTokenDiffusionScheduler + _callback_tensor_inputs = ["block_output_ids", "draft_logits", "accepted_length", "next_token", "output_ids"] + + def __init__( + self, + draft_model: torch.nn.Module, + target_model: torch.nn.Module, + tokenizer: Optional[object] = None, + scheduler: Optional[DFlashTokenDiffusionScheduler] = None, + *, + max_new_tokens: int = 2048, + temperature: float = 0.0, + use_chat_template: bool = True, + add_generation_prompt: bool = True, + ): + super().__init__() + if scheduler is None: + scheduler = DFlashTokenDiffusionScheduler() + self.register_modules( + draft_model=draft_model, target_model=target_model, tokenizer=tokenizer, scheduler=scheduler + ) + self.register_to_config( + max_new_tokens=max_new_tokens, + temperature=temperature, + use_chat_template=use_chat_template, + add_generation_prompt=add_generation_prompt, + ) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[str] = None, + *, + draft_model_id: Optional[str] = None, + target_model_id: Optional[str] = None, + tokenizer_id: Optional[str] = None, + mask_token: Optional[str] = "<|MASK|>", + scheduler: Optional[DFlashTokenDiffusionScheduler] = None, + draft_model_kwargs: Optional[Dict[str, object]] = None, + target_model_kwargs: Optional[Dict[str, object]] = None, + tokenizer_kwargs: Optional[Dict[str, object]] = None, + **pipeline_kwargs, + ) -> "DFlashPipeline": + if draft_model_id is None and target_model_id is None and pretrained_model_name_or_path is not None: + return super().from_pretrained(pretrained_model_name_or_path, **pipeline_kwargs) + + if draft_model_id is None: + if pretrained_model_name_or_path is None: + raise ValueError("Provide `draft_model_id` or `pretrained_model_name_or_path`.") + draft_model_id = str(pretrained_model_name_or_path) + if target_model_id is None: + raise ValueError("`target_model_id` must be provided when loading draft/target models separately.") + + draft_model_kwargs = dict(draft_model_kwargs or {}) + draft_model_kwargs.setdefault("trust_remote_code", True) + target_model_kwargs = dict(target_model_kwargs or {}) + tokenizer_kwargs = dict(tokenizer_kwargs or {}) + + draft = AutoModel.from_pretrained(draft_model_id, **draft_model_kwargs) + target = AutoModelForCausalLM.from_pretrained(target_model_id, **target_model_kwargs) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_id or target_model_id, **tokenizer_kwargs) + + if mask_token is not None and tokenizer.mask_token_id is None: + tokenizer.add_special_tokens({"mask_token": mask_token}) + + return cls( + draft_model=draft, + target_model=target, + tokenizer=tokenizer, + scheduler=scheduler, + **pipeline_kwargs, + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + *, + messages: Optional[List[Dict[str, str]]] = None, + input_ids: Optional[torch.LongTensor] = None, + max_new_tokens: Optional[int] = None, + temperature: Optional[float] = None, + stop_token_ids: Optional[List[int]] = None, + mask_token_id: Optional[int] = None, + use_chat_template: Optional[bool] = None, + add_generation_prompt: Optional[bool] = None, + chat_template_kwargs: Optional[Dict[str, object]] = None, + return_text: bool = True, + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, + ) -> Union[DFlashPipelineOutput, Tuple[torch.LongTensor, Optional[List[str]]]]: + """ + Generate text using block-diffusion speculative decoding. + + Examples: + """ + if max_new_tokens is None: + max_new_tokens = int(self.config.max_new_tokens) + if temperature is None: + temperature = float(self.config.temperature) + if use_chat_template is None: + use_chat_template = bool(self.config.use_chat_template) + if add_generation_prompt is None: + add_generation_prompt = bool(self.config.add_generation_prompt) + + if callback_on_step_end is not None and isinstance( + callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks) + ): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if callback_on_step_end_tensor_inputs is None: + callback_on_step_end_tensor_inputs = ["block_output_ids"] + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " + f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + input_ids = self._prepare_input_ids( + prompt=prompt, + messages=messages, + input_ids=input_ids, + use_chat_template=use_chat_template, + add_generation_prompt=add_generation_prompt, + chat_template_kwargs=chat_template_kwargs, + ) + + if input_ids.shape[0] != 1: + raise ValueError("DFlashPipeline currently supports batch_size=1 input_ids.") + + target_params = list(self.target_model.parameters()) if hasattr(self.target_model, "parameters") else [] + device = target_params[0].device if len(target_params) > 0 else torch.device("cpu") + input_ids = input_ids.to(device=device) + draft_params = list(self.draft_model.parameters()) if hasattr(self.draft_model, "parameters") else [] + draft_device = draft_params[0].device if len(draft_params) > 0 else device + if draft_device != device: + logger.warning( + "Draft model is on %s while target model is on %s. For best performance, place both on the same device.", + draft_device, + device, + ) + + if mask_token_id is None: + mask_token_id = getattr(getattr(self, "tokenizer", None), "mask_token_id", None) + if mask_token_id is None: + raise ValueError("`mask_token_id` must be provided (or available on the tokenizer).") + + if stop_token_ids is None: + eos_token_id = getattr(getattr(self, "tokenizer", None), "eos_token_id", None) + stop_token_ids = [int(eos_token_id)] if eos_token_id is not None else None + if stop_token_ids is not None: + stop_token_ids = [int(token_id) for token_id in stop_token_ids] + + self.draft_model.eval() + self.target_model.eval() + self.scheduler.set_timesteps(1, device=device) + + block_size = self._get_block_size() + target_layer_ids = self._get_target_layer_ids() + input_embeddings = self._get_target_input_embeddings() + output_embeddings = self._get_target_output_embeddings() + + num_input_tokens = input_ids.shape[1] + max_length = num_input_tokens + int(max_new_tokens) + output_ids = torch.full( + (1, max_length + int(block_size)), + int(mask_token_id), + dtype=torch.long, + device=device, + ) + position_ids = torch.arange(output_ids.shape[1], device=device).unsqueeze(0) + + past_key_values_target = DynamicCache() + past_key_values_draft = DynamicCache() + + output = self._target_forward( + input_ids=input_ids, + position_ids=position_ids[:, :num_input_tokens], + past_key_values=past_key_values_target, + output_hidden_states=True, + logits_to_keep=1, + ) + output_ids[:, :num_input_tokens] = input_ids + output_ids[:, num_input_tokens : num_input_tokens + 1] = self.scheduler.sample( + output.logits[:, -1:], temperature=temperature + ) + target_hidden = _extract_context_feature(output.hidden_states, target_layer_ids) + + start = num_input_tokens + global_step = 0 + stop_tensor = None + if stop_token_ids is not None: + stop_tensor = torch.tensor(stop_token_ids, device=device, dtype=torch.long) + + while start < max_length: + block_output_ids = output_ids[:, start : start + int(block_size)].clone() + block_position_ids = position_ids[:, start : start + int(block_size)] + noise_embedding = input_embeddings(block_output_ids) + draft_hidden = self.draft_model( + target_hidden=target_hidden, + noise_embedding=noise_embedding, + position_ids=position_ids[:, past_key_values_draft.get_seq_length() : start + int(block_size)], + past_key_values=past_key_values_draft, + use_cache=True, + is_causal=False, + ) + if not torch.is_tensor(draft_hidden): + draft_hidden = getattr(draft_hidden, "last_hidden_state", draft_hidden[0]) + draft_logits = output_embeddings(draft_hidden[:, -int(block_size) + 1 :, :]) + past_key_values_draft.crop(start) + block_output_ids[:, 1:] = self.scheduler.sample(draft_logits, temperature=temperature) + + output = self._target_forward( + input_ids=block_output_ids, + position_ids=block_position_ids, + past_key_values=past_key_values_target, + output_hidden_states=True, + logits_to_keep=None, + ) + step_output = self.scheduler.step( + block_output_ids, output.logits, temperature=temperature, return_dict=True + ) + accepted_length = step_output.accepted_length + next_token = step_output.next_token + acceptance_length = int(step_output.accepted_length[0].item()) + output_ids[:, start : start + acceptance_length + 1] = block_output_ids[:, : acceptance_length + 1] + output_ids[:, start + acceptance_length + 1] = step_output.next_token + start += acceptance_length + 1 + past_key_values_target.crop(start) + target_hidden = _extract_context_feature(output.hidden_states, target_layer_ids)[ + :, : acceptance_length + 1, : + ] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, global_step, 0, callback_kwargs) + output_ids = callback_outputs.pop("output_ids", output_ids) + global_step += 1 + + if stop_tensor is not None and torch.isin(output_ids[:, num_input_tokens:], stop_tensor).any(): + break + + output_ids = output_ids[:, :max_length] + output_ids = output_ids[:, output_ids[0] != int(mask_token_id)] + if stop_tensor is not None: + stop_positions = torch.isin(output_ids[0, num_input_tokens:], stop_tensor).nonzero(as_tuple=True)[0] + if stop_positions.numel() > 0: + output_ids = output_ids[:, : num_input_tokens + int(stop_positions[0].item()) + 1] + + prompt_len = input_ids.shape[1] + sequences = output_ids[:, prompt_len:] + + texts = None + if return_text and getattr(self, "tokenizer", None) is not None: + texts = self.tokenizer.batch_decode(sequences, skip_special_tokens=True) + + if not return_dict: + return sequences, texts + return DFlashPipelineOutput(sequences=sequences, texts=texts) + + def _get_block_size(self) -> int: + block_size = getattr(self.draft_model, "block_size", None) + if block_size is None: + block_size = getattr(getattr(self.draft_model, "config", None), "block_size", None) + if block_size is None: + raise ValueError("`draft_model` must define `block_size` on the module or its config.") + return int(block_size) + + def _get_target_layer_ids(self) -> List[int]: + layer_ids = getattr(self.draft_model, "target_layer_ids", None) + if layer_ids is not None: + return list(layer_ids) + cfg = getattr(self.draft_model, "config", None) + num_target_layers = getattr(cfg, "num_target_layers", None) + num_hidden_layers = getattr(cfg, "num_hidden_layers", None) + if num_target_layers is None or num_hidden_layers is None: + raise ValueError("`draft_model` must define `target_layer_ids` or expose `num_target_layers` in config.") + return _build_target_layer_ids(int(num_target_layers), int(num_hidden_layers)) + + def _get_target_input_embeddings(self) -> torch.nn.Module: + embeddings = self.target_model.get_input_embeddings() + if embeddings is None: + base_model = getattr(self.target_model, "model", None) + embeddings = getattr(base_model, "embed_tokens", None) + if embeddings is None: + raise ValueError("`target_model` must provide input embeddings for DFlash decoding.") + return embeddings + + def _get_target_output_embeddings(self) -> torch.nn.Module: + embeddings = self.target_model.get_output_embeddings() + if embeddings is None: + embeddings = getattr(self.target_model, "lm_head", None) + if embeddings is None: + raise ValueError("`target_model` must provide output embeddings for DFlash decoding.") + return embeddings + + def _target_forward( + self, + *, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + past_key_values: DynamicCache, + output_hidden_states: bool, + logits_to_keep: Optional[int], + ): + kwargs = { + "input_ids": input_ids, + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": True, + "output_hidden_states": output_hidden_states, + } + if logits_to_keep is not None: + try: + return self.target_model(**kwargs, logits_to_keep=logits_to_keep) + except TypeError: + pass + return self.target_model(**kwargs) + + def _prepare_input_ids( + self, + *, + prompt: Optional[Union[str, List[str]]], + messages: Optional[List[Dict[str, str]]], + input_ids: Optional[torch.LongTensor], + use_chat_template: bool, + add_generation_prompt: bool, + chat_template_kwargs: Optional[Dict[str, object]], + ) -> torch.LongTensor: + if input_ids is not None: + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + if input_ids.ndim != 2: + raise ValueError(f"`input_ids` must be 2D, got shape {tuple(input_ids.shape)}.") + if input_ids.dtype != torch.long: + raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.") + return input_ids + + if self.tokenizer is None: + raise ValueError("Tokenizer is required when `input_ids` is not provided.") + + if messages is not None and prompt is not None: + raise ValueError("Provide either `prompt` or `messages`, not both.") + if messages is None and prompt is None: + raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.") + + chat_template_kwargs = chat_template_kwargs or {} + + def _extract_input_ids(encoded): + if isinstance(encoded, dict) and "input_ids" in encoded: + return encoded["input_ids"] + if hasattr(encoded, "input_ids"): + return encoded.input_ids + return encoded + + if messages is not None: + encoded = self.tokenizer.apply_chat_template( + messages, + add_generation_prompt=add_generation_prompt, + tokenize=True, + return_tensors="pt", + return_dict=True, + **chat_template_kwargs, + ) + return _extract_input_ids(encoded) + + if use_chat_template and getattr(self.tokenizer, "chat_template", None): + if isinstance(prompt, list): + raise ValueError("`prompt` must be a string when `use_chat_template=True`.") + encoded = self.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=add_generation_prompt, + tokenize=True, + return_tensors="pt", + return_dict=True, + **chat_template_kwargs, + ) + return _extract_input_ids(encoded) + + encoded = self.tokenizer(prompt, return_tensors="pt", padding=isinstance(prompt, list)) + return _extract_input_ids(encoded) + + +__all__ = ["DFlashPipeline", "DFlashPipelineOutput"] diff --git a/src/diffusers/pipelines/hybrid_token_diffusion/__init__.py b/src/diffusers/pipelines/hybrid_token_diffusion/__init__.py new file mode 100644 index 000000000000..9fb3efbe4709 --- /dev/null +++ b/src/diffusers/pipelines/hybrid_token_diffusion/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .pipeline_hybrid_token_diffusion import HybridTokenDiffusionPipeline + + +__all__ = ["HybridTokenDiffusionPipeline"] diff --git a/src/diffusers/pipelines/hybrid_token_diffusion/pipeline_hybrid_token_diffusion.py b/src/diffusers/pipelines/hybrid_token_diffusion/pipeline_hybrid_token_diffusion.py new file mode 100644 index 000000000000..bdc372960dee --- /dev/null +++ b/src/diffusers/pipelines/hybrid_token_diffusion/pipeline_hybrid_token_diffusion.py @@ -0,0 +1,24 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from ..token_diffusion.pipeline_token_diffusion import TokenDiffusionPipeline + + +class HybridTokenDiffusionPipeline(TokenDiffusionPipeline): + """Alias of `TokenDiffusionPipeline` for hybrid-transition schedulers.""" + + +__all__ = ["HybridTokenDiffusionPipeline"] diff --git a/src/diffusers/pipelines/llada2/__init__.py b/src/diffusers/pipelines/llada2/__init__.py new file mode 100644 index 000000000000..45a02e6851e2 --- /dev/null +++ b/src/diffusers/pipelines/llada2/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_llada2 import LLaDA2Pipeline, LLaDA2PipelineOutput +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/llada2/pipeline_llada2.py b/src/diffusers/pipelines/llada2/pipeline_llada2.py new file mode 100644 index 000000000000..acdb686b3aaf --- /dev/null +++ b/src/diffusers/pipelines/llada2/pipeline_llada2.py @@ -0,0 +1,182 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch + +from ...utils import BaseOutput, logging, replace_example_docstring +from ..block_refinement import BlockRefinementPipeline, BlockRefinementPipelineOutput + + +logger = logging.get_logger(__name__) + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from transformers import AutoModelForCausalLM, AutoTokenizer + >>> from diffusers import LLaDA2Pipeline + + >>> model_id = "inclusionAI/LLaDA2.0-mini" + >>> model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.bfloat16) + >>> tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) + >>> model = model.to("cuda") + + >>> pipe = LLaDA2Pipeline(model=model, tokenizer=tokenizer) + >>> output = pipe(prompt="What is the meaning of life?", gen_length=256) + >>> print(output.texts[0]) + ``` +""" + + +@dataclass +class LLaDA2PipelineOutput(BaseOutput): + sequences: torch.LongTensor + texts: Optional[List[str]] = None + + +class LLaDA2Pipeline(BlockRefinementPipeline): + r""" + Adapter pipeline for LLaDA2-style discrete diffusion generation. + + This pipeline subclasses [`BlockRefinementPipeline`] and reuses its sampling loop. It only adapts prompt + preparation (including chat templates) and output formatting. + """ + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + *, + messages: Optional[List[Dict[str, str]]] = None, + input_ids: Optional[torch.LongTensor] = None, + use_chat_template: bool = True, + add_generation_prompt: bool = True, + gen_length: int = 2048, + block_length: int = 32, + steps: int = 32, + temperature: float = 0.0, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + sampling_method: str = "multinomial", + threshold: float = 0.95, + minimal_topk: int = 1, + eos_early_stop: bool = False, + eos_token_id: Optional[int] = None, + mask_token_id: Optional[int] = None, + attention_mask_mode: str = "4d", + generator: Optional[torch.Generator] = None, + return_text: bool = True, + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, + ) -> Union[LLaDA2PipelineOutput, Tuple[torch.LongTensor, Optional[List[str]]]]: + """ + Generate text with block-wise refinement. + + Examples: + """ + prompt_ids = self._prepare_prompt_ids( + prompt=prompt, + messages=messages, + input_ids=input_ids, + use_chat_template=use_chat_template, + add_generation_prompt=add_generation_prompt, + ) + + output: BlockRefinementPipelineOutput = super().__call__( + prompt_ids=prompt_ids, + gen_length=gen_length, + block_length=block_length, + steps=steps, + temperature=temperature, + top_p=top_p, + top_k=top_k, + sampling_method=sampling_method, + threshold=threshold, + minimal_topk=minimal_topk, + eos_early_stop=eos_early_stop, + eos_token_id=eos_token_id, + mask_token_id=mask_token_id, + attention_mask_mode=attention_mask_mode, + generator=generator, + return_text=return_text, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + if not return_dict: + return output.sequences, output.texts + return LLaDA2PipelineOutput(sequences=output.sequences, texts=output.texts) + + def _prepare_prompt_ids( + self, + *, + prompt: Optional[Union[str, List[str]]], + messages: Optional[List[Dict[str, str]]], + input_ids: Optional[torch.LongTensor], + use_chat_template: bool, + add_generation_prompt: bool, + ) -> Optional[torch.LongTensor]: + if input_ids is not None: + return input_ids + + if self.tokenizer is None: + if prompt is None and messages is None: + return None + raise ValueError("Tokenizer is required to encode `prompt` or `messages`.") + + def _extract_input_ids(encoded): + if isinstance(encoded, dict) and "input_ids" in encoded: + return encoded["input_ids"] + if hasattr(encoded, "input_ids"): + return encoded.input_ids + return encoded + + if messages is not None: + encoded = self.tokenizer.apply_chat_template( + messages, + add_generation_prompt=add_generation_prompt, + tokenize=True, + return_tensors="pt", + return_dict=True, + ) + return _extract_input_ids(encoded) + + if prompt is None: + return None + + if use_chat_template and getattr(self.tokenizer, "chat_template", None): + if isinstance(prompt, list): + raise ValueError("`prompt` must be a string when `use_chat_template=True`.") + encoded = self.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=add_generation_prompt, + tokenize=True, + return_tensors="pt", + return_dict=True, + ) + return _extract_input_ids(encoded) + + encoded = self.tokenizer(prompt, return_tensors="pt", padding=isinstance(prompt, list)) + return _extract_input_ids(encoded) + + +__all__ = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"] diff --git a/src/diffusers/pipelines/sdar/__init__.py b/src/diffusers/pipelines/sdar/__init__.py new file mode 100644 index 000000000000..13f8e30c9962 --- /dev/null +++ b/src/diffusers/pipelines/sdar/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_sdar"] = ["SDARPipeline", "SDARPipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_sdar import SDARPipeline, SDARPipelineOutput +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/sdar/pipeline_sdar.py b/src/diffusers/pipelines/sdar/pipeline_sdar.py new file mode 100644 index 000000000000..fbd35e95adea --- /dev/null +++ b/src/diffusers/pipelines/sdar/pipeline_sdar.py @@ -0,0 +1,516 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import DynamicCache + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...schedulers import SDARTokenDiffusionScheduler +from ...utils import BaseOutput, logging, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline + + +logger = logging.get_logger(__name__) + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from transformers import AutoModelForCausalLM, AutoTokenizer + >>> from diffusers import SDARPipeline + + >>> model_id = "JetLM/SDAR-1.7B-Chat" + >>> model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, dtype=torch.bfloat16) + >>> tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) + >>> tokenizer.add_special_tokens({"mask_token": "<|MASK|>"}) + + >>> pipe = SDARPipeline(model=model, tokenizer=tokenizer) + >>> out = pipe(prompt="Explain what reinforcement learning is in simple terms.") + >>> print(out.texts[0]) + ``` +""" + + +@dataclass +class SDARPipelineOutput(BaseOutput): + sequences: torch.LongTensor + texts: Optional[List[str]] = None + + +class SDARPipeline(DiffusionPipeline): + r""" + Block diffusion pipeline for SDAR-style token generation. + """ + + model: torch.nn.Module + tokenizer: Optional[object] + scheduler: SDARTokenDiffusionScheduler + _callback_tensor_inputs = ["cur_x", "logits", "sampled_tokens", "sampled_probs", "transfer_index"] + + def __init__( + self, + model: torch.nn.Module, + tokenizer: Optional[object] = None, + scheduler: Optional[SDARTokenDiffusionScheduler] = None, + *, + max_new_tokens: int = 256, + block_length: int = 4, + denoising_steps: int = 4, + temperature: float = 1.0, + top_k: int = 0, + top_p: float = 1.0, + remasking_strategy: str = "low_confidence_dynamic", + confidence_threshold: float = 0.9, + entropy_threshold: float = 0.35, + use_chat_template: bool = True, + add_generation_prompt: bool = True, + attention_mask_mode: str = "3d", + ): + super().__init__() + if scheduler is None: + scheduler = SDARTokenDiffusionScheduler( + block_length=block_length, + denoising_steps=denoising_steps, + remasking_strategy=remasking_strategy, + confidence_threshold=confidence_threshold, + entropy_threshold=entropy_threshold, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + self.register_modules(model=model, tokenizer=tokenizer, scheduler=scheduler) + self.register_to_config( + max_new_tokens=max_new_tokens, + block_length=block_length, + denoising_steps=denoising_steps, + temperature=temperature, + top_k=top_k, + top_p=top_p, + remasking_strategy=remasking_strategy, + confidence_threshold=confidence_threshold, + entropy_threshold=entropy_threshold, + use_chat_template=use_chat_template, + add_generation_prompt=add_generation_prompt, + attention_mask_mode=attention_mask_mode, + ) + self._store_kv_supported: Optional[bool] = None + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + *, + messages: Optional[List[Dict[str, str]]] = None, + input_ids: Optional[torch.LongTensor] = None, + max_new_tokens: Optional[int] = None, + block_length: Optional[int] = None, + denoising_steps: Optional[int] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + remasking_strategy: Optional[str] = None, + confidence_threshold: Optional[float] = None, + entropy_threshold: Optional[float] = None, + stop_token_ids: Optional[List[int]] = None, + mask_token_id: Optional[int] = None, + attention_mask_mode: Optional[str] = None, + use_chat_template: Optional[bool] = None, + add_generation_prompt: Optional[bool] = None, + chat_template_kwargs: Optional[Dict[str, object]] = None, + return_text: bool = True, + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, + ) -> Union[SDARPipelineOutput, Tuple[torch.LongTensor, Optional[List[str]]]]: + """ + Generate text using SDAR-style block diffusion decoding. + + Examples: + """ + if max_new_tokens is None: + max_new_tokens = int(self.config.max_new_tokens) + if block_length is None: + model_block_length = getattr(self.model, "block_length", None) + if model_block_length is None: + model_block_length = getattr(getattr(self.model, "config", None), "block_length", None) + block_length = int(model_block_length) if model_block_length is not None else int(self.config.block_length) + if denoising_steps is None: + denoising_steps = int(self.config.denoising_steps) + if temperature is None: + temperature = float(self.config.temperature) + if top_k is None: + top_k = int(self.config.top_k) + if top_p is None: + top_p = float(self.config.top_p) + if remasking_strategy is None: + remasking_strategy = str(self.config.remasking_strategy) + if confidence_threshold is None: + confidence_threshold = float(self.config.confidence_threshold) + if entropy_threshold is None: + entropy_threshold = float(self.config.entropy_threshold) + if attention_mask_mode is None: + attention_mask_mode = str(self.config.attention_mask_mode) + if use_chat_template is None: + use_chat_template = bool(self.config.use_chat_template) + if add_generation_prompt is None: + add_generation_prompt = bool(self.config.add_generation_prompt) + + if callback_on_step_end is not None and isinstance( + callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks) + ): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if callback_on_step_end_tensor_inputs is None: + callback_on_step_end_tensor_inputs = ["cur_x"] + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " + f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + input_ids = self._prepare_input_ids( + prompt=prompt, + messages=messages, + input_ids=input_ids, + use_chat_template=use_chat_template, + add_generation_prompt=add_generation_prompt, + chat_template_kwargs=chat_template_kwargs, + ) + + if input_ids.shape[0] != 1: + raise ValueError("SDARPipeline currently supports batch_size=1 input_ids.") + + params = list(self.model.parameters()) if hasattr(self.model, "parameters") else [] + device = params[0].device if len(params) > 0 else torch.device("cpu") + input_ids = input_ids.to(device=device) + + if mask_token_id is None: + mask_token_id = getattr(getattr(self, "tokenizer", None), "mask_token_id", None) + if mask_token_id is None: + raise ValueError("`mask_token_id` must be provided (or available on the tokenizer).") + + if stop_token_ids is None: + eos_token_id = getattr(getattr(self, "tokenizer", None), "eos_token_id", None) + stop_token_ids = [int(eos_token_id)] if eos_token_id is not None else None + if stop_token_ids is not None: + stop_token_ids = [int(token_id) for token_id in stop_token_ids] + + self.model.eval() + if block_length <= 0: + raise ValueError(f"`block_length` must be > 0, got {block_length}.") + if denoising_steps <= 0: + raise ValueError(f"`denoising_steps` must be > 0, got {denoising_steps}.") + + self.scheduler.set_timesteps(int(denoising_steps), device=device) + + prompt_length = input_ids.shape[1] + num_blocks = (prompt_length + int(max_new_tokens) + int(block_length) - 1) // int(block_length) + total_length = int(num_blocks) * int(block_length) + + block_mask_3d = self._build_block_attention_mask_3d( + num_blocks=num_blocks, + block_length=block_length, + total_length=total_length, + device=device, + dtype=torch.float32, + ) + block_mask_4d = self._build_block_attention_mask_4d(block_mask_3d, dtype=torch.float32) + block_mask_2d = block_mask_3d[0] + + x = torch.full( + (1, total_length), + int(mask_token_id), + dtype=torch.long, + device=device, + ) + x[:, :prompt_length] = input_ids + + position_ids = torch.arange(total_length, device=device).unsqueeze(0) + past_key_values = DynamicCache() + + prefill_blocks = prompt_length // int(block_length) + prefill_length = int(prefill_blocks) * int(block_length) + resolved_attention_mode = str(attention_mask_mode) + + if prefill_length > 0: + cur_x = x[:, :prefill_length] + cur_position_ids = position_ids[:, :prefill_length] + cur_attn_mask = block_mask_3d[:, :prefill_length, :prefill_length] + cur_attn_mask_4d = block_mask_4d[:, :, :prefill_length, :prefill_length] + cur_attn_mask_2d = block_mask_2d[:prefill_length, :prefill_length] + _, resolved_attention_mode = self._model_forward_logits( + input_ids=cur_x, + attention_mask_3d=cur_attn_mask, + attention_mask_4d=cur_attn_mask_4d, + attention_mask_2d=cur_attn_mask_2d, + position_ids=cur_position_ids, + attention_mask_mode=resolved_attention_mode, + past_key_values=past_key_values, + store_kv=True, + ) + + num_transfer_tokens = self.scheduler.get_num_transfer_tokens(int(block_length), int(denoising_steps)).to( + device=device + ) + + stop_tensor = None + if stop_token_ids is not None: + stop_tensor = torch.tensor(stop_token_ids, device=device, dtype=torch.long) + + global_step = 0 + for block_idx in range(prefill_blocks, int(num_blocks)): + start = int(block_idx) * int(block_length) + end = start + int(block_length) + cur_x = x[:, start:end].clone() + cur_position_ids = position_ids[:, start:end] + cur_attn_mask = block_mask_3d[:, start:end, :end] + cur_attn_mask_4d = block_mask_4d[:, :, start:end, :end] + cur_attn_mask_2d = block_mask_2d[start:end, :end] + + for step in range(int(denoising_steps) + 1): + mask_index = cur_x == int(mask_token_id) + if mask_index.sum() == 0: + _, resolved_attention_mode = self._model_forward_logits( + input_ids=cur_x, + attention_mask_3d=cur_attn_mask, + attention_mask_4d=cur_attn_mask_4d, + attention_mask_2d=cur_attn_mask_2d, + position_ids=cur_position_ids, + attention_mask_mode=resolved_attention_mode, + past_key_values=past_key_values, + store_kv=True, + ) + break + + logits, resolved_attention_mode = self._model_forward_logits( + input_ids=cur_x, + attention_mask_3d=cur_attn_mask, + attention_mask_4d=cur_attn_mask_4d, + attention_mask_2d=cur_attn_mask_2d, + position_ids=cur_position_ids, + attention_mask_mode=resolved_attention_mode, + past_key_values=past_key_values, + store_kv=False, + ) + + step_output = self.scheduler.step( + logits, + step, + cur_x, + mask_token_id=int(mask_token_id), + num_transfer_tokens=num_transfer_tokens, + remasking_strategy=remasking_strategy, + confidence_threshold=confidence_threshold, + entropy_threshold=entropy_threshold, + temperature=temperature, + top_k=top_k, + top_p=top_p, + return_dict=True, + ) + cur_x = step_output.prev_sample + transfer_index = step_output.transfer_index + sampled_tokens = step_output.sampled_tokens + sampled_probs = step_output.sampled_probs + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, global_step, step, callback_kwargs) + cur_x = callback_outputs.pop("cur_x", cur_x) + + global_step += 1 + + x[:, start:end] = cur_x + if stop_tensor is not None and torch.isin(x[:, prompt_length:], stop_tensor).any(): + break + + output_ids = x[:, : prompt_length + int(max_new_tokens)] + if stop_tensor is not None: + stop_positions = torch.isin(output_ids[0, prompt_length:], stop_tensor).nonzero(as_tuple=True)[0] + if stop_positions.numel() > 0: + output_ids = output_ids[:, : prompt_length + int(stop_positions[0].item()) + 1] + + if output_ids.shape[0] == 1: + output_ids = output_ids[:, output_ids[0] != int(mask_token_id)] + + sequences = output_ids[:, prompt_length:] + texts = None + if return_text and getattr(self, "tokenizer", None) is not None: + texts = self.tokenizer.batch_decode(sequences, skip_special_tokens=True) + + if not return_dict: + return sequences, texts + return SDARPipelineOutput(sequences=sequences, texts=texts) + + def _model_forward_logits( + self, + *, + input_ids: torch.LongTensor, + attention_mask_3d: Optional[torch.Tensor], + attention_mask_4d: Optional[torch.Tensor], + attention_mask_2d: Optional[torch.Tensor], + position_ids: torch.LongTensor, + attention_mask_mode: str, + past_key_values: DynamicCache, + store_kv: bool, + ) -> Tuple[torch.Tensor, str]: + if attention_mask_mode not in {"auto", "3d", "4d", "2d", "none"}: + raise ValueError( + f"`attention_mask_mode` must be one of {{'auto','3d','4d','2d','none'}}, got {attention_mask_mode!r}." + ) + + def _call(mask): + kwargs = { + "input_ids": input_ids, + "attention_mask": mask, + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": True, + } + if self._store_kv_supported is False: + output = self.model(**kwargs) + return output.logits if hasattr(output, "logits") else output[0] + if self._store_kv_supported is True: + kwargs["store_kv"] = store_kv + output = self.model(**kwargs) + return output.logits if hasattr(output, "logits") else output[0] + try: + kwargs["store_kv"] = store_kv + output = self.model(**kwargs) + self._store_kv_supported = True + return output.logits if hasattr(output, "logits") else output[0] + except TypeError: + output = self.model(**kwargs) + self._store_kv_supported = False + return output.logits if hasattr(output, "logits") else output[0] + + if attention_mask_mode == "none": + return _call(None), "none" + if attention_mask_mode == "2d": + return _call(attention_mask_2d), "2d" + if attention_mask_mode == "3d": + return _call(attention_mask_3d), "3d" + if attention_mask_mode == "4d": + return _call(attention_mask_4d), "4d" + + try: + return _call(attention_mask_3d), "3d" + except (TypeError, ValueError, RuntimeError): + pass + try: + return _call(attention_mask_4d), "4d" + except (TypeError, ValueError, RuntimeError): + pass + try: + return _call(attention_mask_2d), "2d" + except (TypeError, ValueError, RuntimeError): + return _call(None), "none" + + def _build_block_attention_mask_3d( + self, + *, + num_blocks: int, + block_length: int, + total_length: int, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + block_mask = torch.tril(torch.ones(num_blocks, num_blocks, device=device, dtype=dtype)) + attn = block_mask.repeat_interleave(block_length, dim=0).repeat_interleave(block_length, dim=1).unsqueeze(0) + return attn[:, :total_length, :total_length] + + def _build_block_attention_mask_4d(self, mask_3d: torch.Tensor, *, dtype: torch.dtype) -> torch.Tensor: + attn = mask_3d.unsqueeze(1).to(dtype=dtype) + return torch.where( + attn > 0, + torch.zeros((), device=attn.device, dtype=dtype), + torch.full((), float("-inf"), device=attn.device, dtype=dtype), + ) + + def _prepare_input_ids( + self, + *, + prompt: Optional[Union[str, List[str]]], + messages: Optional[List[Dict[str, str]]], + input_ids: Optional[torch.LongTensor], + use_chat_template: bool, + add_generation_prompt: bool, + chat_template_kwargs: Optional[Dict[str, object]], + ) -> torch.LongTensor: + if input_ids is not None: + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + if input_ids.ndim != 2: + raise ValueError(f"`input_ids` must be 2D, got shape {tuple(input_ids.shape)}.") + if input_ids.dtype != torch.long: + raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.") + return input_ids + + if self.tokenizer is None: + raise ValueError("Tokenizer is required when `input_ids` is not provided.") + + if messages is not None and prompt is not None: + raise ValueError("Provide either `prompt` or `messages`, not both.") + if messages is None and prompt is None: + raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.") + + chat_template_kwargs = chat_template_kwargs or {} + + def _extract_input_ids(encoded): + if isinstance(encoded, dict) and "input_ids" in encoded: + return encoded["input_ids"] + if hasattr(encoded, "input_ids"): + return encoded.input_ids + return encoded + + if messages is not None: + encoded = self.tokenizer.apply_chat_template( + messages, + add_generation_prompt=add_generation_prompt, + tokenize=True, + return_tensors="pt", + return_dict=True, + **chat_template_kwargs, + ) + return _extract_input_ids(encoded) + + if use_chat_template and getattr(self.tokenizer, "chat_template", None): + if isinstance(prompt, list): + raise ValueError("`prompt` must be a string when `use_chat_template=True`.") + encoded = self.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=add_generation_prompt, + tokenize=True, + return_tensors="pt", + return_dict=True, + **chat_template_kwargs, + ) + return _extract_input_ids(encoded) + + encoded = self.tokenizer(prompt, return_tensors="pt", padding=isinstance(prompt, list)) + return _extract_input_ids(encoded) + + +__all__ = ["SDARPipeline", "SDARPipelineOutput"] diff --git a/src/diffusers/pipelines/token_diffusion/__init__.py b/src/diffusers/pipelines/token_diffusion/__init__.py new file mode 100644 index 000000000000..3fc40f86ee12 --- /dev/null +++ b/src/diffusers/pipelines/token_diffusion/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .pipeline_token_diffusion import TokenDiffusionPipeline, TokenDiffusionPipelineOutput + + +__all__ = ["TokenDiffusionPipeline", "TokenDiffusionPipelineOutput"] diff --git a/src/diffusers/pipelines/token_diffusion/pipeline_token_diffusion.py b/src/diffusers/pipelines/token_diffusion/pipeline_token_diffusion.py new file mode 100644 index 000000000000..217a533a1119 --- /dev/null +++ b/src/diffusers/pipelines/token_diffusion/pipeline_token_diffusion.py @@ -0,0 +1,276 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...utils import BaseOutput +from ..pipeline_utils import DiffusionPipeline + + +@dataclass +class TokenDiffusionPipelineOutput(BaseOutput): + """ + Output class for token diffusion pipelines. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Sampled token IDs. + texts (`List[str]`, *optional*): + Decoded texts if a tokenizer was provided and `return_text=True`. + """ + + sequences: torch.LongTensor + texts: Optional[List[str]] = None + + +class TokenDiffusionPipeline(DiffusionPipeline): + """ + Generic token diffusion sampling pipeline. + + This pipeline is intended as a minimal, diffusers-native wrapper around: + - a token denoiser model (e.g. `transformers.AutoModelForMaskedLM`-like, returning logits over vocab), and + - a discrete token scheduler (e.g. `TokenDiffusionScheduler`) that implements `set_timesteps()` and `step()`. + + The pipeline supports multiple forward processes via the scheduler configuration (e.g. absorbing/mask, uniform). + Conditioning (prefix/infill) is intentionally out of scope for the first version. + """ + + model: Any + tokenizer: Any + scheduler: Any + + _callback_tensor_inputs = ["input_ids", "logits"] + + def __init__( + self, + model: Any, + scheduler: Any, + tokenizer: Optional[Any] = None, + *, + seq_len: int = 64, + num_inference_steps: int = 128, + inject_start_token: bool = False, + ): + super().__init__() + self.register_modules(model=model, scheduler=scheduler, tokenizer=tokenizer) + self.register_to_config( + seq_len=seq_len, + num_inference_steps=num_inference_steps, + inject_start_token=inject_start_token, + ) + + @property + def num_timesteps(self): + return self._num_timesteps + + def _resolve_start_token_id(self) -> Optional[int]: + tok = getattr(self, "tokenizer", None) + if tok is None: + return None + for attr in ("bos_token_id", "cls_token_id"): + token_id = getattr(tok, attr, None) + if token_id is not None: + return int(token_id) + return None + + def _init_latents( + self, + batch_size: int, + seq_len: int, + *, + generator: Optional[torch.Generator], + device: torch.device, + ) -> torch.LongTensor: + # Prefer a scheduler-provided prior if available. + if hasattr(self.scheduler, "forward_process") and getattr(self.scheduler, "forward_process") == "uniform": + # Uniform prior over token IDs. Mirror scheduler's exclude-mask behavior. + if getattr(self.scheduler, "exclude_mask_from_uniform", False) and hasattr( + self.scheduler, "_sample_uniform_tokens" + ): + return self.scheduler._sample_uniform_tokens( + torch.Size((batch_size, seq_len)), + device=device, + dtype=torch.long, + generator=generator, + ) + vocab_size = int(getattr(self.scheduler, "vocab_size", 0)) + if vocab_size <= 0: + raise ValueError("Scheduler must define `vocab_size` for uniform prior sampling.") + return torch.randint( + 0, vocab_size, (batch_size, seq_len), device=device, dtype=torch.long, generator=generator + ) + + mask_token_id = getattr(self.scheduler, "mask_token_id", None) + if mask_token_id is None: + raise ValueError("Scheduler must define `mask_token_id` for absorbing prior sampling.") + return torch.full((batch_size, seq_len), int(mask_token_id), device=device, dtype=torch.long) + + def _normalize_prefix_ids( + self, prefix_ids: torch.LongTensor, batch_size: int, device: torch.device + ) -> torch.LongTensor: + if prefix_ids.ndim == 1: + prefix_ids = prefix_ids.unsqueeze(0) + if prefix_ids.ndim != 2: + raise ValueError( + f"`prefix_ids` must have shape [prefix_len] or [batch, prefix_len], got {prefix_ids.shape}." + ) + if prefix_ids.shape[0] not in (1, batch_size): + raise ValueError( + f"`prefix_ids` batch dim must be 1 or batch_size={batch_size}, got {prefix_ids.shape[0]}." + ) + if prefix_ids.dtype != torch.long: + raise ValueError(f"`prefix_ids` must be int64 token IDs, got dtype={prefix_ids.dtype}.") + prefix_ids = prefix_ids.to(device=device) + if prefix_ids.shape[0] == 1 and batch_size > 1: + prefix_ids = prefix_ids.expand(batch_size, -1) + return prefix_ids + + @torch.no_grad() + def __call__( + self, + *, + batch_size: int = 1, + seq_len: Optional[int] = None, + num_inference_steps: Optional[int] = None, + generator: Optional[torch.Generator] = None, + prefix_ids: Optional[torch.LongTensor] = None, + infill_mask: Optional[torch.BoolTensor] = None, + inject_start_token: Optional[bool] = None, + return_text: bool = True, + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, + **model_kwargs, + ) -> Union[TokenDiffusionPipelineOutput, Tuple[torch.LongTensor, Optional[List[str]]]]: + """ + Args: + batch_size: Number of sequences to generate. + seq_len: Sequence length in tokens. Uses `pipe.config.seq_len` when `None`. + num_inference_steps: + Number of reverse diffusion steps. Uses `pipe.config.num_inference_steps` when `None`. + generator: Optional torch generator for determinism. + prefix_ids: Optional prefix token IDs to keep fixed at the start of each sequence. Shape `[P]` or + `[batch_size, P]`. + infill_mask: + Optional boolean mask of shape `[batch_size, seq_len]` indicating which positions are editable (`True`) + vs fixed (`False`). Fixed positions are clamped to the initial values on every step. + inject_start_token: If True, inject `bos_token_id` (or `cls_token_id`) into position 0 (if available). + Uses `pipe.config.inject_start_token` when `None`. + return_text: If True and tokenizer exists, also return decoded strings. + return_dict: If True, returns a `TokenDiffusionPipelineOutput`. + callback_on_step_end: A function called after each denoising step with signature + `callback_on_step_end(self, step: int, timestep: int, callback_kwargs: Dict)`. + callback_on_step_end_tensor_inputs: List of tensor keys to include in `callback_kwargs`. + model_kwargs: Forward kwargs passed to `model(...)` (e.g. attention mask overrides). + """ + if seq_len is None: + seq_len = int(self.config.seq_len) + if num_inference_steps is None: + num_inference_steps = int(self.config.num_inference_steps) + if inject_start_token is None: + inject_start_token = bool(self.config.inject_start_token) + + if callback_on_step_end is not None and isinstance( + callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks) + ): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if callback_on_step_end_tensor_inputs is None: + callback_on_step_end_tensor_inputs = ["input_ids"] + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " + f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + device = self._execution_device + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + + input_ids = self._init_latents(batch_size, seq_len, generator=generator, device=device) + attention_mask = torch.ones_like(input_ids, dtype=torch.long) + + fixed_mask = None + fixed_values = None + if infill_mask is not None: + if infill_mask.shape != (batch_size, seq_len): + raise ValueError( + f"`infill_mask` must have shape {(batch_size, seq_len)}, got {tuple(infill_mask.shape)}." + ) + fixed_mask = (~infill_mask.to(device=device)).to(dtype=torch.bool) + fixed_values = input_ids.clone() + + if prefix_ids is not None: + prefix_ids = self._normalize_prefix_ids(prefix_ids, batch_size=batch_size, device=device) + prefix_len = prefix_ids.shape[1] + if prefix_len > seq_len: + raise ValueError(f"`prefix_ids` length {prefix_len} must be <= seq_len={seq_len}.") + + input_ids[:, :prefix_len] = prefix_ids + if fixed_mask is None: + fixed_mask = torch.zeros((batch_size, seq_len), device=device, dtype=torch.bool) + fixed_values = input_ids.clone() + fixed_mask[:, :prefix_len] = True + fixed_values[:, :prefix_len] = prefix_ids + + start_token_id = self._resolve_start_token_id() + if inject_start_token and start_token_id is not None: + input_ids[:, 0] = start_token_id + if fixed_mask is not None: + fixed_mask[:, 0] = True + fixed_values[:, 0] = start_token_id + + for step_idx, t in enumerate(timesteps): + out = self.model(input_ids=input_ids, attention_mask=attention_mask, **model_kwargs) + logits = getattr(out, "logits", None) + if logits is None: + # Fall back to tuple-style returns. + logits = out[0] + + input_ids = self.scheduler.step(logits, t, input_ids, generator=generator, return_dict=True).prev_sample + + if fixed_mask is not None: + input_ids = torch.where(fixed_mask, fixed_values, input_ids) + + if inject_start_token and start_token_id is not None: + input_ids[:, 0] = start_token_id + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, step_idx, t, callback_kwargs) + input_ids = callback_outputs.pop("input_ids", input_ids) + + texts = None + if return_text and getattr(self, "tokenizer", None) is not None: + texts = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True) + + if not return_dict: + return (input_ids, texts) + return TokenDiffusionPipelineOutput(sequences=input_ids, texts=texts) + + +__all__ = ["TokenDiffusionPipeline", "TokenDiffusionPipelineOutput"] diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 4199e75bf331..c6025293fb4c 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -40,6 +40,7 @@ else: _import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"] _import_structure["scheduling_amused"] = ["AmusedScheduler"] + _import_structure["scheduling_block_token_diffusion"] = ["BlockTokenDiffusionScheduler"] _import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"] _import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"] _import_structure["scheduling_ddim"] = ["DDIMScheduler"] @@ -50,6 +51,10 @@ _import_structure["scheduling_ddpm_parallel"] = ["DDPMParallelScheduler"] _import_structure["scheduling_ddpm_wuerstchen"] = ["DDPMWuerstchenScheduler"] _import_structure["scheduling_deis_multistep"] = ["DEISMultistepScheduler"] + _import_structure["scheduling_dflash_token_diffusion"] = [ + "DFlashTokenDiffusionScheduler", + "DFlashTokenDiffusionSchedulerOutput", + ] _import_structure["scheduling_dpm_cogvideox"] = ["CogVideoXDPMScheduler"] _import_structure["scheduling_dpmsolver_multistep"] = ["DPMSolverMultistepScheduler"] _import_structure["scheduling_dpmsolver_multistep_inverse"] = ["DPMSolverMultistepInverseScheduler"] @@ -62,6 +67,10 @@ _import_structure["scheduling_flow_match_heun_discrete"] = ["FlowMatchHeunDiscreteScheduler"] _import_structure["scheduling_flow_match_lcm"] = ["FlowMatchLCMScheduler"] _import_structure["scheduling_heun_discrete"] = ["HeunDiscreteScheduler"] + _import_structure["scheduling_hybrid_token_diffusion"] = [ + "HybridTokenDiffusionScheduler", + "HybridTokenDiffusionSchedulerOutput", + ] _import_structure["scheduling_ipndm"] = ["IPNDMScheduler"] _import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"] _import_structure["scheduling_k_dpm_2_discrete"] = ["KDPM2DiscreteScheduler"] @@ -71,8 +80,13 @@ _import_structure["scheduling_repaint"] = ["RePaintScheduler"] _import_structure["scheduling_sasolver"] = ["SASolverScheduler"] _import_structure["scheduling_scm"] = ["SCMScheduler"] + _import_structure["scheduling_sdar_token_diffusion"] = [ + "SDARTokenDiffusionScheduler", + "SDARTokenDiffusionSchedulerOutput", + ] _import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"] _import_structure["scheduling_tcd"] = ["TCDScheduler"] + _import_structure["scheduling_token_diffusion"] = ["TokenDiffusionScheduler"] _import_structure["scheduling_unclip"] = ["UnCLIPScheduler"] _import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"] _import_structure["scheduling_utils"] = ["AysSchedules", "KarrasDiffusionSchedulers", "SchedulerMixin"] @@ -153,6 +167,10 @@ from .scheduling_ddpm_parallel import DDPMParallelScheduler from .scheduling_ddpm_wuerstchen import DDPMWuerstchenScheduler from .scheduling_deis_multistep import DEISMultistepScheduler + from .scheduling_dflash_token_diffusion import ( + DFlashTokenDiffusionScheduler, + DFlashTokenDiffusionSchedulerOutput, + ) from .scheduling_dpm_cogvideox import CogVideoXDPMScheduler from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler @@ -174,6 +192,7 @@ from .scheduling_repaint import RePaintScheduler from .scheduling_sasolver import SASolverScheduler from .scheduling_scm import SCMScheduler + from .scheduling_sdar_token_diffusion import SDARTokenDiffusionScheduler, SDARTokenDiffusionSchedulerOutput from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_tcd import TCDScheduler from .scheduling_unclip import UnCLIPScheduler diff --git a/src/diffusers/schedulers/scheduling_block_token_diffusion.py b/src/diffusers/schedulers/scheduling_block_token_diffusion.py new file mode 100644 index 000000000000..4d743cf93dbc --- /dev/null +++ b/src/diffusers/schedulers/scheduling_block_token_diffusion.py @@ -0,0 +1,95 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional, Tuple, Union + +import torch + +from .scheduling_token_diffusion import TokenDiffusionScheduler, TokenDiffusionSchedulerOutput + + +class BlockTokenDiffusionScheduler(TokenDiffusionScheduler): + """ + A token diffusion scheduler that supports updating only a subset of positions (e.g. a block). + + This scheduler reuses the same alpha schedules and forward processes as `TokenDiffusionScheduler`, but allows + callers to restrict noising/denoising to a boolean `block_mask` of shape `[batch, seq_len]`. + """ + + @classmethod + def from_config(cls, config, **kwargs): + # TokenDiffusionScheduler doesn't have compatibles; keep standard ConfigMixin behavior. + return super().from_config(config, **kwargs) + + def add_noise( + self, + original_samples: torch.LongTensor, + noise: Optional[torch.Tensor], + timesteps: torch.LongTensor, + block_mask: Optional[torch.BoolTensor] = None, + ) -> torch.LongTensor: + if block_mask is None: + return super().add_noise(original_samples=original_samples, noise=noise, timesteps=timesteps) + + if block_mask.dtype != torch.bool: + raise ValueError(f"`block_mask` must be boolean, got dtype={block_mask.dtype}.") + if block_mask.shape != original_samples.shape: + raise ValueError( + f"`block_mask` must have shape {tuple(original_samples.shape)}, got {tuple(block_mask.shape)}." + ) + + noised = super().add_noise(original_samples=original_samples, noise=noise, timesteps=timesteps) + return torch.where(block_mask.to(device=original_samples.device), noised, original_samples) + + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.LongTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + block_mask: Optional[torch.BoolTensor] = None, + ) -> Union[TokenDiffusionSchedulerOutput, Tuple[torch.LongTensor]]: + if block_mask is None: + return super().step( + model_output=model_output, + timestep=timestep, + sample=sample, + generator=generator, + return_dict=return_dict, + ) + + if block_mask.dtype != torch.bool: + raise ValueError(f"`block_mask` must be boolean, got dtype={block_mask.dtype}.") + if block_mask.shape != sample.shape: + raise ValueError(f"`block_mask` must have shape {tuple(sample.shape)}, got {tuple(block_mask.shape)}.") + + out = super().step( + model_output=model_output, + timestep=timestep, + sample=sample, + generator=generator, + return_dict=True, + ) + prev = out.prev_sample + prev = torch.where(block_mask.to(device=prev.device), prev, sample) + + if not return_dict: + return (prev,) + return TokenDiffusionSchedulerOutput(prev_sample=prev) + + +__all__ = ["BlockTokenDiffusionScheduler"] diff --git a/src/diffusers/schedulers/scheduling_dflash_token_diffusion.py b/src/diffusers/schedulers/scheduling_dflash_token_diffusion.py new file mode 100644 index 000000000000..24e8238321c9 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_dflash_token_diffusion.py @@ -0,0 +1,107 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class DFlashTokenDiffusionSchedulerOutput(BaseOutput): + """ + Output class for DFlash-style speculative token scheduling. + + Args: + prev_sample (`torch.LongTensor` of shape `(batch_size, block_size)`): + The proposed block tokens from the draft model. + accepted_length (`torch.LongTensor` of shape `(batch_size,)`): + Number of consecutive accepted tokens from the block. + next_token (`torch.LongTensor` of shape `(batch_size,)`): + Next token sampled from the target posterior at the first rejection. + posterior (`torch.LongTensor` of shape `(batch_size, block_size)`): + Sampled tokens from the target posterior used for acceptance checks. + """ + + prev_sample: torch.LongTensor + accepted_length: torch.LongTensor + next_token: torch.LongTensor + posterior: torch.LongTensor + + +class DFlashTokenDiffusionScheduler(SchedulerMixin, ConfigMixin): + """ + Scheduler for DFlash-style block diffusion speculative decoding. + + This scheduler samples target posteriors and computes acceptance lengths for draft blocks. + """ + + order = 1 + + @register_to_config + def __init__(self): + self.num_inference_steps = 1 + self.timesteps = torch.tensor([0], dtype=torch.long) + + def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None: + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + self.num_inference_steps = int(num_inference_steps) + self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, device=device, dtype=torch.long) + + def sample(self, logits: torch.Tensor, temperature: float = 0.0) -> torch.LongTensor: + if temperature < 1e-5: + return torch.argmax(logits, dim=-1) + bsz, seq_len, vocab_size = logits.shape + flat = logits.view(-1, vocab_size) / float(temperature) + probs = torch.softmax(flat, dim=-1) + return torch.multinomial(probs, num_samples=1).view(bsz, seq_len) + + def step( + self, + draft_tokens: torch.LongTensor, + target_logits: torch.Tensor, + *, + temperature: float = 0.0, + return_dict: bool = True, + ) -> Union[ + DFlashTokenDiffusionSchedulerOutput, + Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor], + ]: + posterior = self.sample(target_logits, temperature=temperature) + if draft_tokens.shape[1] > 1: + matches = draft_tokens[:, 1:] == posterior[:, :-1] + accepted_length = matches.int().cumprod(dim=1).sum(dim=1) + else: + accepted_length = torch.zeros((draft_tokens.shape[0],), device=draft_tokens.device, dtype=torch.long) + + next_token = posterior.gather(1, accepted_length.unsqueeze(1)).squeeze(1) + + if not return_dict: + return draft_tokens, accepted_length, next_token, posterior + return DFlashTokenDiffusionSchedulerOutput( + prev_sample=draft_tokens, + accepted_length=accepted_length, + next_token=next_token, + posterior=posterior, + ) + + +__all__ = ["DFlashTokenDiffusionScheduler", "DFlashTokenDiffusionSchedulerOutput"] diff --git a/src/diffusers/schedulers/scheduling_hybrid_token_diffusion.py b/src/diffusers/schedulers/scheduling_hybrid_token_diffusion.py new file mode 100644 index 000000000000..70982cc70d24 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_hybrid_token_diffusion.py @@ -0,0 +1,243 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_token_diffusion import _gumbel_argmax +from .scheduling_utils import SchedulerMixin + + +@dataclass +class HybridTokenDiffusionSchedulerOutput(BaseOutput): + prev_sample: torch.LongTensor + + +class HybridTokenDiffusionScheduler(SchedulerMixin, ConfigMixin): + """ + Hybrid-transition discrete token diffusion scheduler. + + This scheduler defines a forward transition kernel that mixes: + - keeping the current token (scaled by alpha(t)) + - moving toward a mixture distribution over tokens (beta_pi(t)) + + The scheduler exposes: + - `add_noise(...)` for forward corruption + - `step(...)` for reverse updates using the model's predicted token distribution + """ + + order = 1 + + @register_to_config + def __init__( + self, + vocab_size: int, + mask_token_id: int, + num_train_timesteps: int = 1000, + t_eps: float = 1e-4, + p_uniform: float = 0.0, + clip_noise: float = 20.0, + gamma: float = 1.0, + ): + if vocab_size <= 0: + raise ValueError(f"`vocab_size` must be > 0, got {vocab_size}.") + if not (0 <= mask_token_id < vocab_size): + raise ValueError(f"`mask_token_id` must be in [0, vocab_size), got {mask_token_id}.") + if num_train_timesteps <= 1: + raise ValueError(f"`num_train_timesteps` must be > 1, got {num_train_timesteps}.") + if not (0.0 < t_eps < 0.5): + raise ValueError(f"`t_eps` must be in (0, 0.5), got {t_eps}.") + if gamma <= 0: + raise ValueError(f"`gamma` must be > 0, got {gamma}.") + + self.vocab_size = int(vocab_size) + self.mask_token_id = int(mask_token_id) + self.num_train_timesteps = int(num_train_timesteps) + self.t_eps = float(t_eps) + + p_uniform = max(math.exp(-float(clip_noise)), float(p_uniform)) + log_B = float(gamma) * math.log(2.0) + math.log(p_uniform) - math.log(1.0 - p_uniform) + log_B = float(np.clip(log_B, -float(clip_noise), float(clip_noise))) + self.log_B = float(log_B) + self.log_gamma = float(math.log(float(gamma))) + + self.num_inference_steps = None + self.timesteps = None + self._timesteps_with_end = None + + mask = torch.zeros(self.vocab_size, dtype=torch.float32) + mask[self.mask_token_id] = 1.0 + self.mask = mask + + unif = (1.0 - mask) / max(self.vocab_size - 1, 1) + self.unif = unif + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device, None] = None) -> None: + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + self.num_inference_steps = int(num_inference_steps) + + t0 = 1.0 - float(self.t_eps) + t1 = float(self.t_eps) + timesteps = torch.linspace(t0, t1, self.num_inference_steps + 1, dtype=torch.float32, device=device) + self._timesteps_with_end = timesteps + self.timesteps = timesteps[:-1] + + def scale_model_input( + self, sample: torch.Tensor, timestep: Optional[Union[int, torch.Tensor]] = None + ) -> torch.Tensor: + return sample + + def _to_continuous_t(self, timesteps: torch.Tensor, device: torch.device) -> torch.Tensor: + if timesteps.dtype in (torch.float16, torch.float32, torch.float64, torch.bfloat16): + t = timesteps.to(device=device, dtype=torch.float32) + return t.clamp(float(self.t_eps), 1.0 - float(self.t_eps)) + + if timesteps.dtype not in (torch.int32, torch.int64): + raise ValueError(f"`timesteps` must be float or int, got dtype={timesteps.dtype}.") + + t = timesteps.to(device=device, dtype=torch.float32) / float(self.num_train_timesteps - 1) + t = (1.0 - 2.0 * float(self.t_eps)) * t + float(self.t_eps) + return t.clamp(float(self.t_eps), 1.0 - float(self.t_eps)) + + def _get_alpha_betapi(self, t: torch.Tensor, eps: float = 1e-6) -> Tuple[torch.Tensor, torch.Tensor]: + t = t.view(-1, 1) + t1m = 1.0 - t + + gamma = float(math.exp(self.log_gamma)) + B = float(math.exp(self.log_B)) + c_t = (t.pow(gamma / 2.0) * t1m.pow(gamma / 2.0) * B).to(dtype=torch.float32) + C_t = (1.0 + c_t).clamp_min(eps) + + alpha_t = t1m / C_t + beta_pi = ( + t * self.mask.to(device=t.device, dtype=torch.float32) + + c_t * self.unif.to(device=t.device, dtype=torch.float32) + ) / C_t + return alpha_t, beta_pi + + def _probs_at_t(self, probs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + alpha_t, beta_pi = self._get_alpha_betapi(t) + alpha_t = alpha_t.to(dtype=probs.dtype) + beta_pi = beta_pi.to(dtype=probs.dtype) + + out = probs.mul(alpha_t.unsqueeze(1)) + out[..., : beta_pi.shape[-1]].add_(beta_pi.unsqueeze(1)) + return out + + def _sample_categorical(self, probs: torch.Tensor, generator: Optional[torch.Generator]) -> torch.LongTensor: + bsz, seqlen, vocab = probs.shape + flat = probs.view(-1, vocab).clamp_min(torch.finfo(probs.dtype).tiny) + flat = flat / flat.sum(dim=-1, keepdim=True).clamp_min(torch.finfo(probs.dtype).eps) + sample = torch.multinomial(flat, num_samples=1, generator=generator).view(bsz, seqlen) + return sample.to(dtype=torch.long) + + def add_noise( + self, + original_samples: torch.LongTensor, + noise: Optional[torch.Tensor], + timesteps: torch.Tensor, + ) -> torch.LongTensor: + del noise + if original_samples.dtype != torch.long: + raise ValueError(f"`original_samples` must be int64 token IDs, got dtype={original_samples.dtype}.") + + device = original_samples.device + t = self._to_continuous_t(timesteps.to(device=device), device=device) + onehot = F.one_hot(original_samples, num_classes=self.vocab_size).to(dtype=torch.float32) + probs = self._probs_at_t(onehot, t) + return self._sample_categorical(probs, generator=None) + + def _index_for_timestep(self, timestep: Union[float, torch.Tensor]) -> int: + if self.timesteps is None: + raise ValueError("Call `set_timesteps(...)` before calling `step()`.") + + if isinstance(timestep, torch.Tensor): + t = float(timestep.detach().cpu().item()) + else: + t = float(timestep) + + idx = int(torch.argmin(torch.abs(self.timesteps.detach().cpu() - torch.tensor(t))).item()) + return idx + + def step( + self, + model_output: torch.Tensor, + timestep: Union[float, torch.Tensor], + sample: torch.LongTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[HybridTokenDiffusionSchedulerOutput, Tuple[torch.LongTensor]]: + if sample.dtype != torch.long: + raise ValueError(f"`sample` must be int64 token IDs, got dtype={sample.dtype}.") + if model_output.ndim != 3 or model_output.shape[-1] != self.vocab_size: + raise ValueError( + f"`model_output` must have shape [batch, seq_len, vocab_size={self.vocab_size}], got {tuple(model_output.shape)}." + ) + if model_output.shape[0] != sample.shape[0] or model_output.shape[1] != sample.shape[1]: + raise ValueError( + f"`model_output` batch/seq dims {tuple(model_output.shape[:2])} must match `sample` {tuple(sample.shape)}." + ) + + if self._timesteps_with_end is None: + raise ValueError("Call `set_timesteps(...)` before calling `step()`.") + + device = sample.device + batch_size, seq_len = sample.shape + + step_index = self._index_for_timestep(timestep) + t_val = self._timesteps_with_end[step_index].to(device=device) + s_val = self._timesteps_with_end[step_index + 1].to(device=device) + + t = t_val * torch.ones(batch_size, device=device, dtype=torch.float32) + s = s_val * torch.ones(batch_size, device=device, dtype=torch.float32) + + logits = model_output.to(dtype=torch.float32) + logits = logits.clone() + logits[..., self.mask_token_id] = torch.finfo(logits.dtype).min + probs = logits.softmax(dim=-1) + + q_s = self._probs_at_t(probs, s) + q_t = self._probs_at_t(probs, t) + q_zt = q_t.gather(-1, sample.unsqueeze(-1)).clamp_min(torch.finfo(torch.float32).eps) + + alpha_t, beta_pi_t = self._get_alpha_betapi(t) + alpha_s, beta_pi_s = self._get_alpha_betapi(s) + + alpha_ts = (alpha_t / alpha_s).clamp_min(torch.finfo(torch.float32).eps) + beta_pi_ts = beta_pi_t - (alpha_t / alpha_s) * beta_pi_s + + vz_t = F.one_hot(sample, num_classes=self.vocab_size).to(dtype=torch.float32) + beta_pi_ts_at_zt = beta_pi_ts.unsqueeze(1).expand_as(vz_t).gather(-1, sample.unsqueeze(-1)) + q_ts = alpha_ts.view(batch_size, 1, 1) * vz_t + beta_pi_ts_at_zt + + q_st = q_ts * q_s / q_zt + q_st = q_st.clamp_min(torch.finfo(torch.float32).tiny) + q_st = q_st / q_st.sum(dim=-1, keepdim=True).clamp_min(torch.finfo(torch.float32).eps) + + x_prev = _gumbel_argmax(torch.log(q_st), generator=generator).to(dtype=torch.long) + + if not return_dict: + return (x_prev,) + return HybridTokenDiffusionSchedulerOutput(prev_sample=x_prev) + + +__all__ = ["HybridTokenDiffusionScheduler", "HybridTokenDiffusionSchedulerOutput"] diff --git a/src/diffusers/schedulers/scheduling_sdar_token_diffusion.py b/src/diffusers/schedulers/scheduling_sdar_token_diffusion.py new file mode 100644 index 000000000000..a5d86441491c --- /dev/null +++ b/src/diffusers/schedulers/scheduling_sdar_token_diffusion.py @@ -0,0 +1,238 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class SDARTokenDiffusionSchedulerOutput(BaseOutput): + """ + Output class for SDAR-style block diffusion scheduling. + + Args: + prev_sample (`torch.LongTensor` of shape `(batch_size, block_length)`): + Updated block tokens after the current denoising step. + transfer_index (`torch.BoolTensor` of shape `(batch_size, block_length)`): + Boolean mask indicating which tokens were updated. + sampled_tokens (`torch.LongTensor` of shape `(batch_size, block_length)`): + Sampled token IDs from the model logits. + sampled_probs (`torch.Tensor` of shape `(batch_size, block_length)`): + Probabilities of the sampled tokens. + """ + + prev_sample: torch.LongTensor + transfer_index: torch.BoolTensor + sampled_tokens: torch.LongTensor + sampled_probs: torch.Tensor + + +class SDARTokenDiffusionScheduler(SchedulerMixin, ConfigMixin): + """ + Scheduler for SDAR-style block diffusion decoding. + """ + + order = 1 + + @register_to_config + def __init__( + self, + block_length: int = 4, + denoising_steps: int = 4, + remasking_strategy: str = "low_confidence_dynamic", + confidence_threshold: float = 0.9, + entropy_threshold: float = 0.35, + temperature: float = 1.0, + top_k: int = 0, + top_p: float = 1.0, + ): + self.num_inference_steps = int(denoising_steps) + self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, dtype=torch.long) + + def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None: + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + self.num_inference_steps = int(num_inference_steps) + self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, device=device, dtype=torch.long) + + def get_num_transfer_tokens(self, block_length: int, num_inference_steps: int) -> torch.LongTensor: + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + base = int(block_length) // int(num_inference_steps) + remainder = int(block_length) % int(num_inference_steps) + num_transfer_tokens = torch.zeros(int(num_inference_steps), dtype=torch.long) + num_transfer_tokens += base + if remainder > 0: + num_transfer_tokens[:remainder] += 1 + return num_transfer_tokens + + def _top_k_logits(self, logits: torch.Tensor, k: int) -> torch.Tensor: + if k <= 0: + return logits + values, _ = torch.topk(logits, k) + min_values = values[..., -1, None] + return torch.where(logits < min_values, torch.full_like(logits, float("-inf")), logits) + + def _top_p_logits(self, logits: torch.Tensor, p: float) -> torch.Tensor: + if p >= 1.0: + return logits + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + sorted_mask = cumulative_probs > p + sorted_mask[..., 1:] = sorted_mask[..., :-1].clone() + sorted_mask[..., 0] = False + mask_indices = torch.scatter(torch.zeros_like(logits, dtype=torch.bool), -1, sorted_indices, sorted_mask) + return logits.masked_fill(mask_indices, float("-inf")) + + def sample( + self, + logits: torch.Tensor, + *, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + generator: Optional[torch.Generator] = None, + ) -> Tuple[torch.LongTensor, torch.Tensor]: + if temperature is None: + temperature = float(self.config.temperature) + if top_k is None: + top_k = int(self.config.top_k) + if top_p is None: + top_p = float(self.config.top_p) + + orig_shape = logits.shape[:-1] + vocab_size = logits.shape[-1] + flat = logits.view(-1, vocab_size) + + if temperature < 1e-5: + probs = F.softmax(flat, dim=-1) + tokens = torch.argmax(flat, dim=-1, keepdim=True) + token_probs = torch.gather(probs, -1, tokens) + return tokens.view(*orig_shape), token_probs.view(*orig_shape) + + flat = flat / float(temperature) + flat = self._top_k_logits(flat, int(top_k)) + flat = self._top_p_logits(flat, float(top_p)) + probs = F.softmax(flat, dim=-1) + tokens = torch.multinomial(probs, num_samples=1, generator=generator) + token_probs = torch.gather(probs, -1, tokens) + return tokens.view(*orig_shape), token_probs.view(*orig_shape) + + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.LongTensor, + *, + mask_token_id: int, + num_transfer_tokens: torch.LongTensor, + remasking_strategy: Optional[str] = None, + confidence_threshold: Optional[float] = None, + entropy_threshold: Optional[float] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[ + SDARTokenDiffusionSchedulerOutput, Tuple[torch.LongTensor, torch.BoolTensor, torch.LongTensor, torch.Tensor] + ]: + if remasking_strategy is None: + remasking_strategy = str(self.config.remasking_strategy) + if confidence_threshold is None: + confidence_threshold = float(self.config.confidence_threshold) + if entropy_threshold is None: + entropy_threshold = float(self.config.entropy_threshold) + + sampled_tokens, sampled_probs = self.sample( + model_output, temperature=temperature, top_k=top_k, top_p=top_p, generator=generator + ) + mask_index = sample == int(mask_token_id) + transfer_index = torch.zeros_like(mask_index) + + if isinstance(timestep, torch.Tensor): + step_index = int(timestep.item()) + else: + step_index = int(timestep) + + if step_index >= int(num_transfer_tokens.numel()): + step_index = int(num_transfer_tokens.numel()) - 1 + step_transfer = int(num_transfer_tokens[step_index].item()) + + if remasking_strategy == "sequential": + for j in range(sample.shape[0]): + if not mask_index[j].any(): + continue + num_masked = int(mask_index[j].sum().item()) + k = min(step_transfer, num_masked) + first_mask_index = mask_index[j].nonzero(as_tuple=True)[0].min().item() + transfer_index[j, first_mask_index : first_mask_index + k] = True + + elif remasking_strategy in {"low_confidence_static", "low_confidence_dynamic"}: + confidence = torch.where(mask_index, sampled_probs, torch.full_like(sampled_probs, float("-inf"))) + for j in range(confidence.shape[0]): + if not mask_index[j].any(): + continue + num_masked = int(mask_index[j].sum().item()) + k = min(step_transfer, num_masked) + if remasking_strategy == "low_confidence_dynamic": + high_conf_mask = confidence[j] > confidence_threshold + if int(high_conf_mask.sum().item()) >= k: + transfer_index[j] = high_conf_mask + continue + _, idx = torch.topk(confidence[j], k) + transfer_index[j, idx] = True + + elif remasking_strategy == "entropy_bounded": + eps = 1e-12 + entropies = -(sampled_probs.clamp_min(eps) * sampled_probs.clamp_min(eps).log()).sum(dim=-1) + entropies = torch.where(mask_index, entropies, torch.full_like(sampled_probs, float("inf"))) + ent_sorted, order = torch.sort(entropies, dim=1, descending=False) + cumsum = torch.cumsum(ent_sorted, dim=1) + for j in range(sampled_probs.shape[0]): + if not mask_index[j].any(): + continue + threshold_tensor = torch.tensor(entropy_threshold, device=sampled_probs.device) + k = int(torch.searchsorted(cumsum[j], threshold_tensor, right=False).item()) + num_masked = int(mask_index[j].sum().item()) + k = max(1, min(k, num_masked)) + selected_token_indices = order[j, :k] + transfer_index[j, selected_token_indices] = True + + else: + raise ValueError(f"Unknown remasking strategy: {remasking_strategy}") + + prev_sample = sample.clone() + prev_sample[transfer_index] = sampled_tokens[transfer_index] + + if not return_dict: + return prev_sample, transfer_index, sampled_tokens, sampled_probs + return SDARTokenDiffusionSchedulerOutput( + prev_sample=prev_sample, + transfer_index=transfer_index, + sampled_tokens=sampled_tokens, + sampled_probs=sampled_probs, + ) + + +__all__ = ["SDARTokenDiffusionScheduler", "SDARTokenDiffusionSchedulerOutput"] diff --git a/src/diffusers/schedulers/scheduling_token_diffusion.py b/src/diffusers/schedulers/scheduling_token_diffusion.py new file mode 100644 index 000000000000..ee08a6e221a8 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_token_diffusion.py @@ -0,0 +1,467 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class TokenDiffusionSchedulerOutput(BaseOutput): + """ + Output class for discrete token schedulers. + + Args: + prev_sample (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Sample at the previous timestep. This should be fed into the model at the next denoising iteration. + """ + + prev_sample: torch.LongTensor + + +def _gumbel_argmax(logits: torch.Tensor, generator: Optional[torch.Generator] = None) -> torch.LongTensor: + """ + Sample from a categorical distribution defined by (unnormalized) logits via Gumbel-max. + + Args: + logits: Tensor of shape `(..., vocab_size)`. + generator: Optional torch generator for determinism. + + Returns: + `torch.LongTensor` of shape `logits.shape[:-1]` with sampled indices. + """ + # Gumbel(0,1) noise: -log(-log(U)) + uniform = torch.rand(logits.shape, device=logits.device, dtype=logits.dtype, generator=generator).clamp_(1e-30, 1) + gumbel = -torch.log(-torch.log(uniform)) + return (logits + gumbel).argmax(dim=-1) + + +class TokenDiffusionScheduler(SchedulerMixin, ConfigMixin): + """ + Discrete diffusion scheduler over token IDs (categorical states). + + This scheduler is designed for *token-space* diffusion (e.g. masked/absorbing diffusion language models) and + follows the diffusers scheduler API where possible: `set_timesteps()` for inference and `step()` for reverse + updates. + + Currently implemented: + - Forward process: + - `absorbing`: with probability `1 - alpha(t)` replace token with `mask_token_id`. + - `uniform`: with probability `1 - alpha(t)` replace token with a uniform random token. + - Noise schedule: selectable `alpha(t)` families with `t in [0, 1]`. + + Notes: + - `step()` expects the model to return logits over vocabulary for `x0` reconstruction. + - The mask token is treated as an *absorbing state* and is never sampled as an `x0` prediction. + """ + + order = 1 + + @register_to_config + def __init__( + self, + vocab_size: int, + mask_token_id: int, + num_train_timesteps: int = 1000, + alpha_schedule: str = "log_linear", + eps: float = 1e-3, + sigma_min: float = 1e-4, + sigma_max: float = 20.0, + forward_process: str = "absorbing", + exclude_mask_from_uniform: bool = True, + ): + if vocab_size <= 0: + raise ValueError(f"`vocab_size` must be > 0, got {vocab_size}.") + if num_train_timesteps <= 1: + raise ValueError(f"`num_train_timesteps` must be > 1, got {num_train_timesteps}.") + if not (0.0 < eps < 1.0): + raise ValueError(f"`eps` must be in (0, 1), got {eps}.") + if not (0 <= mask_token_id < vocab_size): + raise ValueError(f"`mask_token_id` must be in [0, vocab_size), got {mask_token_id}.") + alpha_schedule = str(alpha_schedule).lower() + if alpha_schedule not in {"log_linear", "linear", "cosine", "geometric"}: + raise ValueError( + "`alpha_schedule` must be one of {'log_linear','linear','cosine','geometric'}, got" + f" {alpha_schedule!r}." + ) + if sigma_min <= 0 or sigma_max <= 0: + raise ValueError( + f"`sigma_min` and `sigma_max` must be > 0, got sigma_min={sigma_min}, sigma_max={sigma_max}." + ) + if sigma_max <= sigma_min: + raise ValueError(f"`sigma_max` must be > `sigma_min`, got sigma_min={sigma_min}, sigma_max={sigma_max}.") + if forward_process not in {"absorbing", "uniform"}: + raise ValueError(f"`forward_process` must be one of {{'absorbing','uniform'}}, got {forward_process!r}.") + + self.vocab_size = int(vocab_size) + self.mask_token_id = int(mask_token_id) + self.num_train_timesteps = int(num_train_timesteps) + self.alpha_schedule = alpha_schedule + self.eps = float(eps) + self.sigma_min = float(sigma_min) + self.sigma_max = float(sigma_max) + self.forward_process = str(forward_process) + self.exclude_mask_from_uniform = bool(exclude_mask_from_uniform) + + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + + def _effective_vocab_size(self) -> int: + if self.forward_process == "uniform" and self.exclude_mask_from_uniform: + return self.vocab_size - 1 + return self.vocab_size + + def _sample_uniform_tokens( + self, shape: torch.Size, device: torch.device, dtype: torch.dtype, generator: Optional[torch.Generator] = None + ) -> torch.LongTensor: + """ + Sample uniform token IDs, optionally excluding `mask_token_id` (by shifting indices around it). + """ + if self.forward_process != "uniform": + raise ValueError("Uniform token sampling is only valid for `forward_process='uniform'`.") + + if not self.exclude_mask_from_uniform: + return torch.randint(0, self.vocab_size, shape, device=device, dtype=dtype, generator=generator) + + # Sample in [0, vocab_size-1) and shift around mask_token_id. + v_eff = self.vocab_size - 1 + draw = torch.randint(0, v_eff, shape, device=device, dtype=dtype, generator=generator) + return torch.where(draw >= self.mask_token_id, draw + 1, draw) + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device, None] = None) -> None: + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Timesteps are stored in descending order, so `timesteps[0]` is the noisiest step. + """ + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + self.num_inference_steps = int(num_inference_steps) + + # Standard diffusers behavior: map inference steps onto training step indices. + timesteps = torch.linspace( + self.num_train_timesteps - 1, 0, self.num_inference_steps, dtype=torch.float32 + ).round() + self.timesteps = timesteps.to(dtype=torch.long, device=device) + + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + return sample + + def _t_from_timestep(self, timestep: Union[int, torch.Tensor], device: torch.device) -> torch.Tensor: + """ + Convert an integer training timestep index into continuous time `t in [0, 1]`. + """ + if isinstance(timestep, torch.Tensor): + t_idx = timestep.to(device=device, dtype=torch.float32) + else: + t_idx = torch.tensor(float(timestep), device=device, dtype=torch.float32) + denom = float(self.num_train_timesteps - 1) + return (t_idx / denom).clamp_(0.0, 1.0) + + def _alpha_t(self, t: torch.Tensor) -> torch.Tensor: + """ + Compute alpha(t) for the configured schedule. + + The returned tensor is expected to be in (0, 1] and monotone decreasing in `t`. + """ + if self.alpha_schedule == "log_linear": + # alpha(t) = 1 - (1 - eps) * t + return 1.0 - (1.0 - self.eps) * t + + if self.alpha_schedule == "linear": + # alpha(t) = (1 - 2*eps) * (1 - t) + eps + return (1.0 - 2.0 * self.eps) * (1.0 - t) + self.eps + + if self.alpha_schedule == "cosine": + # alpha_base(t) = 1 - cos(pi/2 * (1 - t)) + # alpha(t) = (1 - 2*eps) * alpha_base(t) + eps + base = 1.0 - torch.cos(torch.pi / 2.0 * (1.0 - t)) + return (1.0 - 2.0 * self.eps) * base + self.eps + + if self.alpha_schedule == "geometric": + # total_noise(t) = sigma_min^(1-t) * sigma_max^t + # alpha(t) = exp(-total_noise(t)) + sigma_min = torch.as_tensor(self.sigma_min, device=t.device, dtype=t.dtype) + sigma_max = torch.as_tensor(self.sigma_max, device=t.device, dtype=t.dtype) + total_noise = (sigma_min ** (1.0 - t)) * (sigma_max**t) + return (-total_noise).exp() + + raise ValueError(f"Unsupported alpha schedule: {self.alpha_schedule!r}") + + def _alpha_prime_t(self, t: torch.Tensor) -> torch.Tensor: + """ + Compute d/dt alpha(t) for the configured schedule. + """ + if self.alpha_schedule == "log_linear": + return -(1.0 - self.eps) * torch.ones_like(t) + + if self.alpha_schedule == "linear": + return -(1.0 - 2.0 * self.eps) * torch.ones_like(t) + + if self.alpha_schedule == "cosine": + base_prime = -(torch.pi / 2.0) * torch.sin(torch.pi / 2.0 * (1.0 - t)) + return (1.0 - 2.0 * self.eps) * base_prime + + if self.alpha_schedule == "geometric": + sigma_min = torch.as_tensor(self.sigma_min, device=t.device, dtype=t.dtype) + sigma_max = torch.as_tensor(self.sigma_max, device=t.device, dtype=t.dtype) + total_noise = (sigma_min ** (1.0 - t)) * (sigma_max**t) + alpha = (-total_noise).exp() + rate = total_noise * (sigma_max.log() - sigma_min.log()) + return -alpha * rate + + raise ValueError(f"Unsupported alpha schedule: {self.alpha_schedule!r}") + + def get_mdlm_loss_weights(self, timesteps: torch.LongTensor) -> torch.Tensor: + """ + Return per-example positive loss weights for masked-token reconstruction objectives. + + The weight corresponds to `-alpha'(t) / (1 - alpha(t))`, which is positive for monotone decreasing alpha(t). + + Args: + timesteps (`torch.LongTensor` of shape `(batch_size,)`): + Training timestep indices in `[0, num_train_timesteps-1]`. + + Returns: + `torch.FloatTensor` of shape `(batch_size, 1)`: + Positive weights to multiply token-level cross-entropy by. + """ + if timesteps.dtype not in (torch.int32, torch.int64): + raise ValueError(f"`timesteps` must be an integer tensor, got dtype={timesteps.dtype}.") + device = timesteps.device + t = self._t_from_timestep(timesteps.to(device), device=device) + t = t.to(dtype=torch.float32) + alpha = self._alpha_t(t).to(dtype=torch.float32) + dalpha = self._alpha_prime_t(t).to(dtype=torch.float32) + denom = (1.0 - alpha).clamp_min(torch.finfo(torch.float32).eps) + w = (-dalpha / denom).clamp_min(torch.finfo(torch.float32).tiny) + return w.view(-1, 1) + + def get_alpha(self, timesteps: torch.LongTensor) -> torch.Tensor: + """ + Return per-example alpha(t) values for the configured schedule. + + Args: + timesteps (`torch.LongTensor` of shape `(batch_size,)`): + Training timestep indices in `[0, num_train_timesteps-1]`. + + Returns: + `torch.FloatTensor` of shape `(batch_size, 1)`: + Alpha values in `(0, 1]` for each example. + """ + if timesteps.dtype not in (torch.int32, torch.int64): + raise ValueError(f"`timesteps` must be an integer tensor, got dtype={timesteps.dtype}.") + device = timesteps.device + t = self._t_from_timestep(timesteps.to(device), device=device).to(dtype=torch.float32) + alpha = self._alpha_t(t).to(dtype=torch.float32) + return alpha.view(-1, 1) + + def get_alpha_prime(self, timesteps: torch.LongTensor) -> torch.Tensor: + """ + Return per-example time derivative alpha'(t) for the configured schedule. + + Args: + timesteps (`torch.LongTensor` of shape `(batch_size,)`): + Training timestep indices in `[0, num_train_timesteps-1]`. + + Returns: + `torch.FloatTensor` of shape `(batch_size, 1)`: + Alpha derivatives with respect to continuous time `t in [0, 1]`. + """ + if timesteps.dtype not in (torch.int32, torch.int64): + raise ValueError(f"`timesteps` must be an integer tensor, got dtype={timesteps.dtype}.") + device = timesteps.device + t = self._t_from_timestep(timesteps.to(device), device=device).to(dtype=torch.float32) + dalpha = self._alpha_prime_t(t).to(dtype=torch.float32) + return dalpha.view(-1, 1) + + def add_noise( + self, + original_samples: torch.LongTensor, + noise: Optional[torch.Tensor], + timesteps: torch.LongTensor, + ) -> torch.LongTensor: + """ + Apply the (absorbing) forward process q(x_t | x_0). + + The `noise` argument is accepted for API compatibility but is not used for the absorbing kernel. + """ + del noise + + if original_samples.dtype != torch.long: + raise ValueError(f"`original_samples` must be int64 token IDs, got dtype={original_samples.dtype}.") + if timesteps.dtype not in (torch.int32, torch.int64): + raise ValueError(f"`timesteps` must be an integer tensor, got dtype={timesteps.dtype}.") + + batch_size, seq_len = original_samples.shape + device = original_samples.device + + # Convert per-example timesteps into alpha(t) in [eps, 1]. + t = self._t_from_timestep(timesteps.to(device), device=device).view(batch_size, 1) + alpha = self._alpha_t(t).to(dtype=torch.float32) + + p_replace = (1.0 - alpha).expand(batch_size, seq_len) + rand = torch.rand((batch_size, seq_len), device=device, dtype=torch.float32) + replace_positions = rand < p_replace + + if self.forward_process == "absorbing": + replacement = torch.full_like(original_samples, self.mask_token_id) + elif self.forward_process == "uniform": + replacement = self._sample_uniform_tokens( + original_samples.shape, device=device, dtype=original_samples.dtype, generator=None + ) + else: + raise ValueError(f"Unsupported forward process: {self.forward_process!r}") + + return torch.where(replace_positions, replacement, original_samples) + + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.LongTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[TokenDiffusionSchedulerOutput, Tuple[torch.LongTensor]]: + """ + Reverse diffusion step for the configured forward process. + + For `forward_process="absorbing"`, the update mirrors the common absorbing posterior: + - Keep all unmasked positions fixed. + - For masked positions, with probability p_denoise replace mask by a sample from p_theta(x0 | x_t, t). + + For `forward_process="uniform"`, this implements the discrete posterior used by UDLM-style uniform token + diffusion. + """ + if sample.dtype != torch.long: + raise ValueError(f"`sample` must be int64 token IDs, got dtype={sample.dtype}.") + if model_output.ndim != 3 or model_output.shape[-1] != self.vocab_size: + raise ValueError( + f"`model_output` must have shape [batch, seq_len, vocab_size={self.vocab_size}], got {tuple(model_output.shape)}." + ) + if model_output.shape[0] != sample.shape[0] or model_output.shape[1] != sample.shape[1]: + raise ValueError( + f"`model_output` batch/seq dims {tuple(model_output.shape[:2])} must match `sample` {tuple(sample.shape)}." + ) + + device = sample.device + batch_size, seq_len = sample.shape + + # Figure out the previous timestep in the configured inference schedule. + if self.num_inference_steps is None: + raise ValueError("Call `set_timesteps(num_inference_steps, ...)` before calling `step()`.") + + if isinstance(timestep, torch.Tensor): + timestep_int = int(timestep.item()) + else: + timestep_int = int(timestep) + + # Find current index in timesteps and use the next value as "previous" time (less noisy). + # If we are at the end, perform a "noise removal" step (alpha_prev = 1). + current_indices = (self.timesteps == timestep_int).nonzero(as_tuple=False) + if current_indices.numel() == 0: + raise ValueError(f"`timestep` ({timestep_int}) must be one of `self.timesteps`.") + step_index = int(current_indices[0].item()) + is_noise_removal_step = step_index + 1 >= len(self.timesteps) + prev_timestep_int = int(self.timesteps[step_index + 1].item()) if not is_noise_removal_step else 0 + + t = self._t_from_timestep(timestep_int, device=device) + alpha_t = self._alpha_t(t).to(dtype=torch.float32) + if is_noise_removal_step: + alpha_prev = torch.tensor(1.0, device=device, dtype=torch.float32) + else: + t_prev = self._t_from_timestep(prev_timestep_int, device=device) + alpha_prev = self._alpha_t(t_prev).to(dtype=torch.float32) + + if self.forward_process == "uniform": + # Convert logits to probabilities for x0; optionally forbid mask token. + logits = model_output.to(dtype=torch.float32) + if self.exclude_mask_from_uniform: + logits = logits.clone() + logits[..., self.mask_token_id] = torch.finfo(logits.dtype).min + p_x0 = logits.softmax(dim=-1) + + V = self.vocab_size + x = sample + xt_one_hot = F.one_hot(x, V).to(dtype=p_x0.dtype) + + alpha_ts = (alpha_t / alpha_prev).clamp_min(torch.finfo(torch.float32).eps) + + if self.exclude_mask_from_uniform: + limiting = torch.full((V,), 1.0 / float(V - 1), device=device, dtype=p_x0.dtype) + limiting[self.mask_token_id] = 0.0 + else: + limiting = torch.full((V,), 1.0 / float(V), device=device, dtype=p_x0.dtype) + limiting = limiting.view(1, 1, -1) + + alpha_t3 = alpha_t.view(1, 1, 1) + alpha_s3 = alpha_prev.view(1, 1, 1) + alpha_ts3 = alpha_ts.view(1, 1, 1) + + numerator = ( + (alpha_t3 * V * p_x0 * xt_one_hot) + + ((alpha_ts3 - alpha_t3) * xt_one_hot) + + ((alpha_s3 - alpha_t3) * p_x0) + + ((1.0 - alpha_ts3) * (1.0 - alpha_s3) * limiting) + ) + denom = (alpha_t3 * V * p_x0.gather(-1, x.unsqueeze(-1)) + (1.0 - alpha_t3)).clamp_min( + torch.finfo(torch.float32).eps + ) + + q_xs = numerator / denom + q_xs = q_xs.clamp_min(torch.finfo(torch.float32).tiny) + q_xs = q_xs / q_xs.sum(dim=-1, keepdim=True).clamp_min(torch.finfo(torch.float32).eps) + + x_prev = _gumbel_argmax(torch.log(q_xs), generator=generator).to(dtype=torch.long) + + if not return_dict: + return (x_prev,) + return TokenDiffusionSchedulerOutput(prev_sample=x_prev) + + if self.forward_process != "absorbing": + raise ValueError(f"Unsupported forward process for `step()`: {self.forward_process!r}") + + # p_denoise = (alpha_prev - alpha_t) / (1 - alpha_t) + denom = (1.0 - alpha_t).clamp_min(torch.finfo(torch.float32).eps) + p_denoise = ((alpha_prev - alpha_t) / denom).clamp(0.0, 1.0) + + # Sample x0 predictions (never sample the mask token). + logits = model_output.to(dtype=torch.float32) + logits[..., self.mask_token_id] = torch.finfo(logits.dtype).min + sampled_x0 = _gumbel_argmax(logits, generator=generator).to(dtype=torch.long) + + # Only masked positions can change. + is_masked = sample == self.mask_token_id + + # Bernoulli draw for whether to denoise at this step (only matters on masked positions). + rand = torch.rand((batch_size, seq_len), device=device, dtype=torch.float32, generator=generator) + should_denoise = rand < float(p_denoise.item()) + + x_prev = torch.where(is_masked & should_denoise, sampled_x0, sample) + + if not return_dict: + return (x_prev,) + return TokenDiffusionSchedulerOutput(prev_sample=x_prev) + + +__all__ = ["TokenDiffusionScheduler", "TokenDiffusionSchedulerOutput"] diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 3e9968d47fdd..1dceaf9044af 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -11,6 +11,7 @@ import numpy as np import torch +import torch.nn.functional as F if getattr(torch, "distributed", None) is not None: @@ -109,6 +110,92 @@ def compute_snr(noise_scheduler, timesteps): return snr +def compute_confidence_aware_loss( + logits: torch.Tensor, + labels: torch.Tensor, + *, + lambda_conf: float = 0.0, + temperature: float = 1.0, + per_token_weights: Optional[torch.Tensor] = None, + ignore_index: int = -100, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes a confidence-aware training loss for token classification-style heads. + + This loss combines: + - `loss_sft`: standard supervised cross-entropy on all non-ignored labels. + - `loss_conf`: an entropy penalty applied only on tokens that are already predicted correctly. + + Args: + logits (`torch.Tensor`): Logits of shape `(..., vocab_size)`. + labels (`torch.Tensor`): Labels of shape `(...)`, matching `logits.shape[:-1]`. Values set to `ignore_index` + are excluded from both losses. + lambda_conf (`float`, *optional*, defaults to `0.0`): Weight for the confidence term. + temperature (`float`, *optional*, defaults to `1.0`): Temperature used for the entropy term only. Lower values + sharpen the distribution and change the strength of the confidence gradients. + per_token_weights (`torch.Tensor`, *optional*): Optional weights of shape `(...)` to reweight both losses per + token (e.g. schedule-aware weights). Tokens with weight `0` contribute nothing. + ignore_index (`int`, *optional*, defaults to `-100`): Ignore index for labels. + + Returns: + `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`: `(loss, loss_sft, loss_conf)`. + """ + if logits.ndim < 2: + raise ValueError(f"`logits` must have at least 2 dims, got shape {tuple(logits.shape)}.") + if labels.shape != logits.shape[:-1]: + raise ValueError( + f"`labels` shape must match `logits.shape[:-1]`, got labels={tuple(labels.shape)} logits={tuple(logits.shape)}." + ) + if temperature <= 0: + raise ValueError(f"`temperature` must be > 0, got {temperature}.") + + valid = labels.ne(ignore_index) + if per_token_weights is None: + weights = torch.ones_like(labels, dtype=logits.dtype) + else: + if per_token_weights.shape != labels.shape: + raise ValueError( + f"`per_token_weights` shape must match `labels` shape, got {tuple(per_token_weights.shape)} != {tuple(labels.shape)}." + ) + weights = per_token_weights.to(dtype=logits.dtype) + + # Supervised CE (optionally weighted). + vocab_size = logits.shape[-1] + per_token_nll = F.cross_entropy( + logits.reshape(-1, vocab_size), + labels.reshape(-1), + reduction="none", + ignore_index=ignore_index, + ).reshape_as(labels) + + denom_sft = (weights * valid.to(weights.dtype)).sum().clamp_min(1) + loss_sft = (per_token_nll * weights * valid.to(per_token_nll.dtype)).sum() / denom_sft + + # Confidence loss: penalize entropy only where prediction is already correct. + if lambda_conf == 0.0: + loss_conf = torch.zeros((), device=logits.device, dtype=loss_sft.dtype) + return loss_sft, loss_sft, loss_conf + + with torch.no_grad(): + pred = logits.argmax(dim=-1) + correct = valid & pred.eq(labels) + + scaled_logits = logits.float() + if temperature != 1.0: + scaled_logits = scaled_logits / float(temperature) + + probs = torch.softmax(scaled_logits, dim=-1) + eps = torch.finfo(probs.dtype).tiny + log_probs = torch.log(probs.clamp_min(eps)) + entropy = -(probs * log_probs).sum(dim=-1).to(dtype=logits.dtype) + + denom_conf = (weights * correct.to(weights.dtype)).sum().clamp_min(1) + loss_conf = (entropy * weights * correct.to(entropy.dtype)).sum() / denom_conf + + loss = loss_sft + float(lambda_conf) * loss_conf + return loss, loss_sft, loss_conf + + def resolve_interpolation_mode(interpolation_type: str): """ Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index bb94c94da360..622dd94fee35 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2034,6 +2034,66 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class BlockRefinementPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class BlockRefinementPipelineOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class BlockTokenDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class BlockTokenDiffusionPipelineOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class CLIPImageProjection(metaclass=DummyObject): _backends = ["torch"] @@ -2139,6 +2199,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class HybridTokenDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ImagePipelineOutput(metaclass=DummyObject): _backends = ["torch"] @@ -2259,6 +2334,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class TokenDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class TokenDiffusionPipelineOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DiffusersQuantizer(metaclass=DummyObject): _backends = ["torch"] @@ -2289,6 +2394,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class BlockTokenDiffusionScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class CMStochasticIterativeScheduler(metaclass=DummyObject): _backends = ["torch"] @@ -2439,6 +2559,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class DFlashTokenDiffusionScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DFlashTokenDiffusionSchedulerOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DPMSolverMultistepInverseScheduler(metaclass=DummyObject): _backends = ["torch"] @@ -2604,6 +2754,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class HybridTokenDiffusionScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class HybridTokenDiffusionSchedulerOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class IPNDMScheduler(metaclass=DummyObject): _backends = ["torch"] @@ -2784,6 +2964,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class SDARTokenDiffusionScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class SDARTokenDiffusionSchedulerOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class TCDScheduler(metaclass=DummyObject): _backends = ["torch"] @@ -2799,6 +3009,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class TokenDiffusionScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class UnCLIPScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index da32b7ad8df0..15223187178f 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -887,6 +887,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class DFlashPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class DFlashPipelineOutput(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class EasyAnimateControlPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -1877,6 +1907,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LLaDA2Pipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class LLaDA2PipelineOutput(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LongCatImageEditPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -2552,6 +2612,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class SDARPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class SDARPipelineOutput(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class SemanticStableDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/others/test_training.py b/tests/others/test_training.py index 2038a98a813e..d8e86984ef1e 100644 --- a/tests/others/test_training.py +++ b/tests/others/test_training.py @@ -18,7 +18,7 @@ import torch from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel -from diffusers.training_utils import set_seed +from diffusers.training_utils import compute_confidence_aware_loss, set_seed from ..testing_utils import slow @@ -85,3 +85,47 @@ def test_training_step_equality(self): self.assertTrue(torch.allclose(ddpm_noisy_images, ddim_noisy_images, atol=1e-5)) self.assertTrue(torch.allclose(ddpm_noise_pred, ddim_noise_pred, atol=1e-5)) + + def test_confidence_aware_loss(self): + logits = torch.tensor([[[5.0, 0.0], [0.0, 5.0]]]) + labels = torch.tensor([[0, 0]]) + weights = torch.tensor([[1.0, 2.0]]) + + loss, loss_sft, loss_conf = compute_confidence_aware_loss( + logits, labels, lambda_conf=0.0, per_token_weights=weights + ) + self.assertTrue(torch.allclose(loss, loss_sft)) + self.assertTrue(torch.allclose(loss_conf, torch.zeros_like(loss_conf))) + + lambda_conf = 0.25 + loss, loss_sft, loss_conf = compute_confidence_aware_loss( + logits, labels, lambda_conf=lambda_conf, per_token_weights=weights + ) + + # Manual expected values for the small 2-class case. + per_token_nll = torch.nn.functional.cross_entropy(logits.view(-1, 2), labels.view(-1), reduction="none").view( + 1, 2 + ) + expected_sft = (per_token_nll * weights).sum() / weights.sum() + + pred = logits.argmax(dim=-1) + correct = pred.eq(labels) + log_probs = torch.log_softmax(logits.float(), dim=-1) + probs = log_probs.exp() + entropy = -(probs * log_probs).sum(dim=-1).to(dtype=logits.dtype) + expected_conf = (entropy * weights * correct.to(entropy.dtype)).sum() / ( + weights * correct.to(weights.dtype) + ).sum().clamp_min(1) + + expected = expected_sft + lambda_conf * expected_conf + self.assertTrue(torch.allclose(loss_sft, expected_sft)) + self.assertTrue(torch.allclose(loss_conf, expected_conf)) + self.assertTrue(torch.allclose(loss, expected)) + + # Temperature affects only the confidence term. + loss_t, loss_sft_t, loss_conf_t = compute_confidence_aware_loss( + logits, labels, lambda_conf=lambda_conf, temperature=0.5, per_token_weights=weights + ) + self.assertTrue(torch.allclose(loss_sft_t, expected_sft)) + self.assertFalse(torch.allclose(loss_conf_t, expected_conf)) + self.assertTrue(torch.allclose(loss_t, loss_sft_t + lambda_conf * loss_conf_t)) diff --git a/tests/pipelines/test_pipeline_block_refinement.py b/tests/pipelines/test_pipeline_block_refinement.py new file mode 100644 index 000000000000..d4b841268b53 --- /dev/null +++ b/tests/pipelines/test_pipeline_block_refinement.py @@ -0,0 +1,93 @@ +import unittest + +import torch + +from diffusers import BlockRefinementPipeline + + +class _DummyModelOutput: + def __init__(self, logits): + self.logits = logits + + +class _DummyCausalLM(torch.nn.Module): + def __init__(self, vocab_size: int): + super().__init__() + self.vocab_size = int(vocab_size) + self.register_buffer("_device_anchor", torch.empty(0)) + + @property + def dtype(self): + return torch.float32 + + @property + def device(self): + return self._device_anchor.device + + def forward(self, input_ids, attention_mask=None, position_ids=None, **kwargs): + batch_size, seq_len = input_ids.shape + logits = torch.zeros((batch_size, seq_len, self.vocab_size), device=input_ids.device, dtype=torch.float32) + + # Make confidence vary with token position so top-k commits are deterministic. + positions = torch.arange(seq_len, device=input_ids.device, dtype=torch.float32).view(1, seq_len, 1) + token_ids = (torch.arange(seq_len, device=input_ids.device) % (self.vocab_size - 2)).view(1, seq_len, 1) + logits.scatter_(2, token_ids.expand(batch_size, -1, -1), 1.0 + positions.expand(batch_size, -1, -1) * 0.1) + return _DummyModelOutput(logits=logits) + + +class _DummyCausalLM2DOnly(_DummyCausalLM): + def forward(self, input_ids, attention_mask=None, position_ids=None, **kwargs): + if attention_mask is not None and attention_mask.ndim != 2: + raise ValueError("2D attention_mask required") + return super().forward(input_ids, attention_mask=attention_mask, position_ids=position_ids, **kwargs) + + +class BlockRefinementPipelineTest(unittest.TestCase): + def test_pipeline_runs(self): + vocab_size = 32 + model = _DummyCausalLM(vocab_size=vocab_size) + pipe = BlockRefinementPipeline(model=model, tokenizer=None).to("cpu") + + prompt_ids = torch.tensor([[5, 6, 7, 8], [1, 2, 3, 4]], dtype=torch.long) + out = pipe( + prompt_ids=prompt_ids, + gen_length=24, + block_length=8, + steps=8, + temperature=0.0, + threshold=2.0, # force top-k commits + minimal_topk=1, + eos_early_stop=False, + mask_token_id=vocab_size - 1, + eos_token_id=None, + return_text=False, + ) + + self.assertEqual(out.sequences.shape, (2, 24)) + self.assertFalse((out.sequences == vocab_size - 1).any().item()) + + def test_pipeline_falls_back_to_2d_attention_mask(self): + vocab_size = 32 + model = _DummyCausalLM2DOnly(vocab_size=vocab_size) + pipe = BlockRefinementPipeline(model=model, tokenizer=None).to("cpu") + + out = pipe( + prompt_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), + gen_length=16, + block_length=8, + steps=4, + temperature=0.0, + threshold=2.0, + minimal_topk=1, + eos_early_stop=False, + mask_token_id=vocab_size - 1, + eos_token_id=None, + attention_mask_mode="auto", + return_text=False, + ) + + self.assertEqual(out.sequences.shape, (1, 16)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/pipelines/test_pipeline_block_token_diffusion.py b/tests/pipelines/test_pipeline_block_token_diffusion.py new file mode 100644 index 000000000000..e005310be974 --- /dev/null +++ b/tests/pipelines/test_pipeline_block_token_diffusion.py @@ -0,0 +1,59 @@ +import unittest + +import torch + +from diffusers import BlockTokenDiffusionPipeline, BlockTokenDiffusionScheduler + + +class _DummyTokenizer: + cls_token_id = 1 + bos_token_id = None + + def batch_decode(self, sequences, skip_special_tokens=True): + return [" ".join(map(str, row)) for row in sequences.tolist()] + + +class _DummyModelOutput: + def __init__(self, logits): + self.logits = logits + + +class _DummyMLM(torch.nn.Module): + def __init__(self, vocab_size: int): + super().__init__() + self.vocab_size = vocab_size + self.register_buffer("_device_anchor", torch.empty(0)) + + @property + def dtype(self): + return torch.float32 + + @property + def device(self): + return self._device_anchor.device + + def forward(self, input_ids=None, attention_mask=None, **kwargs): + batch_size, seq_len = input_ids.shape + logits = torch.zeros((batch_size, seq_len, self.vocab_size), device=input_ids.device, dtype=torch.float32) + return _DummyModelOutput(logits=logits) + + +class BlockTokenDiffusionPipelineTest(unittest.TestCase): + def test_pipeline_runs_and_respects_prefix(self): + vocab_size = 32 + scheduler = BlockTokenDiffusionScheduler(vocab_size=vocab_size, mask_token_id=vocab_size - 1) + model = _DummyMLM(vocab_size=vocab_size) + tokenizer = _DummyTokenizer() + pipe = BlockTokenDiffusionPipeline(model=model, scheduler=scheduler, tokenizer=tokenizer).to("cpu") + + prefix = torch.tensor([5, 6, 7], dtype=torch.long) + gen = torch.Generator().manual_seed(0) + out = pipe(batch_size=2, seq_len=10, block_size=4, num_inference_steps=2, generator=gen, prefix_ids=prefix) + + self.assertEqual(out.sequences.shape, (2, 10)) + self.assertTrue((out.sequences[:, :3] == prefix.view(1, -1)).all().item()) + self.assertEqual(len(out.texts), 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/pipelines/test_pipeline_hybrid_token_diffusion.py b/tests/pipelines/test_pipeline_hybrid_token_diffusion.py new file mode 100644 index 000000000000..25c2f7a9ae82 --- /dev/null +++ b/tests/pipelines/test_pipeline_hybrid_token_diffusion.py @@ -0,0 +1,56 @@ +import unittest + +import torch + +from diffusers import HybridTokenDiffusionPipeline, HybridTokenDiffusionScheduler + + +class _DummyTokenizer: + cls_token_id = 1 + bos_token_id = None + + def batch_decode(self, sequences, skip_special_tokens=True): + return [" ".join(map(str, row)) for row in sequences.tolist()] + + +class _DummyModelOutput: + def __init__(self, logits): + self.logits = logits + + +class _DummyMLM(torch.nn.Module): + def __init__(self, vocab_size: int): + super().__init__() + self.vocab_size = vocab_size + self.register_buffer("_device_anchor", torch.empty(0)) + + @property + def dtype(self): + return torch.float32 + + @property + def device(self): + return self._device_anchor.device + + def forward(self, input_ids=None, attention_mask=None, **kwargs): + batch_size, seq_len = input_ids.shape + logits = torch.zeros((batch_size, seq_len, self.vocab_size), device=input_ids.device, dtype=torch.float32) + return _DummyModelOutput(logits=logits) + + +class HybridTokenDiffusionPipelineTest(unittest.TestCase): + def test_pipeline_runs(self): + vocab_size = 32 + scheduler = HybridTokenDiffusionScheduler(vocab_size=vocab_size, mask_token_id=vocab_size - 1) + model = _DummyMLM(vocab_size=vocab_size) + tokenizer = _DummyTokenizer() + pipe = HybridTokenDiffusionPipeline(model=model, scheduler=scheduler, tokenizer=tokenizer).to("cpu") + + gen = torch.Generator().manual_seed(0) + out = pipe(batch_size=2, seq_len=8, num_inference_steps=2, generator=gen, inject_start_token=True) + self.assertEqual(out.sequences.shape, (2, 8)) + self.assertEqual(len(out.texts), 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/pipelines/test_pipeline_token_diffusion.py b/tests/pipelines/test_pipeline_token_diffusion.py new file mode 100644 index 000000000000..57c26d13a6a2 --- /dev/null +++ b/tests/pipelines/test_pipeline_token_diffusion.py @@ -0,0 +1,129 @@ +import unittest + +import torch + +from diffusers import TokenDiffusionPipeline, TokenDiffusionScheduler + + +class _DummyTokenizer: + bos_token_id = None + cls_token_id = 1 + + def batch_decode(self, sequences, skip_special_tokens=True): + # Deterministic, cheap “decode”: join token ids as strings. + out = [] + for row in sequences.tolist(): + out.append(" ".join(str(i) for i in row)) + return out + + +class _DummyModelOutput: + def __init__(self, logits): + self.logits = logits + + +class _DummyMLM(torch.nn.Module): + def __init__(self, vocab_size: int): + super().__init__() + self.vocab_size = vocab_size + self.register_buffer("_device_anchor", torch.empty(0)) + + @property + def dtype(self): + return torch.float32 + + @property + def device(self): + return self._device_anchor.device + + def forward(self, input_ids=None, attention_mask=None, **kwargs): + batch_size, seq_len = input_ids.shape + logits = torch.zeros((batch_size, seq_len, self.vocab_size), device=input_ids.device, dtype=torch.float32) + return _DummyModelOutput(logits=logits) + + +class TokenDiffusionPipelineTest(unittest.TestCase): + def test_absorbing_pipeline_runs(self): + vocab_size = 32 + scheduler = TokenDiffusionScheduler( + vocab_size=vocab_size, mask_token_id=vocab_size - 1, forward_process="absorbing" + ) + model = _DummyMLM(vocab_size=vocab_size) + tokenizer = _DummyTokenizer() + + pipe = TokenDiffusionPipeline(model=model, scheduler=scheduler, tokenizer=tokenizer) + pipe = pipe.to("cpu") + + out = pipe(batch_size=2, seq_len=8, num_inference_steps=2, inject_start_token=True) + self.assertEqual(out.sequences.shape, (2, 8)) + self.assertEqual(len(out.texts), 2) + + def test_uniform_pipeline_runs(self): + vocab_size = 32 + scheduler = TokenDiffusionScheduler( + vocab_size=vocab_size, + mask_token_id=vocab_size - 1, + forward_process="uniform", + exclude_mask_from_uniform=True, + ) + model = _DummyMLM(vocab_size=vocab_size) + tokenizer = _DummyTokenizer() + + pipe = TokenDiffusionPipeline(model=model, scheduler=scheduler, tokenizer=tokenizer) + pipe = pipe.to("cpu") + + gen = torch.Generator().manual_seed(0) + out = pipe(batch_size=2, seq_len=8, num_inference_steps=2, generator=gen, inject_start_token=True) + self.assertEqual(out.sequences.shape, (2, 8)) + self.assertFalse((out.sequences == scheduler.mask_token_id).any().item()) + + def test_prefix_ids_are_fixed(self): + vocab_size = 32 + scheduler = TokenDiffusionScheduler( + vocab_size=vocab_size, mask_token_id=vocab_size - 1, forward_process="absorbing" + ) + model = _DummyMLM(vocab_size=vocab_size) + tokenizer = _DummyTokenizer() + + pipe = TokenDiffusionPipeline(model=model, scheduler=scheduler, tokenizer=tokenizer).to("cpu") + prefix = torch.tensor([5, 6, 7], dtype=torch.long) + out = pipe(batch_size=2, seq_len=8, num_inference_steps=2, prefix_ids=prefix, return_text=False) + + self.assertTrue((out.sequences[:, :3] == prefix.view(1, -1)).all().item()) + + def test_infill_mask_freezes_positions(self): + vocab_size = 32 + scheduler = TokenDiffusionScheduler( + vocab_size=vocab_size, + mask_token_id=vocab_size - 1, + forward_process="uniform", + exclude_mask_from_uniform=True, + ) + model = _DummyMLM(vocab_size=vocab_size) + tokenizer = _DummyTokenizer() + + pipe = TokenDiffusionPipeline(model=model, scheduler=scheduler, tokenizer=tokenizer).to("cpu") + + # Only positions 2..7 are editable, first two positions are fixed to the initial values. + infill_mask = torch.ones((2, 8), dtype=torch.bool) + infill_mask[:, :2] = False + gen = torch.Generator().manual_seed(0) + out = pipe( + batch_size=2, seq_len=8, num_inference_steps=2, generator=gen, infill_mask=infill_mask, return_text=False + ) + + # Fixed positions should be unchanged from the initial latents (for uniform, these are random but clamped). + # Since the model predicts uniform logits and the scheduler would otherwise resample, this checks clamping works. + out2 = pipe( + batch_size=2, + seq_len=8, + num_inference_steps=2, + generator=torch.Generator().manual_seed(0), + infill_mask=infill_mask, + return_text=False, + ) + self.assertTrue((out.sequences[:, :2] == out2.sequences[:, :2]).all().item()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/schedulers/test_scheduler_block_token_diffusion.py b/tests/schedulers/test_scheduler_block_token_diffusion.py new file mode 100644 index 000000000000..7b5e833b21d0 --- /dev/null +++ b/tests/schedulers/test_scheduler_block_token_diffusion.py @@ -0,0 +1,30 @@ +import unittest + +import torch + +from diffusers import BlockTokenDiffusionScheduler + + +class BlockTokenDiffusionSchedulerTest(unittest.TestCase): + def test_step_respects_block_mask(self): + vocab_size = 32 + scheduler = BlockTokenDiffusionScheduler(vocab_size=vocab_size, mask_token_id=vocab_size - 1) + scheduler.set_timesteps(1) + + batch_size, seq_len = 2, 8 + x = torch.full((batch_size, seq_len), scheduler.mask_token_id, dtype=torch.long) + block_mask = torch.zeros_like(x, dtype=torch.bool) + block_mask[:, :4] = True + + logits = torch.zeros((batch_size, seq_len, vocab_size), dtype=torch.float32) + gen = torch.Generator().manual_seed(0) + out = scheduler.step(logits, scheduler.timesteps[0], x, generator=gen, return_dict=True, block_mask=block_mask) + + # Block positions should be denoised (non-mask) after the final noise-removal step. + self.assertTrue((out.prev_sample[:, :4] != scheduler.mask_token_id).all().item()) + # Outside the block, tokens should remain unchanged (still mask). + self.assertTrue((out.prev_sample[:, 4:] == scheduler.mask_token_id).all().item()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/schedulers/test_scheduler_hybrid_token_diffusion.py b/tests/schedulers/test_scheduler_hybrid_token_diffusion.py new file mode 100644 index 000000000000..a3146698f22b --- /dev/null +++ b/tests/schedulers/test_scheduler_hybrid_token_diffusion.py @@ -0,0 +1,29 @@ +import unittest + +import torch + +from diffusers import HybridTokenDiffusionScheduler + + +class HybridTokenDiffusionSchedulerTest(unittest.TestCase): + def test_add_noise_and_step_shapes(self): + vocab_size = 32 + scheduler = HybridTokenDiffusionScheduler(vocab_size=vocab_size, mask_token_id=vocab_size - 1) + scheduler.set_timesteps(4, device="cpu") + + batch_size, seq_len = 2, 8 + x0 = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.long) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (batch_size,), dtype=torch.long) + x_t = scheduler.add_noise(x0, noise=None, timesteps=timesteps) + self.assertEqual(x_t.shape, x0.shape) + self.assertEqual(x_t.dtype, torch.long) + + logits = torch.zeros((batch_size, seq_len, vocab_size), dtype=torch.float32) + gen = torch.Generator().manual_seed(0) + out = scheduler.step(logits, scheduler.timesteps[0], x_t, generator=gen, return_dict=True) + self.assertEqual(out.prev_sample.shape, x0.shape) + self.assertTrue(((out.prev_sample >= 0) & (out.prev_sample < vocab_size)).all().item()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/schedulers/test_scheduler_token_diffusion.py b/tests/schedulers/test_scheduler_token_diffusion.py new file mode 100644 index 000000000000..1812f3f28630 --- /dev/null +++ b/tests/schedulers/test_scheduler_token_diffusion.py @@ -0,0 +1,130 @@ +import unittest + +import torch + +from diffusers import TokenDiffusionScheduler + + +class TokenDiffusionSchedulerTest(unittest.TestCase): + def get_scheduler(self, **kwargs): + config = { + "vocab_size": 128, + "mask_token_id": 127, + "num_train_timesteps": 100, + "eps": 1e-3, + } + config.update(kwargs) + return TokenDiffusionScheduler(**config) + + def test_set_timesteps(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(10) + self.assertEqual(len(scheduler.timesteps), 10) + self.assertTrue((scheduler.timesteps[:-1] >= scheduler.timesteps[1:]).all().item()) + + def test_alpha_schedule_monotone_and_bounded(self): + # alpha(t) should be in (0, 1] and non-increasing in t for supported schedules. + schedules = ["log_linear", "linear", "cosine", "geometric"] + t = torch.linspace(0, 1, 33, dtype=torch.float32) + + for name in schedules: + scheduler = self.get_scheduler(alpha_schedule=name) + alpha = scheduler._alpha_t(t) + self.assertTrue(((alpha > 0) & (alpha <= 1)).all().item()) + # monotone non-increasing: alpha[i] >= alpha[i+1] + self.assertTrue((alpha[:-1] >= alpha[1:]).all().item()) + + def test_mdlm_weights_match_log_linear_1_over_t(self): + scheduler = self.get_scheduler(alpha_schedule="log_linear", eps=1e-3, num_train_timesteps=1000) + timesteps = torch.tensor([1, 10, 100, 999], dtype=torch.long) + w = scheduler.get_mdlm_loss_weights(timesteps).squeeze(-1) + t_cont = timesteps.to(dtype=torch.float32) / float(scheduler.num_train_timesteps - 1) + expected = 1.0 / t_cont + self.assertTrue(torch.allclose(w, expected, rtol=5e-5, atol=1e-5)) + + def test_add_noise_absorbing_keeps_shape_dtype(self): + scheduler = self.get_scheduler() + batch_size, seq_len = 4, 16 + x0 = torch.randint(0, scheduler.vocab_size, (batch_size, seq_len), dtype=torch.long) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (batch_size,), dtype=torch.long) + + xt = scheduler.add_noise(x0, noise=None, timesteps=timesteps) + self.assertEqual(xt.shape, x0.shape) + self.assertEqual(xt.dtype, torch.long) + + # xt values must be valid token ids + self.assertTrue(((xt >= 0) & (xt < scheduler.vocab_size)).all().item()) + + def test_step_preserves_unmasked_tokens(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(5) + + batch_size, seq_len = 2, 12 + x_t = torch.randint(0, scheduler.vocab_size, (batch_size, seq_len), dtype=torch.long) + x_t[:, :3] = scheduler.mask_token_id # ensure some masked positions + + # Model predicts uniform logits; step should never change already unmasked positions + logits = torch.zeros((batch_size, seq_len, scheduler.vocab_size), dtype=torch.float32) + out = scheduler.step(logits, scheduler.timesteps[0], x_t, return_dict=True) + x_prev = out.prev_sample + + self.assertTrue((x_prev[:, 3:] == x_t[:, 3:]).all().item()) + + def test_step_never_samples_mask_token(self): + scheduler = self.get_scheduler() + # Use a single inference step so the scheduler denoises to t=0 in one go (p_denoise = 1). + scheduler.set_timesteps(1) + + batch_size, seq_len = 2, 12 + x_t = torch.full((batch_size, seq_len), scheduler.mask_token_id, dtype=torch.long) + logits = torch.zeros((batch_size, seq_len, scheduler.vocab_size), dtype=torch.float32) + + gen = torch.Generator().manual_seed(0) + x_prev = scheduler.step(logits, scheduler.timesteps[0], x_t, generator=gen, return_dict=True).prev_sample + + # Mask token is forbidden as an x0 prediction, and the scheduler performs a final noise-removal step. + self.assertTrue((x_prev != scheduler.mask_token_id).all().item()) + + def test_uniform_add_noise_excludes_mask_if_configured(self): + scheduler = self.get_scheduler(forward_process="uniform", exclude_mask_from_uniform=True) + batch_size, seq_len = 8, 64 + x0 = torch.randint(0, scheduler.vocab_size, (batch_size, seq_len), dtype=torch.long) + # Make sure some originals are mask token too (uniform should still sample non-mask replacements). + x0[:, :5] = scheduler.mask_token_id + + # Use the noisiest time (highest replace probability). + timesteps = torch.full((batch_size,), scheduler.num_train_timesteps - 1, dtype=torch.long) + xt = scheduler.add_noise(x0, noise=None, timesteps=timesteps) + + # Mask token should be rare-to-absent under uniform corruption when excluded. + self.assertFalse((xt == scheduler.mask_token_id).any().item()) + + def test_uniform_step_runs_and_returns_valid_ids(self): + scheduler = self.get_scheduler(forward_process="uniform", exclude_mask_from_uniform=True) + scheduler.set_timesteps(2) + + batch_size, seq_len = 2, 16 + x_t = torch.randint(0, scheduler.vocab_size, (batch_size, seq_len), dtype=torch.long) + logits = torch.zeros((batch_size, seq_len, scheduler.vocab_size), dtype=torch.float32) + + gen = torch.Generator().manual_seed(0) + x_prev = scheduler.step(logits, scheduler.timesteps[0], x_t, generator=gen, return_dict=True).prev_sample + + self.assertEqual(x_prev.shape, x_t.shape) + self.assertTrue(((x_prev >= 0) & (x_prev < scheduler.vocab_size)).all().item()) + # With exclusion, mask token should not appear. + self.assertFalse((x_prev == scheduler.mask_token_id).any().item()) + + def test_alpha_helpers_shapes(self): + scheduler = self.get_scheduler(num_train_timesteps=10) + timesteps = torch.tensor([0, 1, 9], dtype=torch.long) + + alpha = scheduler.get_alpha(timesteps) + dalpha = scheduler.get_alpha_prime(timesteps) + + self.assertEqual(alpha.shape, (3, 1)) + self.assertEqual(dalpha.shape, (3, 1)) + + +if __name__ == "__main__": + unittest.main()