diff --git a/cosmos_framework/configs/base/config.py b/cosmos_framework/configs/base/config.py
index e766c5c..5ac2b41 100644
--- a/cosmos_framework/configs/base/config.py
+++ b/cosmos_framework/configs/base/config.py
@@ -97,4 +97,5 @@ def make_config() -> Config:
import cosmos_framework.configs.base.experiment.sft.vision_sft_nano # noqa: F401
import cosmos_framework.configs.base.experiment.sft.vision_sft_super # noqa: F401
import cosmos_framework.configs.base.experiment.action.posttrain_config.action_policy_droid_nano # noqa: F401
+ import cosmos_framework.configs.base.experiment.action.posttrain_config.action_policy_libero_nano # noqa: F401
return c
diff --git a/cosmos_framework/configs/base/experiment/action/posttrain_config/action_policy_libero_nano.py b/cosmos_framework/configs/base/experiment/action/posttrain_config/action_policy_libero_nano.py
new file mode 100644
index 0000000..b05d3ea
--- /dev/null
+++ b/cosmos_framework/configs/base/experiment/action/posttrain_config/action_policy_libero_nano.py
@@ -0,0 +1,233 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: OpenMDW-1.1
+
+"""``action_policy_libero_nano`` — Cosmos3-Nano LIBERO-10 action-policy SFT recipe.
+
+Feeds ``LIBEROLeRobotDataset`` (frame-wise-relative rot6d, ``quantile_rot``,
+concat_view third-person + wrist) and trains the generation + action heads from
+the public ``nvidia/Cosmos3-Nano`` base. Train on ``libero_10`` alone
+(``LIBERO_ROOT``). See docs/action_policy_libero_sft.md.
+"""
+
+import copy
+
+from hydra.core.config_store import ConfigStore
+
+from cosmos_framework.utils.lazy_config import LazyCall as L
+from cosmos_framework.utils.lazy_config import LazyDict
+
+from cosmos_framework.configs.base.experiment.sft.models.nano_model_config import NANO_MODEL_CONFIG
+from cosmos_framework.data.vfm.joint_dataloader import (
+ PackingDataLoader,
+ RankPartitionedDataLoader,
+)
+from cosmos_framework.data.vfm.action.datasets.action_sft_dataset import get_action_libero_sft_dataset
+
+cs = ConfigStore.instance()
+
+
+def _action_policy_libero_nano_model_config() -> dict:
+ """LIBERO model config: capped packed tokens, selective activation
+ checkpointing, fresh diffusion-expert init, 10x vision flow-matching loss.
+ Keep ``encode_exact_durations=[17, 61, 73]`` to match the Cosmos3-Nano base."""
+ cfg = copy.deepcopy(NANO_MODEL_CONFIG) # action_gen=True, max_action_dim=64
+ # Cap the packed sequence. Uncapped (-1) + a large max_samples_per_batch packs
+ # one very long sequence and OOMs even on H200; 74000 keeps the GA-validated bound.
+ cfg["max_num_tokens_after_packing"] = 74000
+ cfg["activation_checkpointing"]["mode"] = "selective"
+ cfg["diffusion_expert_config"]["load_weights_from_pretrained"] = False
+ cfg["rectified_flow_training_config"]["loss_scale"] = 10.0
+ cfg["rectified_flow_training_config"]["image_loss_scale"] = None
+ cfg["tokenizer"]["encode_exact_durations"] = [17, 61, 73] # match Cosmos3 base + reference SFT (do NOT reduce)
+ return cfg
+
+
+action_policy_libero_nano = LazyDict(
+ dict(
+ defaults=[
+ {"override /model": "mot_fsdp"},
+ {"override /data_train": None},
+ {"override /data_val": None},
+ # FusedAdam with fp32 master_weights + eps 1e-8 (bf16 params + eps 1e-6
+ # diverged on the action loss).
+ {"override /optimizer": "fusedadamw"},
+ {"override /scheduler": "lambdalinear"}, # linear LR decay
+ {"override /checkpoint": "s3"},
+ {
+ "override /callbacks": [
+ "basic",
+ "optimization",
+ "job_monitor",
+ ]
+ },
+ {"override /ema": "power"},
+ {"override /tokenizer": "wan2pt2_tokenizer"},
+ {"override /sound_tokenizer": None},
+ {"override /vlm_config": None},
+ {"override /ckpt_type": "dcp"},
+ "_self_",
+ ],
+ job=dict(
+ project="cosmos3",
+ group="action_sft",
+ name="action_policy_libero_nano",
+ wandb_mode="disabled",
+ ),
+ model=dict(
+ config=_action_policy_libero_nano_model_config(),
+ ),
+ optimizer=dict(
+ betas=[0.9, 0.99],
+ eps=1.0e-08,
+ fused=True, # popped by build_optimizer for FusedAdam (fused by construction)
+ # Train the generation + action heads.
+ keys_to_select=[
+ "moe_gen",
+ "time_embedder",
+ "vae2llm",
+ "llm2vae",
+ "action2llm",
+ "llm2action",
+ "action_modality_embed",
+ ],
+ lr=5.0e-05,
+ lr_multipliers={
+ "action2llm": 5.0,
+ "llm2action": 5.0,
+ "action_modality_embed": 5.0,
+ },
+ optimizer_type="FusedAdam",
+ weight_decay=0.05,
+ ),
+ scheduler=dict(
+ lr_scheduler_type="LambdaLinear",
+ cycle_lengths=[100], # smoke: 100 iters (real run sets via TOML, GA=10000)
+ f_max=[1.0],
+ f_min=[0.0],
+ f_start=[1.0e-06],
+ verbosity_interval=0,
+ warm_up_steps=[0], # smoke (real run sets via TOML, GA=2000)
+ ),
+ trainer=dict(
+ distributed_parallelism="fsdp",
+ grad_accum_iter=1, # real run sets via TOML (GA=2)
+ logging_iter=1,
+ max_iter=100, # smoke
+ max_val_iter=None,
+ run_validation=False,
+ run_validation_on_start=False,
+ save_zero_checkpoint=False,
+ seed=42,
+ timeout_period=999999999,
+ validation_iter=100,
+ compile_config=dict(recompile_limit=8, use_duck_shape=False),
+ cudnn=dict(benchmark=True, deterministic=False),
+ ddp=dict(broadcast_buffers=True, find_unused_parameters=False, static_graph=True),
+ grad_scaler_args=dict(enabled=False),
+ callbacks=dict(
+ dataloader_speed=dict(every_n=100, save_s3=False, step_size=1),
+ device_monitor=dict(
+ every_n=200, log_memory_detail=True, save_s3=False, step_size=1, upload_every_n_mul=5
+ ),
+ grad_clip=dict(clip_norm=1.0, force_finite=True),
+ heart_beat=dict(every_n=200, save_s3=False, step_size=1, update_interval_in_minute=20),
+ iter_speed=dict(every_n=1, hit_thres=50, save_s3=False, save_s3_every_log_n=500),
+ low_precision=dict(update_iter=1),
+ manual_gc=dict(every_n=5, gc_level=1, warm_up=1),
+ param_count=dict(save_s3=False),
+ skip_nan_step=dict(max_consecutive_nan=100),
+ training_stats=dict(log_freq=100),
+ ),
+ ),
+ checkpoint=dict(
+ broadcast_via_filesystem=False,
+ dcp_async_mode_enabled=False,
+ enable_gcs_patch_in_boto3=True,
+ keys_not_to_resume=[],
+ # Skip net_ema (EMA warm-starts from net, see dcp.py) and the action
+ # heads, so they init fresh from the base (the public Cosmos3-Nano base
+ # has no LIBERO-trained action heads).
+ keys_to_skip_loading=[
+ "net_ema.",
+ "action2llm",
+ "llm2action",
+ "action_modality_embed",
+ "action_pos_embed",
+ ],
+ load_ema_to_reg=False,
+ load_path="???", # Cosmos3-Nano DCP dir; supply via TOML/env
+ load_training_state=False,
+ only_load_scheduler_state=False,
+ save_iter=100,
+ strict_resume=False, # base init: tolerate key set differences
+ verbose=True,
+ hf_export=dict(
+ enabled=False,
+ export_every_n=1,
+ hf_repo_id=None,
+ upload_to_object_store=dict(bucket="", credentials="", enabled=False),
+ ),
+ jit=dict(device="cuda", dtype="bfloat16", enabled=False, input_shape=None, strict=True),
+ load_from_object_store=dict(bucket="", credentials="", enabled=False),
+ save_to_object_store=dict(bucket="", credentials="", enabled=False),
+ ),
+ dataloader_train=L(PackingDataLoader)(
+ audio_sample_rate=48000,
+ dataset_name="action_libero",
+ max_samples_per_batch=128, # peak-mem bound (256 OOMs on H200); global = 128 x DP8 x grad_accum2 = 2048
+ max_sequence_length=None, # None disables token packing (TOML can't express null)
+ patch_spatial=2,
+ sound_latent_fps=0,
+ tokenizer_spatial_compression_factor=16,
+ tokenizer_temporal_compression_factor=4,
+ dataloader=L(RankPartitionedDataLoader)(
+ batch_size=1,
+ in_order=False,
+ num_workers=4,
+ persistent_workers=True,
+ pin_memory=True,
+ prefetch_factor=4,
+ sampler=None,
+ # Shuffling is handled by the dataset (iterable_shuffle=True below):
+ # ActionIterableShuffleDataset streams rank x worker-sharded, episode-order-
+ # shuffled, sequential-within-episode.
+ datasets=dict(
+ libero=dict(
+ ratio=1,
+ dataset=L(get_action_libero_sft_dataset)(
+ # Local LeRobot dir for the libero_10 suite ONLY. Use the
+ # 20 FPS nvidia/LIBERO_LeRobot_v3 (matches the bundled stats + 20 Hz eval):
+ # hf download nvidia/LIBERO_LeRobot_v3 --repo-type dataset \
+ # --include 'libero_10/**' --local-dir
# LIBERO_ROOT=/libero_10
+ root="${oc.env:LIBERO_ROOT}",
+ fps=20, # metadata only (FPS-agnostic loader reads native fps from info.json)
+ chunk_length=16,
+ image_size=256, # concat_view -> 256x512
+ mode="policy",
+ camera_mode="concat_view",
+ action_space="frame_wise_relative",
+ rotation_space="6d",
+ pose_coordinate_frame="native",
+ action_normalization="quantile_rot",
+ val_ratio=0.01,
+ iterable_shuffle=True,
+ episode_shuffle_seed=42,
+ resolution=None,
+ max_action_dim="${model.config.max_action_dim}",
+ cfg_dropout_rate=0.1,
+ tokenizer_config="${model.config.vlm_config.tokenizer}",
+ ),
+ ),
+ ),
+ ),
+ ),
+ dataloader_val=None,
+ upload_reproducible_setup=False,
+ ),
+ flags={"allow_objects": True},
+)
+
+
+for _item in [action_policy_libero_nano]:
+ _name = [k for k, v in globals().items() if v is _item][0]
+ cs.store(group="experiment", package="_global_", name=_name, node=_item)
diff --git a/cosmos_framework/data/vfm/action/datasets/__init__.py b/cosmos_framework/data/vfm/action/datasets/__init__.py
index 0b01e6b..6365693 100644
--- a/cosmos_framework/data/vfm/action/datasets/__init__.py
+++ b/cosmos_framework/data/vfm/action/datasets/__init__.py
@@ -12,6 +12,7 @@
from cosmos_framework.data.vfm.action.datasets.base_dataset import ActionBaseDataset
from cosmos_framework.data.vfm.action.datasets.bridge_orig_lerobot_dataset import BridgeOrigLeRobotDataset
from cosmos_framework.data.vfm.action.datasets.droid_lerobot_dataset import DROIDLeRobotDataset
+from cosmos_framework.data.vfm.action.datasets.libero_lerobot_dataset import LIBEROLeRobotDataset
from cosmos_framework.data.vfm.action.datasets.robomind_franka_dataset import RoboMINDFrankaDataset
__all__ = [
@@ -19,5 +20,6 @@
"AgiBotWorldBetaLeRobotDataset",
"BridgeOrigLeRobotDataset",
"DROIDLeRobotDataset",
+ "LIBEROLeRobotDataset",
"RoboMINDFrankaDataset",
]
diff --git a/cosmos_framework/data/vfm/action/datasets/action_sft_dataset.py b/cosmos_framework/data/vfm/action/datasets/action_sft_dataset.py
index 1790de5..96a5219 100644
--- a/cosmos_framework/data/vfm/action/datasets/action_sft_dataset.py
+++ b/cosmos_framework/data/vfm/action/datasets/action_sft_dataset.py
@@ -19,6 +19,7 @@
from torch.utils.data import Dataset, IterableDataset, get_worker_info
from cosmos_framework.data.vfm.action.datasets.droid_lerobot_dataset import DROIDLeRobotDataset
+from cosmos_framework.data.vfm.action.datasets.libero_lerobot_dataset import LIBEROLeRobotDataset
from cosmos_framework.data.vfm.action.transforms import ActionTransformPipeline
@@ -139,3 +140,72 @@ def get_action_droid_sft_dataset(
if iterable_shuffle:
return ActionIterableShuffleDataset(sft, seed=episode_shuffle_seed)
return sft
+
+
+def get_action_libero_sft_dataset(
+ *,
+ root: str,
+ fps: float = 20.0,
+ chunk_length: int = 16,
+ image_size: int = 256,
+ mode: str = "policy",
+ camera_mode: str = "concat_view",
+ action_space: str = "frame_wise_relative",
+ rotation_space: str = "6d",
+ pose_coordinate_frame: str = "native",
+ action_normalization: str | None = "quantile_rot",
+ action_stats_path: str | None = None,
+ split: str = "train",
+ val_ratio: float = 0.01,
+ seed: int = 0,
+ resolution: str | int | None = None,
+ max_action_dim: int = 64,
+ tokenizer_config: dict | None = None,
+ cfg_dropout_rate: float = 0.1,
+ append_viewpoint_info: bool = True,
+ append_duration_fps_timestamps: bool = True,
+ append_resolution_info: bool = True,
+ append_idle_frames: bool = True,
+ iterable_shuffle: bool = False,
+ episode_shuffle_seed: int = 42,
+) -> Dataset:
+ """Build the LIBERO action-policy SFT dataset (GA reproduction defaults).
+
+ Feeds ``LIBEROLeRobotDataset`` (frame-wise-relative rot6d actions,
+ ``quantile_rot``-normalized, concat_view third-person + wrist at 256x256 each
+ → 256x512) through ``ActionTransformPipeline``. ``root`` is a LOCAL LeRobot dir
+ (read parquet + video directly); pre-sync the HF dataset once, e.g.
+ ``hf download lerobot/libero_10 --repo-type dataset --local-dir ``. Point
+ ``root`` at libero_10 alone. The
+ dataset is FPS-agnostic (decodes at real frame timestamps); ``fps`` is metadata
+ for ``conditioning_fps`` / prompt duration.
+ """
+ dataset = LIBEROLeRobotDataset(
+ root=root,
+ image_size=image_size,
+ chunk_length=chunk_length,
+ fps=fps,
+ mode=mode,
+ split=split,
+ val_ratio=val_ratio,
+ seed=seed,
+ camera_mode=camera_mode,
+ action_space=action_space,
+ rotation_space=rotation_space,
+ pose_coordinate_frame=pose_coordinate_frame,
+ action_normalization=action_normalization,
+ action_stats_path=action_stats_path,
+ )
+ transform = ActionTransformPipeline(
+ tokenizer_config=tokenizer_config,
+ cfg_dropout_rate=cfg_dropout_rate,
+ max_action_dim=max_action_dim,
+ append_viewpoint_info=append_viewpoint_info,
+ append_duration_fps_timestamps=append_duration_fps_timestamps,
+ append_resolution_info=append_resolution_info,
+ append_idle_frames=append_idle_frames,
+ )
+ sft = ActionSFTDataset(dataset, transform, resolution)
+ if iterable_shuffle:
+ return ActionIterableShuffleDataset(sft, seed=episode_shuffle_seed)
+ return sft
diff --git a/cosmos_framework/data/vfm/action/datasets/base_dataset.py b/cosmos_framework/data/vfm/action/datasets/base_dataset.py
index 2c9c4cb..f9b0f61 100644
--- a/cosmos_framework/data/vfm/action/datasets/base_dataset.py
+++ b/cosmos_framework/data/vfm/action/datasets/base_dataset.py
@@ -69,18 +69,25 @@ def __init__(
for path in sorted((self._root / "meta" / "episodes").glob("chunk-*/file-*.parquet"))
for row in pq.read_table(path).to_pylist()
}
- self._tasks = {
- int(row["task_index"]): str(row["task"])
- for row in pq.read_table(self._root / "meta" / "tasks.parquet").to_pylist()
- }
- self._rows = sorted(
- (
- row
- for path in sorted((self._root / "data").glob("chunk-*/file-*.parquet"))
- for row in pq.read_table(path).to_pylist()
- ),
- key=lambda row: int(row["index"]),
- )
+ # ``meta/tasks.parquet`` normally has a ``task`` column. Some LeRobot
+ # conversions (e.g. the community LIBERO datasets) instead store the task
+ # string as the (unnamed) pandas index, which pyarrow surfaces as
+ # ``__index_level_0__``. Fall back to the lone non-``task_index`` field so
+ # both layouts work (datasets that have ``task`` are unaffected).
+ self._tasks = {}
+ for row in pq.read_table(self._root / "meta" / "tasks.parquet").to_pylist():
+ if "task" in row:
+ task = row["task"]
+ else:
+ extras = [v for k, v in row.items() if k != "task_index"]
+ task = extras[0] if extras else ""
+ self._tasks[int(row["task_index"])] = str(task)
+ # ``self._rows`` (the flat, index-sorted list of every frame dict) is built
+ # lazily on first access — see the ``_rows`` property. Materializing all
+ # ~18M frames as Python dicts plus a full sort costs ~13 min and tens of GB;
+ # subclasses that build their own compact index (e.g. DROIDLeRobotDataset)
+ # never touch it, so they must not pay for it at construction.
+ self._rows_cache: list[dict[str, Any]] | None = None
@property
def fps(self) -> float:
@@ -213,5 +220,25 @@ def _build_result(
**extras,
}
+ @property
+ def _rows(self) -> list[dict[str, Any]]:
+ """Flat, index-sorted list of every frame dict, built lazily on first access.
+
+ Only datasets that don't build their own compact index (bridge / agibot /
+ robomind) touch this; for them it materializes once and caches. Datasets with
+ a bespoke index (e.g. DROIDLeRobotDataset) never read it, so they skip the
+ ~13 min / tens-of-GB construction entirely.
+ """
+ if self._rows_cache is None:
+ self._rows_cache = sorted(
+ (
+ row
+ for path in sorted((self._root / "data").glob("chunk-*/file-*.parquet"))
+ for row in pq.read_table(path).to_pylist()
+ ),
+ key=lambda row: int(row["index"]),
+ )
+ return self._rows_cache
+
def __len__(self) -> int:
return max(0, (len(self._rows) - self._chunk_length + self._sample_stride - 1) // self._sample_stride)
diff --git a/cosmos_framework/data/vfm/action/datasets/libero_lerobot_dataset.py b/cosmos_framework/data/vfm/action/datasets/libero_lerobot_dataset.py
new file mode 100644
index 0000000..1e5ef01
--- /dev/null
+++ b/cosmos_framework/data/vfm/action/datasets/libero_lerobot_dataset.py
@@ -0,0 +1,333 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: OpenMDW-1.1
+
+"""LIBERO LeRobot dataset (frame-wise-relative action policy).
+
+Mirrors ``DROIDLeRobotDataset``: reads the LeRobot parquet directly, windows by
+frame index, and decodes video at each frame's REAL timestamp. That makes it
+FPS-agnostic — it works with the 10 FPS community ``lerobot/libero_*`` datasets
+and a 20 FPS conversion alike, without LeRobot's ``delta_timestamps`` grid (which
+rejects any window whose synthetic timestamps don't land on real frames).
+
+Action layout (``frame_wise_relative``): the stored 7D ``action`` is already a
+per-frame delta ``[dpos(3), drot_axisangle(3), gripper(1)]``; only the rotation is
+re-encoded to the requested ``rotation_space`` -> ``[dpos(3), rot6d(6), gripper(1)]``
+(10D for ``6d``).
+
+NOTE on FPS / stats fidelity: the bundled ``quantile_rot`` stats were computed on
+a 20 FPS conversion. Per-frame deltas at 10 FPS span 2x the wall-clock motion, so
+use a 20 FPS LIBERO dataset (or recompute stats for the dataset's FPS).
+Loading/training is correct at any FPS regardless.
+"""
+
+from __future__ import annotations
+
+import json
+import random
+from pathlib import Path
+from typing import Any, Literal
+
+import numpy as np
+import pyarrow.parquet as pq
+import torch
+import torch.nn.functional as F
+from lerobot.datasets.video_utils import decode_video_frames
+
+from cosmos_framework.utils import log
+from cosmos_framework.data.vfm.action.action_normalization import normalize_action
+from cosmos_framework.data.vfm.action.action_spec import ActionSpec, Gripper, Pos, Rot, build_action_spec
+from cosmos_framework.data.vfm.action.datasets.base_dataset import ActionBaseDataset
+from cosmos_framework.data.vfm.action.libero_pose_utils import libero_action_dim, libero_rotation_format
+from cosmos_framework.data.vfm.action.pose_utils import convert_rotation
+
+CameraMode = Literal["image", "wrist_image", "concat_view"]
+RotationSpace = Literal["3d", "6d", "9d"]
+
+_ACTION_FEATURE = "action"
+_IMAGE_FEATURE = "observation.images.image"
+_WRIST_FEATURE = "observation.images.wrist_image"
+_STAT_KEYS = ("mean", "std", "min", "max", "q01", "q99")
+_NORMALIZERS_DIR = Path(__file__).parent / "stats"
+
+_VIEWPOINT_BY_CAMERA = {
+ "image": "third_person_view",
+ "wrist_image": "wrist_view",
+ "concat_view": "concat_view",
+}
+
+
+class LIBEROLeRobotDataset(ActionBaseDataset):
+ """LIBERO action-policy dataset with frame-wise-relative rot6d actions.
+
+ 10D ``[pos_delta(3), rot6d_delta(6), gripper(1)]`` (for ``rotation_space='6d'``),
+ ``concat_view`` third-person + wrist video, and ``quantile_rot`` normalization
+ against the bundled stats. Reads parquet + decodes video at real timestamps,
+ so the requested ``fps`` is metadata only (it sets ``conditioning_fps`` and the
+ prompt duration); frame windows always use the data's actual frames.
+ """
+
+ def __init__(
+ self,
+ root: str,
+ fps: float = 20.0,
+ chunk_length: int = 16,
+ mode: str = "policy",
+ tolerance_s: float = 1e-4,
+ camera_mode: CameraMode = "concat_view",
+ image_size: int = 256,
+ action_space: str = "frame_wise_relative",
+ rotation_space: RotationSpace = "6d",
+ pose_coordinate_frame: str = "native",
+ embodiment_type: str = "libero",
+ action_normalization: str | None = "quantile_rot",
+ action_stats_path: str | None = None,
+ split: str = "train",
+ val_ratio: float = 0.01,
+ seed: int = 0,
+ sample_stride: int = 1,
+ ) -> None:
+ if action_space != "frame_wise_relative":
+ raise NotImplementedError(
+ f"This LIBERO dataset only supports action_space='frame_wise_relative', got {action_space!r}."
+ )
+ if camera_mode not in _VIEWPOINT_BY_CAMERA:
+ raise ValueError(f"Unsupported camera_mode={camera_mode!r}. Use image/wrist_image/concat_view.")
+ split = split.lower().strip()
+ if split not in {"train", "val", "valid", "validation", "eval", "test", "full"}:
+ raise ValueError(f"Unsupported split={split!r}. Use train/val/full.")
+ if chunk_length % 4 != 0:
+ raise ValueError(f"chunk_length must be divisible by 4, got {chunk_length}.")
+
+ super().__init__(
+ root=root,
+ domain_name=embodiment_type,
+ fps=fps,
+ chunk_length=chunk_length,
+ mode=mode,
+ pose_convention="backward_framewise", # unused for frame_wise deltas; satisfies the base assert
+ tolerance_s=tolerance_s,
+ viewpoint=_VIEWPOINT_BY_CAMERA[camera_mode],
+ # frame_wise_relative ⇔ backward_framewise idle semantics. quantile_rot is a
+ # LIBERO convention -> normalize with the "quantile" formula on raw-rotation
+ # stats (see _load_norm_stats); pass the method the base will call.
+ action_normalization=None if action_normalization is None else "quantile",
+ sample_stride=sample_stride,
+ )
+ # FPS-agnostic loader: trust the dataset's NATIVE fps for conditioning_fps /
+ # prompt duration so the metadata is truthful (10 for the public
+ # lerobot/libero_*, 20 for a 20 FPS conversion). Frame sampling uses each
+ # frame's real timestamp regardless, so the requested ``fps`` is ignored here.
+ info_fps = self._info.get("fps")
+ if info_fps:
+ if int(info_fps) != int(fps):
+ log.info(f"Using dataset native fps={info_fps} for conditioning (requested {fps}).")
+ self._fps = float(info_fps)
+ self._dt = 1.0 / self._fps
+ self._camera_mode = camera_mode
+ self._image_size = int(image_size)
+ self._rotation_space = rotation_space.lower().strip()
+ self._pose_coordinate_frame = pose_coordinate_frame
+ self._embodiment_type = embodiment_type
+ self._requested_normalization = action_normalization
+ # quantile_rot normalizes against the raw (un-orthonormalized) rotation stats
+ # under "global_raw"; everything else uses "global".
+ self._stats_key = "global_raw" if action_normalization == "quantile_rot" else "global"
+ self._stats_file = self._resolve_stats_file(action_stats_path)
+
+ if self._camera_mode == "image":
+ self._video_keys = [_IMAGE_FEATURE]
+ elif self._camera_mode == "wrist_image":
+ self._video_keys = [_WRIST_FEATURE]
+ else:
+ self._video_keys = [_IMAGE_FEATURE, _WRIST_FEATURE]
+
+ # Compact, lazy frame index (mirrors DROIDLeRobotDataset): read only the
+ # columns the sample builder needs into contiguous arrays, ordered by global
+ # frame index, so DataLoader worker forks share them copy-on-write.
+ index_parts, episode_parts, task_parts, ts_parts, action_parts = [], [], [], [], []
+ for path in sorted((self._root / "data").glob("chunk-*/file-*.parquet")):
+ table = pq.read_table(path, columns=["index", "episode_index", "task_index", "timestamp", _ACTION_FEATURE])
+ index_parts.append(table["index"].to_numpy())
+ episode_parts.append(table["episode_index"].to_numpy())
+ task_parts.append(table["task_index"].to_numpy())
+ ts_parts.append(table["timestamp"].to_numpy())
+ action_parts.append(np.asarray(table[_ACTION_FEATURE].to_pylist(), dtype=np.float32))
+ if not index_parts:
+ raise FileNotFoundError(f"No data parquet found under {self._root / 'data'}.")
+ order = np.argsort(np.concatenate(index_parts).astype(np.int64), kind="stable")
+ self._row_episode = np.concatenate(episode_parts).astype(np.int64)[order]
+ self._row_task = np.concatenate(task_parts).astype(np.int64)[order]
+ self._row_timestamp = np.concatenate(ts_parts).astype(np.float64)[order]
+ self._row_action = np.concatenate(action_parts, axis=0).astype(np.float32)[order]
+
+ assert np.all(np.diff(self._row_episode) >= 0), "episode_index not contiguous after sorting by frame index"
+ ep_vals, ep_starts, ep_counts = np.unique(self._row_episode, return_index=True, return_counts=True)
+
+ # Deterministic per-episode train/val split (seeded; same on every rank).
+ keep = self._split_episode_ids(ep_vals.tolist(), split, val_ratio, seed)
+ kept = np.array([int(v) in keep for v in ep_vals], dtype=bool)
+ self._ep_vals = ep_vals.astype(np.int64)[kept]
+ self._ep_starts = ep_starts.astype(np.int64)[kept]
+ kept_counts = ep_counts.astype(np.int64)[kept]
+ # Within-episode windows only: total - n_kept_episodes * chunk_length valid samples.
+ self._valid_cum = np.cumsum(np.maximum(0, kept_counts - self._chunk_length)).astype(np.int64)
+
+ log.info(
+ f"Loaded LIBERO dataset root={self._root} split={split!r} camera_mode={camera_mode!r} "
+ f"fps={self._fps} kept_episodes={len(self._ep_vals)}/{len(ep_vals)} "
+ f"valid_indices={int(self._valid_cum[-1]) if self._valid_cum.size else 0}"
+ )
+
+ # ---- spec / dims -------------------------------------------------------
+
+ @property
+ def action_dim(self) -> int:
+ return libero_action_dim(self._rotation_space)
+
+ def _action_spec(self) -> ActionSpec:
+ return build_action_spec(Pos(), Rot(libero_rotation_format(self._rotation_space)), Gripper())
+
+ @classmethod
+ def _stats_path(cls) -> Path:
+ # Base classmethod fallback; the instance uses self._stats_file (which also
+ # honors action_stats_path + the rotation/coordinate-frame-specific filename).
+ return _NORMALIZERS_DIR / "libero_native_frame_wise_relative_rot6d.json"
+
+ # ---- normalization (nested global/global_raw + quantile_rot) ------------
+
+ def _bundled_stats_filename(self) -> str:
+ rotation_suffix = {"3d": "3d", "6d": "rot6d", "9d": "rot9d"}.get(self._rotation_space)
+ if rotation_suffix is None:
+ raise ValueError(f"Unsupported rotation_space={self._rotation_space!r}.")
+ action_space = "frame_wise_relative"
+ return f"{self._embodiment_type}_{self._pose_coordinate_frame}_{action_space}_{rotation_suffix}.json"
+
+ def _resolve_stats_file(self, action_stats_path: str | None) -> Path:
+ if action_stats_path:
+ p = Path(action_stats_path)
+ if not p.is_absolute():
+ p = _NORMALIZERS_DIR / p.name
+ if not p.exists():
+ raise FileNotFoundError(f"action_stats_path not found: {action_stats_path!r}")
+ return p
+ p = _NORMALIZERS_DIR / self._bundled_stats_filename()
+ if not p.exists():
+ raise FileNotFoundError(
+ f"Bundled LIBERO stats not found at {p}. Pass action_stats_path or recompute stats."
+ )
+ return p
+
+ def _load_norm_stats(self) -> dict[str, torch.Tensor]:
+ if self._norm_stats is None:
+ raw = json.loads(self._stats_file.read_text())[self._stats_key]
+ self._norm_stats = {
+ k: torch.tensor(v, dtype=torch.float32) for k, v in raw.items() if k in _STAT_KEYS
+ }
+ return self._norm_stats
+
+ # ---- index helpers -----------------------------------------------------
+
+ @staticmethod
+ def _split_episode_ids(ep_ids: list[int], split: str, val_ratio: float, seed: int) -> set[int]:
+ if split == "full":
+ return set(int(v) for v in ep_ids)
+ if not (0.0 < val_ratio < 1.0):
+ raise ValueError(f"val_ratio must be in (0, 1), got {val_ratio}.")
+ n_val = max(1, int(round(len(ep_ids) * val_ratio)))
+ rng = random.Random(seed) # identical selection on every rank
+ val = set(int(v) for v in rng.sample(list(ep_ids), n_val))
+ if split == "train":
+ return set(int(v) for v in ep_ids) - val
+ return val # val/valid/validation/eval/test
+
+ def __len__(self) -> int:
+ return int(self._valid_cum[-1]) if self._valid_cum.size else 0
+
+ def get_shuffle_blocks(self) -> list[tuple[int, int]]:
+ """Per-episode ``(start, length)`` flat-index blocks for
+ ``ActionIterableShuffleDataset`` (shuffle block ORDER + shard across
+ ranks, sequential within a block)."""
+ blocks: list[tuple[int, int]] = []
+ prev = 0
+ for c in np.asarray(self._valid_cum).tolist():
+ c = int(c)
+ if c > prev:
+ blocks.append((prev, c - prev))
+ prev = c
+ return blocks
+
+ # ---- sample build ------------------------------------------------------
+
+ def __getitem__(self, idx: int) -> dict[str, Any]:
+ # Resample a different valid window if a frame fails to decode (bounded retries).
+ n = len(self)
+ last_err: Exception | None = None
+ for _attempt in range(8):
+ try:
+ return self._build_item(idx)
+ except Exception as e: # noqa: BLE001 — skip past undecodable frames
+ last_err = e
+ log.warning(f"LIBERO: sample idx={idx} failed to load ({type(e).__name__}: {e}); resampling")
+ if n > 0:
+ idx = random.randint(0, n - 1)
+ raise RuntimeError(f"LIBERO: failed to load a sample after 8 resamples; last error: {last_err}")
+
+ def _build_item(self, idx: int) -> dict[str, Any]:
+ mode = self._choose_mode()
+ idx = int(idx)
+ ep = int(np.searchsorted(self._valid_cum, idx, side="right"))
+ prev = int(self._valid_cum[ep - 1]) if ep > 0 else 0
+ start = int(self._ep_starts[ep]) + (idx - prev)
+ episode_index = int(self._ep_vals[ep])
+ episode = self._episodes[episode_index]
+
+ stop = start + self._chunk_length + 1
+ timestamps = [float(self._row_timestamp[j]) for j in range(start, stop)]
+ video = self._load_video(episode, timestamps)
+
+ # frame_wise_relative: chunk per-frame deltas are the stored actions directly.
+ raw = self._row_action[start : start + self._chunk_length] # [chunk, 7]
+ action = self._build_frame_wise_action(raw)
+
+ task = self._tasks[int(self._row_task[start])]
+ ai_caption = random.choice([p.strip() for p in task.split(" | ") if p.strip()] or [task])
+
+ extras: dict[str, Any] = {}
+ if self._camera_mode == "concat_view":
+ extras["additional_view_description"] = (
+ "The left half shows the third-person view; the right half shows the wrist-mounted camera."
+ )
+ return self._build_result(mode=mode, video=video, action=action, ai_caption=ai_caption, **extras)
+
+ def _build_frame_wise_action(self, raw: np.ndarray) -> torch.Tensor:
+ raw_t = torch.from_numpy(np.ascontiguousarray(raw)).float() # [chunk, 7]
+ translation = raw_t[:, 0:3]
+ rotation_matrix = convert_rotation(raw_t[:, 3:6], input_format="axisangle", output_format="matrix")
+ rotation = convert_rotation(
+ rotation_matrix, input_format="matrix", output_format=libero_rotation_format(self._rotation_space)
+ )
+ gripper = raw_t[:, 6:7]
+ return torch.cat([translation, rotation, gripper], dim=-1) # [chunk, action_dim]
+
+ def _load_video(self, episode: dict[str, Any], timestamps: list[float]) -> torch.Tensor:
+ frames_by_view = {}
+ for key in self._video_keys:
+ from_ts = float(episode.get(f"videos/{key}/from_timestamp", 0.0))
+ frames = decode_video_frames(
+ self._video_path(episode, key),
+ [from_ts + ts for ts in timestamps],
+ self._tolerance_s,
+ ) # [T, C, H, W] in [0, 1]
+ frames = self._resize(frames)
+ frames_by_view[key] = frames
+ if self._camera_mode == "concat_view":
+ # third-person (left) + wrist (right), horizontally concatenated -> [T, C, H, 2W]
+ return torch.cat([frames_by_view[_IMAGE_FEATURE], frames_by_view[_WRIST_FEATURE]], dim=-1)
+ return frames_by_view[self._video_keys[0]]
+
+ def _resize(self, frames: torch.Tensor) -> torch.Tensor:
+ if frames.shape[-1] == self._image_size and frames.shape[-2] == self._image_size:
+ return frames
+ return F.interpolate(
+ frames, size=(self._image_size, self._image_size), mode="bilinear", align_corners=False
+ )
diff --git a/cosmos_framework/data/vfm/action/datasets/stats/libero_native_frame_wise_relative_rot6d.json b/cosmos_framework/data/vfm/action/datasets/stats/libero_native_frame_wise_relative_rot6d.json
new file mode 100644
index 0000000..a705e7c
--- /dev/null
+++ b/cosmos_framework/data/vfm/action/datasets/stats/libero_native_frame_wise_relative_rot6d.json
@@ -0,0 +1,37 @@
+{
+ "metadata": {
+ "embodiment_type": "libero",
+ "pose_convention": "frame_wise_relative",
+ "pose_coordinate_frame": "native",
+ "rotation_format": "6d",
+ "action_dim": 10,
+ "skip_rotation_dims": [3, 4, 5, 6, 7, 8],
+ "chunk_length": 16,
+ "sample_stride": null,
+ "dataset_name": "libero",
+ "dataset_class": "LIBEROLeRobotDataset",
+ "dataset_root": ["outputs/libero_datasets/libero_10", "outputs/libero_datasets/libero_object", "outputs/libero_datasets/libero_spatial", "outputs/libero_datasets/libero_goal"],
+ "_comment": "Dataset paths are placeholders; the statistics values are independent of local dataset location.",
+ "split": "train",
+ "num_samples_stats": 10000,
+ "reservoir_size": 50000,
+ "max_samples": 10000,
+ "sampling_seed": 42
+ },
+ "global": {
+ "mean": [ 0.050704, 0.097407, -0.094833, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.476725],
+ "std": [ 0.333621, 0.387175, 0.457140, 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, 0.499460],
+ "min": [-0.937500, -0.937500, -0.937500, -1.000000, -1.000000, -1.000000, -1.000000, -1.000000, -1.000000, 0.000000],
+ "max": [ 0.937500, 0.937500, 0.937500, 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, 1.000000],
+ "q01": [-0.723214, -0.808929, -0.937500, -1.000000, -1.000000, -1.000000, -1.000000, -1.000000, -1.000000, 0.000000],
+ "q99": [ 0.937500, 0.870536, 0.937500, 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, 1.000000]
+ },
+ "global_raw": {
+ "mean": [ 0.050704, 0.097407, -0.094833, 0.994873, -0.004579, -0.004288, 0.004389, 0.996104, 0.001109, 0.476725],
+ "std": [ 0.333621, 0.387175, 0.457140, 0.010807, 0.077802, 0.063386, 0.078571, 0.009994, 0.038504, 0.499460],
+ "min": [-0.937500, -0.937500, -0.937500, 0.902028, -0.356085, -0.367416, -0.370434, 0.921907, -0.255000, 0.000000],
+ "max": [ 0.937500, 0.937500, 0.937500, 1.000000, 0.368853, 0.341214, 0.356395, 1.000000, 0.348251, 1.000000],
+ "q01": [-0.723214, -0.808929, -0.937500, 0.934955, -0.223431, -0.189878, -0.334735, 0.938516, -0.107736, 0.000000],
+ "q99": [ 0.937500, 0.870536, 0.937500, 1.000000, 0.331000, 0.163153, 0.226216, 1.000000, 0.127158, 1.000000]
+ }
+}
diff --git a/cosmos_framework/data/vfm/action/libero_pose_utils.py b/cosmos_framework/data/vfm/action/libero_pose_utils.py
new file mode 100644
index 0000000..3a4fd8e
--- /dev/null
+++ b/cosmos_framework/data/vfm/action/libero_pose_utils.py
@@ -0,0 +1,69 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: OpenMDW-1.1
+
+"""Small LIBERO pose helpers shared by training and closed-loop eval."""
+
+from __future__ import annotations
+
+import numpy as np
+import torch
+
+from cosmos_framework.data.vfm.action.pose_utils import (
+ RotationConvention,
+ build_abs_pose_from_components,
+)
+
+# Local-frame post-rotation pattern:
+# R_opencv = R_native @ *_TO_OPENCV.
+LIBERO_TO_OPENCV: np.ndarray = np.array(
+ [[0.0, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]],
+ dtype=np.float32,
+)
+
+LIBERO_ROTATION_FORMATS: dict[str, RotationConvention] = {
+ "3d": "axisangle",
+ "6d": "rot6d",
+ "9d": "rot9d",
+}
+LIBERO_ACTION_DIMS: dict[str, int] = {"3d": 7, "6d": 10, "9d": 13}
+
+
+def libero_rotation_format(rotation_space: str) -> RotationConvention:
+ """Return the shared ``pose_utils`` rotation format for a LIBERO setting."""
+ rotation_format = LIBERO_ROTATION_FORMATS.get(rotation_space)
+ if rotation_format is None:
+ raise ValueError(f"Unsupported rotation_space={rotation_space!r}. Use 3d/6d/9d.")
+ return rotation_format
+
+
+def libero_action_dim(rotation_space: str) -> int:
+ """Return ``[xyz, rotation, gripper]`` action width for LIBERO."""
+ action_dim = LIBERO_ACTION_DIMS.get(rotation_space)
+ if action_dim is None:
+ raise ValueError(f"Unsupported rotation_space={rotation_space!r}. Use 3d/6d/9d.")
+ return action_dim
+
+
+def libero_rotation_space_from_action_dim(action_dim: int) -> str:
+ """Infer LIBERO rotation space from unpadded action width."""
+ for rotation_space, dim in LIBERO_ACTION_DIMS.items():
+ if dim == action_dim:
+ return rotation_space
+ raise ValueError(f"Unable to infer rotation_space from action_dim={action_dim}.")
+
+
+def build_libero_abs_pose(state_raw: torch.Tensor | np.ndarray, *, to_opencv: bool) -> np.ndarray:
+ """Build absolute LIBERO EE poses from state rows.
+
+ ``state_raw`` is ``[x,y,z,axisangle(3),gripper(2)]``. When requested, the
+ local EE frame is post-rotated into the shared OpenCV-style action frame.
+ """
+ if isinstance(state_raw, torch.Tensor):
+ state_np = state_raw.detach().cpu().numpy().astype(np.float32, copy=False)
+ else:
+ state_np = np.asarray(state_raw, dtype=np.float32)
+
+ poses_abs = build_abs_pose_from_components(state_np[:, :3], state_np[:, 3:6], "axisangle")
+ if to_opencv:
+ poses_abs[:, :3, :3] = poses_abs[:, :3, :3] @ LIBERO_TO_OPENCV
+ return poses_abs
diff --git a/cosmos_framework/scripts/action_policy_server_libero.py b/cosmos_framework/scripts/action_policy_server_libero.py
index 7382b97..2fa0567 100644
--- a/cosmos_framework/scripts/action_policy_server_libero.py
+++ b/cosmos_framework/scripts/action_policy_server_libero.py
@@ -63,6 +63,10 @@
# Action-specific helpers live in the in-tree project tree. Imports stay as
# `projects.cosmos3.vfm.*` and are auto-rewritten to `cosmos3._src.vfm.*` by the
# cosmos-framework release script.
+from cosmos_framework.data.vfm.action.action_processing import (
+ ActionProcessingRecord,
+ make_batched_action_processing_fields,
+)
from cosmos_framework.data.vfm.action.domain_utils import get_domain_id
from cosmos_framework.data.vfm.action.transforms import (
build_sequence_plan_from_mode,
@@ -431,7 +435,8 @@ class ActionServerArgs(pydantic.BaseModel):
# ``OmniSetupOverrides`` programmatically in ``build_setup_overrides``.
checkpoint: tyro.conf.OmitArgPrefixes[CheckpointOverrides] = CheckpointOverrides.model_construct()
- """Checkpoint and config loading configuration."""
+ """Checkpoint and config loading configuration. ``use_ema_weights`` lives here and
+ defaults True at inference (suppressed from CLI) -> evals load net_ema by default."""
output_dir: Path | None = None
"""Output directory for ``OmniInference`` (saved config.yaml, benchmarks).
@@ -804,101 +809,53 @@ def get_info(self) -> dict[str, Any]:
# Predict
# ------------------------------------------------------------------
- def predict_policy(self, req: dict[str, Any]) -> dict[str, Any]:
- """
- Run policy inference: given an observation image and prompt, predict actions.
-
- Input request format:
- {
- "image": "",
- "prompt": "",
- "domain_name": "",
- "image_size":
- }
-
- Output format:
- {
- "action": [[a0_0, a0_1, ...], ..., [aN_0, aN_1, ...]],
- "video": ["", ...] # List of T base64-encoded PNG frames
- }
-
- All action dimensions are returned. Video is the decoded predicted rollout as base64 PNGs.
- """
- t0 = time.monotonic()
-
- # Get or assign request ID
- injected_id = req.get("request_id", None)
- if isinstance(injected_id, int) and injected_id > 0:
- request_id = int(injected_id)
- else:
- with self._req_id_lock:
- self._req_id += 1
- request_id = int(self._req_id)
+ def _input_video_key(self) -> str:
+ input_video_key = getattr(self.model, "input_video_key", None)
+ if input_video_key is None:
+ input_video_key = getattr(self.model, "config", None).input_video_key # type: ignore[union-attr]
+ return input_video_key
- # Validate request
+ def _prep_policy_item(self, req: dict[str, Any]) -> dict[str, Any]:
+ """Validate one request and build the per-sample model inputs (video pad,
+ prompt augmentation, sequence_plan). Shared by predict_policy (batch=1) and
+ predict_policy_batch (batch=N) so the two paths stay byte-identical per item."""
image_b64 = req.get("image")
if not isinstance(image_b64, str):
raise ValueError("'image' must be a base64 string")
-
prompt = req.get("prompt")
if not isinstance(prompt, str):
raise ValueError("'prompt' must be a string")
-
domain_name = req.get("domain_name")
if not isinstance(domain_name, str):
raise ValueError("'domain_name' must be a string")
-
image_size = req.get("image_size")
if not isinstance(image_size, int) or image_size <= 0:
raise ValueError("'image_size' must be a positive integer")
- # Decode image
- t_decode0 = time.monotonic()
img_chw_uint8 = _decode_base64_png_to_rgb_uint8(image_b64)
img_h, img_w = img_chw_uint8.shape[-2:]
-
- # Handle resizing: for multi-view (non-square) images, scale proportionally
- # to maintain aspect ratio while matching height to image_size
+ # Multi-view (non-square) images: scale proportionally, matching height to image_size.
if img_h != image_size:
- # Calculate new width to maintain aspect ratio
scale = image_size / img_h
new_w = int(round(img_w * scale))
- hwc = img_chw_uint8.permute(1, 2, 0).cpu().numpy() # [H,W,3]
+ hwc = img_chw_uint8.permute(1, 2, 0).cpu().numpy()
resized = Image.fromarray(hwc).resize((new_w, image_size), resample=Image.Resampling.BILINEAR)
arr = np.asarray(resized, dtype=np.uint8).copy()
- img_chw_uint8 = torch.from_numpy(arr).permute(2, 0, 1).contiguous() # [3,H,W] # [3,H,W]
- t_decode1 = time.monotonic()
+ img_chw_uint8 = torch.from_numpy(arr).permute(2, 0, 1).contiguous()
- # Construct batch in IterativeJointDataLoader format (list-of-lists for multi-item keys)
t_frames = self.cfg.action_chunk_size + 1
_, final_h, final_w = img_chw_uint8.shape
video_c_t_h_w_uint8 = img_chw_uint8.unsqueeze(1).repeat(1, t_frames, 1, 1) # [3,T,H,W]
-
- # Apply reflection padding to match closest predefined resolution
resolution = get_vision_data_resolution((final_h, final_w))
target_w, target_h = find_closest_target_size(final_h, final_w, resolution)
pad_dict: dict[str, Any] = {"video": video_c_t_h_w_uint8}
reflection_pad_to_target(pad_dict, ["video"], True, target_w, target_h)
- video_padded = pad_dict["video"] # (C, T, target_h, target_w)
- padded_image_size = pad_dict["image_size"] # (4,)
-
- # Action: zeros tensor as noise starting point for policy mode
- action_t_d = torch.zeros(
- (self.cfg.action_chunk_size, self.cfg.max_action_dim),
- dtype=torch.float32,
- ) # [T,action_dim]
-
- input_video_key = getattr(self.model, "input_video_key", None)
- if input_video_key is None:
- input_video_key = getattr(self.model, "config", None).input_video_key # type: ignore[union-attr]
-
sequence_plan = build_sequence_plan_from_mode(
mode="policy",
video_length=self.cfg.action_chunk_size + 1,
action_length=self.cfg.action_chunk_size,
has_text=True,
)
-
augmented_prompt = _augment_prompt_with_metadata(
prompt,
t_frames=t_frames,
@@ -908,10 +865,126 @@ def predict_policy(self, req: dict[str, Any]) -> dict[str, Any]:
append_duration_fps=self.append_duration_fps,
append_resolution_info=self.append_resolution_info,
)
+ return {
+ "img_chw_uint8": img_chw_uint8,
+ "video_padded": pad_dict["video"],
+ "padded_image_size": pad_dict["image_size"],
+ "augmented_prompt": augmented_prompt,
+ "sequence_plan": sequence_plan,
+ "domain_name": domain_name,
+ "image_size": image_size,
+ }
+
+ def predict_policy_batch(self, reqs: list[dict[str, Any]]) -> dict[str, Any]:
+ """Batched policy inference: N requests -> ONE diffusion forward (batch_size=N)
+ -> N denormalized action chunks. Skips vision decode (the vectorized eval client
+ only needs actions), so it is ~N x faster than N serial /predict calls."""
+ t0 = time.monotonic()
+ if not isinstance(reqs, list) or not reqs:
+ raise ValueError("'items' must be a non-empty list of policy requests")
+ preps = [self._prep_policy_item(r) for r in reqs]
+ n = len(preps)
+ action_t_d = torch.zeros(
+ (self.cfg.action_chunk_size, self.cfg.max_action_dim), dtype=torch.float32
+ )
+ input_video_key = self._input_video_key()
+ batch: dict[str, Any] = {
+ input_video_key: [[p["video_padded"]] for p in preps],
+ **make_batched_action_processing_fields(
+ ActionProcessingRecord(raw_action_dim=self.raw_action_dim, action_normalizer=None),
+ batch_size=n,
+ ),
+ "action": [[action_t_d] for _ in preps],
+ "mode": ["policy"] * n,
+ "ai_caption": [p["augmented_prompt"] for p in preps],
+ "prompt": [p["augmented_prompt"] for p in preps],
+ "conditioning_fps": [torch.tensor(self.cfg.fps, dtype=torch.long) for _ in preps],
+ "image_size": torch.stack([p["padded_image_size"] for p in preps]).to(device="cuda"),
+ "domain_id": [torch.tensor(get_domain_id(p["domain_name"]), dtype=torch.long) for p in preps],
+ "sequence_plan": [p["sequence_plan"] for p in preps],
+ }
+ t_inf0 = time.monotonic()
+ with self._lock:
+ with torch.inference_mode():
+ samples = self.model.generate_samples_from_batch(
+ batch,
+ guidance=self.cfg.guidance,
+ seed=[self.cfg.seed] * n,
+ num_steps=self.cfg.num_steps,
+ has_negative_prompt=False,
+ )
+ t_inf1 = time.monotonic()
+ actions: list[list[list[float]]] = []
+ for i in range(n):
+ pred = samples["action"][i].float().squeeze(0) # [T,D]
+ pred = self._denormalize_action(pred)
+ actions.append(pred.detach().cpu().numpy().tolist())
+ log.info(
+ f"[action-server] predict_batch n={n} steps={self.cfg.num_steps} "
+ f"ms_total={(time.monotonic() - t0) * 1000.0:.1f} ms_infer={(t_inf1 - t_inf0) * 1000.0:.1f}"
+ )
+ return {"actions": actions}
+
+ def predict_policy(self, req: dict[str, Any]) -> dict[str, Any]:
+ """
+ Run policy inference: given an observation image and prompt, predict actions.
+
+ Input request format:
+ {
+ "image": "",
+ "prompt": "",
+ "domain_name": "",
+ "image_size":
+ }
+
+ Output format:
+ {
+ "action": [[a0_0, a0_1, ...], ..., [aN_0, aN_1, ...]],
+ "video": ["", ...] # List of T base64-encoded PNG frames
+ }
+
+ All action dimensions are returned. Video is the decoded predicted rollout as base64 PNGs.
+ """
+ t0 = time.monotonic()
+
+ # Get or assign request ID
+ injected_id = req.get("request_id", None)
+ if isinstance(injected_id, int) and injected_id > 0:
+ request_id = int(injected_id)
+ else:
+ with self._req_id_lock:
+ self._req_id += 1
+ request_id = int(self._req_id)
+
+ # Per-item preprocessing (validation, decode/resize/pad, prompt, sequence_plan).
+ t_decode0 = time.monotonic()
+ prep = self._prep_policy_item(req)
+ t_decode1 = time.monotonic()
+ img_chw_uint8 = prep["img_chw_uint8"]
+ video_padded = prep["video_padded"]
+ padded_image_size = prep["padded_image_size"]
+ augmented_prompt = prep["augmented_prompt"]
+ sequence_plan = prep["sequence_plan"]
+ domain_name = prep["domain_name"]
+ image_size = prep["image_size"]
+
+ # Action: zeros tensor as noise starting point for policy mode
+ action_t_d = torch.zeros(
+ (self.cfg.action_chunk_size, self.cfg.max_action_dim),
+ dtype=torch.float32,
+ ) # [T,action_dim]
+
+ input_video_key = self._input_video_key()
batch: dict[str, Any] = {
input_video_key: [[video_padded]],
- "raw_action_dim": [torch.tensor(self.raw_action_dim, dtype=torch.long)],
+ # Provide BOTH raw_action_dim and the action_processing_record the model
+ # needs to externalize (invert) the generated action; building the batch
+ # by hand previously omitted the record -> "cannot be externalized".
+ **make_batched_action_processing_fields(
+ ActionProcessingRecord(raw_action_dim=self.raw_action_dim, action_normalizer=None),
+ batch_size=1,
+ ),
"action": [[action_t_d]],
"mode": ["policy"],
"ai_caption": [augmented_prompt],
@@ -1103,7 +1176,7 @@ def do_GET(self) -> None: # noqa: N802
self._send_json(404, {"error": "Not found"})
def do_POST(self) -> None: # noqa: N802
- if self.path not in ("/", "/predict"):
+ if self.path not in ("/", "/predict", "/predict_batch"):
self._send_json(404, {"error": "Not found"})
return
@@ -1147,13 +1220,21 @@ def do_POST(self) -> None: # noqa: N802
f"path={self.path} bytes={length}"
)
+ is_batch = self.path == "/predict_batch"
try:
- out = service.predict_policy(req)
+ if is_batch:
+ out = service.predict_policy_batch(req.get("items", []))
+ else:
+ out = service.predict_policy(req)
except Exception as e:
err = str(e)
traceback.print_exc()
- payload = {"action": [], "error": err, "request_id": req.get("request_id")}
+ payload = (
+ {"actions": [], "error": err}
+ if is_batch
+ else {"action": [], "error": err, "request_id": req.get("request_id")}
+ )
log.error(f"[action-server] request_id={req.get('request_id')} ERROR: {err}")
# Dump failed request for offline debugging if enabled.
diff --git a/cosmos_framework/simulation/__init__.py b/cosmos_framework/simulation/__init__.py
new file mode 100644
index 0000000..28a81be
--- /dev/null
+++ b/cosmos_framework/simulation/__init__.py
@@ -0,0 +1,2 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: OpenMDW-1.1
diff --git a/cosmos_framework/simulation/libero/__init__.py b/cosmos_framework/simulation/libero/__init__.py
new file mode 100644
index 0000000..503ec1b
--- /dev/null
+++ b/cosmos_framework/simulation/libero/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: OpenMDW-1.1
+
diff --git a/cosmos_framework/simulation/libero/closed_loop_eval.py b/cosmos_framework/simulation/libero/closed_loop_eval.py
new file mode 100644
index 0000000..660be36
--- /dev/null
+++ b/cosmos_framework/simulation/libero/closed_loop_eval.py
@@ -0,0 +1,1343 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: OpenMDW-1.1
+
+"""
+Closed-loop evaluation for LIBERO using the Action HTTP inference server.
+
+# Single-view example (agentview camera):
+PYTHONPATH=. python cosmos_framework/simulation/libero/closed_loop_eval.py \
+ --server_url http://localhost:8000 \
+ --task_suite libero_10 \
+ --num_trials_per_task 10 \
+ --action_horizon 16 \
+ --camera agentview \
+ --save_gifs --gif_fps 20 \
+ --action_space frame_wise_relative \
+ --rotation_space 6d \
+ --action_dim 10 \
+ --output_dir results/libero_closed_loop_10_single_view
+
+# Multi-view example (agentview + wrist cameras):
+PYTHONPATH=. python cosmos_framework/simulation/libero/closed_loop_eval.py \
+ --server_url http://localhost:8000 \
+ --task_suite libero_goal \
+ --num_trials_per_task 2 \
+ --action_horizon 16 \
+ --camera agentview,wrist \
+ --save_gifs --gif_fps 20 \
+ --action_space frame_wise_relative \
+ --rotation_space 6d \
+ --action_dim 10 \
+ --output_dir results/libero_closed_loop_goal_multiview
+"""
+
+from __future__ import annotations
+
+import argparse
+import base64
+import io
+import json
+import os
+import random
+import sys
+import time
+from dataclasses import dataclass
+from collections.abc import Callable
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+import requests
+from PIL import Image
+from scipy.spatial.transform import Rotation as R
+
+from cosmos_framework.data.vfm.action.libero_pose_utils import (
+ libero_rotation_format,
+ libero_rotation_space_from_action_dim,
+)
+from cosmos_framework.data.vfm.action.pose_utils import convert_rotation
+from cosmos_framework.data.vfm.action.viewpoint_utils import DEFAULT_VIEWPOINT_TEMPLATES
+
+benchmark: Any
+get_libero_path: Any
+OffScreenRenderEnv: Any
+
+
+TASK_MAX_STEPS: dict[str, int] = {
+ "libero_spatial": 220,
+ "libero_object": 280,
+ "libero_goal": 300,
+ "libero_10": 520,
+ "libero_90": 400,
+}
+
+
+_CAMERA_PROMPT_NAMES: dict[str, str] = {
+ "agentview": "third-person view",
+ "wrist": "wrist-mounted camera",
+}
+
+
+def _append_prompt_sentence(prompt: str, sentence: str) -> str:
+ """Append one metadata sentence using the same separator convention as training augmentors."""
+ if sentence in prompt:
+ return prompt
+ prompt = prompt.rstrip()
+ if not prompt:
+ return sentence.rstrip()
+ separator = " " if prompt.rstrip().endswith(".") else ". "
+ return prompt + separator + sentence.rstrip()
+
+
+def _concat_view_layout_description(cameras: list[str]) -> str:
+ """Describe the horizontal camera layout sent by ``ActionEnvironmentClient``."""
+ camera_names = [_CAMERA_PROMPT_NAMES[camera] for camera in cameras]
+ if len(camera_names) == 2:
+ return f"The left half shows the {camera_names[0]}; the right half shows the {camera_names[1]}."
+ layout = ", ".join(camera_names)
+ return f"The views are concatenated horizontally from left to right as: {layout}."
+
+
+def _augment_task_prompt_with_viewpoint(task_description: str, cameras: list[str]) -> str:
+ """Concat-view caption augmentation for closed-loop LIBERO eval."""
+ if len(cameras) <= 1:
+ return task_description
+ prompt = _append_prompt_sentence(task_description, DEFAULT_VIEWPOINT_TEMPLATES["concat_view"])
+ return _append_prompt_sentence(prompt, _concat_view_layout_description(cameras))
+
+
+def _rotation_repr_to_mat(rotation: np.ndarray, rotation_space: str) -> np.ndarray:
+ """Convert a single LIBERO rotation block to a 3x3 rotation matrix."""
+ matrix = convert_rotation(
+ rotation,
+ libero_rotation_format(rotation_space),
+ "matrix",
+ normalize_matrix=rotation_space != "3d",
+ )
+ if not isinstance(matrix, np.ndarray):
+ raise TypeError(f"Expected NumPy rotation matrix, got {type(matrix)!r}")
+ return matrix
+
+
+@dataclass
+class EpisodeResult:
+ success: bool
+ steps: int
+ error: str | None
+ actions: list[list[float]]
+
+
+class ActionEnvironmentClient:
+ """Client for interacting with the Action model server."""
+
+ server_url: str
+ domain_name: str
+ prompt: str
+ image_size: int
+ timeout: float
+
+ def __init__(
+ self,
+ server_url: str,
+ domain_name: str,
+ prompt: str,
+ image_size: int,
+ timeout: float,
+ ) -> None:
+ self.server_url = server_url.rstrip("/")
+ self.domain_name = domain_name
+ self.prompt = prompt
+ self.image_size = image_size
+ self.timeout = timeout
+
+ def check_health(self) -> bool:
+ """Check if the model server is healthy."""
+ try:
+ resp = requests.get(f"{self.server_url}/", timeout=5.0)
+ return resp.status_code == 200
+ except requests.RequestException:
+ return False
+
+ def get_info(self) -> dict[str, str]:
+ """Get model server info."""
+ resp = requests.get(f"{self.server_url}/info", timeout=5.0)
+ resp.raise_for_status()
+ return resp.json()
+
+ def notify_next_episode(self) -> None:
+ """Notify server to advance to next episode (used with dataset action server)."""
+ try:
+ requests.post(
+ f"{self.server_url}/next_episode",
+ json={"prompt": self.prompt},
+ timeout=5.0,
+ )
+ except requests.RequestException:
+ pass
+
+ def encode_image(self, image: np.ndarray) -> str:
+ """Encode a numpy image (H, W, 3) uint8 to base64 PNG, resizing to image_size."""
+ if image.dtype != np.uint8:
+ if image.max() <= 1.0:
+ image = (image * 255.0).round().astype(np.uint8)
+ else:
+ image = image.astype(np.uint8)
+ pil_img = Image.fromarray(image)
+ if pil_img.size != (self.image_size, self.image_size):
+ pil_img = pil_img.resize(
+ (self.image_size, self.image_size),
+ resample=Image.Resampling.BILINEAR,
+ )
+ buf = io.BytesIO()
+ pil_img.save(buf, format="PNG")
+ return base64.b64encode(buf.getvalue()).decode("ascii")
+
+ def encode_image_raw(self, image: np.ndarray) -> str:
+ """Encode a numpy image (H, W, 3) uint8 to base64 PNG without resizing."""
+ if image.dtype != np.uint8:
+ if image.max() <= 1.0:
+ image = (image * 255.0).round().astype(np.uint8)
+ else:
+ image = image.astype(np.uint8)
+ pil_img = Image.fromarray(image)
+ buf = io.BytesIO()
+ pil_img.save(buf, format="PNG")
+ return base64.b64encode(buf.getvalue()).decode("ascii")
+
+ def resize_image(self, image: np.ndarray) -> np.ndarray:
+ """Resize image to model input size."""
+ if image.dtype != np.uint8:
+ if image.max() <= 1.0:
+ image = (image * 255.0).round().astype(np.uint8)
+ else:
+ image = image.astype(np.uint8)
+ pil_img = Image.fromarray(image)
+ if pil_img.size != (self.image_size, self.image_size):
+ pil_img = pil_img.resize(
+ (self.image_size, self.image_size),
+ resample=Image.Resampling.BILINEAR,
+ )
+ return np.array(pil_img)
+
+ def concatenate_images(self, images: list[np.ndarray]) -> np.ndarray:
+ """Resize each image and concatenate horizontally (side-by-side).
+
+ Args:
+ images: List of images with shape (H, W, 3).
+
+ Returns:
+ Concatenated image with shape (image_size, image_size*num_views, 3).
+ """
+ resized = [self.resize_image(img) for img in images]
+ return np.concatenate(resized, axis=1)
+
+ def predict(self, observation: np.ndarray | list[np.ndarray]) -> dict[str, Any]:
+ """Send observation(s) to model server and get predicted actions.
+
+ Args:
+ observation: Single image as np.ndarray or list of images for multi-view.
+ For multi-view, images are resized and concatenated horizontally before sending.
+ """
+ if isinstance(observation, list):
+ # Multi-view: resize each, concatenate horizontally, and send as single image
+ concatenated = self.concatenate_images(observation)
+ encoded = self.encode_image_raw(concatenated)
+ else:
+ # Single view: send single image
+ encoded = self.encode_image(observation)
+
+ payload = {
+ "image": encoded,
+ "prompt": self.prompt,
+ "domain_name": self.domain_name,
+ "image_size": self.image_size,
+ }
+
+ resp = requests.post(
+ f"{self.server_url}/predict",
+ json=payload,
+ headers={"Content-Type": "application/json"},
+ timeout=self.timeout,
+ )
+ resp.raise_for_status()
+
+ result = resp.json()
+ if "error" in result and result["error"]:
+ raise RuntimeError(f"Model server error: {result['error']}")
+ return result
+
+ def predict_batch(self, observations: list[list[np.ndarray]]) -> list[list[list[float]]]:
+ """Batched inference: a list of per-env multi-view observations -> ONE
+ POST /predict_batch -> a list of action chunks (one per env). Used by the
+ vectorized eval so N parallel envs share a single diffusion forward."""
+ items = []
+ for obs_imgs in observations:
+ concat = self.concatenate_images(obs_imgs) if len(obs_imgs) > 1 else self.resize_image(obs_imgs[0])
+ items.append(
+ {
+ "image": self.encode_image_raw(concat),
+ "prompt": self.prompt,
+ "domain_name": self.domain_name,
+ "image_size": self.image_size,
+ }
+ )
+ resp = requests.post(
+ f"{self.server_url}/predict_batch",
+ json={"items": items},
+ headers={"Content-Type": "application/json"},
+ timeout=max(self.timeout, 300.0),
+ )
+ resp.raise_for_status()
+ result = resp.json()
+ if "error" in result and result["error"]:
+ raise RuntimeError(f"Model server error: {result['error']}")
+ return result["actions"]
+
+
+def _find_accessible_dri_nodes() -> list[Path]:
+ dri_path = Path("/dev/dri")
+ if not dri_path.exists():
+ return []
+ nodes = list(dri_path.glob("renderD*")) + list(dri_path.glob("card*"))
+ return [node for node in nodes if os.access(node, os.R_OK | os.W_OK)]
+
+
+def _resolve_mujoco_backend(requested_backend: str) -> tuple[str, str]:
+ requested_backend = requested_backend.lower()
+ if requested_backend != "auto":
+ return requested_backend, "requested"
+
+ env_backend = os.environ.get("MUJOCO_GL")
+ if env_backend:
+ return env_backend.lower(), "env"
+
+ if _find_accessible_dri_nodes():
+ return "egl", "auto-gpu"
+ return "osmesa", "auto-cpu"
+
+
+def _configure_mujoco_env(requested_backend: str) -> str:
+ backend, source = _resolve_mujoco_backend(requested_backend)
+ if backend not in {"egl", "osmesa", "glfw"}:
+ raise ValueError(f"Unsupported MuJoCo GL backend: {backend!r}. Use auto, egl, osmesa, or glfw.")
+
+ os.environ["MUJOCO_GL"] = backend
+ if backend == "egl":
+ os.environ["PYOPENGL_PLATFORM"] = "egl"
+ elif backend == "osmesa":
+ os.environ["PYOPENGL_PLATFORM"] = "osmesa"
+ return f"{backend} ({source})"
+
+
+def _import_libero() -> None:
+ global benchmark, get_libero_path, OffScreenRenderEnv
+ try:
+ from libero.libero import benchmark as libero_benchmark
+ from libero.libero import get_libero_path as libero_get_libero_path
+ from libero.libero.envs import OffScreenRenderEnv as libero_offscreen_render_env
+ except ImportError as exc: # pragma: no cover - environment-specific dependency
+ raise RuntimeError(
+ "Failed to import LIBERO. Make sure the LIBERO environment is activated. "
+ f"python={sys.executable!r}, import_error={exc!r}"
+ ) from exc
+
+ benchmark = libero_benchmark
+ get_libero_path = libero_get_libero_path
+ OffScreenRenderEnv = libero_offscreen_render_env
+
+
+def _wait_for_server(client: ActionEnvironmentClient, timeout_s: float) -> None:
+ start = time.perf_counter()
+ while time.perf_counter() - start < timeout_s:
+ if client.check_health():
+ return
+ time.sleep(1.0)
+ raise RuntimeError(f"Timed out waiting for server at {client.server_url}")
+
+
+def _get_libero_env(
+ task: Any,
+ *,
+ resolution: int,
+ seed: int,
+ render_gpu_device_id: int,
+) -> tuple[Any, str]:
+ task_description = str(task.language)
+ task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
+ env_args = {
+ "bddl_file_name": task_bddl_file,
+ "camera_heights": resolution,
+ "camera_widths": resolution,
+ "render_gpu_device_id": render_gpu_device_id,
+ }
+ env = OffScreenRenderEnv(**env_args)
+ env.seed(seed)
+ return env, task_description
+
+
+def _get_libero_dummy_action() -> list[float]:
+ return [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.0]
+
+
+def _get_libero_image(
+ obs: dict[str, Any],
+ camera: str,
+ *,
+ flip_images: bool,
+ rotate_180: bool,
+) -> np.ndarray:
+ if camera == "agentview":
+ image = obs["agentview_image"]
+ elif camera == "wrist":
+ image = obs["robot0_eye_in_hand_image"]
+ else:
+ raise ValueError(f"Unsupported camera={camera!r}. Use 'agentview' or 'wrist'.")
+
+ if rotate_180:
+ image = image[::-1, ::-1]
+ if flip_images:
+ image = np.flipud(image)
+ return image
+
+
+def _get_libero_images(
+ obs: dict[str, Any],
+ cameras: list[str],
+ *,
+ flip_images: bool,
+ rotate_180: bool,
+) -> list[np.ndarray]:
+ """Get images from multiple cameras."""
+ return [_get_libero_image(obs, camera, flip_images=flip_images, rotate_180=rotate_180) for camera in cameras]
+
+
+def _ensure_uint8_image(image: np.ndarray) -> np.ndarray:
+ if image.dtype != np.uint8:
+ if image.max() <= 1.0:
+ image = (image * 255.0).round().astype(np.uint8)
+ else:
+ image = image.astype(np.uint8)
+ return image
+
+
+def _save_gif(frames: list[Image.Image], output_path: Path, fps: int) -> None:
+ if not frames:
+ return
+ duration_ms = int(1000 / fps) if fps > 0 else 100
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+ first, *rest = frames
+ first.save(
+ output_path,
+ save_all=True,
+ append_images=rest,
+ duration=duration_ms,
+ loop=0,
+ )
+
+
+def _decode_b64_frames(b64_frames: list[str]) -> list[Image.Image]:
+ """Decode a list of base64-encoded PNG strings into PIL Images."""
+ images: list[Image.Image] = []
+ for b64 in b64_frames:
+ raw = base64.b64decode(b64)
+ images.append(Image.open(io.BytesIO(raw)).convert("RGB"))
+ return images
+
+
+def _save_comparison_gif(
+ comparison_windows: list[tuple[list[Image.Image], list[Image.Image]]],
+ output_path: Path,
+ fps: int,
+ target_height: int = 256,
+ separator_width: int = 4,
+) -> None:
+ """Create and save a side-by-side comparison GIF (Action prediction | env rollout).
+
+ Each window is a (action_frames, env_frames) pair from one prediction call.
+ Frames are paired index-by-index; the conditioning frame (index 0) of
+ subsequent windows is skipped to avoid duplicating the boundary frame.
+ """
+ from PIL import ImageDraw
+
+ combined_frames: list[Image.Image] = []
+ banner_h = 16
+
+ for window_idx, (action_frames, env_frames) in enumerate(comparison_windows):
+ n = min(len(action_frames), len(env_frames))
+ start = 1 if window_idx > 0 else 0
+ for i in range(start, n):
+ action_img = action_frames[i]
+ env_img = env_frames[i]
+
+ action_w = int(action_img.width * target_height / action_img.height)
+ env_w = int(env_img.width * target_height / env_img.height)
+ action_resized = action_img.resize((action_w, target_height), Image.Resampling.BILINEAR)
+ env_resized = env_img.resize((env_w, target_height), Image.Resampling.BILINEAR)
+
+ total_w = action_w + separator_width + env_w
+ total_h = target_height + banner_h
+ combined = Image.new("RGB", (total_w, total_h), color=0)
+
+ draw = ImageDraw.Draw(combined)
+ draw.rectangle([(0, 0), (action_w, banner_h)], fill=(30, 30, 60))
+ draw.rectangle([(action_w + separator_width, 0), (total_w, banner_h)], fill=(30, 60, 30))
+ draw.text((4, 1), "Action Prediction", fill=(100, 180, 255))
+ draw.text((action_w + separator_width + 4, 1), "Environment", fill=(100, 255, 100))
+
+ combined.paste(action_resized, (0, banner_h))
+ combined.paste(env_resized, (action_w + separator_width, banner_h))
+ combined_frames.append(combined)
+
+ if combined_frames:
+ _save_gif(combined_frames, output_path, fps)
+
+
+def _select_action_chunk(actions: list[list[float]], action_horizon: int) -> list[list[float]]:
+ if action_horizon <= 0 or action_horizon >= len(actions):
+ return actions
+ return actions[:action_horizon]
+
+
+def _format_action(action: list[float], action_dim: int) -> list[float]:
+ if len(action) < action_dim:
+ raise ValueError(f"Action dimension {len(action)} smaller than expected {action_dim}")
+ return action[:action_dim]
+
+
+def _remap_gripper(action: list[float], mode: str) -> list[float]:
+ """Map the model's gripper command to the LIBERO env's [-1, 1] (negative = open).
+
+ The right mapping depends on the gripper convention of the dataset the policy
+ was trained on (the server denormalizes back to that raw convention):
+
+ * ``zero_one`` (NVIDIA LIBERO_LeRobot_v3): raw gripper in [0, 1]; the env wants
+ [-1, 1] with negative=open. The i4/cosmos-rl reference BINARIZES this to hard
+ {-1, +1} via ``-sign(2g - 1)`` (not the continuous ``1 - 2g`` from issue #50).
+ For a confident policy the two agree (g~0/1), but an undertrained policy emits
+ g~0.5 where continuous ``1-2g``~0 never actuates the gripper -> grasps fail.
+ Binarizing matches the reference and is robust to weak checkpoints.
+ * ``pm_one`` (community ``lerobot/libero_*``): raw gripper already in {-1, +1}
+ (robosuite convention) -> pass through (clamped).
+ * ``pm_one_flip``: {-1, +1} but with inverted open/close sign.
+ """
+ action = list(action) # avoid mutating the caller's list
+ g = action[-1]
+ if mode == "zero_one":
+ action[-1] = max(-1.0, min(1.0, g * 2.0 - 1.0)) * -1.0 # [0,1] -> [-1,1], negative=open (issue #50)
+ elif mode == "pm_one":
+ action[-1] = max(-1.0, min(1.0, g))
+ elif mode == "pm_one_flip":
+ action[-1] = max(-1.0, min(1.0, -g))
+ else:
+ raise ValueError(f"Unknown gripper_mode={mode!r}. Use zero_one/pm_one/pm_one_flip.")
+ return action
+
+
+def _infer_rotation_space(action_dim: int, rotation_space: str) -> str:
+ if rotation_space != "auto":
+ return rotation_space
+ return libero_rotation_space_from_action_dim(action_dim)
+
+
+def _obs_to_pose(obs: dict[str, Any]) -> tuple[np.ndarray, np.ndarray]:
+ position = np.asarray(obs["robot0_eef_pos"], dtype=np.float32)
+ quat = np.asarray(obs["robot0_eef_quat"], dtype=np.float32)
+ rotation = R.from_quat(quat).as_matrix()
+ return position, rotation
+
+
+def _anchored_action_to_delta(
+ anchored_action: np.ndarray,
+ base_pose: tuple[np.ndarray, np.ndarray],
+ current_pose: tuple[np.ndarray, np.ndarray],
+ rotation_space: str,
+) -> np.ndarray:
+ anchored_translation = anchored_action[:3]
+ rotation_dim = anchored_action.shape[0] - 4
+ anchored_rotation = anchored_action[3 : 3 + rotation_dim]
+ gripper = anchored_action[3 + rotation_dim : 4 + rotation_dim]
+
+ base_pos, base_rot = base_pose
+ current_pos, current_rot = current_pose
+
+ if rotation_space == "3d":
+ anchored_rot = R.from_rotvec(anchored_rotation).as_matrix()
+ elif rotation_space == "6d":
+ anchored_rot = _rotation_repr_to_mat(anchored_rotation, rotation_space)
+ elif rotation_space == "9d":
+ anchored_rot = anchored_rotation.reshape(3, 3)
+ else:
+ raise ValueError(f"Unsupported rotation_space={rotation_space!r}. Use 3d/6d/9d.")
+ target_rot = base_rot @ anchored_rot
+ target_pos = base_pos + base_rot @ anchored_translation
+ delta_pos = target_pos - current_pos
+ delta_rot = target_rot @ current_rot.T
+ delta_rotvec = R.from_matrix(delta_rot).as_rotvec()
+
+ return np.concatenate([delta_pos, delta_rotvec, gripper], axis=0)
+
+
+def _framewise_action_to_delta(
+ framewise_action: np.ndarray,
+ rotation_space: str,
+) -> np.ndarray:
+ """Convert a frame-wise policy action to LIBERO's 7D simulator command.
+
+ Frame-wise actions are already per-step deltas in the LIBERO controller's
+ convention (see ``LiberoDataset`` with ``action_space='frame_wise_relative'``),
+ so the only conversion required is decoding the chosen rotation
+ representation back to a rotation vector. No anchor/current pose is needed.
+ """
+ if rotation_space == "3d":
+ return framewise_action
+
+ translation = framewise_action[:3]
+ rotation_dim = framewise_action.shape[0] - 4
+ rotation_repr = framewise_action[3 : 3 + rotation_dim]
+ gripper = framewise_action[3 + rotation_dim : 4 + rotation_dim]
+ rotation_delta = _rotation_repr_to_mat(rotation_repr, rotation_space)
+
+ delta_pos = translation
+ delta_rotvec = R.from_matrix(rotation_delta).as_rotvec()
+ return np.concatenate([delta_pos, delta_rotvec, gripper], axis=0)
+
+
+def _run_episode(
+ env: Any,
+ client: ActionEnvironmentClient,
+ *,
+ cameras: list[str],
+ flip_images: bool,
+ rotate_180: bool,
+ action_horizon: int,
+ action_dim: int,
+ action_space: str,
+ rotation_space: str,
+ gripper_mode: str,
+ max_steps: int,
+ warmup_steps: int,
+ initial_state: np.ndarray | None,
+ gif_path: Path | None,
+ gif_fps: int,
+ comparison_path: Path | None = None,
+) -> EpisodeResult:
+ env.reset()
+ if initial_state is not None:
+ obs = env.set_init_state(initial_state)
+ else:
+ obs = env.get_observation()
+
+ action_queue: list[list[float]] = []
+ base_pose: tuple[np.ndarray, np.ndarray] | None = None
+ step = 0
+ success = False
+ gif_frames: list[Image.Image] = []
+ action_log: list[list[float]] = []
+ is_multi_view = len(cameras) > 1
+ resolved_rotation_space = _infer_rotation_space(action_dim, rotation_space)
+
+ comparison_windows: list[tuple[list[Image.Image], list[Image.Image]]] = []
+
+ def record_frame(current_obs: dict[str, Any]) -> None:
+ if gif_path is None:
+ return
+ image = _get_libero_image(
+ current_obs,
+ cameras[0],
+ flip_images=flip_images,
+ rotate_180=rotate_180,
+ )
+ image = _ensure_uint8_image(image)
+ gif_frames.append(Image.fromarray(image).convert("RGB"))
+
+ def capture_comparison_frame(current_obs: dict[str, Any]) -> Image.Image:
+ """Capture an env frame matching Action's input view (multi-view concatenated if applicable)."""
+ if is_multi_view:
+ imgs = _get_libero_images(current_obs, cameras, flip_images=flip_images, rotate_180=rotate_180)
+ concat = client.concatenate_images(imgs)
+ return Image.fromarray(_ensure_uint8_image(concat)).convert("RGB")
+ img = _get_libero_image(current_obs, cameras[0], flip_images=flip_images, rotate_180=rotate_180)
+ return Image.fromarray(_ensure_uint8_image(img)).convert("RGB")
+
+ record_frame(obs)
+
+ while step < max_steps:
+ if step < warmup_steps:
+ dummy = _get_libero_dummy_action()
+ obs, _, _, _ = env.step(dummy)
+ action_log.append(dummy)
+ step += 1
+ record_frame(obs)
+ continue
+
+ if not action_queue:
+ if is_multi_view:
+ observation_imgs = _get_libero_images(
+ obs,
+ cameras,
+ flip_images=flip_images,
+ rotate_180=rotate_180,
+ )
+ result = client.predict(observation_imgs)
+ else:
+ observation_img = _get_libero_image(
+ obs,
+ cameras[0],
+ flip_images=flip_images,
+ rotate_180=rotate_180,
+ )
+ result = client.predict(observation_img)
+ actions = result.get("action", [])
+ if not actions:
+ return EpisodeResult(False, step, "Empty action chunk from server", action_log)
+ action_queue = _select_action_chunk(actions, action_horizon)
+
+ if comparison_path is not None:
+ action_video_b64 = result.get("video", [])
+ if action_video_b64:
+ action_frames = _decode_b64_frames(action_video_b64)
+ env_comparison_frames = [capture_comparison_frame(obs)]
+ comparison_windows.append((action_frames, env_comparison_frames))
+
+ if action_space == "relative":
+ base_pose = _obs_to_pose(obs)
+
+ raw_action = _format_action(action_queue.pop(0), action_dim)
+ if action_space == "relative":
+ if base_pose is None:
+ raise RuntimeError("Missing base pose for relative action conversion")
+ current_pose = _obs_to_pose(obs)
+ action = _anchored_action_to_delta(
+ np.asarray(raw_action, dtype=np.float32),
+ base_pose,
+ current_pose,
+ resolved_rotation_space,
+ )
+ action_list = action.tolist()
+ else:
+ action = _framewise_action_to_delta(
+ np.asarray(raw_action, dtype=np.float32),
+ resolved_rotation_space,
+ )
+ action_list = action.tolist()
+
+ # Map the model's gripper command to the env's [-1, 1] per the dataset convention.
+ action_list = _remap_gripper(action_list, gripper_mode)
+
+ action_log.append(action_list)
+ obs, _, done, info = env.step(action_list)
+ step += 1
+ record_frame(obs)
+
+ if comparison_path is not None and comparison_windows:
+ comparison_windows[-1][1].append(capture_comparison_frame(obs))
+
+ if isinstance(info, dict) and info.get("success"):
+ success = True
+ break
+ if done:
+ success = True if not isinstance(info, dict) else bool(info.get("success", True))
+ break
+
+ if gif_path is not None:
+ _save_gif(gif_frames, gif_path, gif_fps)
+ if comparison_path is not None and comparison_windows:
+ _save_comparison_gif(comparison_windows, comparison_path, gif_fps)
+ return EpisodeResult(success, step, None, action_log)
+
+
+def _load_initial_states(
+ task_suite: Any,
+ task_id: int,
+ *,
+ task_description: str,
+ initial_states_path: str,
+ episode_idx: int,
+) -> np.ndarray | None:
+ default_initial_states = task_suite.get_task_init_states(task_id)
+
+ if initial_states_path == "DEFAULT":
+ return np.array(default_initial_states[episode_idx])
+
+ with open(initial_states_path, "r", encoding="utf-8") as f:
+ all_initial_states = json.load(f)
+
+ task_key = task_description.replace(" ", "_")
+ episode_key = f"demo_{episode_idx}"
+ if not all_initial_states[task_key][episode_key]["success"]:
+ return None
+ return np.array(all_initial_states[task_key][episode_key]["initial_state"])
+
+
+def _parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(description="LIBERO closed-loop evaluation via Action HTTP server")
+ parser.add_argument(
+ "--server_url", type=str, required=True, help="Base URL for Action server (e.g., http://host:8000)"
+ )
+ parser.add_argument("--task_suite", type=str, default="libero_spatial", choices=sorted(TASK_MAX_STEPS.keys()))
+ parser.add_argument("--num_trials_per_task", type=int, default=10)
+ parser.add_argument("--task_ids", type=str, default="", help="Comma-separated task IDs to evaluate (default: all)")
+ parser.add_argument("--image_size", type=int, default=256, help="Model input image size")
+ parser.add_argument("--env_image_size", type=int, default=256, help="Environment render resolution")
+ parser.add_argument("--action_horizon", type=int, default=0, help="Actions to execute per request (0=full chunk)")
+ parser.add_argument("--action_dim", type=int, default=10, help="Action dimension for LIBERO")
+ parser.add_argument(
+ "--action_space",
+ type=str,
+ default="frame_wise_relative",
+ choices=["relative", "frame_wise_relative"],
+ help="Action space expected from the model (relative=anchored, frame_wise_relative=framewise deltas).",
+ )
+ parser.add_argument(
+ "--rotation_space",
+ type=str,
+ default="auto",
+ choices=["auto", "3d", "6d", "9d"],
+ help="Rotation representation for anchored actions (auto infers from action_dim).",
+ )
+ parser.add_argument(
+ "--gripper_mode",
+ type=str,
+ default="zero_one",
+ choices=["zero_one", "pm_one", "pm_one_flip"],
+ help="Gripper convention of the training data: 'zero_one' = [0,1] (NVIDIA "
+ "LIBERO_LeRobot_v3, mapped 1-2g); 'pm_one' = {-1,+1} (community lerobot/libero_*, "
+ "pass-through); 'pm_one_flip' = {-1,+1} with inverted sign.",
+ )
+ parser.add_argument("--domain_name", type=str, default="libero")
+ parser.add_argument(
+ "--camera",
+ type=str,
+ default="agentview",
+ help="Camera(s) to use. Single camera: 'agentview' or 'wrist'. Multiple cameras: comma-separated, e.g., 'agentview,wrist'.",
+ )
+ parser.add_argument("--flip_images", action="store_true", help="Flip images vertically before encoding")
+ parser.add_argument(
+ "--rotate_180",
+ action=argparse.BooleanOptionalAction,
+ default=True,
+ help="Rotate images by 180 degrees before encoding (default: True; pass --no-rotate-180 to disable)",
+ )
+ parser.add_argument("--warmup_steps", type=int, default=10, help="Stabilization steps with dummy actions")
+ parser.add_argument("--max_steps", type=int, default=0, help="Override max steps per episode (0=default)")
+ parser.add_argument("--timeout", type=float, default=30.0, help="HTTP request timeout in seconds")
+ parser.add_argument("--wait_timeout", type=float, default=60.0, help="Seconds to wait for server health")
+ parser.add_argument("--seed", type=int, default=0)
+ parser.add_argument("--save_gifs", action="store_true", help="Save per-episode GIFs of rendered frames")
+ parser.add_argument(
+ "--save_comparison",
+ action="store_true",
+ help="Save side-by-side comparison GIFs (Action prediction vs environment rollout)",
+ )
+ parser.add_argument("--gif_fps", type=int, default=20, help="Frames per second for saved GIFs")
+ parser.add_argument(
+ "--mujoco_gl",
+ type=str,
+ default="auto",
+ choices=["auto", "egl", "osmesa", "glfw"],
+ help="MuJoCo GL backend (auto picks egl if /dev/dri is accessible, else osmesa).",
+ )
+ parser.add_argument(
+ "--render_gpu_device_id",
+ type=int,
+ default=-1,
+ help="GPU device index for EGL rendering (-1 uses default device).",
+ )
+ parser.add_argument(
+ "--initial_states_path",
+ type=str,
+ default="DEFAULT",
+ help='Path to initial states JSON. Use "DEFAULT" for benchmark defaults.',
+ )
+ parser.add_argument(
+ "--num_envs",
+ type=int,
+ default=1,
+ help="Number of parallel LIBERO envs (SubprocVectorEnv). >1 runs trials in waves "
+ "with ONE batched /predict_batch per control step (~num_envs x faster). 1 = serial.",
+ )
+ parser.add_argument("--output_dir", type=str, default="", help="Directory to save evaluation summary JSON")
+ return parser.parse_args()
+
+
+class _LiberoEnvFactory:
+ """Picklable env factory for SubprocVectorEnv under the spawn start method.
+
+ spawn pickles each env_fn and re-imports this module in the child, so the
+ factory must be a top-level class (lambdas/closures are not picklable). The
+ child sets the GL backend and imports OffScreenRenderEnv locally so its EGL
+ context is created fresh in the worker process."""
+
+ def __init__(
+ self,
+ *,
+ bddl_file_name: str,
+ camera_heights: int,
+ camera_widths: int,
+ render_gpu_device_id: int,
+ mujoco_gl: str,
+ ) -> None:
+ self.bddl_file_name = bddl_file_name
+ self.camera_heights = camera_heights
+ self.camera_widths = camera_widths
+ self.render_gpu_device_id = render_gpu_device_id
+ self.mujoco_gl = mujoco_gl
+
+ def __call__(self) -> Any:
+ # Resolve to a concrete GPU; -1 (auto) makes EGL device selection race/fail
+ # across spawned workers (EGLError / "'EGLGLContext' object has no attribute
+ # '_context'"). Set the GL backend + pin the EGL device BEFORE importing
+ # OffScreenRenderEnv (which dlopen's the GL stack at import).
+ dev = self.render_gpu_device_id if self.render_gpu_device_id >= 0 else 0
+ os.environ["MUJOCO_GL"] = self.mujoco_gl
+ if self.mujoco_gl == "egl":
+ os.environ["PYOPENGL_PLATFORM"] = "egl"
+ os.environ["MUJOCO_EGL_DEVICE_ID"] = str(dev)
+ os.environ["EGL_DEVICE_ID"] = str(dev)
+ elif self.mujoco_gl == "osmesa":
+ os.environ["PYOPENGL_PLATFORM"] = "osmesa"
+ from libero.libero.envs import OffScreenRenderEnv as _OffScreenRenderEnv
+
+ return _OffScreenRenderEnv(
+ bddl_file_name=self.bddl_file_name,
+ camera_heights=self.camera_heights,
+ camera_widths=self.camera_widths,
+ render_gpu_device_id=dev,
+ )
+
+
+def _run_task_vectorized(
+ task: Any,
+ task_description: str,
+ *,
+ num_trials: int,
+ num_envs: int,
+ env_image_size: int,
+ seed: int,
+ render_gpu_device_id: int,
+ client: ActionEnvironmentClient,
+ cameras: list[str],
+ flip_images: bool,
+ rotate_180: bool,
+ action_horizon: int,
+ action_dim: int,
+ rotation_space: str,
+ gripper_mode: str,
+ max_steps: int,
+ warmup_steps: int,
+ init_states: list[np.ndarray | None],
+) -> list[dict[str, Any]]:
+ """Run all `num_trials` of one task across `num_envs` parallel LIBERO envs
+ (SubprocVectorEnv), in waves. Each control step gathers obs from the ACTIVE
+ (not-done) envs, issues ONE batched /predict_batch, and steps all active envs;
+ done envs are masked out. Returns per-trial result dicts in trial order with the
+ same shape as the serial path's episode_results."""
+ import multiprocessing as _mp
+
+ from libero.libero.envs.venv import SubprocVectorEnv
+
+ # LIBERO's SubprocVectorEnv defaults to the fork start method; forked children
+ # inherit the parent's already-dlopen'd EGL/GL state, which corrupts per-child
+ # render-context creation (EGLError / 'EGLGLContext' has no attribute '_context').
+ # Force spawn so each env worker starts clean — exactly like the (working) serial
+ # single-process path. spawn pickles env_fns, so the factory below is picklable.
+ try:
+ _mp.set_start_method("spawn", force=True)
+ except RuntimeError: # pragma: no cover - already set
+ pass
+
+ resolved_rotation_space = _infer_rotation_space(action_dim, rotation_space)
+ bddl = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
+
+ results: list[dict[str, Any]] = [None] * num_trials # type: ignore[list-item]
+ for t in range(num_trials):
+ if init_states[t] is None:
+ results[t] = {
+ "episode": t,
+ "success": False,
+ "steps": 0,
+ "error": "Skipped due to failed expert demo",
+ "elapsed_s": 0.0,
+ }
+ runnable = [t for t in range(num_trials) if init_states[t] is not None]
+ if not runnable:
+ return results
+
+ n = min(num_envs, len(runnable))
+
+ mujoco_gl = os.environ.get("MUJOCO_GL", "egl")
+ env_fn = _LiberoEnvFactory(
+ bddl_file_name=bddl,
+ camera_heights=env_image_size,
+ camera_widths=env_image_size,
+ render_gpu_device_id=render_gpu_device_id,
+ mujoco_gl=mujoco_gl,
+ )
+ venv = SubprocVectorEnv([env_fn for _ in range(n)])
+ try:
+ venv.seed(seed)
+ for w0 in range(0, len(runnable), n):
+ wave = runnable[w0 : w0 + n] # trial indices for this wave
+ slots = list(range(len(wave))) # env slots in use
+ t_wave0 = time.perf_counter()
+ venv.reset(id=slots)
+ states = np.stack([np.asarray(init_states[t], dtype=np.float64) for t in wave])
+ obs_arr = venv.set_init_state(states, id=slots)
+ obs_by_slot = {s: obs_arr[i] for i, s in enumerate(slots)}
+ done = {s: False for s in slots}
+ succ = {s: False for s in slots}
+ err: dict[int, str | None] = {s: None for s in slots}
+ nsteps = {s: max_steps for s in slots}
+ step = 0
+
+ for _ in range(warmup_steps):
+ act = np.stack([_get_libero_dummy_action() for _ in slots])
+ obs_arr, _, _, _ = venv.step(act, id=slots)
+ for i, s in enumerate(slots):
+ obs_by_slot[s] = obs_arr[i]
+ step += 1
+
+ while step < max_steps:
+ active = [s for s in slots if not done[s]]
+ if not active:
+ break
+ obs_batch = [
+ _get_libero_images(obs_by_slot[s], cameras, flip_images=flip_images, rotate_180=rotate_180)
+ for s in active
+ ]
+ try:
+ chunks = client.predict_batch(obs_batch)
+ except Exception as e: # noqa: BLE001
+ for s in active:
+ done[s] = True
+ err[s] = f"server error: {e}"
+ nsteps[s] = step
+ break
+ if not chunks or len(chunks) != len(active):
+ for s in active:
+ done[s] = True
+ err[s] = "bad batch response from server"
+ nsteps[s] = step
+ break
+ chunk_by_slot = {s: chunks[k] for k, s in enumerate(active)}
+ horizon = action_horizon if action_horizon > 0 else len(chunks[0])
+ for h in range(horizon):
+ cur = [s for s in slots if not done[s]]
+ if not cur or step >= max_steps:
+ break
+ env_actions = []
+ for s in cur:
+ raw = _format_action(chunk_by_slot[s][h], action_dim)
+ a = _framewise_action_to_delta(np.asarray(raw, dtype=np.float32), resolved_rotation_space)
+ env_actions.append(_remap_gripper(a.tolist(), gripper_mode))
+ obs_arr, _, d, info = venv.step(np.stack(env_actions), id=cur)
+ step += 1
+ for i, s in enumerate(cur):
+ obs_by_slot[s] = obs_arr[i]
+ di = bool(d[i])
+ ii = info[i] if isinstance(info, (list, np.ndarray)) else info
+ is_succ = bool(ii.get("success")) if isinstance(ii, dict) else False
+ if is_succ:
+ done[s], succ[s], nsteps[s] = True, True, step
+ elif di:
+ # mirror serial: done w/o explicit success defaults to success
+ done[s] = True
+ succ[s] = ii.get("success", True) if isinstance(ii, dict) else True
+ nsteps[s] = step
+ per_ep_elapsed = round((time.perf_counter() - t_wave0) / max(1, len(wave)), 3)
+ for s, t in zip(slots, wave):
+ results[t] = {
+ "episode": t,
+ "success": bool(succ[s]),
+ "steps": int(nsteps[s]),
+ "error": err[s],
+ "elapsed_s": per_ep_elapsed,
+ }
+ finally:
+ try:
+ venv.close()
+ except Exception: # noqa: BLE001
+ pass
+ return results
+
+
+def main() -> None:
+ args = _parse_args()
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+
+ if args.save_gifs and not args.output_dir:
+ raise ValueError("--save_gifs requires --output_dir to be set")
+ if args.save_comparison and not args.output_dir:
+ raise ValueError("--save_comparison requires --output_dir to be set")
+
+ # Parse cameras from comma-separated string
+ cameras = [c.strip() for c in args.camera.split(",") if c.strip()]
+ if not cameras:
+ raise ValueError("At least one camera must be specified")
+ for cam in cameras:
+ if cam not in ("agentview", "wrist"):
+ raise ValueError(f"Unsupported camera={cam!r}. Use 'agentview' or 'wrist'.")
+
+ mujoco_backend = _configure_mujoco_env(args.mujoco_gl)
+ _import_libero()
+
+ client = ActionEnvironmentClient(
+ server_url=args.server_url,
+ domain_name=args.domain_name,
+ prompt="",
+ image_size=args.image_size,
+ timeout=args.timeout,
+ )
+ print(f"MuJoCo GL backend: {mujoco_backend}", flush=True)
+ print("Waiting for model server...", flush=True)
+ _wait_for_server(client, args.wait_timeout)
+ print(f"Connected to model server: {client.get_info()}", flush=True)
+
+ benchmark_dict = benchmark.get_benchmark_dict()
+ task_suite = benchmark_dict[args.task_suite]()
+ num_tasks = int(task_suite.n_tasks)
+
+ if args.task_ids:
+ selected_task_ids = [int(t) for t in args.task_ids.split(",") if t.strip()]
+ else:
+ selected_task_ids = list(range(num_tasks))
+
+ max_steps = args.max_steps if args.max_steps > 0 else TASK_MAX_STEPS[args.task_suite]
+
+ total_episodes = 0
+ total_successes = 0
+ task_results: list[dict[str, Any]] = []
+
+ output_dir = Path(args.output_dir) if args.output_dir else None
+ gif_root = output_dir / "gifs" if output_dir and args.save_gifs else None
+ comparison_root = output_dir / "comparisons" if output_dir and args.save_comparison else None
+
+ for task_id in selected_task_ids:
+ task = task_suite.get_task(task_id)
+
+ # ---- Vectorized path: N parallel envs + one batched /predict_batch per step ----
+ if args.num_envs > 1:
+ task_description = str(task.language)
+ client.prompt = _augment_task_prompt_with_viewpoint(task_description, cameras)
+ init_states = [
+ _load_initial_states(
+ task_suite,
+ task_id,
+ task_description=task_description,
+ initial_states_path=args.initial_states_path,
+ episode_idx=e,
+ )
+ for e in range(args.num_trials_per_task)
+ ]
+ episode_results = _run_task_vectorized(
+ task,
+ task_description,
+ num_trials=args.num_trials_per_task,
+ num_envs=args.num_envs,
+ env_image_size=args.env_image_size,
+ seed=args.seed,
+ render_gpu_device_id=args.render_gpu_device_id,
+ client=client,
+ cameras=cameras,
+ flip_images=args.flip_images,
+ rotate_180=args.rotate_180,
+ action_horizon=args.action_horizon,
+ action_dim=args.action_dim,
+ rotation_space=args.rotation_space,
+ gripper_mode=args.gripper_mode,
+ max_steps=max_steps,
+ warmup_steps=args.warmup_steps,
+ init_states=init_states,
+ )
+ task_episodes = 0
+ task_successes = 0
+ for er in episode_results:
+ task_episodes += 1
+ total_episodes += 1
+ if er["success"]:
+ task_successes += 1
+ total_successes += 1
+ print(
+ f"Task {task_id} | Episode {er['episode'] + 1}/{args.num_trials_per_task} | "
+ f"success={er['success']} steps={er['steps']} elapsed_s={er['elapsed_s']:.1f} | "
+ f"task SR {task_successes}/{task_episodes} ({100.0 * task_successes / max(1, task_episodes):.1f}%) | "
+ f"overall SR {total_successes}/{total_episodes} "
+ f"({100.0 * total_successes / max(1, total_episodes):.1f}%)",
+ flush=True,
+ )
+ task_success_rate = float(task_successes) / float(task_episodes) if task_episodes > 0 else 0.0
+ task_results.append(
+ {
+ "task_id": task_id,
+ "task_description": task_description,
+ "episodes": task_episodes,
+ "successes": task_successes,
+ "success_rate": task_success_rate,
+ "episode_results": episode_results,
+ }
+ )
+ print(
+ f"Task {task_id} summary: {task_successes}/{task_episodes} ({task_success_rate * 100:.1f}%)",
+ flush=True,
+ )
+ continue
+
+ env, task_description = _get_libero_env(
+ task,
+ resolution=args.env_image_size,
+ seed=args.seed,
+ render_gpu_device_id=args.render_gpu_device_id,
+ )
+
+ task_episodes = 0
+ task_successes = 0
+ episode_results: list[dict[str, Any]] = []
+
+ for episode_idx in range(args.num_trials_per_task):
+ episode_t0 = time.perf_counter()
+ client.prompt = _augment_task_prompt_with_viewpoint(task_description, cameras)
+ initial_state = _load_initial_states(
+ task_suite,
+ task_id,
+ task_description=task_description,
+ initial_states_path=args.initial_states_path,
+ episode_idx=episode_idx,
+ )
+ if initial_state is None:
+ episode_elapsed_s = time.perf_counter() - episode_t0
+ episode_results.append(
+ {
+ "episode": episode_idx,
+ "success": False,
+ "steps": 0,
+ "error": "Skipped due to failed expert demo",
+ "elapsed_s": round(episode_elapsed_s, 3),
+ }
+ )
+ print(
+ f"Task {task_id} | Episode {episode_idx + 1}/{args.num_trials_per_task} | "
+ "success=False steps=0 "
+ f"elapsed_s={episode_elapsed_s:.1f} "
+ "error='Skipped due to failed expert demo'",
+ flush=True,
+ )
+ continue
+
+ gif_path = (
+ gif_root / f"task_{task_id:03d}" / f"episode_{episode_idx:03d}.gif" if gif_root is not None else None
+ )
+ comparison_path = (
+ comparison_root / f"task_{task_id:03d}" / f"episode_{episode_idx:03d}.gif"
+ if comparison_root is not None
+ else None
+ )
+ try:
+ result = _run_episode(
+ env,
+ client,
+ cameras=cameras,
+ flip_images=args.flip_images,
+ rotate_180=args.rotate_180,
+ action_horizon=args.action_horizon,
+ action_dim=args.action_dim,
+ action_space=args.action_space,
+ rotation_space=args.rotation_space,
+ gripper_mode=args.gripper_mode,
+ max_steps=max_steps,
+ warmup_steps=args.warmup_steps,
+ initial_state=initial_state,
+ gif_path=gif_path,
+ gif_fps=args.gif_fps,
+ comparison_path=comparison_path,
+ )
+ except Exception as exc:
+ result = EpisodeResult(False, 0, str(exc), [])
+ episode_elapsed_s = time.perf_counter() - episode_t0
+
+ task_episodes += 1
+ total_episodes += 1
+ if result.success:
+ task_successes += 1
+ total_successes += 1
+
+ episode_results.append(
+ {
+ "episode": episode_idx,
+ "success": result.success,
+ "steps": result.steps,
+ "error": result.error,
+ "elapsed_s": round(episode_elapsed_s, 3),
+ }
+ )
+
+ # Save per-episode action log as JSON
+ if output_dir is not None and result.actions:
+ action_log_dir = output_dir / "actions" / f"task_{task_id:03d}"
+ action_log_dir.mkdir(parents=True, exist_ok=True)
+ action_log_path = action_log_dir / f"episode_{episode_idx:03d}.json"
+ action_log_path.write_text(
+ json.dumps(result.actions, indent=2),
+ encoding="utf-8",
+ )
+
+ client.notify_next_episode()
+
+ print(
+ f"Task {task_id} | Episode {episode_idx + 1}/{args.num_trials_per_task} | "
+ f"success={result.success} steps={result.steps} elapsed_s={episode_elapsed_s:.1f} | "
+ f"task SR {task_successes}/{task_episodes} ({100.0 * task_successes / max(1, task_episodes):.1f}%) | "
+ f"overall SR {total_successes}/{total_episodes} ({100.0 * total_successes / max(1, total_episodes):.1f}%)",
+ flush=True,
+ )
+
+ task_success_rate = float(task_successes) / float(task_episodes) if task_episodes > 0 else 0.0
+ task_results.append(
+ {
+ "task_id": task_id,
+ "task_description": task_description,
+ "episodes": task_episodes,
+ "successes": task_successes,
+ "success_rate": task_success_rate,
+ "episode_results": episode_results,
+ }
+ )
+ print(
+ f"Task {task_id} summary: {task_successes}/{task_episodes} ({task_success_rate * 100:.1f}%)",
+ flush=True,
+ )
+ # Close the env (and its EGL/MuJoCo render context) before the next task.
+ # Leaving it open leaks one EGL context per task and hangs after ~8 tasks.
+ try:
+ env.close()
+ except Exception:
+ pass
+
+ overall_success_rate = float(total_successes) / float(total_episodes) if total_episodes > 0 else 0.0
+ summary = {
+ "task_suite": args.task_suite,
+ "total_episodes": total_episodes,
+ "total_successes": total_successes,
+ "overall_success_rate": overall_success_rate,
+ "num_trials_per_task": args.num_trials_per_task,
+ "selected_task_ids": selected_task_ids,
+ "action_space": args.action_space,
+ "rotation_space": _infer_rotation_space(args.action_dim, args.rotation_space),
+ "action_dim": args.action_dim,
+ "task_results": task_results,
+ }
+
+ print(
+ f"Overall success rate: {total_successes}/{total_episodes} ({overall_success_rate * 100:.1f}%)",
+ flush=True,
+ )
+
+ if output_dir is not None:
+ output_dir.mkdir(parents=True, exist_ok=True)
+ summary_path = output_dir / "summary.json"
+ summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
+ print(f"Saved summary to {summary_path}", flush=True)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/cosmos_framework/utils/vfm/model_loader.py b/cosmos_framework/utils/vfm/model_loader.py
index 6e6a0dd..51140b3 100644
--- a/cosmos_framework/utils/vfm/model_loader.py
+++ b/cosmos_framework/utils/vfm/model_loader.py
@@ -252,10 +252,17 @@ def _load_model(
keys_to_skip_loading=keys_to_skip_loading or [],
)
+ # Single-rank load (e.g. the action-policy inference server): force no_dist so
+ # ``dcp.load`` skips the collective ``gather_object`` over the load plan, which
+ # pickles the plan and can fail on training/EMA DCPs. Multi-rank loads keep the
+ # default distributed path.
+ no_dist = not (dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1)
+
dcp.load(
state_dict=state_dict,
storage_reader=storage_reader,
planner=load_planner,
+ no_dist=no_dist,
)
log.info(f"Successfully loaded model from {checkpoint_path}")
diff --git a/docs/action_policy_libero_sft.md b/docs/action_policy_libero_sft.md
new file mode 100644
index 0000000..660de11
--- /dev/null
+++ b/docs/action_policy_libero_sft.md
@@ -0,0 +1,109 @@
+# Cosmos3-Nano LIBERO-10 action-policy SFT
+
+Full SFT of the public `nvidia/Cosmos3-Nano` base into a LIBERO-10 action
+policy: vision + language in, action chunks out.
+
+| Piece | Path |
+| ---------------- | ----------------------------------------------------------------------------------------------- |
+| Dataset | `cosmos_framework/data/vfm/action/datasets/libero_lerobot_dataset.py` (`LIBEROLeRobotDataset`) |
+| SFT wrapper | `get_action_libero_sft_dataset` in `.../datasets/action_sft_dataset.py` |
+| Norm stats | `.../datasets/stats/libero_native_frame_wise_relative_rot6d.json` |
+| Experiment | `cosmos_framework/configs/base/experiment/action/posttrain_config/action_policy_libero_nano.py` |
+| Run TOML | `examples/toml/sft_config/action_policy_libero_repro.toml` |
+| Launch | `examples/launch_sft_action_policy_libero.sh` |
+| Inference server | `cosmos_framework/scripts/action_policy_server_libero.py` |
+| Closed-loop eval | `cosmos_framework/simulation/libero/closed_loop_eval.py` |
+
+## 1. Data
+
+`LIBEROLeRobotDataset` reads a local LeRobot dir (`LIBERO_ROOT`). Use the 20 FPS
+[`nvidia/LIBERO_LeRobot_v3`](https://huggingface.co/datasets/nvidia/LIBERO_LeRobot_v3),
+which the bundled `quantile_rot` stats and the 20 Hz eval assume. Train on
+`libero_10` alone:
+
+```bash
+hf download nvidia/LIBERO_LeRobot_v3 --repo-type dataset \
+ --include 'libero_10/**' --local-dir /LIBERO_LeRobot_v3
+export LIBERO_ROOT=/LIBERO_LeRobot_v3/libero_10
+```
+
+Actions are `frame_wise_relative` rot6d (10D = pos 3 + rot6d 6 + gripper 1),
+`concat_view` (third-person + wrist, each 256×256 → 256×512), `quantile_rot`
+normalized. The pipeline snaps the 256×512 concat to a 192×320 model canvas; the
+eval server reproduces the same snap (§4).
+
+## 2. Train
+
+```bash
+export LD_LIBRARY_PATH='' # NGC container: avoid torch._C import error
+export LIBERO_ROOT=/path/to/libero_10_lerobot
+export BASE_CHECKPOINT_PATH=
+export WAN_VAE_PATH=
+export IMAGINAIRE_OUTPUT_ROOT=/path/to/output_root
+
+bash examples/launch_sft_action_policy_libero.sh # HSDP 2x8; set NNODES/NODE_RANK/MASTER_ADDR per node
+```
+
+Recipe knobs live in `action_policy_libero_nano`; the TOML sets run-level scalars
+(lr 5e-5, warmup 500, cycle 16000, `save_iter=500`, HSDP 2x8). Global batch is
+2048 = `max_samples_per_batch` 128 × 16 ranks × grad_accum 1.
+
+## 3. Closed-loop eval
+
+Start the policy server on a **trained** checkpoint (the base DCP has no action
+heads), then run the LIBERO simulator client against it.
+
+```bash
+python -m cosmos_framework.scripts.action_policy_server_libero \
+ --experiment action_policy_libero_nano \
+ --experiment-overrides "model.config.tokenizer.vae_path=$WAN_VAE_PATH" \
+ --checkpoint-path /checkpoints/iter_000001500 \
+ --action-normalization quantile_rot \
+ --action-stats-path cosmos_framework/data/vfm/action/datasets/stats/libero_native_frame_wise_relative_rot6d.json \
+ --raw-action-dim 10 --fps 20 --port 8000
+```
+
+The LIBERO sim needs a separate venv (robosuite/mujoco pins conflict with the
+training env):
+
+```bash
+# Optional — only on a headless container without working GPU EGL:
+# export NVIDIA_DRIVER_CAPABILITIES=all
+# apt-get install -y libegl1 libglvnd0 libgl1 libglib2.0-0 ffmpeg
+# mkdir -p /usr/share/glvnd/egl_vendor.d
+# echo '{"file_format_version":"1.0.0","ICD":{"library_path":"libEGL_nvidia.so.0"}}' \
+# > /usr/share/glvnd/egl_vendor.d/10_nvidia.json
+
+uv venv --python 3.10 .libenv && VV=.libenv/bin/python
+git clone https://github.com/Lifelong-Robot-Learning/LIBERO.git && \
+ uv pip install -p $VV -e LIBERO -r LIBERO/requirements.txt
+uv pip install -p $VV "robosuite==1.4.1" "mujoco==2.3.7" "torch<2.6" loguru requests scipy pillow numpy
+mkdir -p ~/.libero && touch ~/.libero/config.yaml
+RS=$($VV -c "import robosuite,os;print(os.path.dirname(robosuite.__file__))"); $VV "$RS/scripts/setup_macros.py"
+$VV -c "from libero.libero import set_libero_default_path; set_libero_default_path()"
+
+MUJOCO_GL=egl PYTHONPATH=$PWD:$PWD/LIBERO $VV \
+ cosmos_framework/simulation/libero/closed_loop_eval.py \
+ --server_url http://localhost:8000 \
+ --task_suite libero_10 --num_trials_per_task 50 --num_envs 8 \
+ --camera agentview,wrist --image_size 256 \
+ --action_space frame_wise_relative --rotation_space 6d --action_dim 10 \
+ --output_dir results/libero_closed_loop_10
+```
+
+## 4. Heads-up
+
+- **Lower-memory GPUs** — reduce the per-rank batch:
+ `--opts dataloader_train.max_samples_per_batch=64` (scale `replicate` to keep
+ global batch 2048).
+
+Eval parity — the client/server already handle these; verify if accuracy is low:
+
+- **Concat layout** — run with `--camera agentview,wrist --image_size 256` so the
+ 256×512 concat matches training (the server snaps it to 192×320 identically).
+- **Gripper** — model emits `[0, 1]`; the env wants `[-1, 1]` (negative = open).
+ The client applies `1 − 2·g`; flip the sign if the gripper never opens.
+- **Image orientation** — sim frames are rotated 180° vs training; the client
+ rotates them back.
+- **Normalization** — start the server with `--action-normalization quantile_rot`
+ and the bundled rot6d stats, or actions come out at the wrong scale.
diff --git a/examples/launch_sft_action_policy_libero.sh b/examples/launch_sft_action_policy_libero.sh
new file mode 100755
index 0000000..29a9a16
--- /dev/null
+++ b/examples/launch_sft_action_policy_libero.sh
@@ -0,0 +1,46 @@
+#!/usr/bin/env bash
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: OpenMDW-1.1
+
+# Structured-TOML launch for action_policy_libero_nano — Cosmos3-Nano LIBERO
+# action-policy SFT (HSDP, full SFT). Drives cosmos_framework.scripts.train
+# against examples/toml/sft_config/action_policy_libero_repro.toml.
+#
+# Point LIBERO_ROOT at the libero_10 suite ONLY. Use the 20 FPS
+# nvidia/LIBERO_LeRobot_v3. The default recipe is HSDP 2x8 (global batch 2048);
+# set NNODES/NODE_RANK/MASTER_ADDR per node.
+# See docs/action_policy_libero_sft.md.
+#
+# Required env vars:
+# LIBERO_ROOT local LIBERO-10 LeRobot dataset dir, e.g. /libero_10 (no default)
+# Optional env vars (defaults below; override to relocate data/checkpoints):
+# BASE_CHECKPOINT_PATH default: examples/checkpoints/Cosmos3-Nano
+# WAN_VAE_PATH default: examples/checkpoints/wan22_vae/Wan2.2_VAE.pth
+# HF_TOKEN if any tokenizer download requires gated HF access
+# OUTPUT_ROOT default: outputs/train
+#
+# Pre-sync the 20 FPS suite once:
+# hf download nvidia/LIBERO_LeRobot_v3 --repo-type dataset --include 'libero_10/**' --local-dir
+# export LIBERO_ROOT=/libero_10
+#
+# Usage (HSDP 2x8; set NNODES/NODE_RANK/MASTER_ADDR per node):
+# LIBERO_ROOT=/libero_10 bash examples/launch_sft_action_policy_libero.sh
+
+TOML_FILE="examples/toml/sft_config/action_policy_libero_repro.toml"
+: "${BASE_CHECKPOINT_PATH:=examples/checkpoints/Cosmos3-Nano}"
+
+# LIBEROLeRobotDataset reads ${oc.env:LIBERO_ROOT} directly (a LOCAL LeRobot dir);
+# export it so torchrun (launched in this shell) inherits it.
+export LIBERO_ROOT="${LIBERO_ROOT:-}"
+
+EXTRA_DATASET_CHECK='[[ -f "$LIBERO_ROOT/meta/info.json" ]] || { echo "ERROR: LIBERO_ROOT must be a local LeRobot dir containing meta/info.json (got: '\''$LIBERO_ROOT'\''). Pre-sync: hf download nvidia/LIBERO_LeRobot_v3 --repo-type dataset --include '\''libero_10/**'\'' --local-dir (then LIBERO_ROOT=/libero_10). See docs/action_policy_libero_sft.md" >&2; exit 1; }'
+
+# Extra Hydra overrides from the environment: a space-separated string word-split into
+# the TAIL_OVERRIDES array. An exported string survives `bash ` (a child
+# process), unlike a TAIL_OVERRIDES array set in your shell. Use it for smoke runs,
+# e.g. EXTRA_TAIL_OVERRIDES="trainer.max_iter=5 job.wandb_mode=offline".
+TAIL_OVERRIDES=(
+ ${EXTRA_TAIL_OVERRIDES:-}
+)
+
+source "$(dirname "${BASH_SOURCE[0]}")/_sft_launcher_common.sh"
diff --git a/examples/toml/sft_config/action_policy_libero_repro.toml b/examples/toml/sft_config/action_policy_libero_repro.toml
new file mode 100644
index 0000000..a0c49c7
--- /dev/null
+++ b/examples/toml/sft_config/action_policy_libero_repro.toml
@@ -0,0 +1,46 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: OpenMDW-1.1
+
+# LIBERO-10 action-policy SFT run config for the `action_policy_libero_nano`
+# experiment. Train on libero_10 alone (HSDP 2x8, global batch 2048).
+# Env: LIBERO_ROOT, BASE_CHECKPOINT_PATH, WAN_VAE_PATH, IMAGINAIRE_OUTPUT_ROOT.
+# See docs/action_policy_libero_sft.md.
+
+[job]
+task = "vfm"
+experiment = "action_policy_libero_nano"
+project = "cosmos3_action_libero"
+group = "action_sft"
+name = "action_policy_libero_repro"
+wandb_mode = "online"
+
+[model]
+precision = "bfloat16"
+max_num_tokens_after_packing = 74000
+
+[model.parallelism]
+data_parallel_shard_degree = 8
+data_parallel_replicate_degree = 2 # HSDP 2x8 = 16 ranks (2 nodes); minimum for gbs 2048 at grad_accum 1
+
+[model.activation_checkpointing]
+mode = "selective"
+save_ops_regex = ["fmha"]
+
+[model.tokenizer]
+vae_path = "${oc.env:WAN_VAE_PATH}"
+
+[optimizer]
+lr = 5.0e-05
+
+[scheduler]
+cycle_lengths = [16000]
+warm_up_steps = [500]
+
+[trainer]
+max_iter = 2000
+logging_iter = 50
+grad_accum_iter = 1 # global batch = max_samples 128 x (shard 8 x replicate 2) x 1 = 2048
+
+[checkpoint]
+load_path = "${oc.env:BASE_CHECKPOINT_PATH}"
+save_iter = 500