Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Make tensordict not incompatible with torch.compile #629

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ class TensorDict(TensorDictBase):
_is_shared = False
_is_memmap = False

@torch.compiler.disable()
def __init__(
self,
source: T | dict[str, CompatibleType],
Expand Down Expand Up @@ -313,6 +314,7 @@ def is_empty(self):
return False
return True

@torch.compiler.disable()
def _to_module(
self,
module,
Expand Down Expand Up @@ -432,6 +434,7 @@ def _quick_set(swap_dict, swap_td):
else:
return TensorDict(_swap, batch_size=[], _run_checks=False)

@torch.compiler.disable()
def __ne__(self, other: object) -> T | bool:
if _is_tensorclass(other):
return other != self
Expand Down Expand Up @@ -498,6 +501,7 @@ def __or__(self, other: object) -> T | bool:
)
return False

@torch.compiler.disable()
def __eq__(self, other: object) -> T | bool:
if is_tensorclass(other):
return other == self
Expand Down
4 changes: 4 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def __contains__(self, key: NestedKey) -> bool:
"`key in tensordict.keys()` instead."
)

@torch.compiler.disable()
def __getitem__(self, index: IndexType) -> T:
"""Indexes all tensors according to the provided index.

Expand Down Expand Up @@ -2025,6 +2026,7 @@ def entry_class(self, key: NestedKey) -> type:
"""
...

@torch.compiler.disable()
def set(
self, key: NestedKey, item: CompatibleType, inplace: bool = False, **kwargs: Any
) -> T:
Expand Down Expand Up @@ -2290,6 +2292,7 @@ def _default_get(self, key: NestedKey, default: Any = NO_DEFAULT) -> CompatibleT
_KEY_ERROR.format(key, self.__class__.__name__, sorted(self.keys()))
)

@torch.compiler.disable()
def get(self, key: NestedKey, default: Any = NO_DEFAULT) -> CompatibleType:
"""Gets the value stored with the input key.

Expand Down Expand Up @@ -2697,6 +2700,7 @@ def copy_at_(self, tensordict: T, idx: IndexType) -> T:
"""See :obj:`TensorDictBase.update_at_`."""
return self.update_at_(tensordict, idx)

@torch.compiler.disable()
def is_empty(self) -> bool:
"""Checks if the tensordict contains any leaf."""
for _ in self.keys(True, True):
Expand Down
1 change: 1 addition & 0 deletions tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,7 @@ def reset_parameters_recursive(
lambda x: x.detach().requires_grad_(), inplace=False
)

is_stateless = False
if _auto_make_functional() and not is_functional(self):
make_functional(self, keep_params=True)
is_stateless = self._is_stateless
Expand Down
1 change: 0 additions & 1 deletion tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ class ProbabilisticTensorDictModule(TensorDictModuleBase):
... TensorDictModule,
... )
>>> from tensordict.nn.distributions import NormalParamExtractor
>>> from tensordict.nn.functional_modules import make_functional
>>> from torch.distributions import Normal
>>> td = TensorDict(
... {"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3]
Expand Down
1 change: 0 additions & 1 deletion tensordict/nn/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ class TensorDictSequential(TensorDictModule):
... TensorDictSequential,
... )
>>> from tensordict.nn.distributions import NormalParamExtractor
>>> from tensordict.nn.functional_modules import make_functional
>>> from torch.distributions import Normal
>>> td = TensorDict({"input": torch.randn(3, 4)}, [3,])
>>> net1 = torch.nn.Linear(4, 8)
Expand Down
4 changes: 4 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def test_reset(self):
nn.Sequential(nn.Tanh(), nn.Linear(1, 1), nn.Linear(2, 1)),
],
)
@_set_auto_make_functional(True)
def test_reset_functional(self, net):
torch.manual_seed(0)
module = TensorDictModule(net, in_keys=["in"], out_keys=["out"])
Expand Down Expand Up @@ -204,6 +205,7 @@ def test_reset_functional(self, net):
p.all()
), f"Discrepancy between returned weights and those in-place updated {p}"

@_set_auto_make_functional(True)
def test_reset_functional_called_once(self):
import unittest.mock

Expand Down Expand Up @@ -403,6 +405,7 @@ def test_stateful_probabilistic(self, lazy, interaction_type, out_keys):
@pytest.mark.skipif(
not _has_functorch, reason=f"functorch not found: err={FUNCTORCH_ERR}"
)
@_set_auto_make_functional(True)
def test_functional_before(self):
torch.manual_seed(0)
param_multiplier = 1
Expand Down Expand Up @@ -569,6 +572,7 @@ def test_functional_probabilistic(self):
@pytest.mark.skipif(
not _has_functorch, reason=f"functorch not found: err={FUNCTORCH_ERR}"
)
@_set_auto_make_functional(True)
def test_functional_with_buffer(self):
torch.manual_seed(0)
param_multiplier = 1
Expand Down
Loading