Skip to content

Commit 5661db7

Browse files
authored
[Quality] replace reduce to reduction, better error message for invalid mask (#3179)
1 parent 8374bb4 commit 5661db7

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

torchrl/modules/distributions/discrete.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def rsample(self, sample_shape: torch.Size | Sequence = None) -> torch.Tensor:
172172
)
173173

174174

175-
class MaskedCategorical(D.Categorical):
175+
class (D.Categorical):
176176
"""MaskedCategorical distribution.
177177
178178
Reference:
@@ -251,7 +251,10 @@ def __init__(
251251
sparse_mask = True
252252
else:
253253
sparse_mask = False
254-
254+
if !mask.any():
255+
raise ValueError(
256+
f"Provided ``mask`` must contain a value True"
257+
)
255258
if probs is not None:
256259
if logits is not None:
257260
raise ValueError(
@@ -360,7 +363,7 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
360363
original_value_shape = value.shape
361364
value = value.flatten()
362365
logits = logits.unsqueeze(0).expand(value.shape + logits.shape)
363-
result = -torch.nn.functional.cross_entropy(logits, value, reduce=False)
366+
result = -torch.nn.functional.cross_entropy(logits, value, reduction='none')
364367
if original_value_shape is not None:
365368
result = result.unflatten(0, original_value_shape)
366369
else:
@@ -391,7 +394,7 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
391394
original_idx_shape = idx.shape
392395
idx = idx.flatten()
393396
logits = logits.unsqueeze(0).expand(idx.shape + logits.shape)
394-
ret = -torch.nn.functional.cross_entropy(logits, idx, reduce=False)
397+
ret = -torch.nn.functional.cross_entropy(logits, idx, reduction='none')
395398
if original_idx_shape is not None:
396399
ret = ret.unflatten(0, original_idx_shape)
397400
else:

0 commit comments

Comments
 (0)