Skip to content

Commit 28c3489

Browse files
committed
Update (base update)
[ghstack-poisoned]
1 parent 463785c commit 28c3489

File tree

2 files changed

+56
-1
lines changed

2 files changed

+56
-1
lines changed

test/test_libs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2790,7 +2790,7 @@ class TestVmas:
27902790
@pytest.mark.parametrize("scenario_name", VmasWrapper.available_envs)
27912791
@pytest.mark.parametrize("continuous_actions", [True, False])
27922792
def test_all_vmas_scenarios(self, scenario_name, continuous_actions):
2793-
2793+
27942794
env = VmasEnv(
27952795
scenario=scenario_name,
27962796
continuous_actions=continuous_actions,

test/test_rb.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4514,6 +4514,61 @@ def test_compressed_storage_memory_efficiency(self):
45144514
), f"Compression ratio {compression_ratio} is too low"
45154515

45164516

4517+
class TestRBLazyInit:
4518+
def test_lazy_init(self):
4519+
def transform(td):
4520+
return td
4521+
4522+
rb = ReplayBuffer(
4523+
storage=partial(ListStorage),
4524+
writer=partial(RoundRobinWriter),
4525+
sampler=partial(RandomSampler),
4526+
transform_factory=lambda: transform,
4527+
)
4528+
assert not rb.initialized
4529+
assert not hasattr(rb, "_storage")
4530+
assert rb._init_storage is not None
4531+
assert not hasattr(rb, "_sampler")
4532+
assert rb._init_sampler is not None
4533+
assert not hasattr(rb, "_writer")
4534+
assert rb._init_writer is not None
4535+
rb.extend(TensorDict(batch_size=[2]))
4536+
assert rb.initialized
4537+
assert rb._storage is not None
4538+
assert rb._init_storage is None
4539+
assert rb._sampler is not None
4540+
assert rb._init_sampler is None
4541+
assert rb._writer is not None
4542+
assert rb._init_writer is None
4543+
4544+
rb = ReplayBuffer(
4545+
storage=partial(ListStorage),
4546+
writer=partial(RoundRobinWriter),
4547+
sampler=partial(RandomSampler),
4548+
)
4549+
assert rb.initialized
4550+
assert rb._storage is not None
4551+
assert rb._init_storage is None
4552+
assert rb._sampler is not None
4553+
assert rb._init_sampler is None
4554+
assert rb._writer is not None
4555+
assert rb._init_writer is None
4556+
4557+
rb = ReplayBuffer(
4558+
storage=partial(ListStorage),
4559+
writer=partial(RoundRobinWriter),
4560+
sampler=partial(RandomSampler),
4561+
delayed_init=False,
4562+
)
4563+
assert rb.initialized
4564+
assert rb._storage is not None
4565+
assert rb._init_storage is None
4566+
assert rb._sampler is not None
4567+
assert rb._init_sampler is None
4568+
assert rb._writer is not None
4569+
assert rb._init_writer is None
4570+
4571+
45174572
if __name__ == "__main__":
45184573
args, unknown = argparse.ArgumentParser().parse_known_args()
45194574
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 commit comments

Comments
 (0)