Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/unilab/algos/torch/appo/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,11 @@ def __init__(
sim_backend: str = "mujoco",
num_envs: int = 1024,
steps_per_env: int = 24,
num_workers: int = 1, # kept for API compat, but only 1 collector used
replay_queue_size: int = 3,
seed: int | None = None,
resume_path: str | None = None,
nan_guard_cfg: NanGuardCfg | None = None,
):
del num_workers
super().__init__(
env_name=env_name,
env_cfg_overrides=env_cfg_overrides,
Expand Down
93 changes: 71 additions & 22 deletions src/unilab/ipc/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,44 +104,67 @@ def add(
if self._critic_dim > 0 and (critic is None or next_critic is None):
raise ValueError("ReplayBuffer with critic_dim > 0 requires critic and next_critic")

parts = [
obs,
next_obs,
actions,
rewards.unsqueeze(1),
dones.unsqueeze(1),
truncated.unsqueeze(1),
]
if has_critic:
assert next_critic is not None
parts.extend([critic, next_critic])
row = torch.cat(parts, dim=1)

if idx + n <= self.capacity:
self._storage[idx : idx + n] = row
target = self._storage[idx : idx + n]
self._write_transition_rows(
target,
obs,
actions,
rewards,
next_obs,
dones,
truncated,
critic,
next_critic,
has_critic=has_critic,
)
self._patch_terminal_next_observations(
self._storage[idx : idx + n, self._nobs_sl],
target[:, self._nobs_sl],
terminal_mask,
terminal_next_obs,
self._storage[idx : idx + n, self._ncritic_sl] if has_critic else None,
target[:, self._ncritic_sl] if has_critic else None,
terminal_next_critic,
)
else:
split = self.capacity - idx
self._storage[idx:] = row[:split]
self._storage[: n - split] = row[split:]
first = self._storage[idx:]
second = self._storage[: n - split]
self._write_transition_rows(
first,
obs[:split],
actions[:split],
rewards[:split],
next_obs[:split],
dones[:split],
truncated[:split],
critic[:split] if critic is not None else None,
next_critic[:split] if next_critic is not None else None,
has_critic=has_critic,
)
self._write_transition_rows(
second,
obs[split:],
actions[split:],
rewards[split:],
next_obs[split:],
dones[split:],
truncated[split:],
critic[split:] if critic is not None else None,
next_critic[split:] if next_critic is not None else None,
has_critic=has_critic,
)
self._patch_terminal_next_observations(
self._storage[idx:, self._nobs_sl],
first[:, self._nobs_sl],
terminal_mask[:split] if terminal_mask is not None else None,
terminal_next_obs[:split] if terminal_next_obs is not None else None,
self._storage[idx:, self._ncritic_sl] if has_critic else None,
first[:, self._ncritic_sl] if has_critic else None,
terminal_next_critic[:split] if terminal_next_critic is not None else None,
)
self._patch_terminal_next_observations(
self._storage[: n - split, self._nobs_sl],
second[:, self._nobs_sl],
terminal_mask[split:] if terminal_mask is not None else None,
terminal_next_obs[split:] if terminal_next_obs is not None else None,
self._storage[: n - split, self._ncritic_sl] if has_critic else None,
second[:, self._ncritic_sl] if has_critic else None,
terminal_next_critic[split:] if terminal_next_critic is not None else None,
)

Expand All @@ -156,6 +179,32 @@ def add(
args={"batch_size": int(n), "device": self.device},
)

def _write_transition_rows(
self,
target,
obs,
actions,
rewards,
next_obs,
dones,
truncated,
critic,
next_critic,
*,
has_critic: bool,
) -> None:
target[:, self._obs_sl] = obs
target[:, self._nobs_sl] = next_obs
target[:, self._act_sl] = actions
target[:, self._rew_col] = rewards
target[:, self._done_col] = dones
target[:, self._trunc_col] = truncated
if has_critic:
assert critic is not None
assert next_critic is not None
target[:, self._critic_sl] = critic
target[:, self._ncritic_sl] = next_critic

@staticmethod
def _patch_terminal_next_observations(
target_next_obs,
Expand Down
32 changes: 32 additions & 0 deletions tests/ipc/test_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import multiprocessing as mp

import numpy as np
import pytest
import torch

from unilab.ipc.replay_buffer import ReplayBuffer
Expand Down Expand Up @@ -190,6 +191,37 @@ def test_add_stores_combined_dones_and_truncated_contract():
torch.testing.assert_close(buf._storage[:3, buf._trunc_col], torch.tensor([0.0, 1.0, 0.0]))


def test_add_writes_packed_columns_without_cat(monkeypatch: pytest.MonkeyPatch):
"""Collector hot path should not allocate a full concatenated transition batch."""
buf = ReplayBuffer(
capacity=8,
obs_dim=_OBS_DIM,
action_dim=_ACTION_DIM,
critic_dim=5,
device=_DEVICE,
)
obs = torch.randn(4, _OBS_DIM)
act = torch.randn(4, _ACTION_DIM)
rew = torch.randn(4)
nobs = torch.randn(4, _OBS_DIM)
done = torch.zeros(4)
trunc = torch.zeros(4)
critic = torch.randn(4, 5)
ncritic = torch.randn(4, 5)

def _fail_cat(*args, **kwargs):
del args, kwargs
raise AssertionError("ReplayBuffer.add should write packed columns directly")

monkeypatch.setattr(torch, "cat", _fail_cat)

buf.add(obs, act, rew, nobs, done, trunc, critic=critic, next_critic=ncritic)

torch.testing.assert_close(buf._storage[:4, buf._obs_sl], obs)
torch.testing.assert_close(buf._storage[:4, buf._act_sl], act)
torch.testing.assert_close(buf._storage[:4, buf._critic_sl], critic)


# ---------------------------------------------------------------------------
# Multiprocess test
# ---------------------------------------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions tests/scripts/test_train_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,8 @@ def test_build_appo_runner_kwargs_forwards_sim_backend():
assert runner_kwargs["collector_device"] == "cpu"
assert runner_kwargs["num_envs"] == cfg.algo.num_envs
assert runner_kwargs["steps_per_env"] == cfg.algo.steps_per_env
assert "num_workers" not in runner_kwargs
assert "num_collectors" not in runner_kwargs
assert runner_kwargs["env_cfg_overrides"]["reward_config"]["scales"] == {}


Expand Down
Loading