Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,7 @@ to be able to create this other composition:
InitTracker
KLRewardTransform
LineariseRewards
ModuleTransform
MultiAction
NoopResetEnv
ObservationNorm
Expand Down
131 changes: 129 additions & 2 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
TensorDictBase,
unravel_key,
)
from tensordict.nn import TensorDictSequential, WrapModule
from tensordict.nn import TensorDictModule, TensorDictSequential, WrapModule
from tensordict.utils import _unravel_key_to_tuple, assert_allclose_td
from torch import multiprocessing as mp, nn, Tensor
from torchrl._utils import _replace_last, prod, set_auto_unwrap_transformed_env
Expand Down Expand Up @@ -122,8 +122,9 @@
from torchrl.envs.libs.dm_control import _has_dm_control
from torchrl.envs.libs.gym import _has_gym, GymEnv, set_gym_backend
from torchrl.envs.libs.unity_mlagents import _has_unity_mlagents
from torchrl.envs.transforms import VecNorm
from torchrl.envs.transforms import ModuleTransform, VecNorm
from torchrl.envs.transforms.llm import KLRewardTransform
from torchrl.envs.transforms.module import RayModuleTransform
from torchrl.envs.transforms.r3m import _R3MNet
from torchrl.envs.transforms.transforms import (
_has_tv,
Expand Down Expand Up @@ -198,6 +199,8 @@
StateLessCountingEnv,
)

_has_ray = importlib.util.find_spec("ray") is not None

IS_WIN = platform == "win32"
if IS_WIN:
mp_ctx = "spawn"
Expand Down Expand Up @@ -14888,6 +14891,130 @@ def test_transform_inverse(self):
return


class TestModuleTransform(TransformBase):
@property
def _module_factory_samespec(self):
return partial(
TensorDictModule,
nn.LazyLinear(7),
in_keys=["observation"],
out_keys=["observation"],
)

@property
def _module_factory_samespec_inverse(self):
return partial(
TensorDictModule, nn.LazyLinear(7), in_keys=["action"], out_keys=["action"]
)

def _single_env_maker(self):
base_env = ContinuousActionVecMockEnv()
t = ModuleTransform(module_factory=self._module_factory_samespec)
return base_env.append_transform(t)

def test_single_trans_env_check(self):
env = self._single_env_maker()
env.check_env_specs()

def test_serial_trans_env_check(self):
env = SerialEnv(2, self._single_env_maker)
try:
env.check_env_specs()
finally:
env.close(raise_if_closed=False)

def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv):
env = maybe_fork_ParallelEnv(2, self._single_env_maker)
try:
env.check_env_specs()
finally:
env.close(raise_if_closed=False)

def test_trans_serial_env_check(self):
env = SerialEnv(2, ContinuousActionVecMockEnv)
try:
env = env.append_transform(
ModuleTransform(module_factory=self._module_factory_samespec)
)
env.check_env_specs()
finally:
env.close(raise_if_closed=False)

def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
env = maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv)
try:
env = env.append_transform(
ModuleTransform(module_factory=self._module_factory_samespec)
)
env.check_env_specs()
finally:
env.close(raise_if_closed=False)

def test_transform_no_env(self):
t = ModuleTransform(module_factory=self._module_factory_samespec)
td = t(TensorDict(observation=torch.randn(2, 3), batch_size=[2]))
assert td["observation"].shape == (2, 7)

def test_transform_compose(self):
t = Compose(ModuleTransform(module_factory=self._module_factory_samespec))
td = t(TensorDict(observation=torch.randn(2, 3), batch_size=[2]))
assert td["observation"].shape == (2, 7)

def test_transform_env(self):
# TODO: We should give users the opportunity to modify the specs
env = self._single_env_maker()
env.check_env_specs()

def test_transform_model(self):
t = nn.Sequential(
Compose(ModuleTransform(module_factory=self._module_factory_samespec))
)
td = t(TensorDict(observation=torch.randn(2, 3), batch_size=[2]))
assert td["observation"].shape == (2, 7)

def test_transform_rb(self):
t = ModuleTransform(module_factory=self._module_factory_samespec)
rb = ReplayBuffer(transform=t)
rb.extend(TensorDict(observation=torch.randn(2, 3), batch_size=[2]))
assert rb._storage._storage[0]["observation"].shape == (3,)
s = rb.sample(2)
assert s["observation"].shape == (2, 7)

rb = ReplayBuffer()
rb.append_transform(t, invert=True)
rb.extend(TensorDict(observation=torch.randn(2, 3), batch_size=[2]))
assert rb._storage._storage[0]["observation"].shape == (7,)
s = rb.sample(2)
assert s["observation"].shape == (2, 7)

def test_transform_inverse(self):
t = ModuleTransform(
module_factory=self._module_factory_samespec_inverse, inverse=True
)
env = ContinuousActionVecMockEnv().append_transform(t)
env.check_env_specs()

@pytest.mark.skipif(not _has_ray, reason="ray required")
def test_ray_extension(self):
import ray

# Check if ray is initialized
ray_init = ray.is_initialized
try:
t = ModuleTransform(
module_factory=self._module_factory_samespec,
use_ray_service=True,
actor_name="my_transform",
)
env = ContinuousActionVecMockEnv().append_transform(t)
assert isinstance(t, RayModuleTransform)
env.check_env_specs()
assert ray.get_actor("my_transform") is not None
finally:
if not ray_init:
ray.stop()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
10 changes: 3 additions & 7 deletions torchrl/envs/llm/transforms/dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@

from torchrl.data.tensor_specs import Composite, DEVICE_TYPING, TensorSpec
from torchrl.envs.common import EnvBase
from torchrl.envs.transforms import TensorDictPrimer, Transform

# Import ray service components
from torchrl.envs.llm.transforms.ray_service import (
from torchrl.envs.transforms.ray_service import (
_map_input_output_device,
_RayServiceMetaClass,
RayTransform,
)
from torchrl.envs.transforms.transforms import TensorDictPrimer, Transform
from torchrl.envs.utils import make_composite_from_td

T = TypeVar("T")
Expand Down Expand Up @@ -259,7 +259,7 @@ def primers(self):
@primers.setter
def primers(self, value: TensorSpec):
"""Set primers property."""
self._ray.get(self._actor.set_attr.remote("primers", value))
self._ray.get(self._actor._set_attr.remote("primers", value))

# TensorDictPrimer methods
def init(self, tensordict: TensorDictBase | None):
Expand Down Expand Up @@ -857,7 +857,3 @@ def _update_primers_batch_size(self, parent_batch_size):
def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}(primers={self.primers}, dataloader={self.dataloader})"

def set_attr(self, name, value):
"""Set attribute on the remote actor or locally."""
setattr(self, name, value)
2 changes: 1 addition & 1 deletion torchrl/envs/llm/transforms/kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torchrl.data import Composite, Unbounded
from torchrl.data.tensor_specs import DEVICE_TYPING
from torchrl.envs import EnvBase, Transform
from torchrl.envs.llm.transforms.ray_service import _RayServiceMetaClass, RayTransform
from torchrl.envs.transforms.ray_service import _RayServiceMetaClass, RayTransform
from torchrl.envs.transforms.transforms import Compose
from torchrl.envs.transforms.utils import _set_missing_tolerance
from torchrl.modules.llm.policies.common import LLMWrapperBase
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/llm/transforms/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch

from tensordict import lazy_stack, TensorDictBase
from torchrl import torchrl_logger
from torchrl._utils import logger as torchrl_logger
from torchrl.data.llm import History

from torchrl.envs import Transform
Expand Down
7 changes: 5 additions & 2 deletions torchrl/envs/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

from .gym_transforms import EndOfLifeTransform
from .llm import KLRewardTransform
from .module import ModuleTransform
from .r3m import R3MTransform
from .ray_service import RayTransform
from .rb_transforms import MultiStepTransform

from .transforms import (
ActionDiscretizer,
ActionMask,
Expand Down Expand Up @@ -85,9 +86,9 @@
"CatFrames",
"CatTensors",
"CenterCrop",
"ConditionalPolicySwitch",
"ClipTransform",
"Compose",
"ConditionalPolicySwitch",
"ConditionalSkip",
"Crop",
"DTypeCastTransform",
Expand All @@ -104,6 +105,7 @@
"InitTracker",
"KLRewardTransform",
"LineariseRewards",
"ModuleTransform",
"MultiAction",
"MultiStepTransform",
"NoopResetEnv",
Expand All @@ -113,6 +115,7 @@
"PinMemoryTransform",
"R3MTransform",
"RandomCropTensorDict",
"RayTransform",
"RemoveEmptySpecs",
"RenameTransform",
"Resize",
Expand Down
Loading
Loading