File tree Expand file tree Collapse file tree 1 file changed +3
-6
lines changed
torchrl/modules/distributions Expand file tree Collapse file tree 1 file changed +3
-6
lines changed Original file line number Diff line number Diff line change @@ -251,10 +251,7 @@ def __init__(
251251 sparse_mask = True
252252 else :
253253 sparse_mask = False
254- if !mask .any ():
255- raise ValueError (
256- f"Provided ``mask`` must contain a value True"
257- )
254+
258255 if probs is not None :
259256 if logits is not None :
260257 raise ValueError (
@@ -363,7 +360,7 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
363360 original_value_shape = value .shape
364361 value = value .flatten ()
365362 logits = logits .unsqueeze (0 ).expand (value .shape + logits .shape )
366- result = - torch .nn .functional .cross_entropy (logits , value , reduction = 'none' )
363+ result = - torch .nn .functional .cross_entropy (logits , value , reduce = False )
367364 if original_value_shape is not None :
368365 result = result .unflatten (0 , original_value_shape )
369366 else :
@@ -394,7 +391,7 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
394391 original_idx_shape = idx .shape
395392 idx = idx .flatten ()
396393 logits = logits .unsqueeze (0 ).expand (idx .shape + logits .shape )
397- ret = - torch .nn .functional .cross_entropy (logits , idx , reduction = 'none' )
394+ ret = - torch .nn .functional .cross_entropy (logits , idx , reduce = False )
398395 if original_idx_shape is not None :
399396 ret = ret .unflatten (0 , original_idx_shape )
400397 else :
You can’t perform that action at this time.
0 commit comments