Skip to content

Commit 8570c25

Browse files
authored
[Feature] Weight loss outputs when using prioritized sampler (#3235)
1 parent 3d5dd1a commit 8570c25

File tree

15 files changed

+760
-243
lines changed

15 files changed

+760
-243
lines changed

test/test_cost.py

Lines changed: 574 additions & 60 deletions
Large diffs are not rendered by default.

test/test_rb.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1804,9 +1804,9 @@ def test_batch_errors():
18041804

18051805
@pytest.mark.skipif(not torchrl._utils.RL_WARNINGS, reason="RL_WARNINGS is not set")
18061806
def test_add_warning():
1807-
from torchrl._utils import RL_WARNINGS
1807+
from torchrl._utils import rl_warnings
18081808

1809-
if not RL_WARNINGS:
1809+
if not rl_warnings():
18101810
return
18111811
rb = ReplayBuffer(storage=ListStorage(10), batch_size=3)
18121812
with pytest.warns(

torchrl/_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,3 +1062,8 @@ def merge_ray_runtime_env(ray_init_config: dict[str, Any]) -> dict[str, Any]:
10621062
runtime_env["env_vars"] = dict(runtime_env["env_vars"])
10631063

10641064
return ray_init_config
1065+
1066+
1067+
def rl_warnings():
1068+
"""Checks the status of the RL_WARNINGS env varioble."""
1069+
return RL_WARNINGS

torchrl/collectors/collectors.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
compile_with_warmup,
4848
logger as torchrl_logger,
4949
prod,
50-
RL_WARNINGS,
50+
rl_warnings,
5151
VERBOSE,
5252
)
5353
from torchrl.collectors.utils import split_trajectories
@@ -1218,7 +1218,7 @@ def _setup_total_frames(self, total_frames: int, frames_per_batch: int) -> None:
12181218
total_frames = float("inf")
12191219
else:
12201220
remainder = total_frames % frames_per_batch
1221-
if remainder != 0 and RL_WARNINGS:
1221+
if remainder != 0 and rl_warnings():
12221222
warnings.warn(
12231223
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). "
12241224
f"This means {frames_per_batch - remainder} additional frames will be collected."
@@ -1238,7 +1238,7 @@ def _setup_init_random_frames(
12381238
if (
12391239
init_random_frames not in (-1, None, 0)
12401240
and init_random_frames % frames_per_batch != 0
1241-
and RL_WARNINGS
1241+
and rl_warnings()
12421242
):
12431243
warnings.warn(
12441244
f"init_random_frames ({init_random_frames}) is not exactly a multiple of frames_per_batch ({frames_per_batch}), "
@@ -1261,7 +1261,7 @@ def _setup_postproc(self, postproc: Callable | None) -> None:
12611261

12621262
def _setup_frames_per_batch(self, frames_per_batch: int) -> None:
12631263
"""Calculate and validate frames per batch."""
1264-
if frames_per_batch % self.n_env != 0 and RL_WARNINGS:
1264+
if frames_per_batch % self.n_env != 0 and rl_warnings():
12651265
warnings.warn(
12661266
f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), "
12671267
f" this results in more frames_per_batch per iteration that requested"
@@ -2809,7 +2809,7 @@ def _setup_multi_total_frames(
28092809
total_frames = float("inf")
28102810
else:
28112811
remainder = total_frames % total_frames_per_batch
2812-
if remainder != 0 and RL_WARNINGS:
2812+
if remainder != 0 and rl_warnings():
28132813
warnings.warn(
28142814
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({total_frames_per_batch}). "
28152815
f"This means {total_frames_per_batch - remainder} additional frames will be collected. "
@@ -3741,7 +3741,7 @@ def update_policy_weights_(
37413741
def frames_per_batch_worker(self, worker_idx: int | None) -> int:
37423742
if worker_idx is not None and isinstance(self._frames_per_batch, Sequence):
37433743
return self._frames_per_batch[worker_idx]
3744-
if self.requested_frames_per_batch % self.num_workers != 0 and RL_WARNINGS:
3744+
if self.requested_frames_per_batch % self.num_workers != 0 and rl_warnings():
37453745
warnings.warn(
37463746
f"frames_per_batch {self.requested_frames_per_batch} is not exactly divisible by the number of collector workers {self.num_workers},"
37473747
f" this results in more frames_per_batch per iteration that requested."

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from torch import Tensor
4242
from torch.utils._pytree import tree_map
4343

44-
from torchrl._utils import accept_remote_rref_udf_invocation, RL_WARNINGS
44+
from torchrl._utils import accept_remote_rref_udf_invocation, rl_warnings
4545
from torchrl.data.replay_buffers.samplers import (
4646
PrioritizedSampler,
4747
RandomSampler,
@@ -871,7 +871,7 @@ def add(self, data: Any) -> int:
871871
data = None
872872
if data is None:
873873
return torch.zeros((0, self._storage.ndim), dtype=torch.long)
874-
if RL_WARNINGS and is_tensor_collection(data) and data.ndim:
874+
if rl_warnings() and is_tensor_collection(data) and data.ndim:
875875
warnings.warn(
876876
f"Using `add()` with a TensorDict that has batch_size={data.batch_size}. "
877877
f"Use `extend()` to add multiple elements, or `add()` with a single element (batch_size=torch.Size([])). "
@@ -1319,14 +1319,14 @@ class PrioritizedReplayBuffer(ReplayBuffer):
13191319
>>> # get the info to find what the indices are
13201320
>>> sample, info = rb.sample(5, return_info=True)
13211321
>>> print(sample, info)
1322-
tensor([2, 7, 4, 3, 5]) {'_weight': array([1., 1., 1., 1., 1.], dtype=float32), 'index': array([2, 7, 4, 3, 5])}
1322+
tensor([2, 7, 4, 3, 5]) {'priority_weight': array([1., 1., 1., 1., 1.], dtype=float32), 'index': array([2, 7, 4, 3, 5])}
13231323
>>> # update priority
13241324
>>> priority = torch.ones(5) * 5
13251325
>>> rb.update_priority(info["index"], priority)
13261326
>>> # and now a new sample, the weights should be updated
13271327
>>> sample, info = rb.sample(5, return_info=True)
13281328
>>> print(sample, info)
1329-
tensor([2, 5, 2, 2, 5]) {'_weight': array([0.36278465, 0.36278465, 0.36278465, 0.36278465, 0.36278465],
1329+
tensor([2, 5, 2, 2, 5]) {'priority_weight': array([0.36278465, 0.36278465, 0.36278465, 0.36278465, 0.36278465],
13301330
dtype=float32), 'index': array([2, 5, 2, 2, 5])}
13311331
13321332
"""
@@ -1861,7 +1861,7 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
18611861
>>> print(sample)
18621862
TensorDict(
18631863
fields={
1864-
_weight: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False),
1864+
priority_weight: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False),
18651865
a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
18661866
b: TensorDict(
18671867
fields={
@@ -1884,7 +1884,7 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
18841884
>>> print(sample)
18851885
TensorDict(
18861886
fields={
1887-
_weight: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False),
1887+
priority_weight: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False),
18881888
a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
18891889
b: TensorDict(
18901890
fields={

torchrl/data/replay_buffers/samplers.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from tensordict.utils import NestedKey
2222
from torch.utils._pytree import tree_map
2323
from torchrl._extension import EXTENSION_WARNING
24-
from torchrl._utils import _replace_last, logger, RL_WARNINGS
24+
from torchrl._utils import _replace_last, logger, rl_warnings
2525
from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage
2626
from torchrl.data.replay_buffers.utils import _auto_device, _is_int, unravel_index
2727

@@ -373,7 +373,7 @@ class PrioritizedSampler(Sampler):
373373
device=cpu,
374374
is_shared=False)
375375
>>> print(info)
376-
{'_weight': array([1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11,
376+
{'priority_weight': array([1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11,
377377
1.e-11, 1.e-11], dtype=float32), 'index': array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])}
378378
379379
.. note:: Using a :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer` can smoothen the
@@ -423,7 +423,7 @@ def __init__(
423423
self.dtype = dtype
424424
self._max_priority_within_buffer = max_priority_within_buffer
425425
self._init()
426-
if RL_WARNINGS and SumSegmentTreeFp32 is None:
426+
if rl_warnings() and SumSegmentTreeFp32 is None:
427427
logger.warning(EXTENSION_WARNING)
428428

429429
def __repr__(self):
@@ -588,7 +588,7 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor:
588588
weight = torch.pow(weight / p_min, -self._beta)
589589
if storage.ndim > 1:
590590
index = unravel_index(index, storage.shape)
591-
return index, {"_weight": weight}
591+
return index, {"priority_weight": weight}
592592

593593
def add(self, index: torch.Tensor | int) -> None:
594594
super().add(index)
@@ -2068,7 +2068,7 @@ class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
20682068
episode [2, 2, 2, 2, 1, 1]
20692069
>>> print("steps", sample["steps"].tolist())
20702070
steps [1, 2, 0, 1, 1, 2]
2071-
>>> print("weight", info["_weight"].tolist())
2071+
>>> print("weight", info["priority_weight"].tolist())
20722072
weight [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
20732073
>>> priority = torch.tensor([0,3,3,0,0,0,1,1,1])
20742074
>>> rb.update_priority(torch.arange(0,9,1), priority=priority)
@@ -2077,7 +2077,7 @@ class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
20772077
episode [2, 2, 2, 2, 2, 2]
20782078
>>> print("steps", sample["steps"].tolist())
20792079
steps [1, 2, 0, 1, 0, 1]
2080-
>>> print("weight", info["_weight"].tolist())
2080+
>>> print("weight", info["priority_weight"].tolist())
20812081
weight [9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06]
20822082
"""
20832083

@@ -2294,15 +2294,19 @@ def sample(self, storage: Storage, batch_size: int) -> tuple[torch.Tensor, dict]
22942294
if isinstance(starts, tuple):
22952295
starts = torch.stack(starts, -1)
22962296
# starts = torch.as_tensor(starts, device=lengths.device)
2297-
info["_weight"] = torch.as_tensor(info["_weight"], device=lengths.device)
2297+
info["priority_weight"] = torch.as_tensor(
2298+
info["priority_weight"], device=lengths.device
2299+
)
22982300

22992301
# extends starting indices of each slice with sequence_length to get indices of all steps
23002302
index = self._tensor_slices_from_startend(
23012303
seq_length, starts, storage_length=storage.shape[0]
23022304
)
23032305

23042306
# repeat the weight of each slice to match the number of steps
2305-
info["_weight"] = torch.repeat_interleave(info["_weight"], seq_length)
2307+
info["priority_weight"] = torch.repeat_interleave(
2308+
info["priority_weight"], seq_length
2309+
)
23062310

23072311
if self.truncated_key is not None:
23082312
# following logics borrowed from SliceSampler

torchrl/objectives/common.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torch import nn
2020
from torch.nn import Parameter
2121

22-
from torchrl._utils import RL_WARNINGS
22+
from torchrl._utils import rl_warnings
2323
from torchrl.envs.utils import ExplorationType, set_exploration_type
2424
from torchrl.modules.tensordict_module.rnn import set_recurrent_mode
2525
from torchrl.objectives.utils import ValueEstimators
@@ -34,7 +34,7 @@
3434
def _updater_check_forward_prehook(module, *args, **kwargs):
3535
if (
3636
not all(module._has_update_associated.values())
37-
and RL_WARNINGS
37+
and rl_warnings()
3838
and not is_compiling()
3939
):
4040
warnings.warn(
@@ -128,6 +128,7 @@ class _AcceptedKeys:
128128
tensor_keys: _AcceptedKeys
129129
_vmap_randomness = None
130130
default_value_estimator: ValueEstimators = None
131+
use_prioritized_weights: str | bool = "auto"
131132

132133
deterministic_sampling_mode: ExplorationType = ExplorationType.DETERMINISTIC
133134

@@ -449,7 +450,7 @@ def __getattr__(self, item):
449450
params = params.data
450451
elif (
451452
not self._has_update_associated[item[7:-7]]
452-
and RL_WARNINGS
453+
and rl_warnings()
453454
and not is_compiling()
454455
):
455456
# no updater associated
@@ -491,6 +492,25 @@ def reset(self) -> None:
491492
# mainly used for PPO with KL target
492493
pass
493494

495+
def _maybe_get_priority_weight(
496+
self, tensordict: TensorDictBase
497+
) -> torch.Tensor | None:
498+
"""Extract priority weights from tensordict if prioritized replay is enabled.
499+
500+
Args:
501+
tensordict (TensorDictBase): The input tensordict that may contain priority weights.
502+
503+
Returns:
504+
torch.Tensor | None: The priority weights if available and enabled, None otherwise.
505+
"""
506+
weights = None
507+
if (
508+
self.use_prioritized_weights in (True, "auto")
509+
and self.tensor_keys.priority_weight in tensordict.keys()
510+
):
511+
weights = tensordict.get(self.tensor_keys.priority_weight)
512+
return weights
513+
494514
def _reset_module_parameters(self, module_name, module):
495515
params_name = f"{module_name}_params"
496516
target_name = f"target_{module_name}_params"

torchrl/objectives/ddpg.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ class _AcceptedKeys:
171171
reward: NestedKey = "reward"
172172
done: NestedKey = "done"
173173
terminated: NestedKey = "terminated"
174+
priority_weight: NestedKey = "priority_weight"
174175

175176
tensor_keys: _AcceptedKeys
176177
default_keys = _AcceptedKeys
@@ -202,11 +203,13 @@ def __init__(
202203
gamma: float | None = None,
203204
separate_losses: bool = False,
204205
reduction: str | None = None,
206+
use_prioritized_weights: str | bool = "auto",
205207
) -> None:
206208
self._in_keys = None
207209
if reduction is None:
208210
reduction = "mean"
209211
super().__init__()
212+
self.use_prioritized_weights = use_prioritized_weights
210213
self.delay_actor = delay_actor
211214
self.delay_value = delay_value
212215

@@ -268,6 +271,8 @@ def _set_in_keys(self):
268271
*self.value_network.in_keys,
269272
*[unravel_key(("next", key)) for key in self.value_network.in_keys],
270273
}
274+
if self.use_prioritized_weights:
275+
in_keys.add(unravel_key(self.tensor_keys.priority_weight))
271276
self._in_keys = sorted(in_keys, key=str)
272277

273278
@property
@@ -316,6 +321,7 @@ def loss_actor(
316321
self,
317322
tensordict: TensorDictBase,
318323
) -> [torch.Tensor, dict]:
324+
weights = self._maybe_get_priority_weight(tensordict)
319325
td_copy = tensordict.select(
320326
*self.actor_in_keys, *self.value_exclusive_keys, strict=False
321327
).detach()
@@ -325,7 +331,7 @@ def loss_actor(
325331
td_copy = self.value_network(td_copy)
326332
loss_actor = -td_copy.get(self.tensor_keys.state_action_value).squeeze(-1)
327333
metadata = {}
328-
loss_actor = _reduce(loss_actor, self.reduction)
334+
loss_actor = _reduce(loss_actor, self.reduction, weights=weights)
329335
self._clear_weakrefs(
330336
tensordict,
331337
loss_actor,
@@ -340,6 +346,7 @@ def loss_value(
340346
self,
341347
tensordict: TensorDictBase,
342348
) -> tuple[torch.Tensor, dict]:
349+
weights = self._maybe_get_priority_weight(tensordict)
343350
# value loss
344351
td_copy = tensordict.select(*self.value_network.in_keys, strict=False).detach()
345352
with self.value_network_params.to_module(self.value_network):
@@ -372,7 +379,7 @@ def loss_value(
372379
"target_value_max": target_value.max(),
373380
"pred_value_max": pred_val.max(),
374381
}
375-
loss_value = _reduce(loss_value, self.reduction)
382+
loss_value = _reduce(loss_value, self.reduction, weights=weights)
376383
self._clear_weakrefs(
377384
tensordict,
378385
"value_network_params",

0 commit comments

Comments
 (0)