Skip to content

Commit

Permalink
[BugFix] select_out_keys for Prob sequential
Browse files Browse the repository at this point in the history
ghstack-source-id: a566ae225c54f07a680b4bf380b16d8e797f62ea
Pull Request resolved: #1103
  • Loading branch information
vmoens committed Nov 23, 2024
1 parent c95a703 commit df61d64
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 12 deletions.
29 changes: 20 additions & 9 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,6 @@ def __init__(
in_keys: NestedKey | List[NestedKey] | Dict[str, NestedKey],
out_keys: NestedKey | List[NestedKey] | None = None,
*,
default_interaction_mode: str | None = None,
default_interaction_type: InteractionType = InteractionType.DETERMINISTIC,
distribution_class: type = Delta,
distribution_kwargs: dict | None = None,
Expand Down Expand Up @@ -332,11 +331,7 @@ def __init__(
log_prob_key = "sample_log_prob"
self.log_prob_key = log_prob_key

if default_interaction_mode is not None:
raise ValueError(
"default_interaction_mode is deprecated, use default_interaction_type instead."
)
self.default_interaction_type = default_interaction_type
self.default_interaction_type = InteractionType(default_interaction_type)

if isinstance(distribution_class, str):
distribution_class = distributions_maps.get(distribution_class.lower())
Expand All @@ -356,7 +351,7 @@ def get_dist(self, tensordict: TensorDictBase) -> D.Distribution:
for dist_key, td_key in _zip_strict(self.dist_keys, self.in_keys):
if isinstance(dist_key, tuple):
dist_key = dist_key[-1]
dist_kwargs[dist_key] = tensordict.get(td_key)
dist_kwargs[dist_key] = tensordict.get(td_key, None)
dist = self.distribution_class(
**dist_kwargs, **_dynamo_friendly_to_dict(self.distribution_kwargs)
)
Expand Down Expand Up @@ -630,8 +625,24 @@ def forward(
tensordict_out: TensorDictBase | None = None,
**kwargs,
) -> TensorDictBase:
tensordict_out = self.get_dist_params(tensordict, tensordict_out, **kwargs)
return self.module[-1](tensordict_out, _requires_sample=self._requires_sample)
if (tensordict_out is None and self._select_before_return) or (
tensordict_out is not None
):
tensordict_exec = tensordict.copy()
else:
tensordict_exec = tensordict
tensordict_exec = self.get_dist_params(tensordict_exec, **kwargs)
tensordict_exec = self.module[-1](
tensordict_exec, _requires_sample=self._requires_sample
)
if tensordict_out is not None:
result = tensordict_out
result.update(tensordict_exec, keys_to_update=self.out_keys)
else:
result = tensordict_exec
if self._select_before_return:
return tensordict.update(result, keys_to_update=self.out_keys)
return result


def _dynamo_friendly_to_dict(data):
Expand Down
6 changes: 3 additions & 3 deletions tensordict/nn/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,13 +470,13 @@ def forward(
tensordict_out: TensorDictBase | None = None,
**kwargs: Any,
) -> TensorDictBase:
if tensordict_out is None and self._select_before_return:
if (tensordict_out is None and self._select_before_return) or (
tensordict_out is not None
):
tensordict_exec = tensordict.copy()
else:
tensordict_exec = tensordict
if not len(kwargs):
if tensordict_out is not None:
tensordict_exec = tensordict_exec.copy()
for module in self.module:
tensordict_exec = self._run_module(module, tensordict_exec, **kwargs)
else:
Expand Down
53 changes: 53 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,59 @@ def test_stateful_probabilistic_deprec(self, lazy):
dist = tdmodule.get_dist(td)
assert dist.rsample().shape[: td.ndimension()] == td.shape

@pytest.mark.parametrize("return_log_prob", [True, False])
@pytest.mark.parametrize("td_out", [True, False])
def test_probtdseq(self, return_log_prob, td_out):
mod = ProbabilisticTensorDictSequential(
TensorDictModule(lambda x: x + 2, in_keys=["a"], out_keys=["c"]),
TensorDictModule(lambda x: (x + 2, x), in_keys=["b"], out_keys=["d", "e"]),
ProbabilisticTensorDictModule(
in_keys={"loc": "d", "scale": "e"},
out_keys=["f"],
distribution_class=Normal,
return_log_prob=return_log_prob,
default_interaction_type="random",
),
)
inp = TensorDict({"a": 0.0, "b": 1.0})
inp_clone = inp.clone()
if td_out:
out = TensorDict()
else:
out = None
out2 = mod(inp, tensordict_out=out)
assert not mod._select_before_return
if td_out:
assert out is out2
else:
assert out2 is inp
assert set(out2.keys()) - {"a", "b"} == set(mod.out_keys), (
td_out,
return_log_prob,
)

inp = inp_clone.clone()
mod.select_out_keys("f")
if td_out:
out = TensorDict()
else:
out = None
out2 = mod(inp, tensordict_out=out)
assert mod._select_before_return
if td_out:
assert out is out2
else:
assert out2 is inp
expected = {"f"}
if td_out:
assert set(out2.keys()) == set(mod.out_keys) == expected
else:
assert (
set(out2.keys()) - set(inp_clone.keys())
== set(mod.out_keys)
== expected
)

@pytest.mark.parametrize("lazy", [True, False])
def test_stateful_probabilistic(self, lazy):
torch.manual_seed(0)
Expand Down

1 comment on commit df61d64

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: df61d64 Previous: c95a703 Ratio
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_last 76158.11921615894 iter/sec (stddev: 9.84018451671747e-7) 169405.24380341786 iter/sec (stddev: 7.340222570973414e-7) 2.22
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_leaf_last 77114.26150616663 iter/sec (stddev: 8.255069601043964e-7) 168723.80396478748 iter/sec (stddev: 5.101773355119075e-7) 2.19

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.