Skip to content

Latest commit

 

History

History
476 lines (401 loc) · 20.3 KB

File metadata and controls

476 lines (401 loc) · 20.3 KB

Training Models from Scratch

DiffSynth-Studio's training engine supports training foundation models from scratch. This article introduces how to train a small text-to-image model with only 0.1B parameters from scratch.

1. Building Model Architecture

1.1 Diffusion Model

From UNet [1] [2] to DiT [3] [4], the mainstream model architectures of Diffusion have undergone multiple evolutions. Typically, a Diffusion model's inputs include:

  • Image tensor (latents): The encoding of images, generated by the VAE model, containing partial noise
  • Text tensor (prompt_embeds): The encoding of text, generated by the text encoder
  • Timestep (timestep): A scalar used to mark which stage of the Diffusion process we are currently at

The model's output is a tensor with the same shape as the image tensor, representing the denoising direction predicted by the model. For details about Diffusion model theory, please refer to Basic Principles of Diffusion Models. In this article, we build a DiT model with only 0.1B parameters: AAADiT.

Model Architecture Code
import torch, accelerate
from PIL import Image
from typing import Union
from tqdm import tqdm
from einops import rearrange, repeat

from transformers import AutoProcessor, AutoTokenizer
from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
from diffsynth.models.general_modules import TimestepEmbeddings
from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
from diffsynth.models.flux2_vae import Flux2VAE


class AAAPositionalEmbedding(torch.nn.Module):
    def __init__(self, height=16, width=16, dim=1024):
        super().__init__()
        self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
        self.text_emb = torch.nn.Parameter(torch.randn((dim,)))

    def forward(self, image, text):
        height, width = image.shape[-2:]
        image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
        image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
        image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
        text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
        text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
        emb = torch.concat([image_emb, text_emb], dim=1)
        return emb


class AAABlock(torch.nn.Module):
    def __init__(self, dim=1024, num_heads=32):
        super().__init__()
        self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
        self.to_q = torch.nn.Linear(dim, dim)
        self.to_k = torch.nn.Linear(dim, dim)
        self.to_v = torch.nn.Linear(dim, dim)
        self.to_out = torch.nn.Linear(dim, dim)
        self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
        self.ff = torch.nn.Sequential(
            torch.nn.Linear(dim, dim*3),
            torch.nn.SiLU(),
            torch.nn.Linear(dim*3, dim),
        )
        self.to_gate = torch.nn.Linear(dim, dim * 2)
        self.num_heads = num_heads

    def attention(self, emb, pos_emb):
        emb = self.norm_attn(emb + pos_emb)
        q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
        emb = attention_forward(
            q, k, v,
            q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
            dims={"n": self.num_heads},
        )
        emb = self.to_out(emb)
        return emb
    
    def feed_forward(self, emb, pos_emb):
        emb = self.norm_mlp(emb + pos_emb)
        emb = self.ff(emb)
        return emb
    
    def forward(self, emb, pos_emb, t_emb):
        gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
        emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
        emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
        return emb


class AAADiT(torch.nn.Module):
    def __init__(self, dim=1024):
        super().__init__()
        self.pos_embedder = AAAPositionalEmbedding(dim=dim)
        self.timestep_embedder = TimestepEmbeddings(256, dim)
        self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
        self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
        self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
        self.proj_out = torch.nn.Linear(dim, 128)

    def forward(
        self,
        latents,
        prompt_embeds,
        timestep,
        use_gradient_checkpointing=False,
        use_gradient_checkpointing_offload=False,
    ):
        pos_emb = self.pos_embedder(latents, prompt_embeds)
        t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
        image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
        text = self.text_embedder(prompt_embeds)
        emb = torch.concat([image, text], dim=1)
        for block_id, block in enumerate(self.blocks):
            emb = gradient_checkpoint_forward(
                block,
                use_gradient_checkpointing=use_gradient_checkpointing,
                use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
                emb=emb,
                pos_emb=pos_emb,
                t_emb=t_emb,
            )
        emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
        emb = self.proj_out(emb)
        emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
        return emb

1.2 Encoder-Decoder Models

Besides the Diffusion model used for denoising, we also need two other models:

  • Text Encoder: Used to encode text into tensors. We adopt the Qwen/Qwen3-0.6B model.
  • VAE Encoder-Decoder: The encoder part is used to encode images into tensors, and the decoder part is used to decode image tensors into images. We adopt the VAE model from black-forest-labs/FLUX.2-klein-4B.

The architectures of these two models are already integrated in DiffSynth-Studio, located at /diffsynth/models/z_image_text_encoder.py and /diffsynth/models/flux2_vae.py, so we don't need to modify any code.

2. Building Pipeline

We introduced how to build a model Pipeline in the document Integrating Pipeline. For the model in this article, we also need to build a Pipeline to connect the text encoder, Diffusion model, and VAE encoder-decoder.

Pipeline Code
class AAAImagePipeline(BasePipeline):
    def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
        super().__init__(
            device=device, torch_dtype=torch_dtype,
            height_division_factor=16, width_division_factor=16,
        )
        self.scheduler = FlowMatchScheduler("FLUX.2")
        self.text_encoder: ZImageTextEncoder = None
        self.dit: AAADiT = None
        self.vae: Flux2VAE = None
        self.tokenizer: AutoProcessor = None
        self.in_iteration_models = ("dit",)
        self.units = [
            AAAUnit_PromptEmbedder(),
            AAAUnit_NoiseInitializer(),
            AAAUnit_InputImageEmbedder(),
        ]
        self.model_fn = model_fn_aaa
    
    @staticmethod
    def from_pretrained(
        torch_dtype: torch.dtype = torch.bfloat16,
        device: Union[str, torch.device] = "cuda",
        model_configs: list[ModelConfig] = [],
        tokenizer_config: ModelConfig = None,
        vram_limit: float = None,
    ):
        # Initialize pipeline
        pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
        model_pool = pipe.download_and_load_models(model_configs, vram_limit)
        
        # Fetch models
        pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
        pipe.dit = model_pool.fetch_model("aaa_dit")
        pipe.vae = model_pool.fetch_model("flux2_vae")
        if tokenizer_config is not None:
            tokenizer_config.download_if_necessary()
            pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
        
        # VRAM Management
        pipe.vram_management_enabled = pipe.check_vram_management_state()
        return pipe
    
    @torch.no_grad()
    def __call__(
        self,
        # Prompt
        prompt: str,
        negative_prompt: str = "",
        cfg_scale: float = 1.0,
        # Image
        input_image: Image.Image = None,
        denoising_strength: float = 1.0,
        # Shape
        height: int = 1024,
        width: int = 1024,
        # Randomness
        seed: int = None,
        rand_device: str = "cpu",
        # Steps
        num_inference_steps: int = 30,
        # Progress bar
        progress_bar_cmd = tqdm,
    ):
        self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)

        # Parameters
        inputs_posi = {"prompt": prompt}
        inputs_nega = {"negative_prompt": negative_prompt}
        inputs_shared = {
            "cfg_scale": cfg_scale,
            "input_image": input_image, "denoising_strength": denoising_strength,
            "height": height, "width": width,
            "seed": seed, "rand_device": rand_device,
            "num_inference_steps": num_inference_steps,
        }
        for unit in self.units:
            inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)

        # Denoise
        self.load_models_to_device(self.in_iteration_models)
        models = {name: getattr(self, name) for name in self.in_iteration_models}
        for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
            timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
            noise_pred = self.cfg_guided_model_fn(
                self.model_fn, cfg_scale,
                inputs_shared, inputs_posi, inputs_nega,
                **models, timestep=timestep, progress_id=progress_id
            )
            inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
        
        # Decode
        self.load_models_to_device(['vae'])
        image = self.vae.decode(inputs_shared["latents"])
        image = self.vae_output_to_image(image)
        self.load_models_to_device([])

        return image


class AAAUnit_PromptEmbedder(PipelineUnit):
    def __init__(self):
        super().__init__(
            seperate_cfg=True,
            input_params_posi={"prompt": "prompt"},
            input_params_nega={"prompt": "negative_prompt"},
            output_params=("prompt_embeds",),
            onload_model_names=("text_encoder",)
        )
        self.hidden_states_layers = (-1,)

    def process(self, pipe: AAAImagePipeline, prompt):
        pipe.load_models_to_device(self.onload_model_names)
        text = pipe.tokenizer.apply_chat_template(
            [{"role": "user", "content": prompt}],
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=False,
        )
        inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
        output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
        prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
        return {"prompt_embeds": prompt_embeds}


class AAAUnit_NoiseInitializer(PipelineUnit):
    def __init__(self):
        super().__init__(
            input_params=("height", "width", "seed", "rand_device"),
            output_params=("noise",),
        )

    def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
        noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
        return {"noise": noise}


class AAAUnit_InputImageEmbedder(PipelineUnit):
    def __init__(self):
        super().__init__(
            input_params=("input_image", "noise"),
            output_params=("latents", "input_latents"),
            onload_model_names=("vae",)
        )

    def process(self, pipe: AAAImagePipeline, input_image, noise):
        if input_image is None:
            return {"latents": noise, "input_latents": None}
        pipe.load_models_to_device(['vae'])
        image = pipe.preprocess_image(input_image)
        input_latents = pipe.vae.encode(image)
        if pipe.scheduler.training:
            return {"latents": noise, "input_latents": input_latents}
        else:
            latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
            return {"latents": latents, "input_latents": input_latents}


def model_fn_aaa(
    dit: AAADiT,
    latents=None,
    prompt_embeds=None,
    timestep=None,
    use_gradient_checkpointing=False,
    use_gradient_checkpointing_offload=False,
    **kwargs,
):
    model_output = dit(
        latents,
        prompt_embeds,
        timestep,
        use_gradient_checkpointing=use_gradient_checkpointing,
        use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
    )
    return model_output

3. Preparing Dataset

To quickly verify training effectiveness, we use the dataset Pokemon-First Generation, which is reproduced from the open-source project pokemon-dataset-zh, containing 151 first-generation Pokemon from Bulbasaur to Mew. If you want to use other datasets, please refer to the document Preparing Datasets and diffsynth.core.data.

modelscope download --dataset DiffSynth-Studio/pokemon-gen1 --local_dir ./data

4. Start Training

The training process can be quickly implemented using Pipeline. We have placed the complete code at ../Research_Tutorial/train_from_scratch.py, which can be directly started with python docs/en/Research_Tutorial/train_from_scratch.py for single GPU training.

To enable multi-GPU parallel training, please run accelerate config to set relevant parameters, then use the command accelerate launch docs/en/Research_Tutorial/train_from_scratch.py to start training.

This training script has no stopping condition, please manually close it when needed. The model converges after training approximately 60,000 steps, requiring 10-20 hours for single GPU training.

Training Code
class AAATrainingModule(DiffusionTrainingModule):
    def __init__(self, device):
        super().__init__()
        self.pipe = AAAImagePipeline.from_pretrained(
            torch_dtype=torch.bfloat16,
            device=device,
            model_configs=[
                ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
                ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
            ],
            tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
        )
        self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
        self.pipe.freeze_except(["dit"])
        self.pipe.scheduler.set_timesteps(1000, training=True)

    def forward(self, data):
        inputs_posi = {"prompt": data["prompt"]}
        inputs_nega = {"negative_prompt": ""}
        inputs_shared = {
            "input_image": data["image"],
            "height": data["image"].size[1],
            "width": data["image"].size[0],
            "cfg_scale": 1,
            "use_gradient_checkpointing": False,
            "use_gradient_checkpointing_offload": False,
        }
        for unit in self.pipe.units:
            inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
        loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
        return loss


if __name__ == "__main__":
    accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
    dataset = UnifiedDataset(
        base_path="data/images",
        metadata_path="data/metadata_merged.csv",
        max_data_items=10000000,
        data_file_keys=("image",),
        main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
    )
    model = AAATrainingModule(device=accelerator.device)
    model_logger = ModelLogger(
        "models/AAA/v1",
        remove_prefix_in_ckpt="pipe.dit.",
    )
    launch_training_task(
        accelerator, dataset, model, model_logger,
        learning_rate=2e-4,
        num_workers=4,
        save_steps=50000,
        num_epochs=999999,
    )

5. Verifying Training Results

If you don't want to wait for the model training to complete, you can directly download our pre-trained model.

modelscope download --model DiffSynth-Studio/AAAMyModel step-600000.safetensors --local_dir models/DiffSynth-Studio/AAAMyModel

Loading the model

from diffsynth import load_model

pipe = AAAImagePipeline.from_pretrained(
    torch_dtype=torch.bfloat16,
    device="cuda",
    model_configs=[
        ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
        ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
    ],
    tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
)
pipe.dit = load_model(AAADiT, "models/DiffSynth-Studio/AAAMyModel/step-600000.safetensors", torch_dtype=torch.bfloat16, device="cuda")

Model inference, generating the first-generation Pokemon "starter trio". At this point, the images generated by the model basically match the training data.

for seed, prompt in enumerate([
    "green, lizard, plant, Grass, Poison, seed on back, red eyes, smiling expression, short stout limbs, sharp claws",
    "orange, cream, lizard, Fire, flame on tail tip, large eyes, smiling expression, cream-colored belly patch, sharp claws",
    "blue, beige, brown, turtle, water type, shell, big eyes, short limbs, curled tail",
]):
    image = pipe(
        prompt=prompt,
        negative_prompt=" ",
        num_inference_steps=30,
        cfg_scale=10,
        seed=seed,
        height=256, width=256,
    )
    image.save(f"image_{seed}.jpg")
Image Image Image

Model inference, generating Pokemon with "sharp claws". At this point, different random seeds can produce different image results.

for seed, prompt in enumerate([
    "sharp claws",
    "sharp claws",
    "sharp claws",
]):
    image = pipe(
        prompt=prompt,
        negative_prompt=" ",
        num_inference_steps=30,
        cfg_scale=10,
        seed=seed+4,
        height=256, width=256,
    )
    image.save(f"image_sharp_claws_{seed}.jpg")
Image Image Image

Now, we have obtained a 0.1B small text-to-image model. This model can already generate 151 Pokemon, but cannot generate other image content. If you increase the amount of data, model parameters, and number of GPUs based on this, you can train a more powerful text-to-image model!