diff --git a/.gitignore b/.gitignore index 7e99e36..81fd351 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,21 @@ -*.pyc \ No newline at end of file +*.pyc + +# datasets / models / generated artifacts (do not commit) +models/ +arms_lerobot/ +libero_10/ +pick-n-place-sq-lerobot-v21/ +pick-n-place-sq-lerobot-v21.tgz +_frame_samples/ + +# generated outputs +arms/generated_*/ +arms/test_generated*/ +arms/submit_result/ + +# local logs / debug outputs +arms/*.log +arms/submit_result_step2200*/ +arms/sample_result/ +arms/test/ +arms/train/ \ No newline at end of file diff --git a/arms/arms_to_lingbot_input.py b/arms/arms_to_lingbot_input.py new file mode 100644 index 0000000..329e97a --- /dev/null +++ b/arms/arms_to_lingbot_input.py @@ -0,0 +1,129 @@ +import csv +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import cv2 +import numpy as np + + +@dataclass(frozen=True) +class ArmsEpisode: + root: Path + video_path: Path + instruction_path: Path + joint_path: Path + action_path: Path + + +def load_arms_episode(episode_dir: str | Path) -> ArmsEpisode: + root = Path(episode_dir) + return ArmsEpisode( + root=root, + video_path=root / "video.mp4", + instruction_path=root / "instruction.txt", + joint_path=root / "joint.txt", + action_path=root / "action.txt", + ) + + +def _read_single_line_text(path: Path) -> str: + return path.read_text(encoding="utf-8").strip() + + +def _read_csv_with_first_index_col(path: Path) -> Tuple[np.ndarray, List[str]]: + """ + arms 的 joint/action 文件是 CSV: + - 第一列是时间索引(表头有时为空或 Unnamed:0) + - 后面才是真正的关节/手指维度 + + 返回: + - data: shape [T, D],不含索引列 + - columns: D 个维度对应的列名(不含索引列) + """ + with path.open("r", newline="") as f: + reader = csv.reader(f) + header = next(reader) + rows = [r for r in reader if len(r) > 0] + + # drop the first index column + columns = header[1:] + data = np.asarray([[float(x) for x in r[1:]] for r in rows], dtype=np.float32) + return data, columns + + +def decode_video_frames( + video_path: Path, + frame_indices: List[int], +) -> List[np.ndarray]: + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + raise RuntimeError(f"Cannot open video: {video_path}") + + frames: List[np.ndarray] = [] + for idx in frame_indices: + cap.set(cv2.CAP_PROP_POS_FRAMES, idx) + ok, frame_bgr = cap.read() + if not ok or frame_bgr is None: + cap.release() + raise RuntimeError(f"Cannot read frame={idx} from {video_path}") + frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) + frames.append(frame_rgb) + + cap.release() + return frames + + +def build_obs_dict_from_arms( + episode: ArmsEpisode, + history_len: int = 16, + start_frame: int = 0, + cam_key: str = "front", +) -> Dict[str, Any]: + """ + 生成一个“类似 wan_va_server.py 期待的 obs dict”的最小结构: + - obs['obs'] : List[Dict[str, np.ndarray]],每个元素是一个时间步,各相机一张 HWC RGB uint8 + - obs['state']: np.ndarray,历史 state 序列(这里直接用 joint 的 D 维向量序列) + - obs['action_history']: np.ndarray,历史 action 序列(用于你自己做 KV cache / 条件化时用) + - obs['prompt']: 指令文本 + + 注意: + - 这个脚本只负责“把 arms 数据读出来并组织好”,不直接调用 LingBot-VA 的 server。 + - state/action 的 shape 与下游你改的 loader/adapter 相关,这里先提供最原始的 [T, D]。 + """ + prompt = _read_single_line_text(episode.instruction_path) + joint, joint_cols = _read_csv_with_first_index_col(episode.joint_path) + action, action_cols = _read_csv_with_first_index_col(episode.action_path) + + T = joint.shape[0] + if action.shape[0] != T: + raise ValueError(f"joint/action length mismatch: joint={T}, action={action.shape[0]}") + + end = min(start_frame + history_len, T) + frame_ids = list(range(start_frame, end)) + + frames = decode_video_frames(episode.video_path, frame_ids) + obs_seq = [{cam_key: f.astype(np.uint8)} for f in frames] + + obs: Dict[str, Any] = { + "obs": obs_seq, + "state": joint[start_frame:end], + "action_history": action[start_frame:end], + "prompt": prompt, + "meta": { + "joint_columns": joint_cols, + "action_columns": action_cols, + "start_frame": start_frame, + "history_len": history_len, + }, + } + return obs + + +if __name__ == "__main__": + ep = load_arms_episode("arms/train/1_1") + obs = build_obs_dict_from_arms(ep, history_len=16, start_frame=0, cam_key="front") + print("prompt:", obs["prompt"]) + print("obs len:", len(obs["obs"]), "frame shape:", obs["obs"][0]["front"].shape) + print("state shape:", obs["state"].shape, "action_history shape:", obs["action_history"].shape) + print("action_dim:", obs["action_history"].shape[-1]) diff --git a/arms/compute_arms_norm_stat.py b/arms/compute_arms_norm_stat.py new file mode 100644 index 0000000..7b37b40 --- /dev/null +++ b/arms/compute_arms_norm_stat.py @@ -0,0 +1,55 @@ +""" +计算 arms_lerobot 的 action 分位数归一化统计量(q01/q99)。 + +输出: + /norm_stat.json + +用法: + conda activate gmr + python arms/compute_arms_norm_stat.py --dataset-root arms_lerobot +""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +import numpy as np +import pandas as pd + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--dataset-root", type=str, required=True) + ap.add_argument("--q-low", type=float, default=0.01) + ap.add_argument("--q-high", type=float, default=0.99) + args = ap.parse_args() + + dataset_root = Path(args.dataset_root) + data_dir = dataset_root / "data" / "chunk-000" + files = sorted(data_dir.glob("episode_*.parquet")) + if not files: + raise RuntimeError(f"No parquet files under {data_dir}") + + # collect all action rows + acts = [] + for p in files: + df = pd.read_parquet(p, columns=["action"]) + a = np.stack(df["action"].to_list()).astype(np.float32) # [T, D] + acts.append(a) + acts_all = np.concatenate(acts, axis=0) # [N, D] + + q01 = np.quantile(acts_all, args.q_low, axis=0).tolist() + q99 = np.quantile(acts_all, args.q_high, axis=0).tolist() + + out = {"q01": q01, "q99": q99} + out_path = dataset_root / "norm_stat.json" + out_path.write_text(json.dumps(out, ensure_ascii=False, indent=2), encoding="utf-8") + print(f"Wrote: {out_path}") + print("action_dim:", acts_all.shape[1], "rows:", acts_all.shape[0]) + + +if __name__ == "__main__": + main() + diff --git a/arms/convert_arms_to_lerobot.py b/arms/convert_arms_to_lerobot.py new file mode 100644 index 0000000..d700ae6 --- /dev/null +++ b/arms/convert_arms_to_lerobot.py @@ -0,0 +1,264 @@ +""" +把 arms 数据集转换成 LeRobot 风格目录结构,尽量贴近本仓库 README/LeRobot 的期望: + +输入(每条轨迹一个文件夹): + // + - video.mp4 + - instruction.txt + - joint.txt + - action.txt + +输出(LeRobot 数据集目录): + / + - meta/episodes.jsonl + - videos/chunk-000/observation.images.front/episode_000000.mp4 + - data/chunk-000/episode_000000.parquet + +说明: + - 本脚本保持 arms 的动作/状态维度不变:CSV 去掉第一列索引后是 D=26。 + - parquet schema 采用与现有 LeRobot 数据类似的“向量列”: + action: float32[D] + observation.state: float32[D] + 以及索引/时间戳列。 + - 该脚本不会提取 latents(.pth)。latents 可以后续用 Wan2.2 VAE 批处理提取。 + +运行示例(建议在有 pyarrow 的环境,比如 conda env gmr): + conda activate gmr + python arms/convert_arms_to_lerobot.py --arms-root arms/train --out-root arms_lerobot +""" + +from __future__ import annotations + +import argparse +import csv +import json +import os +import shutil +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq + +try: + import cv2 # optional, for fps +except Exception: # pragma: no cover + cv2 = None + + +@dataclass(frozen=True) +class ArmsEpisodePaths: + episode_dir: Path + video_path: Path + instruction_path: Path + joint_path: Path + action_path: Path + + +def _read_single_line_text(path: Path) -> str: + return path.read_text(encoding="utf-8").strip() + + +def _read_csv_with_first_index_col(path: Path) -> Tuple[np.ndarray, List[str], np.ndarray]: + """ + 返回: + data: [T, D] float32(不含第一列索引) + columns: D 列名(不含第一列索引) + index: [T] int(第一列索引) + """ + with path.open("r", newline="") as f: + reader = csv.reader(f) + header = next(reader) + rows = [r for r in reader if len(r) > 0 and any(x.strip() for x in r)] + + columns = header[1:] + idx = np.asarray([int(float(r[0])) for r in rows], dtype=np.int32) + data = np.asarray([[float(x) for x in r[1:]] for r in rows], dtype=np.float32) + return data, columns, idx + + +def _get_video_fps(video_path: Path) -> float | None: + if cv2 is None: + return None + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + return None + fps = float(cap.get(cv2.CAP_PROP_FPS) or 0.0) + cap.release() + return fps if fps > 0 else None + + +def _discover_arms_episodes(arms_root: Path) -> List[ArmsEpisodePaths]: + episodes: List[ArmsEpisodePaths] = [] + for child in sorted(arms_root.iterdir()): + if not child.is_dir(): + continue + video = child / "video.mp4" + instr = child / "instruction.txt" + joint = child / "joint.txt" + action = child / "action.txt" + if video.exists() and instr.exists() and joint.exists() and action.exists(): + episodes.append( + ArmsEpisodePaths( + episode_dir=child, + video_path=video, + instruction_path=instr, + joint_path=joint, + action_path=action, + ) + ) + return episodes + + +def _ensure_dirs(out_root: Path) -> Dict[str, Path]: + meta = out_root / "meta" + videos = out_root / "videos" / "chunk-000" / "observation.images.front" + data = out_root / "data" / "chunk-000" + for p in (meta, videos, data): + p.mkdir(parents=True, exist_ok=True) + return {"meta": meta, "videos": videos, "data": data} + + +def _write_episode_parquet( + out_parquet: Path, + episode_index: int, + task_index: int, + action: np.ndarray, # [T, D] + state: np.ndarray, # [T, D] + frame_index: np.ndarray, # [T] + fps: float | None, +) -> None: + T, D = action.shape + if state.shape != (T, D): + raise ValueError(f"state shape {state.shape} != action shape {action.shape}") + + # timestamps: seconds; if fps unknown, store 0..T-1 + if fps is None or fps <= 0: + ts = frame_index.astype(np.float32) + else: + ts = frame_index.astype(np.float32) / np.float32(fps) + + # Use FixedSizeList for vector columns (consistent with how we read other datasets). + action_arr = pa.FixedSizeListArray.from_arrays(pa.array(action.reshape(-1), type=pa.float32()), D) + state_arr = pa.FixedSizeListArray.from_arrays(pa.array(state.reshape(-1), type=pa.float32()), D) + + table = pa.table( + { + "episode_index": pa.array(np.full((T,), episode_index, dtype=np.int32)), + "index": pa.array(np.arange(T, dtype=np.int32)), + "frame_index": pa.array(frame_index.astype(np.int32)), + "task_index": pa.array(np.full((T,), task_index, dtype=np.int32)), + "timestamp": pa.array(ts.astype(np.float32)), + "action": action_arr, + "observation.state": state_arr, + } + ) + pq.write_table(table, out_parquet) + + +def convert(arms_root: Path, out_root: Path) -> None: + eps = _discover_arms_episodes(arms_root) + if not eps: + raise RuntimeError(f"No valid episodes found under {arms_root}") + + dirs = _ensure_dirs(out_root) + episodes_jsonl = dirs["meta"] / "episodes.jsonl" + + # overwrite episodes.jsonl + if episodes_jsonl.exists(): + episodes_jsonl.unlink() + + id_map: Dict[int, str] = {} + action_columns_ref: List[str] | None = None + joint_columns_ref: List[str] | None = None + + for episode_index, ep in enumerate(eps): + task = _read_single_line_text(ep.instruction_path) + state, joint_cols, state_idx = _read_csv_with_first_index_col(ep.joint_path) + action, action_cols, action_idx = _read_csv_with_first_index_col(ep.action_path) + + # sanity checks + if state.shape[0] != action.shape[0]: + raise ValueError(f"{ep.episode_dir}: joint/action length mismatch") + if not np.array_equal(state_idx, action_idx): + raise ValueError(f"{ep.episode_dir}: joint/action time index mismatch") + + if joint_columns_ref is None: + joint_columns_ref = joint_cols + elif joint_cols != joint_columns_ref: + raise ValueError(f"{ep.episode_dir}: joint columns differ from first episode") + + if action_columns_ref is None: + action_columns_ref = action_cols + elif action_cols != action_columns_ref: + raise ValueError(f"{ep.episode_dir}: action columns differ from first episode") + + T = int(state.shape[0]) + fps = _get_video_fps(ep.video_path) + + # write parquet + out_parquet = dirs["data"] / f"episode_{episode_index:06d}.parquet" + _write_episode_parquet( + out_parquet=out_parquet, + episode_index=episode_index, + task_index=0, + action=action, + state=state, + frame_index=state_idx, + fps=fps, + ) + + # copy video + out_video = dirs["videos"] / f"episode_{episode_index:06d}.mp4" + if not out_video.exists(): + shutil.copy2(ep.video_path, out_video) + + # write episodes.jsonl line + line = { + "episode_index": episode_index, + "tasks": [task], + "length": T, + "action_config": [ + { + "start_frame": int(state_idx[0]), + "end_frame": int(state_idx[-1]) + 1, # end_frame is exclusive in many tools + "action_text": task, + "skill": "", + } + ], + } + with episodes_jsonl.open("a", encoding="utf-8") as f: + f.write(json.dumps(line, ensure_ascii=False) + "\n") + + id_map[episode_index] = ep.episode_dir.name + + # write a small manifest for debugging + manifest = { + "arms_root": str(arms_root), + "num_episodes": len(eps), + "action_dim": len(action_columns_ref or []), + "state_dim": len(joint_columns_ref or []), + "columns": { + "action": action_columns_ref, + "state": joint_columns_ref, + }, + "episode_id_map": id_map, + } + (out_root / "manifest.json").write_text(json.dumps(manifest, ensure_ascii=False, indent=2), encoding="utf-8") + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--arms-root", type=str, default="arms/train", help="arms/train 目录") + ap.add_argument("--out-root", type=str, default="arms_lerobot", help="输出 LeRobot 数据集目录") + args = ap.parse_args() + + convert(Path(args.arms_root), Path(args.out_root)) + print(f"Done. Wrote LeRobot dataset to: {args.out_root}") + + +if __name__ == "__main__": + main() + diff --git a/arms/decode_test_latents_to_video.py b/arms/decode_test_latents_to_video.py new file mode 100644 index 0000000..ac2dff3 --- /dev/null +++ b/arms/decode_test_latents_to_video.py @@ -0,0 +1,167 @@ +""" +把 run_arms_test_infer_latents_only.py 产出的 pred_latents.pt 解码成 video.mp4。 + +只加载 VAE(不加载 transformer),所以显存占用低,能在 GPU 上快速解码。 +""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +import cv2 +import numpy as np +import torch +from diffusers import AutoencoderKLWan +from diffusers.utils import export_to_video +from diffusers.video_processor import VideoProcessor + + +def _probe_video(path: Path) -> tuple[int, int, float]: + cap = cv2.VideoCapture(str(path)) + if not cap.isOpened(): + cap.release() + raise RuntimeError(f"Cannot open video: {path}") + fps = float(cap.get(cv2.CAP_PROP_FPS) or 0.0) + w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0) + h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0) + cap.release() + return w, h, fps + + +def _count_target_frames(ep_dir: Path) -> int | None: + # Prefer action.txt line count (matches submission expectation) + action_path = ep_dir / "action.txt" + if not action_path.exists(): + return None + n_lines = len(action_path.read_text(encoding="utf-8").splitlines()) + if n_lines <= 1: + return None + return n_lines - 1 + + +def _resample_video_np(video: np.ndarray, target_frames: int) -> np.ndarray: + # video: [F,H,W,C] float/uint8 + if target_frames <= 0: + raise ValueError("target_frames must be > 0") + f = int(video.shape[0]) + if f == target_frames: + return video + if f <= 1: + return np.repeat(video, target_frames, axis=0) + + # Linear interpolation in time + t_src = np.linspace(0.0, 1.0, num=f, endpoint=True, dtype=np.float32) + t_tgt = np.linspace(0.0, 1.0, num=target_frames, endpoint=True, dtype=np.float32) + out = np.empty((target_frames, *video.shape[1:]), dtype=video.dtype) + for i, t in enumerate(t_tgt): + j = int(np.searchsorted(t_src, t, side="right") - 1) + j = max(0, min(j, f - 2)) + t0, t1 = float(t_src[j]), float(t_src[j + 1]) + w = 0.0 if t1 == t0 else (float(t) - t0) / (t1 - t0) + out[i] = (1.0 - w) * video[j] + w * video[j + 1] + return out + + +def _match_length(video: np.ndarray, target_frames: int) -> np.ndarray: + """Prefer trimming over interpolation to preserve sharpness.""" + f = int(video.shape[0]) + if f == target_frames: + return video + if f > target_frames: + return video[:target_frames] + # f < target_frames + if f <= 1: + return np.repeat(video, target_frames, axis=0) + # If short, interpolate (less harmful than upscaling everything) + return _resample_video_np(video, target_frames) + + +def _resize_video_np(video: np.ndarray, target_w: int, target_h: int) -> np.ndarray: + if target_w <= 0 or target_h <= 0: + return video + f, h, w, c = video.shape + if w == target_w and h == target_h: + return video + out = np.empty((f, target_h, target_w, c), dtype=video.dtype) + for i in range(f): + out[i] = cv2.resize(video[i], (target_w, target_h), interpolation=cv2.INTER_AREA) + return out + + +@torch.no_grad() +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--pred-root", type=str, required=True, help="包含各 episode 子目录的输出根目录") + ap.add_argument("--test-root", type=str, default=None, help="可选:原始 test 根目录(用于对齐 fps/分辨率)") + ap.add_argument("--vae-root", type=str, required=True, help="models/lingbot-va-base 或其 vae 子目录") + ap.add_argument("--device", type=str, default="cuda:0") + ap.add_argument("--dtype", type=str, default="float16", choices=["float16", "bfloat16"]) + ap.add_argument("--force-frames", type=int, default=None, help="可选:强制输出帧数(提交通常要求 50)") + ap.add_argument("--skip-existing", action="store_true") + args = ap.parse_args() + + pred_root = Path(args.pred_root) + test_root = Path(args.test_root) if args.test_root else None + vae_root = Path(args.vae_root) + device = torch.device(args.device) + dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16 + + vae_path = vae_root / "vae" if (vae_root / "vae").exists() else vae_root + vae = AutoencoderKLWan.from_pretrained(str(vae_path), torch_dtype=dtype).to(device).eval() + video_processor = VideoProcessor(vae_scale_factor=1) + + ep_dirs = sorted([p for p in pred_root.iterdir() if p.is_dir()]) + for ep_dir in ep_dirs: + pt_path = ep_dir / "pred_latents.pt" + out_mp4 = ep_dir / "video.mp4" + if not pt_path.exists(): + continue + if args.skip_existing and out_mp4.exists(): + continue + + payload = torch.load(pt_path, map_location="cpu") + latents = payload["latents"].to(device=device, dtype=dtype) # [1,48,51,16,16] normalized + + latents_mean = torch.tensor(vae.config.latents_mean, device=device, dtype=dtype).view(1, vae.config.z_dim, 1, 1, 1) + latents_std = 1.0 / torch.tensor(vae.config.latents_std, device=device, dtype=dtype).view(1, vae.config.z_dim, 1, 1, 1) + latents_denorm = latents / latents_std + latents_mean + + video = vae.decode(latents_denorm, return_dict=False)[0] + video = video_processor.postprocess_video(video, output_type="np")[0] + + # Align to submission expectations: same FPS/resolution as original test video, and frame count as action.txt + target_frames = _count_target_frames(ep_dir) or int(video.shape[0]) + # Submission usually expects 50 future frames (README). Some scripts output 51 steps (80..130 inclusive). + if args.force_frames is not None and args.force_frames > 0: + target_frames = int(args.force_frames) + elif target_frames == 51: + target_frames = 50 + target_fps = float(payload.get("fps", 10) or 10) + target_w = target_h = 0 + if test_root is not None: + src_video = test_root / ep_dir.name / "video.mp4" + if src_video.exists(): + target_w, target_h, src_fps = _probe_video(src_video) + if src_fps > 0: + target_fps = src_fps + + video = video.astype(np.float32) + video = _match_length(video, target_frames) + video = _resize_video_np(video, target_w=target_w, target_h=target_h) + # export_to_video expects uint8 [0,255] + vmax = float(video.max()) if video.size else 0.0 + if vmax <= 1.0: + video = video * 255.0 + elif vmax <= 2.0: + # In case the range is [0,2] (rare), map to [0,255] + video = (video / 2.0) * 255.0 + video = np.clip(video, 0, 255).astype(np.uint8) + export_to_video(video, str(out_mp4), fps=int(round(target_fps))) + + print(f"Done. Decoded videos to: {pred_root}") + + +if __name__ == "__main__": + main() + diff --git a/arms/extract_latents_arms_lerobot.py b/arms/extract_latents_arms_lerobot.py new file mode 100644 index 0000000..91b5fab --- /dev/null +++ b/arms/extract_latents_arms_lerobot.py @@ -0,0 +1,334 @@ +""" +为 arms_lerobot 提取视频 latents(Wan2.2 VAE encoder),输出到: + arms_lerobot/latents/chunk-000/observation.images.front/episode_XXXXXX_0_T.pth + +该 .pth 文件的字段尽量对齐 README “Extract Latents” 部分: + - latent: Tensor [N, C] (bfloat16/float16) + - latent_num_frames: int + - latent_height: int + - latent_width: int + - video_num_frames: int + - video_height / video_width: int + - text_emb: Tensor [L, D] + - text: str + - frame_ids: list[int] + - start_frame / end_frame: int(end_frame 为 exclusive) + - fps / ori_fps: int/float + +依赖: + - torch, diffusers, transformers, opencv-python, pyarrow(只用来读 episodes.jsonl 的话不需要) + +运行建议(示例): + conda activate gmr + python arms/extract_latents_arms_lerobot.py \ + --dataset-root arms_lerobot \ + --wan22-path /path/to/wan2.2 \ + --device cuda:0 \ + --dtype bfloat16 \ + --target-size 256 256 \ + --max-episodes 2 +""" + +from __future__ import annotations + +import argparse +import json +import math +import os +import sys +from pathlib import Path +from typing import Dict, List, Tuple + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F + +from diffusers import AutoencoderKLWan + + +def load_vae(vae_path: str, torch_dtype: torch.dtype, device: torch.device): + vae = AutoencoderKLWan.from_pretrained(vae_path, torch_dtype=torch_dtype) + return vae.to(device) + + +def patchify(x: torch.Tensor, patch_size): + if patch_size is None or patch_size == 1: + return x + batch_size, channels, frames, height, width = x.shape + x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size) + x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous() + x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size) + return x + + +class WanVAEStreamingWrapper: + def __init__(self, vae_model): + self.vae = vae_model + self.encoder = vae_model.encoder + self.quant_conv = vae_model.quant_conv + + if hasattr(self.vae, "_cached_conv_counts"): + self.enc_conv_num = self.vae._cached_conv_counts["encoder"] + else: + count = 0 + for m in self.encoder.modules(): + if m.__class__.__name__ == "WanCausalConv3d": + count += 1 + self.enc_conv_num = count + + self.clear_cache() + + def clear_cache(self): + self.feat_cache = [None] * self.enc_conv_num + + @torch.no_grad() + def encode_chunk(self, x_chunk: torch.Tensor): + if hasattr(self.vae.config, "patch_size") and self.vae.config.patch_size is not None: + x_chunk = patchify(x_chunk, self.vae.config.patch_size) + feat_idx = [0] + out = self.encoder(x_chunk, feat_cache=self.feat_cache, feat_idx=feat_idx) + enc = self.quant_conv(out) + return enc + + +def read_video_all_frames(video_path: Path) -> Tuple[List[np.ndarray], float]: + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + raise RuntimeError(f"Cannot open video: {video_path}") + fps = float(cap.get(cv2.CAP_PROP_FPS) or 0.0) + frames: List[np.ndarray] = [] + while True: + ok, frame_bgr = cap.read() + if not ok or frame_bgr is None: + break + frames.append(frame_bgr) + cap.release() + return frames, fps + + +def sample_frames(frames: List[np.ndarray], stride: int) -> Tuple[List[np.ndarray], List[int]]: + if stride <= 0: + stride = 1 + idxs = list(range(0, len(frames), stride)) + out = [frames[i] for i in idxs] + return out, idxs + + +def preprocess_frames_to_tensor( + frames_bgr: List[np.ndarray], + target_h: int, + target_w: int, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + # frames_bgr: list of HWC uint8 + rgb = [ + cv2.cvtColor(f, cv2.COLOR_BGR2RGB) + for f in frames_bgr + ] + arr = np.stack(rgb, axis=0) # [F, H, W, 3] + # resize + if arr.shape[1] != target_h or arr.shape[2] != target_w: + resized = [] + for f in arr: + resized.append(cv2.resize(f, (target_w, target_h), interpolation=cv2.INTER_AREA)) + arr = np.stack(resized, axis=0) + # to torch: [1,3,F,H,W] float in [-1,1] + x = torch.from_numpy(arr).to(device=device) + x = x.permute(3, 0, 1, 2).contiguous().float() # [3,F,H,W] + x = (x / 255.0) * 2.0 - 1.0 + x = x.unsqueeze(0).to(dtype=dtype) + return x + + +def make_empty_text_emb(max_sequence_length: int, text_dim: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + return torch.zeros((max_sequence_length, text_dim), device=device, dtype=dtype) + + +@torch.no_grad() +def encode_video_to_latent_flat( + vae, + video_tensor: torch.Tensor, # [1,3,F,H,W] + chunk_frames: int = 0, +) -> Tuple[torch.Tensor, int, int, int]: + """ + 使用 diffusers 的 `AutoencoderKLWan.encode` 直接编码。 + + 说明: + - `AutoencoderKLWan.encode` 内部会处理其需要的 patchify / 3D 因果卷积缓存。 + - 为避免不同版本 diffusers 的 streaming API 兼容问题,这里默认整段编码。 + """ + _ = chunk_frames # keep CLI arg for backward compatibility + enc = vae.encode(video_tensor) + mu = enc.latent_dist.mean # [1, z_dim, F', H', W'] + + latents_mean = torch.tensor(vae.config.latents_mean, device=mu.device, dtype=mu.dtype).view(1, -1, 1, 1, 1) + latents_std = torch.tensor(vae.config.latents_std, device=mu.device, dtype=mu.dtype).view(1, -1, 1, 1, 1) + mu_norm = (mu.float() - latents_mean.float()) * (1.0 / latents_std.float()) + mu_norm = mu_norm.to(mu.dtype) + + _, c, f, h, w = mu_norm.shape + flat = mu_norm.permute(0, 2, 3, 4, 1).reshape(-1, c) # [F*H*W, C] + return flat, f, h, w + + +def load_episodes_jsonl(path: Path) -> List[Dict]: + out = [] + with path.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + out.append(json.loads(line)) + return out + + +def ensure_latent_dir(dataset_root: Path, video_rel_dir: Path) -> Path: + # video_rel_dir like: videos/chunk-000/observation.images.front + latent_dir = dataset_root / "latents" / "chunk-000" / video_rel_dir.name + latent_dir.mkdir(parents=True, exist_ok=True) + return latent_dir + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--dataset-root", type=str, required=True) + ap.add_argument( + "--out-root", + type=str, + default="", + help="可选:latents 输出根目录(默认写到 /latents)。会自动创建 chunk-000/observation.images.front/", + ) + ap.add_argument("--wan22-path", type=str, required=True, help="包含 vae/tokenizer/text_encoder/transformer 的目录(取其中 vae/tokenizer/text_encoder)") + ap.add_argument("--device", type=str, default="cuda:0") + ap.add_argument("--dtype", type=str, default="bfloat16", choices=["float16", "bfloat16"]) + ap.add_argument("--target-size", type=int, nargs=2, default=[256, 256], metavar=("H", "W")) + ap.add_argument("--stride", type=int, default=1, help="从原视频每 stride 帧取一帧") + ap.add_argument("--max-seq-len", type=int, default=512) + ap.add_argument( + "--text-mode", + type=str, + default="empty", + choices=["empty", "encode"], + help="empty: 不加载 text encoder,写全零 text_emb;encode: 用 tokenizer+text_encoder 编码指令文本", + ) + ap.add_argument("--text-dim", type=int, default=4096, help="text-mode=empty 时使用的 embedding 维度") + ap.add_argument("--max-episodes", type=int, default=-1, help="调试用,<=0 表示全量") + ap.add_argument("--skip-existing", action="store_true", help="若 latent 文件已存在则跳过") + ap.add_argument("--vae-chunk-frames", type=int, default=8, help="VAE 编码按时间分块大小(避免 OOM)") + args = ap.parse_args() + + dataset_root = Path(args.dataset_root) + out_root = Path(args.out_root) if str(args.out_root).strip() else (dataset_root / "latents") + out_root.mkdir(parents=True, exist_ok=True) + wan22_path = Path(args.wan22_path) + device = torch.device(args.device) + dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + target_h, target_w = int(args.target_size[0]), int(args.target_size[1]) + + # load VAE: support local directory with subfolder `vae/` or HF repo id + if wan22_path.exists(): + vae = load_vae(str(wan22_path / "vae"), torch_dtype=dtype, device=device) + is_local = True + else: + # treat as HuggingFace repo id + try: + vae = AutoencoderKLWan.from_pretrained(args.wan22_path, subfolder="vae", torch_dtype=dtype).to(device) + except Exception: + # fallback: repo root itself is the VAE + vae = AutoencoderKLWan.from_pretrained(args.wan22_path, torch_dtype=dtype).to(device) + is_local = False + # text embedding + if args.text_mode == "encode": + from diffusers.pipelines.wan.pipeline_wan import prompt_clean # local import to keep deps optional + from transformers import T5TokenizerFast, UMT5EncoderModel + + if is_local: + tokenizer = T5TokenizerFast.from_pretrained(str(wan22_path / "tokenizer")) + text_encoder = UMT5EncoderModel.from_pretrained(str(wan22_path / "text_encoder"), torch_dtype=dtype).to(device) + else: + tokenizer = T5TokenizerFast.from_pretrained(args.wan22_path, subfolder="tokenizer") + text_encoder = UMT5EncoderModel.from_pretrained(args.wan22_path, subfolder="text_encoder", torch_dtype=dtype).to(device) + text_encoder.eval() + else: + tokenizer = None + text_encoder = None + # also dump empty_emb.pt for later training configs + empty_emb = make_empty_text_emb(args.max_seq_len, args.text_dim, device=device, dtype=dtype).cpu() + torch.save(empty_emb, out_root / "empty_emb.pt") + + episodes = load_episodes_jsonl(dataset_root / "meta" / "episodes.jsonl") + + video_dir = dataset_root / "videos" / "chunk-000" / "observation.images.front" + latent_dir = out_root / "chunk-000" / "observation.images.front" + latent_dir.mkdir(parents=True, exist_ok=True) + + n = len(episodes) if args.max_episodes <= 0 else min(len(episodes), int(args.max_episodes)) + for i in range(n): + ep = episodes[i] + episode_index = int(ep["episode_index"]) + task_text = ep["tasks"][0] + start_frame = int(ep["action_config"][0]["start_frame"]) + end_frame = int(ep["action_config"][0]["end_frame"]) # exclusive + + video_path = video_dir / f"episode_{episode_index:06d}.mp4" + out_path = latent_dir / f"episode_{episode_index:06d}_{start_frame}_{end_frame}.pth" + if args.skip_existing and out_path.exists(): + continue + + frames_all, ori_fps = read_video_all_frames(video_path) + frames, frame_ids = sample_frames(frames_all, args.stride) + if len(frames) == 0: + raise RuntimeError(f"Empty video after sampling: {video_path}") + + video_tensor = preprocess_frames_to_tensor(frames, target_h, target_w, device=device, dtype=dtype) + + latent_flat, latent_f, latent_h, latent_w = encode_video_to_latent_flat( + vae=vae, + video_tensor=video_tensor, + chunk_frames=int(args.vae_chunk_frames), + ) + + if args.text_mode == "encode": + prompt = prompt_clean(task_text) + text_inputs = tokenizer( + [prompt], + padding="max_length", + max_length=args.max_seq_len, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + attn_mask = text_inputs.attention_mask.to(device) + text_emb = text_encoder(text_input_ids, attn_mask).last_hidden_state[0].to(dtype=dtype) + else: + text_emb = make_empty_text_emb(args.max_seq_len, args.text_dim, device=device, dtype=dtype) + + payload = { + "latent": latent_flat.to(dtype), + "latent_num_frames": int(latent_f), + "latent_height": int(latent_h), + "latent_width": int(latent_w), + "video_num_frames": int(len(frames)), + "video_height": int(target_h), + "video_width": int(target_w), + "text_emb": text_emb.to(dtype), + "text": task_text, + "frame_ids": [int(x) for x in frame_ids], + "start_frame": int(start_frame), + "end_frame": int(end_frame), + "fps": float(ori_fps / max(1, args.stride)) if ori_fps else 0.0, + "ori_fps": float(ori_fps) if ori_fps else 0.0, + } + + torch.save(payload, out_path) + + print(f"Done. Wrote latents to: {latent_dir}") + + +if __name__ == "__main__": + main() + diff --git a/arms/generate_arms_and_dump.py b/arms/generate_arms_and_dump.py new file mode 100644 index 0000000..2dd4cf1 --- /dev/null +++ b/arms/generate_arms_and_dump.py @@ -0,0 +1,142 @@ +""" +从 arms_lerobot 生成“后续 50 步”(按样例索引 80–130 共 51 行)的 action/joint,并落盘为 sample_result 风格。 + +注意: + 1) 该脚本假设你已经有一个可用的 LingBot-VA 模型目录(包含 transformer/vae/tokenizer/text_encoder)。 + 2) 当前实现优先把“推理跑通 + 输出格式正确”作为目标;joint 的生成默认直接使用 action(可替换为更精确的 joint 预测/解算)。 + +输出目录结构: + // + - instruction.txt + - action.txt + - joint.txt + - video.mp4 (可选:如果你把 latent 解码为视频) + +用法(示例): + conda activate gmr + python arms/compute_arms_norm_stat.py --dataset-root arms_lerobot + python arms/generate_arms_and_dump.py \\ + --dataset-root arms_lerobot \\ + --model-root /path/to/lingbot-va-checkpoint \\ + --episode-index 0 \\ + --out-root arms/generated_samples +""" + +from __future__ import annotations + +import argparse +import csv +import json +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +import pandas as pd + + +def read_manifest(dataset_root: Path) -> Dict: + return json.loads((dataset_root / "manifest.json").read_text(encoding="utf-8")) + + +def read_norm_stat(dataset_root: Path) -> Dict: + p = dataset_root / "norm_stat.json" + if not p.exists(): + raise RuntimeError(f"Missing {p}. Run: python arms/compute_arms_norm_stat.py --dataset-root {dataset_root}") + return json.loads(p.read_text(encoding="utf-8")) + + +def read_instruction(dataset_root: Path, episode_index: int) -> str: + # from episodes.jsonl + ep_lines = (dataset_root / "meta" / "episodes.jsonl").read_text(encoding="utf-8").strip().splitlines() + obj = json.loads(ep_lines[episode_index]) + return obj["tasks"][0] + + +def read_episode_action_state(dataset_root: Path, episode_index: int) -> Tuple[np.ndarray, np.ndarray]: + p = dataset_root / "data" / "chunk-000" / f"episode_{episode_index:06d}.parquet" + df = pd.read_parquet(p, columns=["action", "observation.state"]) + action = np.stack(df["action"].to_list()).astype(np.float32) # [T, D] + state = np.stack(df["observation.state"].to_list()).astype(np.float32) # [T, D] + return action, state + + +def write_csv_like_sample(path: Path, header: List[str], idxs: List[int], data: np.ndarray) -> None: + """ + header: 不含索引列名(样例里索引列名是 Unnamed: 0) + data: shape [len(idxs), D] + """ + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["Unnamed: 0"] + header) + for i, t in enumerate(idxs): + writer.writerow([int(t)] + [float(x) for x in data[i].tolist()]) + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--dataset-root", type=str, required=True) + ap.add_argument("--model-root", type=str, required=True, help="LingBot-VA checkpoint 根目录(含 transformer/vae/tokenizer/text_encoder)") + ap.add_argument("--episode-index", type=int, default=0) + ap.add_argument("--out-root", type=str, required=True) + ap.add_argument("--start-idx", type=int, default=80) + ap.add_argument("--end-idx", type=int, default=130) + args = ap.parse_args() + + dataset_root = Path(args.dataset_root) + out_root = Path(args.out_root) + manifest = read_manifest(dataset_root) + norm_stat = read_norm_stat(dataset_root) + + episode_index = int(args.episode_index) + episode_id = manifest["episode_id_map"][str(episode_index)] + + # columns + cols = manifest["columns"]["action"] + action_all, state_all = read_episode_action_state(dataset_root, episode_index) + prompt = read_instruction(dataset_root, episode_index) + + # 简化:先用“真值 action/ joint”来演示落盘格式(你后续换成模型预测即可) + s = int(args.start_idx) + e = int(args.end_idx) + idxs = list(range(s, e + 1)) + # 某些 episode 可能长度 < 131。为了保证“按样例输出 80–130”,这里做 padding: + # - 已有部分用真值 + # - 超出长度的部分用最后一帧重复(占位) + T = action_all.shape[0] + if s >= T: + # 全部超出:直接用最后一帧重复(或全 0) + last_a = action_all[-1] + last_s = state_all[-1] + action_out = np.repeat(last_a[None], repeats=len(idxs), axis=0) + joint_out = np.repeat(last_s[None], repeats=len(idxs), axis=0) + else: + a_part = action_all[s : min(e + 1, T)] + s_part = state_all[s : min(e + 1, T)] + need = len(idxs) - a_part.shape[0] + if need > 0: + last_a = a_part[-1] + last_s = s_part[-1] + a_pad = np.repeat(last_a[None], repeats=need, axis=0) + s_pad = np.repeat(last_s[None], repeats=need, axis=0) + action_out = np.concatenate([a_part, a_pad], axis=0) + joint_out = np.concatenate([s_part, s_pad], axis=0) + else: + action_out = a_part + joint_out = s_part + + ep_out_dir = out_root / episode_id + ep_out_dir.mkdir(parents=True, exist_ok=True) + (ep_out_dir / "instruction.txt").write_text(prompt + "\n", encoding="utf-8") + write_csv_like_sample(ep_out_dir / "action.txt", cols, idxs, action_out) + write_csv_like_sample(ep_out_dir / "joint.txt", cols, idxs, joint_out) + + print(f"Done. Wrote: {ep_out_dir}") + print("NOTE: 当前脚本先用真值 action/joint 演示输出格式;下一步把 action_out/joint_out 替换为模型预测即可。") + + +if __name__ == "__main__": + main() + diff --git a/arms/rebuild_submit_videos_50frames.py b/arms/rebuild_submit_videos_50frames.py new file mode 100644 index 0000000..db839a0 --- /dev/null +++ b/arms/rebuild_submit_videos_50frames.py @@ -0,0 +1,101 @@ +""" +为提交目录重建 video.mp4,使其与 sample_result 对齐:50 帧、30fps。 + +输入: + - pred_root: run_arms_test_infer_latents_only.py 的输出(含 pred_latents.pt) + - submit_root: 最终提交目录(每个 episode 子目录已经有 action/joint/instruction) + +做法: + - 从 pred_latents.pt 读取 normalized latents(latent 帧数=51) + - 用 VAE 解码得到视频(通常会上采样到 ~200 帧) + - 将解码视频均匀采样到 50 帧,再导出为 mp4(30fps) +""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +import numpy as np +import torch +import cv2 +from diffusers import AutoencoderKLWan +from diffusers.utils import export_to_video +from diffusers.video_processor import VideoProcessor + + +@torch.no_grad() +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--pred-root", type=str, required=True) + ap.add_argument("--submit-root", type=str, required=True) + ap.add_argument("--vae-root", type=str, required=True) + ap.add_argument("--device", type=str, default="cuda:0") + ap.add_argument("--dtype", type=str, default="float16", choices=["float16", "bfloat16"]) + ap.add_argument("--target-frames", type=int, default=50) + ap.add_argument("--fps", type=int, default=30) + ap.add_argument("--target-size", type=int, nargs=2, default=[720, 1280], metavar=("H", "W")) + ap.add_argument("--skip-existing", action="store_true") + args = ap.parse_args() + + pred_root = Path(args.pred_root) + submit_root = Path(args.submit_root) + vae_root = Path(args.vae_root) + device = torch.device(args.device) + dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16 + + vae_path = vae_root / "vae" if (vae_root / "vae").exists() else vae_root + vae = AutoencoderKLWan.from_pretrained(str(vae_path), torch_dtype=dtype).to(device).eval() + video_processor = VideoProcessor(vae_scale_factor=1) + + for ep_dir in sorted([p for p in pred_root.iterdir() if p.is_dir()]): + ep_id = ep_dir.name + if ep_id == "real": + continue + + pt_path = ep_dir / "pred_latents.pt" + if not pt_path.exists(): + continue + + out_dir = submit_root / ep_id + out_mp4 = out_dir / "video.mp4" + if not out_dir.exists(): + continue + if args.skip_existing and out_mp4.exists(): + continue + + payload = torch.load(pt_path, map_location="cpu") + latents = payload["latents"].to(device=device, dtype=dtype) # normalized [1,48,T,16,16] + + latents_mean = torch.tensor(vae.config.latents_mean, device=device, dtype=dtype).view(1, vae.config.z_dim, 1, 1, 1) + latents_std = 1.0 / torch.tensor(vae.config.latents_std, device=device, dtype=dtype).view(1, vae.config.z_dim, 1, 1, 1) + latents_denorm = latents / latents_std + latents_mean + + video = vae.decode(latents_denorm, return_dict=False)[0] + video = video_processor.postprocess_video(video, output_type="np")[0] # [T,H,W,3] + + T = int(video.shape[0]) + target = int(args.target_frames) + if T != target: + # evenly sample to target frames + idx = np.linspace(0, max(0, T - 1), num=target) + idx = np.round(idx).astype(np.int64) + idx = np.clip(idx, 0, max(0, T - 1)) + video = video[idx] + + # resize to target resolution (GT is 1280x720) + target_h, target_w = int(args.target_size[0]), int(args.target_size[1]) + if int(video.shape[1]) != target_h or int(video.shape[2]) != target_w: + resized = [] + for fr in video: + resized.append(cv2.resize(fr, (target_w, target_h), interpolation=cv2.INTER_LINEAR)) + video = np.stack(resized, axis=0) + + export_to_video(video, str(out_mp4), fps=int(args.fps)) + + print(f"Done. Rebuilt submit videos in: {submit_root}") + + +if __name__ == "__main__": + main() + diff --git a/arms/run_arms_infer_all_latent.py b/arms/run_arms_infer_all_latent.py new file mode 100644 index 0000000..5c4abe5 --- /dev/null +++ b/arms/run_arms_infer_all_latent.py @@ -0,0 +1,242 @@ +""" +批量跑 arms_lerobot 的 latent 推理,并按比赛格式落盘到一个目录。 + +特点: +- 逐 episode 串行推理,避免显存/CPU 被并发打爆 +- 每个 episode 复用同一个 VA_Server(只 reset prompt + 重建 kv cache) +- 出错会记录到 errors.jsonl,继续下一个 episode + +示例: + /home/landscape-layton-ljw/miniconda3/envs/gmr/bin/python arms/run_arms_infer_all_latent.py \ + --dataset-root arms_lerobot \ + --model-root models/lingbot-va-base \ + --latents-root arms_lerobot/latents_lingbot \ + --out-root arms/generated_samples_latent_all \ + --device cuda:0 +""" + +from __future__ import annotations + +import argparse +import json +import sys +import time +import traceback +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +import pandas as pd +import torch +from easydict import EasyDict +from einops import rearrange + +# allow running as a standalone script +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from wan_va.wan_va_server import VA_Server +from wan_va.configs.va_arms_cfg import va_arms_cfg + +from arms.generate_arms_and_dump import read_manifest, write_csv_like_sample + + +def _load_episode_meta(dataset_root: Path, episode_index: int) -> Tuple[str, int, int]: + line = (dataset_root / "meta" / "episodes.jsonl").read_text(encoding="utf-8").splitlines()[episode_index] + ep = json.loads(line) + prompt = ep["tasks"][0] + cfg0 = ep["action_config"][0] + return prompt, int(cfg0["start_frame"]), int(cfg0["end_frame"]) + + +def _load_latents_5d(latent_path: Path, device: torch.device) -> torch.Tensor: + payload = torch.load(latent_path, map_location="cpu") + latent_flat: torch.Tensor = payload["latent"] # [N, C] + f = int(payload["latent_num_frames"]) + h = int(payload["latent_height"]) + w = int(payload["latent_width"]) + c = int(latent_flat.shape[-1]) + latents = rearrange(latent_flat, "(f h w) c -> 1 c f h w", f=f, h=h, w=w, c=c) + return latents.to(device) + + +def _pad_norm_stat_to_30(norm_stat: Dict) -> Dict: + out = dict(norm_stat) + for k in ["q01", "q99"]: + if k in out and len(out[k]) < 30: + pad_val = 0.0 if k == "q01" else 1.0 + out[k] = list(out[k]) + [pad_val] * (30 - len(out[k])) + return out + + +def _state_to_cf1_padded_29(state_t26: np.ndarray) -> np.ndarray: + # input: [T,26] -> [29,T,1] + state_cf1 = state_t26.astype(np.float32).T[:, :, None] + state_cf1 = np.pad(state_cf1, ((0, 3), (0, 0), (0, 0)), mode="constant", constant_values=0.0) + return state_cf1 + + +def _actions_to_t26(actions_any: np.ndarray, frame_chunk_size: int) -> np.ndarray: + a = np.asarray(actions_any, dtype=np.float32) + while a.ndim > 2 and a.shape[-1] == 1: + a = a[..., 0] + if a.ndim == 2 and a.shape[0] == 26 and a.shape[1] == frame_chunk_size: + a = a.T + if a.ndim == 2 and a.shape[0] in (26, 29, 30) and a.shape[1] != 26: + a = a.T + if a.ndim != 2: + raise RuntimeError(f"Unexpected action chunk shape: {a.shape}") + if a.shape[1] > 26: + a = a[:, :26] + if a.shape[1] != 26: + raise RuntimeError(f"Unexpected action chunk shape after trim: {a.shape}") + return a + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--dataset-root", type=str, required=True) + ap.add_argument("--model-root", type=str, required=True) + ap.add_argument("--latents-root", type=str, required=True) + ap.add_argument("--out-root", type=str, required=True) + ap.add_argument("--device", type=str, default="cuda:0") + ap.add_argument("--history-len", type=int, default=16) + ap.add_argument("--start-idx", type=int, default=80) + ap.add_argument("--end-idx", type=int, default=130) + ap.add_argument("--max-episodes", type=int, default=-1, help="调试用,<=0 表示全量") + args = ap.parse_args() + + dataset_root = Path(args.dataset_root) + latents_root = Path(args.latents_root) + out_root = Path(args.out_root) + out_root.mkdir(parents=True, exist_ok=True) + + manifest = read_manifest(dataset_root) + cols = manifest["columns"]["action"] + + norm_stat = json.loads((dataset_root / "norm_stat.json").read_text(encoding="utf-8")) + norm_stat = _pad_norm_stat_to_30(norm_stat) + + cfg = EasyDict(va_arms_cfg) + cfg.wan22_pretrained_model_name_or_path = str(Path(args.model_root).resolve()) + cfg.save_root = str(out_root.resolve()) + cfg.enable_offload = True + cfg.param_dtype = torch.float16 + cfg.local_rank = int(str(args.device).split(":")[-1]) if "cuda" in args.device else 0 + cfg.norm_stat = norm_stat + + if "cuda" in args.device: + torch.cuda.empty_cache() + server = VA_Server(cfg) + device = server.device + + episodes = (dataset_root / "meta" / "episodes.jsonl").read_text(encoding="utf-8").splitlines() + n = len(episodes) if args.max_episodes <= 0 else min(len(episodes), int(args.max_episodes)) + + errors_path = out_root / "errors.jsonl" + if errors_path.exists(): + errors_path.unlink() + + for episode_index in range(n): + t0 = time.time() + try: + episode_id = manifest["episode_id_map"][str(episode_index)] + prompt, start_frame, end_frame = _load_episode_meta(dataset_root, episode_index) + print(f"[{episode_index+1}/{n}] start {episode_id}", flush=True) + + # reset prompt + server.infer({"reset": True, "prompt": prompt}) + + # state + parquet_path = dataset_root / "data" / "chunk-000" / f"episode_{episode_index:06d}.parquet" + df = pd.read_parquet(parquet_path, columns=["observation.state"]) + state = np.stack(df["observation.state"].to_list()).astype(np.float32) # [T, 26] + + hist_len = min(int(args.history_len), int(state.shape[0])) + hist_latent_frames = int(np.ceil(hist_len / 4.0)) + + latent_path = latents_root / "chunk-000" / "observation.images.front" / f"episode_{episode_index:06d}_{start_frame}_{end_frame}.pth" + if not latent_path.exists(): + raise FileNotFoundError(str(latent_path)) + + latents_5d = _load_latents_5d(latent_path, device=device) + if latents_5d.shape[2] < hist_latent_frames: + raise RuntimeError(f"latent frames insufficient: need {hist_latent_frames}, got {latents_5d.shape[2]}") + latents_hist = latents_5d[:, :, :hist_latent_frames] + + server.init_latent = latents_hist[:, :, :1].to(server.dtype) + latent_model_input = latents_hist[:, :, 1:].to(server.dtype) if hist_latent_frames > 1 else None + + state_cf1 = _state_to_cf1_padded_29(state[:hist_len]) + action_model_input = server.preprocess_action(state_cf1).to(device=device, dtype=server.dtype) + + server.transformer.clear_pred_cache(server.cache_name) + server.frame_st_id = 0 + if server.init_latent is not None and latent_model_input is not None: + latent_full = torch.cat([server.init_latent, latent_model_input], dim=2) + else: + latent_full = server.init_latent if latent_model_input is None else latent_model_input + + input_dict = server._prepare_latent_input(latent_full, action_model_input, frame_st_id=server.frame_st_id) + with torch.no_grad(): + server.transformer( + server._repeat_input_for_cfg(input_dict["latent_res_lst"]), + update_cache=2, + cache_name=server.cache_name, + action_mode=False, + ) + server.transformer( + server._repeat_input_for_cfg(input_dict["action_res_lst"]), + update_cache=2, + cache_name=server.cache_name, + action_mode=True, + ) + server.frame_st_id += int(latent_full.shape[2]) + + # rollout actions + need = int(args.end_idx) - int(args.start_idx) + 1 + chunks: List[np.ndarray] = [] + for _ in range(1000): + out = server.infer({"obs": [], "state": state_cf1}) + if "action" in out: + chunks.append(_actions_to_t26(out["action"], frame_chunk_size=int(cfg.frame_chunk_size))) + if sum(x.shape[0] for x in chunks) >= need: + break + if not chunks: + raise RuntimeError("No action returned.") + + action_out = np.concatenate(chunks, axis=0)[:need] + idxs = list(range(int(args.start_idx), int(args.end_idx) + 1)) + if action_out.shape[0] < need: + pad = np.repeat(action_out[-1:], repeats=need - action_out.shape[0], axis=0) + action_out = np.concatenate([action_out, pad], axis=0) + + joint_out = action_out.copy() + + ep_out_dir = out_root / episode_id + ep_out_dir.mkdir(parents=True, exist_ok=True) + (ep_out_dir / "instruction.txt").write_text(prompt + "\n", encoding="utf-8") + write_csv_like_sample(ep_out_dir / "action.txt", cols, idxs, action_out) + write_csv_like_sample(ep_out_dir / "joint.txt", cols, idxs, joint_out) + print(f"[{episode_index+1}/{n}] done {episode_id} in {time.time()-t0:.1f}s", flush=True) + + except Exception as e: + rec = { + "episode_index": episode_index, + "error": repr(e), + "traceback": traceback.format_exc(), + } + with errors_path.open("a", encoding="utf-8") as f: + f.write(json.dumps(rec, ensure_ascii=False) + "\n") + print(f"[{episode_index+1}/{n}] ERROR in {time.time()-t0:.1f}s: {rec['error']}", flush=True) + continue + + print(f"Done. Wrote results to: {out_root}") + if errors_path.exists(): + print(f"Errors (if any): {errors_path}") + + +if __name__ == "__main__": + main() + diff --git a/arms/run_arms_infer_and_dump.py b/arms/run_arms_infer_and_dump.py new file mode 100644 index 0000000..daf09a6 --- /dev/null +++ b/arms/run_arms_infer_and_dump.py @@ -0,0 +1,246 @@ +""" +用 lingbot-va-base(本地模型目录)对 arms_lerobot 做推理,并按样例 80–130(51 行)落盘。 + +当前实现目标:把“能跑通模型推理 → 拿到动作序列 → 写成 action.txt/joint.txt”这条链路打通。 +joint 的输出默认直接复制 action(同维度),你后续可以替换为更合理的 joint 预测/解算。 +""" + +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path +from typing import Dict, List + +import cv2 +import numpy as np +import pandas as pd +import torch +from easydict import EasyDict +from einops import rearrange + +# allow running as a standalone script +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from wan_va.wan_va_server import VA_Server +from wan_va.configs.va_arms_cfg import va_arms_cfg + +from arms.generate_arms_and_dump import read_manifest, write_csv_like_sample + + +def load_episode_prompt(dataset_root: Path, episode_index: int) -> str: + line = (dataset_root / "meta" / "episodes.jsonl").read_text(encoding="utf-8").splitlines()[episode_index] + return json.loads(line)["tasks"][0] + + +def load_episode_action_range(dataset_root: Path, episode_index: int) -> tuple[int, int]: + line = (dataset_root / "meta" / "episodes.jsonl").read_text(encoding="utf-8").splitlines()[episode_index] + ep = json.loads(line) + cfg0 = ep["action_config"][0] + return int(cfg0["start_frame"]), int(cfg0["end_frame"]) + + +def load_latents_5d(latent_path: Path, device: torch.device) -> torch.Tensor: + payload = torch.load(latent_path, map_location="cpu") + latent_flat: torch.Tensor = payload["latent"] # [N, C] + f = int(payload["latent_num_frames"]) + h = int(payload["latent_height"]) + w = int(payload["latent_width"]) + c = int(latent_flat.shape[-1]) + latents = rearrange(latent_flat, "(f h w) c -> 1 c f h w", f=f, h=h, w=w, c=c) + return latents.to(device) + + +def load_video_frames(video_path: Path, frame_ids: List[int]) -> List[np.ndarray]: + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + raise RuntimeError(f"Cannot open {video_path}") + frames = [] + for idx in frame_ids: + cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx)) + ok, frame = cap.read() + if not ok or frame is None: + raise RuntimeError(f"Cannot read frame {idx} from {video_path}") + # BGR -> RGB + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frames.append(frame) + cap.release() + return frames + + +def build_obs(front_frames_rgb: List[np.ndarray], state_seq: np.ndarray, prompt: str) -> Dict: + obs_seq = [{"observation.images.front": fr.astype(np.uint8)} for fr in front_frames_rgb] + # VA_Server.preprocess_action expects numpy array shaped [C, F, H] + # We store state_seq as [F, C] -> [C, F, 1] + state_cf1 = state_seq.astype(np.float32).T[:, :, None] + return {"obs": obs_seq, "state": state_cf1, "prompt": prompt} + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--dataset-root", type=str, required=True) + ap.add_argument("--model-root", type=str, required=True, help="本地 lingbot-va-base 目录") + ap.add_argument("--latents-root", type=str, default="", help="可选:优先从这里读取 latents(例如 /latents_lingbot)") + ap.add_argument("--episode-index", type=int, default=0) + ap.add_argument("--out-root", type=str, required=True) + ap.add_argument("--history-len", type=int, default=16) + ap.add_argument("--start-idx", type=int, default=80) + ap.add_argument("--end-idx", type=int, default=130) + ap.add_argument("--device", type=str, default="cuda:0") + ap.add_argument("--debug-shapes", action="store_true") + args = ap.parse_args() + + dataset_root = Path(args.dataset_root) + manifest = read_manifest(dataset_root) + episode_index = int(args.episode_index) + episode_id = manifest["episode_id_map"][str(episode_index)] + cols = manifest["columns"]["action"] + + # load norm stat + norm_stat = json.loads((dataset_root / "norm_stat.json").read_text(encoding="utf-8")) + # pad norm stat to 30 dims (lingbot-va-base expects 30) + for k in ["q01", "q99"]: + if k in norm_stat and len(norm_stat[k]) < 30: + pad_val = 0.0 if k == "q01" else 1.0 + norm_stat[k] = list(norm_stat[k]) + [pad_val] * (30 - len(norm_stat[k])) + + # configure model + cfg = EasyDict(va_arms_cfg) + cfg.wan22_pretrained_model_name_or_path = str(Path(args.model_root).resolve()) + cfg.save_root = str(Path(args.out_root).resolve()) + # 16G 显存下建议 offload VAE/text_encoder 到 CPU,仅 transformer 在 GPU + cfg.enable_offload = True + cfg.param_dtype = torch.float16 + cfg.local_rank = int(str(args.device).split(":")[-1]) if "cuda" in args.device else 0 + cfg.norm_stat = norm_stat + + # init server + if "cuda" in args.device: + torch.cuda.empty_cache() + server = VA_Server(cfg) + device = server.device + + # reset with prompt + prompt = load_episode_prompt(dataset_root, episode_index) + server.infer({"reset": True, "prompt": prompt}) + + # load episode state and frames for history + parquet_path = dataset_root / "data" / "chunk-000" / f"episode_{episode_index:06d}.parquet" + df = pd.read_parquet(parquet_path, columns=["observation.state"]) + state = np.stack(df["observation.state"].to_list()).astype(np.float32) # [T, 26] + + # --- latent-based KV cache prefill (skip slow CPU VAE encoding) --- + start_frame, end_frame = load_episode_action_range(dataset_root, episode_index) + latents_root = Path(args.latents_root) if str(args.latents_root).strip() else (dataset_root / "latents") + latent_path = latents_root / "chunk-000" / "observation.images.front" / f"episode_{episode_index:06d}_{start_frame}_{end_frame}.pth" + if not latent_path.exists(): + raise FileNotFoundError( + f"找不到 latent 文件:{latent_path}\n" + f"请先用 `arms/extract_latents_arms_lerobot.py`(wan22-path 指向 models/lingbot-va-base)提取到该目录。" + ) + + # VAE temporal downsample is 4 for lingbot-va-base + hist_len = min(int(args.history_len), int(state.shape[0])) + hist_latent_frames = int(np.ceil(hist_len / 4.0)) + latents_5d = load_latents_5d(latent_path, device=device) # [1,48,F',16,16] + if latents_5d.shape[2] < hist_latent_frames: + raise RuntimeError(f"latent 帧数不足:need {hist_latent_frames}, got {latents_5d.shape[2]}") + latents_hist = latents_5d[:, :, :hist_latent_frames] + + # init_latent = first latent frame, remaining go through kv cache + server.init_latent = latents_hist[:, :, :1].to(server.dtype) + latent_model_input = latents_hist[:, :, 1:].to(server.dtype) if hist_latent_frames > 1 else None + + # action/state history still uses original frames + state_hist = state[:hist_len] + state_cf1 = state_hist.astype(np.float32).T[:, :, None] # [26,F,1] + # pad to 29 dims so VA_Server.preprocess_action 的 +1 padding 变成 30 dims + state_cf1 = np.pad(state_cf1, ((0, 3), (0, 0), (0, 0)), mode="constant", constant_values=0.0) + action_model_input = server.preprocess_action(state_cf1).to(device=device, dtype=server.dtype) + + server.transformer.clear_pred_cache(server.cache_name) + server.frame_st_id = 0 + if server.init_latent is not None and latent_model_input is not None: + latent_full = torch.cat([server.init_latent, latent_model_input], dim=2) + else: + latent_full = server.init_latent if latent_model_input is None else latent_model_input + + input_dict = server._prepare_latent_input(latent_full, action_model_input, frame_st_id=server.frame_st_id) + with torch.no_grad(): + server.transformer( + server._repeat_input_for_cfg(input_dict["latent_res_lst"]), + update_cache=2, + cache_name=server.cache_name, + action_mode=False, + ) + server.transformer( + server._repeat_input_for_cfg(input_dict["action_res_lst"]), + update_cache=2, + cache_name=server.cache_name, + action_mode=True, + ) + server.frame_st_id += int(latent_full.shape[2]) + + # rollout: this repository's server produces one chunk of actions at a time. + # Here we call infer repeatedly until we have enough timesteps to write 80-130. + actions_pred: List[np.ndarray] = [] + max_calls = 100 + for _ in range(max_calls): + out = server.infer({"obs": [], "state": state_cf1}) + if "action" in out: + if args.debug_shapes: + a = out["action"] + print("infer action.shape =", getattr(a, "shape", None)) + a = np.asarray(out["action"], dtype=np.float32) + # server returns (C, frame_chunk, 1) in our config + while a.ndim > 2 and a.shape[-1] == 1: + a = a[..., 0] + if a.ndim == 2 and a.shape[0] == 26 and a.shape[1] == cfg.frame_chunk_size: + a = a.T # -> (frame_chunk, 26) + actions_pred.append(a) + if sum(x.shape[0] for x in actions_pred) >= (args.end_idx - args.start_idx + 1): + break + + if not actions_pred: + raise RuntimeError("No action returned from server inference.") + + action_out = np.concatenate([np.asarray(x, dtype=np.float32) for x in actions_pred], axis=0) + # normalize shape to [T, 26] + if action_out.ndim == 1: + action_out = action_out[None] + while action_out.ndim > 2 and action_out.shape[-1] == 1: + action_out = action_out[..., 0] + # sometimes returned as [C, T] + if action_out.ndim == 2 and action_out.shape[0] in (26, 29, 30) and action_out.shape[1] != 26: + action_out = action_out.T + if action_out.ndim != 2: + raise RuntimeError(f"Unexpected action_out.ndim={action_out.ndim}, shape={action_out.shape}") + if action_out.shape[1] > 26: + action_out = action_out[:, :26] + if action_out.shape[1] != 26: + raise RuntimeError(f"Unexpected action_out shape={action_out.shape}, expected (*,26)") + + # pad/truncate to 51 rows + idxs = list(range(args.start_idx, args.end_idx + 1)) + need = len(idxs) + if action_out.shape[0] < need: + pad = np.repeat(action_out[-1:], repeats=need - action_out.shape[0], axis=0) + action_out = np.concatenate([action_out, pad], axis=0) + action_out = action_out[:need] + joint_out = action_out.copy() + + out_dir = Path(args.out_root) / episode_id + out_dir.mkdir(parents=True, exist_ok=True) + (out_dir / "instruction.txt").write_text(prompt + "\n", encoding="utf-8") + write_csv_like_sample(out_dir / "action.txt", cols, idxs, action_out) + write_csv_like_sample(out_dir / "joint.txt", cols, idxs, joint_out) + + print(f"Done. Wrote: {out_dir}") + + +if __name__ == "__main__": + main() + diff --git a/arms/run_arms_test_infer_and_video.py b/arms/run_arms_test_infer_and_video.py new file mode 100644 index 0000000..b61f1f3 --- /dev/null +++ b/arms/run_arms_test_infer_and_video.py @@ -0,0 +1,273 @@ +""" +对 arms/test/ 批量推理: +- 输入:每个 episode 目录包含 video.mp4(16帧)、instruction.txt、joint.txt(16行) +- 输出:每个 episode 输出 action.txt/joint.txt(80-130 共 51 行)以及预测 video.mp4 + +说明: +- 采用 latent 推理:先用 lingbot-va-base 的 VAE 把 16 帧编码成 48通道 latents(时域下采样=4) +- KV cache 预填充:用 4 个 latent 帧 + 16 步 joint history(padding 到 30 维动作空间) +- rollout:调用 VA_Server._infer 取回 (actions, latents),同时生成视频 +""" + +from __future__ import annotations + +import argparse +import sys +import time +from pathlib import Path +from typing import List, Tuple + +import cv2 +import numpy as np +import pandas as pd +import torch +from easydict import EasyDict +from tqdm import tqdm + +# allow running as a standalone script +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from diffusers.video_processor import VideoProcessor +from diffusers.utils import export_to_video + +from wan_va.configs.va_arms_cfg import va_arms_cfg +from wan_va.wan_va_server import VA_Server + + +def _read_video_frames_rgb(video_path: Path) -> Tuple[List[np.ndarray], float]: + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + raise RuntimeError(f"Cannot open video: {video_path}") + fps = float(cap.get(cv2.CAP_PROP_FPS) or 0.0) + frames: List[np.ndarray] = [] + while True: + ok, frame_bgr = cap.read() + if not ok or frame_bgr is None: + break + frames.append(cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)) + cap.release() + return frames, fps + + +def _frames_to_tensor(frames_rgb: List[np.ndarray], target_h: int, target_w: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + # -> [1,3,F,H,W], value in [-1,1] + if len(frames_rgb) == 0: + raise RuntimeError("Empty frames") + arr = [] + for fr in frames_rgb: + fr = cv2.resize(fr, (target_w, target_h), interpolation=cv2.INTER_AREA) + arr.append(fr.astype(np.float32)) + x = np.stack(arr, axis=0) # [F,H,W,3] + x = torch.from_numpy(x).permute(3, 0, 1, 2).contiguous() # [3,F,H,W] + x = (x / 255.0) * 2.0 - 1.0 + return x.unsqueeze(0).to(device=device, dtype=dtype) + + +@torch.no_grad() +def _encode_video_to_latents_norm(server: VA_Server, frames_rgb: List[np.ndarray]) -> torch.Tensor: + """ + 返回 normalized latents: [1,48,F',H',W'],与 VA_Server._encode_obs 一致的归一化方式。 + """ + device = server.device + dtype = server.dtype + + video_tensor = _frames_to_tensor(frames_rgb, server.job_config.height, server.job_config.width, device=device, dtype=dtype) + + # enable_offload 下 VAE 可能在 CPU;这里临时搬到 GPU 编码更快 + vae_was_on_cpu = next(server.vae.parameters()).device.type == "cpu" + if vae_was_on_cpu: + server.vae = server.vae.to(device).to(dtype) + + enc = server.vae.encode(video_tensor) + mu = enc.latent_dist.mean # [1,48,4,16,16] + + latents_mean = torch.tensor(server.vae.config.latents_mean, device=mu.device, dtype=mu.dtype).view(1, -1, 1, 1, 1) + latents_std = torch.tensor(server.vae.config.latents_std, device=mu.device, dtype=mu.dtype).view(1, -1, 1, 1, 1) + mu_norm = (mu.float() - latents_mean.float()) * (1.0 / latents_std.float()) + mu_norm = mu_norm.to(dtype) + + if vae_was_on_cpu and server.enable_offload: + server.vae = server.vae.to("cpu") + torch.cuda.empty_cache() + + return mu_norm + + +def _load_history_joint_cf1(test_ep_dir: Path, history_len: int) -> np.ndarray: + df = pd.read_csv(test_ep_dir / "joint.txt") + # drop first index col + data = df.iloc[:, 1:].to_numpy(dtype=np.float32) # [T,26] + hist = data[: min(history_len, data.shape[0])] + state_cf1 = hist.T[:, :, None] # [26,F,1] + # pad to 29 dims so preprocess_action 的 +1 padding 变成 30 dims + state_cf1 = np.pad(state_cf1, ((0, 3), (0, 0), (0, 0)), mode="constant", constant_values=0.0) + return state_cf1 + + +def _write_csv_like_sample(path: Path, header_cols: List[str], idxs: List[int], data_t26: np.ndarray) -> None: + import csv + + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8", newline="") as f: + writer = csv.writer(f) + writer.writerow(["Unnamed: 0"] + header_cols) + for i, t in enumerate(idxs): + row = [int(t)] + [float(x) for x in data_t26[i].tolist()] + writer.writerow(row) + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--test-root", type=str, default="arms/test") + ap.add_argument("--model-root", type=str, required=True) + ap.add_argument("--out-root", type=str, default="arms/test_generated") + ap.add_argument("--device", type=str, default="cuda:0") + ap.add_argument("--history-len", type=int, default=16) + ap.add_argument("--start-idx", type=int, default=80) + ap.add_argument("--end-idx", type=int, default=130) + ap.add_argument("--max-episodes", type=int, default=-1) + ap.add_argument("--skip-existing", action="store_true") + args = ap.parse_args() + + test_root = Path(args.test_root) + out_root = Path(args.out_root) + out_root.mkdir(parents=True, exist_ok=True) + + # header columns from any action.txt (same schema) + any_action = next(iter(sorted(test_root.glob("*/action.txt")))) + header = any_action.read_text(encoding="utf-8").splitlines()[0].strip().split(",") + header_cols = header[1:] # drop Unnamed: 0 + + # norm_stat from arms_lerobot (train stats) + norm_stat_path = Path("arms_lerobot/norm_stat.json") + if not norm_stat_path.exists(): + raise FileNotFoundError("需要先生成 arms_lerobot/norm_stat.json(用 compute_arms_norm_stat.py)") + norm_stat = __import__("json").loads(norm_stat_path.read_text(encoding="utf-8")) + for k in ["q01", "q99"]: + if k in norm_stat and len(norm_stat[k]) < 30: + pad_val = 0.0 if k == "q01" else 1.0 + norm_stat[k] = list(norm_stat[k]) + [pad_val] * (30 - len(norm_stat[k])) + + cfg = EasyDict(va_arms_cfg) + cfg.wan22_pretrained_model_name_or_path = str(Path(args.model_root).resolve()) + cfg.save_root = str(out_root.resolve()) + cfg.enable_offload = True + cfg.param_dtype = torch.float16 + cfg.local_rank = int(str(args.device).split(":")[-1]) if "cuda" in args.device else 0 + cfg.norm_stat = norm_stat + + if "cuda" in args.device: + torch.cuda.empty_cache() + server = VA_Server(cfg) + server.video_processor = VideoProcessor(vae_scale_factor=1) + + ep_dirs = sorted([p for p in test_root.iterdir() if p.is_dir()]) + if args.max_episodes > 0: + ep_dirs = ep_dirs[: int(args.max_episodes)] + + idxs = list(range(int(args.start_idx), int(args.end_idx) + 1)) + need_steps = len(idxs) + + for ep_dir in ep_dirs: + ep_id = ep_dir.name + t0 = time.time() + try: + ep_out = out_root / ep_id + if args.skip_existing and (ep_out / "action.txt").exists() and (ep_out / "joint.txt").exists() and (ep_out / "video.mp4").exists(): + print(f"skip {ep_id}", flush=True) + continue + + prompt = (ep_dir / "instruction.txt").read_text(encoding="utf-8").strip() + video_path = ep_dir / "video.mp4" + frames_rgb, fps = _read_video_frames_rgb(video_path) + if len(frames_rgb) == 0: + raise RuntimeError(f"Empty video: {video_path}") + + # reset and prepare caches + server.infer({"reset": True, "prompt": prompt}) + + # encode 16 frames -> latents (normalized) + latents_hist_full = _encode_video_to_latents_norm(server, frames_rgb) # [1,48,4,16,16] + hist_latent_frames = int(np.ceil(min(args.history_len, len(frames_rgb)) / 4.0)) + latents_hist = latents_hist_full[:, :, :hist_latent_frames] + + server.init_latent = latents_hist[:, :, :1].to(server.dtype) + latent_model_input = latents_hist[:, :, 1:].to(server.dtype) if hist_latent_frames > 1 else None + + state_cf1 = _load_history_joint_cf1(ep_dir, history_len=int(args.history_len)) + action_model_input = server.preprocess_action(state_cf1).to(device=server.device, dtype=server.dtype) + + server.transformer.clear_pred_cache(server.cache_name) + server.frame_st_id = 0 + if server.init_latent is not None and latent_model_input is not None: + latent_full = torch.cat([server.init_latent, latent_model_input], dim=2) + else: + latent_full = server.init_latent if latent_model_input is None else latent_model_input + + input_dict = server._prepare_latent_input(latent_full, action_model_input, frame_st_id=server.frame_st_id) + with torch.no_grad(): + server.transformer( + server._repeat_input_for_cfg(input_dict["latent_res_lst"]), + update_cache=2, + cache_name=server.cache_name, + action_mode=False, + ) + server.transformer( + server._repeat_input_for_cfg(input_dict["action_res_lst"]), + update_cache=2, + cache_name=server.cache_name, + action_mode=True, + ) + server.frame_st_id += int(latent_full.shape[2]) + + # rollout actions + latents + action_chunks: List[np.ndarray] = [] + latent_chunks: List[torch.Tensor] = [] + while sum(x.shape[0] for x in action_chunks) < need_steps: + actions_chunk, latents_chunk = server._infer({"obs": [], "state": state_cf1}, frame_st_id=server.frame_st_id) + a = np.asarray(actions_chunk, dtype=np.float32) + # actions from server are [C, frame_chunk, 1] + while a.ndim > 2 and a.shape[-1] == 1: + a = a[..., 0] + if a.ndim == 2 and a.shape[0] == 26 and a.shape[1] == int(cfg.frame_chunk_size): + a = a.T + if a.ndim != 2: + raise RuntimeError(f"Unexpected action chunk shape: {a.shape}") + if a.shape[1] > 26: + a = a[:, :26] + action_chunks.append(a) + latent_chunks.append(latents_chunk.detach().to("cpu")) + server.frame_st_id += int(cfg.frame_chunk_size) + + action_out = np.concatenate(action_chunks, axis=0)[:need_steps] + joint_out = action_out.copy() + + # decode video (51 frames, take first need_steps frames) + pred_latents = torch.cat(latent_chunks, dim=2)[:, :, :need_steps] + # decode on CPU to avoid OOM (transformer already occupies most GPU memory) + if next(server.vae.parameters()).device.type != "cpu": + server.vae = server.vae.to("cpu") + decoded_video = server.decode_one_video(pred_latents.to("cpu", dtype=server.dtype), "np")[0] + + ep_out.mkdir(parents=True, exist_ok=True) + (ep_out / "instruction.txt").write_text(prompt + "\n", encoding="utf-8") + _write_csv_like_sample(ep_out / "action.txt", header_cols, idxs, action_out) + _write_csv_like_sample(ep_out / "joint.txt", header_cols, idxs, joint_out) + export_to_video(decoded_video, str(ep_out / "video.mp4"), fps=10) + + print(f"done {ep_id} in {time.time()-t0:.1f}s (in_fps={fps})", flush=True) + torch.cuda.empty_cache() + + except Exception as e: + print(f"ERROR {ep_id}: {repr(e)}", flush=True) + torch.cuda.empty_cache() + continue + + print(f"All done. Output: {out_root}") + + +if __name__ == "__main__": + main() + diff --git a/arms/run_arms_test_infer_latents_only.py b/arms/run_arms_test_infer_latents_only.py new file mode 100644 index 0000000..4463844 --- /dev/null +++ b/arms/run_arms_test_infer_latents_only.py @@ -0,0 +1,348 @@ +""" +对 arms/test/ 批量推理(不解码视频): +- 输入:每个 episode 目录包含 video.mp4(16帧)、instruction.txt、joint.txt(16行) +- 输出:每个 episode 输出 + - action.txt / joint.txt(80-130 共 51 行) + - pred_latents.pt(预测视频 latents,供后续单独解码为 video.mp4) + +这样把“推理”和“视频解码”拆开,可以避免显存 OOM(transformer + VAE 同时上 GPU) +并且比 CPU 解码视频快很多。 +""" + +from __future__ import annotations + +import argparse +import json +import sys +import time +from pathlib import Path +from typing import List, Tuple + +import cv2 +import numpy as np +import pandas as pd +import torch +from easydict import EasyDict + +# allow running as a standalone script +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from wan_va.configs.va_arms_cfg import va_arms_cfg +from wan_va.wan_va_server import VA_Server + + +def _read_video_frames_rgb(video_path: Path) -> Tuple[List[np.ndarray], float]: + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + raise RuntimeError(f"Cannot open video: {video_path}") + fps = float(cap.get(cv2.CAP_PROP_FPS) or 0.0) + frames: List[np.ndarray] = [] + while True: + ok, frame_bgr = cap.read() + if not ok or frame_bgr is None: + break + frames.append(cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)) + cap.release() + return frames, fps + + +def _frames_to_tensor(frames_rgb: List[np.ndarray], target_h: int, target_w: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + if len(frames_rgb) == 0: + raise RuntimeError("Empty frames") + arr = [] + for fr in frames_rgb: + fr = cv2.resize(fr, (target_w, target_h), interpolation=cv2.INTER_AREA) + arr.append(fr.astype(np.float32)) + x = np.stack(arr, axis=0) # [F,H,W,3] + x = torch.from_numpy(x).permute(3, 0, 1, 2).contiguous() # [3,F,H,W] + x = (x / 255.0) * 2.0 - 1.0 + return x.unsqueeze(0).to(device=device, dtype=dtype) + + +@torch.no_grad() +def _encode_video_to_latents_norm(server: VA_Server, frames_rgb: List[np.ndarray]) -> torch.Tensor: + device = server.device + dtype = server.dtype + video_tensor = _frames_to_tensor(frames_rgb, server.job_config.height, server.job_config.width, device=device, dtype=dtype) + + # temporarily move VAE to GPU for encoding + vae_was_on_cpu = next(server.vae.parameters()).device.type == "cpu" + if vae_was_on_cpu: + server.vae = server.vae.to(device).to(dtype) + + enc = server.vae.encode(video_tensor) + mu = enc.latent_dist.mean + latents_mean = torch.tensor(server.vae.config.latents_mean, device=mu.device, dtype=mu.dtype).view(1, -1, 1, 1, 1) + latents_std = torch.tensor(server.vae.config.latents_std, device=mu.device, dtype=mu.dtype).view(1, -1, 1, 1, 1) + mu_norm = (mu.float() - latents_mean.float()) * (1.0 / latents_std.float()) + mu_norm = mu_norm.to(dtype) + + if vae_was_on_cpu and server.enable_offload: + server.vae = server.vae.to("cpu") + torch.cuda.empty_cache() + + return mu_norm + + +def _load_history_joint_cf1(test_ep_dir: Path, history_len: int) -> np.ndarray: + df = pd.read_csv(test_ep_dir / "joint.txt") + data = df.iloc[:, 1:].to_numpy(dtype=np.float32) # [T,26] + hist = data[: min(history_len, data.shape[0])] + state_cf1 = hist.T[:, :, None] # [26,F,1] + state_cf1 = np.pad(state_cf1, ((0, 3), (0, 0), (0, 0)), mode="constant", constant_values=0.0) # -> [29,F,1] + return state_cf1 + + +def _load_history_joint_26(test_ep_dir: Path, history_len: int) -> np.ndarray: + df = pd.read_csv(test_ep_dir / "joint.txt") + data = df.iloc[:, 1:].to_numpy(dtype=np.float32) # [T,26] + return data[: min(history_len, data.shape[0])] + + +def _apply_finger_hold_then_grasp( + out_t26: np.ndarray, + hist_t26: np.ndarray, + grasp_steps: int, +) -> np.ndarray: + """ + Scheme A: + - finger dims (14:26) hold the last observed pose for most steps + - last `grasp_steps` steps switch to a "closed" template (max over history) + """ + if out_t26.shape[1] != 26: + return out_t26 + if hist_t26.size == 0: + return out_t26 + + grasp_steps = int(max(0, min(grasp_steps, out_t26.shape[0]))) + fingers_last = hist_t26[-1, 14:26].copy() + fingers_closed = hist_t26[:, 14:26].max(axis=0).copy() + + out = out_t26.copy() + out[:, 14:26] = fingers_last[None] + if grasp_steps > 0: + out[-grasp_steps:, 14:26] = fingers_closed[None] + return out + + +def _write_csv_like_sample(path: Path, header_cols: List[str], idxs: List[int], data_t26: np.ndarray) -> None: + import csv + + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8", newline="") as f: + writer = csv.writer(f) + writer.writerow(["Unnamed: 0"] + header_cols) + for i, t in enumerate(idxs): + writer.writerow([int(t)] + [float(x) for x in data_t26[i].tolist()]) + + +def _action16_to_action26(action_cf: np.ndarray, hist_t26: np.ndarray) -> np.ndarray: + """ + VA_Server.postprocess_action returns only used_action_channel_ids (16 dims) for arms config: + - 14 joint dims (left7+right7) + - 2 gripper dims (not the 12 finger dims in ARMS csv schema) + + For submission/action.txt schema (26 dims), we: + - take the first 14 dims as joints + - fill finger dims (14:26) using history template (handled later by _apply_finger_hold_then_grasp) + """ + if action_cf.ndim != 2: + return action_cf + if action_cf.shape[1] != 16: + return action_cf + out = np.zeros((action_cf.shape[0], 26), dtype=np.float32) + out[:, :14] = action_cf[:, :14].astype(np.float32) + if hist_t26.size > 0 and hist_t26.shape[1] == 26: + out[:, 14:26] = hist_t26[-1, 14:26][None].astype(np.float32) + return out + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--test-root", type=str, default="arms/test") + ap.add_argument("--model-root", type=str, required=True) + ap.add_argument("--out-root", type=str, default="arms/test_generated_latents") + ap.add_argument("--device", type=str, default="cuda:0") + ap.add_argument("--history-len", type=int, default=16) + ap.add_argument("--start-idx", type=int, default=80) + ap.add_argument("--end-idx", type=int, default=130) + ap.add_argument("--max-episodes", type=int, default=-1) + ap.add_argument("--skip-existing", action="store_true") + ap.add_argument("--overwrite-existing", action="store_true", help="重跑并覆盖已存在的输出(action/joint/pred_latents)") + ap.add_argument("--no-pred-latents", action="store_true", help="不保存 pred_latents.pt(加速,且不做 latents 拼接)") + ap.add_argument( + "--start-from-test-end", + action="store_true", + help="从 arms/test//action.txt 的最后一行 idx+1 开始连续预测,并输出 --predict-steps 行(覆盖 --start-idx/--end-idx)", + ) + ap.add_argument("--predict-steps", type=int, default=51, help="--start-from-test-end 时预测多少步(默认 51)") + ap.add_argument("--grasp-steps", type=int, default=10, help="方案A:最后多少步把手指切到闭合姿态") + args = ap.parse_args() + + test_root = Path(args.test_root) + out_root = Path(args.out_root) + out_root.mkdir(parents=True, exist_ok=True) + + # header columns from any action.txt (same schema) + any_action = next(iter(sorted(test_root.glob("*/action.txt")))) + header = any_action.read_text(encoding="utf-8").splitlines()[0].strip().split(",") + header_cols = header[1:] + + # norm_stat from arms_lerobot (train stats) + norm_stat_path = Path("arms_lerobot/norm_stat.json") + if not norm_stat_path.exists(): + raise FileNotFoundError("需要先生成 arms_lerobot/norm_stat.json(用 compute_arms_norm_stat.py)") + norm_stat = json.loads(norm_stat_path.read_text(encoding="utf-8")) + for k in ["q01", "q99"]: + if k in norm_stat and len(norm_stat[k]) < 30: + pad_val = 0.0 if k == "q01" else 1.0 + norm_stat[k] = list(norm_stat[k]) + [pad_val] * (30 - len(norm_stat[k])) + + cfg = EasyDict(va_arms_cfg) + cfg.wan22_pretrained_model_name_or_path = str(Path(args.model_root).resolve()) + cfg.save_root = str(out_root.resolve()) + cfg.enable_offload = True + cfg.param_dtype = torch.float16 + cfg.local_rank = int(str(args.device).split(":")[-1]) if "cuda" in args.device else 0 + cfg.norm_stat = norm_stat + + if "cuda" in args.device: + torch.cuda.empty_cache() + server = VA_Server(cfg) + + ep_dirs = sorted([p for p in test_root.iterdir() if p.is_dir()]) + if args.max_episodes > 0: + ep_dirs = ep_dirs[: int(args.max_episodes)] + + fixed_start = int(args.start_idx) + fixed_end = int(args.end_idx) + fixed_idxs = list(range(fixed_start, fixed_end + 1)) + fixed_need_steps = len(fixed_idxs) + + for ep_dir in ep_dirs: + ep_id = ep_dir.name + t0 = time.time() + try: + ep_out = out_root / ep_id + if ( + args.skip_existing + and not args.overwrite_existing + and (ep_out / "action.txt").exists() + and (ep_out / "joint.txt").exists() + and (args.no_pred_latents or (ep_out / "pred_latents.pt").exists()) + ): + print(f"skip {ep_id}", flush=True) + continue + + prompt = (ep_dir / "instruction.txt").read_text(encoding="utf-8").strip() + frames_rgb, _fps = _read_video_frames_rgb(ep_dir / "video.mp4") + if len(frames_rgb) == 0: + raise RuntimeError("Empty video") + + server.infer({"reset": True, "prompt": prompt}) + + latents_hist_full = _encode_video_to_latents_norm(server, frames_rgb) # [1,48,4,16,16] + hist_latent_frames = int(np.ceil(min(args.history_len, len(frames_rgb)) / 4.0)) + latents_hist = latents_hist_full[:, :, :hist_latent_frames] + + server.init_latent = latents_hist[:, :, :1].to(server.dtype) + latent_model_input = latents_hist[:, :, 1:].to(server.dtype) if hist_latent_frames > 1 else None + + state_cf1 = _load_history_joint_cf1(ep_dir, history_len=int(args.history_len)) + hist_26 = _load_history_joint_26(ep_dir, history_len=int(args.history_len)) + action_model_input = server.preprocess_action(state_cf1).to(device=server.device, dtype=server.dtype) + + if args.start_from_test_end: + df_hist = pd.read_csv(ep_dir / "action.txt") + if df_hist.shape[0] == 0: + raise RuntimeError("Empty test action.txt") + last_idx = int(df_hist.iloc[-1, 0]) + start_idx = last_idx + 1 + need_steps = int(args.predict_steps) + idxs = list(range(start_idx, start_idx + need_steps)) + else: + idxs = fixed_idxs + need_steps = fixed_need_steps + + server.transformer.clear_pred_cache(server.cache_name) + server.frame_st_id = 0 + if server.init_latent is not None and latent_model_input is not None: + latent_full = torch.cat([server.init_latent, latent_model_input], dim=2) + else: + latent_full = server.init_latent if latent_model_input is None else latent_model_input + + input_dict = server._prepare_latent_input(latent_full, action_model_input, frame_st_id=server.frame_st_id) + with torch.no_grad(): + server.transformer( + server._repeat_input_for_cfg(input_dict["latent_res_lst"]), + update_cache=2, + cache_name=server.cache_name, + action_mode=False, + ) + server.transformer( + server._repeat_input_for_cfg(input_dict["action_res_lst"]), + update_cache=2, + cache_name=server.cache_name, + action_mode=True, + ) + server.frame_st_id += int(latent_full.shape[2]) + + action_chunks: List[np.ndarray] = [] + latent_chunks: List[torch.Tensor] = [] + while sum(x.shape[0] for x in action_chunks) < need_steps: + actions_chunk, latents_chunk = server._infer({"obs": [], "state": state_cf1}, frame_st_id=server.frame_st_id) + a = np.asarray(actions_chunk, dtype=np.float32) + while a.ndim > 2 and a.shape[-1] == 1: + a = a[..., 0] + # Expected shapes: + # - (16, F) from VA_Server.postprocess_action (C,F) + # - (26, F) from older pipelines + if a.ndim == 2 and a.shape[0] in (16, 26) and a.shape[1] == int(cfg.frame_chunk_size): + a = a.T + # If action is still channel-first, transpose. + if a.ndim == 2 and a.shape[0] in (16, 26) and a.shape[1] != int(cfg.frame_chunk_size): + # already (T,C) or something else; keep + pass + + # Convert 16-d outputs to 26-d csv schema + if a.ndim == 2 and a.shape[1] == 16: + a = _action16_to_action26(a, hist_26) + + if a.ndim == 2 and a.shape[1] > 26: + a = a[:, :26] + action_chunks.append(a) + if not args.no_pred_latents: + latent_chunks.append(latents_chunk.detach().to("cpu")) + server.frame_st_id += int(cfg.frame_chunk_size) + + action_out = np.concatenate(action_chunks, axis=0)[:need_steps] + if action_out.ndim != 2 or action_out.shape[1] != 26: + raise RuntimeError(f"Bad action_out shape {getattr(action_out,'shape',None)}, expected (T,26)") + # Scheme A: fingers hold then grasp + action_out = _apply_finger_hold_then_grasp(action_out, hist_26, grasp_steps=int(args.grasp_steps)) + joint_out = action_out.copy() + + ep_out.mkdir(parents=True, exist_ok=True) + (ep_out / "instruction.txt").write_text(prompt + "\n", encoding="utf-8") + _write_csv_like_sample(ep_out / "action.txt", header_cols, idxs, action_out) + _write_csv_like_sample(ep_out / "joint.txt", header_cols, idxs, joint_out) + if not args.no_pred_latents: + pred_latents = torch.cat(latent_chunks, dim=2)[:, :, :need_steps].contiguous() # [1,48,T,16,16] + torch.save({"latents": pred_latents, "fps": 10}, ep_out / "pred_latents.pt") + + print(f"done {ep_id} in {time.time()-t0:.1f}s", flush=True) + torch.cuda.empty_cache() + + except Exception as e: + print(f"ERROR {ep_id}: {repr(e)}", flush=True) + torch.cuda.empty_cache() + continue + + print(f"All done. Output: {out_root}") + + +if __name__ == "__main__": + main() + diff --git a/docs/ARMS_MI300X_POSTTRAIN.md b/docs/ARMS_MI300X_POSTTRAIN.md new file mode 100644 index 0000000..2d4c829 --- /dev/null +++ b/docs/ARMS_MI300X_POSTTRAIN.md @@ -0,0 +1,206 @@ +# Arms 数据 + MI300X 服务器 Post-training / 训练排障指南 + +本文把从「把工程与数据搬到服务器」到「`arms_train` 能真正开训」这条链路里**高频踩坑**整理成一份可执行清单。默认你在服务器上的工程路径类似: + +- 代码:`/root/lingbot-va` +- 数据集:`/root/lingbot-va/arms_lerobot`(注意:本仓库 `.gitignore` 忽略了 `arms_lerobot/`,**数据集不会随 git 推送**) +- Base 模型:`/root/lingbot-va/models/lingbot-va-base`(HuggingFace:`robbyant/lingbot-va-base`) +- Wan Diffusers(如需本地 VAE/权重目录):`/root/lingbot-va/models/Wan2.2-Animate-14B-Diffusers`(示例 HF:`Wan-AI/Wan2.2-Animate-14B-Diffusers`) + +--- + +## 1) 代码:GitHub `git pull` 或 rsync + +### 推荐:GitHub 拉代码 + +```bash +git clone https://github.com/<你的组织>/<你的仓库>.git +cd lingbot-va +git checkout arms-ROCM-posttraining +git pull +``` + +### 备选:从本机 rsync 整个工程(注意 exclude 大目录) + +常用排除:`models/`、`example/`、`assets/`、`libero_10/`、压缩包等。 + +--- + +## 2) 数据:`arms_lerobot` 单独同步(不要指望 git) + +`arms_lerobot/` 在本仓库 `.gitignore` 中,因此请用 `rsync`/`scp` 同步到服务器固定路径,例如: + +```bash +rsync -avP --partial --info=progress2 \ + /path/to/local/arms_lerobot/ \ + root@:/root/lingbot-va/arms_lerobot/ +``` + +### 训练走 latents 时,至少要包含 + +- `meta/episodes.jsonl` +- `data/chunk-000/episode_*.parquet` +- `latents/chunk-000//*.pth` +- `empty_emb.pt`、`norm_stat.json`(你当前流程里会用到) +- `manifest.json`(如果你们工具链依赖) + +`videos/` 是否必须取决于你是否还要走“读 mp4”的路径;**latents 训练通常不强依赖 videos**。 + +--- + +## 3) 模型权重:服务器下载(HF / 镜像) + +### `lingbot-va-base` + +仓库页:`https://huggingface.co/robbyant/lingbot-va-base` + +```bash +pip install -U "huggingface_hub[cli]" +export HF_ENDPOINT="https://hf-mirror.com" # 不能直连 huggingface.co 时用镜像;能直连则改成 https://huggingface.co + +huggingface-cli download "robbyant/lingbot-va-base" \ + --local-dir "/root/lingbot-va/models/lingbot-va-base" \ + --local-dir-use-symlinks False +``` + +### Wan Diffusers(示例) + +```bash +huggingface-cli download "Wan-AI/Wan2.2-Animate-14B-Diffusers" \ + --local-dir "/root/lingbot-va/models/Wan2.2-Animate-14B-Diffusers" \ + --include "vae/*" +``` + +--- + +## 4) Python 依赖:不要只装 torch + +你至少会遇到过这些 import 缺失(按报错逐个装): + +```bash +python3 -m pip install -U wandb datasets jsonlines av +``` + +`av`(PyAV)即使训练主要读 latents,`lerobot` 在 import 阶段也会加载视频工具链,**没装 `av` 会直接起不来**。 + +> 若 `pip install av` 失败,优先用系统包补齐 ffmpeg 相关开发库后再装(不同发行版包名略有差异)。 + +--- + +## 5) 配置:三条路径必须对齐 + +1) `wan_va/configs/va_arms_cfg.py` + +- `wan22_pretrained_model_name_or_path = "/root/lingbot-va/models/lingbot-va-base"` + +2) `wan_va/configs/va_arms_train_cfg.py` + +- `dataset_path = "/root/lingbot-va/arms_lerobot"`(**不要用 `./arms_lerobot`**) +- `empty_emb_path` 会拼接在 `dataset_path` 下,确保 `empty_emb.pt` 存在 + +3) 启动参数 + +```bash +torchrun --standalone --nproc_per_node=1 wan_va/train.py \ + --config-name arms_train \ + --save-root /root/lingbot-va/outputs/arms_train +``` + +--- + +## 6) W&B:不要用不存在的 CLI 参数 + +`wan_va/train.py` **不支持** `--logger wandb` 这类参数;是否启用看配置: + +- `wan_va/configs/va_arms_train_cfg.py` → `enable_wandb` + +若开启 W&B,需要环境变量(示例): + +```bash +export WANDB_BASE_URL="https://api.wandb.ai" +export WANDB_API_KEY="..." +export WANDB_TEAM_NAME="..." +export WANDB_PROJECT="..." +``` + +--- + +## 7) LeRobot 元数据:你这次训练失败的“关键三连” + +`lerobot==0.3.x` 的 `LeRobotDatasetMetadata` 会读取本地 `meta/` 下多份文件。自定义导出的 `arms_lerobot` 往往缺其中几份,表现为: + +- `num_samples=0`(数据集根本没被注册进来) +- `HFValidationError: repo id ... '/root/...'`(异常分支把本地路径误当成 HF repo id) + +### 7.1 必须有:`meta/info.json` + +用于声明 `codebase_version`、`fps`、`features`、数据/视频路径模板等。 + +### 7.2 必须有:`meta/tasks.jsonl` + +格式示例(每行一个 JSON): + +```json +{"task_index": 0, "task": "Pick up ..."} +``` + +### 7.3 必须有:`meta/episodes_stats.jsonl` + +这是本次排障里最后一个阻塞点:`LeRobotDatasetMetadata` 会加载 episode 级统计;缺失会触发异常分支。 + +> 本仓库已在 `wan_va/dataset/lerobot_latent_dataset.py` 增加兜底:缺文件时尝试从 `episodes.jsonl` + `data/chunk-000/*.parquet` 自动生成 `tasks.jsonl` / `episodes_stats.jsonl`,并把传给 `LeRobotDatasetMetadata` 的 `repo_id` 固定为目录名(如 `arms_lerobot`),避免绝对路径被当成 HF repo id。 + +--- + +## 8) SSH / rsync:`Permission denied (publickey)` + +这表示目标机只允许公钥登录。你需要: + +- 在发起端生成 `ssh-keygen` +- 把公钥追加到目标机 `~/.ssh/authorized_keys` +- 或指定已有私钥:`-i /path/to/key.pem` + +--- + +## 9) 推荐启动方式(与上游 README 一致) + +上游 README 的范式是: + +```bash +NGPU=1 CONFIG_NAME='arms_train' bash script/run_va_posttrain.sh +``` + +注意:本仓库脚本里 `WANDB_*` 可能是占位符;若你不开 W&B,请先把 `enable_wandb=False`,避免训练启动先去走 W&B 初始化。 + +--- + +## 10) 快速自检命令(服务器上) + +```bash +python3 - <<'PY' +from wan_va.configs import VA_CONFIGS +print("dataset_path:", VA_CONFIGS["arms_train"].dataset_path) +print("wan22_pretrained_model_name_or_path:", VA_CONFIGS["arms_train"].wan22_pretrained_model_name_or_path) +PY + +ls -lah /root/lingbot-va/arms_lerobot/meta/info.json +ls -lah /root/lingbot-va/arms_lerobot/meta/tasks.jsonl +ls -lah /root/lingbot-va/arms_lerobot/meta/episodes_stats.jsonl +ls -1 /root/lingbot-va/arms_lerobot/latents/chunk-000/observation.images.front | head +``` + +--- + +## 11) 仍然失败时,最有效的信息 + +请贴三段信息(从下到上): + +1) `python3 - </norm_stat.json 覆盖;这里给默认占位(30 维) +va_arms_cfg.norm_stat = {"q01": [0.0] * va_arms_cfg.action_dim, "q99": [1.0] * va_arms_cfg.action_dim} + diff --git a/wan_va/configs/va_arms_train_cfg.py b/wan_va/configs/va_arms_train_cfg.py new file mode 100644 index 0000000..c20b36b --- /dev/null +++ b/wan_va/configs/va_arms_train_cfg.py @@ -0,0 +1,43 @@ +# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved. +import os +from easydict import EasyDict + +from .va_arms_cfg import va_arms_cfg + + +va_arms_train_cfg = EasyDict(__name__="Config: VA arms train") +va_arms_train_cfg.update(va_arms_cfg) + +# dataset root containing meta/data/videos/latents/empty_emb.pt +va_arms_train_cfg.dataset_path = os.environ.get("ARMS_LEROBOT_PATH", "/path/to/arms_lerobot") +va_arms_train_cfg.empty_emb_path = os.path.join(va_arms_train_cfg.dataset_path, "empty_emb.pt") + +# logging / io +va_arms_train_cfg.enable_wandb = False +# NOTE: After FSDP/CUDA init, fork-based DataLoader workers can deadlock on some stacks. +# Keep this small (0 is safest for "first run green"); increase only if you validate spawn/fork safety. +va_arms_train_cfg.load_worker = 0 + +# Dataset construction uses a multiprocessing Pool inside `MultiLatentLeRobotDataset`. +# Default used to be 128 workers which looks like "hundreds of train.py processes" in pgrep. +va_arms_train_cfg.dataset_init_workers = 16 +va_arms_train_cfg.save_interval = 200 +va_arms_train_cfg.gc_interval = 50 +va_arms_train_cfg.cfg_prob = 0.1 + +# training parameters (MI300X 单卡可以适当加大 batch/accum) +va_arms_train_cfg.learning_rate = 1e-5 +va_arms_train_cfg.beta1 = 0.9 +va_arms_train_cfg.beta2 = 0.95 +va_arms_train_cfg.weight_decay = 1e-1 +va_arms_train_cfg.warmup_steps = 50 +va_arms_train_cfg.batch_size = 1 +va_arms_train_cfg.gradient_accumulation_steps = 8 +va_arms_train_cfg.num_steps = 10000 + +# loss weights (两阶段训练用) +# 阶段1(动作优先):latent_loss_weight=0, action_loss_weight=1 +# 阶段2(联合):latent_loss_weight=1, action_loss_weight=1 +va_arms_train_cfg.latent_loss_weight = 0.0 +va_arms_train_cfg.action_loss_weight = 1.0 + diff --git a/wan_va/dataset/lerobot_latent_dataset.py b/wan_va/dataset/lerobot_latent_dataset.py index f63c3fb..ce2fe99 100644 --- a/wan_va/dataset/lerobot_latent_dataset.py +++ b/wan_va/dataset/lerobot_latent_dataset.py @@ -2,12 +2,13 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata from lerobot.datasets.utils import get_episode_data_index from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats +import json import numpy as np from pathlib import Path from collections.abc import Callable import os from tqdm import tqdm -from multiprocessing import Pool +import multiprocessing as mp from functools import partial import torch from einops import rearrange @@ -28,6 +29,171 @@ def recursive_find_file(directory, filename='info.json'): print(f"Error: {e}") return result + +def _ensure_tasks_jsonl(dataset_root: Path) -> None: + """ + LeRobotDatasetMetadata expects `meta/tasks.jsonl`. + Some custom exports only ship `meta/episodes.jsonl`; synthesize tasks from it. + """ + tasks_path = dataset_root / "meta" / "tasks.jsonl" + if tasks_path.exists(): + return + + episodes_path = dataset_root / "meta" / "episodes.jsonl" + if not episodes_path.exists(): + raise FileNotFoundError( + f"Missing {tasks_path} and cannot synthesize without {episodes_path}" + ) + + tasks: list[str] = [] + seen: set[str] = set() + with episodes_path.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + obj = json.loads(line) + t = None + if isinstance(obj.get("tasks"), list) and obj["tasks"]: + t = str(obj["tasks"][0]) + elif isinstance(obj.get("task"), str): + t = obj["task"] + if not t: + continue + if t in seen: + continue + seen.add(t) + tasks.append(t) + + tasks_path.parent.mkdir(parents=True, exist_ok=True) + with tasks_path.open("w", encoding="utf-8") as f: + for i, t in enumerate(tasks): + f.write(json.dumps({"task_index": i, "task": t}, ensure_ascii=False) + "\n") + + +def _load_task_to_index(dataset_root: Path) -> dict[str, int]: + tasks_path = dataset_root / "meta" / "tasks.jsonl" + if not tasks_path.exists(): + raise FileNotFoundError(f"Missing {tasks_path}; run training after tasks.jsonl exists") + m: dict[str, int] = {} + with tasks_path.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + obj = json.loads(line) + m[str(obj["task"])] = int(obj["task_index"]) + return m + + +def _stat_1d(x: "np.ndarray") -> dict: + # x: [T, D] + if x.size == 0: + d = int(x.shape[1]) if x.ndim == 2 else 0 + z = [0.0] * d + o = [1.0] * d + return { + "min": z, + "max": o, + "mean": z, + "std": z, + "count": [0], + } + mn = np.min(x, axis=0).astype(float).tolist() + mx = np.max(x, axis=0).astype(float).tolist() + mu = np.mean(x, axis=0).astype(float).tolist() + sd = np.std(x, axis=0).astype(float).tolist() + return {"min": mn, "max": mx, "mean": mu, "std": sd, "count": [int(x.shape[0])]} + + +def _dummy_video_stat(count: int) -> dict: + # Shape matches common LeRobot exports (3 nested levels), values are placeholders. + return { + "min": [[[0.0]], [[0.0]], [[0.0]]], + "max": [[[1.0]], [[1.0]], [[1.0]]], + "mean": [[[0.5]], [[0.5]], [[0.5]]], + "std": [[[0.1]], [[0.1]], [[0.1]]], + "count": [int(count)], + } + + +def _ensure_episodes_stats_jsonl(dataset_root: Path) -> None: + """ + LeRobotDatasetMetadata expects `meta/episodes_stats.jsonl` for v2.1 datasets. + If missing, compute lightweight stats from parquet + episodes.jsonl. + """ + out_path = dataset_root / "meta" / "episodes_stats.jsonl" + if out_path.exists(): + return + + info_path = dataset_root / "meta" / "info.json" + if not info_path.exists(): + raise FileNotFoundError(f"Missing {info_path}") + info = json.loads(info_path.read_text(encoding="utf-8")) + fps = float(info.get("fps", 30) or 30) + + episodes_path = dataset_root / "meta" / "episodes.jsonl" + if not episodes_path.exists(): + raise FileNotFoundError(f"Missing {episodes_path}") + + task_to_idx = _load_task_to_index(dataset_root) + + # global row index base (approximate LeRobot's global `index`) + cum_rows = 0 + lines_out: list[str] = [] + + with episodes_path.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + ep = json.loads(line) + ep_idx = int(ep["episode_index"]) + length = int(ep["length"]) + tasks = ep.get("tasks") or [] + task_str = str(tasks[0]) if tasks else "" + task_index = int(task_to_idx[task_str]) if task_str in task_to_idx else 0 + + pq = dataset_root / "data" / "chunk-000" / f"episode_{ep_idx:06d}.parquet" + if not pq.exists(): + raise FileNotFoundError(f"Missing parquet for episode {ep_idx}: {pq}") + + # Optional dependency: pandas/pyarrow are already required by lerobot workflows. + import pandas as pd # local import to keep import graph lighter + + df = pd.read_parquet(pq, columns=["action", "observation.state"]) + act = np.stack(df["action"].to_numpy()).astype(np.float32) + st = np.stack(df["observation.state"].to_numpy()).astype(np.float32) + + T = int(act.shape[0]) + if T != length: + # Keep going, but prefer parquet truth for stats. + length = T + + idx = np.arange(cum_rows, cum_rows + length, dtype=np.int64)[:, None] + ep_col = np.full((length, 1), ep_idx, dtype=np.int64) + fi = np.arange(0, length, dtype=np.int64)[:, None] + ti = np.full((length, 1), task_index, dtype=np.int64) + ts = (fi.astype(np.float32) / float(fps)) + + stats = { + "episode_index": _stat_1d(ep_col.astype(np.float32)), + "index": _stat_1d(idx.astype(np.float32)), + "frame_index": _stat_1d(fi.astype(np.float32)), + "task_index": _stat_1d(ti.astype(np.float32)), + "timestamp": _stat_1d(ts), + "action": _stat_1d(act), + "observation.state": _stat_1d(st), + # We don't decode mp4 here; latent training doesn't need accurate video stats. + "observation.images.front": _dummy_video_stat(length), + } + + lines_out.append(json.dumps({"episode_index": ep_idx, "stats": stats}, ensure_ascii=False)) + cum_rows += length + + out_path.write_text("\n".join(lines_out) + ("\n" if lines_out else ""), encoding="utf-8") + + def construct_lerobot( repo_id, config, @@ -45,9 +211,17 @@ def construct_lerobot_multi_processor(config, construct_lerobot, config=config, ) - repo_list = recursive_find_file(config.dataset_path, 'info.json') + # Always resolve dataset_path to an absolute path. Relative paths like + # "./arms_lerobot" would otherwise be treated as an invalid HF repo id. + dataset_root = Path(config.dataset_path).expanduser().resolve() + repo_list = recursive_find_file(str(dataset_root), 'info.json') repo_list = [v.split('/meta/info.json')[0] for v in repo_list] - with Pool(num_init_worker) as pool: + for root in repo_list: + _ensure_tasks_jsonl(Path(root)) + _ensure_episodes_stats_jsonl(Path(root)) + # Use spawn context to avoid fork-related crashes/hangs with torch + GPU init. + ctx = mp.get_context("spawn") + with ctx.Pool(num_init_worker) as pool: datasets_out_lst = pool.map(construct_func, repo_list) return datasets_out_lst @@ -71,8 +245,13 @@ class MultiLatentLeRobotDataset(torch.utils.data.Dataset): def __init__( self, config, - num_init_worker=128, + num_init_worker=None, ): + if num_init_worker is None: + num_init_worker = int(getattr(config, "dataset_init_workers", 16) or 16) + cpu = os.cpu_count() or 1 + # Avoid spawning hundreds of short-lived workers during dataset init. + num_init_worker = int(max(1, min(int(num_init_worker), max(1, cpu)))) self._datasets = construct_lerobot_multi_processor(config, num_init_worker, ) @@ -111,8 +290,16 @@ def __init__( repo_id, config=None, ): - self.repo_id = repo_id - self.root = HF_LEROBOT_HOME / repo_id + # `repo_id` here is the local dataset root directory passed from + # `recursive_find_file(...)`. LeRobotDatasetMetadata expects a HF-style + # repo id string (no slashes), while the on-disk dataset lives at + # `dataset_root`. + dataset_root = Path(repo_id).expanduser().resolve() + _ensure_tasks_jsonl(dataset_root) + _ensure_episodes_stats_jsonl(dataset_root) + + self.repo_id = dataset_root.name # e.g. "arms_lerobot" + self.root = dataset_root self.image_transforms = None self.delta_timestamps = None self.episodes = None @@ -141,7 +328,7 @@ def __init__( self.hf_dataset = self.load_hf_dataset() self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) - self.latent_path = Path(repo_id) / 'latents' + self.latent_path = dataset_root / 'latents' self.empty_emb = torch.load(config.empty_emb_path, weights_only=False) self.config = config self.cfg_prob = config.cfg_prob @@ -270,9 +457,39 @@ def _action_post_process(self, local_start_frame, local_end_frame, latent_frame_ action_mask = np.ones_like(action, dtype='bool') assert action.shape[0] == required_action_num - - action_paded = np.pad(action, ((0, 0), (0, 1)), mode='constant', constant_values=0) - action_mask_padded = np.pad(action_mask, ((0, 0), (0, 1)), mode='constant', constant_values=0) + # If dataset provides 26-dim (dual-arm joints 14 + fingers 12), map to + # 30-dim (EEF_L7 + EEF_R7 + joints_L7 + joints_R7 + grip_L1 + grip_R1). + # Without URDF we cannot compute EEF; we keep EEF dims at 0 and mask them out. + # Finger joints are aggregated into a single gripper scalar per hand (mean). + if int(action.shape[1]) == 26 and int(getattr(self.config, "action_dim", 26)) == 30: + # action layout (26): + # 0:7 left arm joints + # 7:14 right arm joints + # 14:20 left fingers (6) + # 20:26 right fingers (6) + left_j = action[:, 0:7] + right_j = action[:, 7:14] + left_f = action[:, 14:20] + right_f = action[:, 20:26] + + grip_l = left_f.mean(axis=1, keepdims=True) + grip_r = right_f.mean(axis=1, keepdims=True) + + eef_zeros = np.zeros((action.shape[0], 14), dtype=action.dtype) + action = np.concatenate([eef_zeros, left_j, right_j, grip_l, grip_r], axis=1) # [T,30] + + # mask: EEF dims invalid (0), joints+gripper valid (1) + eef_mask = np.zeros((action_mask.shape[0], 14), dtype=bool) + jr_mask = np.ones((action_mask.shape[0], 16), dtype=bool) # 14 joints + 2 grippers + action_mask = np.concatenate([eef_mask, jr_mask], axis=1) + + + # Align action dim to model action_dim (lingbot-va-base expects 30). + # Some datasets store 26-dim actions; pad with zeros to config.action_dim. + target_action_dim = int(getattr(self.config, "action_dim", action.shape[1])) + pad_dim = max(0, target_action_dim - int(action.shape[1])) + action_paded = np.pad(action, ((0, 0), (0, pad_dim)), mode='constant', constant_values=0) + action_mask_padded = np.pad(action_mask, ((0, 0), (0, pad_dim)), mode='constant', constant_values=0) action_aligned = action_paded[:, self.config.inverse_used_action_channel_ids] action_mask_aligned = action_mask_padded[:, self.config.inverse_used_action_channel_ids] diff --git a/wan_va/modules/model.py b/wan_va/modules/model.py index 25b45e5..b55c011 100644 --- a/wan_va/modules/model.py +++ b/wan_va/modules/model.py @@ -26,10 +26,16 @@ ) from functools import partial +flash_attn_func = None try: - from flash_attn_interface import flash_attn_func -except: - from flash_attn import flash_attn_func + from flash_attn_interface import flash_attn_func as _flash_attn_func # type: ignore + flash_attn_func = _flash_attn_func +except Exception: + try: + from flash_attn import flash_attn_func as _flash_attn_func # type: ignore + flash_attn_func = _flash_attn_func + except Exception: + flash_attn_func = None __all__ = ['WanTransformer3DModel'] @@ -302,6 +308,11 @@ def __init__( if attn_mode == 'torch': self.attn_op = custom_sdpa elif attn_mode == 'flashattn': + if flash_attn_func is None: + raise ImportError( + "attn_mode='flashattn' requires flash-attn, but it is not installed. " + "Install flash-attn or set attn_mode='torch'." + ) self.attn_op = flash_attn_func elif attn_mode == 'flex': self.attn_op = FlexAttnFunc(cross_attention_dim_head is not None) diff --git a/wan_va/train.py b/wan_va/train.py index fff03f0..66c1651 100644 --- a/wan_va/train.py +++ b/wan_va/train.py @@ -126,11 +126,21 @@ def __init__(self, config): shuffle=True, seed=42 ) if config.world_size > 1 else None + # NOTE: Avoid fork-based DataLoader workers after CUDA init / FSDP shard. + # This commonly deadlocks (low CPU, 0% GPU) when num_workers>0. + num_workers = int(getattr(config, "load_worker", 0) or 0) + if torch.cuda.is_initialized() and num_workers > 0: + if int(getattr(config, "rank", 0) or 0) == 0: + logger.warning( + "CUDA is already initialized; forcing DataLoader num_workers=0 " + f"(config.load_worker was {num_workers}) to avoid fork/CUDA deadlocks." + ) + num_workers = 0 self.train_loader = DataLoader( train_dataset, batch_size=config.batch_size, shuffle=(train_sampler is None), - num_workers=config.load_worker, + num_workers=num_workers, sampler=train_sampler, ) @@ -234,8 +244,11 @@ def _prepare_input_dict(self, batch_dict): action_mode=True, noisy_cond_prob=0.0) - latent_dict['text_emb'] = batch_dict['text_emb'] - action_dict['text_emb'] = batch_dict['text_emb'] + # Ensure text embedding dtype matches model weights (param_dtype). + # Otherwise diffusers text projection may error: Half vs BFloat16. + text_emb = batch_dict['text_emb'].to(self.dtype) + latent_dict['text_emb'] = text_emb + action_dict['text_emb'] = text_emb action_dict['actions_mask'] = batch_dict['actions_mask'] input_dict = { @@ -307,7 +320,9 @@ def _train_step(self, batch, batch_idx): output = self.transformer(input_dict, train_mode=True) latent_loss, action_loss = self.compute_loss(input_dict, output) - loss = latent_loss + action_loss + latent_w = float(getattr(self.config, "latent_loss_weight", 1.0)) + action_w = float(getattr(self.config, "action_loss_weight", 1.0)) + loss = latent_w * latent_loss + action_w * action_loss loss.backward() @@ -519,6 +534,20 @@ def run(args): if args.save_root is not None: config.save_root = args.save_root + # Optional: auto resume from latest checkpoint under /checkpoints. + if getattr(args, "resume_latest", False) and getattr(config, "rank", 0) == 0: + ckpt_root = Path(config.save_root) / "checkpoints" + if ckpt_root.exists(): + cands = sorted(ckpt_root.glob("checkpoint_step_*")) + if cands: + latest = cands[-1] + config.resume_from = str(latest) + logger.info(f"Auto-resume enabled. Using latest checkpoint: {latest}") + else: + logger.info(f"Auto-resume enabled but no checkpoints under: {ckpt_root}") + else: + logger.info(f"Auto-resume enabled but checkpoints dir missing: {ckpt_root}") + if rank == 0: logger.info(f"Using config: {args.config_name}") logger.info(f"World size: {world_size}, Local rank: {local_rank}") @@ -542,6 +571,11 @@ def main(): default=None, help="Root directory for saving checkpoints", ) + parser.add_argument( + "--resume-latest", + action="store_true", + help="Auto resume from latest checkpoint_step_* under /checkpoints", + ) args = parser.parse_args() run(args) diff --git a/wan_va/utils/__init__.py b/wan_va/utils/__init__.py index 875b26e..cb9f981 100644 --- a/wan_va/utils/__init__.py +++ b/wan_va/utils/__init__.py @@ -1,9 +1,14 @@ # Copyright 2024-2025 The Robbyant Team Authors. All rights reserved. from .logging import init_logger, logger from .scheduler import FlowMatchScheduler -from .sever_utils import run_async_server_mode from .utils import data_seq_to_patch, get_mesh_id, save_async, sample_timestep_id, warmup_constant_lambda + +def run_async_server_mode(*args, **kwargs): + # lazy import to avoid pulling server-only deps (websockets/msgpack) in offline usage + from .sever_utils import run_async_server_mode as _run + return _run(*args, **kwargs) + __all__ = [ 'logger', 'init_logger', 'get_mesh_id', 'save_async', 'data_seq_to_patch', 'FlowMatchScheduler', 'run_async_server_mode', 'sample_timestep_id', 'warmup_constant_lambda' diff --git a/wan_va/wan_va_server.py b/wan_va/wan_va_server.py index 4c95ace..3573580 100644 --- a/wan_va/wan_va_server.py +++ b/wan_va/wan_va_server.py @@ -27,15 +27,9 @@ load_transformer, load_vae, ) -from utils import ( - FlowMatchScheduler, - data_seq_to_patch, - get_mesh_id, - init_logger, - logger, - run_async_server_mode, - save_async, -) +from utils.logging import init_logger, logger +from utils.scheduler import FlowMatchScheduler +from utils.utils import data_seq_to_patch, get_mesh_id, save_async class VA_Server: @@ -693,6 +687,8 @@ def run(args): model.generate() elif config.infer_mode == 'server': logger.info(f"******************************USE Server mode******************************") + # lazy import to avoid pulling server-only deps in offline usage + from utils.sever_utils import run_async_server_mode run_async_server_mode(model, local_rank, config.host, port) else: raise ValueError(f"Unknown infer mode: {config.infer_mode}")