Skip to content

Commit 8e7c1ab

Browse files
author
Ryan Partridge
committed
fix(buffer): Fixed warm method bug when num_envs=1.
Vectorized environments don't like having a single environment. Added condition to ensure warm method requires a minimum of '2' envs.
1 parent 1da3c8e commit 8e7c1ab

3 files changed

Lines changed: 18 additions & 4 deletions

File tree

tests/test_buffer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,10 +264,19 @@ def test_buffer_warm(self, replay_buffer: ReplayBuffer):
264264
assert len(replay_buffer) == 0
265265

266266
n_samples = 10
267-
model.buffer.warm(model, n_samples, 2)
267+
model.buffer.warm(model, n_samples)
268268

269269
assert len(model.buffer) >= n_samples
270270

271+
def test_buffer_warm_single_env(self, replay_buffer: ReplayBuffer):
272+
model = NeuroFlowCT("InvertedPendulum-v5", 8, 16, device=torch.device("cpu"))
273+
assert len(replay_buffer) == 0
274+
275+
n_samples = 10
276+
277+
with pytest.raises(ValueError):
278+
model.buffer.warm(model, n_samples, 1)
279+
271280
def test_add_multi(self, replay_buffer: ReplayBuffer):
272281
"""Test adding multiple experiences at once using add_multi."""
273282
# Create test data: batch of 5 experiences

velora/buffer/replay.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,12 @@ def warm(self, agent: "RLModuleAgent", n_samples: int, num_envs: int = 8) -> Non
9595
Parameters:
9696
agent (Any): the agent to generate samples with
9797
n_samples (int): the maximum number of samples to generate
98-
num_envs (int, optional): number of vectorized environments
98+
num_envs (int, optional): number of vectorized environments. Cannot
99+
be smaller than `2`
99100
"""
101+
if num_envs < 2:
102+
raise ValueError(f"'{num_envs=}' cannot be smaller than 2.")
103+
100104
envs = gym.make_vec(
101105
agent.env.spec.id,
102106
num_envs=num_envs,
@@ -115,6 +119,7 @@ def warm(self, agent: "RLModuleAgent", n_samples: int, num_envs: int = 8) -> Non
115119
dones = terminated | truncated
116120

117121
self.add_multi(states, actions, rewards, next_states, dones, hidden)
122+
118123
states = next_states
119124

120125
envs.close()

velora/models/nf/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def train(
321321
)
322322

323323
if warmup_steps > 0:
324-
self.buffer.warm(self, warmup_steps, 1 if warmup_steps < 8 else 8)
324+
self.buffer.warm(self, warmup_steps, 2 if warmup_steps < 8 else 8)
325325

326326
with TrainHandler(
327327
self, n_episodes, max_steps, log_freq, window_size, callbacks
@@ -716,7 +716,7 @@ def train(
716716
)
717717

718718
if warmup_steps > 0:
719-
self.buffer.warm(self, warmup_steps, 1 if warmup_steps < 8 else 8)
719+
self.buffer.warm(self, warmup_steps, 2 if warmup_steps < 8 else 8)
720720

721721
with TrainHandler(
722722
self, n_episodes, max_steps, log_freq, window_size, callbacks

0 commit comments

Comments
 (0)