@@ -172,7 +172,7 @@ def rsample(self, sample_shape: torch.Size | Sequence = None) -> torch.Tensor:
172172 )
173173
174174
175- class MaskedCategorical (D .Categorical ):
175+ class (D .Categorical ):
176176 """MaskedCategorical distribution.
177177
178178 Reference:
@@ -251,7 +251,10 @@ def __init__(
251251 sparse_mask = True
252252 else :
253253 sparse_mask = False
254-
254+ if !mask .any ():
255+ raise ValueError (
256+ f"Provided ``mask`` must contain a value True"
257+ )
255258 if probs is not None :
256259 if logits is not None :
257260 raise ValueError (
@@ -360,7 +363,7 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
360363 original_value_shape = value .shape
361364 value = value .flatten ()
362365 logits = logits .unsqueeze (0 ).expand (value .shape + logits .shape )
363- result = - torch .nn .functional .cross_entropy (logits , value , reduce = False )
366+ result = - torch .nn .functional .cross_entropy (logits , value , reduction = 'none' )
364367 if original_value_shape is not None :
365368 result = result .unflatten (0 , original_value_shape )
366369 else :
@@ -391,7 +394,7 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
391394 original_idx_shape = idx .shape
392395 idx = idx .flatten ()
393396 logits = logits .unsqueeze (0 ).expand (idx .shape + logits .shape )
394- ret = - torch .nn .functional .cross_entropy (logits , idx , reduce = False )
397+ ret = - torch .nn .functional .cross_entropy (logits , idx , reduction = 'none' )
395398 if original_idx_shape is not None :
396399 ret = ret .unflatten (0 , original_idx_shape )
397400 else :
0 commit comments