diff --git a/docs/sphinx/source/en/2-user_guide/1-training/4-multi_gpu.md b/docs/sphinx/source/en/2-user_guide/1-training/4-multi_gpu.md index 56de4d942..198639b25 100644 --- a/docs/sphinx/source/en/2-user_guide/1-training/4-multi_gpu.md +++ b/docs/sphinx/source/en/2-user_guide/1-training/4-multi_gpu.md @@ -1,8 +1,10 @@ # Multi-GPU -The currently validated multi-GPU training path is SAC in replay-buffer mode. -Use the unified CLI as usual, and enable multiple GPUs with the shared -off-policy field `training.num_gpus`. +The currently validated multi-GPU training paths are SAC/FastSAC and FlashSAC +in replay-buffer mode. Use the unified CLI as usual, and enable multiple GPUs +with the shared off-policy field `training.num_gpus`. The multi-GPU runner is a +generic off-policy orchestration layer, but each learner must explicitly opt +into the distributed learner contract. The multi-GPU runner keeps algorithm code separate from IPC: a collector fills the CPU replay buffer on the host, the runner packs batches for each learner @@ -27,7 +29,8 @@ this avoids extending communication to AdamW momentum state. When `algo.obs_normalization=true`, each learner rank updates its observation normalizer from cross-rank global batch moments; rank 0 publishes the matching mean/std to the CPU collector at the same synchronization point as actor -weights. +weights. FlashSAC reward normalization keeps the replay-order update on rank 0 +and broadcasts the normalizer state to other ranks before learner updates. For strict per-update gradient averaging, set `training.multi_gpu_sync_mode=sync_sgd`. That mode is closer to single-GPU @@ -53,7 +56,9 @@ single-GPU `algo.batch_size=8192` corresponds to two-GPU ## Preconditions -- SAC only: `training.num_gpus > 1` rejects TD3, FlashSAC, PPO, MLX PPO, and APPO. +- FastSAC and FlashSAC learners support this path; `training.num_gpus > 1` + rejects TD3, PPO, MLX PPO, APPO, and custom SAC runtimes until their learners + declare support. - CUDA is required; select physical cards with `CUDA_VISIBLE_DEVICES`. - SAC symmetry augmentation is not supported in multi-GPU mode. If the task owner enables it by default, set `algo.use_symmetry=false`. @@ -95,6 +100,16 @@ CUDA_VISIBLE_DEVICES=0,7 uv run train --algo sac --task g1_walk_flat --sim mujoc Logs still use SAC's default directory: `logs/fast_sac//`. +FlashSAC uses the same knobs: + +```bash +uv run train --algo flashsac --task g1_walk_flat --sim mujoco \ + training.num_gpus=2 \ + training.multi_gpu_sync_mode=local_sgd +``` + +FlashSAC logs still use `logs/flash_sac//`. + ## Performance Checks Multi-GPU mainly targets learner update bottlenecks. The collector is still one @@ -116,7 +131,8 @@ compare steady-state `perf/iter_ms`, `timing/learner_train_ms`, ## Common Errors -- `Only SAC supports training.num_gpus > 1`: only SAC is validated right now. +- ` does not support training.num_gpus > 1`: that learner has not + declared and validated the multi-GPU contract yet. - `SAC multi-GPU training requires a CUDA device`: CUDA is unavailable, or `training.device` was set to CPU. - `set training.num_gpus=1 or algo.use_symmetry=false`: multi-GPU SAC does not diff --git a/docs/sphinx/source/en/2-user_guide/2-algorithms/3-sac.md b/docs/sphinx/source/en/2-user_guide/2-algorithms/3-sac.md index c72b5492b..f1011be43 100644 --- a/docs/sphinx/source/en/2-user_guide/2-algorithms/3-sac.md +++ b/docs/sphinx/source/en/2-user_guide/2-algorithms/3-sac.md @@ -11,10 +11,12 @@ The off-policy runner decouples CPU simulation from GPU learning through shared memory: a collector subprocess fills a CPU-resident replay buffer while the learner trains on the GPU. -SAC is also the currently validated replay-buffer multi-GPU algorithm. Enable it -with `training.num_gpus > 1`; the host side packs and distributes batches in -parallel, while the GPU learners default to delayed parameter averaging via -`training.multi_gpu_sync_mode=local_sgd`. See +The default FastSAC learner is also the currently validated replay-buffer +multi-GPU SAC implementation. Enable it with `training.num_gpus > 1`; the host +side packs and distributes batches in parallel, while the GPU learners default +to delayed parameter averaging via `training.multi_gpu_sync_mode=local_sgd`. +Custom SAC runtimes must explicitly declare the distributed learner contract +before they can use this path. See {doc}`../1-training/4-multi_gpu` for the full command, strict-sync fallback, and constraints. diff --git a/docs/sphinx/source/en/2-user_guide/2-algorithms/5-flash_sac.md b/docs/sphinx/source/en/2-user_guide/2-algorithms/5-flash_sac.md index f2a3ad719..c88d72cd3 100644 --- a/docs/sphinx/source/en/2-user_guide/2-algorithms/5-flash_sac.md +++ b/docs/sphinx/source/en/2-user_guide/2-algorithms/5-flash_sac.md @@ -30,7 +30,17 @@ playback video. See {doc}`/en/1-getting_started/3-evaluation_and_playback`. - `algo.algo_params.actor_num_blocks=2` - `algo.algo_params.critic_num_blocks=2` -`scripts/train_offpolicy.py` rejects `training.num_gpus > 1` for FlashSAC, so -keep the default single-GPU path unless the implementation changes. +FlashSAC supports the shared off-policy multi-GPU runner. Enable it with: + +```bash +uv run train --algo flashsac --task g1_walk_flat --sim mujoco \ + training.num_gpus=2 \ + training.multi_gpu_sync_mode=local_sgd +``` + +Multi-GPU FlashSAC requires CUDA and synchronized collection. The learner owns +its distributed synchronization hooks: gradients are averaged in `sync_sgd`, +parameters and persistent normalization buffers are averaged in `local_sgd`, and +reward normalizer state is updated on rank 0 then broadcast to the other ranks. The log root is `logs/flash_sac//`. diff --git a/docs/sphinx/source/zh_CN/2-user_guide/1-training/4-multi_gpu.md b/docs/sphinx/source/zh_CN/2-user_guide/1-training/4-multi_gpu.md index b6bcc41ee..e6a010890 100644 --- a/docs/sphinx/source/zh_CN/2-user_guide/1-training/4-multi_gpu.md +++ b/docs/sphinx/source/zh_CN/2-user_guide/1-training/4-multi_gpu.md @@ -1,8 +1,9 @@ # 多 GPU -当前已验证的多 GPU 训练路径是 SAC 的 replay-buffer 模式。入口仍然是统一 CLI: -`uv run train --algo sac ...`,多卡由共享 off-policy 配置字段 -`training.num_gpus` 打开。 +当前已验证的多 GPU 训练路径是 SAC/FastSAC 和 FlashSAC 的 replay-buffer 模式。入口 +仍然是统一 CLI,多卡由共享 off-policy 配置字段 `training.num_gpus` 打开。多 GPU +runner 是通用的 off-policy 编排层,但 learner 必须通过分布式 learner contract 显式声 +明支持。 多 GPU runner 保持算法与 IPC 隔离:collector 在 host 侧填充 CPU replay buffer, runner 根据各 learner rank 的请求打包 batch,并通过 pinned-memory pipeline 并行分 @@ -22,7 +23,8 @@ learner iteration 同步一次;增大该值可以进一步减少 4 卡、8 卡 开启 `algo.obs_normalization=true` 时,每个 learner rank 使用跨 rank 聚合后的全局 batch moments 更新 observation normalizer;rank0 在发布 actor 权重给 CPU -collector 的同一同步点发布对应 mean/std。 +collector 的同一同步点发布对应 mean/std。FlashSAC 的 reward normalization 保持由 +rank0 按 replay 写入顺序更新,并在 learner update 前广播 normalizer 状态给其它 rank。 如需严格的每次 update 梯度平均,可显式设置 `training.multi_gpu_sync_mode=sync_sgd`。该模式更接近单卡 global batch 的同步 @@ -44,7 +46,8 @@ batch**,不是跨所有 GPU 的 global batch。`training.num_gpus=N` 时,每 ## 前置条件 -- 只支持 SAC:`training.num_gpus > 1` 会拒绝 TD3、FlashSAC、PPO、MLX PPO 和 APPO。 +- FastSAC 与 FlashSAC learner 支持该路径;`training.num_gpus > 1` 会拒绝尚未声明该 + 能力的 TD3、PPO、MLX PPO、APPO 和 custom SAC runtime。 - 必须使用 CUDA 设备;用 `CUDA_VISIBLE_DEVICES` 选择物理卡。 - SAC 的对称增强当前不支持多卡;若任务 owner 默认开启,需要设置 `algo.use_symmetry=false`。 @@ -85,6 +88,16 @@ CUDA_VISIBLE_DEVICES=0,7 uv run train --algo sac --task g1_walk_flat --sim mujoc 日志仍写入 SAC 的默认目录:`logs/fast_sac//`。 +FlashSAC 使用同一组多卡参数: + +```bash +uv run train --algo flashsac --task g1_walk_flat --sim mujoco \ + training.num_gpus=2 \ + training.multi_gpu_sync_mode=local_sgd +``` + +FlashSAC 日志仍写入 `logs/flash_sac//`。 + ## 性能检查 多 GPU 主要减少 learner 更新瓶颈。collector 仍是单个 CPU 进程,所以 @@ -103,7 +116,8 @@ ring-buffer 窗口,让 CPU 随机 gather 与下一次 env step 重叠,同时 ## 常见错误 -- `Only SAC supports training.num_gpus > 1`:当前只验证 SAC。 +- ` does not support training.num_gpus > 1`:该 learner 尚未声明并验证多 + GPU contract。 - `SAC multi-GPU training requires a CUDA device`:没有可用 CUDA,或 `training.device` 被设成了 CPU。 - `set training.num_gpus=1 or algo.use_symmetry=false`:多卡 SAC 暂不支持对称增 diff --git a/docs/sphinx/source/zh_CN/2-user_guide/2-algorithms/3-sac.md b/docs/sphinx/source/zh_CN/2-user_guide/2-algorithms/3-sac.md index 126e5cfe5..7afdbf602 100644 --- a/docs/sphinx/source/zh_CN/2-user_guide/2-algorithms/3-sac.md +++ b/docs/sphinx/source/zh_CN/2-user_guide/2-algorithms/3-sac.md @@ -9,10 +9,11 @@ SAC 通过共享的 off-policy 入口 `scripts/train_offpolicy.py` 选择,TD3 off-policy runner 通过 shared memory 把 CPU 仿真与 GPU 学习解耦:collector 子进程 填充驻留在 CPU 上的 replay buffer,learner 在 GPU 上训练。 -SAC 也是当前已验证的 replay-buffer 多 GPU 训练算法。多卡模式通过 +默认 FastSAC learner 也是当前已验证的 replay-buffer 多 GPU SAC 实现。多卡模式通过 `training.num_gpus > 1` 打开,host 侧并行打包并分发 batch,多张 GPU 上的 learner -默认使用 `training.multi_gpu_sync_mode=local_sgd` 做 delayed-sync 参数平均。完整命 -令、严格同步回退和限制见 {doc}`../1-training/4-multi_gpu`。 +默认使用 `training.multi_gpu_sync_mode=local_sgd` 做 delayed-sync 参数平均。custom +SAC runtime 必须显式声明 distributed learner contract 后才能使用这条路径。完整命令、 +严格同步回退和限制见 {doc}`../1-training/4-multi_gpu`。 ## 快速开始 diff --git a/docs/sphinx/source/zh_CN/2-user_guide/2-algorithms/5-flash_sac.md b/docs/sphinx/source/zh_CN/2-user_guide/2-algorithms/5-flash_sac.md index c5687e3f9..82a160e16 100644 --- a/docs/sphinx/source/zh_CN/2-user_guide/2-algorithms/5-flash_sac.md +++ b/docs/sphinx/source/zh_CN/2-user_guide/2-algorithms/5-flash_sac.md @@ -28,7 +28,16 @@ uv run train --algo flashsac --task go2_joystick_flat --sim mujoco training.no_p - `algo.algo_params.actor_num_blocks=2` - `algo.algo_params.critic_num_blocks=2` -`scripts/train_offpolicy.py` 会拒绝 FlashSAC 的 `training.num_gpus > 1`,因此除非实 -现发生变化,否则请保持默认的单 GPU 路径。 +FlashSAC 支持共享的 off-policy 多 GPU runner。使用方式: + +```bash +uv run train --algo flashsac --task g1_walk_flat --sim mujoco \ + training.num_gpus=2 \ + training.multi_gpu_sync_mode=local_sgd +``` + +多 GPU FlashSAC 要求 CUDA 和同步采集。learner 自己拥有分布式同步 hook: +`sync_sgd` 下同步梯度,`local_sgd` 下平均参数和 persistent normalization buffer, +reward normalizer 由 rank0 按 replay 写入顺序更新后广播给其它 rank。 日志根目录为 `logs/flash_sac//`。 diff --git a/scripts/train_offpolicy.py b/scripts/train_offpolicy.py index f52d9fc27..daf7d98de 100644 --- a/scripts/train_offpolicy.py +++ b/scripts/train_offpolicy.py @@ -178,8 +178,6 @@ def build_runner(algo_name: str, cfg: DictConfig): num_gpus = int(getattr(cfg.training, "num_gpus", 1)) multi_gpu_sync_mode = str(getattr(cfg.training, "multi_gpu_sync_mode", "local_sgd")) multi_gpu_sync_interval = int(getattr(cfg.training, "multi_gpu_sync_interval", 1)) - if num_gpus > 1 and algo_name != "sac": - raise ValueError("Only SAC supports training.num_gpus > 1 in this validation round") sync_collection = not bool(cfg.training.no_sync_collection) @@ -270,15 +268,25 @@ def build_runner(algo_name: str, cfg: DictConfig): "critic_obs_dim": _critic_dim, **_learner_extra_kwargs, } - _learner = _learner_cls(device=_device, **_learner_kwargs) if num_gpus > 1: + from unilab.algos.torch.offpolicy.distributed import ( + validate_distributed_learner_capability, + ) from unilab.algos.torch.offpolicy.multi_gpu_runner import MultiGPUOffPolicyRunner if not str(_device).startswith("cuda"): - raise ValueError("SAC multi-GPU training requires a CUDA device") + raise ValueError(f"{_algo_type} multi-GPU training requires a CUDA device") if not sync_collection: raise ValueError("Multi-GPU off-policy replay requires synchronized collection") + validate_distributed_learner_capability( + algo_type=_algo_type, + learner_cls=_learner_cls, + learner_kwargs=_learner_kwargs, + num_gpus=num_gpus, + sync_mode=multi_gpu_sync_mode, + ) + _learner = _learner_cls(device=_device, **_learner_kwargs) return MultiGPUOffPolicyRunner( learner=_learner, env_name=cfg.training.task_name, @@ -312,6 +320,7 @@ def build_runner(algo_name: str, cfg: DictConfig): nan_guard_cfg=_nan_guard_cfg, ) + _learner = _learner_cls(device=_device, **_learner_kwargs) return DoubleBufferOffPolicyRunner( learner=_learner, env_name=cfg.training.task_name, @@ -344,11 +353,21 @@ def build_runner(algo_name: str, cfg: DictConfig): if algo_name == "td3": from unilab.algos.torch.common.device import get_env_dims from unilab.algos.torch.fast_td3.learner import FastTD3Learner + from unilab.algos.torch.offpolicy.distributed import ( + validate_distributed_learner_capability, + ) from unilab.algos.torch.offpolicy.double_buffer_runner import ( DoubleBufferOffPolicyRunner, ) from unilab.utils.device import get_default_device + validate_distributed_learner_capability( + learner_cls=FastTD3Learner, + algo_type="td3", + learner_kwargs={}, + num_gpus=num_gpus, + sync_mode=multi_gpu_sync_mode, + ) _device = cfg.training.device or get_default_device() _obs_dim, _action_dim, _critic_dim = get_env_dims( cfg.training.task_name, diff --git a/src/unilab/algos/torch/fast_sac/learner.py b/src/unilab/algos/torch/fast_sac/learner.py index 045f6800b..9841eab07 100644 --- a/src/unilab/algos/torch/fast_sac/learner.py +++ b/src/unilab/algos/torch/fast_sac/learner.py @@ -378,6 +378,10 @@ class FastSACLearner: - Distributional critic (C51, num_atoms=101) """ + supports_multi_gpu = True + supports_multi_gpu_symmetry = False + supported_multi_gpu_sync_modes = frozenset({"sync_sgd", "local_sgd"}) + def __init__( self, obs_dim: int, diff --git a/src/unilab/algos/torch/fast_td3/learner.py b/src/unilab/algos/torch/fast_td3/learner.py index 3c13e22d6..03e6f79a2 100644 --- a/src/unilab/algos/torch/fast_td3/learner.py +++ b/src/unilab/algos/torch/fast_td3/learner.py @@ -134,6 +134,10 @@ class FastTD3Learner: - Observation normalization """ + supports_multi_gpu = False + supports_multi_gpu_symmetry = False + supported_multi_gpu_sync_modes: frozenset[str] = frozenset() + def __init__( self, obs_dim: int, diff --git a/src/unilab/algos/torch/flash_sac/double_buffer.py b/src/unilab/algos/torch/flash_sac/double_buffer.py index 8ba79d00e..178906698 100644 --- a/src/unilab/algos/torch/flash_sac/double_buffer.py +++ b/src/unilab/algos/torch/flash_sac/double_buffer.py @@ -7,6 +7,7 @@ from omegaconf import DictConfig from unilab.algos.torch.flash_sac.learner import FlashSACLearner +from unilab.algos.torch.offpolicy.distributed import validate_distributed_learner_capability from unilab.algos.torch.offpolicy.double_buffer_runner import DoubleBufferOffPolicyRunner from unilab.training import create_env, ensure_registries from unilab.training.seed import apply_training_seed @@ -20,9 +21,16 @@ def _validate_flashsac_double_buffer_runtime( device: str, replay_prefetch_mode: str, ) -> None: - _ = device if cfg.training.num_gpus > 1: - raise ValueError("FlashSAC-B cpu_pinned_double_buffer is single-GPU only") + if not str(device).startswith("cuda"): + raise ValueError("FlashSAC multi-GPU training requires a CUDA device") + validate_distributed_learner_capability( + learner_cls=FlashSACLearner, + algo_type="flashsac", + learner_kwargs={}, + num_gpus=int(cfg.training.num_gpus), + sync_mode=str(getattr(cfg.training, "multi_gpu_sync_mode", "local_sgd")), + ) if cfg.training.no_sync_collection: raise ValueError("FlashSAC-B cpu_pinned_double_buffer requires synchronized collection") if replay_prefetch_mode != "one_tick": @@ -40,7 +48,7 @@ def build_flashsac_double_buffer_runner( replay_prefetch_mode: str, verbose_metrics: bool, nan_guard_cfg: NanGuardCfg | None = None, -) -> DoubleBufferOffPolicyRunner: +) -> Any: """Build FlashSAC with the opt-in CPU-pinned double-buffer replay pipeline.""" from unilab.base.observations import get_obs_dims @@ -62,41 +70,91 @@ def build_flashsac_double_buffer_runner( finally: env.close() - learner = FlashSACLearner( - obs_dim=obs_dim, - action_dim=action_dim, - critic_obs_dim=critic_obs_dim, - device=device, - gamma=cfg.algo.gamma, - tau=cfg.algo.tau, - actor_lr=cfg.algo.actor_lr, - critic_lr=cfg.algo.critic_lr, - actor_hidden_dim=cfg.algo.actor_hidden_dim, - critic_hidden_dim=cfg.algo.critic_hidden_dim, - actor_num_blocks=cfg.algo.algo_params.actor_num_blocks, - critic_num_blocks=cfg.algo.algo_params.critic_num_blocks, - num_atoms=cfg.algo.num_atoms, - critic_min_v=cfg.algo.algo_params.critic_min_v, - critic_max_v=cfg.algo.algo_params.critic_max_v, - temp_initial_value=cfg.algo.algo_params.temp_initial_value, - temp_target_sigma=cfg.algo.algo_params.temp_target_sigma, - temp_target_entropy=cfg.algo.algo_params.temp_target_entropy, - actor_bc_alpha=cfg.algo.algo_params.actor_bc_alpha, - actor_noise_zeta_mu=cfg.algo.algo_params.actor_noise_zeta_mu, - actor_noise_zeta_max=cfg.algo.algo_params.actor_noise_zeta_max, - learning_rate_init=cfg.algo.algo_params.learning_rate_init, - learning_rate_peak=cfg.algo.algo_params.learning_rate_peak, - learning_rate_end=cfg.algo.algo_params.learning_rate_end, - learning_rate_warmup_steps=cfg.algo.algo_params.learning_rate_warmup_steps, - learning_rate_decay_steps=cfg.algo.algo_params.learning_rate_decay_steps, - normalize_reward=cfg.algo.algo_params.normalize_reward, - normalized_g_max=cfg.algo.algo_params.normalized_g_max, - n_step=cfg.algo.algo_params.n_step, - obs_normalization=cfg.algo.obs_normalization, - use_amp=cfg.training.use_amp, - amp_dtype=cfg.algo.algo_params.amp_dtype, - use_compile=cfg.algo.algo_params.use_compile, - ) + learner_kwargs = { + "obs_dim": obs_dim, + "action_dim": action_dim, + "critic_obs_dim": critic_obs_dim, + "gamma": cfg.algo.gamma, + "tau": cfg.algo.tau, + "actor_lr": cfg.algo.actor_lr, + "critic_lr": cfg.algo.critic_lr, + "actor_hidden_dim": cfg.algo.actor_hidden_dim, + "critic_hidden_dim": cfg.algo.critic_hidden_dim, + "actor_num_blocks": cfg.algo.algo_params.actor_num_blocks, + "critic_num_blocks": cfg.algo.algo_params.critic_num_blocks, + "num_atoms": cfg.algo.num_atoms, + "critic_min_v": cfg.algo.algo_params.critic_min_v, + "critic_max_v": cfg.algo.algo_params.critic_max_v, + "temp_initial_value": cfg.algo.algo_params.temp_initial_value, + "temp_target_sigma": cfg.algo.algo_params.temp_target_sigma, + "temp_target_entropy": cfg.algo.algo_params.temp_target_entropy, + "actor_bc_alpha": cfg.algo.algo_params.actor_bc_alpha, + "actor_noise_zeta_mu": cfg.algo.algo_params.actor_noise_zeta_mu, + "actor_noise_zeta_max": cfg.algo.algo_params.actor_noise_zeta_max, + "learning_rate_init": cfg.algo.algo_params.learning_rate_init, + "learning_rate_peak": cfg.algo.algo_params.learning_rate_peak, + "learning_rate_end": cfg.algo.algo_params.learning_rate_end, + "learning_rate_warmup_steps": cfg.algo.algo_params.learning_rate_warmup_steps, + "learning_rate_decay_steps": cfg.algo.algo_params.learning_rate_decay_steps, + "normalize_reward": cfg.algo.algo_params.normalize_reward, + "normalized_g_max": cfg.algo.algo_params.normalized_g_max, + "n_step": cfg.algo.algo_params.n_step, + "obs_normalization": cfg.algo.obs_normalization, + "use_amp": cfg.training.use_amp, + "amp_dtype": cfg.algo.algo_params.amp_dtype, + "use_compile": cfg.algo.algo_params.use_compile, + } + actor_kwargs = { + "actor_num_blocks": cfg.algo.algo_params.actor_num_blocks, + "actor_noise_zeta_mu": cfg.algo.algo_params.actor_noise_zeta_mu, + "actor_noise_zeta_max": cfg.algo.algo_params.actor_noise_zeta_max, + } + + if cfg.training.num_gpus > 1: + from unilab.algos.torch.offpolicy.multi_gpu_runner import MultiGPUOffPolicyRunner + + validate_distributed_learner_capability( + learner_cls=FlashSACLearner, + algo_type="flashsac", + learner_kwargs=learner_kwargs, + num_gpus=int(cfg.training.num_gpus), + sync_mode=str(getattr(cfg.training, "multi_gpu_sync_mode", "local_sgd")), + ) + learner = FlashSACLearner(device=device, **learner_kwargs) + return MultiGPUOffPolicyRunner( + learner=learner, + env_name=cfg.training.task_name, + algo_type="flashsac", + learner_cls=FlashSACLearner, + learner_kwargs=learner_kwargs, + num_gpus=int(cfg.training.num_gpus), + distributed_backend="nccl", + multi_gpu_sync_mode=str(getattr(cfg.training, "multi_gpu_sync_mode", "local_sgd")), + multi_gpu_sync_interval=int(getattr(cfg.training, "multi_gpu_sync_interval", 1)), + num_envs=cfg.algo.num_envs, + replay_buffer_n=cfg.algo.replay_buffer_n, + batch_size=cfg.algo.batch_size, + learning_starts=cfg.algo.learning_starts, + updates_per_step=cfg.algo.updates_per_step, + policy_frequency=cfg.algo.policy_frequency, + sync_collection=not cfg.training.no_sync_collection, + env_steps_per_sync=cfg.training.env_steps_per_sync, + device=device, + actor_hidden_dim=cfg.algo.actor_hidden_dim, + use_layer_norm=False, + obs_normalization=cfg.algo.obs_normalization, + sim_backend=cfg.training.sim_backend, + env_cfg_override=env_cfg_override, + actor_kwargs=actor_kwargs, + seed=cfg.algo.seed, + trace_enabled=cfg.training.trace_enabled, + trace_output_dir=cfg.training.trace_output_dir, + trace_thread_time=cfg.training.trace_thread_time, + trace_cuda_events=cfg.training.trace_cuda_events, + nan_guard_cfg=nan_guard_cfg, + ) + + learner = FlashSACLearner(device=device, **learner_kwargs) return DoubleBufferOffPolicyRunner( learner=learner, @@ -116,11 +174,7 @@ def build_flashsac_double_buffer_runner( obs_normalization=cfg.algo.obs_normalization, sim_backend=cfg.training.sim_backend, env_cfg_override=env_cfg_override, - actor_kwargs={ - "actor_num_blocks": cfg.algo.algo_params.actor_num_blocks, - "actor_noise_zeta_mu": cfg.algo.algo_params.actor_noise_zeta_mu, - "actor_noise_zeta_max": cfg.algo.algo_params.actor_noise_zeta_max, - }, + actor_kwargs=actor_kwargs, seed=cfg.algo.seed, trace_enabled=cfg.training.trace_enabled, trace_output_dir=cfg.training.trace_output_dir, diff --git a/src/unilab/algos/torch/flash_sac/learner.py b/src/unilab/algos/torch/flash_sac/learner.py index f28abee1e..744bde3f0 100644 --- a/src/unilab/algos/torch/flash_sac/learner.py +++ b/src/unilab/algos/torch/flash_sac/learner.py @@ -7,6 +7,7 @@ from typing import Any, cast import torch +import torch.distributed as dist import torch.nn as nn import torch.optim as optim @@ -22,6 +23,7 @@ resolve_target_entropy, select_min_q_log_probs, ) +from unilab.algos.torch.offpolicy.distributed import normalize_distributed_sync_mode @dataclass @@ -132,8 +134,29 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.g_r = state_dict["g_r"] self.g_r_max = state_dict["g_r_max"] + def broadcast(self, src: int = 0) -> None: + if not (dist.is_available() and dist.is_initialized()): + return + g_r_size = torch.tensor([self.g_r.numel()], device=self.device, dtype=torch.int64) + dist.broadcast(g_r_size, src=src) + num_envs = int(g_r_size.item()) + self._ensure_g_r_shape(num_envs) + for tensor in ( + self.rms.mean, + self.rms.var, + self.rms.count, + self.g_r_max, + ): + dist.broadcast(tensor, src=src) + if num_envs > 0: + dist.broadcast(self.g_r, src=src) + class FlashSACLearner: + supports_multi_gpu = True + supports_multi_gpu_symmetry = False + supported_multi_gpu_sync_modes = frozenset({"sync_sgd", "local_sgd"}) + def __init__( self, obs_dim: int, @@ -169,6 +192,8 @@ def __init__( use_amp: bool = False, amp_dtype: str = "auto", use_compile: bool = False, + world_size: int = 1, + distributed_sync_mode: str = "sync_sgd", ): self.device = torch.device(device) self.gamma = gamma @@ -185,6 +210,8 @@ def __init__( self.use_compile = bool( use_compile and get_torch_compile_for_cuda(self.device, warn=True) is not None ) + self.world_size = int(world_size) + self.distributed_sync_mode = normalize_distributed_sync_mode(distributed_sync_mode) self.actor = FlashSACActor( num_blocks=actor_num_blocks, @@ -294,13 +321,126 @@ def _should_use_grad_scaler( def _maybe_normalize_obs(self, obs: torch.Tensor, *, update: bool) -> torch.Tensor: if isinstance(self.obs_normalizer, nn.Identity): return obs - return cast(torch.Tensor, self.obs_normalizer(obs, update=update)) + normalizer = cast(EmpiricalNormalization, self.obs_normalizer) + if update: + self._update_obs_normalizer(obs) + return cast(torch.Tensor, normalizer(obs, update=False)) + return cast(torch.Tensor, normalizer(obs, update=False)) def _autocast(self): return torch.autocast( device_type=self.device.type, dtype=self._amp_dtype, enabled=self.use_amp ) + def _distributed_normalization_ready(self) -> bool: + return self.world_size > 1 and dist.is_available() and dist.is_initialized() + + @torch.no_grad() + def _update_obs_normalizer(self, obs: torch.Tensor) -> None: + if isinstance(self.obs_normalizer, nn.Identity): + return + normalizer = cast(EmpiricalNormalization, self.obs_normalizer) + if not self._distributed_normalization_ready(): + normalizer.update(obs) + return + + obs_for_stats = obs.detach().to(dtype=normalizer._mean.dtype) + obs_dim = int(obs_for_stats.shape[-1]) + moment_payload = torch.cat( + [ + obs_for_stats.sum(dim=0), + obs_for_stats.square().sum(dim=0), + torch.tensor( + [obs_for_stats.shape[0]], + device=obs_for_stats.device, + dtype=obs_for_stats.dtype, + ), + ] + ) + dist.all_reduce(moment_payload, op=dist.ReduceOp.SUM) + batch_count = moment_payload[-1].clamp_min(1.0) + batch_mean = (moment_payload[:obs_dim] / batch_count).view_as(normalizer._mean) + batch_var = ( + moment_payload[obs_dim : 2 * obs_dim] / batch_count - batch_mean.view(-1).square() + ).clamp_min(0.0) + normalizer.update_from_moments( + batch_mean, + batch_var.view_as(normalizer._var), + batch_count.round().to(dtype=normalizer.count.dtype), + ) + + def _reduce_gradients(self, module: nn.Module) -> bool: + if self.world_size <= 1 or self.distributed_sync_mode != "sync_sgd": + return True + grads = [param.grad.reshape(-1) for param in module.parameters() if param.grad is not None] + if not grads: + return True + flat = torch.cat(grads) + dist.all_reduce(flat, op=dist.ReduceOp.SUM) + flat /= self.world_size + if not bool(torch.isfinite(flat).all().item()): + return False + offset = 0 + for param in module.parameters(): + if param.grad is None: + continue + n = param.grad.numel() + param.grad.copy_(flat[offset : offset + n].view_as(param.grad)) + offset += n + return True + + def _backoff_grad_scaler(self) -> None: + if self.scaler is None: + return + self.scaler.update(self.scaler.get_scale() * self.scaler.get_backoff_factor()) + + def _parameter_sync_tensors(self) -> list[torch.Tensor]: + tensors: list[torch.Tensor] = [] + for module in (self.actor, self.critic, self.target_critic, self.temperature): + tensors.extend( + tensor + for tensor in module.state_dict().values() + if torch.is_tensor(tensor) and tensor.is_floating_point() + ) + return tensors + + @torch.no_grad() + def average_distributed_parameters(self) -> None: + if self.world_size <= 1: + return + tensors = self._parameter_sync_tensors() + if not tensors: + return + flat = torch.cat([tensor.reshape(-1) for tensor in tensors]) + dist.all_reduce(flat, op=dist.ReduceOp.SUM) + flat /= self.world_size + offset = 0 + for tensor in tensors: + n = tensor.numel() + tensor.copy_(flat[offset : offset + n].view_as(tensor)) + offset += n + self.actor.normalize_parameters() + self.critic.normalize_parameters() + self.target_critic.normalize_parameters() + + def sync_initial_parameters(self, src: int = 0) -> None: + if self.world_size <= 1: + return + for module in (self.actor, self.critic, self.target_critic, self.temperature): + for tensor in module.state_dict().values(): + if torch.is_tensor(tensor): + dist.broadcast(tensor, src=src) + if not isinstance(self.obs_normalizer, nn.Identity): + for tensor in self.obs_normalizer.state_dict().values(): + if torch.is_tensor(tensor): + dist.broadcast(tensor, src=src) + self.sync_reward_normalizer(src=src) + + def sync_reward_normalizer(self, src: int = 0) -> None: + if self.world_size <= 1 or self.reward_normalizer is None: + return + self.reward_normalizer.broadcast(src=src) + def update_reward_stats( self, rewards: torch.Tensor, @@ -415,15 +555,24 @@ def update_critic(self, batch: dict[str, torch.Tensor]) -> dict[str, float]: ) self.critic_optimizer.zero_grad(set_to_none=True) + critic_step_ok = True if self.scaler is not None: self.scaler.scale(critic_loss).backward() - self.scaler.step(self.critic_optimizer) - self.scaler.update() + self.scaler.unscale_(self.critic_optimizer) + critic_step_ok = self._reduce_gradients(self.critic) + if critic_step_ok: + self.scaler.step(self.critic_optimizer) + self.scaler.update() + else: + self._backoff_grad_scaler() else: critic_loss.backward() - self.critic_optimizer.step() - self.critic_scheduler.step() - self.critic.normalize_parameters() + critic_step_ok = self._reduce_gradients(self.critic) + if critic_step_ok: + self.critic_optimizer.step() + if critic_step_ok: + self.critic_scheduler.step() + self.critic.normalize_parameters() return { "critic_loss": float(critic_loss.detach().cpu()), @@ -458,22 +607,32 @@ def update_actor(self, batch: dict[str, torch.Tensor]) -> dict[str, float]: ) self.actor_optimizer.zero_grad(set_to_none=True) + actor_step_ok = True if self.scaler is not None: self.scaler.scale(actor_loss).backward() - self.scaler.step(self.actor_optimizer) - self.scaler.update() + self.scaler.unscale_(self.actor_optimizer) + actor_step_ok = self._reduce_gradients(self.actor) + if actor_step_ok: + self.scaler.step(self.actor_optimizer) + self.scaler.update() + else: + self._backoff_grad_scaler() else: actor_loss.backward() - self.actor_optimizer.step() - self.actor_scheduler.step() - self.actor.normalize_parameters() + actor_step_ok = self._reduce_gradients(self.actor) + if actor_step_ok: + self.actor_optimizer.step() + if actor_step_ok: + self.actor_scheduler.step() + self.actor.normalize_parameters() temp_value = self.temperature() temp_loss = temp_value * (entropy - self.target_entropy) self.temperature_optimizer.zero_grad(set_to_none=True) temp_loss.backward() - self.temperature_optimizer.step() - self.temperature_scheduler.step() + if self._reduce_gradients(self.temperature): + self.temperature_optimizer.step() + self.temperature_scheduler.step() return { "actor_loss": float(actor_loss.detach().cpu()), diff --git a/src/unilab/algos/torch/hora/sac_learner.py b/src/unilab/algos/torch/hora/sac_learner.py index 8f270058c..2968a1e9e 100644 --- a/src/unilab/algos/torch/hora/sac_learner.py +++ b/src/unilab/algos/torch/hora/sac_learner.py @@ -32,6 +32,8 @@ def derive_priv_info_from_critic_obs( class HoraSACLearner(FastSACLearner): """FastSAC learner variant whose actor consumes HORA privileged info.""" + supports_multi_gpu = False + def __init__( self, *, diff --git a/src/unilab/algos/torch/offpolicy/__init__.py b/src/unilab/algos/torch/offpolicy/__init__.py index 96d6afd3e..9991471f6 100644 --- a/src/unilab/algos/torch/offpolicy/__init__.py +++ b/src/unilab/algos/torch/offpolicy/__init__.py @@ -1,5 +1,10 @@ """Off-policy RL unified infrastructure.""" +from unilab.algos.torch.offpolicy.distributed import ( + DistributedLearnerHooks, + DistributedOffPolicyLearner, + validate_distributed_learner_capability, +) from unilab.algos.torch.offpolicy.multi_gpu_runner import MultiGPUOffPolicyRunner from unilab.algos.torch.offpolicy.runner import OffPolicyRunner from unilab.algos.torch.offpolicy.worker import off_policy_collector_fn @@ -9,5 +14,8 @@ "OffPolicyLogger", "OffPolicyRunner", "MultiGPUOffPolicyRunner", + "DistributedLearnerHooks", + "DistributedOffPolicyLearner", + "validate_distributed_learner_capability", "off_policy_collector_fn", ] diff --git a/src/unilab/algos/torch/offpolicy/distributed.py b/src/unilab/algos/torch/offpolicy/distributed.py new file mode 100644 index 000000000..759791bf9 --- /dev/null +++ b/src/unilab/algos/torch/offpolicy/distributed.py @@ -0,0 +1,145 @@ +"""Distributed learner contract for off-policy multi-GPU training. + +The runner owns process, replay, and IPC orchestration. Algorithm learners own +which model states participate in distributed synchronization. Keeping that +boundary explicit avoids generic DDP wrappers in the hot path. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Protocol, cast + +MULTIGPU_SYNC_MODES = frozenset({"sync_sgd", "local_sgd"}) + + +class DistributedOffPolicyLearner(Protocol): + """Protocol implemented by learners that opt into multi-GPU off-policy training.""" + + supports_multi_gpu: bool + supports_multi_gpu_symmetry: bool + supported_multi_gpu_sync_modes: frozenset[str] + actor: Any + update_count: int + + def update_critic(self, batch: dict[str, Any]) -> dict[str, float]: ... + + def update_actor(self, batch: dict[str, Any]) -> dict[str, float]: ... + + def soft_update_target(self) -> None: ... + + def get_state_dict(self) -> dict[str, Any]: ... + + def sync_initial_parameters(self, src: int = 0) -> None: ... + + def average_distributed_parameters(self) -> None: ... + + +@dataclass(frozen=True, slots=True) +class DistributedLearnerHooks: + """Bound distributed hooks resolved once per learner worker.""" + + sync_initial_parameters: Callable[..., None] + average_distributed_parameters: Callable[[], None] + + +def _learner_name(learner_cls: type[Any]) -> str: + return getattr(learner_cls, "__name__", str(learner_cls)) + + +def normalize_distributed_sync_mode(mode: str) -> str: + """Return a validated distributed learner synchronization mode.""" + normalized = str(mode).strip().lower() + if normalized not in MULTIGPU_SYNC_MODES: + supported = ", ".join(sorted(MULTIGPU_SYNC_MODES)) + raise ValueError(f"training.multi_gpu_sync_mode must be one of: {supported}; got {mode!r}") + return normalized + + +def validate_distributed_learner_capability( + *, + learner_cls: type[Any], + algo_type: str, + learner_kwargs: dict[str, Any], + num_gpus: int, + sync_mode: str, +) -> None: + """Validate that a learner class has explicitly opted into multi-GPU training.""" + if int(num_gpus) <= 1: + return + + normalized_sync_mode = normalize_distributed_sync_mode(sync_mode) + + learner_name = _learner_name(learner_cls) + if not bool(learner_cls.__dict__.get("supports_multi_gpu", False)): + raise ValueError( + f"{algo_type} learner {learner_name} does not support training.num_gpus > 1" + ) + + supported_modes = set(getattr(learner_cls, "supported_multi_gpu_sync_modes", ())) + if normalized_sync_mode not in supported_modes: + supported = ", ".join(sorted(supported_modes)) or "" + raise ValueError( + f"{algo_type} learner {learner_name} does not support " + f"training.multi_gpu_sync_mode={sync_mode!r}; supported modes: {supported}" + ) + + if bool(learner_kwargs.get("use_symmetry", False)) and not bool( + getattr(learner_cls, "supports_multi_gpu_symmetry", False) + ): + raise ValueError( + "Off-policy symmetry augmentation does not support training.num_gpus > 1; " + "set training.num_gpus=1 or algo.use_symmetry=false" + ) + + sync_initial_parameters = getattr(learner_cls, "sync_initial_parameters", None) + if not callable(sync_initial_parameters): + raise ValueError( + f"{algo_type} learner {learner_name} must implement " + "sync_initial_parameters(src=0) for multi-GPU training" + ) + if normalized_sync_mode == "local_sgd": + average_distributed_parameters = getattr( + learner_cls, + "average_distributed_parameters", + None, + ) + if not callable(average_distributed_parameters): + raise ValueError( + f"{algo_type} learner {learner_name} must implement " + "average_distributed_parameters() for local_sgd multi-GPU training" + ) + + +def _noop_average_parameters() -> None: + return None + + +def resolve_distributed_learner_hooks( + learner: Any, + *, + sync_mode: str, +) -> DistributedLearnerHooks: + """Resolve distributed learner hooks once, outside the learner update loop.""" + sync_initial_parameters = getattr(learner, "sync_initial_parameters", None) + if not callable(sync_initial_parameters): + raise ValueError( + "Multi-GPU off-policy learner must implement sync_initial_parameters(src=0)" + ) + sync_initial_parameters = cast(Callable[..., None], sync_initial_parameters) + + average_distributed_parameters = getattr(learner, "average_distributed_parameters", None) + average_parameters: Callable[[], None] + if sync_mode == "local_sgd": + if not callable(average_distributed_parameters): + raise ValueError( + "Multi-GPU local_sgd requires learner.average_distributed_parameters()" + ) + average_parameters = cast(Callable[[], None], average_distributed_parameters) + else: + average_parameters = _noop_average_parameters + + return DistributedLearnerHooks( + sync_initial_parameters=sync_initial_parameters, + average_distributed_parameters=average_parameters, + ) diff --git a/src/unilab/algos/torch/offpolicy/multi_gpu_runner.py b/src/unilab/algos/torch/offpolicy/multi_gpu_runner.py index 5e0d0826f..377b4f035 100644 --- a/src/unilab/algos/torch/offpolicy/multi_gpu_runner.py +++ b/src/unilab/algos/torch/offpolicy/multi_gpu_runner.py @@ -25,12 +25,18 @@ import torch.distributed as dist import torch.multiprocessing as tmp # torch.multiprocessing for spawn +from unilab.algos.torch.offpolicy.distributed import ( + normalize_distributed_sync_mode, + resolve_distributed_learner_hooks, + validate_distributed_learner_capability, +) from unilab.algos.torch.offpolicy.runner import ( OffPolicyRunner, build_offpolicy_sample_info, build_reward_comparison_metrics, compute_train_start_threshold, replay_buffer_ready_for_learning, + update_reward_stats_from_replay, ) from unilab.algos.torch.offpolicy.worker import off_policy_collector_fn from unilab.ipc import SharedObsNormStats, SharedWeightSync @@ -41,7 +47,6 @@ from unilab.training.seed import apply_training_seed, derive_worker_seed MULTIGPU_REPLAY_READY_POLL_SEC = 0.001 -MULTIGPU_SYNC_MODES = {"sync_sgd", "local_sgd"} class _CollectorDiedError(RuntimeError): @@ -56,11 +61,7 @@ def _find_free_port() -> int: def normalize_multi_gpu_sync_mode(mode: str) -> str: """Return a validated multi-GPU learner synchronization mode.""" - normalized = str(mode).strip().lower() - if normalized not in MULTIGPU_SYNC_MODES: - supported = ", ".join(sorted(MULTIGPU_SYNC_MODES)) - raise ValueError(f"training.multi_gpu_sync_mode must be one of: {supported}; got {mode!r}") - return normalized + return normalize_distributed_sync_mode(mode) def normalize_multi_gpu_sync_interval(interval: int) -> int: @@ -141,6 +142,12 @@ def _publish_obs_normalizer_stats(learner: Any, shared_obs_normalizer_stats: Any shared_obs_normalizer_stats.put((mean.detach().cpu().numpy(), std.detach().cpu().numpy())) +def _format_multi_gpu_algo_name(algo_type: str, world_size: int) -> str: + if algo_type == "flashsac": + return f"FlashSAC_x{world_size}GPU" + return f"Fast{algo_type.upper()}_x{world_size}GPU" + + def _learner_worker( rank: int, world_size: int, @@ -183,19 +190,23 @@ def _learner_worker( replay_buffer.device = device # 2. Create learner on this device - learner_kwargs = dict(learner_kwargs) - learner_kwargs["distributed_sync_mode"] = normalize_multi_gpu_sync_mode( + sync_mode = normalize_multi_gpu_sync_mode( str(runner_kwargs.get("multi_gpu_sync_mode", "local_sgd")) ) + learner_kwargs = dict(learner_kwargs) + learner_kwargs["distributed_sync_mode"] = sync_mode learner = learner_cls(device=device, world_size=world_size, **learner_kwargs) # 3. Broadcast rank-0 params so all workers start identically. - sync_initial_parameters = getattr(learner, "sync_initial_parameters", None) - if not callable(sync_initial_parameters): - raise ValueError( - "Multi-GPU off-policy learner must implement sync_initial_parameters(src=0)" - ) - sync_initial_parameters(src=0) + distributed_hooks = resolve_distributed_learner_hooks(learner, sync_mode=sync_mode) + distributed_hooks.sync_initial_parameters(src=0) + sync_reward_normalizer = getattr(learner, "sync_reward_normalizer", None) + if not callable(sync_reward_normalizer): + sync_reward_normalizer = None + has_reward_stats = ( + hasattr(learner, "update_reward_stats") + and getattr(learner, "reward_normalizer", None) is not None + ) # 4. Reconnect to the shared weight-sync buffer weight_sync = SharedWeightSync( @@ -216,9 +227,6 @@ def _learner_worker( obs_dim: int = runner_kwargs["obs_dim"] action_dim: int = runner_kwargs["action_dim"] logger_type: str = runner_kwargs.get("logger_type", "tensorboard") - sync_mode = normalize_multi_gpu_sync_mode( - str(runner_kwargs.get("multi_gpu_sync_mode", "local_sgd")) - ) sync_interval = normalize_multi_gpu_sync_interval( int(runner_kwargs.get("multi_gpu_sync_interval", 1)) ) @@ -244,7 +252,10 @@ def _learner_worker( if rank == 0: os.makedirs(log_dir, exist_ok=True) logger = OffPolicyLogger( - algo_name=f"Fast{str(runner_kwargs.get('algo_type', 'offpolicy')).upper()}_x{world_size}GPU", + algo_name=_format_multi_gpu_algo_name( + str(runner_kwargs.get("algo_type", "offpolicy")), + world_size, + ), max_iterations=max_iterations, num_envs=num_envs, env_name=env_name, @@ -276,6 +287,7 @@ def _learner_worker( write_read_ema = 0.0 last_buf_log = 0 prepared_tick: int | None = None + reward_stats_ptr = 0 # 7. Training loop for it in range(1, max_iterations + 1): @@ -349,6 +361,19 @@ def _learner_worker( # --- Training: each rank independently samples a different mini-batch --- iter_metrics: dict = defaultdict(list) ptr_before = int(replay_buffer.ptr[0]) if rank == 0 else 0 + replay_ptr = int(replay_buffer.ptr[0]) + if sync_reward_normalizer is None or rank == 0: + reward_stats_ptr = update_reward_stats_from_replay( + learner, + replay_buffer, + start_ptr=reward_stats_ptr, + end_ptr=replay_ptr, + num_envs=num_envs, + ) + else: + reward_stats_ptr = replay_ptr + if has_reward_stats and sync_reward_normalizer is not None: + sync_reward_normalizer(src=0) if prepared_tick != it: min_prepare_ptr = train_start_threshold if it == 1 else int(replay_buffer.ptr[0]) @@ -438,12 +463,7 @@ def _learner_worker( did_param_sync = False if should_param_sync: param_sync_start = time.perf_counter() - average_parameters = getattr(learner, "average_distributed_parameters", None) - if not callable(average_parameters): - raise ValueError( - "Multi-GPU local_sgd requires learner.average_distributed_parameters()" - ) - average_parameters() + distributed_hooks.average_distributed_parameters() param_sync_time = time.perf_counter() - param_sync_start did_param_sync = True @@ -565,16 +585,18 @@ class MultiGPUOffPolicyRunner(OffPolicyRunner): def validate_capabilities( *, algo_type: str, + learner_cls: Any, learner_kwargs: Dict[str, Any], num_gpus: int, + sync_mode: str = "local_sgd", ) -> None: - if num_gpus <= 1: - return - if algo_type == "sac" and bool(learner_kwargs.get("use_symmetry", False)): - raise ValueError( - "Off-policy symmetry augmentation does not support training.num_gpus > 1; " - "set training.num_gpus=1 or algo.use_symmetry=false" - ) + validate_distributed_learner_capability( + learner_cls=learner_cls, + algo_type=algo_type, + learner_kwargs=learner_kwargs, + num_gpus=num_gpus, + sync_mode=sync_mode, + ) def __init__( self, @@ -589,10 +611,13 @@ def __init__( multi_gpu_sync_interval: int = 1, **kwargs: Any, ) -> None: + normalized_sync_mode = normalize_multi_gpu_sync_mode(multi_gpu_sync_mode) self.validate_capabilities( algo_type=algo_type, + learner_cls=learner_cls, learner_kwargs=learner_kwargs, num_gpus=num_gpus, + sync_mode=normalized_sync_mode, ) super().__init__(learner=learner, env_name=env_name, algo_type=algo_type, **kwargs) self.num_gpus = num_gpus @@ -600,7 +625,7 @@ def __init__( self._learner_cls = learner_cls self._learner_kwargs = learner_kwargs self.distributed_backend = distributed_backend - self.multi_gpu_sync_mode = normalize_multi_gpu_sync_mode(multi_gpu_sync_mode) + self.multi_gpu_sync_mode = normalized_sync_mode self.multi_gpu_sync_interval = normalize_multi_gpu_sync_interval( int(multi_gpu_sync_interval) ) diff --git a/src/unilab/algos/torch/offpolicy/runner.py b/src/unilab/algos/torch/offpolicy/runner.py index b0c3fe533..e9bd7a5fa 100644 --- a/src/unilab/algos/torch/offpolicy/runner.py +++ b/src/unilab/algos/torch/offpolicy/runner.py @@ -87,6 +87,66 @@ def build_reward_comparison_metrics( return {"mean_ep100": float(reward_history[-1])} +def read_recent_replay_field( + replay_buffer: Any, + field_name: str, + start_ptr: int, + count: int, +) -> torch.Tensor: + idx = start_ptr % replay_buffer.capacity + + if hasattr(replay_buffer, field_name): + source = getattr(replay_buffer, field_name) + else: + packed_key = { + "rewards": "_rew_col", + "dones": "_done_col", + "truncated": "_trunc_col", + }[field_name] + source = replay_buffer._storage[:, getattr(replay_buffer, packed_key)] + + if idx + count <= replay_buffer.capacity: + return cast(torch.Tensor, source[idx : idx + count].clone()) + + split = replay_buffer.capacity - idx + return cast(torch.Tensor, torch.cat([source[idx:], source[: count - split]], dim=0).clone()) + + +def update_reward_stats_from_replay( + learner: Any, + replay_buffer: Any, + *, + start_ptr: int, + end_ptr: int, + num_envs: int, +) -> int: + if not hasattr(learner, "update_reward_stats"): + return end_ptr + if getattr(learner, "reward_normalizer", None) is None: + return end_ptr + + count = end_ptr - start_ptr + if count <= 0: + return end_ptr + if count > replay_buffer.capacity: + count = replay_buffer.capacity + start_ptr = end_ptr - count + if count % num_envs != 0: + count -= count % num_envs + start_ptr = end_ptr - count + if count <= 0: + return end_ptr + + rewards = read_recent_replay_field(replay_buffer, "rewards", start_ptr, count) + dones = read_recent_replay_field(replay_buffer, "dones", start_ptr, count) + num_steps = count // num_envs + learner.update_reward_stats( + rewards.view(num_steps, num_envs), + dones.view(num_steps, num_envs), + ) + return end_ptr + + class OffPolicyRunner(AsyncRunner): """Unified runner for SAC and TD3.""" @@ -179,50 +239,16 @@ def _sync_logger_replay_counters(logger, replay_buffer) -> None: def _read_recent_replay_field( replay_buffer, field_name: str, start_ptr: int, count: int ) -> torch.Tensor: - idx = start_ptr % replay_buffer.capacity - - if hasattr(replay_buffer, field_name): - source = getattr(replay_buffer, field_name) - else: - packed_key = { - "rewards": "_rew_col", - "dones": "_done_col", - "truncated": "_trunc_col", - }[field_name] - source = replay_buffer._storage[:, getattr(replay_buffer, packed_key)] - - if idx + count <= replay_buffer.capacity: - return cast(torch.Tensor, source[idx : idx + count].clone()) - - split = replay_buffer.capacity - idx - return cast(torch.Tensor, torch.cat([source[idx:], source[: count - split]], dim=0).clone()) + return read_recent_replay_field(replay_buffer, field_name, start_ptr, count) def _update_reward_stats_from_replay(self, replay_buffer, start_ptr: int, end_ptr: int) -> int: - if not hasattr(self.learner, "update_reward_stats"): - return end_ptr - if getattr(self.learner, "reward_normalizer", None) is None: - return end_ptr - - count = end_ptr - start_ptr - if count <= 0: - return end_ptr - if count > replay_buffer.capacity: - count = replay_buffer.capacity - start_ptr = end_ptr - count - if count % self.num_envs != 0: - count -= count % self.num_envs - start_ptr = end_ptr - count - if count <= 0: - return end_ptr - - rewards = self._read_recent_replay_field(replay_buffer, "rewards", start_ptr, count) - dones = self._read_recent_replay_field(replay_buffer, "dones", start_ptr, count) - num_steps = count // self.num_envs - self.learner.update_reward_stats( - rewards.view(num_steps, self.num_envs), - dones.view(num_steps, self.num_envs), + return update_reward_stats_from_replay( + self.learner, + replay_buffer, + start_ptr=start_ptr, + end_ptr=end_ptr, + num_envs=self.num_envs, ) - return end_ptr def learn( self, diff --git a/tests/algos/test_fast_sac_symmetry_contract.py b/tests/algos/test_fast_sac_symmetry_contract.py index cebae81a6..344daf5b0 100644 --- a/tests/algos/test_fast_sac_symmetry_contract.py +++ b/tests/algos/test_fast_sac_symmetry_contract.py @@ -194,6 +194,7 @@ def fake_all_reduce(tensor, op=None): def test_multi_gpu_offpolicy_runner_rejects_sac_symmetry_capability(): + from unilab.algos.torch.fast_sac.learner import FastSACLearner from unilab.algos.torch.offpolicy.multi_gpu_runner import MultiGPUOffPolicyRunner with pytest.raises( @@ -202,33 +203,108 @@ def test_multi_gpu_offpolicy_runner_rejects_sac_symmetry_capability(): ): MultiGPUOffPolicyRunner.validate_capabilities( algo_type="sac", + learner_cls=FastSACLearner, learner_kwargs={"use_symmetry": True}, num_gpus=2, ) @pytest.mark.parametrize( - ("algo_type", "learner_kwargs", "num_gpus"), + ("learner_kwargs", "num_gpus"), [ - ("sac", {"use_symmetry": False}, 2), - ("sac", {"use_symmetry": True}, 1), - ("td3", {"use_symmetry": True}, 2), + ({"use_symmetry": False}, 2), + ({"use_symmetry": True}, 1), ], ) def test_multi_gpu_offpolicy_runner_allows_supported_capabilities( - algo_type: str, learner_kwargs: dict[str, bool], num_gpus: int, ): + from unilab.algos.torch.fast_sac.learner import FastSACLearner from unilab.algos.torch.offpolicy.multi_gpu_runner import MultiGPUOffPolicyRunner MultiGPUOffPolicyRunner.validate_capabilities( - algo_type=algo_type, + algo_type="sac", + learner_cls=FastSACLearner, learner_kwargs=learner_kwargs, num_gpus=num_gpus, ) +def test_multi_gpu_offpolicy_runner_rejects_unsupported_learner_capability(): + from unilab.algos.torch.fast_td3.learner import FastTD3Learner + from unilab.algos.torch.offpolicy.multi_gpu_runner import MultiGPUOffPolicyRunner + + with pytest.raises(ValueError, match="FastTD3Learner.*does not support training.num_gpus"): + MultiGPUOffPolicyRunner.validate_capabilities( + algo_type="td3", + learner_cls=FastTD3Learner, + learner_kwargs={}, + num_gpus=2, + ) + + +def test_multi_gpu_offpolicy_runner_rejects_unsupported_sync_mode(): + from unilab.algos.torch.fast_sac.learner import FastSACLearner + from unilab.algos.torch.offpolicy.multi_gpu_runner import MultiGPUOffPolicyRunner + + with pytest.raises(ValueError, match="training.multi_gpu_sync_mode must be one of"): + MultiGPUOffPolicyRunner.validate_capabilities( + algo_type="sac", + learner_cls=FastSACLearner, + learner_kwargs={"use_symmetry": False}, + num_gpus=2, + sync_mode="bogus", + ) + + +def test_multi_gpu_offpolicy_runner_normalizes_sync_mode_before_capability_check(): + from unilab.algos.torch.fast_sac.learner import FastSACLearner + from unilab.algos.torch.offpolicy.multi_gpu_runner import MultiGPUOffPolicyRunner + + MultiGPUOffPolicyRunner.validate_capabilities( + algo_type="sac", + learner_cls=FastSACLearner, + learner_kwargs={"use_symmetry": False}, + num_gpus=2, + sync_mode="LOCAL_SGD", + ) + + +def test_multi_gpu_offpolicy_runner_requires_direct_learner_opt_in(): + from unilab.algos.torch.fast_sac.learner import FastSACLearner + from unilab.algos.torch.offpolicy.multi_gpu_runner import MultiGPUOffPolicyRunner + + class CustomSAC(FastSACLearner): + pass + + with pytest.raises(ValueError, match="CustomSAC.*does not support training.num_gpus"): + MultiGPUOffPolicyRunner.validate_capabilities( + algo_type="custom_sac", + learner_cls=CustomSAC, + learner_kwargs={"use_symmetry": False}, + num_gpus=2, + ) + + +def test_multi_gpu_offpolicy_runner_rejects_missing_distributed_hooks_before_spawn(): + from unilab.algos.torch.offpolicy.multi_gpu_runner import MultiGPUOffPolicyRunner + + class IncompleteLearner: + supports_multi_gpu = True + supports_multi_gpu_symmetry = False + supported_multi_gpu_sync_modes = frozenset({"local_sgd"}) + + with pytest.raises(ValueError, match="IncompleteLearner.*sync_initial_parameters"): + MultiGPUOffPolicyRunner.validate_capabilities( + algo_type="incomplete", + learner_cls=IncompleteLearner, + learner_kwargs={}, + num_gpus=2, + sync_mode="local_sgd", + ) + + def test_fast_sac_local_sgd_skips_per_update_gradient_all_reduce( monkeypatch: pytest.MonkeyPatch, ): diff --git a/tests/algos/test_flash_sac_learner.py b/tests/algos/test_flash_sac_learner.py index 2d033b789..1df91d949 100644 --- a/tests/algos/test_flash_sac_learner.py +++ b/tests/algos/test_flash_sac_learner.py @@ -10,6 +10,7 @@ from unilab.algos.torch.flash_sac.learner import FlashSACLearner, RewardNormalizer from unilab.algos.torch.flash_sac.update import compute_categorical_td_target +from unilab.algos.torch.offpolicy.distributed import validate_distributed_learner_capability def _make_batch(batch_size: int = 32) -> dict[str, torch.Tensor]: @@ -33,6 +34,23 @@ def _make_batch(batch_size: int = 32) -> dict[str, torch.Tensor]: } +def _make_small_learner(**kwargs: Any) -> FlashSACLearner: + defaults = { + "obs_dim": 4, + "action_dim": 2, + "critic_obs_dim": 6, + "actor_hidden_dim": 8, + "critic_hidden_dim": 8, + "actor_num_blocks": 1, + "critic_num_blocks": 1, + "num_atoms": 5, + "device": "cpu", + "use_compile": False, + } + defaults.update(kwargs) + return FlashSACLearner(**defaults) + + def test_flashsac_learner_exposes_expected_dims(): learner = FlashSACLearner(obs_dim=98, action_dim=29, critic_obs_dim=101, device="cpu") @@ -41,6 +59,106 @@ def test_flashsac_learner_exposes_expected_dims(): assert learner.action_dim == 29 +def test_flashsac_learner_declares_multi_gpu_contract() -> None: + validate_distributed_learner_capability( + learner_cls=FlashSACLearner, + algo_type="flashsac", + learner_kwargs={}, + num_gpus=2, + sync_mode="LOCAL_SGD", + ) + + learner = _make_small_learner(world_size=2, distributed_sync_mode="LOCAL_SGD") + + assert learner.supports_multi_gpu is True + assert learner.supported_multi_gpu_sync_modes == frozenset({"sync_sgd", "local_sgd"}) + assert learner.world_size == 2 + assert learner.distributed_sync_mode == "local_sgd" + + +def test_flashsac_parameter_sync_tensors_include_temperature_and_persistent_buffers() -> None: + learner = _make_small_learner() + tensors = learner._parameter_sync_tensors() + ptrs = {tensor.data_ptr() for tensor in tensors} + + assert learner.temperature.log_temp.data_ptr() in ptrs + assert learner.actor.embedder.norm.running_mean.data_ptr() in ptrs + assert learner.critic.embedder.norm.running_mean.data_ptr() in ptrs + assert learner.target_critic.embedder.norm.running_mean.data_ptr() in ptrs + assert learner.critic.predictor.support.data_ptr() in ptrs + + +def test_flashsac_reduce_gradients_averages_flat_gradient_payload(monkeypatch) -> None: + learner = _make_small_learner(world_size=2, distributed_sync_mode="sync_sgd") + for param in learner.actor.parameters(): + param.grad = torch.ones_like(param) + + calls = [] + + def fake_all_reduce(tensor: torch.Tensor, op=None) -> None: + del op + calls.append(tensor.numel()) + tensor.mul_(4.0) + + monkeypatch.setattr("unilab.algos.torch.flash_sac.learner.dist.all_reduce", fake_all_reduce) + + assert learner._reduce_gradients(learner.actor) is True + + assert calls == [sum(param.numel() for param in learner.actor.parameters())] + for param in learner.actor.parameters(): + assert param.grad is not None + torch.testing.assert_close(param.grad, torch.full_like(param.grad, 2.0)) + + +def test_flashsac_reduce_gradients_reports_nonfinite_payload(monkeypatch) -> None: + learner = _make_small_learner(world_size=2, distributed_sync_mode="sync_sgd") + first_param = next(learner.actor.parameters()) + first_param.grad = torch.ones_like(first_param) + + def fake_all_reduce(tensor: torch.Tensor, op=None) -> None: + del op + tensor[0] = float("inf") + + monkeypatch.setattr("unilab.algos.torch.flash_sac.learner.dist.all_reduce", fake_all_reduce) + + assert learner._reduce_gradients(learner.actor) is False + assert first_param.grad is not None + torch.testing.assert_close(first_param.grad, torch.ones_like(first_param)) + + +def test_flashsac_obs_normalizer_uses_cross_rank_moments(monkeypatch) -> None: + learner = _make_small_learner(obs_normalization=True, world_size=2) + + monkeypatch.setattr("unilab.algos.torch.flash_sac.learner.dist.is_available", lambda: True) + monkeypatch.setattr("unilab.algos.torch.flash_sac.learner.dist.is_initialized", lambda: True) + + def fake_all_reduce(payload: torch.Tensor, op=None) -> None: + del op + obs_dim = 4 + payload[:obs_dim] += torch.tensor([10.0, 20.0, 30.0, 40.0]) + payload[obs_dim : 2 * obs_dim] += torch.tensor([50.0, 200.0, 450.0, 800.0]) + payload[-1] += 2.0 + + monkeypatch.setattr("unilab.algos.torch.flash_sac.learner.dist.all_reduce", fake_all_reduce) + + learner._update_obs_normalizer( + torch.tensor( + [ + [1.0, 2.0, 3.0, 4.0], + [3.0, 4.0, 5.0, 6.0], + ] + ) + ) + + normalizer = learner.obs_normalizer + assert not isinstance(normalizer, torch.nn.Identity) + torch.testing.assert_close(normalizer.count, torch.tensor(4)) + torch.testing.assert_close( + normalizer.mean, + torch.tensor([3.5, 6.5, 9.5, 12.5]), + ) + + def test_flashsac_compile_targets_training_hot_paths(monkeypatch) -> None: calls: list[tuple[str, dict[str, Any]]] = [] diff --git a/tests/algos/test_offpolicy_double_buffer_runner.py b/tests/algos/test_offpolicy_double_buffer_runner.py index d94c6cc65..ad8b538e0 100644 --- a/tests/algos/test_offpolicy_double_buffer_runner.py +++ b/tests/algos/test_offpolicy_double_buffer_runner.py @@ -169,6 +169,10 @@ def close(self): pass class _FakeLearner: + supports_multi_gpu = True + supports_multi_gpu_symmetry = False + supported_multi_gpu_sync_modes = frozenset({"sync_sgd", "local_sgd"}) + class actor: @staticmethod def state_dict(): @@ -180,6 +184,12 @@ def __init__(self, *args, **kwargs): del args self.kwargs = kwargs + def sync_initial_parameters(self, src=0): + del src + + def average_distributed_parameters(self): + pass + class _FakeRunner: def __init__(self, *args, **kwargs): self.kwargs = kwargs @@ -464,15 +474,90 @@ def __init__(self, *args, **kwargs): assert runner.kwargs["device"] == device -def test_flashsac_double_buffer_multi_gpu_rejected(): +def test_flashsac_double_buffer_multi_gpu_dispatches_to_multi_gpu_runner( + monkeypatch: pytest.MonkeyPatch, +): + import gymnasium as gym + + import unilab.algos.torch.flash_sac.double_buffer as flash_db_mod + import unilab.algos.torch.offpolicy.multi_gpu_runner as mg_mod + cfg = _offpolicy_cfg( [ "algo=flashsac", "training.device=cuda", "training.num_gpus=2", + "training.multi_gpu_sync_mode=sync_sgd", + "training.multi_gpu_sync_interval=3", + "algo.obs_normalization=true", + ] + ) + + class _FakeEnv: + obs_groups_spec = {"obs": 4, "critic": 6} + action_space = gym.spaces.Box(-1.0, 1.0, shape=(2,)) + + def close(self): + pass + + class _FakeLearner: + supports_multi_gpu = True + supports_multi_gpu_symmetry = False + supported_multi_gpu_sync_modes = frozenset({"sync_sgd", "local_sgd"}) + + class actor: + @staticmethod + def state_dict(): + return {"w": MagicMock(shape=(4,))} + + update_count = 0 + + def __init__(self, *args, **kwargs): + del args + self.kwargs = kwargs + + def sync_initial_parameters(self, src=0): + del src + + def average_distributed_parameters(self): + pass + + class _FakeRunner: + def __init__(self, *args, **kwargs): + del args + self.kwargs = kwargs + + monkeypatch.setattr(flash_db_mod, "ensure_registries", lambda: None) + monkeypatch.setattr(flash_db_mod, "create_env", lambda *args, **kwargs: _FakeEnv()) + monkeypatch.setattr(flash_db_mod, "FlashSACLearner", _FakeLearner) + monkeypatch.setattr(mg_mod, "MultiGPUOffPolicyRunner", _FakeRunner) + + runner = _offpolicy().build_runner("flashsac", cfg) + + assert isinstance(runner, _FakeRunner) + assert runner.kwargs["algo_type"] == "flashsac" + assert runner.kwargs["num_gpus"] == 2 + assert runner.kwargs["multi_gpu_sync_mode"] == "sync_sgd" + assert runner.kwargs["multi_gpu_sync_interval"] == 3 + assert runner.kwargs["obs_normalization"] is True + assert runner.kwargs["learner"].kwargs["obs_normalization"] is True + assert runner.kwargs["learner_kwargs"]["obs_normalization"] is True + assert runner.kwargs["actor_kwargs"] == { + "actor_num_blocks": 2, + "actor_noise_zeta_mu": 2.0, + "actor_noise_zeta_max": 16, + } + + +def test_flashsac_double_buffer_multi_gpu_requires_cuda_device(): + cfg = _offpolicy_cfg( + [ + "algo=flashsac", + "training.device=cpu", + "training.num_gpus=2", ] ) - with pytest.raises(ValueError, match="Only SAC supports training.num_gpus > 1"): + with pytest.raises(ValueError, match="requires a CUDA device"): _offpolicy().build_runner("flashsac", cfg) diff --git a/tests/algos/test_offpolicy_runner_unit.py b/tests/algos/test_offpolicy_runner_unit.py index 5062da5ce..952fd653c 100644 --- a/tests/algos/test_offpolicy_runner_unit.py +++ b/tests/algos/test_offpolicy_runner_unit.py @@ -34,6 +34,9 @@ def parameters(self): class _FakeLearner: last_instance: "_FakeLearner | None" = None + supports_multi_gpu = True + supports_multi_gpu_symmetry = False + supported_multi_gpu_sync_modes = frozenset({"sync_sgd", "local_sgd"}) def __init__(self, *args, **kwargs) -> None: del args @@ -70,6 +73,24 @@ def get_state_dict(self) -> dict[str, int]: return {"update_count": self.update_count} +class _RewardStatsLearner(_FakeLearner): + sync_reward_calls = 0 + reward_update_calls = 0 + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.reward_normalizer = object() + + def update_reward_stats(self, rewards: torch.Tensor, dones: torch.Tensor) -> None: + del rewards, dones + self.reward_update_calls += 1 + + def sync_reward_normalizer(self, src: int = 0) -> None: + del src + self.sync_reward_calls += 1 + type(self).sync_reward_calls += 1 + + class _FakeReplayBuffer: last_instance: "_FakeReplayBuffer | None" = None @@ -295,6 +316,8 @@ def _reset_fakes() -> None: _FakeWeightSync.last_instance = None _FakeLogger.last_instance = None _FakeLearner.last_instance = None + _RewardStatsLearner.sync_reward_calls = 0 + _RewardStatsLearner.reward_update_calls = 0 @pytest.mark.parametrize( @@ -1472,6 +1495,129 @@ def close(self) -> None: assert learner.average_parameter_calls == 1 +@pytest.mark.parametrize(("rank", "expected_reward_updates"), [(0, 1), (1, 0)]) +def test_multi_gpu_reward_stats_rank0_scans_replay_then_all_ranks_sync( + monkeypatch: pytest.MonkeyPatch, + tmp_path, + rank: int, + expected_reward_updates: int, +) -> None: + replay_buffer = _FakeReplayBuffer(capacity=16, obs_dim=4, action_dim=2, device="cpu") + replay_buffer.size[0] = 16 + replay_buffer.ptr[0] = 16 + replay_buffer._rew_col = 0 + replay_buffer._done_col = 1 + replay_buffer._storage[:, 0] = torch.arange(16, dtype=torch.float32) + logger = _FakeLogger() + + class _StopEvent: + def is_set(self) -> bool: + return False + + class _ReadyQueue: + def get(self, timeout: float | None = None) -> int: + del timeout + return 1 + + class _DoneQueue: + def put(self, item: int, timeout: float | None = None) -> None: + del item, timeout + + class _FakePipeline: + def __init__(self, *args, **kwargs) -> None: + del args, kwargs + self.last_incremental_h2d_time_s = 0.0 + + def start_prepare( + self, tick_id: int, sample_count: int, min_snapshot_ptr=None, **kwargs + ) -> bool: + del tick_id, sample_count, min_snapshot_ptr, kwargs + return True + + def batch_ready(self, tick_id: int, sample_count: int) -> bool: + del tick_id, sample_count + return True + + def sample_large_batch(self, tick_id: int, sample_count: int) -> dict[str, torch.Tensor]: + del tick_id + return { + "obs": torch.zeros(sample_count, 4), + "actions": torch.zeros(sample_count, 2), + "rewards": torch.zeros(sample_count), + "next_obs": torch.zeros(sample_count, 4), + "dones": torch.zeros(sample_count), + "truncated": torch.zeros(sample_count), + "critic": torch.zeros(sample_count, 4), + "next_critic": torch.zeros(sample_count, 4), + } + + def after_tick(self) -> None: + pass + + def close(self) -> None: + pass + + monkeypatch.setattr(multi_gpu_runner_module.torch.cuda, "set_device", lambda rank: None) + monkeypatch.setattr(multi_gpu_runner_module.dist, "init_process_group", lambda *a, **k: None) + monkeypatch.setattr(multi_gpu_runner_module.dist, "barrier", lambda: None) + monkeypatch.setattr(multi_gpu_runner_module.dist, "destroy_process_group", lambda: None) + monkeypatch.setattr(multi_gpu_runner_module.torch, "save", lambda *a, **k: None) + monkeypatch.setattr(multi_gpu_runner_module, "SharedWeightSync", _FakeWeightSync) + monkeypatch.setattr(multi_gpu_runner_module, "OffPolicyLogger", lambda **kwargs: logger) + monkeypatch.setattr( + multi_gpu_runner_module, + "MultiGPUCPUPinnedReplayPipeline", + _FakePipeline, + ) + monkeypatch.setattr(multi_gpu_runner_module.time, "perf_counter", lambda: 0.0) + + multi_gpu_runner_module._learner_worker( + rank=rank, + world_size=2, + learner_cls=_RewardStatsLearner, + learner_kwargs={}, + runner_kwargs={ + "max_iterations": 1, + "save_interval": 0, + "log_dir": str(tmp_path), + "batch_size": 4, + "updates_per_step": 1, + "policy_frequency": 1, + "sync_collection": True, + "env_steps_per_sync": 1, + "env_name": "DummyEnv", + "num_envs": 4, + "obs_dim": 4, + "action_dim": 2, + "logger_type": "none", + "learning_starts": 0, + "seed": 1, + "distributed_backend": "nccl", + "multi_gpu_sync_mode": "sync_sgd", + "multi_gpu_sync_interval": 1, + "algo_type": "flashsac", + }, + replay_buffer=replay_buffer, + weight_sync_name="fake-weight-sync", + weight_sync_lock=None, + weight_param_shapes={"weight": torch.Size([1])}, + stop_event=_StopEvent(), + collection_ready_queue=_ReadyQueue(), + trainer_done_queue=_DoneQueue(), + metrics_queue=queue.Queue(), + collector_pack_request_queue=[queue.Queue(), queue.Queue()], + collector_pack_ready_queue=[queue.Queue(), queue.Queue()], + collector_pack_shared_slots=[[torch.zeros(1)], [torch.zeros(1)]], + master_port=29500, + ) + + learner = _RewardStatsLearner.last_instance + assert isinstance(learner, _RewardStatsLearner) + assert learner.reward_update_calls == expected_reward_updates + assert learner.sync_reward_calls == 1 + assert _RewardStatsLearner.sync_reward_calls == 1 + + def test_multi_gpu_batch_ready_wait_is_reported_as_collector_wait( monkeypatch: pytest.MonkeyPatch, tmp_path ) -> None: diff --git a/tests/ipc/test_async_runner.py b/tests/ipc/test_async_runner.py index 9c4b4c11a..4f7afd90b 100644 --- a/tests/ipc/test_async_runner.py +++ b/tests/ipc/test_async_runner.py @@ -10,6 +10,7 @@ import pytest +import unilab.ipc.async_runner as async_runner_module from unilab.ipc.async_runner import AsyncRunner from unilab.ipc.collector_error import format_collector_death @@ -46,6 +47,36 @@ def _make_runner(rl_cfg=None, **kwargs) -> _StubRunner: ) +class _InlineProcess: + """Process test double for contract checks that do not need OS spawn.""" + + def __init__(self, target, args=(), kwargs=None, daemon=None): + self._target = target + self._args = args + self._kwargs = kwargs or {} + self.daemon = daemon + self.exitcode = None + self.started = False + + def start(self) -> None: + self.started = True + try: + self._target(*self._args, **self._kwargs) + except BaseException: + self.exitcode = 1 + raise + self.exitcode = 0 + + def is_alive(self) -> bool: + return False + + def join(self, timeout=None) -> None: + return None + + def terminate(self) -> None: + self.exitcode = -signal.SIGTERM + + # --------------------------------------------------------------------------- # Initialisation # --------------------------------------------------------------------------- @@ -161,6 +192,7 @@ def _worker_wait_for_stop(stop_event) -> None: stop_event.wait(timeout=30) +@pytest.mark.slow def test_close_joins_running_collector(): """close() must signal the stop event and reap the collector process. @@ -193,16 +225,16 @@ def _noop_collector(stop_event) -> None: stop_event.wait(timeout=30) -def _collector_report_kwargs( +def _collector_record_kwargs( stop_event, - report_queue, + record, token: str, sim_backend: str = "missing", ) -> None: - report_queue.put({"sim_backend": sim_backend, "token": token}) - stop_event.wait(timeout=30) + record.append({"sim_backend": sim_backend, "token": token}) +@pytest.mark.slow def test_start_collector_spawns_process(): """_start_collector() must create and start a subprocess.""" r = _make_runner() @@ -212,20 +244,24 @@ def test_start_collector_spawns_process(): r.close() -def test_start_collector_does_not_merge_runner_runtime_fields(): +def test_start_collector_does_not_merge_runner_runtime_fields(monkeypatch): r = _make_runner(sim_backend="motrix") - report_queue = _SPAWN_CTX.Queue() - r._start_collector( - target_fn=_collector_report_kwargs, - kwargs={ - "stop_event": r._stop_event, - "report_queue": report_queue, - "token": "ok", - }, - ) - payload = report_queue.get(timeout=5) - assert payload == {"sim_backend": "missing", "token": "ok"} - r.close() + record = [] + monkeypatch.setattr(async_runner_module._SPAWN_CTX, "Process", _InlineProcess) + try: + r._start_collector( + target_fn=_collector_record_kwargs, + kwargs={ + "stop_event": r._stop_event, + "record": record, + "token": "ok", + }, + ) + assert record == [{"sim_backend": "missing", "token": "ok"}] + assert r._collector_process is not None + assert r._collector_process.exitcode == 0 + finally: + r.close() def test_format_collector_death_reports_shell_style_sigbus(): diff --git a/tests/scripts/test_train_scripts.py b/tests/scripts/test_train_scripts.py index 3f8a65f72..49464695c 100644 --- a/tests/scripts/test_train_scripts.py +++ b/tests/scripts/test_train_scripts.py @@ -2427,16 +2427,17 @@ def test_offpolicy_g1_rough_terrain_task_composes() -> None: assert cfg.training.sim_backend == "mujoco" -def test_offpolicy_flashsac_rejects_multi_gpu(): +def test_offpolicy_flashsac_multi_gpu_requires_cuda_device(): cfg = _offpolicy_cfg( [ "algo=flashsac", "task=flashsac/g1_walk_flat/mujoco", "training.num_gpus=2", + "training.device=cpu", ] ) - with pytest.raises(ValueError, match="Only SAC supports training.num_gpus > 1"): + with pytest.raises(ValueError, match="requires a CUDA device"): _offpolicy().build_runner("flashsac", cfg)