Skip to content

Commit 80bfa6e

Browse files
authored
Revert "replace reduce to reduction, better error message for invalid mask" (#3182)
1 parent 45bd3c0 commit 80bfa6e

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

torchrl/modules/distributions/discrete.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -251,10 +251,7 @@ def __init__(
251251
sparse_mask = True
252252
else:
253253
sparse_mask = False
254-
if !mask.any():
255-
raise ValueError(
256-
f"Provided ``mask`` must contain a value True"
257-
)
254+
258255
if probs is not None:
259256
if logits is not None:
260257
raise ValueError(
@@ -363,7 +360,7 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
363360
original_value_shape = value.shape
364361
value = value.flatten()
365362
logits = logits.unsqueeze(0).expand(value.shape + logits.shape)
366-
result = -torch.nn.functional.cross_entropy(logits, value, reduction='none')
363+
result = -torch.nn.functional.cross_entropy(logits, value, reduce=False)
367364
if original_value_shape is not None:
368365
result = result.unflatten(0, original_value_shape)
369366
else:
@@ -394,7 +391,7 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
394391
original_idx_shape = idx.shape
395392
idx = idx.flatten()
396393
logits = logits.unsqueeze(0).expand(idx.shape + logits.shape)
397-
ret = -torch.nn.functional.cross_entropy(logits, idx, reduction='none')
394+
ret = -torch.nn.functional.cross_entropy(logits, idx, reduce=False)
398395
if original_idx_shape is not None:
399396
ret = ret.unflatten(0, original_idx_shape)
400397
else:

0 commit comments

Comments
 (0)