diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..10df512 --- /dev/null +++ b/.gitignore @@ -0,0 +1,12 @@ +prepared_arms/ +arms/ + +# local outputs / artifacts +0_put_both_the_alphabet_soup_and_the_tomato_sauce_in_the_basket/ +outputs/ + +# python cache +__pycache__/ +**/__pycache__/ +*.pyc + diff --git a/Cursor b/Cursor new file mode 100644 index 0000000..2046f23 --- /dev/null +++ b/Cursor @@ -0,0 +1,2 @@ +Placeholder file to satisfy tooling argument parsing. + diff --git a/README.md b/README.md index 1d3a63c..bde8797 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,170 @@ pip install websockets einops diffusers==0.36.0 transformers==4.55.2 accelerate pip install flash-attn --no-build-isolation ``` +--- + +## 我们的改动(arms + ROCm) + +本仓库在上游基础上,增加了适配 **AMD ROCm(MI300X)** 与自定义双臂单相机数据集 **`arms/`** 的训练/数据准备流程,并把远程 LIBERO client 的落盘能力补齐(mp4/png/npz/action/joint)。 + +### 快速开始(MI300X / ROCm 7.2) + +#### 1) 安装依赖 + +```bash +python3 -m venv ~/venvs/lingbot-va +source ~/venvs/lingbot-va/bin/activate +python -m pip install -U pip + +# 大 wheel 建议关缓存,避免 pip 报 Memoryview is too large +PIP_NO_CACHE_DIR=1 pip install -r requirements.txt +``` + +#### 2) 准备 arms 数据 + +```bash +python scripts/prepare_arms_dataset.py --arms-root ./arms --split train --out ./prepared_arms +``` + +#### 3) 下载 checkpoint(lingbot-va-base) + +```bash +pip install -U huggingface_hub +hf download --repo-type model robbyant/lingbot-va-base --local-dir /root/checkpoints/lingbot-va-base +``` + +#### 4) 提取 VAE latents + +```bash +python scripts/extract_arms_latents.py \ + --dataset-root ./prepared_arms \ + --ckpt-dir /root/checkpoints/lingbot-va-base \ + --device cuda \ + --dtype bfloat16 \ + --height 256 --width 256 +``` + +#### 5) 单卡微调训练(post-training) + +```bash +export TORCHDYNAMO_DISABLE=1 # 更稳(先跑通) +python -m wan_va.train --config-name arms_train --save-root ./train_out_arms +``` + +建议用 `tmux` 运行,防止断网中断: + +```bash +tmux new -s arms_train +# 运行训练命令后,Ctrl+b 再按 d 退出但继续跑 +``` + +### WandB(可选) + +`arms_train` 默认启用 wandb;若缺少 `WANDB_*` 会自动降级关闭。 +要开启请设置: + +```bash +export WANDB_BASE_URL="https://api.wandb.ai" +export WANDB_API_KEY="..." +export WANDB_TEAM_NAME="..." +export WANDB_PROJECT="va_arms" +``` + +### LIBERO(录制 mp4/png/npz) + +详细见 `ROCM_LIBERO_SETUP.md`(包含 EGL/osmesa、以及 client 输出落盘路径说明)。 + +### LIBERO post-training(LeRobot 格式 + latents) + +与 `arms_train` 不同,**官方管线走 LeRobot latent 数据集**(`MultiLatentLeRobotDataset`),配置名为 **`libero_train`**(`wan_va/configs/va_libero_train_cfg.py`,相机与分辨率见 `va_libero_cfg`:`128×128`、双相机 `agentview_rgb` + `eye_in_hand_rgb`)。 + +1. **安装**(ROCm 上 `lerobot` 建议 `--no-deps`,避免动 torch): + +```bash +pip install --no-deps -r requirements_posttrain.txt +pip install scipy wandb +``` + +2. **准备数据**:目录下要有标准 LeRobot 数据集(递归能找到 `meta/info.json`),`episodes` 里带 **`action_config`** 分段;并在**每个相机 key** 下准备好与 `episodes` 对齐的 latent: + +`latents/chunk-XXX//episode_XXXXXX_start_end.pth` + +(字段需与 `ArmsLatentDataset`/现有提取脚本一致:`latent`、`frame_ids`、`text_emb`、`latent_num_frames`、`latent_height`、`latent_width`。本仓库目前只自带 `scripts/extract_arms_latents.py`,LIBERO 双相机 + 128 分辨率需按 `va_libero_cfg.obs_cam_keys` 与 LeRobot 视频路径**自行对齐提取**或从上游/社区找现成 LeRobot+latents。) + +3. **`empty_emb.pt`**:放在你指定的 `empty_emb_path`(与 `va_libero_train_cfg` 中一致),形状与 `text_emb` 相同(可用任意一条 latent 里的 `text_emb` 做 `zeros_like` 生成)。 + +4. **改配置**:编辑 `va_libero_train_cfg.py` 里的 `dataset_path`、`empty_emb_path`、`wan22_pretrained_model_name_or_path`(与 arms 相同即可)。 + +5. **开训**: + +```bash +export TORCHDYNAMO_DISABLE=1 +python -m wan_va.train --config-name libero_train --save-root ./train_out_libero +``` + +--- + +## 今日问题与修复对照表(按实际发生顺序) + +这一节专门记录你今天在服务器上跑流程时遇到的报错、当时哪里跑错、以及我们最终改了哪些代码把它跑通。 + +### A. LIBERO / 录制落盘相关 + +- **`ModuleNotFoundError: No module named 'wan_va'`(跑 client)** + - **触发方式**:用 `python evaluation/libero/client.py` 直接跑文件,Python 没把仓库根目录当包路径。 + - **正确方式**:在仓库根目录用 `python -m evaluation.libero.client ...` 或临时 `PYTHONPATH=.`。 + +- **`pip install libero` 报 “inconsistent version”** + - **原因**:PyPI 的同名包元数据不一致(文件名 0.1.1 / metadata 0.1.0)。 + - **正确安装**:装 LIBERO 官方仓库源码(`pip install -e ~/LIBERO`),见 `ROCM_LIBERO_SETUP.md` 第 3 节。 + +- **`AttributeError: 'NoneType' object has no attribute 'eglQueryString'`(robosuite/mujoco)** + - **原因**:无头渲染 EGL 没配置好(系统 EGL/Mesa 依赖或环境变量缺失)。 + - **修复**:安装 EGL/Mesa 依赖 + `PyOpenGL-accelerate`,并在启动前 `export MUJOCO_GL=egl` / `export PYOPENGL_PLATFORM=egl`。 + - **兜底**:EGL 真不可用时走 `--mujoco-gl osmesa`(更慢但能跑)。 + +- **“远端写不出 mp4 / ffmpeg backend”** + - **修复**:`pip install "imageio[ffmpeg]"`;并保留 PNG 关键帧,可在本地 `ffmpeg` 合成 mp4。 + +- **“我想保存 mp4/png/npz(含 action + joint)”** + - **我们改的代码**:`evaluation/libero/client.py` + - **新增**:关键帧 PNG、轨迹 `.npz`(actions + joint/EEF/gripper + policy_chunks)、视频 mp4(失败回退 gif)。 + - **落盘位置**:`--out-dir` 指定目录下(默认 `outputs/libero/...`),详见 `ROCM_LIBERO_SETUP.md` 的 7.3。 + +### B. arms 数据集(双臂单相机)训练相关 + +- **为什么要先提 latents 再训练?** + - **原因**:训练输入不是原始 RGB,而是 Wan2.2 VAE 编码后的 latent(省显存/加速/与预训练对齐)。 + - **对应脚本**:`scripts/extract_arms_latents.py` + +- **`ModuleNotFoundError: No module named 'wan_va'`(提 latents 脚本)** + - **原因**:脚本直接运行时 import 路径不包含 repo root。 + - **修复**:在脚本中加入 `sys.path.insert(0, repo_root)`(已在 `scripts/extract_arms_latents.py` 里做)。 + +- **VAE temporal shape 报错(例如 conv3d kernel > input / T 不匹配)** + - **原因**:Wan VAE 的时间维对齐很敏感,chunk/步长/偶数长度都会影响。 + - **修复策略(已实现)**:优先走非 streaming `vae.encode(x)`;失败时自动重试(`::2` 下采样、裁掉/补齐一帧等)。 + +- **`ValueError: environment variable MASTER_ADDR expected, but not set`(单卡训练)** + - **原因**:训练入口无条件 init distributed。 + - **修复**:`wan_va/distributed/util.py`:`world_size<=1` 时跳过 `dist.init_process_group`。 + +- **`KeyError` / wandb 环境变量缺失导致启动失败** + - **修复**:`wan_va/train.py`:检测缺少 `WANDB_*` 时自动关闭 wandb(即使 config 里 True)。 + +- **`FileNotFoundError: ./prepared_arms/empty_emb.pt`** + - **修复**:`wan_va/dataset/arms_latent_dataset.py`:自动从已有 latent 文件推断 `text_emb` 形状并生成 `empty_emb.pt`。 + +- **`Cannot re-initialize CUDA in forked subprocess`(DataLoader worker)** + - **原因**:latent `.pth` 里可能存了 CUDA tensor 或加载时映射到 CUDA,worker fork 后触发 CUDA re-init。 + - **修复**: + - `scripts/extract_arms_latents.py`:保存前 `.cpu()`; + - `wan_va/dataset/arms_latent_dataset.py`:`torch.load(..., map_location="cpu")`; + - `wan_va/configs/va_arms_train_cfg.py`:默认 `load_worker=0`。 + +- **`flex_attention` 的 `block_mask` 长度不匹配** + - **修复**:`wan_va/modules/model.py`:对 mask 调用 `_adjust(q_len, kv_len)` 做裁剪对齐。 + ## ⚠️ Important: `attn_mode` Configuration diff --git a/ROCM_LIBERO_SETUP.md b/ROCM_LIBERO_SETUP.md new file mode 100644 index 0000000..d6772ea --- /dev/null +++ b/ROCM_LIBERO_SETUP.md @@ -0,0 +1,403 @@ +## MI300X / ROCm 7.2 跑通 LingBot-VA + LIBERO(Server/Client)指南 + +```bash +wget https://repo.radeon.com/amdgpu-install/7.2/ubuntu/jammy/amdgpu-install_7.2.70200-1_all.deb +sudo apt install ./amdgpu-install_7.2.70200-1_all.deb +sudo apt update +sudo apt install python3-setuptools python3-wheel +sudo usermod -a -G render,video $LOGNAME # Add the current user to the render and video groups +sudo apt install rocm +# +wget https://repo.radeon.com/amdgpu-install/7.2/ubuntu/jammy/amdgpu-install_7.2.70200-1_all.deb +sudo apt install ./amdgpu-install_7.2.70200-1_all.deb +sudo apt update +sudo apt install "linux-headers-$(uname -r)" "linux-modules-extra-$(uname -r)" +sudo apt install amdgpu-dkms +# +reboot +# +pip3 install --no-cache-dir --pre torch torchvision --index-url https://download.pytorch.org/whl/rocm7.2 +apt update && apt install -y git-lfs +git lfs install + + +``` +适用场景: +- **AMD Instinct MI300X**(ROCm 7.2) +- **远端 Ubuntu 22.04**(SSH 上去跑) +- **无显示器/无 GUI**,需要 **MuJoCo + robosuite 离屏渲染(EGL)** +- 目标是先把 `evaluation/libero` 的 **server/client 流程跑通** + +这份指南重点解决你遇到的两个坑: +- **谨慎安装 `lerobot`**:在 ROCm 环境中不建议让 `pip` 自动解析/升级它的依赖链,以免把 ROCm 的 torch 组合替换掉。推荐在装好 ROCm torch 后使用 `pip install --no-deps lerobot==0.3.3`,并把额外依赖(如 `scipy`、`wandb`)单独安装。 +- LIBERO 仿真依赖链较长,按下面“一次装齐最小集合”做。 +- 如果你选择装 `flash-attn`:**在 AMD/ROCm 上应走 Triton 后端**(aiter JIT),并避免“在错误目录执行 pip install .”把项目打成 `UNKNOWN` 包。 + +--- + +## 0. 强烈建议:用独立 venv(避免把系统 Python 搞崩) + +在远端服务器执行: + +```bash +python3 -m venv ~/venvs/lingbot-va +source ~/venvs/lingbot-va/bin/activate +python -m pip install -U pip +``` + +后面所有 `pip` / `python` 都在这个 venv 里执行(不再用系统的 `pip`)。 + +--- + +## 1) 安装 ROCm 版 PyTorch(按你的 ROCm 版本选) + +你机器显示 ROCm 7.2(`amd-smi`),但 PyTorch 官方 wheel 可能不提供 rocm7.2 的 index。 +如果你平台/镜像已经自带可用的 torch(ROCm),你可以跳过此步;否则按平台文档安装 ROCm PyTorch。 + +检查 torch 是否可用: + +```bash +python3 -c "import torch; print(torch.__version__); print('hip:', torch.version.hip); print('cuda available:', torch.cuda.is_available())" +``` + +> 注意:ROCm 下 `torch.cuda.is_available()` 也可能为 True(PyTorch 沿用了 cuda API 名称)。 + +--- + +## 2) 安装 LingBot-VA 运行所需 Python 包(不要装 lerobot) + +在 `~/lingbot-va` 目录,激活 venv 后执行: + +```bash +pip install websockets msgpack opencv-python "imageio[ffmpeg]" matplotlib ftfy easydict einops tqdm +pip install "diffusers==0.36.0" "transformers==4.55.2" accelerate +``` + +### 2.1 关于 flash-attn +在 AMD/ROCm 环境下,`flash-attn` 可能没有对应内核或会回退实现。**能跑通优先**,不必强行追求 flash-attn CUDA 内核。 + +如果你确实想在 MI300X / ROCm 7.2 上安装 `flash-attn`(用于 `"attn_mode": "flashattn"`),推荐按 FlashAttention 官方说明走 **Triton AMD 后端**(ROCm 6.0+,7.2 也适用): +- 需要依赖:`packaging psutil ninja` +- 需要在 **flash-attention 仓库目录**执行安装(不要在 `lingbot-va/` 根目录执行 `pip install .`,否则会出现打包成 `UNKNOWN` 的情况) + +示例(在 venv 内): + +```bash +pip install packaging psutil ninja +cd ~ +git clone https://github.com/Dao-AILab/flash-attention +cd flash-attention +git submodule update --init --recursive + +# 启用 Triton AMD 后端 +FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pip install --no-build-isolation . + +# 可选:打开 autotune(首次运行会预热更久) +export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" +export FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE" +``` + +验证: + +```bash +python3 -c "import flash_attn; print('flash_attn version:', flash_attn.__version__)" +``` + +> 你可能会看到 `flash_attn_2_cuda not found, falling back to Triton implementation`,这在 AMD 上通常是正常的(代表走 Triton/aiter 路线)。 + +--- + +## 3) 安装 LIBERO(benchmark 源码版) + +不要 `pip install libero`(PyPI 上同名包不对,且版本元数据有问题)。 + +```bash +git clone https://github.com/Lifelong-Robot-Learning/LIBERO.git ~/LIBERO +pip install -e ~/LIBERO +python3 -c "from libero.libero import benchmark; print('libero import ok')" +``` + +首次安装会引导你生成 `~/.libero/config.yaml`(数据路径随便填一个能写的目录即可)。 + +--- + +## 3.1(可选)下载数据:HuggingFace / Google Drive / tgz 解压 + +### A) HuggingFace 下载 checkpoints(推荐 huggingface-cli) + +```bash +pip install -U huggingface_hub +mkdir -p ~/lingbot-va/checkpoints +cd ~/lingbot-va/checkpoints +huggingface-cli download --repo-type model robbyant/lingbot-va-base --local-dir lingbot-va-base +``` + +国内可选镜像: + +```bash +export HF_ENDPOINT=https://hf-mirror.com +``` + +### B) Google Drive 下载(gdown) + +```bash +pip install -U gdown +``` + +注意:带 `&` 的 URL 一定要加引号,否则 shell 会把它拆成后台任务,最终只下载到一个几 KB 的 HTML 页面。 + +```bash +gdown --fuzzy "https://drive.google.com/file/d/1QGNkvsb1hlRmRkKCgFlyWitv17sRuagS/view" +``` + +### C) `.tgz` 解压 + +```bash +mkdir -p ~/lingbot-va/data/libero +tar -xzvf libero_10.tgz -C ~/lingbot-va/data/libero +``` + +## 4) 安装 LIBERO 仿真依赖(robosuite + mujoco + bddl + 其他) + +### 4.1 系统 EGL/Mesa 依赖(无头渲染需要) + +```bash +sudo apt-get update +sudo apt-get install -y libegl1 libgles2 libgl1-mesa-dri libgl1-mesa-glx mesa-utils +``` + +### 4.2 Python 依赖 + +```bash +pip install "mujoco==3.1.5" "robosuite==1.4.1" +pip install PyOpenGL PyOpenGL-accelerate +pip install bddl cloudpickle gymnasium +pip install "imageio[ffmpeg]" opencv-python tqdm +pip install gym==0.26.2 +``` + +> 说明:LIBERO 的 `venv.py` 里写的是 `import gym`,所以需要装 `gym`(即使你也装了 `gymnasium`)。 + +--- + +## 5) 配置 MuJoCo 无头渲染(EGL) + +每次跑 client 前都建议先 export: + +```bash +export MUJOCO_GL=egl +export PYOPENGL_PLATFORM=egl +``` + +验证 robosuite+mujoco 是否能 import(会有一些 warning,能打印 ok 就行): + +```bash +python3 -c "import robosuite, mujoco; print('robosuite+mujoco ok')" +``` + +如果报 `eglQueryString` 类错误,通常是环境变量没生效或 EGL/Mesa 组件没装齐。 + +### 5.1 EGL 仍失败(`amdgpu_dri` / `PLATFORM_DEVICE` / `EGL_BAD_DISPLAY`) + +部分云主机 / 容器装了 ROCm 计算栈,但**没有**完整的 OpenGL/EGL 用户态驱动(日志里常见 `failed to open amdgpu_dri.so`、`Cannot initialize a EGL device display`)。此时 **EGL 离屏不可用**,可改用 **OSMesa 软件光栅**(较慢,但能跑通仿真与录屏): + +```bash +sudo apt-get update +sudo apt-get install -y libosmesa6 +export MUJOCO_GL=osmesa +unset PYOPENGL_PLATFORM # 若曾设为 egl,先取消,避免 PyOpenGL 仍走 EGL +``` + +或直接让 client 在 **导入 MuJoCo 之前** 设好后端(等价于上面 `export`): + +```bash +python3 -m evaluation.libero.client \ + --mujoco-gl osmesa \ + --libero-benchmark libero_10 \ + --port 29056 \ + --test-num 1 \ + --task-range 0 1 \ + --out-dir outputs/libero +``` + +> `--mujoco-gl` 必须在进程启动时生效;本仓库在 `client.py` 最前面解析该参数,无需改 LIBERO/robosuite 源码。 + +--- + +## 6) 配置 LingBot-VA checkpoints 路径 + 推理 attn_mode + +### 6.1 改 checkpoints 路径 +编辑 `wan_va/configs/va_libero_cfg.py`: + +- `wan22_pretrained_model_name_or_path = "./checkpoints/libero-va-base"` + +改成你的本地模型目录(建议用正斜杠),例如: + +`/root/lingbot-va/lingbot-va-base` + +该目录下必须包含: +`vae/ tokenizer/ text_encoder/ transformer/` + +### 6.2 改 attn_mode(推理必须) +编辑: +`<模型目录>/transformer/config.json` + +把 `"attn_mode"` 设为: +- `"torch"` 或 `"flashattn"` + +不要用 `"flex"`(训练用,推理会报错)。 + +--- + +## 7) 启动 LIBERO Server / Client(两个终端/两个 tmux pane) + +### 7.1 终端 A:启动 server + +```bash +cd ~/lingbot-va +source ~/venvs/lingbot-va/bin/activate +bash evaluation/libero/launch_server.sh +``` + +看到类似: +`server listening on 0.0.0.0:29056` +说明启动成功。 + +### 7.2 终端 B:启动 client + +```bash +cd ~/lingbot-va +source ~/venvs/lingbot-va/bin/activate +export MUJOCO_GL=egl +export PYOPENGL_PLATFORM=egl + +# 先用最小任务量跑通流程(推荐) +python3 -m evaluation.libero.client \ + --libero-benchmark libero_10 \ + --port 29056 \ + --test-num 1 \ + --task-range 0 1 \ + --out-dir outputs/libero +``` + +跑通后再把 `--test-num`、`--task-range` 调大。 + +> 如果你的系统没有 `python` 命令(只有 `python3`),请改用 `python3 -m ...`,或安装 `python-is-python3`: +> +> ```bash +> sudo apt-get install -y python-is-python3 +> ``` + +输出视频在: +`outputs/libero/.../*.mp4` + +--- + +## 7.3 Client 输出落盘说明(MP4 / PNG / NPZ:action + joint) + +`evaluation/libero/client.py` 会按 episode 落盘到 `--out-dir` 指定的根目录(默认 `outputs/libero`,相对路径是相对于你启动命令时的当前目录)。 + +### 7.3.1 输出目录结构 + +每个 episode 的输出路径形如: + +`{out_dir}/{libero_benchmark}/{task_idx}_{prompt(空格->下划线)}/` + +在该目录下会生成(同一个 episode 前缀为 `{episode_idx}_{done}`,其中 done 为 True/False): + +- **视频**:`{episode_idx}_{done}.mp4` + - 若写 mp4 失败(通常是缺 FFMPEG 后端),会自动回退写成同名 `.gif` +- **关键帧 PNG**:`{episode_idx}_{done}_png/frame_000000.png` ...(与视频帧一致,左右相机横向拼接) +- **轨迹 NPZ**:`{episode_idx}_{done}.npz` + +另有每个 task 的成功率统计: + +`{out_dir}/{libero_benchmark}_{task_idx}.json`(仅 `succ_num/total_num/succ_rate`) + +### 7.3.2 `npz` 里保存了哪些量 + +`{episode}.npz` 中的主要键: + +- `actions`: 每个 env step 实际执行的动作向量,形状通常为 `(T, act_dim)` +- `robot0_joint_pos`, `robot0_joint_vel`, `robot0_gripper_qpos`, `robot0_eef_pos`, `robot0_eef_quat` + - 仅当 LIBERO/robosuite 的原始 `obs` 中存在对应键时才会写入 + - 行数会与 `actions` 的 \(T\) 对齐(逐步记录) +- `policy_chunks`: 策略每次 `model.infer(...)` 返回的整块 action(为了复盘 chunk 级输出),长度为策略前向次数 + - 读取时需要 `allow_pickle=True` + +本地快速检查(示例): + +```bash +python3 - <<'PY' +import numpy as np +p = "outputs/libero/<...>/.npz" +d = np.load(p, allow_pickle=True) +print("keys:", d.files) +for k in d.files: + v = d[k] + print(k, v.dtype, v.shape) +PY +``` + +### 7.3.3 远程写不出 MP4 的处理 + +若出现 “无法写入 MP4 / backend / ffmpeg” 相关报错,优先安装: + +```bash +pip install "imageio[ffmpeg]" +``` + +即使 MP4 临时写不了,也可以只用 PNG 在本地合成视频(在 `*_png/` 目录中执行): + +```bash +ffmpeg -y -framerate 60 -i frame_%06d.png -c:v libx264 -pix_fmt yuv420p out.mp4 +``` + +### 7.3.4 常见日志提示含义 + +你可能会看到: + +`[info] using task orders [0, 1, 2, ..., 9]` + +这表示 benchmark 内部的默认任务顺序(例如 `libero_10` 通常是 10 个任务 0~9)。实际运行哪些任务仍由 `--task-range start end` 决定(常见为左闭右开,即 `0 1` 只跑 task 0)。 + +--- + +## 8) 重要:谨慎安装 lerobot(避免把 ROCm 环境弄崩) + +本仓库的 `evaluation/libero/client.py` 我们已改为使用标准库 `json` 写结果文件, +因此 **跑 LIBERO 不再需要 `lerobot`**。 + +如果你已经装过 `lerobot` 并导致出现: +- `No module named 'flash_attn_2_cuda'` + +这通常意味着环境被换成了 CUDA/NVIDIA 组合。 +最稳的恢复方式是:**丢弃当前环境,重新建一个干净 venv**,按本指南重装最小依赖。 + +如果你确实需要 `lerobot`(用于 post-training 的某些工具链),建议: + +```bash +pip install --no-deps lerobot==0.3.3 +pip install scipy wandb +``` + +核心原则是:**不要让 pip 因为 lerobot 去自动升级/替换你的 torch/diffusers/triton 组合**。 +--- + +## 9) 今日常见报错速查(LIBERO/ROCm) + +- **`ModuleNotFoundError: No module named 'wan_va'`** + - **原因**:没用 `-m` 启动(直接跑文件导致包路径丢失)。 + - **做法**:在 repo 根目录运行 `python3 -m evaluation.libero.client ...`(见 7.2)。 + +- **`pip install libero` 报 “inconsistent version”** + - **原因**:PyPI 同名包元数据问题。 + - **做法**:安装 LIBERO 官方仓库源码(见 3)。 + +- **`AttributeError: 'NoneType' object has no attribute 'eglQueryString'`** + - **原因**:EGL/Mesa 或环境变量没配好。 + - **做法**:按 5) 配 EGL;不行就按 5.1 用 OSMesa。 + +- **写不出 `.mp4`** + - **做法**:按 7.3.3 安装 `imageio[ffmpeg]`;或者只用 PNG 在本地 `ffmpeg` 合成。 + diff --git a/evaluation/libero/client.py b/evaluation/libero/client.py new file mode 100644 index 0000000..f7a9c0f --- /dev/null +++ b/evaluation/libero/client.py @@ -0,0 +1,381 @@ +import os +import sys + + +def _apply_mujoco_gl_argv(): + """MuJoCo reads MUJOCO_GL at native init; must run before importing libero/mujoco.""" + i = 0 + while i < len(sys.argv): + a = sys.argv[i] + if a == "--mujoco-gl" and i + 1 < len(sys.argv): + os.environ["MUJOCO_GL"] = sys.argv[i + 1] + i += 2 + continue + if a.startswith("--mujoco-gl="): + os.environ["MUJOCO_GL"] = a.partition("=")[2] + i += 1 + + +_apply_mujoco_gl_argv() + +import numpy as np +import torch + +# PyTorch 2.6+ defaults torch.load(..., weights_only=True). LIBERO benchmark init *.pt +# files are trusted pickles (numpy + tensors); allow full unpickle for those loads. +_torch_load_orig = torch.load + + +def _torch_load_compat(*args, **kwargs): + if "weights_only" not in kwargs: + try: + return _torch_load_orig(*args, weights_only=False, **kwargs) + except TypeError: + return _torch_load_orig(*args, **kwargs) + return _torch_load_orig(*args, **kwargs) + + +torch.load = _torch_load_compat + +from wan_va.utils.Simple_Remote_Infer.deploy.websocket_client_policy import WebsocketClientPolicy +import argparse +from libero.libero import benchmark +import time +from libero.libero.envs import OffScreenRenderEnv +from pathlib import Path +from tqdm import tqdm +import json +import imageio +import cv2 + + +def save_video(real_obs_list, save_path, fps=15, video_names=["observation.images.agentview_rgb", "observation.images.eye_in_hand_rgb"]): + if not real_obs_list: + print("❌ No real observation frames") + return + + first_obs = real_obs_list[0] + base_h, width_base = first_obs[video_names[0]].shape[:2] + target_size = (width_base, base_h) + + print(f"Saving video: {len(real_obs_list)} frames...") + + final_frames = [_stack_frame(obs, video_names, target_size) for obs in real_obs_list] + + path = Path(save_path) + try: + imageio.mimsave(str(path), final_frames, fps=fps) + print(f"✅ Video saved to: {path}") + except ValueError as e: + msg = str(e).lower() + if path.suffix.lower() == ".mp4" and ( + "backend" in msg or "ffmpeg" in msg or "ffm" in msg or "wI" in str(e) + ): + gif_path = path.with_suffix(".gif") + print( + "⚠️ 无法写入 MP4:缺少 imageio 的 FFMPEG 插件。请先执行:pip install 'imageio[ffmpeg]'" + f"\n 本次已改用 GIF:{gif_path}" + ) + imageio.mimsave(str(gif_path), final_frames, fps=fps) + print(f"✅ Video saved to: {gif_path}") + else: + raise + + +def construct_single_env(env_args): + last_exc = None + for _ in range(5): + try: + return OffScreenRenderEnv(**env_args) + except Exception as e: + last_exc = e + print(f"Error!!! construct env failed: {e}") + time.sleep(5) + raise RuntimeError( + "OffScreenRenderEnv failed after 5 attempts. Headless EGL often needs a working GPU GL stack " + "(e.g. amdgpu_dri / EGL PLATFORM_DEVICE). On minimal cloud images, use CPU software rendering: " + "`sudo apt install -y libosmesa6` then `export MUJOCO_GL=osmesa` or " + "`python3 -m evaluation.libero.client --mujoco-gl osmesa ...`. " + "See ROCM_LIBERO_SETUP.md §5.1." + ) from last_exc + + +def _extract_obs(obs): + """ + Extract agentview and eye_in_hand images from raw env obs dict. + + Avoids torch round-trip: the env already returns uint8 numpy arrays [H, W, C]. + We just flip the vertical axis ([::-1]) and make a contiguous copy once. + """ + agentview = np.ascontiguousarray(obs["agentview_image"][::-1]) + eye_in_hand = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1]) + return {"observation.images.agentview_rgb": agentview, "observation.images.eye_in_hand_rgb": eye_in_hand} + + +# Robosuite / LIBERO 常见本体观测键(存在则写入 trajectory npz) +_PROPRIO_KEYS = ( + "robot0_joint_pos", + "robot0_joint_vel", + "robot0_gripper_qpos", + "robot0_eef_pos", + "robot0_eef_quat", +) + + +def _to_numpy(x): + if hasattr(x, "detach"): + return x.detach().cpu().numpy() + return np.asarray(x) + + +def _proprio_from_raw(raw_obs): + """从 env 原始 obs 中取出关节 / 末端等向量,便于落盘。""" + out = {} + if raw_obs is None: + return out + for k in _PROPRIO_KEYS: + if k in raw_obs: + out[k] = np.asarray(raw_obs[k], dtype=np.float32).reshape(-1).copy() + return out + + +def init_single_env(env_in, init_state): + env_in.reset() + env_in.set_init_state(init_state) + for _ in range(5): + obs, _, _, _ = env_in.step([0.] * 7) + return _extract_obs(obs) + + +def env_one_step(env_in, action): + obs, _, done, _ = env_in.step(action) + return _extract_obs(obs), done, obs + + +def _stack_frame(obs, video_names, target_size): + return np.hstack([cv2.resize(obs[name], target_size) for name in video_names]).astype(np.uint8) + + +def run_one( + model, + libero_benchmark, + task_idx, + out_dir, + episode_idx, +): + vnames = ["observation.images.agentview_rgb", "observation.images.eye_in_hand_rgb"] + benchmark_dict = benchmark.get_benchmark_dict() + benchmark_instance = benchmark_dict[libero_benchmark]() + num_tasks = benchmark_instance.get_num_tasks() + assert task_idx < num_tasks, f"Error: error id must smaller than {num_tasks}" + prompt = benchmark_instance.get_task(task_idx).language + env_args = { + "bddl_file_name": benchmark_instance.get_task_bddl_file_path(task_idx), + "camera_heights": 128, + "camera_widths": 128, + } + init_states = benchmark_instance.get_task_init_states(task_idx) + + cur_env = construct_single_env(env_args) + first_obs = init_single_env(cur_env, init_states[episode_idx % init_states.shape[0]]) + + ret = model.infer(dict(reset=True, prompt=prompt)) + + full_obs_list = [] + action_rows = [] + proprio_rows = {k: [] for k in _PROPRIO_KEYS} + policy_chunks = [] + done = False + first = True + while cur_env.env.timestep < 800: + ret = model.infer(dict(obs=first_obs, prompt=prompt)) + action = _to_numpy(ret["action"]) + policy_chunks.append(np.asarray(action, dtype=np.float32).copy()) + + key_frame_list = [] + assert action.shape[2] % 4 == 0 + action_per_frame = action.shape[2] // 4 + start_idx = 1 if first else 0 + for i in range(start_idx, action.shape[1]): + for j in range(action.shape[2]): + ee_action = _to_numpy(action[:, i, j]).reshape(-1) + action_rows.append(ee_action.astype(np.float32, copy=False)) + observes, done, raw_obs = env_one_step(cur_env, ee_action) + prop = _proprio_from_raw(raw_obs) + for k in proprio_rows: + if k in prop: + proprio_rows[k].append(prop[k]) + if done: + break + if (j + 1) % action_per_frame == 0: + full_obs_list.append(observes) + key_frame_list.append(observes) + + if done: + break + + first = False + + if done: + break + else: + model.infer(dict(obs=key_frame_list, compute_kv_cache=True, imagine=False, state=action)) + + out_file = Path(out_dir) / libero_benchmark / f"{task_idx}_{prompt.replace(' ', '_')}" / f"{episode_idx}_{done}.mp4" + out_file.parent.mkdir(exist_ok=True, parents=True) + episode_base = out_file.parent / out_file.stem + + if full_obs_list: + png_dir = Path(str(episode_base) + "_png") + png_dir.mkdir(parents=True, exist_ok=True) + fo = full_obs_list[0] + base_h, width_base = fo[vnames[0]].shape[:2] + target_size = (width_base, base_h) + for fi, obs in enumerate(full_obs_list): + frame = _stack_frame(obs, vnames, target_size) + cv2.imwrite( + str(png_dir / f"frame_{fi:06d}.png"), + cv2.cvtColor(frame, cv2.COLOR_RGB2BGR), + ) + print(f"✅ PNG 关键帧: {png_dir} ({len(full_obs_list)} 帧)") + + if action_rows: + n = len(action_rows) + for k, rows in proprio_rows.items(): + if rows and len(rows) != n: + raise RuntimeError( + f"proprio / action 步数不一致: {k} len={len(rows)}, actions len={n}" + ) + traj = {"actions": np.stack(action_rows, axis=0)} + for k, rows in proprio_rows.items(): + if rows: + traj[k] = np.stack(rows, axis=0) + if policy_chunks: + traj["policy_chunks"] = np.empty(len(policy_chunks), dtype=object) + for i, c in enumerate(policy_chunks): + traj["policy_chunks"][i] = c + npz_path = episode_base.with_suffix(".npz") + np.savez_compressed(npz_path, **traj) + msg = f"actions {traj['actions'].shape}" + if len(traj) > 1: + msg += " +" + ",".join(f" {k}{traj[k].shape}" for k in traj if k != "actions") + print(f"✅ 轨迹 npz: {npz_path} ({msg})") + + save_video( + real_obs_list=full_obs_list, + save_path=out_file, + fps=60, + video_names=vnames, + ) + + cur_env.close() + return done + + +def run( + libero_benchmark, + port, + out_dir, + test_num, + task_range=None, +): + ''' + task_range: [start, end) for splitting tasks + ''' + if task_range is None: + benchmark_dict = benchmark.get_benchmark_dict() + benchmark_instance = benchmark_dict[libero_benchmark]() + num_tasks = benchmark_instance.get_num_tasks() + progress_bar = tqdm(range(num_tasks), total=num_tasks) + else: + assert len(task_range) == 2, f'task_range: [start, end) for splitting tasks, however, task_range: {task_range}' + num_tasks = task_range[1] - task_range[0] + progress_bar = tqdm(range(task_range[0], task_range[1]), total=num_tasks) + + print(f"#################### Use benchmark: {libero_benchmark}, num_tasks: {num_tasks} #############") + model = WebsocketClientPolicy(port=port) + + video_save_root_dict = None + + episode_list = range(test_num) + for task_idx in progress_bar: + if video_save_root_dict is not None and task_idx in video_save_root_dict: + video_save_list = os.listdir(os.path.join(out_dir, libero_benchmark, video_save_root_dict[task_idx])) + video_states = [1 for file in video_save_list if file.split('_')[1].split('.')[0] == 'True'] + succ_num = float(len(video_states)) + episode_list = range(len(video_save_list), test_num) + else: + succ_num = 0. + + for episode_idx in tqdm(episode_list, total=len(episode_list)): + res_i = run_one(model, libero_benchmark, task_idx, out_dir, episode_idx) + succ_num += res_i + succ_rate = succ_num / (episode_idx + 1) + print(f"Success rate: {succ_rate}, success num: {succ_num}, total num: {episode_idx + 1}") + out_file = Path(out_dir) / f"{libero_benchmark}_{task_idx}.json" + out_file.parent.mkdir(exist_ok=True, parents=True) + with open(out_file, "w", encoding="utf-8") as f: + json.dump( + { + "succ_num": succ_num, + "total_num": float(episode_idx + 1), + "succ_rate": succ_rate, + }, + f, + ensure_ascii=False, + indent=2, + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--libero-benchmark", + type=str, + default="libero_10", + choices=["libero_10", "libero_goal", "libero_spatial", "libero_object"], + help="Benchmark name", + ) + parser.add_argument( + "--task-range", + type=int, + nargs="+", + default=[0, 10], + help="Task range [start, end) for splitting tasks", + ) + parser.add_argument( + "--port", + type=int, + default=23908, + help="WebSocket port", + ) + parser.add_argument( + "--test-num", + type=int, + default=50, + help="Number of test episodes", + ) + parser.add_argument( + "--out-dir", + type=str, + default="outputs/libero", + help="Output directory for results", + ) + parser.add_argument( + "--mujoco-gl", + type=str, + default=None, + choices=["egl", "glfw", "osmesa"], + help="Set MUJOCO_GL before MuJoCo loads (must match early argv parse; use when EGL is unavailable).", + ) + args = parser.parse_args() + if args.mujoco_gl is not None: + os.environ["MUJOCO_GL"] = args.mujoco_gl + kw = vars(args) + kw.pop("mujoco_gl", None) + run(**kw) + print("Finish all process!!!!!!!!!!!!") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/evaluation/libero/launch_client.sh b/evaluation/libero/launch_client.sh new file mode 100644 index 0000000..81d17f1 --- /dev/null +++ b/evaluation/libero/launch_client.sh @@ -0,0 +1,12 @@ +START=0 +END=10 + +# 若 EGL 报错,可先: sudo apt install -y libosmesa6 ,并取消下行注释或 export MUJOCO_GL=osmesa +# MUJOCO_GL=osmesa + +python3 -m evaluation.libero.client \ + --libero-benchmark libero_10 \ + --port 29056 \ + --test-num 50 \ + --task-range $START $END \ + --out-dir outputs/libero \ No newline at end of file diff --git a/evaluation/libero/launch_server.sh b/evaluation/libero/launch_server.sh new file mode 100644 index 0000000..ae5a86d --- /dev/null +++ b/evaluation/libero/launch_server.sh @@ -0,0 +1,11 @@ + +save_root='visualization/' +mkdir -p $save_root + +python3 -m torch.distributed.run \ + --nproc_per_node 1 \ + --master_port 29061 \ + wan_va/wan_va_server.py \ + --config-name libero \ + --port 29056 \ + --save_root $save_root \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 860b14a..cfdd822 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,22 @@ -torch==2.9.0 -torchvision==0.24.0 -torchaudio==2.9.0 +# ROCm (AMD) install tip: +# Use ROCm wheels first, then fall back to PyPI for the rest. +# If your ROCm version differs (e.g. rocm6.3/7.0), change the URL accordingly. +--index-url https://download.pytorch.org/whl/rocm7.2 +--extra-index-url https://pypi.org/simple + +torch +torchvision +# torchaudio wheels are not always available for ROCm; install separately if needed. +# torchaudio diffusers==0.36.0 transformers==4.55.2 +# transformers==4.55.2 expects huggingface-hub < 1.0 +huggingface-hub>=0.34.0,<1.0 accelerate einops easydict -flash_attn +# flash_attn is optional; on ROCm it typically falls back to Triton or may fail to build. +# flash_attn numpy==1.26.4 tqdm imageio[ffmpeg] @@ -18,7 +28,16 @@ ftfy safetensors Pillow +# LIBERO / MuJoCo (headless simulation) +mujoco==3.1.5 +robosuite==1.4.1 +PyOpenGL +PyOpenGL-accelerate +bddl +cloudpickle +gym==0.26.2 +gymnasium + # Post-training -lerobot==0.3.3 scipy wandb diff --git a/requirements_posttrain.txt b/requirements_posttrain.txt new file mode 100644 index 0000000..1d814fe --- /dev/null +++ b/requirements_posttrain.txt @@ -0,0 +1,15 @@ +# Post-training extras (safe on ROCm) +# +# IMPORTANT (ROCm/AMD): +# - Install ROCm PyTorch first (per requirements.txt). +# - Then install lerobot WITHOUT deps to avoid pip swapping your torch build. +# +# Recommended: +# pip install -r requirements.txt +# pip install --no-deps -r requirements_posttrain.txt +# +# NOTE: requirements files do NOT support pip CLI flags like --no-deps inline. +lerobot==0.3.3 +scipy +wandb + diff --git a/scripts/extract_arms_latents.py b/scripts/extract_arms_latents.py new file mode 100644 index 0000000..3233733 --- /dev/null +++ b/scripts/extract_arms_latents.py @@ -0,0 +1,228 @@ +# Extract Wan2.2 VAE latents for prepared_arms dataset. +# +# Input: +# prepared_arms/meta/episodes.jsonl +# prepared_arms/videos/chunk-000/observation.images.cam_high/episode_000000.mp4 +# +# Output: +# prepared_arms/latents/chunk-000/observation.images.cam_high/episode_000000_0_T.pth +# +# Each .pth is a dict matching repo README (latent, latent_num_frames, frame_ids, text_emb, ...). +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path + +import cv2 +import numpy as np +import torch +from einops import rearrange + +# Allow running as a script from repo root or elsewhere. +_REPO_ROOT = Path(__file__).resolve().parents[1] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from wan_va.modules.utils import WanVAEStreamingWrapper, load_text_encoder, load_tokenizer, load_vae, patchify + + +def _read_video_rgb(path: Path) -> tuple[np.ndarray, float]: + cap = cv2.VideoCapture(str(path)) + if not cap.isOpened(): + raise RuntimeError(f"Failed to open video: {path}") + fps = cap.get(cv2.CAP_PROP_FPS) or 0.0 + frames = [] + while True: + ok, frame_bgr = cap.read() + if not ok: + break + frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) + frames.append(frame_rgb) + cap.release() + if not frames: + raise RuntimeError(f"No frames read from: {path}") + return np.stack(frames, axis=0), float(fps) + + +@torch.no_grad() +def _encode_video_to_latent( + vae, + streaming_vae: WanVAEStreamingWrapper, + frames_rgb: np.ndarray, + *, + height: int, + width: int, + device: torch.device, + dtype: torch.dtype, + chunk_size: int, + streaming: bool, +) -> torch.Tensor: + """ + Returns normalized mean latents: [1, C, F, H', W'] (H'/W' ~ height//16,width//16). + """ + # frames_rgb: [F, H, W, 3] uint8 + x = torch.from_numpy(frames_rgb).float().permute(3, 0, 1, 2) # 3,F,H,W + x = torch.nn.functional.interpolate(x, size=(height, width), mode="bilinear", align_corners=False) + x = (x / 255.0) * 2.0 - 1.0 + x = x.unsqueeze(0).to(device=device, dtype=dtype) # 1,3,F,H,W + + F = x.shape[2] + def _try_encode(x_in: torch.Tensor) -> torch.Tensor: + streaming_vae.clear_cache() + # Prefer native VAE encode path if available; it handles temporal shapes internally. + if hasattr(vae, "encode"): + posterior = vae.encode(x_in) + if hasattr(posterior, "latent_dist") and hasattr(posterior.latent_dist, "mean"): + mu = posterior.latent_dist.mean + # Match streaming_vae.encode_chunk output convention: quant_conv output (2C) split later. + # Here we return a fake "enc_out" by concatenating mu/logvar-like tensors. + # logvar is not used downstream, so zeros is fine. + zeros = torch.zeros_like(mu) + return torch.cat([mu, zeros], dim=1) + return streaming_vae.encode_chunk(x_in) + + # Default: encode the full clip in one call. + # If WAN VAE hits temporal shape mismatch, retry with safer temporal layouts. + used_stride = 1 + used_frame_ids = None # filled by caller + try: + enc_out = _try_encode(x) + except RuntimeError as e: + msg = str(e) + if "must match the size of tensor" not in msg or "at non-singleton dimension 2" not in msg: + raise + + # Retry set (in order): keep stride=2, then make temporal length even (drop/pad). + used_stride = 2 + x2 = x[:, :, ::2].contiguous() + tries = [x2] + + # Make even length by dropping last if needed. + if x2.shape[2] % 2 == 1 and x2.shape[2] > 1: + tries.append(x2[:, :, :-1].contiguous()) + + # Or pad one frame to make even length. + if x2.shape[2] % 2 == 1: + tries.append(torch.cat([x2, x2[:, :, -1:, :, :]], dim=2)) + + last_err = e + for t in tries: + try: + enc_out = _try_encode(t) + x = t + break + except RuntimeError as e2: + last_err = e2 + else: + raise last_err + + mu, _logvar = torch.chunk(enc_out, 2, dim=1) + latents_mean = torch.tensor(vae.config.latents_mean, device=mu.device, dtype=mu.dtype).view(1, -1, 1, 1, 1) + latents_std = torch.tensor(vae.config.latents_std, device=mu.device, dtype=mu.dtype).view(1, -1, 1, 1, 1) + mu_norm = ((mu.float() - latents_mean.float()) * (1.0 / latents_std.float())).to(mu.dtype) + # Crop back to original frame count (streaming may alter length). + return mu_norm[:, :, : x.shape[2]], used_stride + + +@torch.no_grad() +def _encode_text(text_encoder, tokenizer, text: str, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + tokens = tokenizer([text], padding="max_length", truncation=True, max_length=256, return_tensors="pt") + tokens = {k: v.to(device) for k, v in tokens.items()} + out = text_encoder(**tokens).last_hidden_state # [1, L, D] + return out.to(dtype) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--dataset-root", type=str, default="./prepared_arms", help="prepared_arms root") + ap.add_argument("--ckpt-dir", type=str, required=True, help="Wan2.2 checkpoint dir containing vae/, tokenizer/, text_encoder/") + ap.add_argument("--device", type=str, default="cuda", help="torch device (e.g. cuda)") + ap.add_argument("--dtype", type=str, default="bfloat16", choices=["float16", "bfloat16", "float32"]) + ap.add_argument("--height", type=int, default=256) + ap.add_argument("--width", type=int, default=256) + ap.add_argument("--chunk-size", type=int, default=2, help="frames per streaming VAE chunk (AutoencoderKLWan streaming is stable with 2)") + ap.add_argument("--streaming", action="store_true", help="Use streaming VAE encode (only needed for very long videos).") + args = ap.parse_args() + + dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[args.dtype] + device = torch.device(args.device) + + root = Path(args.dataset_root) + meta_path = root / "meta" / "episodes.jsonl" + video_dir = root / "videos" / "chunk-000" / "observation.images.cam_high" + out_dir = root / "latents" / "chunk-000" / "observation.images.cam_high" + out_dir.mkdir(parents=True, exist_ok=True) + + ckpt = Path(args.ckpt_dir) + vae = load_vae(str(ckpt / "vae"), torch_dtype=dtype, torch_device=device) + streaming_vae = WanVAEStreamingWrapper(vae) + tokenizer = load_tokenizer(str(ckpt / "tokenizer")) + text_encoder = load_text_encoder(str(ckpt / "text_encoder"), torch_dtype=dtype, torch_device=device) + + lines = meta_path.read_text(encoding="utf-8").splitlines() + for line in lines: + if not line.strip(): + continue + d = json.loads(line) + ep = int(d["episode_index"]) + instruction = d["tasks"][0] if d.get("tasks") else "" + length = int(d["length"]) + start_frame = 0 + end_frame = length + + video_path = video_dir / f"episode_{ep:06d}.mp4" + frames_rgb, ori_fps = _read_video_rgb(video_path) + if frames_rgb.shape[0] != length: + # allow minor mismatch, but keep end_frame consistent with actual frames + end_frame = min(end_frame, frames_rgb.shape[0]) + frames_rgb = frames_rgb[:end_frame] + + lat, used_stride = _encode_video_to_latent( + vae, + streaming_vae, + frames_rgb, + height=args.height, + width=args.width, + device=device, + dtype=dtype, + chunk_size=args.chunk_size, + streaming=args.streaming, + ) # [1,C,F,h,w] + + # Flatten to [N,C] as repo README expects. + lat_fhwc = lat[0].permute(1, 2, 3, 0).contiguous() # F,h,w,C + latent_num_frames, latent_height, latent_width = lat_fhwc.shape[:3] + latent_flat = rearrange(lat_fhwc, "f h w c -> (f h w) c").to(torch.bfloat16).cpu() + + text_emb = _encode_text(text_encoder, tokenizer, instruction, device=device, dtype=torch.bfloat16)[0].cpu() + # When we fallback to 2x downsample in VAE encode, frame_ids should still reflect the sampled frames. + # For simplicity we always use every frame here; if fallback triggers, the latent's frame_ids will be subsampled. + frame_ids = list(range(start_frame, end_frame, used_stride))[: int(lat.shape[2])] + + out = { + "latent": latent_flat, + "latent_num_frames": int(latent_num_frames), + "latent_height": int(latent_height), + "latent_width": int(latent_width), + "video_num_frames": int(frames_rgb.shape[0]), + "video_height": int(frames_rgb.shape[1]), + "video_width": int(frames_rgb.shape[2]), + "text_emb": text_emb, + "text": instruction, + "frame_ids": frame_ids, + "start_frame": int(start_frame), + "end_frame": int(end_frame), + "fps": int(round(ori_fps)) if ori_fps > 0 else 30, + "ori_fps": float(ori_fps), + } + + out_path = out_dir / f"episode_{ep:06d}_{start_frame}_{end_frame}.pth" + torch.save(out, out_path) + print(f"✅ wrote {out_path}") + + +if __name__ == "__main__": + main() + diff --git a/scripts/prepare_arms_dataset.py b/scripts/prepare_arms_dataset.py new file mode 100644 index 0000000..274e60c --- /dev/null +++ b/scripts/prepare_arms_dataset.py @@ -0,0 +1,142 @@ +# Prepare repo-local ./arms (dual-arm single-cam) into a training-ready folder. +# +# Output structure (dataset_root): +# meta/episodes.jsonl +# videos/chunk-000/observation.images.cam_high/episode_000000.mp4 +# actions/episode_000000.npy # float32 [T,30] +# norm_stat.json # q01/q99 over 30 dims +# +# NOTE: You still need to extract Wan2.2 VAE latents into: +# latents/chunk-000/observation.images.cam_high/episode_000000_0_T.pth +# +from __future__ import annotations + +import argparse +import csv +import json +import os +from pathlib import Path + +import numpy as np + + +LEFT_JOINT_COLS = [f"idx{13+i}_left_arm_joint{i+1}_position" for i in range(7)] +RIGHT_JOINT_COLS = [f"idx{20+i}_right_arm_joint{i+1}_position" for i in range(7)] + + +def _read_csv_rows(path: Path) -> tuple[list[str], np.ndarray]: + with open(path, "r", encoding="utf-8") as f: + reader = csv.reader(f) + header = next(reader) + header = [h.strip() for h in header] + rows = [] + for r in reader: + if not r: + continue + rows.append([float(x) for x in r]) + arr = np.asarray(rows, dtype=np.float32) + return header, arr + + +def _map_release_action_to_30(header: list[str], arr: np.ndarray) -> np.ndarray: + """ + Build 30-dim action following repo README standard: + [left_eef(7), right_eef(7), left_joints(7), right_joints(7), left_gripper(1), right_gripper(1)] + + Arms data provides joint positions for both arms (and many finger joints). + We fill joint channels and set eef/grippers to 0 by default. + """ + col_to_idx = {c: i for i, c in enumerate(header)} + + out = np.zeros((arr.shape[0], 30), dtype=np.float32) + + # left joints -> dims 14..20 + for k, col in enumerate(LEFT_JOINT_COLS): + if col in col_to_idx: + out[:, 14 + k] = arr[:, col_to_idx[col]] + else: + raise KeyError(f"Missing column in action.txt: {col}") + + # right joints -> dims 21..27 + for k, col in enumerate(RIGHT_JOINT_COLS): + if col in col_to_idx: + out[:, 21 + k] = arr[:, col_to_idx[col]] + else: + raise KeyError(f"Missing column in action.txt: {col}") + + return out + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--arms-root", type=str, default="./arms", help="Path to arms/ directory") + ap.add_argument("--split", type=str, default="train", choices=["train", "test"], help="Which split to prepare") + ap.add_argument("--out", type=str, default="./prepared_arms", help="Output dataset root") + args = ap.parse_args() + + arms_root = Path(args.arms_root) + split_root = arms_root / args.split + out_root = Path(args.out) + + episodes = sorted([p for p in split_root.iterdir() if p.is_dir()]) + assert episodes, f"No episode folders under: {split_root}" + + (out_root / "meta").mkdir(parents=True, exist_ok=True) + video_dir = out_root / "videos" / "chunk-000" / "observation.images.cam_high" + video_dir.mkdir(parents=True, exist_ok=True) + actions_dir = out_root / "actions" + actions_dir.mkdir(parents=True, exist_ok=True) + instructions_dir = out_root / "instructions" + instructions_dir.mkdir(parents=True, exist_ok=True) + + ep_jsonl = out_root / "meta" / "episodes.jsonl" + all_actions = [] + + with open(ep_jsonl, "w", encoding="utf-8") as f: + for ep_idx, ep in enumerate(episodes): + action_txt = ep / "action.txt" + instr_txt = ep / "instruction.txt" + video_mp4 = ep / "video.mp4" + + header, arr = _read_csv_rows(action_txt) + actions30 = _map_release_action_to_30(header, arr) + np.save(actions_dir / f"episode_{ep_idx:06d}.npy", actions30) + all_actions.append(actions30) + + # copy / link video + dst_video = video_dir / f"episode_{ep_idx:06d}.mp4" + if not dst_video.exists(): + # prefer hardlink when possible + try: + os.link(video_mp4, dst_video) + except OSError: + import shutil + shutil.copy2(video_mp4, dst_video) + + instruction = instr_txt.read_text(encoding="utf-8").strip() + (instructions_dir / f"episode_{ep_idx:06d}.txt").write_text(instruction + "\n", encoding="utf-8") + length = int(actions30.shape[0]) + + line = { + "episode_index": ep_idx, + "tasks": [instruction], + "length": length, + "action_config": [ + {"start_frame": 0, "end_frame": length, "action_text": instruction} + ], + } + f.write(json.dumps(line, ensure_ascii=False) + "\n") + + all_actions = np.concatenate(all_actions, axis=0) + q01 = np.quantile(all_actions, 0.01, axis=0).astype(float).tolist() + q99 = np.quantile(all_actions, 0.99, axis=0).astype(float).tolist() + with open(out_root / "norm_stat.json", "w", encoding="utf-8") as f: + json.dump({"q01": q01, "q99": q99}, f, ensure_ascii=False, indent=2) + + print(f"Prepared {len(episodes)} episodes to: {out_root}") + print("Next: extract Wan2.2 VAE latents into out_root/latents/ mirroring videos/.") + + +if __name__ == "__main__": + main() + diff --git a/wan_va/configs/__init__.py b/wan_va/configs/__init__.py index 87964ae..eff0b78 100644 --- a/wan_va/configs/__init__.py +++ b/wan_va/configs/__init__.py @@ -4,6 +4,8 @@ from .va_franka_i2va import va_franka_i2va_cfg from .va_robotwin_i2va import va_robotwin_i2va_cfg from .va_robotwin_train_cfg import va_robotwin_train_cfg +from .va_arms_train_cfg import va_arms_train_cfg +from .va_libero_train_cfg import va_libero_train_cfg from .va_demo_train_cfg import va_demo_train_cfg from .va_demo_cfg import va_demo_cfg from .va_demo_i2va import va_demo_i2va_cfg @@ -14,6 +16,8 @@ 'robotwin_i2av': va_robotwin_i2va_cfg, 'franka_i2av': va_franka_i2va_cfg, 'robotwin_train': va_robotwin_train_cfg, + 'arms_train': va_arms_train_cfg, + 'libero_train': va_libero_train_cfg, 'demo': va_demo_cfg, 'demo_train': va_demo_train_cfg, 'demo_i2av': va_demo_i2va_cfg, diff --git a/wan_va/configs/va_arms_cfg.py b/wan_va/configs/va_arms_cfg.py new file mode 100644 index 0000000..5cc8e39 --- /dev/null +++ b/wan_va/configs/va_arms_cfg.py @@ -0,0 +1,50 @@ +# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved. +from easydict import EasyDict + +from .shared_config import va_shared_cfg + +va_arms_cfg = EasyDict(__name__="Config: VA arms (dual-arm, single-cam)") +va_arms_cfg.update(va_shared_cfg) + +# dataset_format is used by wan_va/train.py to select dataset implementation. +va_arms_cfg.dataset_format = "arms" + +# Single front-facing camera in ./arms videos. +va_arms_cfg.env_type = "arms" +va_arms_cfg.obs_cam_keys = ["observation.images.cam_high"] + +# Video size used for VAE latent extraction / training. +va_arms_cfg.height = 256 +va_arms_cfg.width = 256 + +# Transformer temporal settings (match your data / sampling later if needed). +va_arms_cfg.attn_window = 72 +va_arms_cfg.frame_chunk_size = 4 + +# Training-time clip length control. +# ArmsLatentDataset will, by default, randomly crop a fixed number of *latent frames* +# from each episode to keep sequence length bounded (helps avoid OOM on long episodes). +# Set to None / 0 to disable cropping and use full episodes. +va_arms_cfg.train_latent_frames = 24 + +# FlowMatch schedulers +va_arms_cfg.snr_shift = 5.0 +va_arms_cfg.action_snr_shift = 1.0 + +# Action format follows repo README "30 dims" standard. For release we mostly fill joint channels. +va_arms_cfg.action_dim = 30 +va_arms_cfg.action_per_frame = 4 + +# Use dual-arm channels (16 total) like robotwin, but semantics come from your mapping. +# We keep the same idea: 7 + 1 + 7 + 1 = 16 channels selected from 30-dim action. +# Here we use JOINT channels: left joints [14:21), right joints [21:28), and grippers [28,29]. +va_arms_cfg.used_action_channel_ids = list(range(14, 21)) + [28] + list(range(21, 28)) + [29] +inverse_used_action_channel_ids = [len(va_arms_cfg.used_action_channel_ids)] * va_arms_cfg.action_dim +for i, j in enumerate(va_arms_cfg.used_action_channel_ids): + inverse_used_action_channel_ids[j] = i +va_arms_cfg.inverse_used_action_channel_ids = inverse_used_action_channel_ids + +# Placeholder stats; ArmsLatentDataset will prefer /norm_stat.json if present. +va_arms_cfg.action_norm_method = "quantiles" +va_arms_cfg.norm_stat = {"q01": [0.0] * 30, "q99": [1.0] * 30} + diff --git a/wan_va/configs/va_arms_train_cfg.py b/wan_va/configs/va_arms_train_cfg.py new file mode 100644 index 0000000..570f365 --- /dev/null +++ b/wan_va/configs/va_arms_train_cfg.py @@ -0,0 +1,34 @@ +# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved. +import os +from easydict import EasyDict + +from .va_arms_cfg import va_arms_cfg + +va_arms_train_cfg = EasyDict(__name__="Config: VA arms train") +va_arms_train_cfg.update(va_arms_cfg) + +va_arms_train_cfg.dataset_path = "./prepared_arms" +va_arms_train_cfg.empty_emb_path = os.path.join(va_arms_train_cfg.dataset_path, "empty_emb.pt") + +# Pretrained checkpoint root containing transformer/vae/tokenizer/text_encoder. +# Set this to your downloaded lingbot-va-base path on the training machine. +va_arms_train_cfg.wan22_pretrained_model_name_or_path = "/root/checkpoints/lingbot-va-base" + +# Enable by default; training will auto-disable if WANDB_* env vars are missing. +va_arms_train_cfg.enable_wandb = True +# Single-GPU default: avoid multiprocessing CUDA issues when loading .pth files. +va_arms_train_cfg.load_worker = 0 +va_arms_train_cfg.save_interval = 200 +va_arms_train_cfg.gc_interval = 50 +va_arms_train_cfg.cfg_prob = 0.1 + +# Training parameters (start conservative) +va_arms_train_cfg.learning_rate = 1e-5 +va_arms_train_cfg.beta1 = 0.9 +va_arms_train_cfg.beta2 = 0.95 +va_arms_train_cfg.weight_decay = 1e-1 +va_arms_train_cfg.warmup_steps = 10 +va_arms_train_cfg.batch_size = 1 +va_arms_train_cfg.gradient_accumulation_steps = 10 +va_arms_train_cfg.num_steps = 5000 + diff --git a/wan_va/configs/va_libero_cfg.py b/wan_va/configs/va_libero_cfg.py new file mode 100644 index 0000000..aa868e5 --- /dev/null +++ b/wan_va/configs/va_libero_cfg.py @@ -0,0 +1,58 @@ +# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved. +from easydict import EasyDict + +from .shared_config import va_shared_cfg + +va_libero_cfg = EasyDict(__name__='Config: VA libero') +va_libero_cfg.update(va_shared_cfg) +va_shared_cfg.infer_mode = 'server' + +va_libero_cfg.wan22_pretrained_model_name_or_path = "/path/to/pretrained/model" + +va_libero_cfg.attn_window = 30 +va_libero_cfg.frame_chunk_size = 4 +va_libero_cfg.env_type = 'none' + +va_libero_cfg.height = 128 +va_libero_cfg.width = 128 +va_libero_cfg.action_dim = 30 +va_libero_cfg.action_per_frame = 4 +va_libero_cfg.obs_cam_keys = [ + 'observation.images.agentview_rgb', 'observation.images.eye_in_hand_rgb' +] +va_libero_cfg.guidance_scale = 5 +va_libero_cfg.action_guidance_scale = 1 + +va_libero_cfg.num_inference_steps = 20 +va_libero_cfg.video_exec_step = -1 +va_libero_cfg.action_num_inference_steps = 50 + +va_libero_cfg.snr_shift = 5.0 +va_libero_cfg.action_snr_shift = 1.0 + +va_libero_cfg.used_action_channel_ids = list(range(0, 6)) + list(range(28, 29)) +inverse_used_action_channel_ids = [len(va_libero_cfg.used_action_channel_ids) + ] * va_libero_cfg.action_dim +for i, j in enumerate(va_libero_cfg.used_action_channel_ids): + inverse_used_action_channel_ids[j] = i +va_libero_cfg.inverse_used_action_channel_ids = inverse_used_action_channel_ids + +va_libero_cfg.action_norm_method = 'quantiles' +va_libero_cfg.norm_stat = { + "q01": [ + -0.6589285731315613, + -0.84375, + -0.9375, + -0.12107142806053162, + -0.15964286029338837, + -0.26571428775787354, + ] + [0.] * 22 + [-1.0, 0.], + "q99": [ + 0.8999999761581421, + 0.8544642925262451, + 0.9375, + 0.17142857611179352, + 0.1842857152223587, + 0.34392857551574707, + ] + [0.] * 22 + [1.0, 0.], +} \ No newline at end of file diff --git a/wan_va/configs/va_libero_i2va.py b/wan_va/configs/va_libero_i2va.py new file mode 100644 index 0000000..b9c903d --- /dev/null +++ b/wan_va/configs/va_libero_i2va.py @@ -0,0 +1,12 @@ + +# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved. +from easydict import EasyDict +from .va_libero_cfg import va_libero_cfg + +va_libero_i2va_cfg = EasyDict(__name__='Config: VA libero i2va') +va_libero_i2va_cfg.update(va_libero_cfg) + +va_libero_i2va_cfg.input_img_path = 'example/libero' +va_libero_i2va_cfg.num_chunks_to_infer = 10 +va_libero_i2va_cfg.prompt = "put both the alphabet soup and the tomato sauce in the basket" +va_libero_i2va_cfg.infer_mode = 'i2va' diff --git a/wan_va/configs/va_libero_train_cfg.py b/wan_va/configs/va_libero_train_cfg.py new file mode 100644 index 0000000..9279394 --- /dev/null +++ b/wan_va/configs/va_libero_train_cfg.py @@ -0,0 +1,35 @@ +# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved. +from easydict import EasyDict +from .va_libero_cfg import va_libero_cfg +import os + +va_libero_train_cfg = EasyDict(__name__='Config: VA libero train') +va_libero_train_cfg.update(va_libero_cfg) + +# Use LeRobot latent dataset (default in train.py when dataset_format != "arms"). +va_libero_train_cfg.dataset_format = "lerobot" + +# Parent directory that *contains* one or more LeRobot dataset roots (each has meta/info.json). +# Latents must live under /latents/... mirroring videos (see README). +va_libero_train_cfg.dataset_path = "/path/to/lerobot_datasets_parent" +# Any one repo root is fine; create e.g. `torch.zeros_like(sample_text_emb)` once you have a .pth. +va_libero_train_cfg.empty_emb_path = "/path/to/some_lerobot_repo/empty_emb.pt" + +va_libero_train_cfg.wan22_pretrained_model_name_or_path = "/root/checkpoints/lingbot-va-base" + +va_libero_train_cfg.enable_wandb = True +# Single-GPU / ROCm: avoid DataLoader workers loading CUDA tensors in forked children. +va_libero_train_cfg.load_worker = 0 +va_libero_train_cfg.save_interval = 200 +va_libero_train_cfg.gc_interval = 50 +va_libero_train_cfg.cfg_prob = 0.1 + +# Training parameters +va_libero_train_cfg.learning_rate = 1e-5 +va_libero_train_cfg.beta1 = 0.9 +va_libero_train_cfg.beta2 = 0.95 +va_libero_train_cfg.weight_decay = 1e-1 +va_libero_train_cfg.warmup_steps = 10 +va_libero_train_cfg.batch_size = 1 +va_libero_train_cfg.gradient_accumulation_steps = 10 +va_libero_train_cfg.num_steps = 5000 \ No newline at end of file diff --git a/wan_va/dataset/__init__.py b/wan_va/dataset/__init__.py index 362c191..4824ddf 100644 --- a/wan_va/dataset/__init__.py +++ b/wan_va/dataset/__init__.py @@ -1,6 +1,7 @@ # Copyright 2024-2025 The Robbyant Team Authors. All rights reserved. -from .lerobot_latent_dataset import MultiLatentLeRobotDataset +# +# Keep imports lightweight: LeRobot dataset pulls extra dependencies (datasets, etc.). +# Import it at call sites when needed. +from .arms_latent_dataset import ArmsLatentDataset -__all__ = [ - 'MultiLatentLeRobotDataset' -] \ No newline at end of file +__all__ = ["ArmsLatentDataset"] \ No newline at end of file diff --git a/wan_va/dataset/arms_latent_dataset.py b/wan_va/dataset/arms_latent_dataset.py new file mode 100644 index 0000000..979d538 --- /dev/null +++ b/wan_va/dataset/arms_latent_dataset.py @@ -0,0 +1,213 @@ +# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved. +from __future__ import annotations + +import json +import os +from dataclasses import dataclass +from pathlib import Path + +import numpy as np +import torch +from einops import rearrange + + +@dataclass(frozen=True) +class ArmsSample: + episode_index: int + start_frame: int + end_frame: int + action_text: str + + +def _load_norm_stat(dataset_root: Path, fallback_norm_stat: dict) -> tuple[np.ndarray, np.ndarray]: + """ + Priority: + 1) /norm_stat.json (written by prep script) + 2) config.norm_stat + """ + stat_path = dataset_root / "norm_stat.json" + if stat_path.exists(): + with open(stat_path, "r", encoding="utf-8") as f: + d = json.load(f) + q01 = np.array(d["q01"], dtype=float)[None] + q99 = np.array(d["q99"], dtype=float)[None] + return q01, q99 + + q01 = np.array(fallback_norm_stat["q01"], dtype=float)[None] + q99 = np.array(fallback_norm_stat["q99"], dtype=float)[None] + return q01, q99 + + +class ArmsLatentDataset(torch.utils.data.Dataset): + """ + Dataset for the repo-local ./arms data after running scripts/prepare_arms_dataset.py. + + Expected directory: + dataset_path/ + meta/episodes.jsonl + videos/chunk-000/observation.images.cam_high/episode_000000.mp4 + actions/episode_000000.npy # [T, 30] float32 (mapped to 30-dim standard) + latents/chunk-000/observation.images.cam_high/episode_000000_0_T.pth (optional; required for training) + empty_emb.pt + norm_stat.json (optional) + """ + + def __init__(self, config): + self.config = config + self.root = Path(config.dataset_path) + self.meta_path = self.root / "meta" / "episodes.jsonl" + assert self.meta_path.exists(), f"episodes.jsonl not found: {self.meta_path}" + + self.used_video_keys = list(config.obs_cam_keys) + assert len(self.used_video_keys) == 1, "release dataset expects a single camera key" + + empty_emb_path = Path(config.empty_emb_path) + if not empty_emb_path.exists(): + # Create a compatible empty embedding from the first latent file's text_emb. + # This avoids requiring users to manually provide empty_emb.pt. + latent_dir = self.root / "latents" / "chunk-000" / self.used_video_keys[0] + first_pth = next(iter(sorted(latent_dir.glob("episode_*.pth"))), None) + if first_pth is None: + raise FileNotFoundError( + f"empty_emb.pt not found at {empty_emb_path} and no latent files under {latent_dir} " + "to infer embedding shape. Please run latent extraction first." + ) + sample = torch.load(first_pth, weights_only=False) + text_emb = sample.get("text_emb", None) + if text_emb is None: + raise KeyError(f"'text_emb' missing in latent file: {first_pth}") + empty_emb_path.parent.mkdir(parents=True, exist_ok=True) + torch.save(torch.zeros_like(text_emb), empty_emb_path) + self.empty_emb = torch.load(empty_emb_path, weights_only=False) + self.cfg_prob = getattr(config, "cfg_prob", 0.0) + + self.q01, self.q99 = _load_norm_stat(self.root, config.norm_stat) + + self.latent_path = self.root / "latents" + self.actions_path = self.root / "actions" + + self.samples: list[ArmsSample] = [] + with open(self.meta_path, "r", encoding="utf-8") as f: + for line in f: + if not line.strip(): + continue + d = json.loads(line) + ep = int(d["episode_index"]) + tasks = d.get("tasks", []) + action_config = d.get("action_config", []) + if not action_config: + # single segment fallback + self.samples.append( + ArmsSample(ep, 0, int(d["length"]), tasks[0] if tasks else "") + ) + else: + for acfg in action_config: + self.samples.append( + ArmsSample( + ep, + int(acfg["start_frame"]), + int(acfg["end_frame"]), + str(acfg.get("action_text", tasks[0] if tasks else "")), + ) + ) + + # inverse_used_action_channel_ids is defined on config; keep behavior consistent with LeRobot loader. + self.inverse_used_action_channel_ids = np.asarray(config.inverse_used_action_channel_ids, dtype=int) + + def __len__(self) -> int: + return len(self.samples) + + def _load_latent_segment(self, episode_index: int, start_frame: int, end_frame: int) -> dict: + # We keep the same naming convention as README: episode_{idx}_{start}_{end}.pth + episode_chunk = 0 + key = self.used_video_keys[0] + latent_file = ( + self.latent_path + / f"chunk-{episode_chunk:03d}" + / key + / f"episode_{episode_index:06d}_{start_frame}_{end_frame}.pth" + ) + assert latent_file.exists(), ( + f"latent file not found: {latent_file}\n" + "You need to extract Wan2.2 VAE latents into dataset_root/latents/ mirroring videos/." + ) + return torch.load(latent_file, weights_only=False, map_location="cpu") + + def _load_actions(self, episode_index: int) -> np.ndarray: + ap = self.actions_path / f"episode_{episode_index:06d}.npy" + assert ap.exists(), f"actions file not found: {ap}" + a = np.load(ap) + assert a.ndim == 2 and a.shape[1] == 30, f"actions must be [T, 30], got {a.shape}" + return a + + def _action_post_process(self, local_start_frame: int, local_end_frame: int, latent_frame_ids, action: np.ndarray): + # Keep same logic as lerobot_latent_dataset.py + act_shift = int(latent_frame_ids[0] - local_start_frame) + frame_stride = latent_frame_ids[1] - latent_frame_ids[0] + action = action[act_shift:] + + action = np.pad(action, pad_width=((frame_stride * 4, 0), (0, 0)), mode="constant", constant_values=0) + + latent_frame_num = (len(latent_frame_ids) - 1) // 4 + 1 + required_action_num = latent_frame_num * frame_stride * 4 + action = action[:required_action_num] + action_mask = np.ones_like(action, dtype="bool") + assert action.shape[0] == required_action_num + + # Extra mask channel, same as existing pipeline. + action_paded = np.pad(action, ((0, 0), (0, 1)), mode="constant", constant_values=0) + action_mask_padded = np.pad(action_mask, ((0, 0), (0, 1)), mode="constant", constant_values=0) + + action_aligned = action_paded[:, self.inverse_used_action_channel_ids] + action_mask_aligned = action_mask_padded[:, self.inverse_used_action_channel_ids] + action_aligned = (action_aligned - self.q01) / (self.q99 - self.q01 + 1e-6) * 2.0 - 1.0 + + action_aligned = rearrange(action_aligned, "(f n) c -> c f n 1", f=latent_frame_num) + action_mask_aligned = rearrange(action_mask_aligned, "(f n) c -> c f n 1", f=latent_frame_num) + action_aligned *= action_mask_aligned + return torch.from_numpy(action_aligned).float(), torch.from_numpy(action_mask_aligned).bool() + + def __getitem__(self, idx: int) -> dict: + s = self.samples[idx % len(self.samples)] + latent_dict = self._load_latent_segment(s.episode_index, s.start_frame, s.end_frame) + + # Latent dict fields follow README table. + latent = latent_dict["latent"] + latent_num_frames = int(latent_dict["latent_num_frames"]) + latent_height = int(latent_dict["latent_height"]) + latent_width = int(latent_dict["latent_width"]) + frame_ids = latent_dict["frame_ids"] + + lat = rearrange(latent, "(f h w) c -> f h w c", f=latent_num_frames, h=latent_height, w=latent_width) + # Optional training-time random crop in *latent frame* space to bound sequence length and memory. + train_latent_frames = int(getattr(self.config, "train_latent_frames", 0) or 0) + if train_latent_frames > 0 and latent_num_frames > train_latent_frames: + max_start = latent_num_frames - train_latent_frames + start_l = int(torch.randint(0, max_start + 1, (1,)).item()) + end_l = start_l + train_latent_frames + lat = lat[start_l:end_l] + frame_ids = frame_ids[start_l:end_l] + latent_num_frames = train_latent_frames + + lat = lat.permute(3, 0, 1, 2) # C,F,H,W + + text_emb = latent_dict["text_emb"] + if torch.rand(1).item() < self.cfg_prob: + text_emb = self.empty_emb + + actions_full = self._load_actions(s.episode_index) + # Align actions to the (possibly cropped) latent frame ids. + seg_start = int(frame_ids[0]) + seg_end = int(frame_ids[-1]) + 1 + seg_start = max(seg_start, 0) + seg_end = min(seg_end, actions_full.shape[0]) + actions_seg = actions_full[seg_start:seg_end] + actions_aligned, actions_mask = self._action_post_process(seg_start, seg_end, frame_ids, actions_seg) + + return { + "latents": lat, + "text_emb": text_emb, + "actions": actions_aligned, + "actions_mask": actions_mask, + } + diff --git a/wan_va/distributed/util.py b/wan_va/distributed/util.py index 8ad7c29..a1c8fb4 100644 --- a/wan_va/distributed/util.py +++ b/wan_va/distributed/util.py @@ -22,12 +22,18 @@ def _configure_model(model, shard_fn, param_dtype, device, eval_mode=True): def init_distributed(world_size, local_rank, rank): - # if world_size > 1: + # Single-process training: do not require env:// rendezvous variables. + if world_size is None or int(world_size) <= 1: + torch.cuda.set_device(local_rank) + return + torch.cuda.set_device(local_rank) - dist.init_process_group(backend="nccl", - init_method="env://", - rank=rank, - world_size=world_size) + dist.init_process_group( + backend="nccl", + init_method="env://", + rank=rank, + world_size=world_size, + ) def dist_mean(local_tensor): if dist.is_initialized(): diff --git a/wan_va/modules/model.py b/wan_va/modules/model.py index 25b45e5..0bbdb58 100644 --- a/wan_va/modules/model.py +++ b/wan_va/modules/model.py @@ -77,6 +77,16 @@ def half(x): k_varlen = k_varlen.to(v_varlen.dtype) block_mask = FlexAttnFunc.cross_attention_mask if self.is_cross else FlexAttnFunc.attention_mask + # The precomputed block mask may be larger than the actual q/kv lengths + # for the current batch/sequence. Crop it to the current lengths. + q_len = q_varlen.shape[2] + kv_len = k_varlen.shape[2] + if block_mask is not None and hasattr(block_mask, "_adjust"): + try: + block_mask = block_mask._adjust(q_len, kv_len) + except Exception: + # Fall back to original mask if adjust fails. + pass x_out = FlexAttnFunc.flex_attn(q_varlen, k_varlen, v_varlen, block_mask=block_mask, kernel_options = { "BLOCK_M": 64, diff --git a/wan_va/train.py b/wan_va/train.py index 3d85ca2..b5152ba 100644 --- a/wan_va/train.py +++ b/wan_va/train.py @@ -43,14 +43,21 @@ FlowMatchScheduler ) -from dataset import MultiLatentLeRobotDataset +from wan_va.dataset.arms_latent_dataset import ArmsLatentDataset import gc class Trainer: def __init__(self, config): if config.enable_wandb and config.rank == 0: - wandb.login(host=os.environ['WANDB_BASE_URL'], key=os.environ['WANDB_API_KEY']) + # Make wandb optional in minimal server environments. + base_url = os.environ.get("WANDB_BASE_URL") + api_key = os.environ.get("WANDB_API_KEY") + if not base_url or not api_key: + logger.warning("WANDB_* env vars missing; disabling wandb logging.") + config.enable_wandb = False + else: + wandb.login(host=base_url, key=api_key) self.wandb = wandb self.wandb.init( entity=os.environ["WANDB_TEAM_NAME"], @@ -61,6 +68,10 @@ def __init__(self, config): name='test_lln' # name=os.path.basename(os.path.normpath(job_config.job.dump_folder)) ) + # Make sure charts use the same global step. + # Also log "flat" metric names in addition to nested keys so W&B auto-panels pick them up. + self.wandb.define_metric("step") + self.wandb.define_metric("*", step_metric="step") logger.info("WandB logging enabled") self.step = 0 self.config = config @@ -118,7 +129,14 @@ def __init__(self, config): # Setup dataloaders logger.info("Setting up datasets...") - train_dataset = MultiLatentLeRobotDataset(config=config) + dataset_format = getattr(config, "dataset_format", "lerobot") + if dataset_format == "arms": + train_dataset = ArmsLatentDataset(config=config) + else: + # Import LeRobot-based dataset only when needed to avoid forcing + # the heavy `lerobot/datasets` dependency for arms-only training. + from wan_va.dataset.lerobot_latent_dataset import MultiLatentLeRobotDataset + train_dataset = MultiLatentLeRobotDataset(config=config) train_sampler = DistributedSampler( train_dataset, num_replicas=config.world_size, @@ -144,8 +162,8 @@ def __init__(self, config): self.gradient_accumulation_steps = getattr(config, 'gradient_accumulation_steps', 1) self.train_loader_iter = None - # if hasattr(config, 'resume_from') and config.resume_from: - # self._load_training_state(config.resume_from) + if hasattr(config, "resume_from") and config.resume_from: + self._load_training_state(config.resume_from) def _get_next_batch(self): """Get next batch from iterator, reset if epoch is finished.""" @@ -262,9 +280,14 @@ def compute_loss(self, self.patch_size, latent_pred, input_dict['latent_dict']['targets'].shape[-3], input_dict['latent_dict']['targets'].shape[-2], input_dict['latent_dict']['targets'].shape[-1], batch_size=latent_pred.shape[0]) - Bn, Fn = input_dict['latent_dict']['timesteps'].shape - latent_loss_weight = self.train_scheduler_latent.training_weight(input_dict['latent_dict']['timesteps'].flatten()).reshape(Bn, Fn) - action_loss_weight = self.train_scheduler_action.training_weight(input_dict['action_dict']['timesteps'].flatten()).reshape(Bn, Fn) + Bn_l, Fn_l = input_dict["latent_dict"]["timesteps"].shape + Bn_a, Fn_a = input_dict["action_dict"]["timesteps"].shape + latent_loss_weight = self.train_scheduler_latent.training_weight( + input_dict["latent_dict"]["timesteps"].flatten() + ).reshape(Bn_l, Fn_l) + action_loss_weight = self.train_scheduler_action.training_weight( + input_dict["action_dict"]["timesteps"].flatten() + ).reshape(Bn_a, Fn_a) # Frame-wise video loss calculation latent_loss = F.mse_loss(latent_pred.float(), input_dict['latent_dict']['targets'].float().detach(), reduction='none') @@ -300,10 +323,12 @@ def _train_step(self, batch, batch_idx): should_sync = (batch_idx + 1) % self.gradient_accumulation_steps == 0 - if not should_sync: - self.transformer.set_requires_gradient_sync(False) - else: - self.transformer.set_requires_gradient_sync(True) + # Only FSDP-wrapped models have set_requires_gradient_sync. + if hasattr(self.transformer, "set_requires_gradient_sync"): + if not should_sync: + self.transformer.set_requires_gradient_sync(False) + else: + self.transformer.set_requires_gradient_sync(True) output = self.transformer(input_dict, train_mode=True) latent_loss, action_loss = self.compute_loss(input_dict, output) @@ -335,10 +360,11 @@ def save_checkpoint(self,): options=StateDictOptions(full_state_dict=True, cpu_offload=True), ) state_dict_bf16 = {k: v.to(torch.bfloat16) for k, v in state_dict.items()} - # optim_state = get_optimizer_state_dict( - # self.transformer, self.optimizer, - # options=StateDictOptions(full_state_dict=True, cpu_offload=True), - # ) + optim_state = get_optimizer_state_dict( + self.transformer, + self.optimizer, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) # Only rank 0 saves the checkpoint if self.config.rank == 0: @@ -363,14 +389,18 @@ def save_checkpoint(self,): with open(config_file, 'w') as f: json.dump(config_dict, f, indent=2) - # # Save optimizer state and training metadata in PyTorch format - # training_state_path = checkpoint_dir / "training_state.pt" - # logger.info(f"Saving training state to {training_state_path}") - # torch.save({ - # 'step': self.step, - # 'optimizer_state_dict': optim_state, - # 'config': vars(self.config), - # }, training_state_path) + # Save optimizer/lr_scheduler/step for resume. + training_state_path = checkpoint_dir / "training_state.pt" + logger.info(f"Saving training state to {training_state_path}") + torch.save( + { + "step": self.step, + "optimizer_state_dict": optim_state, + "lr_scheduler_state_dict": self.lr_scheduler.state_dict(), + "config": dict(self.config), + }, + training_state_path, + ) logger.info(f"Checkpoint saved successfully at step {self.step}") @@ -409,6 +439,12 @@ def _load_training_state(self, checkpoint_path): optim_state_dict=training_state['optimizer_state_dict'], options=StateDictOptions(full_state_dict=True, strict=False) ) + if "lr_scheduler_state_dict" in training_state: + try: + self.lr_scheduler.load_state_dict(training_state["lr_scheduler_state_dict"]) + except Exception as e: + if self.config.rank == 0: + logger.warning(f"Failed to load lr_scheduler state, continuing. Error: {e}") self.step = training_state.get('step', 0) if self.config.rank == 0: @@ -470,7 +506,10 @@ def train(self): if self.config.rank == 0: total_norm = losses['total_norm'] - progress_bar.n += self.gradient_accumulation_steps + # One entry here == one optimizer step (self.step). Do not add + # gradient_accumulation_steps — that made tqdm's "Nit" look like + # 10× the real step and confused monitoring. + progress_bar.update(1) progress_bar.set_postfix({ 'latent_loss': f'{latent_loss_show:.4f}', 'action_loss': f'{action_loss_show:.4f}', @@ -480,10 +519,14 @@ def train(self): }) if self.config.enable_wandb: self.wandb.log({ + 'step': self.step, 'loss_metrics/global_avg_video_loss': latent_loss_show, 'loss_metrics/global_avg_action_loss': action_loss_show, 'loss_metrics/global_max_video_loss': max_latent_loss_show, 'loss_metrics/global_max_action_loss': max_action_loss_show, + # Flat keys (easier to find in W&B UI) + 'latent_loss': latent_loss_show, + 'action_loss': action_loss_show, 'grad_norm': total_norm.item(), 'lr': lr, }, step=self.step) @@ -519,6 +562,37 @@ def run(args): if args.save_root is not None: config.save_root = args.save_root + def _find_latest_checkpoint(ckpt_root: Path): + if not ckpt_root.exists(): + return None + best_step = None + best_path = None + for p in ckpt_root.glob("checkpoint_step_*"): + try: + step = int(p.name.split("_")[-1]) + except Exception: + continue + if best_step is None or step > best_step: + best_step = step + best_path = p + return str(best_path) if best_path is not None else None + + if getattr(args, "resume", False): + ckpt_root = Path(config.save_root) / "checkpoints" + latest = _find_latest_checkpoint(ckpt_root) + if latest: + config.resume_from = latest + if rank == 0: + logger.info(f"Auto-resume enabled. Using latest checkpoint: {latest}") + else: + if rank == 0: + logger.warning(f"--resume was set but no checkpoints found under {ckpt_root}. Starting fresh.") + + if getattr(args, "resume_from", None): + config.resume_from = args.resume_from + if rank == 0: + logger.info(f"Resuming from checkpoint (explicit): {config.resume_from}") + if rank == 0: logger.info(f"Using config: {args.config_name}") logger.info(f"World size: {world_size}, Local rank: {local_rank}") @@ -542,6 +616,18 @@ def main(): default=None, help="Root directory for saving checkpoints", ) + parser.add_argument( + "--resume", + action="store_true", + help="Resume training from the latest checkpoint under /checkpoints/", + ) + parser.add_argument( + "--resume-from", + dest="resume_from", + type=str, + default=None, + help="Resume training from a specific checkpoint directory (e.g. .../checkpoints/checkpoint_step_2000)", + ) args = parser.parse_args() run(args)