Skip to content

Commit 257de0c

Browse files
committed
refactor(tweaks): Minor changes to DDPG and buffer methods.
- Removed 'flatten()' from action return. Can cause issues in future with vector environments - Changed 'buffer.push()' -> 'buffer.add()' for better UX - Changes 'state.unsqueeze(0)' in predict method to conditional - Added 'pragma: no cover' comment to 'except ImportError', required for Python 3.11 compatibility. Not needed to test when have 'tox'
1 parent 14c57e1 commit 257de0c

File tree

8 files changed

+59
-49
lines changed

8 files changed

+59
-49
lines changed

docs/learn/customize/buffers.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ buffer = ReplayBuffer(capacity=100_000, device=device)
2626

2727
???+ api "API Docs"
2828

29-
[`velora.buffer.ReplayBuffer.push(exp)`](../reference/buffer.md#velora.buffer.BufferBase.push)
29+
[`velora.buffer.ReplayBuffer.add(exp)`](../reference/buffer.md#velora.buffer.BufferBase.add)
3030

31-
To add an item, we `push()` a set of `Experience` to it:
31+
To add an item, we `add()` a set of `Experience` to it:
3232

3333
```python
3434
from velora.buffer import Experience
@@ -42,7 +42,7 @@ exp = Experience(
4242
done=False,
4343
)
4444

45-
buffer.push(exp)
45+
buffer.add(exp)
4646
```
4747

4848
`Experience` is a simple dataclass that holds the information of a single environment `timestep`. We'll talk about them in more detail later.
@@ -126,7 +126,7 @@ exp = Experience(
126126
next_state=torch.zeros(state_dim, device=device),
127127
done=False,
128128
)
129-
buffer.push(exp)
129+
buffer.add(exp)
130130

131131
# Get a batch
132132
batch = buffer.sample(batch_size=5)
@@ -158,11 +158,11 @@ buffer = RolloutBuffer(capacity=10, device=device)
158158

159159
???+ api "API Docs"
160160

161-
[`velora.buffer.RolloutBuffer.push(exp)`](../reference/buffer.md#velora.buffer.RolloutBuffer.push)
161+
[`velora.buffer.RolloutBuffer.add(exp)`](../reference/buffer.md#velora.buffer.RolloutBuffer.add)
162162

163163
[`velora.buffer.RolloutBuffer.empty()`](../reference/buffer.md#velora.buffer.RolloutBuffer.empty)
164164

165-
To add an item, we `push()` a set of `Experience` to it:
165+
To add an item, we `add()` a set of `Experience` to it:
166166

167167
```python
168168
from velora.buffer import Experience
@@ -176,7 +176,7 @@ exp = Experience(
176176
done=False,
177177
)
178178

179-
buffer.push(exp)
179+
buffer.add(exp)
180180
```
181181

182182
Once the buffer is full, we need to `empty` it before we can add new samples:
@@ -228,8 +228,8 @@ exp = Experience(
228228
done=False,
229229
)
230230

231-
buffer.push(exp)
232-
# buffer.push(exp) # BufferError
231+
buffer.add(exp)
232+
# buffer.add(exp) # BufferError
233233

234234
batch = buffer.sample()
235235

@@ -374,7 +374,7 @@ for i_ep in range(n_episodes):
374374
done = terminated or truncated
375375

376376
# Add it to the buffer
377-
buffer.push(
377+
buffer.add(
378378
Experience(state, action.item(), reward, next_state, done),
379379
)
380380

tests/models/test_ddpg.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def test_train_step(self, ddpg: LiquidDDPG):
172172

173173
# Create Experience object explicitly
174174
exp = Experience(state, action, reward, next_state, done)
175-
ddpg.buffer.push(exp)
175+
ddpg.buffer.add(exp)
176176

177177
# Perform training step
178178
result = ddpg._train_step(batch_size, gamma)
@@ -190,7 +190,7 @@ def test_train_step_insufficient_buffer(self, ddpg: LiquidDDPG):
190190
for _ in range(batch_size - 1):
191191
state = torch.zeros(ddpg.state_dim)
192192
exp = Experience(state, 1.0, 2.0, state, False)
193-
ddpg.buffer.push(exp)
193+
ddpg.buffer.add(exp)
194194

195195
# Should return None when buffer is insufficient
196196
result = ddpg._train_step(batch_size, gamma)
@@ -228,7 +228,7 @@ def test_save_load_with_buffer(self, ddpg: LiquidDDPG):
228228
next_state = torch.ones(ddpg.state_dim)
229229
done = i == 9
230230
exp = Experience(state, action, reward, next_state, done)
231-
ddpg.buffer.push(exp)
231+
ddpg.buffer.add(exp)
232232

233233
with tempfile.TemporaryDirectory() as temp_dir:
234234
filepath = os.path.join(temp_dir, "model.pt")
@@ -415,7 +415,7 @@ def patched_init(self, dirname, **kwargs):
415415

416416
# Mock buffer.push to prevent storing experiences
417417
with (
418-
patch.object(ddpg.buffer, "push"),
418+
patch.object(ddpg.buffer, "add"),
419419
patch.object(ddpg.buffer, "warm"),
420420
):
421421
# Mock _train_step to avoid network operations
@@ -468,7 +468,7 @@ def test_early_stopping(self, ddpg: LiquidDDPG, env: gym.Env):
468468
# Mock necessary methods to avoid actual training
469469
with (
470470
patch.object(ddpg.buffer, "warm"),
471-
patch.object(ddpg.buffer, "push"),
471+
patch.object(ddpg.buffer, "add"),
472472
patch.object(ddpg, "_train_step", return_value=(0.1, 0.2)),
473473
patch.object(
474474
ddpg,

tests/test_buffer.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def sample_experience(self) -> Experience:
8686
def filled_buffer(self, replay_buffer: ReplayBuffer) -> ReplayBuffer:
8787
"""Fixture that returns a replay buffer with 10 experiences."""
8888
for i in range(10):
89-
replay_buffer.push(
89+
replay_buffer.add(
9090
Experience(
9191
state=torch.tensor([float(i), float(i + 1)]),
9292
action=torch.tensor([i]),
@@ -109,7 +109,7 @@ def test_config(self, replay_buffer: ReplayBuffer):
109109
def test_push_experience(
110110
self, replay_buffer: ReplayBuffer, sample_experience: Experience
111111
) -> None:
112-
replay_buffer.push(sample_experience)
112+
replay_buffer.add(sample_experience)
113113
assert len(replay_buffer) == 1
114114
assert isinstance(replay_buffer.buffer[0], Experience)
115115

@@ -118,13 +118,13 @@ def test_buffer_capacity(
118118
) -> None:
119119
# Fill buffer beyond capacity
120120
for _ in range(150):
121-
replay_buffer.push(sample_experience)
121+
replay_buffer.add(sample_experience)
122122
assert len(replay_buffer) == 100 # Should not exceed capacity
123123

124124
def test_sample_insufficient_experiences(
125125
self, replay_buffer: ReplayBuffer, sample_experience: Experience
126126
) -> None:
127-
replay_buffer.push(sample_experience)
127+
replay_buffer.add(sample_experience)
128128
with pytest.raises(ValueError):
129129
replay_buffer.sample(batch_size=2)
130130

@@ -133,7 +133,7 @@ def test_sample_batch(
133133
) -> None:
134134
# Fill buffer with multiple experiences
135135
for _ in range(10):
136-
replay_buffer.push(sample_experience)
136+
replay_buffer.add(sample_experience)
137137

138138
batch_size = 5
139139
batch = replay_buffer.sample(batch_size)
@@ -149,7 +149,7 @@ def test_len_method(
149149
self, replay_buffer: ReplayBuffer, sample_experience: Experience
150150
) -> None:
151151
assert len(replay_buffer) == 0
152-
replay_buffer.push(sample_experience)
152+
replay_buffer.add(sample_experience)
153153
assert len(replay_buffer) == 1
154154

155155
def test_state_dict_empty_buffer(self, replay_buffer: ReplayBuffer) -> None:
@@ -276,7 +276,7 @@ def test_buffer_warm(self):
276276
next_state=torch.zeros(state_dim, device=device),
277277
done=False,
278278
)
279-
buffer.push(exp)
279+
buffer.add(exp)
280280

281281
# Verify buffer length increases
282282
assert len(buffer) == n_samples + 1
@@ -318,7 +318,7 @@ def sample_experience(self) -> Experience:
318318
def filled_buffer(self, rollout_buffer: RolloutBuffer) -> RolloutBuffer:
319319
"""Fixture that returns a filled rollout buffer with 3 experiences."""
320320
for i in range(3):
321-
rollout_buffer.push(
321+
rollout_buffer.add(
322322
Experience(
323323
state=torch.tensor([float(i), float(i + 1)]),
324324
action=torch.tensor([i]),
@@ -341,7 +341,7 @@ def test_config(self, rollout_buffer: RolloutBuffer):
341341
def test_push_experience(
342342
self, rollout_buffer: RolloutBuffer, sample_experience: Experience
343343
) -> None:
344-
rollout_buffer.push(sample_experience)
344+
rollout_buffer.add(sample_experience)
345345
assert len(rollout_buffer) == 1
346346
assert isinstance(rollout_buffer.buffer[0], Experience)
347347

@@ -350,11 +350,11 @@ def test_buffer_capacity_error(
350350
) -> None:
351351
# Fill buffer to capacity
352352
for _ in range(5):
353-
rollout_buffer.push(sample_experience)
353+
rollout_buffer.add(sample_experience)
354354

355355
# Attempt to push when buffer is full
356356
with pytest.raises(BufferError):
357-
rollout_buffer.push(sample_experience)
357+
rollout_buffer.add(sample_experience)
358358

359359
def test_sample_empty_buffer(self, rollout_buffer: RolloutBuffer) -> None:
360360
with pytest.raises(BufferError) as exc_info:
@@ -367,7 +367,7 @@ def test_sample_buffer(
367367
# Fill buffer with experiences
368368
num_experiences = 3
369369
for _ in range(num_experiences):
370-
rollout_buffer.push(sample_experience)
370+
rollout_buffer.add(sample_experience)
371371

372372
batch = rollout_buffer.sample()
373373

@@ -384,7 +384,7 @@ def test_clear_buffer(
384384
) -> None:
385385
# Add some experiences
386386
for _ in range(3):
387-
rollout_buffer.push(sample_experience)
387+
rollout_buffer.add(sample_experience)
388388
assert len(rollout_buffer) == 3
389389

390390
# Clear buffer
@@ -395,9 +395,9 @@ def test_len_method(
395395
self, rollout_buffer: RolloutBuffer, sample_experience: Experience
396396
) -> None:
397397
assert len(rollout_buffer) == 0
398-
rollout_buffer.push(sample_experience)
398+
rollout_buffer.add(sample_experience)
399399
assert len(rollout_buffer) == 1
400-
rollout_buffer.push(sample_experience)
400+
rollout_buffer.add(sample_experience)
401401
assert len(rollout_buffer) == 2
402402
rollout_buffer.empty()
403403
assert len(rollout_buffer) == 0
@@ -514,7 +514,7 @@ def test_empty_after_save(self, filled_buffer: RolloutBuffer) -> None:
514514
assert len(loaded_buffer) == 3 # Original size before emptying
515515

516516
# Add more experiences to emptied buffer
517-
filled_buffer.push(
517+
filled_buffer.add(
518518
Experience(
519519
state=torch.tensor([10.0, 11.0]),
520520
action=10.0,
@@ -534,7 +534,7 @@ def test_load_and_continue_filling(self) -> None:
534534
# Create and fill a buffer
535535
buffer = RolloutBuffer(capacity=5)
536536
for i in range(3):
537-
buffer.push(
537+
buffer.add(
538538
Experience(
539539
state=torch.tensor([float(i), float(i + 1)]),
540540
action=torch.tensor([i]),
@@ -556,7 +556,7 @@ def test_load_and_continue_filling(self) -> None:
556556
assert len(loaded_buffer) == 3
557557

558558
# Add more experiences
559-
loaded_buffer.push(
559+
loaded_buffer.add(
560560
Experience(
561561
state=torch.tensor([10.0, 11.0]),
562562
action=torch.tensor([10.0]),
@@ -569,7 +569,7 @@ def test_load_and_continue_filling(self) -> None:
569569
assert len(loaded_buffer) == 4
570570

571571
# Try to add experiences up to capacity
572-
loaded_buffer.push(
572+
loaded_buffer.add(
573573
Experience(
574574
state=torch.tensor([11.0, 12.0]),
575575
action=torch.tensor([11.0]),
@@ -583,7 +583,7 @@ def test_load_and_continue_filling(self) -> None:
583583

584584
# Should raise error on next push
585585
with pytest.raises(BufferError, match="Buffer full!"):
586-
loaded_buffer.push(
586+
loaded_buffer.add(
587587
Experience(
588588
state=torch.tensor([12.0, 13.0]),
589589
action=torch.tensor([12.0]),

velora/buffer/base.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,24 @@ def __init__(self, capacity: int, *, device: torch.device | None = None) -> None
2727
self.buffer: Deque[Experience] = deque(maxlen=capacity)
2828
self.device = device
2929

30-
def push(self, exp: Experience) -> None:
30+
def add(self, exp: Experience) -> None:
3131
"""
32-
Stores an experience in the buffer.
32+
Adds a single experience to the buffer.
3333
3434
Parameters:
35-
exp (Experience): a single set of experience as an object
35+
exp (Experience): a single set of experience
3636
"""
3737
self.buffer.append(exp)
3838

39+
def add_multi(self, exp: List[Experience]) -> None:
40+
"""
41+
Adds multiple experiences to the buffer.
42+
43+
Parameters:
44+
exp (List[Experience]): a list of experience
45+
"""
46+
self.buffer.extend(exp)
47+
3948
def _batch(self, batch: List[Experience]) -> BatchExperience:
4049
"""
4150
Helper method. Converts a `List[Experience]` into a `BatchExperience`.
@@ -163,7 +172,7 @@ def load(cls, filepath: str | Path) -> Self:
163172
data["next_states"],
164173
data["dones"],
165174
):
166-
buffer.push(
175+
buffer.add(
167176
Experience(
168177
state=to_tensor(state, device=device),
169178
action=to_tensor(action, device=device),

velora/buffer/replay.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
try:
55
from typing import override
6-
except ImportError:
6+
except ImportError: # pragma: no cover
77
from typing_extensions import override # pragma: no cover
88

99
import gymnasium as gym
@@ -84,7 +84,7 @@ def warm(self, agent: RLAgent, env_name: str, n_samples: int) -> None:
8484
next_state, reward, terminated, truncated, _ = env.step(action)
8585
done = terminated or truncated
8686

87-
self.push(Experience(state, action, reward, next_state, done))
87+
self.add(Experience(state, action, reward, next_state, done))
8888

8989
state = next_state
9090

velora/buffer/rollout.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
try:
22
from typing import override
3-
except ImportError:
3+
except ImportError: # pragma: no cover
44
from typing_extensions import override # pragma: no cover
55

66
import torch
@@ -36,7 +36,7 @@ def config(self) -> BufferConfig:
3636
return BufferConfig(type="RolloutBuffer", capacity=self.capacity)
3737

3838
@override
39-
def push(self, exp: Experience) -> None:
39+
def add(self, exp: Experience) -> None:
4040
"""
4141
Stores an experience in the buffer.
4242
@@ -46,7 +46,7 @@ def push(self, exp: Experience) -> None:
4646
if len(self.buffer) == self.capacity:
4747
raise BufferError("Buffer full! Use the 'empty()' method first.")
4848

49-
super().push(exp)
49+
super().add(exp)
5050

5151
@override
5252
def sample(self) -> BatchExperience:

velora/callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
try:
77
from typing import override
8-
except ImportError:
8+
except ImportError: # pragma: no cover
99
from typing_extensions import override # pragma: no cover
1010

1111
import gymnasium as gym

0 commit comments

Comments
 (0)