Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions cosmos_framework/callbacks/dit_image_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: OpenMDW-1.1

"""Fixed class-conditioned image sampling callback for DiT training."""

from __future__ import annotations

from contextlib import nullcontext
from typing import Any

import torch
import torchvision
import wandb

from cosmos_framework.callbacks.every_n import EveryN
from cosmos_framework.model._base import ImaginaireModel
from cosmos_framework.trainer import ImaginaireTrainer
from cosmos_framework.utils import distributed


class DiTImageSampleCallback(EveryN):
"""Generate fixed ImageNet class samples through ``model.generate_image``."""

def __init__(
self,
every_n: int = 5000,
class_ids: list[int] | None = None,
cfg_scales: list[float] | None = None,
num_steps: int = 50,
seed: int = 0,
is_ema: bool = True,
run_at_start: bool = False,
) -> None:
super().__init__(every_n=every_n, run_at_start=run_at_start)
self.class_ids = class_ids or [0, 1, 2, 3]
self.cfg_scales = cfg_scales or [1.0, 1.25, 1.5, 2.0]
self.num_steps = num_steps
self.seed = seed
self.is_ema = is_ema
self.rank = distributed.get_rank()

@torch.no_grad()
def every_n_impl(
self,
trainer: ImaginaireTrainer,
model: ImaginaireModel,
data_batch: dict[str, torch.Tensor],
output_batch: dict[str, torch.Tensor],
loss: torch.Tensor,
iteration: int,
) -> None:
del trainer, data_batch, output_batch, loss

if not hasattr(model, "generate_image"):
raise AttributeError("DiTImageSampleCallback requires model.generate_image().")
if self.is_ema and not model.config.ema.enabled:
return

was_training = model.training
context: Any = model.ema_scope("dit_image_sample") if self.is_ema else nullcontext()
generated_rows: list[torch.Tensor] = []
seed_list = [self.seed + sample_idx for sample_idx in range(len(self.class_ids))]
try:
with context:
for cfg_scale in self.cfg_scales:
images = model.generate_image(
class_ids=self.class_ids,
num_steps=self.num_steps,
cfg_scale=cfg_scale,
seed=seed_list,
) # [B,3,H,W]
if self.rank == 0:
generated_rows.append(images.detach().float().cpu()) # [B,3,H,W]
finally:
if was_training:
model.train()

if self.rank != 0 or wandb.run is None or not generated_rows:
return

grid_images = torch.cat(generated_rows, dim=0) # [R*B,3,H,W]
grid = torchvision.utils.make_grid(grid_images, nrow=len(self.class_ids), padding=2, normalize=False) # [3,H,W]
grid_np = grid.clamp(0.0, 1.0).permute(1, 2, 0).numpy() # [H,W,3]
tag = "ema" if self.is_ema else "reg"
caption = f"classes={self.class_ids}, cfg={self.cfg_scales}, steps={self.num_steps}, seed={self.seed}, {tag}"
wandb.log(
{f"dit_image_sample/{tag}": wandb.Image(grid_np, caption=caption)},
step=iteration,
)
15 changes: 14 additions & 1 deletion cosmos_framework/callbacks/iter_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,27 @@ def every_n_impl(
)

per_sample_batch_counter = dict()
if hasattr(model, "is_image_batch"):
# for VFM
if hasattr(model, "is_image_batch") and hasattr(model, "input_image_key") and hasattr(model, "input_video_key"):
is_image_batch = model.is_image_batch(data_batch)
if is_image_batch:
image_batch_size = len(data_batch[model.input_image_key])
per_sample_batch_counter["image_batch_size"] = image_batch_size
else:
video_batch_size = len(data_batch[model.input_video_key])
per_sample_batch_counter["video_batch_size"] = video_batch_size
# for LLM training only
elif "input_ids" in data_batch:
mbs = data_batch["input_ids"].shape[0]
dp_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
grad_accum_iter = int(trainer.config.trainer.grad_accum_iter)
per_sample_batch_counter["token_batch_size"] = mbs
per_sample_batch_counter["token_global_batch_size"] = mbs * dp_size * grad_accum_iter
# Cumulative token count (LLM analog of sample_counter). Set by
# ``LLMPretrainModel.training_step`` into a persistent buffer on
# ``model.net``, so this value survives checkpoint resume.
if hasattr(model, "token_counter"):
per_sample_batch_counter["token_counter"] = model.token_counter

if wandb.run:
sample_counter = getattr(trainer, "sample_counter", iteration)
Expand Down
2 changes: 1 addition & 1 deletion cosmos_framework/callbacks/norm_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from cosmos_framework.utils import distributed, log, misc
from cosmos_framework.utils.callback import Callback
from cosmos_framework.utils.easy_io import easy_io
from cosmos_framework.data.vfm.sequence_packing import get_gen_seq
from cosmos_framework.data.vfm.sequence_packing.runtime import get_gen_seq

try:
from apex.contrib.layer_norm import FastLayerNorm
Expand Down
11 changes: 6 additions & 5 deletions cosmos_framework/callbacks/sequence_packing_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import torch
import wandb

import cosmos_framework.data.vfm.sequence_packing as sequence_packing
from cosmos_framework.callbacks.every_n import EveryN
from cosmos_framework.model._base import ImaginaireModel
from cosmos_framework.trainer import ImaginaireTrainer
from cosmos_framework.data.vfm.sequence_packing.runtime import get_padding_stats


class SequencePackingPadding(EveryN):
Expand All @@ -32,11 +32,12 @@ def every_n_impl(
iteration: int,
) -> None:
if wandb.run:
padding_stats = get_padding_stats()
log_dict = {
"SequencePackingPadding/max_causal_len_image_batch": sequence_packing.MAX_CAUSAL_LEN_IMAGE_BATCH,
"SequencePackingPadding/max_full_len_image_batch": sequence_packing.MAX_FULL_LEN_IMAGE_BATCH,
"SequencePackingPadding/max_causal_len_video_batch": sequence_packing.MAX_CAUSAL_LEN_VIDEO_BATCH,
"SequencePackingPadding/max_full_len_video_batch": sequence_packing.MAX_FULL_LEN_VIDEO_BATCH,
"SequencePackingPadding/max_causal_len_image_batch": padding_stats["MAX_CAUSAL_LEN_IMAGE_BATCH"],
"SequencePackingPadding/max_full_len_image_batch": padding_stats["MAX_FULL_LEN_IMAGE_BATCH"],
"SequencePackingPadding/max_causal_len_video_batch": padding_stats["MAX_CAUSAL_LEN_VIDEO_BATCH"],
"SequencePackingPadding/max_full_len_video_batch": padding_stats["MAX_FULL_LEN_VIDEO_BATCH"],
}
modality = "video"
if "is_image_batch" in output_batch:
Expand Down
70 changes: 70 additions & 0 deletions cosmos_framework/configs/base/defaults/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from cosmos_framework.utils.lazy_config import LazyCall as L
from cosmos_framework.utils.callback import LowPrecisionCallback, WandBCallback
from cosmos_framework.callbacks.compile_tokenizer import CompileTokenizer

from cosmos_framework.callbacks.device_monitor import DeviceMonitor
from cosmos_framework.callbacks.dit_image_sample import DiTImageSampleCallback
from cosmos_framework.callbacks.every_n_draw_sample import EveryNDrawSample
from cosmos_framework.callbacks.expert_heatmap import ExpertHeatmap
from cosmos_framework.callbacks.grad_clip import GradClip
Expand Down Expand Up @@ -72,6 +74,41 @@
ofu=L(OFUCallback)(every_n="${trainer.logging_iter}"),
)

# LLM-only subset of BASIC_CALLBACKS.
# Drops VFM-specific callbacks:
# - CompileTokenizer: requires model.tokenizer_vision_gen (VAE)
# - ExpertHeatmap: requires MoE language_model with mlp_moe_gen
# - SigmaLossAnalysis: rectified-flow specific
# - SequencePackingPadding: VFM multi-modal packing specific
# - NormMonitor: param filter assumes "moe_gen" params → logs nothing for dense LLM
# Drops Necessary but not supported Callbacks:
# - MFU: @TODO
BASIC_LLM_CALLBACKS = dict(
iter_speed=L(IterSpeed)(
every_n="${trainer.logging_iter}",
save_s3="${upload_reproducible_setup}",
save_s3_every_log_n=500,
hit_thres=50,
),
manual_gc=L(ManualGarbageCollection)(every_n=5),
wandb=L(WandBCallback)(),
wandb_2x=L(WandBCallbackMultiplier)(
logging_iter_multipler=2,
save_logging_iter_multipler=1,
save_s3="${upload_reproducible_setup}",
),
param_count=L(ParamCount)(
save_s3="${upload_reproducible_setup}",
),
wandb_val=L(WandBCallbackEval)(
save_s3="${upload_reproducible_setup}",
),
ofu=L(OFUCallback)(every_n="${trainer.logging_iter}"),
)

# DiT-safe subset for LLM-backed rectified-flow image training.
BASIC_DIT_CALLBACKS = dict(BASIC_LLM_CALLBACKS)

JOB_MONITOR_CALLBACKS = dict(
heart_beat=L(HeartBeat)(
every_n=200,
Expand All @@ -94,6 +131,19 @@
low_precision=L(LowPrecisionCallback)(update_iter=1, config=PLACEHOLDER, trainer=PLACEHOLDER), # use model
)

OPTIMIZATION_LLM_CALLBACKS = dict(
skip_nan_step=L(SkipNaNStep)(max_consecutive_nan=100),
grad_clip=L(GradClip)(clip_norm=1.0, track_per_modality=False),
low_precision=L(LowPrecisionCallback)(update_iter=1, config=PLACEHOLDER, trainer=PLACEHOLDER),
)

# DiT reuses the same GradClip callback as LLM, without VFM image/video grad-norm split.
OPTIMIZATION_DIT_CALLBACKS = dict(
skip_nan_step=L(SkipNaNStep)(max_consecutive_nan=100),
grad_clip=L(GradClip)(clip_norm=1.0, track_per_modality=False),
low_precision=L(LowPrecisionCallback)(update_iter=1, config=PLACEHOLDER, trainer=PLACEHOLDER),
)

VIZ_ONLINE_SAMPLING_CALLBACKS = dict(
every_n_sample_reg=L(EveryNDrawSample)(
every_n=5000,
Expand All @@ -108,22 +158,42 @@
),
)

DIT_IMAGE_SAMPLING_CALLBACKS = dict(
dit_image_sample_ema=L(DiTImageSampleCallback)(
every_n=5000,
class_ids=[0, 1, 2, 3],
cfg_scales=[1.0, 1.25, 1.5, 2.0],
num_steps=50,
seed=0,
is_ema=True,
),
)


def register_callbacks():
cs = ConfigStore.instance()
cs.store(group="callbacks", package="trainer.callbacks", name="basic", node=BASIC_CALLBACKS)
cs.store(group="callbacks", package="trainer.callbacks", name="job_monitor", node=JOB_MONITOR_CALLBACKS)
cs.store(group="callbacks", package="trainer.callbacks", name="optimization", node=OPTIMIZATION_CALLBACKS)
cs.store(group="callbacks", package="trainer.callbacks", name="optimization_llm", node=OPTIMIZATION_LLM_CALLBACKS)
cs.store(group="callbacks", package="trainer.callbacks", name="optimization_dit", node=OPTIMIZATION_DIT_CALLBACKS)
# Online sampling generation callback
cs.store(
group="callbacks", package="trainer.callbacks", name="viz_online_sampling", node=VIZ_ONLINE_SAMPLING_CALLBACKS
)
# Register "generation" as alias for "viz_online_sampling" (expected by base config.py defaults)
cs.store(group="callbacks", package="trainer.callbacks", name="generation", node=VIZ_ONLINE_SAMPLING_CALLBACKS)
cs.store(
group="callbacks", package="trainer.callbacks", name="dit_image_sampling", node=DIT_IMAGE_SAMPLING_CALLBACKS
)

TRAINING_STATS_CALLBACKS = dict(
training_stats=L(TrainingStatsCallback)(
log_freq=100,
)
)
cs.store(group="callbacks", package="trainer.callbacks", name="training_stats", node=TRAINING_STATS_CALLBACKS)

# Only for LLM training, removed callbacks that is not working for llm training
cs.store(group="callbacks", package="trainer.callbacks", name="basic_llm", node=BASIC_LLM_CALLBACKS)
cs.store(group="callbacks", package="trainer.callbacks", name="basic_dit", node=BASIC_DIT_CALLBACKS)
64 changes: 64 additions & 0 deletions cosmos_framework/configs/base/defaults/llm_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: OpenMDW-1.1

"""Register LLM and DiT model configs alongside the existing ``mot_fsdp``."""

from hydra.core.config_store import ConfigStore

from cosmos_framework.utils.lazy_config import LazyCall as L
from cosmos_framework.model.vfm.llm.dit.image_dit_model import DiTPretrainModel, DiTPretrainModelConfig
from cosmos_framework.model.vfm.llm.llm_pretrain_model import LLMPretrainModel, LLMPretrainModelConfig

# ── FSDP config (production, multi-GPU) ──────────────────────────────────────

LLM_FSDP_CONFIG = dict(
trainer=dict(
distributed_parallelism="fsdp",
),
model=L(LLMPretrainModel)(
config=LLMPretrainModelConfig(),
_recursive_=False,
),
)

# ── DDP config (debug, single-node) ─────────────────────────────────────────

LLM_DDP_CONFIG = dict(
trainer=dict(
distributed_parallelism="ddp",
),
model=L(LLMPretrainModel)(
config=LLMPretrainModelConfig(),
_recursive_=False,
),
)

# ── Image DiT configs ───────────────────────────────────────────────────────

DIT_FSDP_CONFIG = dict(
trainer=dict(
distributed_parallelism="fsdp",
),
model=L(DiTPretrainModel)(
config=DiTPretrainModelConfig(),
_recursive_=False,
),
)

DIT_DDP_CONFIG = dict(
trainer=dict(
distributed_parallelism="ddp",
),
model=L(DiTPretrainModel)(
config=DiTPretrainModelConfig(),
_recursive_=False,
),
)


def register_llm_model():
cs = ConfigStore.instance()
cs.store(group="model", package="_global_", name="llm_fsdp", node=LLM_FSDP_CONFIG)
cs.store(group="model", package="_global_", name="llm_ddp", node=LLM_DDP_CONFIG)
cs.store(group="model", package="_global_", name="dit_fsdp", node=DIT_FSDP_CONFIG)
cs.store(group="model", package="_global_", name="dit_ddp", node=DIT_DDP_CONFIG)
19 changes: 4 additions & 15 deletions cosmos_framework/configs/base/defaults/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,6 @@ class RectifiedFlowTrainingConfig:
loss_scale: float = 1.0 # Loss scale
image_loss_scale: float | None = None # If set, overrides loss_scale for images
sound_loss_scale: float | None = None # If set, overrides loss_scale for sound
use_high_sigma_strategy: bool = False # Whether to use high sigma strategy
high_sigma_ratio: float = 0.05 # Ratio of using high sigmas
high_sigma_timesteps_min: int = 995 # Minimum timestep for high sigma
high_sigma_timesteps_max: int = 1000 # Maximum timestep for high sigma
use_discrete_rf: bool = False # Whether to use discrete formulation of rectified flow

# user: please adjust this value according to loss_scale to balance the action loss with the video loss.
Expand All @@ -93,21 +89,16 @@ class RectifiedFlowTrainingConfig:

# Independent noise schedule for action. When False (default), action shares the sigma
# sampled from the vision RF on every step — legacy behavior. When True, action samples
# its own sigma from `rectified_flow_action` using `shift_action` and
# `use_high_sigma_strategy_action`. Action always uses a shared scalar sigma per sample
# ([B,1]), independent of vision's DF mode. If action opts in to the high-sigma strategy,
# it reuses the global ratio / min / max.
# its own sigma from `rectified_flow_action` using `shift_action`. Action always uses a
# shared scalar sigma per sample ([B,1]), independent of vision's DF mode.
independent_action_schedule: bool = False
shift_action: int | None = None # must be int; None → inherit `shift` (which must also be int)
use_high_sigma_strategy_action: bool = False

# Independent noise schedule for sound. When False (default), sound shares the vision
# sigma schedule, reindexed to the dense audio-bearing subset. When True, sound samples
# its own scalar sigma per sample ([B,1]) from `rectified_flow_sound` using `shift_sound`
# and `use_high_sigma_strategy_sound`.
# its own scalar sigma per sample ([B,1]) from `rectified_flow_sound` using `shift_sound`.
independent_sound_schedule: bool = False
shift_sound: int | None = None # must be int; None → inherit `shift` (which must also be int)
use_high_sigma_strategy_sound: bool = False

# When True, per-instance flow-matching loss is normalized by the count of
# active (noisy) elements rather than all elements — preserves sum/active_count
Expand Down Expand Up @@ -204,9 +195,7 @@ class OmniMoTModelConfig:
# Attention implementation for joint understanding + generation
# Note "two_way" and "three_way" disallow and remove "End-of-Vision" or other text token in the generation tower.
# "three_way" must only be used when introducing sparsity
joint_attn_implementation: str = (
"two_way" # "two_way", "three_way" or "flex" (NOTICE: We are planning to remove "flex" soon)
)
joint_attn_implementation: str = "two_way" # "two_way" or "three_way"

# Per-layer NATTEN parameters
# Must use "three_way" attention if used.
Expand Down
Loading