diff --git a/cosmos_framework/configs/base/config.py b/cosmos_framework/configs/base/config.py index e766c5c..5ac2b41 100644 --- a/cosmos_framework/configs/base/config.py +++ b/cosmos_framework/configs/base/config.py @@ -97,4 +97,5 @@ def make_config() -> Config: import cosmos_framework.configs.base.experiment.sft.vision_sft_nano # noqa: F401 import cosmos_framework.configs.base.experiment.sft.vision_sft_super # noqa: F401 import cosmos_framework.configs.base.experiment.action.posttrain_config.action_policy_droid_nano # noqa: F401 + import cosmos_framework.configs.base.experiment.action.posttrain_config.action_policy_libero_nano # noqa: F401 return c diff --git a/cosmos_framework/configs/base/experiment/action/posttrain_config/action_policy_libero_nano.py b/cosmos_framework/configs/base/experiment/action/posttrain_config/action_policy_libero_nano.py new file mode 100644 index 0000000..b05d3ea --- /dev/null +++ b/cosmos_framework/configs/base/experiment/action/posttrain_config/action_policy_libero_nano.py @@ -0,0 +1,233 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""``action_policy_libero_nano`` — Cosmos3-Nano LIBERO-10 action-policy SFT recipe. + +Feeds ``LIBEROLeRobotDataset`` (frame-wise-relative rot6d, ``quantile_rot``, +concat_view third-person + wrist) and trains the generation + action heads from +the public ``nvidia/Cosmos3-Nano`` base. Train on ``libero_10`` alone +(``LIBERO_ROOT``). See docs/action_policy_libero_sft.md. +""" + +import copy + +from hydra.core.config_store import ConfigStore + +from cosmos_framework.utils.lazy_config import LazyCall as L +from cosmos_framework.utils.lazy_config import LazyDict + +from cosmos_framework.configs.base.experiment.sft.models.nano_model_config import NANO_MODEL_CONFIG +from cosmos_framework.data.vfm.joint_dataloader import ( + PackingDataLoader, + RankPartitionedDataLoader, +) +from cosmos_framework.data.vfm.action.datasets.action_sft_dataset import get_action_libero_sft_dataset + +cs = ConfigStore.instance() + + +def _action_policy_libero_nano_model_config() -> dict: + """LIBERO model config: capped packed tokens, selective activation + checkpointing, fresh diffusion-expert init, 10x vision flow-matching loss. + Keep ``encode_exact_durations=[17, 61, 73]`` to match the Cosmos3-Nano base.""" + cfg = copy.deepcopy(NANO_MODEL_CONFIG) # action_gen=True, max_action_dim=64 + # Cap the packed sequence. Uncapped (-1) + a large max_samples_per_batch packs + # one very long sequence and OOMs even on H200; 74000 keeps the GA-validated bound. + cfg["max_num_tokens_after_packing"] = 74000 + cfg["activation_checkpointing"]["mode"] = "selective" + cfg["diffusion_expert_config"]["load_weights_from_pretrained"] = False + cfg["rectified_flow_training_config"]["loss_scale"] = 10.0 + cfg["rectified_flow_training_config"]["image_loss_scale"] = None + cfg["tokenizer"]["encode_exact_durations"] = [17, 61, 73] # match Cosmos3 base + reference SFT (do NOT reduce) + return cfg + + +action_policy_libero_nano = LazyDict( + dict( + defaults=[ + {"override /model": "mot_fsdp"}, + {"override /data_train": None}, + {"override /data_val": None}, + # FusedAdam with fp32 master_weights + eps 1e-8 (bf16 params + eps 1e-6 + # diverged on the action loss). + {"override /optimizer": "fusedadamw"}, + {"override /scheduler": "lambdalinear"}, # linear LR decay + {"override /checkpoint": "s3"}, + { + "override /callbacks": [ + "basic", + "optimization", + "job_monitor", + ] + }, + {"override /ema": "power"}, + {"override /tokenizer": "wan2pt2_tokenizer"}, + {"override /sound_tokenizer": None}, + {"override /vlm_config": None}, + {"override /ckpt_type": "dcp"}, + "_self_", + ], + job=dict( + project="cosmos3", + group="action_sft", + name="action_policy_libero_nano", + wandb_mode="disabled", + ), + model=dict( + config=_action_policy_libero_nano_model_config(), + ), + optimizer=dict( + betas=[0.9, 0.99], + eps=1.0e-08, + fused=True, # popped by build_optimizer for FusedAdam (fused by construction) + # Train the generation + action heads. + keys_to_select=[ + "moe_gen", + "time_embedder", + "vae2llm", + "llm2vae", + "action2llm", + "llm2action", + "action_modality_embed", + ], + lr=5.0e-05, + lr_multipliers={ + "action2llm": 5.0, + "llm2action": 5.0, + "action_modality_embed": 5.0, + }, + optimizer_type="FusedAdam", + weight_decay=0.05, + ), + scheduler=dict( + lr_scheduler_type="LambdaLinear", + cycle_lengths=[100], # smoke: 100 iters (real run sets via TOML, GA=10000) + f_max=[1.0], + f_min=[0.0], + f_start=[1.0e-06], + verbosity_interval=0, + warm_up_steps=[0], # smoke (real run sets via TOML, GA=2000) + ), + trainer=dict( + distributed_parallelism="fsdp", + grad_accum_iter=1, # real run sets via TOML (GA=2) + logging_iter=1, + max_iter=100, # smoke + max_val_iter=None, + run_validation=False, + run_validation_on_start=False, + save_zero_checkpoint=False, + seed=42, + timeout_period=999999999, + validation_iter=100, + compile_config=dict(recompile_limit=8, use_duck_shape=False), + cudnn=dict(benchmark=True, deterministic=False), + ddp=dict(broadcast_buffers=True, find_unused_parameters=False, static_graph=True), + grad_scaler_args=dict(enabled=False), + callbacks=dict( + dataloader_speed=dict(every_n=100, save_s3=False, step_size=1), + device_monitor=dict( + every_n=200, log_memory_detail=True, save_s3=False, step_size=1, upload_every_n_mul=5 + ), + grad_clip=dict(clip_norm=1.0, force_finite=True), + heart_beat=dict(every_n=200, save_s3=False, step_size=1, update_interval_in_minute=20), + iter_speed=dict(every_n=1, hit_thres=50, save_s3=False, save_s3_every_log_n=500), + low_precision=dict(update_iter=1), + manual_gc=dict(every_n=5, gc_level=1, warm_up=1), + param_count=dict(save_s3=False), + skip_nan_step=dict(max_consecutive_nan=100), + training_stats=dict(log_freq=100), + ), + ), + checkpoint=dict( + broadcast_via_filesystem=False, + dcp_async_mode_enabled=False, + enable_gcs_patch_in_boto3=True, + keys_not_to_resume=[], + # Skip net_ema (EMA warm-starts from net, see dcp.py) and the action + # heads, so they init fresh from the base (the public Cosmos3-Nano base + # has no LIBERO-trained action heads). + keys_to_skip_loading=[ + "net_ema.", + "action2llm", + "llm2action", + "action_modality_embed", + "action_pos_embed", + ], + load_ema_to_reg=False, + load_path="???", # Cosmos3-Nano DCP dir; supply via TOML/env + load_training_state=False, + only_load_scheduler_state=False, + save_iter=100, + strict_resume=False, # base init: tolerate key set differences + verbose=True, + hf_export=dict( + enabled=False, + export_every_n=1, + hf_repo_id=None, + upload_to_object_store=dict(bucket="", credentials="", enabled=False), + ), + jit=dict(device="cuda", dtype="bfloat16", enabled=False, input_shape=None, strict=True), + load_from_object_store=dict(bucket="", credentials="", enabled=False), + save_to_object_store=dict(bucket="", credentials="", enabled=False), + ), + dataloader_train=L(PackingDataLoader)( + audio_sample_rate=48000, + dataset_name="action_libero", + max_samples_per_batch=128, # peak-mem bound (256 OOMs on H200); global = 128 x DP8 x grad_accum2 = 2048 + max_sequence_length=None, # None disables token packing (TOML can't express null) + patch_spatial=2, + sound_latent_fps=0, + tokenizer_spatial_compression_factor=16, + tokenizer_temporal_compression_factor=4, + dataloader=L(RankPartitionedDataLoader)( + batch_size=1, + in_order=False, + num_workers=4, + persistent_workers=True, + pin_memory=True, + prefetch_factor=4, + sampler=None, + # Shuffling is handled by the dataset (iterable_shuffle=True below): + # ActionIterableShuffleDataset streams rank x worker-sharded, episode-order- + # shuffled, sequential-within-episode. + datasets=dict( + libero=dict( + ratio=1, + dataset=L(get_action_libero_sft_dataset)( + # Local LeRobot dir for the libero_10 suite ONLY. Use the + # 20 FPS nvidia/LIBERO_LeRobot_v3 (matches the bundled stats + 20 Hz eval): + # hf download nvidia/LIBERO_LeRobot_v3 --repo-type dataset \ + # --include 'libero_10/**' --local-dir # LIBERO_ROOT=/libero_10 + root="${oc.env:LIBERO_ROOT}", + fps=20, # metadata only (FPS-agnostic loader reads native fps from info.json) + chunk_length=16, + image_size=256, # concat_view -> 256x512 + mode="policy", + camera_mode="concat_view", + action_space="frame_wise_relative", + rotation_space="6d", + pose_coordinate_frame="native", + action_normalization="quantile_rot", + val_ratio=0.01, + iterable_shuffle=True, + episode_shuffle_seed=42, + resolution=None, + max_action_dim="${model.config.max_action_dim}", + cfg_dropout_rate=0.1, + tokenizer_config="${model.config.vlm_config.tokenizer}", + ), + ), + ), + ), + ), + dataloader_val=None, + upload_reproducible_setup=False, + ), + flags={"allow_objects": True}, +) + + +for _item in [action_policy_libero_nano]: + _name = [k for k, v in globals().items() if v is _item][0] + cs.store(group="experiment", package="_global_", name=_name, node=_item) diff --git a/cosmos_framework/data/vfm/action/datasets/__init__.py b/cosmos_framework/data/vfm/action/datasets/__init__.py index 0b01e6b..6365693 100644 --- a/cosmos_framework/data/vfm/action/datasets/__init__.py +++ b/cosmos_framework/data/vfm/action/datasets/__init__.py @@ -12,6 +12,7 @@ from cosmos_framework.data.vfm.action.datasets.base_dataset import ActionBaseDataset from cosmos_framework.data.vfm.action.datasets.bridge_orig_lerobot_dataset import BridgeOrigLeRobotDataset from cosmos_framework.data.vfm.action.datasets.droid_lerobot_dataset import DROIDLeRobotDataset +from cosmos_framework.data.vfm.action.datasets.libero_lerobot_dataset import LIBEROLeRobotDataset from cosmos_framework.data.vfm.action.datasets.robomind_franka_dataset import RoboMINDFrankaDataset __all__ = [ @@ -19,5 +20,6 @@ "AgiBotWorldBetaLeRobotDataset", "BridgeOrigLeRobotDataset", "DROIDLeRobotDataset", + "LIBEROLeRobotDataset", "RoboMINDFrankaDataset", ] diff --git a/cosmos_framework/data/vfm/action/datasets/action_sft_dataset.py b/cosmos_framework/data/vfm/action/datasets/action_sft_dataset.py index 1790de5..96a5219 100644 --- a/cosmos_framework/data/vfm/action/datasets/action_sft_dataset.py +++ b/cosmos_framework/data/vfm/action/datasets/action_sft_dataset.py @@ -19,6 +19,7 @@ from torch.utils.data import Dataset, IterableDataset, get_worker_info from cosmos_framework.data.vfm.action.datasets.droid_lerobot_dataset import DROIDLeRobotDataset +from cosmos_framework.data.vfm.action.datasets.libero_lerobot_dataset import LIBEROLeRobotDataset from cosmos_framework.data.vfm.action.transforms import ActionTransformPipeline @@ -139,3 +140,72 @@ def get_action_droid_sft_dataset( if iterable_shuffle: return ActionIterableShuffleDataset(sft, seed=episode_shuffle_seed) return sft + + +def get_action_libero_sft_dataset( + *, + root: str, + fps: float = 20.0, + chunk_length: int = 16, + image_size: int = 256, + mode: str = "policy", + camera_mode: str = "concat_view", + action_space: str = "frame_wise_relative", + rotation_space: str = "6d", + pose_coordinate_frame: str = "native", + action_normalization: str | None = "quantile_rot", + action_stats_path: str | None = None, + split: str = "train", + val_ratio: float = 0.01, + seed: int = 0, + resolution: str | int | None = None, + max_action_dim: int = 64, + tokenizer_config: dict | None = None, + cfg_dropout_rate: float = 0.1, + append_viewpoint_info: bool = True, + append_duration_fps_timestamps: bool = True, + append_resolution_info: bool = True, + append_idle_frames: bool = True, + iterable_shuffle: bool = False, + episode_shuffle_seed: int = 42, +) -> Dataset: + """Build the LIBERO action-policy SFT dataset (GA reproduction defaults). + + Feeds ``LIBEROLeRobotDataset`` (frame-wise-relative rot6d actions, + ``quantile_rot``-normalized, concat_view third-person + wrist at 256x256 each + → 256x512) through ``ActionTransformPipeline``. ``root`` is a LOCAL LeRobot dir + (read parquet + video directly); pre-sync the HF dataset once, e.g. + ``hf download lerobot/libero_10 --repo-type dataset --local-dir ``. Point + ``root`` at libero_10 alone. The + dataset is FPS-agnostic (decodes at real frame timestamps); ``fps`` is metadata + for ``conditioning_fps`` / prompt duration. + """ + dataset = LIBEROLeRobotDataset( + root=root, + image_size=image_size, + chunk_length=chunk_length, + fps=fps, + mode=mode, + split=split, + val_ratio=val_ratio, + seed=seed, + camera_mode=camera_mode, + action_space=action_space, + rotation_space=rotation_space, + pose_coordinate_frame=pose_coordinate_frame, + action_normalization=action_normalization, + action_stats_path=action_stats_path, + ) + transform = ActionTransformPipeline( + tokenizer_config=tokenizer_config, + cfg_dropout_rate=cfg_dropout_rate, + max_action_dim=max_action_dim, + append_viewpoint_info=append_viewpoint_info, + append_duration_fps_timestamps=append_duration_fps_timestamps, + append_resolution_info=append_resolution_info, + append_idle_frames=append_idle_frames, + ) + sft = ActionSFTDataset(dataset, transform, resolution) + if iterable_shuffle: + return ActionIterableShuffleDataset(sft, seed=episode_shuffle_seed) + return sft diff --git a/cosmos_framework/data/vfm/action/datasets/base_dataset.py b/cosmos_framework/data/vfm/action/datasets/base_dataset.py index 2c9c4cb..f9b0f61 100644 --- a/cosmos_framework/data/vfm/action/datasets/base_dataset.py +++ b/cosmos_framework/data/vfm/action/datasets/base_dataset.py @@ -69,18 +69,25 @@ def __init__( for path in sorted((self._root / "meta" / "episodes").glob("chunk-*/file-*.parquet")) for row in pq.read_table(path).to_pylist() } - self._tasks = { - int(row["task_index"]): str(row["task"]) - for row in pq.read_table(self._root / "meta" / "tasks.parquet").to_pylist() - } - self._rows = sorted( - ( - row - for path in sorted((self._root / "data").glob("chunk-*/file-*.parquet")) - for row in pq.read_table(path).to_pylist() - ), - key=lambda row: int(row["index"]), - ) + # ``meta/tasks.parquet`` normally has a ``task`` column. Some LeRobot + # conversions (e.g. the community LIBERO datasets) instead store the task + # string as the (unnamed) pandas index, which pyarrow surfaces as + # ``__index_level_0__``. Fall back to the lone non-``task_index`` field so + # both layouts work (datasets that have ``task`` are unaffected). + self._tasks = {} + for row in pq.read_table(self._root / "meta" / "tasks.parquet").to_pylist(): + if "task" in row: + task = row["task"] + else: + extras = [v for k, v in row.items() if k != "task_index"] + task = extras[0] if extras else "" + self._tasks[int(row["task_index"])] = str(task) + # ``self._rows`` (the flat, index-sorted list of every frame dict) is built + # lazily on first access — see the ``_rows`` property. Materializing all + # ~18M frames as Python dicts plus a full sort costs ~13 min and tens of GB; + # subclasses that build their own compact index (e.g. DROIDLeRobotDataset) + # never touch it, so they must not pay for it at construction. + self._rows_cache: list[dict[str, Any]] | None = None @property def fps(self) -> float: @@ -213,5 +220,25 @@ def _build_result( **extras, } + @property + def _rows(self) -> list[dict[str, Any]]: + """Flat, index-sorted list of every frame dict, built lazily on first access. + + Only datasets that don't build their own compact index (bridge / agibot / + robomind) touch this; for them it materializes once and caches. Datasets with + a bespoke index (e.g. DROIDLeRobotDataset) never read it, so they skip the + ~13 min / tens-of-GB construction entirely. + """ + if self._rows_cache is None: + self._rows_cache = sorted( + ( + row + for path in sorted((self._root / "data").glob("chunk-*/file-*.parquet")) + for row in pq.read_table(path).to_pylist() + ), + key=lambda row: int(row["index"]), + ) + return self._rows_cache + def __len__(self) -> int: return max(0, (len(self._rows) - self._chunk_length + self._sample_stride - 1) // self._sample_stride) diff --git a/cosmos_framework/data/vfm/action/datasets/libero_lerobot_dataset.py b/cosmos_framework/data/vfm/action/datasets/libero_lerobot_dataset.py new file mode 100644 index 0000000..1e5ef01 --- /dev/null +++ b/cosmos_framework/data/vfm/action/datasets/libero_lerobot_dataset.py @@ -0,0 +1,333 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""LIBERO LeRobot dataset (frame-wise-relative action policy). + +Mirrors ``DROIDLeRobotDataset``: reads the LeRobot parquet directly, windows by +frame index, and decodes video at each frame's REAL timestamp. That makes it +FPS-agnostic — it works with the 10 FPS community ``lerobot/libero_*`` datasets +and a 20 FPS conversion alike, without LeRobot's ``delta_timestamps`` grid (which +rejects any window whose synthetic timestamps don't land on real frames). + +Action layout (``frame_wise_relative``): the stored 7D ``action`` is already a +per-frame delta ``[dpos(3), drot_axisangle(3), gripper(1)]``; only the rotation is +re-encoded to the requested ``rotation_space`` -> ``[dpos(3), rot6d(6), gripper(1)]`` +(10D for ``6d``). + +NOTE on FPS / stats fidelity: the bundled ``quantile_rot`` stats were computed on +a 20 FPS conversion. Per-frame deltas at 10 FPS span 2x the wall-clock motion, so +use a 20 FPS LIBERO dataset (or recompute stats for the dataset's FPS). +Loading/training is correct at any FPS regardless. +""" + +from __future__ import annotations + +import json +import random +from pathlib import Path +from typing import Any, Literal + +import numpy as np +import pyarrow.parquet as pq +import torch +import torch.nn.functional as F +from lerobot.datasets.video_utils import decode_video_frames + +from cosmos_framework.utils import log +from cosmos_framework.data.vfm.action.action_normalization import normalize_action +from cosmos_framework.data.vfm.action.action_spec import ActionSpec, Gripper, Pos, Rot, build_action_spec +from cosmos_framework.data.vfm.action.datasets.base_dataset import ActionBaseDataset +from cosmos_framework.data.vfm.action.libero_pose_utils import libero_action_dim, libero_rotation_format +from cosmos_framework.data.vfm.action.pose_utils import convert_rotation + +CameraMode = Literal["image", "wrist_image", "concat_view"] +RotationSpace = Literal["3d", "6d", "9d"] + +_ACTION_FEATURE = "action" +_IMAGE_FEATURE = "observation.images.image" +_WRIST_FEATURE = "observation.images.wrist_image" +_STAT_KEYS = ("mean", "std", "min", "max", "q01", "q99") +_NORMALIZERS_DIR = Path(__file__).parent / "stats" + +_VIEWPOINT_BY_CAMERA = { + "image": "third_person_view", + "wrist_image": "wrist_view", + "concat_view": "concat_view", +} + + +class LIBEROLeRobotDataset(ActionBaseDataset): + """LIBERO action-policy dataset with frame-wise-relative rot6d actions. + + 10D ``[pos_delta(3), rot6d_delta(6), gripper(1)]`` (for ``rotation_space='6d'``), + ``concat_view`` third-person + wrist video, and ``quantile_rot`` normalization + against the bundled stats. Reads parquet + decodes video at real timestamps, + so the requested ``fps`` is metadata only (it sets ``conditioning_fps`` and the + prompt duration); frame windows always use the data's actual frames. + """ + + def __init__( + self, + root: str, + fps: float = 20.0, + chunk_length: int = 16, + mode: str = "policy", + tolerance_s: float = 1e-4, + camera_mode: CameraMode = "concat_view", + image_size: int = 256, + action_space: str = "frame_wise_relative", + rotation_space: RotationSpace = "6d", + pose_coordinate_frame: str = "native", + embodiment_type: str = "libero", + action_normalization: str | None = "quantile_rot", + action_stats_path: str | None = None, + split: str = "train", + val_ratio: float = 0.01, + seed: int = 0, + sample_stride: int = 1, + ) -> None: + if action_space != "frame_wise_relative": + raise NotImplementedError( + f"This LIBERO dataset only supports action_space='frame_wise_relative', got {action_space!r}." + ) + if camera_mode not in _VIEWPOINT_BY_CAMERA: + raise ValueError(f"Unsupported camera_mode={camera_mode!r}. Use image/wrist_image/concat_view.") + split = split.lower().strip() + if split not in {"train", "val", "valid", "validation", "eval", "test", "full"}: + raise ValueError(f"Unsupported split={split!r}. Use train/val/full.") + if chunk_length % 4 != 0: + raise ValueError(f"chunk_length must be divisible by 4, got {chunk_length}.") + + super().__init__( + root=root, + domain_name=embodiment_type, + fps=fps, + chunk_length=chunk_length, + mode=mode, + pose_convention="backward_framewise", # unused for frame_wise deltas; satisfies the base assert + tolerance_s=tolerance_s, + viewpoint=_VIEWPOINT_BY_CAMERA[camera_mode], + # frame_wise_relative ⇔ backward_framewise idle semantics. quantile_rot is a + # LIBERO convention -> normalize with the "quantile" formula on raw-rotation + # stats (see _load_norm_stats); pass the method the base will call. + action_normalization=None if action_normalization is None else "quantile", + sample_stride=sample_stride, + ) + # FPS-agnostic loader: trust the dataset's NATIVE fps for conditioning_fps / + # prompt duration so the metadata is truthful (10 for the public + # lerobot/libero_*, 20 for a 20 FPS conversion). Frame sampling uses each + # frame's real timestamp regardless, so the requested ``fps`` is ignored here. + info_fps = self._info.get("fps") + if info_fps: + if int(info_fps) != int(fps): + log.info(f"Using dataset native fps={info_fps} for conditioning (requested {fps}).") + self._fps = float(info_fps) + self._dt = 1.0 / self._fps + self._camera_mode = camera_mode + self._image_size = int(image_size) + self._rotation_space = rotation_space.lower().strip() + self._pose_coordinate_frame = pose_coordinate_frame + self._embodiment_type = embodiment_type + self._requested_normalization = action_normalization + # quantile_rot normalizes against the raw (un-orthonormalized) rotation stats + # under "global_raw"; everything else uses "global". + self._stats_key = "global_raw" if action_normalization == "quantile_rot" else "global" + self._stats_file = self._resolve_stats_file(action_stats_path) + + if self._camera_mode == "image": + self._video_keys = [_IMAGE_FEATURE] + elif self._camera_mode == "wrist_image": + self._video_keys = [_WRIST_FEATURE] + else: + self._video_keys = [_IMAGE_FEATURE, _WRIST_FEATURE] + + # Compact, lazy frame index (mirrors DROIDLeRobotDataset): read only the + # columns the sample builder needs into contiguous arrays, ordered by global + # frame index, so DataLoader worker forks share them copy-on-write. + index_parts, episode_parts, task_parts, ts_parts, action_parts = [], [], [], [], [] + for path in sorted((self._root / "data").glob("chunk-*/file-*.parquet")): + table = pq.read_table(path, columns=["index", "episode_index", "task_index", "timestamp", _ACTION_FEATURE]) + index_parts.append(table["index"].to_numpy()) + episode_parts.append(table["episode_index"].to_numpy()) + task_parts.append(table["task_index"].to_numpy()) + ts_parts.append(table["timestamp"].to_numpy()) + action_parts.append(np.asarray(table[_ACTION_FEATURE].to_pylist(), dtype=np.float32)) + if not index_parts: + raise FileNotFoundError(f"No data parquet found under {self._root / 'data'}.") + order = np.argsort(np.concatenate(index_parts).astype(np.int64), kind="stable") + self._row_episode = np.concatenate(episode_parts).astype(np.int64)[order] + self._row_task = np.concatenate(task_parts).astype(np.int64)[order] + self._row_timestamp = np.concatenate(ts_parts).astype(np.float64)[order] + self._row_action = np.concatenate(action_parts, axis=0).astype(np.float32)[order] + + assert np.all(np.diff(self._row_episode) >= 0), "episode_index not contiguous after sorting by frame index" + ep_vals, ep_starts, ep_counts = np.unique(self._row_episode, return_index=True, return_counts=True) + + # Deterministic per-episode train/val split (seeded; same on every rank). + keep = self._split_episode_ids(ep_vals.tolist(), split, val_ratio, seed) + kept = np.array([int(v) in keep for v in ep_vals], dtype=bool) + self._ep_vals = ep_vals.astype(np.int64)[kept] + self._ep_starts = ep_starts.astype(np.int64)[kept] + kept_counts = ep_counts.astype(np.int64)[kept] + # Within-episode windows only: total - n_kept_episodes * chunk_length valid samples. + self._valid_cum = np.cumsum(np.maximum(0, kept_counts - self._chunk_length)).astype(np.int64) + + log.info( + f"Loaded LIBERO dataset root={self._root} split={split!r} camera_mode={camera_mode!r} " + f"fps={self._fps} kept_episodes={len(self._ep_vals)}/{len(ep_vals)} " + f"valid_indices={int(self._valid_cum[-1]) if self._valid_cum.size else 0}" + ) + + # ---- spec / dims ------------------------------------------------------- + + @property + def action_dim(self) -> int: + return libero_action_dim(self._rotation_space) + + def _action_spec(self) -> ActionSpec: + return build_action_spec(Pos(), Rot(libero_rotation_format(self._rotation_space)), Gripper()) + + @classmethod + def _stats_path(cls) -> Path: + # Base classmethod fallback; the instance uses self._stats_file (which also + # honors action_stats_path + the rotation/coordinate-frame-specific filename). + return _NORMALIZERS_DIR / "libero_native_frame_wise_relative_rot6d.json" + + # ---- normalization (nested global/global_raw + quantile_rot) ------------ + + def _bundled_stats_filename(self) -> str: + rotation_suffix = {"3d": "3d", "6d": "rot6d", "9d": "rot9d"}.get(self._rotation_space) + if rotation_suffix is None: + raise ValueError(f"Unsupported rotation_space={self._rotation_space!r}.") + action_space = "frame_wise_relative" + return f"{self._embodiment_type}_{self._pose_coordinate_frame}_{action_space}_{rotation_suffix}.json" + + def _resolve_stats_file(self, action_stats_path: str | None) -> Path: + if action_stats_path: + p = Path(action_stats_path) + if not p.is_absolute(): + p = _NORMALIZERS_DIR / p.name + if not p.exists(): + raise FileNotFoundError(f"action_stats_path not found: {action_stats_path!r}") + return p + p = _NORMALIZERS_DIR / self._bundled_stats_filename() + if not p.exists(): + raise FileNotFoundError( + f"Bundled LIBERO stats not found at {p}. Pass action_stats_path or recompute stats." + ) + return p + + def _load_norm_stats(self) -> dict[str, torch.Tensor]: + if self._norm_stats is None: + raw = json.loads(self._stats_file.read_text())[self._stats_key] + self._norm_stats = { + k: torch.tensor(v, dtype=torch.float32) for k, v in raw.items() if k in _STAT_KEYS + } + return self._norm_stats + + # ---- index helpers ----------------------------------------------------- + + @staticmethod + def _split_episode_ids(ep_ids: list[int], split: str, val_ratio: float, seed: int) -> set[int]: + if split == "full": + return set(int(v) for v in ep_ids) + if not (0.0 < val_ratio < 1.0): + raise ValueError(f"val_ratio must be in (0, 1), got {val_ratio}.") + n_val = max(1, int(round(len(ep_ids) * val_ratio))) + rng = random.Random(seed) # identical selection on every rank + val = set(int(v) for v in rng.sample(list(ep_ids), n_val)) + if split == "train": + return set(int(v) for v in ep_ids) - val + return val # val/valid/validation/eval/test + + def __len__(self) -> int: + return int(self._valid_cum[-1]) if self._valid_cum.size else 0 + + def get_shuffle_blocks(self) -> list[tuple[int, int]]: + """Per-episode ``(start, length)`` flat-index blocks for + ``ActionIterableShuffleDataset`` (shuffle block ORDER + shard across + ranks, sequential within a block).""" + blocks: list[tuple[int, int]] = [] + prev = 0 + for c in np.asarray(self._valid_cum).tolist(): + c = int(c) + if c > prev: + blocks.append((prev, c - prev)) + prev = c + return blocks + + # ---- sample build ------------------------------------------------------ + + def __getitem__(self, idx: int) -> dict[str, Any]: + # Resample a different valid window if a frame fails to decode (bounded retries). + n = len(self) + last_err: Exception | None = None + for _attempt in range(8): + try: + return self._build_item(idx) + except Exception as e: # noqa: BLE001 — skip past undecodable frames + last_err = e + log.warning(f"LIBERO: sample idx={idx} failed to load ({type(e).__name__}: {e}); resampling") + if n > 0: + idx = random.randint(0, n - 1) + raise RuntimeError(f"LIBERO: failed to load a sample after 8 resamples; last error: {last_err}") + + def _build_item(self, idx: int) -> dict[str, Any]: + mode = self._choose_mode() + idx = int(idx) + ep = int(np.searchsorted(self._valid_cum, idx, side="right")) + prev = int(self._valid_cum[ep - 1]) if ep > 0 else 0 + start = int(self._ep_starts[ep]) + (idx - prev) + episode_index = int(self._ep_vals[ep]) + episode = self._episodes[episode_index] + + stop = start + self._chunk_length + 1 + timestamps = [float(self._row_timestamp[j]) for j in range(start, stop)] + video = self._load_video(episode, timestamps) + + # frame_wise_relative: chunk per-frame deltas are the stored actions directly. + raw = self._row_action[start : start + self._chunk_length] # [chunk, 7] + action = self._build_frame_wise_action(raw) + + task = self._tasks[int(self._row_task[start])] + ai_caption = random.choice([p.strip() for p in task.split(" | ") if p.strip()] or [task]) + + extras: dict[str, Any] = {} + if self._camera_mode == "concat_view": + extras["additional_view_description"] = ( + "The left half shows the third-person view; the right half shows the wrist-mounted camera." + ) + return self._build_result(mode=mode, video=video, action=action, ai_caption=ai_caption, **extras) + + def _build_frame_wise_action(self, raw: np.ndarray) -> torch.Tensor: + raw_t = torch.from_numpy(np.ascontiguousarray(raw)).float() # [chunk, 7] + translation = raw_t[:, 0:3] + rotation_matrix = convert_rotation(raw_t[:, 3:6], input_format="axisangle", output_format="matrix") + rotation = convert_rotation( + rotation_matrix, input_format="matrix", output_format=libero_rotation_format(self._rotation_space) + ) + gripper = raw_t[:, 6:7] + return torch.cat([translation, rotation, gripper], dim=-1) # [chunk, action_dim] + + def _load_video(self, episode: dict[str, Any], timestamps: list[float]) -> torch.Tensor: + frames_by_view = {} + for key in self._video_keys: + from_ts = float(episode.get(f"videos/{key}/from_timestamp", 0.0)) + frames = decode_video_frames( + self._video_path(episode, key), + [from_ts + ts for ts in timestamps], + self._tolerance_s, + ) # [T, C, H, W] in [0, 1] + frames = self._resize(frames) + frames_by_view[key] = frames + if self._camera_mode == "concat_view": + # third-person (left) + wrist (right), horizontally concatenated -> [T, C, H, 2W] + return torch.cat([frames_by_view[_IMAGE_FEATURE], frames_by_view[_WRIST_FEATURE]], dim=-1) + return frames_by_view[self._video_keys[0]] + + def _resize(self, frames: torch.Tensor) -> torch.Tensor: + if frames.shape[-1] == self._image_size and frames.shape[-2] == self._image_size: + return frames + return F.interpolate( + frames, size=(self._image_size, self._image_size), mode="bilinear", align_corners=False + ) diff --git a/cosmos_framework/data/vfm/action/datasets/stats/libero_native_frame_wise_relative_rot6d.json b/cosmos_framework/data/vfm/action/datasets/stats/libero_native_frame_wise_relative_rot6d.json new file mode 100644 index 0000000..a705e7c --- /dev/null +++ b/cosmos_framework/data/vfm/action/datasets/stats/libero_native_frame_wise_relative_rot6d.json @@ -0,0 +1,37 @@ +{ + "metadata": { + "embodiment_type": "libero", + "pose_convention": "frame_wise_relative", + "pose_coordinate_frame": "native", + "rotation_format": "6d", + "action_dim": 10, + "skip_rotation_dims": [3, 4, 5, 6, 7, 8], + "chunk_length": 16, + "sample_stride": null, + "dataset_name": "libero", + "dataset_class": "LIBEROLeRobotDataset", + "dataset_root": ["outputs/libero_datasets/libero_10", "outputs/libero_datasets/libero_object", "outputs/libero_datasets/libero_spatial", "outputs/libero_datasets/libero_goal"], + "_comment": "Dataset paths are placeholders; the statistics values are independent of local dataset location.", + "split": "train", + "num_samples_stats": 10000, + "reservoir_size": 50000, + "max_samples": 10000, + "sampling_seed": 42 + }, + "global": { + "mean": [ 0.050704, 0.097407, -0.094833, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.476725], + "std": [ 0.333621, 0.387175, 0.457140, 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, 0.499460], + "min": [-0.937500, -0.937500, -0.937500, -1.000000, -1.000000, -1.000000, -1.000000, -1.000000, -1.000000, 0.000000], + "max": [ 0.937500, 0.937500, 0.937500, 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, 1.000000], + "q01": [-0.723214, -0.808929, -0.937500, -1.000000, -1.000000, -1.000000, -1.000000, -1.000000, -1.000000, 0.000000], + "q99": [ 0.937500, 0.870536, 0.937500, 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, 1.000000] + }, + "global_raw": { + "mean": [ 0.050704, 0.097407, -0.094833, 0.994873, -0.004579, -0.004288, 0.004389, 0.996104, 0.001109, 0.476725], + "std": [ 0.333621, 0.387175, 0.457140, 0.010807, 0.077802, 0.063386, 0.078571, 0.009994, 0.038504, 0.499460], + "min": [-0.937500, -0.937500, -0.937500, 0.902028, -0.356085, -0.367416, -0.370434, 0.921907, -0.255000, 0.000000], + "max": [ 0.937500, 0.937500, 0.937500, 1.000000, 0.368853, 0.341214, 0.356395, 1.000000, 0.348251, 1.000000], + "q01": [-0.723214, -0.808929, -0.937500, 0.934955, -0.223431, -0.189878, -0.334735, 0.938516, -0.107736, 0.000000], + "q99": [ 0.937500, 0.870536, 0.937500, 1.000000, 0.331000, 0.163153, 0.226216, 1.000000, 0.127158, 1.000000] + } +} diff --git a/cosmos_framework/data/vfm/action/libero_pose_utils.py b/cosmos_framework/data/vfm/action/libero_pose_utils.py new file mode 100644 index 0000000..3a4fd8e --- /dev/null +++ b/cosmos_framework/data/vfm/action/libero_pose_utils.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Small LIBERO pose helpers shared by training and closed-loop eval.""" + +from __future__ import annotations + +import numpy as np +import torch + +from cosmos_framework.data.vfm.action.pose_utils import ( + RotationConvention, + build_abs_pose_from_components, +) + +# Local-frame post-rotation pattern: +# R_opencv = R_native @ *_TO_OPENCV. +LIBERO_TO_OPENCV: np.ndarray = np.array( + [[0.0, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]], + dtype=np.float32, +) + +LIBERO_ROTATION_FORMATS: dict[str, RotationConvention] = { + "3d": "axisangle", + "6d": "rot6d", + "9d": "rot9d", +} +LIBERO_ACTION_DIMS: dict[str, int] = {"3d": 7, "6d": 10, "9d": 13} + + +def libero_rotation_format(rotation_space: str) -> RotationConvention: + """Return the shared ``pose_utils`` rotation format for a LIBERO setting.""" + rotation_format = LIBERO_ROTATION_FORMATS.get(rotation_space) + if rotation_format is None: + raise ValueError(f"Unsupported rotation_space={rotation_space!r}. Use 3d/6d/9d.") + return rotation_format + + +def libero_action_dim(rotation_space: str) -> int: + """Return ``[xyz, rotation, gripper]`` action width for LIBERO.""" + action_dim = LIBERO_ACTION_DIMS.get(rotation_space) + if action_dim is None: + raise ValueError(f"Unsupported rotation_space={rotation_space!r}. Use 3d/6d/9d.") + return action_dim + + +def libero_rotation_space_from_action_dim(action_dim: int) -> str: + """Infer LIBERO rotation space from unpadded action width.""" + for rotation_space, dim in LIBERO_ACTION_DIMS.items(): + if dim == action_dim: + return rotation_space + raise ValueError(f"Unable to infer rotation_space from action_dim={action_dim}.") + + +def build_libero_abs_pose(state_raw: torch.Tensor | np.ndarray, *, to_opencv: bool) -> np.ndarray: + """Build absolute LIBERO EE poses from state rows. + + ``state_raw`` is ``[x,y,z,axisangle(3),gripper(2)]``. When requested, the + local EE frame is post-rotated into the shared OpenCV-style action frame. + """ + if isinstance(state_raw, torch.Tensor): + state_np = state_raw.detach().cpu().numpy().astype(np.float32, copy=False) + else: + state_np = np.asarray(state_raw, dtype=np.float32) + + poses_abs = build_abs_pose_from_components(state_np[:, :3], state_np[:, 3:6], "axisangle") + if to_opencv: + poses_abs[:, :3, :3] = poses_abs[:, :3, :3] @ LIBERO_TO_OPENCV + return poses_abs diff --git a/cosmos_framework/scripts/action_policy_server_libero.py b/cosmos_framework/scripts/action_policy_server_libero.py index 7382b97..2fa0567 100644 --- a/cosmos_framework/scripts/action_policy_server_libero.py +++ b/cosmos_framework/scripts/action_policy_server_libero.py @@ -63,6 +63,10 @@ # Action-specific helpers live in the in-tree project tree. Imports stay as # `projects.cosmos3.vfm.*` and are auto-rewritten to `cosmos3._src.vfm.*` by the # cosmos-framework release script. +from cosmos_framework.data.vfm.action.action_processing import ( + ActionProcessingRecord, + make_batched_action_processing_fields, +) from cosmos_framework.data.vfm.action.domain_utils import get_domain_id from cosmos_framework.data.vfm.action.transforms import ( build_sequence_plan_from_mode, @@ -431,7 +435,8 @@ class ActionServerArgs(pydantic.BaseModel): # ``OmniSetupOverrides`` programmatically in ``build_setup_overrides``. checkpoint: tyro.conf.OmitArgPrefixes[CheckpointOverrides] = CheckpointOverrides.model_construct() - """Checkpoint and config loading configuration.""" + """Checkpoint and config loading configuration. ``use_ema_weights`` lives here and + defaults True at inference (suppressed from CLI) -> evals load net_ema by default.""" output_dir: Path | None = None """Output directory for ``OmniInference`` (saved config.yaml, benchmarks). @@ -804,101 +809,53 @@ def get_info(self) -> dict[str, Any]: # Predict # ------------------------------------------------------------------ - def predict_policy(self, req: dict[str, Any]) -> dict[str, Any]: - """ - Run policy inference: given an observation image and prompt, predict actions. - - Input request format: - { - "image": "", - "prompt": "", - "domain_name": "", - "image_size": - } - - Output format: - { - "action": [[a0_0, a0_1, ...], ..., [aN_0, aN_1, ...]], - "video": ["", ...] # List of T base64-encoded PNG frames - } - - All action dimensions are returned. Video is the decoded predicted rollout as base64 PNGs. - """ - t0 = time.monotonic() - - # Get or assign request ID - injected_id = req.get("request_id", None) - if isinstance(injected_id, int) and injected_id > 0: - request_id = int(injected_id) - else: - with self._req_id_lock: - self._req_id += 1 - request_id = int(self._req_id) + def _input_video_key(self) -> str: + input_video_key = getattr(self.model, "input_video_key", None) + if input_video_key is None: + input_video_key = getattr(self.model, "config", None).input_video_key # type: ignore[union-attr] + return input_video_key - # Validate request + def _prep_policy_item(self, req: dict[str, Any]) -> dict[str, Any]: + """Validate one request and build the per-sample model inputs (video pad, + prompt augmentation, sequence_plan). Shared by predict_policy (batch=1) and + predict_policy_batch (batch=N) so the two paths stay byte-identical per item.""" image_b64 = req.get("image") if not isinstance(image_b64, str): raise ValueError("'image' must be a base64 string") - prompt = req.get("prompt") if not isinstance(prompt, str): raise ValueError("'prompt' must be a string") - domain_name = req.get("domain_name") if not isinstance(domain_name, str): raise ValueError("'domain_name' must be a string") - image_size = req.get("image_size") if not isinstance(image_size, int) or image_size <= 0: raise ValueError("'image_size' must be a positive integer") - # Decode image - t_decode0 = time.monotonic() img_chw_uint8 = _decode_base64_png_to_rgb_uint8(image_b64) img_h, img_w = img_chw_uint8.shape[-2:] - - # Handle resizing: for multi-view (non-square) images, scale proportionally - # to maintain aspect ratio while matching height to image_size + # Multi-view (non-square) images: scale proportionally, matching height to image_size. if img_h != image_size: - # Calculate new width to maintain aspect ratio scale = image_size / img_h new_w = int(round(img_w * scale)) - hwc = img_chw_uint8.permute(1, 2, 0).cpu().numpy() # [H,W,3] + hwc = img_chw_uint8.permute(1, 2, 0).cpu().numpy() resized = Image.fromarray(hwc).resize((new_w, image_size), resample=Image.Resampling.BILINEAR) arr = np.asarray(resized, dtype=np.uint8).copy() - img_chw_uint8 = torch.from_numpy(arr).permute(2, 0, 1).contiguous() # [3,H,W] # [3,H,W] - t_decode1 = time.monotonic() + img_chw_uint8 = torch.from_numpy(arr).permute(2, 0, 1).contiguous() - # Construct batch in IterativeJointDataLoader format (list-of-lists for multi-item keys) t_frames = self.cfg.action_chunk_size + 1 _, final_h, final_w = img_chw_uint8.shape video_c_t_h_w_uint8 = img_chw_uint8.unsqueeze(1).repeat(1, t_frames, 1, 1) # [3,T,H,W] - - # Apply reflection padding to match closest predefined resolution resolution = get_vision_data_resolution((final_h, final_w)) target_w, target_h = find_closest_target_size(final_h, final_w, resolution) pad_dict: dict[str, Any] = {"video": video_c_t_h_w_uint8} reflection_pad_to_target(pad_dict, ["video"], True, target_w, target_h) - video_padded = pad_dict["video"] # (C, T, target_h, target_w) - padded_image_size = pad_dict["image_size"] # (4,) - - # Action: zeros tensor as noise starting point for policy mode - action_t_d = torch.zeros( - (self.cfg.action_chunk_size, self.cfg.max_action_dim), - dtype=torch.float32, - ) # [T,action_dim] - - input_video_key = getattr(self.model, "input_video_key", None) - if input_video_key is None: - input_video_key = getattr(self.model, "config", None).input_video_key # type: ignore[union-attr] - sequence_plan = build_sequence_plan_from_mode( mode="policy", video_length=self.cfg.action_chunk_size + 1, action_length=self.cfg.action_chunk_size, has_text=True, ) - augmented_prompt = _augment_prompt_with_metadata( prompt, t_frames=t_frames, @@ -908,10 +865,126 @@ def predict_policy(self, req: dict[str, Any]) -> dict[str, Any]: append_duration_fps=self.append_duration_fps, append_resolution_info=self.append_resolution_info, ) + return { + "img_chw_uint8": img_chw_uint8, + "video_padded": pad_dict["video"], + "padded_image_size": pad_dict["image_size"], + "augmented_prompt": augmented_prompt, + "sequence_plan": sequence_plan, + "domain_name": domain_name, + "image_size": image_size, + } + + def predict_policy_batch(self, reqs: list[dict[str, Any]]) -> dict[str, Any]: + """Batched policy inference: N requests -> ONE diffusion forward (batch_size=N) + -> N denormalized action chunks. Skips vision decode (the vectorized eval client + only needs actions), so it is ~N x faster than N serial /predict calls.""" + t0 = time.monotonic() + if not isinstance(reqs, list) or not reqs: + raise ValueError("'items' must be a non-empty list of policy requests") + preps = [self._prep_policy_item(r) for r in reqs] + n = len(preps) + action_t_d = torch.zeros( + (self.cfg.action_chunk_size, self.cfg.max_action_dim), dtype=torch.float32 + ) + input_video_key = self._input_video_key() + batch: dict[str, Any] = { + input_video_key: [[p["video_padded"]] for p in preps], + **make_batched_action_processing_fields( + ActionProcessingRecord(raw_action_dim=self.raw_action_dim, action_normalizer=None), + batch_size=n, + ), + "action": [[action_t_d] for _ in preps], + "mode": ["policy"] * n, + "ai_caption": [p["augmented_prompt"] for p in preps], + "prompt": [p["augmented_prompt"] for p in preps], + "conditioning_fps": [torch.tensor(self.cfg.fps, dtype=torch.long) for _ in preps], + "image_size": torch.stack([p["padded_image_size"] for p in preps]).to(device="cuda"), + "domain_id": [torch.tensor(get_domain_id(p["domain_name"]), dtype=torch.long) for p in preps], + "sequence_plan": [p["sequence_plan"] for p in preps], + } + t_inf0 = time.monotonic() + with self._lock: + with torch.inference_mode(): + samples = self.model.generate_samples_from_batch( + batch, + guidance=self.cfg.guidance, + seed=[self.cfg.seed] * n, + num_steps=self.cfg.num_steps, + has_negative_prompt=False, + ) + t_inf1 = time.monotonic() + actions: list[list[list[float]]] = [] + for i in range(n): + pred = samples["action"][i].float().squeeze(0) # [T,D] + pred = self._denormalize_action(pred) + actions.append(pred.detach().cpu().numpy().tolist()) + log.info( + f"[action-server] predict_batch n={n} steps={self.cfg.num_steps} " + f"ms_total={(time.monotonic() - t0) * 1000.0:.1f} ms_infer={(t_inf1 - t_inf0) * 1000.0:.1f}" + ) + return {"actions": actions} + + def predict_policy(self, req: dict[str, Any]) -> dict[str, Any]: + """ + Run policy inference: given an observation image and prompt, predict actions. + + Input request format: + { + "image": "", + "prompt": "", + "domain_name": "", + "image_size": + } + + Output format: + { + "action": [[a0_0, a0_1, ...], ..., [aN_0, aN_1, ...]], + "video": ["", ...] # List of T base64-encoded PNG frames + } + + All action dimensions are returned. Video is the decoded predicted rollout as base64 PNGs. + """ + t0 = time.monotonic() + + # Get or assign request ID + injected_id = req.get("request_id", None) + if isinstance(injected_id, int) and injected_id > 0: + request_id = int(injected_id) + else: + with self._req_id_lock: + self._req_id += 1 + request_id = int(self._req_id) + + # Per-item preprocessing (validation, decode/resize/pad, prompt, sequence_plan). + t_decode0 = time.monotonic() + prep = self._prep_policy_item(req) + t_decode1 = time.monotonic() + img_chw_uint8 = prep["img_chw_uint8"] + video_padded = prep["video_padded"] + padded_image_size = prep["padded_image_size"] + augmented_prompt = prep["augmented_prompt"] + sequence_plan = prep["sequence_plan"] + domain_name = prep["domain_name"] + image_size = prep["image_size"] + + # Action: zeros tensor as noise starting point for policy mode + action_t_d = torch.zeros( + (self.cfg.action_chunk_size, self.cfg.max_action_dim), + dtype=torch.float32, + ) # [T,action_dim] + + input_video_key = self._input_video_key() batch: dict[str, Any] = { input_video_key: [[video_padded]], - "raw_action_dim": [torch.tensor(self.raw_action_dim, dtype=torch.long)], + # Provide BOTH raw_action_dim and the action_processing_record the model + # needs to externalize (invert) the generated action; building the batch + # by hand previously omitted the record -> "cannot be externalized". + **make_batched_action_processing_fields( + ActionProcessingRecord(raw_action_dim=self.raw_action_dim, action_normalizer=None), + batch_size=1, + ), "action": [[action_t_d]], "mode": ["policy"], "ai_caption": [augmented_prompt], @@ -1103,7 +1176,7 @@ def do_GET(self) -> None: # noqa: N802 self._send_json(404, {"error": "Not found"}) def do_POST(self) -> None: # noqa: N802 - if self.path not in ("/", "/predict"): + if self.path not in ("/", "/predict", "/predict_batch"): self._send_json(404, {"error": "Not found"}) return @@ -1147,13 +1220,21 @@ def do_POST(self) -> None: # noqa: N802 f"path={self.path} bytes={length}" ) + is_batch = self.path == "/predict_batch" try: - out = service.predict_policy(req) + if is_batch: + out = service.predict_policy_batch(req.get("items", [])) + else: + out = service.predict_policy(req) except Exception as e: err = str(e) traceback.print_exc() - payload = {"action": [], "error": err, "request_id": req.get("request_id")} + payload = ( + {"actions": [], "error": err} + if is_batch + else {"action": [], "error": err, "request_id": req.get("request_id")} + ) log.error(f"[action-server] request_id={req.get('request_id')} ERROR: {err}") # Dump failed request for offline debugging if enabled. diff --git a/cosmos_framework/simulation/__init__.py b/cosmos_framework/simulation/__init__.py new file mode 100644 index 0000000..28a81be --- /dev/null +++ b/cosmos_framework/simulation/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 diff --git a/cosmos_framework/simulation/libero/__init__.py b/cosmos_framework/simulation/libero/__init__.py new file mode 100644 index 0000000..503ec1b --- /dev/null +++ b/cosmos_framework/simulation/libero/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + diff --git a/cosmos_framework/simulation/libero/closed_loop_eval.py b/cosmos_framework/simulation/libero/closed_loop_eval.py new file mode 100644 index 0000000..660be36 --- /dev/null +++ b/cosmos_framework/simulation/libero/closed_loop_eval.py @@ -0,0 +1,1343 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +""" +Closed-loop evaluation for LIBERO using the Action HTTP inference server. + +# Single-view example (agentview camera): +PYTHONPATH=. python cosmos_framework/simulation/libero/closed_loop_eval.py \ + --server_url http://localhost:8000 \ + --task_suite libero_10 \ + --num_trials_per_task 10 \ + --action_horizon 16 \ + --camera agentview \ + --save_gifs --gif_fps 20 \ + --action_space frame_wise_relative \ + --rotation_space 6d \ + --action_dim 10 \ + --output_dir results/libero_closed_loop_10_single_view + +# Multi-view example (agentview + wrist cameras): +PYTHONPATH=. python cosmos_framework/simulation/libero/closed_loop_eval.py \ + --server_url http://localhost:8000 \ + --task_suite libero_goal \ + --num_trials_per_task 2 \ + --action_horizon 16 \ + --camera agentview,wrist \ + --save_gifs --gif_fps 20 \ + --action_space frame_wise_relative \ + --rotation_space 6d \ + --action_dim 10 \ + --output_dir results/libero_closed_loop_goal_multiview +""" + +from __future__ import annotations + +import argparse +import base64 +import io +import json +import os +import random +import sys +import time +from dataclasses import dataclass +from collections.abc import Callable +from pathlib import Path +from typing import Any + +import numpy as np +import requests +from PIL import Image +from scipy.spatial.transform import Rotation as R + +from cosmos_framework.data.vfm.action.libero_pose_utils import ( + libero_rotation_format, + libero_rotation_space_from_action_dim, +) +from cosmos_framework.data.vfm.action.pose_utils import convert_rotation +from cosmos_framework.data.vfm.action.viewpoint_utils import DEFAULT_VIEWPOINT_TEMPLATES + +benchmark: Any +get_libero_path: Any +OffScreenRenderEnv: Any + + +TASK_MAX_STEPS: dict[str, int] = { + "libero_spatial": 220, + "libero_object": 280, + "libero_goal": 300, + "libero_10": 520, + "libero_90": 400, +} + + +_CAMERA_PROMPT_NAMES: dict[str, str] = { + "agentview": "third-person view", + "wrist": "wrist-mounted camera", +} + + +def _append_prompt_sentence(prompt: str, sentence: str) -> str: + """Append one metadata sentence using the same separator convention as training augmentors.""" + if sentence in prompt: + return prompt + prompt = prompt.rstrip() + if not prompt: + return sentence.rstrip() + separator = " " if prompt.rstrip().endswith(".") else ". " + return prompt + separator + sentence.rstrip() + + +def _concat_view_layout_description(cameras: list[str]) -> str: + """Describe the horizontal camera layout sent by ``ActionEnvironmentClient``.""" + camera_names = [_CAMERA_PROMPT_NAMES[camera] for camera in cameras] + if len(camera_names) == 2: + return f"The left half shows the {camera_names[0]}; the right half shows the {camera_names[1]}." + layout = ", ".join(camera_names) + return f"The views are concatenated horizontally from left to right as: {layout}." + + +def _augment_task_prompt_with_viewpoint(task_description: str, cameras: list[str]) -> str: + """Concat-view caption augmentation for closed-loop LIBERO eval.""" + if len(cameras) <= 1: + return task_description + prompt = _append_prompt_sentence(task_description, DEFAULT_VIEWPOINT_TEMPLATES["concat_view"]) + return _append_prompt_sentence(prompt, _concat_view_layout_description(cameras)) + + +def _rotation_repr_to_mat(rotation: np.ndarray, rotation_space: str) -> np.ndarray: + """Convert a single LIBERO rotation block to a 3x3 rotation matrix.""" + matrix = convert_rotation( + rotation, + libero_rotation_format(rotation_space), + "matrix", + normalize_matrix=rotation_space != "3d", + ) + if not isinstance(matrix, np.ndarray): + raise TypeError(f"Expected NumPy rotation matrix, got {type(matrix)!r}") + return matrix + + +@dataclass +class EpisodeResult: + success: bool + steps: int + error: str | None + actions: list[list[float]] + + +class ActionEnvironmentClient: + """Client for interacting with the Action model server.""" + + server_url: str + domain_name: str + prompt: str + image_size: int + timeout: float + + def __init__( + self, + server_url: str, + domain_name: str, + prompt: str, + image_size: int, + timeout: float, + ) -> None: + self.server_url = server_url.rstrip("/") + self.domain_name = domain_name + self.prompt = prompt + self.image_size = image_size + self.timeout = timeout + + def check_health(self) -> bool: + """Check if the model server is healthy.""" + try: + resp = requests.get(f"{self.server_url}/", timeout=5.0) + return resp.status_code == 200 + except requests.RequestException: + return False + + def get_info(self) -> dict[str, str]: + """Get model server info.""" + resp = requests.get(f"{self.server_url}/info", timeout=5.0) + resp.raise_for_status() + return resp.json() + + def notify_next_episode(self) -> None: + """Notify server to advance to next episode (used with dataset action server).""" + try: + requests.post( + f"{self.server_url}/next_episode", + json={"prompt": self.prompt}, + timeout=5.0, + ) + except requests.RequestException: + pass + + def encode_image(self, image: np.ndarray) -> str: + """Encode a numpy image (H, W, 3) uint8 to base64 PNG, resizing to image_size.""" + if image.dtype != np.uint8: + if image.max() <= 1.0: + image = (image * 255.0).round().astype(np.uint8) + else: + image = image.astype(np.uint8) + pil_img = Image.fromarray(image) + if pil_img.size != (self.image_size, self.image_size): + pil_img = pil_img.resize( + (self.image_size, self.image_size), + resample=Image.Resampling.BILINEAR, + ) + buf = io.BytesIO() + pil_img.save(buf, format="PNG") + return base64.b64encode(buf.getvalue()).decode("ascii") + + def encode_image_raw(self, image: np.ndarray) -> str: + """Encode a numpy image (H, W, 3) uint8 to base64 PNG without resizing.""" + if image.dtype != np.uint8: + if image.max() <= 1.0: + image = (image * 255.0).round().astype(np.uint8) + else: + image = image.astype(np.uint8) + pil_img = Image.fromarray(image) + buf = io.BytesIO() + pil_img.save(buf, format="PNG") + return base64.b64encode(buf.getvalue()).decode("ascii") + + def resize_image(self, image: np.ndarray) -> np.ndarray: + """Resize image to model input size.""" + if image.dtype != np.uint8: + if image.max() <= 1.0: + image = (image * 255.0).round().astype(np.uint8) + else: + image = image.astype(np.uint8) + pil_img = Image.fromarray(image) + if pil_img.size != (self.image_size, self.image_size): + pil_img = pil_img.resize( + (self.image_size, self.image_size), + resample=Image.Resampling.BILINEAR, + ) + return np.array(pil_img) + + def concatenate_images(self, images: list[np.ndarray]) -> np.ndarray: + """Resize each image and concatenate horizontally (side-by-side). + + Args: + images: List of images with shape (H, W, 3). + + Returns: + Concatenated image with shape (image_size, image_size*num_views, 3). + """ + resized = [self.resize_image(img) for img in images] + return np.concatenate(resized, axis=1) + + def predict(self, observation: np.ndarray | list[np.ndarray]) -> dict[str, Any]: + """Send observation(s) to model server and get predicted actions. + + Args: + observation: Single image as np.ndarray or list of images for multi-view. + For multi-view, images are resized and concatenated horizontally before sending. + """ + if isinstance(observation, list): + # Multi-view: resize each, concatenate horizontally, and send as single image + concatenated = self.concatenate_images(observation) + encoded = self.encode_image_raw(concatenated) + else: + # Single view: send single image + encoded = self.encode_image(observation) + + payload = { + "image": encoded, + "prompt": self.prompt, + "domain_name": self.domain_name, + "image_size": self.image_size, + } + + resp = requests.post( + f"{self.server_url}/predict", + json=payload, + headers={"Content-Type": "application/json"}, + timeout=self.timeout, + ) + resp.raise_for_status() + + result = resp.json() + if "error" in result and result["error"]: + raise RuntimeError(f"Model server error: {result['error']}") + return result + + def predict_batch(self, observations: list[list[np.ndarray]]) -> list[list[list[float]]]: + """Batched inference: a list of per-env multi-view observations -> ONE + POST /predict_batch -> a list of action chunks (one per env). Used by the + vectorized eval so N parallel envs share a single diffusion forward.""" + items = [] + for obs_imgs in observations: + concat = self.concatenate_images(obs_imgs) if len(obs_imgs) > 1 else self.resize_image(obs_imgs[0]) + items.append( + { + "image": self.encode_image_raw(concat), + "prompt": self.prompt, + "domain_name": self.domain_name, + "image_size": self.image_size, + } + ) + resp = requests.post( + f"{self.server_url}/predict_batch", + json={"items": items}, + headers={"Content-Type": "application/json"}, + timeout=max(self.timeout, 300.0), + ) + resp.raise_for_status() + result = resp.json() + if "error" in result and result["error"]: + raise RuntimeError(f"Model server error: {result['error']}") + return result["actions"] + + +def _find_accessible_dri_nodes() -> list[Path]: + dri_path = Path("/dev/dri") + if not dri_path.exists(): + return [] + nodes = list(dri_path.glob("renderD*")) + list(dri_path.glob("card*")) + return [node for node in nodes if os.access(node, os.R_OK | os.W_OK)] + + +def _resolve_mujoco_backend(requested_backend: str) -> tuple[str, str]: + requested_backend = requested_backend.lower() + if requested_backend != "auto": + return requested_backend, "requested" + + env_backend = os.environ.get("MUJOCO_GL") + if env_backend: + return env_backend.lower(), "env" + + if _find_accessible_dri_nodes(): + return "egl", "auto-gpu" + return "osmesa", "auto-cpu" + + +def _configure_mujoco_env(requested_backend: str) -> str: + backend, source = _resolve_mujoco_backend(requested_backend) + if backend not in {"egl", "osmesa", "glfw"}: + raise ValueError(f"Unsupported MuJoCo GL backend: {backend!r}. Use auto, egl, osmesa, or glfw.") + + os.environ["MUJOCO_GL"] = backend + if backend == "egl": + os.environ["PYOPENGL_PLATFORM"] = "egl" + elif backend == "osmesa": + os.environ["PYOPENGL_PLATFORM"] = "osmesa" + return f"{backend} ({source})" + + +def _import_libero() -> None: + global benchmark, get_libero_path, OffScreenRenderEnv + try: + from libero.libero import benchmark as libero_benchmark + from libero.libero import get_libero_path as libero_get_libero_path + from libero.libero.envs import OffScreenRenderEnv as libero_offscreen_render_env + except ImportError as exc: # pragma: no cover - environment-specific dependency + raise RuntimeError( + "Failed to import LIBERO. Make sure the LIBERO environment is activated. " + f"python={sys.executable!r}, import_error={exc!r}" + ) from exc + + benchmark = libero_benchmark + get_libero_path = libero_get_libero_path + OffScreenRenderEnv = libero_offscreen_render_env + + +def _wait_for_server(client: ActionEnvironmentClient, timeout_s: float) -> None: + start = time.perf_counter() + while time.perf_counter() - start < timeout_s: + if client.check_health(): + return + time.sleep(1.0) + raise RuntimeError(f"Timed out waiting for server at {client.server_url}") + + +def _get_libero_env( + task: Any, + *, + resolution: int, + seed: int, + render_gpu_device_id: int, +) -> tuple[Any, str]: + task_description = str(task.language) + task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file) + env_args = { + "bddl_file_name": task_bddl_file, + "camera_heights": resolution, + "camera_widths": resolution, + "render_gpu_device_id": render_gpu_device_id, + } + env = OffScreenRenderEnv(**env_args) + env.seed(seed) + return env, task_description + + +def _get_libero_dummy_action() -> list[float]: + return [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.0] + + +def _get_libero_image( + obs: dict[str, Any], + camera: str, + *, + flip_images: bool, + rotate_180: bool, +) -> np.ndarray: + if camera == "agentview": + image = obs["agentview_image"] + elif camera == "wrist": + image = obs["robot0_eye_in_hand_image"] + else: + raise ValueError(f"Unsupported camera={camera!r}. Use 'agentview' or 'wrist'.") + + if rotate_180: + image = image[::-1, ::-1] + if flip_images: + image = np.flipud(image) + return image + + +def _get_libero_images( + obs: dict[str, Any], + cameras: list[str], + *, + flip_images: bool, + rotate_180: bool, +) -> list[np.ndarray]: + """Get images from multiple cameras.""" + return [_get_libero_image(obs, camera, flip_images=flip_images, rotate_180=rotate_180) for camera in cameras] + + +def _ensure_uint8_image(image: np.ndarray) -> np.ndarray: + if image.dtype != np.uint8: + if image.max() <= 1.0: + image = (image * 255.0).round().astype(np.uint8) + else: + image = image.astype(np.uint8) + return image + + +def _save_gif(frames: list[Image.Image], output_path: Path, fps: int) -> None: + if not frames: + return + duration_ms = int(1000 / fps) if fps > 0 else 100 + output_path.parent.mkdir(parents=True, exist_ok=True) + first, *rest = frames + first.save( + output_path, + save_all=True, + append_images=rest, + duration=duration_ms, + loop=0, + ) + + +def _decode_b64_frames(b64_frames: list[str]) -> list[Image.Image]: + """Decode a list of base64-encoded PNG strings into PIL Images.""" + images: list[Image.Image] = [] + for b64 in b64_frames: + raw = base64.b64decode(b64) + images.append(Image.open(io.BytesIO(raw)).convert("RGB")) + return images + + +def _save_comparison_gif( + comparison_windows: list[tuple[list[Image.Image], list[Image.Image]]], + output_path: Path, + fps: int, + target_height: int = 256, + separator_width: int = 4, +) -> None: + """Create and save a side-by-side comparison GIF (Action prediction | env rollout). + + Each window is a (action_frames, env_frames) pair from one prediction call. + Frames are paired index-by-index; the conditioning frame (index 0) of + subsequent windows is skipped to avoid duplicating the boundary frame. + """ + from PIL import ImageDraw + + combined_frames: list[Image.Image] = [] + banner_h = 16 + + for window_idx, (action_frames, env_frames) in enumerate(comparison_windows): + n = min(len(action_frames), len(env_frames)) + start = 1 if window_idx > 0 else 0 + for i in range(start, n): + action_img = action_frames[i] + env_img = env_frames[i] + + action_w = int(action_img.width * target_height / action_img.height) + env_w = int(env_img.width * target_height / env_img.height) + action_resized = action_img.resize((action_w, target_height), Image.Resampling.BILINEAR) + env_resized = env_img.resize((env_w, target_height), Image.Resampling.BILINEAR) + + total_w = action_w + separator_width + env_w + total_h = target_height + banner_h + combined = Image.new("RGB", (total_w, total_h), color=0) + + draw = ImageDraw.Draw(combined) + draw.rectangle([(0, 0), (action_w, banner_h)], fill=(30, 30, 60)) + draw.rectangle([(action_w + separator_width, 0), (total_w, banner_h)], fill=(30, 60, 30)) + draw.text((4, 1), "Action Prediction", fill=(100, 180, 255)) + draw.text((action_w + separator_width + 4, 1), "Environment", fill=(100, 255, 100)) + + combined.paste(action_resized, (0, banner_h)) + combined.paste(env_resized, (action_w + separator_width, banner_h)) + combined_frames.append(combined) + + if combined_frames: + _save_gif(combined_frames, output_path, fps) + + +def _select_action_chunk(actions: list[list[float]], action_horizon: int) -> list[list[float]]: + if action_horizon <= 0 or action_horizon >= len(actions): + return actions + return actions[:action_horizon] + + +def _format_action(action: list[float], action_dim: int) -> list[float]: + if len(action) < action_dim: + raise ValueError(f"Action dimension {len(action)} smaller than expected {action_dim}") + return action[:action_dim] + + +def _remap_gripper(action: list[float], mode: str) -> list[float]: + """Map the model's gripper command to the LIBERO env's [-1, 1] (negative = open). + + The right mapping depends on the gripper convention of the dataset the policy + was trained on (the server denormalizes back to that raw convention): + + * ``zero_one`` (NVIDIA LIBERO_LeRobot_v3): raw gripper in [0, 1]; the env wants + [-1, 1] with negative=open. The i4/cosmos-rl reference BINARIZES this to hard + {-1, +1} via ``-sign(2g - 1)`` (not the continuous ``1 - 2g`` from issue #50). + For a confident policy the two agree (g~0/1), but an undertrained policy emits + g~0.5 where continuous ``1-2g``~0 never actuates the gripper -> grasps fail. + Binarizing matches the reference and is robust to weak checkpoints. + * ``pm_one`` (community ``lerobot/libero_*``): raw gripper already in {-1, +1} + (robosuite convention) -> pass through (clamped). + * ``pm_one_flip``: {-1, +1} but with inverted open/close sign. + """ + action = list(action) # avoid mutating the caller's list + g = action[-1] + if mode == "zero_one": + action[-1] = max(-1.0, min(1.0, g * 2.0 - 1.0)) * -1.0 # [0,1] -> [-1,1], negative=open (issue #50) + elif mode == "pm_one": + action[-1] = max(-1.0, min(1.0, g)) + elif mode == "pm_one_flip": + action[-1] = max(-1.0, min(1.0, -g)) + else: + raise ValueError(f"Unknown gripper_mode={mode!r}. Use zero_one/pm_one/pm_one_flip.") + return action + + +def _infer_rotation_space(action_dim: int, rotation_space: str) -> str: + if rotation_space != "auto": + return rotation_space + return libero_rotation_space_from_action_dim(action_dim) + + +def _obs_to_pose(obs: dict[str, Any]) -> tuple[np.ndarray, np.ndarray]: + position = np.asarray(obs["robot0_eef_pos"], dtype=np.float32) + quat = np.asarray(obs["robot0_eef_quat"], dtype=np.float32) + rotation = R.from_quat(quat).as_matrix() + return position, rotation + + +def _anchored_action_to_delta( + anchored_action: np.ndarray, + base_pose: tuple[np.ndarray, np.ndarray], + current_pose: tuple[np.ndarray, np.ndarray], + rotation_space: str, +) -> np.ndarray: + anchored_translation = anchored_action[:3] + rotation_dim = anchored_action.shape[0] - 4 + anchored_rotation = anchored_action[3 : 3 + rotation_dim] + gripper = anchored_action[3 + rotation_dim : 4 + rotation_dim] + + base_pos, base_rot = base_pose + current_pos, current_rot = current_pose + + if rotation_space == "3d": + anchored_rot = R.from_rotvec(anchored_rotation).as_matrix() + elif rotation_space == "6d": + anchored_rot = _rotation_repr_to_mat(anchored_rotation, rotation_space) + elif rotation_space == "9d": + anchored_rot = anchored_rotation.reshape(3, 3) + else: + raise ValueError(f"Unsupported rotation_space={rotation_space!r}. Use 3d/6d/9d.") + target_rot = base_rot @ anchored_rot + target_pos = base_pos + base_rot @ anchored_translation + delta_pos = target_pos - current_pos + delta_rot = target_rot @ current_rot.T + delta_rotvec = R.from_matrix(delta_rot).as_rotvec() + + return np.concatenate([delta_pos, delta_rotvec, gripper], axis=0) + + +def _framewise_action_to_delta( + framewise_action: np.ndarray, + rotation_space: str, +) -> np.ndarray: + """Convert a frame-wise policy action to LIBERO's 7D simulator command. + + Frame-wise actions are already per-step deltas in the LIBERO controller's + convention (see ``LiberoDataset`` with ``action_space='frame_wise_relative'``), + so the only conversion required is decoding the chosen rotation + representation back to a rotation vector. No anchor/current pose is needed. + """ + if rotation_space == "3d": + return framewise_action + + translation = framewise_action[:3] + rotation_dim = framewise_action.shape[0] - 4 + rotation_repr = framewise_action[3 : 3 + rotation_dim] + gripper = framewise_action[3 + rotation_dim : 4 + rotation_dim] + rotation_delta = _rotation_repr_to_mat(rotation_repr, rotation_space) + + delta_pos = translation + delta_rotvec = R.from_matrix(rotation_delta).as_rotvec() + return np.concatenate([delta_pos, delta_rotvec, gripper], axis=0) + + +def _run_episode( + env: Any, + client: ActionEnvironmentClient, + *, + cameras: list[str], + flip_images: bool, + rotate_180: bool, + action_horizon: int, + action_dim: int, + action_space: str, + rotation_space: str, + gripper_mode: str, + max_steps: int, + warmup_steps: int, + initial_state: np.ndarray | None, + gif_path: Path | None, + gif_fps: int, + comparison_path: Path | None = None, +) -> EpisodeResult: + env.reset() + if initial_state is not None: + obs = env.set_init_state(initial_state) + else: + obs = env.get_observation() + + action_queue: list[list[float]] = [] + base_pose: tuple[np.ndarray, np.ndarray] | None = None + step = 0 + success = False + gif_frames: list[Image.Image] = [] + action_log: list[list[float]] = [] + is_multi_view = len(cameras) > 1 + resolved_rotation_space = _infer_rotation_space(action_dim, rotation_space) + + comparison_windows: list[tuple[list[Image.Image], list[Image.Image]]] = [] + + def record_frame(current_obs: dict[str, Any]) -> None: + if gif_path is None: + return + image = _get_libero_image( + current_obs, + cameras[0], + flip_images=flip_images, + rotate_180=rotate_180, + ) + image = _ensure_uint8_image(image) + gif_frames.append(Image.fromarray(image).convert("RGB")) + + def capture_comparison_frame(current_obs: dict[str, Any]) -> Image.Image: + """Capture an env frame matching Action's input view (multi-view concatenated if applicable).""" + if is_multi_view: + imgs = _get_libero_images(current_obs, cameras, flip_images=flip_images, rotate_180=rotate_180) + concat = client.concatenate_images(imgs) + return Image.fromarray(_ensure_uint8_image(concat)).convert("RGB") + img = _get_libero_image(current_obs, cameras[0], flip_images=flip_images, rotate_180=rotate_180) + return Image.fromarray(_ensure_uint8_image(img)).convert("RGB") + + record_frame(obs) + + while step < max_steps: + if step < warmup_steps: + dummy = _get_libero_dummy_action() + obs, _, _, _ = env.step(dummy) + action_log.append(dummy) + step += 1 + record_frame(obs) + continue + + if not action_queue: + if is_multi_view: + observation_imgs = _get_libero_images( + obs, + cameras, + flip_images=flip_images, + rotate_180=rotate_180, + ) + result = client.predict(observation_imgs) + else: + observation_img = _get_libero_image( + obs, + cameras[0], + flip_images=flip_images, + rotate_180=rotate_180, + ) + result = client.predict(observation_img) + actions = result.get("action", []) + if not actions: + return EpisodeResult(False, step, "Empty action chunk from server", action_log) + action_queue = _select_action_chunk(actions, action_horizon) + + if comparison_path is not None: + action_video_b64 = result.get("video", []) + if action_video_b64: + action_frames = _decode_b64_frames(action_video_b64) + env_comparison_frames = [capture_comparison_frame(obs)] + comparison_windows.append((action_frames, env_comparison_frames)) + + if action_space == "relative": + base_pose = _obs_to_pose(obs) + + raw_action = _format_action(action_queue.pop(0), action_dim) + if action_space == "relative": + if base_pose is None: + raise RuntimeError("Missing base pose for relative action conversion") + current_pose = _obs_to_pose(obs) + action = _anchored_action_to_delta( + np.asarray(raw_action, dtype=np.float32), + base_pose, + current_pose, + resolved_rotation_space, + ) + action_list = action.tolist() + else: + action = _framewise_action_to_delta( + np.asarray(raw_action, dtype=np.float32), + resolved_rotation_space, + ) + action_list = action.tolist() + + # Map the model's gripper command to the env's [-1, 1] per the dataset convention. + action_list = _remap_gripper(action_list, gripper_mode) + + action_log.append(action_list) + obs, _, done, info = env.step(action_list) + step += 1 + record_frame(obs) + + if comparison_path is not None and comparison_windows: + comparison_windows[-1][1].append(capture_comparison_frame(obs)) + + if isinstance(info, dict) and info.get("success"): + success = True + break + if done: + success = True if not isinstance(info, dict) else bool(info.get("success", True)) + break + + if gif_path is not None: + _save_gif(gif_frames, gif_path, gif_fps) + if comparison_path is not None and comparison_windows: + _save_comparison_gif(comparison_windows, comparison_path, gif_fps) + return EpisodeResult(success, step, None, action_log) + + +def _load_initial_states( + task_suite: Any, + task_id: int, + *, + task_description: str, + initial_states_path: str, + episode_idx: int, +) -> np.ndarray | None: + default_initial_states = task_suite.get_task_init_states(task_id) + + if initial_states_path == "DEFAULT": + return np.array(default_initial_states[episode_idx]) + + with open(initial_states_path, "r", encoding="utf-8") as f: + all_initial_states = json.load(f) + + task_key = task_description.replace(" ", "_") + episode_key = f"demo_{episode_idx}" + if not all_initial_states[task_key][episode_key]["success"]: + return None + return np.array(all_initial_states[task_key][episode_key]["initial_state"]) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="LIBERO closed-loop evaluation via Action HTTP server") + parser.add_argument( + "--server_url", type=str, required=True, help="Base URL for Action server (e.g., http://host:8000)" + ) + parser.add_argument("--task_suite", type=str, default="libero_spatial", choices=sorted(TASK_MAX_STEPS.keys())) + parser.add_argument("--num_trials_per_task", type=int, default=10) + parser.add_argument("--task_ids", type=str, default="", help="Comma-separated task IDs to evaluate (default: all)") + parser.add_argument("--image_size", type=int, default=256, help="Model input image size") + parser.add_argument("--env_image_size", type=int, default=256, help="Environment render resolution") + parser.add_argument("--action_horizon", type=int, default=0, help="Actions to execute per request (0=full chunk)") + parser.add_argument("--action_dim", type=int, default=10, help="Action dimension for LIBERO") + parser.add_argument( + "--action_space", + type=str, + default="frame_wise_relative", + choices=["relative", "frame_wise_relative"], + help="Action space expected from the model (relative=anchored, frame_wise_relative=framewise deltas).", + ) + parser.add_argument( + "--rotation_space", + type=str, + default="auto", + choices=["auto", "3d", "6d", "9d"], + help="Rotation representation for anchored actions (auto infers from action_dim).", + ) + parser.add_argument( + "--gripper_mode", + type=str, + default="zero_one", + choices=["zero_one", "pm_one", "pm_one_flip"], + help="Gripper convention of the training data: 'zero_one' = [0,1] (NVIDIA " + "LIBERO_LeRobot_v3, mapped 1-2g); 'pm_one' = {-1,+1} (community lerobot/libero_*, " + "pass-through); 'pm_one_flip' = {-1,+1} with inverted sign.", + ) + parser.add_argument("--domain_name", type=str, default="libero") + parser.add_argument( + "--camera", + type=str, + default="agentview", + help="Camera(s) to use. Single camera: 'agentview' or 'wrist'. Multiple cameras: comma-separated, e.g., 'agentview,wrist'.", + ) + parser.add_argument("--flip_images", action="store_true", help="Flip images vertically before encoding") + parser.add_argument( + "--rotate_180", + action=argparse.BooleanOptionalAction, + default=True, + help="Rotate images by 180 degrees before encoding (default: True; pass --no-rotate-180 to disable)", + ) + parser.add_argument("--warmup_steps", type=int, default=10, help="Stabilization steps with dummy actions") + parser.add_argument("--max_steps", type=int, default=0, help="Override max steps per episode (0=default)") + parser.add_argument("--timeout", type=float, default=30.0, help="HTTP request timeout in seconds") + parser.add_argument("--wait_timeout", type=float, default=60.0, help="Seconds to wait for server health") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--save_gifs", action="store_true", help="Save per-episode GIFs of rendered frames") + parser.add_argument( + "--save_comparison", + action="store_true", + help="Save side-by-side comparison GIFs (Action prediction vs environment rollout)", + ) + parser.add_argument("--gif_fps", type=int, default=20, help="Frames per second for saved GIFs") + parser.add_argument( + "--mujoco_gl", + type=str, + default="auto", + choices=["auto", "egl", "osmesa", "glfw"], + help="MuJoCo GL backend (auto picks egl if /dev/dri is accessible, else osmesa).", + ) + parser.add_argument( + "--render_gpu_device_id", + type=int, + default=-1, + help="GPU device index for EGL rendering (-1 uses default device).", + ) + parser.add_argument( + "--initial_states_path", + type=str, + default="DEFAULT", + help='Path to initial states JSON. Use "DEFAULT" for benchmark defaults.', + ) + parser.add_argument( + "--num_envs", + type=int, + default=1, + help="Number of parallel LIBERO envs (SubprocVectorEnv). >1 runs trials in waves " + "with ONE batched /predict_batch per control step (~num_envs x faster). 1 = serial.", + ) + parser.add_argument("--output_dir", type=str, default="", help="Directory to save evaluation summary JSON") + return parser.parse_args() + + +class _LiberoEnvFactory: + """Picklable env factory for SubprocVectorEnv under the spawn start method. + + spawn pickles each env_fn and re-imports this module in the child, so the + factory must be a top-level class (lambdas/closures are not picklable). The + child sets the GL backend and imports OffScreenRenderEnv locally so its EGL + context is created fresh in the worker process.""" + + def __init__( + self, + *, + bddl_file_name: str, + camera_heights: int, + camera_widths: int, + render_gpu_device_id: int, + mujoco_gl: str, + ) -> None: + self.bddl_file_name = bddl_file_name + self.camera_heights = camera_heights + self.camera_widths = camera_widths + self.render_gpu_device_id = render_gpu_device_id + self.mujoco_gl = mujoco_gl + + def __call__(self) -> Any: + # Resolve to a concrete GPU; -1 (auto) makes EGL device selection race/fail + # across spawned workers (EGLError / "'EGLGLContext' object has no attribute + # '_context'"). Set the GL backend + pin the EGL device BEFORE importing + # OffScreenRenderEnv (which dlopen's the GL stack at import). + dev = self.render_gpu_device_id if self.render_gpu_device_id >= 0 else 0 + os.environ["MUJOCO_GL"] = self.mujoco_gl + if self.mujoco_gl == "egl": + os.environ["PYOPENGL_PLATFORM"] = "egl" + os.environ["MUJOCO_EGL_DEVICE_ID"] = str(dev) + os.environ["EGL_DEVICE_ID"] = str(dev) + elif self.mujoco_gl == "osmesa": + os.environ["PYOPENGL_PLATFORM"] = "osmesa" + from libero.libero.envs import OffScreenRenderEnv as _OffScreenRenderEnv + + return _OffScreenRenderEnv( + bddl_file_name=self.bddl_file_name, + camera_heights=self.camera_heights, + camera_widths=self.camera_widths, + render_gpu_device_id=dev, + ) + + +def _run_task_vectorized( + task: Any, + task_description: str, + *, + num_trials: int, + num_envs: int, + env_image_size: int, + seed: int, + render_gpu_device_id: int, + client: ActionEnvironmentClient, + cameras: list[str], + flip_images: bool, + rotate_180: bool, + action_horizon: int, + action_dim: int, + rotation_space: str, + gripper_mode: str, + max_steps: int, + warmup_steps: int, + init_states: list[np.ndarray | None], +) -> list[dict[str, Any]]: + """Run all `num_trials` of one task across `num_envs` parallel LIBERO envs + (SubprocVectorEnv), in waves. Each control step gathers obs from the ACTIVE + (not-done) envs, issues ONE batched /predict_batch, and steps all active envs; + done envs are masked out. Returns per-trial result dicts in trial order with the + same shape as the serial path's episode_results.""" + import multiprocessing as _mp + + from libero.libero.envs.venv import SubprocVectorEnv + + # LIBERO's SubprocVectorEnv defaults to the fork start method; forked children + # inherit the parent's already-dlopen'd EGL/GL state, which corrupts per-child + # render-context creation (EGLError / 'EGLGLContext' has no attribute '_context'). + # Force spawn so each env worker starts clean — exactly like the (working) serial + # single-process path. spawn pickles env_fns, so the factory below is picklable. + try: + _mp.set_start_method("spawn", force=True) + except RuntimeError: # pragma: no cover - already set + pass + + resolved_rotation_space = _infer_rotation_space(action_dim, rotation_space) + bddl = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file) + + results: list[dict[str, Any]] = [None] * num_trials # type: ignore[list-item] + for t in range(num_trials): + if init_states[t] is None: + results[t] = { + "episode": t, + "success": False, + "steps": 0, + "error": "Skipped due to failed expert demo", + "elapsed_s": 0.0, + } + runnable = [t for t in range(num_trials) if init_states[t] is not None] + if not runnable: + return results + + n = min(num_envs, len(runnable)) + + mujoco_gl = os.environ.get("MUJOCO_GL", "egl") + env_fn = _LiberoEnvFactory( + bddl_file_name=bddl, + camera_heights=env_image_size, + camera_widths=env_image_size, + render_gpu_device_id=render_gpu_device_id, + mujoco_gl=mujoco_gl, + ) + venv = SubprocVectorEnv([env_fn for _ in range(n)]) + try: + venv.seed(seed) + for w0 in range(0, len(runnable), n): + wave = runnable[w0 : w0 + n] # trial indices for this wave + slots = list(range(len(wave))) # env slots in use + t_wave0 = time.perf_counter() + venv.reset(id=slots) + states = np.stack([np.asarray(init_states[t], dtype=np.float64) for t in wave]) + obs_arr = venv.set_init_state(states, id=slots) + obs_by_slot = {s: obs_arr[i] for i, s in enumerate(slots)} + done = {s: False for s in slots} + succ = {s: False for s in slots} + err: dict[int, str | None] = {s: None for s in slots} + nsteps = {s: max_steps for s in slots} + step = 0 + + for _ in range(warmup_steps): + act = np.stack([_get_libero_dummy_action() for _ in slots]) + obs_arr, _, _, _ = venv.step(act, id=slots) + for i, s in enumerate(slots): + obs_by_slot[s] = obs_arr[i] + step += 1 + + while step < max_steps: + active = [s for s in slots if not done[s]] + if not active: + break + obs_batch = [ + _get_libero_images(obs_by_slot[s], cameras, flip_images=flip_images, rotate_180=rotate_180) + for s in active + ] + try: + chunks = client.predict_batch(obs_batch) + except Exception as e: # noqa: BLE001 + for s in active: + done[s] = True + err[s] = f"server error: {e}" + nsteps[s] = step + break + if not chunks or len(chunks) != len(active): + for s in active: + done[s] = True + err[s] = "bad batch response from server" + nsteps[s] = step + break + chunk_by_slot = {s: chunks[k] for k, s in enumerate(active)} + horizon = action_horizon if action_horizon > 0 else len(chunks[0]) + for h in range(horizon): + cur = [s for s in slots if not done[s]] + if not cur or step >= max_steps: + break + env_actions = [] + for s in cur: + raw = _format_action(chunk_by_slot[s][h], action_dim) + a = _framewise_action_to_delta(np.asarray(raw, dtype=np.float32), resolved_rotation_space) + env_actions.append(_remap_gripper(a.tolist(), gripper_mode)) + obs_arr, _, d, info = venv.step(np.stack(env_actions), id=cur) + step += 1 + for i, s in enumerate(cur): + obs_by_slot[s] = obs_arr[i] + di = bool(d[i]) + ii = info[i] if isinstance(info, (list, np.ndarray)) else info + is_succ = bool(ii.get("success")) if isinstance(ii, dict) else False + if is_succ: + done[s], succ[s], nsteps[s] = True, True, step + elif di: + # mirror serial: done w/o explicit success defaults to success + done[s] = True + succ[s] = ii.get("success", True) if isinstance(ii, dict) else True + nsteps[s] = step + per_ep_elapsed = round((time.perf_counter() - t_wave0) / max(1, len(wave)), 3) + for s, t in zip(slots, wave): + results[t] = { + "episode": t, + "success": bool(succ[s]), + "steps": int(nsteps[s]), + "error": err[s], + "elapsed_s": per_ep_elapsed, + } + finally: + try: + venv.close() + except Exception: # noqa: BLE001 + pass + return results + + +def main() -> None: + args = _parse_args() + random.seed(args.seed) + np.random.seed(args.seed) + + if args.save_gifs and not args.output_dir: + raise ValueError("--save_gifs requires --output_dir to be set") + if args.save_comparison and not args.output_dir: + raise ValueError("--save_comparison requires --output_dir to be set") + + # Parse cameras from comma-separated string + cameras = [c.strip() for c in args.camera.split(",") if c.strip()] + if not cameras: + raise ValueError("At least one camera must be specified") + for cam in cameras: + if cam not in ("agentview", "wrist"): + raise ValueError(f"Unsupported camera={cam!r}. Use 'agentview' or 'wrist'.") + + mujoco_backend = _configure_mujoco_env(args.mujoco_gl) + _import_libero() + + client = ActionEnvironmentClient( + server_url=args.server_url, + domain_name=args.domain_name, + prompt="", + image_size=args.image_size, + timeout=args.timeout, + ) + print(f"MuJoCo GL backend: {mujoco_backend}", flush=True) + print("Waiting for model server...", flush=True) + _wait_for_server(client, args.wait_timeout) + print(f"Connected to model server: {client.get_info()}", flush=True) + + benchmark_dict = benchmark.get_benchmark_dict() + task_suite = benchmark_dict[args.task_suite]() + num_tasks = int(task_suite.n_tasks) + + if args.task_ids: + selected_task_ids = [int(t) for t in args.task_ids.split(",") if t.strip()] + else: + selected_task_ids = list(range(num_tasks)) + + max_steps = args.max_steps if args.max_steps > 0 else TASK_MAX_STEPS[args.task_suite] + + total_episodes = 0 + total_successes = 0 + task_results: list[dict[str, Any]] = [] + + output_dir = Path(args.output_dir) if args.output_dir else None + gif_root = output_dir / "gifs" if output_dir and args.save_gifs else None + comparison_root = output_dir / "comparisons" if output_dir and args.save_comparison else None + + for task_id in selected_task_ids: + task = task_suite.get_task(task_id) + + # ---- Vectorized path: N parallel envs + one batched /predict_batch per step ---- + if args.num_envs > 1: + task_description = str(task.language) + client.prompt = _augment_task_prompt_with_viewpoint(task_description, cameras) + init_states = [ + _load_initial_states( + task_suite, + task_id, + task_description=task_description, + initial_states_path=args.initial_states_path, + episode_idx=e, + ) + for e in range(args.num_trials_per_task) + ] + episode_results = _run_task_vectorized( + task, + task_description, + num_trials=args.num_trials_per_task, + num_envs=args.num_envs, + env_image_size=args.env_image_size, + seed=args.seed, + render_gpu_device_id=args.render_gpu_device_id, + client=client, + cameras=cameras, + flip_images=args.flip_images, + rotate_180=args.rotate_180, + action_horizon=args.action_horizon, + action_dim=args.action_dim, + rotation_space=args.rotation_space, + gripper_mode=args.gripper_mode, + max_steps=max_steps, + warmup_steps=args.warmup_steps, + init_states=init_states, + ) + task_episodes = 0 + task_successes = 0 + for er in episode_results: + task_episodes += 1 + total_episodes += 1 + if er["success"]: + task_successes += 1 + total_successes += 1 + print( + f"Task {task_id} | Episode {er['episode'] + 1}/{args.num_trials_per_task} | " + f"success={er['success']} steps={er['steps']} elapsed_s={er['elapsed_s']:.1f} | " + f"task SR {task_successes}/{task_episodes} ({100.0 * task_successes / max(1, task_episodes):.1f}%) | " + f"overall SR {total_successes}/{total_episodes} " + f"({100.0 * total_successes / max(1, total_episodes):.1f}%)", + flush=True, + ) + task_success_rate = float(task_successes) / float(task_episodes) if task_episodes > 0 else 0.0 + task_results.append( + { + "task_id": task_id, + "task_description": task_description, + "episodes": task_episodes, + "successes": task_successes, + "success_rate": task_success_rate, + "episode_results": episode_results, + } + ) + print( + f"Task {task_id} summary: {task_successes}/{task_episodes} ({task_success_rate * 100:.1f}%)", + flush=True, + ) + continue + + env, task_description = _get_libero_env( + task, + resolution=args.env_image_size, + seed=args.seed, + render_gpu_device_id=args.render_gpu_device_id, + ) + + task_episodes = 0 + task_successes = 0 + episode_results: list[dict[str, Any]] = [] + + for episode_idx in range(args.num_trials_per_task): + episode_t0 = time.perf_counter() + client.prompt = _augment_task_prompt_with_viewpoint(task_description, cameras) + initial_state = _load_initial_states( + task_suite, + task_id, + task_description=task_description, + initial_states_path=args.initial_states_path, + episode_idx=episode_idx, + ) + if initial_state is None: + episode_elapsed_s = time.perf_counter() - episode_t0 + episode_results.append( + { + "episode": episode_idx, + "success": False, + "steps": 0, + "error": "Skipped due to failed expert demo", + "elapsed_s": round(episode_elapsed_s, 3), + } + ) + print( + f"Task {task_id} | Episode {episode_idx + 1}/{args.num_trials_per_task} | " + "success=False steps=0 " + f"elapsed_s={episode_elapsed_s:.1f} " + "error='Skipped due to failed expert demo'", + flush=True, + ) + continue + + gif_path = ( + gif_root / f"task_{task_id:03d}" / f"episode_{episode_idx:03d}.gif" if gif_root is not None else None + ) + comparison_path = ( + comparison_root / f"task_{task_id:03d}" / f"episode_{episode_idx:03d}.gif" + if comparison_root is not None + else None + ) + try: + result = _run_episode( + env, + client, + cameras=cameras, + flip_images=args.flip_images, + rotate_180=args.rotate_180, + action_horizon=args.action_horizon, + action_dim=args.action_dim, + action_space=args.action_space, + rotation_space=args.rotation_space, + gripper_mode=args.gripper_mode, + max_steps=max_steps, + warmup_steps=args.warmup_steps, + initial_state=initial_state, + gif_path=gif_path, + gif_fps=args.gif_fps, + comparison_path=comparison_path, + ) + except Exception as exc: + result = EpisodeResult(False, 0, str(exc), []) + episode_elapsed_s = time.perf_counter() - episode_t0 + + task_episodes += 1 + total_episodes += 1 + if result.success: + task_successes += 1 + total_successes += 1 + + episode_results.append( + { + "episode": episode_idx, + "success": result.success, + "steps": result.steps, + "error": result.error, + "elapsed_s": round(episode_elapsed_s, 3), + } + ) + + # Save per-episode action log as JSON + if output_dir is not None and result.actions: + action_log_dir = output_dir / "actions" / f"task_{task_id:03d}" + action_log_dir.mkdir(parents=True, exist_ok=True) + action_log_path = action_log_dir / f"episode_{episode_idx:03d}.json" + action_log_path.write_text( + json.dumps(result.actions, indent=2), + encoding="utf-8", + ) + + client.notify_next_episode() + + print( + f"Task {task_id} | Episode {episode_idx + 1}/{args.num_trials_per_task} | " + f"success={result.success} steps={result.steps} elapsed_s={episode_elapsed_s:.1f} | " + f"task SR {task_successes}/{task_episodes} ({100.0 * task_successes / max(1, task_episodes):.1f}%) | " + f"overall SR {total_successes}/{total_episodes} ({100.0 * total_successes / max(1, total_episodes):.1f}%)", + flush=True, + ) + + task_success_rate = float(task_successes) / float(task_episodes) if task_episodes > 0 else 0.0 + task_results.append( + { + "task_id": task_id, + "task_description": task_description, + "episodes": task_episodes, + "successes": task_successes, + "success_rate": task_success_rate, + "episode_results": episode_results, + } + ) + print( + f"Task {task_id} summary: {task_successes}/{task_episodes} ({task_success_rate * 100:.1f}%)", + flush=True, + ) + # Close the env (and its EGL/MuJoCo render context) before the next task. + # Leaving it open leaks one EGL context per task and hangs after ~8 tasks. + try: + env.close() + except Exception: + pass + + overall_success_rate = float(total_successes) / float(total_episodes) if total_episodes > 0 else 0.0 + summary = { + "task_suite": args.task_suite, + "total_episodes": total_episodes, + "total_successes": total_successes, + "overall_success_rate": overall_success_rate, + "num_trials_per_task": args.num_trials_per_task, + "selected_task_ids": selected_task_ids, + "action_space": args.action_space, + "rotation_space": _infer_rotation_space(args.action_dim, args.rotation_space), + "action_dim": args.action_dim, + "task_results": task_results, + } + + print( + f"Overall success rate: {total_successes}/{total_episodes} ({overall_success_rate * 100:.1f}%)", + flush=True, + ) + + if output_dir is not None: + output_dir.mkdir(parents=True, exist_ok=True) + summary_path = output_dir / "summary.json" + summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") + print(f"Saved summary to {summary_path}", flush=True) + + +if __name__ == "__main__": + main() diff --git a/cosmos_framework/utils/vfm/model_loader.py b/cosmos_framework/utils/vfm/model_loader.py index 6e6a0dd..51140b3 100644 --- a/cosmos_framework/utils/vfm/model_loader.py +++ b/cosmos_framework/utils/vfm/model_loader.py @@ -252,10 +252,17 @@ def _load_model( keys_to_skip_loading=keys_to_skip_loading or [], ) + # Single-rank load (e.g. the action-policy inference server): force no_dist so + # ``dcp.load`` skips the collective ``gather_object`` over the load plan, which + # pickles the plan and can fail on training/EMA DCPs. Multi-rank loads keep the + # default distributed path. + no_dist = not (dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1) + dcp.load( state_dict=state_dict, storage_reader=storage_reader, planner=load_planner, + no_dist=no_dist, ) log.info(f"Successfully loaded model from {checkpoint_path}") diff --git a/docs/action_policy_libero_sft.md b/docs/action_policy_libero_sft.md new file mode 100644 index 0000000..660de11 --- /dev/null +++ b/docs/action_policy_libero_sft.md @@ -0,0 +1,109 @@ +# Cosmos3-Nano LIBERO-10 action-policy SFT + +Full SFT of the public `nvidia/Cosmos3-Nano` base into a LIBERO-10 action +policy: vision + language in, action chunks out. + +| Piece | Path | +| ---------------- | ----------------------------------------------------------------------------------------------- | +| Dataset | `cosmos_framework/data/vfm/action/datasets/libero_lerobot_dataset.py` (`LIBEROLeRobotDataset`) | +| SFT wrapper | `get_action_libero_sft_dataset` in `.../datasets/action_sft_dataset.py` | +| Norm stats | `.../datasets/stats/libero_native_frame_wise_relative_rot6d.json` | +| Experiment | `cosmos_framework/configs/base/experiment/action/posttrain_config/action_policy_libero_nano.py` | +| Run TOML | `examples/toml/sft_config/action_policy_libero_repro.toml` | +| Launch | `examples/launch_sft_action_policy_libero.sh` | +| Inference server | `cosmos_framework/scripts/action_policy_server_libero.py` | +| Closed-loop eval | `cosmos_framework/simulation/libero/closed_loop_eval.py` | + +## 1. Data + +`LIBEROLeRobotDataset` reads a local LeRobot dir (`LIBERO_ROOT`). Use the 20 FPS +[`nvidia/LIBERO_LeRobot_v3`](https://huggingface.co/datasets/nvidia/LIBERO_LeRobot_v3), +which the bundled `quantile_rot` stats and the 20 Hz eval assume. Train on +`libero_10` alone: + +```bash +hf download nvidia/LIBERO_LeRobot_v3 --repo-type dataset \ + --include 'libero_10/**' --local-dir /LIBERO_LeRobot_v3 +export LIBERO_ROOT=/LIBERO_LeRobot_v3/libero_10 +``` + +Actions are `frame_wise_relative` rot6d (10D = pos 3 + rot6d 6 + gripper 1), +`concat_view` (third-person + wrist, each 256×256 → 256×512), `quantile_rot` +normalized. The pipeline snaps the 256×512 concat to a 192×320 model canvas; the +eval server reproduces the same snap (§4). + +## 2. Train + +```bash +export LD_LIBRARY_PATH='' # NGC container: avoid torch._C import error +export LIBERO_ROOT=/path/to/libero_10_lerobot +export BASE_CHECKPOINT_PATH= +export WAN_VAE_PATH= +export IMAGINAIRE_OUTPUT_ROOT=/path/to/output_root + +bash examples/launch_sft_action_policy_libero.sh # HSDP 2x8; set NNODES/NODE_RANK/MASTER_ADDR per node +``` + +Recipe knobs live in `action_policy_libero_nano`; the TOML sets run-level scalars +(lr 5e-5, warmup 500, cycle 16000, `save_iter=500`, HSDP 2x8). Global batch is +2048 = `max_samples_per_batch` 128 × 16 ranks × grad_accum 1. + +## 3. Closed-loop eval + +Start the policy server on a **trained** checkpoint (the base DCP has no action +heads), then run the LIBERO simulator client against it. + +```bash +python -m cosmos_framework.scripts.action_policy_server_libero \ + --experiment action_policy_libero_nano \ + --experiment-overrides "model.config.tokenizer.vae_path=$WAN_VAE_PATH" \ + --checkpoint-path /checkpoints/iter_000001500 \ + --action-normalization quantile_rot \ + --action-stats-path cosmos_framework/data/vfm/action/datasets/stats/libero_native_frame_wise_relative_rot6d.json \ + --raw-action-dim 10 --fps 20 --port 8000 +``` + +The LIBERO sim needs a separate venv (robosuite/mujoco pins conflict with the +training env): + +```bash +# Optional — only on a headless container without working GPU EGL: +# export NVIDIA_DRIVER_CAPABILITIES=all +# apt-get install -y libegl1 libglvnd0 libgl1 libglib2.0-0 ffmpeg +# mkdir -p /usr/share/glvnd/egl_vendor.d +# echo '{"file_format_version":"1.0.0","ICD":{"library_path":"libEGL_nvidia.so.0"}}' \ +# > /usr/share/glvnd/egl_vendor.d/10_nvidia.json + +uv venv --python 3.10 .libenv && VV=.libenv/bin/python +git clone https://github.com/Lifelong-Robot-Learning/LIBERO.git && \ + uv pip install -p $VV -e LIBERO -r LIBERO/requirements.txt +uv pip install -p $VV "robosuite==1.4.1" "mujoco==2.3.7" "torch<2.6" loguru requests scipy pillow numpy +mkdir -p ~/.libero && touch ~/.libero/config.yaml +RS=$($VV -c "import robosuite,os;print(os.path.dirname(robosuite.__file__))"); $VV "$RS/scripts/setup_macros.py" +$VV -c "from libero.libero import set_libero_default_path; set_libero_default_path()" + +MUJOCO_GL=egl PYTHONPATH=$PWD:$PWD/LIBERO $VV \ + cosmos_framework/simulation/libero/closed_loop_eval.py \ + --server_url http://localhost:8000 \ + --task_suite libero_10 --num_trials_per_task 50 --num_envs 8 \ + --camera agentview,wrist --image_size 256 \ + --action_space frame_wise_relative --rotation_space 6d --action_dim 10 \ + --output_dir results/libero_closed_loop_10 +``` + +## 4. Heads-up + +- **Lower-memory GPUs** — reduce the per-rank batch: + `--opts dataloader_train.max_samples_per_batch=64` (scale `replicate` to keep + global batch 2048). + +Eval parity — the client/server already handle these; verify if accuracy is low: + +- **Concat layout** — run with `--camera agentview,wrist --image_size 256` so the + 256×512 concat matches training (the server snaps it to 192×320 identically). +- **Gripper** — model emits `[0, 1]`; the env wants `[-1, 1]` (negative = open). + The client applies `1 − 2·g`; flip the sign if the gripper never opens. +- **Image orientation** — sim frames are rotated 180° vs training; the client + rotates them back. +- **Normalization** — start the server with `--action-normalization quantile_rot` + and the bundled rot6d stats, or actions come out at the wrong scale. diff --git a/examples/launch_sft_action_policy_libero.sh b/examples/launch_sft_action_policy_libero.sh new file mode 100755 index 0000000..29a9a16 --- /dev/null +++ b/examples/launch_sft_action_policy_libero.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +# Structured-TOML launch for action_policy_libero_nano — Cosmos3-Nano LIBERO +# action-policy SFT (HSDP, full SFT). Drives cosmos_framework.scripts.train +# against examples/toml/sft_config/action_policy_libero_repro.toml. +# +# Point LIBERO_ROOT at the libero_10 suite ONLY. Use the 20 FPS +# nvidia/LIBERO_LeRobot_v3. The default recipe is HSDP 2x8 (global batch 2048); +# set NNODES/NODE_RANK/MASTER_ADDR per node. +# See docs/action_policy_libero_sft.md. +# +# Required env vars: +# LIBERO_ROOT local LIBERO-10 LeRobot dataset dir, e.g. /libero_10 (no default) +# Optional env vars (defaults below; override to relocate data/checkpoints): +# BASE_CHECKPOINT_PATH default: examples/checkpoints/Cosmos3-Nano +# WAN_VAE_PATH default: examples/checkpoints/wan22_vae/Wan2.2_VAE.pth +# HF_TOKEN if any tokenizer download requires gated HF access +# OUTPUT_ROOT default: outputs/train +# +# Pre-sync the 20 FPS suite once: +# hf download nvidia/LIBERO_LeRobot_v3 --repo-type dataset --include 'libero_10/**' --local-dir +# export LIBERO_ROOT=/libero_10 +# +# Usage (HSDP 2x8; set NNODES/NODE_RANK/MASTER_ADDR per node): +# LIBERO_ROOT=/libero_10 bash examples/launch_sft_action_policy_libero.sh + +TOML_FILE="examples/toml/sft_config/action_policy_libero_repro.toml" +: "${BASE_CHECKPOINT_PATH:=examples/checkpoints/Cosmos3-Nano}" + +# LIBEROLeRobotDataset reads ${oc.env:LIBERO_ROOT} directly (a LOCAL LeRobot dir); +# export it so torchrun (launched in this shell) inherits it. +export LIBERO_ROOT="${LIBERO_ROOT:-}" + +EXTRA_DATASET_CHECK='[[ -f "$LIBERO_ROOT/meta/info.json" ]] || { echo "ERROR: LIBERO_ROOT must be a local LeRobot dir containing meta/info.json (got: '\''$LIBERO_ROOT'\''). Pre-sync: hf download nvidia/LIBERO_LeRobot_v3 --repo-type dataset --include '\''libero_10/**'\'' --local-dir (then LIBERO_ROOT=/libero_10). See docs/action_policy_libero_sft.md" >&2; exit 1; }' + +# Extra Hydra overrides from the environment: a space-separated string word-split into +# the TAIL_OVERRIDES array. An exported string survives `bash ` (a child +# process), unlike a TAIL_OVERRIDES array set in your shell. Use it for smoke runs, +# e.g. EXTRA_TAIL_OVERRIDES="trainer.max_iter=5 job.wandb_mode=offline". +TAIL_OVERRIDES=( + ${EXTRA_TAIL_OVERRIDES:-} +) + +source "$(dirname "${BASH_SOURCE[0]}")/_sft_launcher_common.sh" diff --git a/examples/toml/sft_config/action_policy_libero_repro.toml b/examples/toml/sft_config/action_policy_libero_repro.toml new file mode 100644 index 0000000..a0c49c7 --- /dev/null +++ b/examples/toml/sft_config/action_policy_libero_repro.toml @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +# LIBERO-10 action-policy SFT run config for the `action_policy_libero_nano` +# experiment. Train on libero_10 alone (HSDP 2x8, global batch 2048). +# Env: LIBERO_ROOT, BASE_CHECKPOINT_PATH, WAN_VAE_PATH, IMAGINAIRE_OUTPUT_ROOT. +# See docs/action_policy_libero_sft.md. + +[job] +task = "vfm" +experiment = "action_policy_libero_nano" +project = "cosmos3_action_libero" +group = "action_sft" +name = "action_policy_libero_repro" +wandb_mode = "online" + +[model] +precision = "bfloat16" +max_num_tokens_after_packing = 74000 + +[model.parallelism] +data_parallel_shard_degree = 8 +data_parallel_replicate_degree = 2 # HSDP 2x8 = 16 ranks (2 nodes); minimum for gbs 2048 at grad_accum 1 + +[model.activation_checkpointing] +mode = "selective" +save_ops_regex = ["fmha"] + +[model.tokenizer] +vae_path = "${oc.env:WAN_VAE_PATH}" + +[optimizer] +lr = 5.0e-05 + +[scheduler] +cycle_lengths = [16000] +warm_up_steps = [500] + +[trainer] +max_iter = 2000 +logging_iter = 50 +grad_accum_iter = 1 # global batch = max_samples 128 x (shard 8 x replicate 2) x 1 = 2048 + +[checkpoint] +load_path = "${oc.env:BASE_CHECKPOINT_PATH}" +save_iter = 500