@@ -4514,6 +4514,61 @@ def test_compressed_storage_memory_efficiency(self):
45144514 ), f"Compression ratio { compression_ratio } is too low"
45154515
45164516
4517+ class TestRBLazyInit :
4518+ def test_lazy_init (self ):
4519+ def transform (td ):
4520+ return td
4521+
4522+ rb = ReplayBuffer (
4523+ storage = partial (ListStorage ),
4524+ writer = partial (RoundRobinWriter ),
4525+ sampler = partial (RandomSampler ),
4526+ transform_factory = lambda : transform ,
4527+ )
4528+ assert not rb .initialized
4529+ assert not hasattr (rb , "_storage" )
4530+ assert rb ._init_storage is not None
4531+ assert not hasattr (rb , "_sampler" )
4532+ assert rb ._init_sampler is not None
4533+ assert not hasattr (rb , "_writer" )
4534+ assert rb ._init_writer is not None
4535+ rb .extend (TensorDict (batch_size = [2 ]))
4536+ assert rb .initialized
4537+ assert rb ._storage is not None
4538+ assert rb ._init_storage is None
4539+ assert rb ._sampler is not None
4540+ assert rb ._init_sampler is None
4541+ assert rb ._writer is not None
4542+ assert rb ._init_writer is None
4543+
4544+ rb = ReplayBuffer (
4545+ storage = partial (ListStorage ),
4546+ writer = partial (RoundRobinWriter ),
4547+ sampler = partial (RandomSampler ),
4548+ )
4549+ assert rb .initialized
4550+ assert rb ._storage is not None
4551+ assert rb ._init_storage is None
4552+ assert rb ._sampler is not None
4553+ assert rb ._init_sampler is None
4554+ assert rb ._writer is not None
4555+ assert rb ._init_writer is None
4556+
4557+ rb = ReplayBuffer (
4558+ storage = partial (ListStorage ),
4559+ writer = partial (RoundRobinWriter ),
4560+ sampler = partial (RandomSampler ),
4561+ delayed_init = False ,
4562+ )
4563+ assert rb .initialized
4564+ assert rb ._storage is not None
4565+ assert rb ._init_storage is None
4566+ assert rb ._sampler is not None
4567+ assert rb ._init_sampler is None
4568+ assert rb ._writer is not None
4569+ assert rb ._init_writer is None
4570+
4571+
45174572if __name__ == "__main__" :
45184573 args , unknown = argparse .ArgumentParser ().parse_known_args ()
45194574 pytest .main ([__file__ , "--capture" , "no" , "--exitfirst" ] + unknown )
0 commit comments