Skip to content

Commit

Permalink
[Feature] TensorDictModule method and kwargs specification
Browse files Browse the repository at this point in the history
ghstack-source-id: a97fca4c78f5d5c2813d3396e3dd440e2d4e0a4a
Pull Request resolved: #1228
  • Loading branch information
vmoens committed Feb 20, 2025
1 parent 28fbea1 commit bbd5ba8
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 2 deletions.
38 changes: 36 additions & 2 deletions tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ class TensorDictModule(TensorDictModuleBase):
"""A TensorDictModule, is a python wrapper around a :obj:`nn.Module` that reads and writes to a TensorDict.
Args:
module (Callable): a callable, typically a :class:`torch.nn.Module`,
module (Callable[[Any], Any]): a callable, typically a :class:`torch.nn.Module`,
used to map the input to the output parameter space. Its forward method
can return a single tensor, a tuple of tensors or even a dictionary.
In the latter case, the output keys of the :class:`TensorDictModule`
Expand Down Expand Up @@ -846,6 +846,8 @@ class TensorDictModule(TensorDictModuleBase):
:class:`~tensordict.TensorDictBase` subclass than :class:`~tensordict.TensorDict`, the output will still
be a :class:`~tensordict.TensorDict` instance.
method (str, optional): the method to be called in the module, if any. Defaults to `__call__`.
method_kwargs (Dict[str, Any], optional): additional keyword arguments to be passed to the module's method being called.
Embedding a neural network in a TensorDictModule only requires to specify the input
and output keys. TensorDictModule support functional and regular :obj:`nn.Module`
Expand Down Expand Up @@ -931,6 +933,30 @@ class TensorDictModule(TensorDictModuleBase):
>>> td['t']
tensor(5.)
We can specify the method to be called within a module. Compared to using a lambda function or similar around the
module's method, this has the advantage that the module attributes (params, buffers, submodules) will be exposed.
Examples:
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>> from torch import nn
>>> import torch
>>>
>>> class MyNet(nn.Module):
... def my_func(self, tensor: torch.Tensor, *, an_integer: int):
... return tensor + an_integer
...
>>> s = Seq(
... {
... "a": lambda td: td+1,
... "b": lambda td: td * 2,
... "c": Mod(MyNet(), in_keys=["a"], out_keys=["b"], method="my_func", method_kwargs={"an_integer": 2}),
... }
... )
>>> td = s(TensorDict(a=0))
>>> print(td)
>>>
>>> assert td["b"] == 4
Functional calls to a tensordict module is easy:
Expand Down Expand Up @@ -986,6 +1012,8 @@ def __init__(
*,
out_to_in_map: bool | None = None,
inplace: bool | str = True,
method: str | None = None,
method_kwargs: dict | None = None,
) -> None:
super().__init__()

Expand Down Expand Up @@ -1059,6 +1087,8 @@ def __init__(
"instead."
)
self.inplace = inplace
self.method = method
self.method_kwargs = method_kwargs if method_kwargs is not None else {}

@property
def is_functional(self) -> bool:
Expand Down Expand Up @@ -1106,7 +1136,11 @@ def _write_to_tensordict(
def _call_module(
self, tensors: Sequence[Tensor], **kwargs: Any
) -> Tensor | Sequence[Tensor]:
out = self.module(*tensors, **kwargs)
kwargs.update(self.method_kwargs)
if self.method is None:
out = self.module(*tensors, **kwargs)
else:
out = getattr(self.module, self.method)(*tensors, **kwargs)
return out

@dispatch(auto_batch_size=False)
Expand Down
23 changes: 23 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,29 @@ def __len__(self):
def insert(self, index, value):
self._data.insert(index, value)

def test_module_method_and_kwargs(self):

class MyNet(nn.Module):
def my_func(self, tensor: torch.Tensor, *, an_integer: int):
return tensor + an_integer

s = TensorDictSequential(
{
"a": lambda td: td + 1,
"b": lambda td: td * 2,
"c": TensorDictModule(
MyNet(),
in_keys=["a"],
out_keys=["b"],
method="my_func",
method_kwargs={"an_integer": 2},
),
}
)
td = s(TensorDict(a=0))

assert td["b"] == 4

def test_mutable_sequence(self):
in_keys = self.MyMutableSequence(["a", "b", "c"])
out_keys = self.MyMutableSequence(["d", "e", "f"])
Expand Down

1 comment on commit bbd5ba8

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: bbd5ba8 Previous: ad0a8dd Ratio
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_last 75380.7532361309 iter/sec (stddev: 0.000001321250127301775) 235391.76891241295 iter/sec (stddev: 3.2118570721111605e-7) 3.12
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_leaf_last 75527.82174950559 iter/sec (stddev: 8.796865387755741e-7) 233546.6655456368 iter/sec (stddev: 4.175278496273709e-7) 3.09

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.