2121from tensordict .utils import NestedKey
2222from torch .utils ._pytree import tree_map
2323from torchrl ._extension import EXTENSION_WARNING
24- from torchrl ._utils import _replace_last , logger , RL_WARNINGS
24+ from torchrl ._utils import _replace_last , logger , rl_warnings
2525from torchrl .data .replay_buffers .storages import Storage , StorageEnsemble , TensorStorage
2626from torchrl .data .replay_buffers .utils import _auto_device , _is_int , unravel_index
2727
@@ -373,7 +373,7 @@ class PrioritizedSampler(Sampler):
373373 device=cpu,
374374 is_shared=False)
375375 >>> print(info)
376- {'_weight ': array([1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11,
376+ {'priority_weight ': array([1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11,
377377 1.e-11, 1.e-11], dtype=float32), 'index': array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])}
378378
379379 .. note:: Using a :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer` can smoothen the
@@ -423,7 +423,7 @@ def __init__(
423423 self .dtype = dtype
424424 self ._max_priority_within_buffer = max_priority_within_buffer
425425 self ._init ()
426- if RL_WARNINGS and SumSegmentTreeFp32 is None :
426+ if rl_warnings () and SumSegmentTreeFp32 is None :
427427 logger .warning (EXTENSION_WARNING )
428428
429429 def __repr__ (self ):
@@ -588,7 +588,7 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor:
588588 weight = torch .pow (weight / p_min , - self ._beta )
589589 if storage .ndim > 1 :
590590 index = unravel_index (index , storage .shape )
591- return index , {"_weight " : weight }
591+ return index , {"priority_weight " : weight }
592592
593593 def add (self , index : torch .Tensor | int ) -> None :
594594 super ().add (index )
@@ -2068,7 +2068,7 @@ class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
20682068 episode [2, 2, 2, 2, 1, 1]
20692069 >>> print("steps", sample["steps"].tolist())
20702070 steps [1, 2, 0, 1, 1, 2]
2071- >>> print("weight", info["_weight "].tolist())
2071+ >>> print("weight", info["priority_weight "].tolist())
20722072 weight [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
20732073 >>> priority = torch.tensor([0,3,3,0,0,0,1,1,1])
20742074 >>> rb.update_priority(torch.arange(0,9,1), priority=priority)
@@ -2077,7 +2077,7 @@ class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
20772077 episode [2, 2, 2, 2, 2, 2]
20782078 >>> print("steps", sample["steps"].tolist())
20792079 steps [1, 2, 0, 1, 0, 1]
2080- >>> print("weight", info["_weight "].tolist())
2080+ >>> print("weight", info["priority_weight "].tolist())
20812081 weight [9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06]
20822082 """
20832083
@@ -2294,15 +2294,19 @@ def sample(self, storage: Storage, batch_size: int) -> tuple[torch.Tensor, dict]
22942294 if isinstance (starts , tuple ):
22952295 starts = torch .stack (starts , - 1 )
22962296 # starts = torch.as_tensor(starts, device=lengths.device)
2297- info ["_weight" ] = torch .as_tensor (info ["_weight" ], device = lengths .device )
2297+ info ["priority_weight" ] = torch .as_tensor (
2298+ info ["priority_weight" ], device = lengths .device
2299+ )
22982300
22992301 # extends starting indices of each slice with sequence_length to get indices of all steps
23002302 index = self ._tensor_slices_from_startend (
23012303 seq_length , starts , storage_length = storage .shape [0 ]
23022304 )
23032305
23042306 # repeat the weight of each slice to match the number of steps
2305- info ["_weight" ] = torch .repeat_interleave (info ["_weight" ], seq_length )
2307+ info ["priority_weight" ] = torch .repeat_interleave (
2308+ info ["priority_weight" ], seq_length
2309+ )
23062310
23072311 if self .truncated_key is not None :
23082312 # following logics borrowed from SliceSampler
0 commit comments