@@ -3536,13 +3536,22 @@ def t_op():
3536
3536
def _pad_batch_rule (batched_args , batch_dims , * , padding_config ):
3537
3537
operand , padding_value = batched_args
3538
3538
operand_bdim , padding_value_bdim = batch_dims
3539
+ if operand_bdim is None :
3540
+ operand_bdim = 0
3541
+ operand = broadcast (operand , (padding_value .shape [padding_value_bdim ],))
3542
+
3543
+ padding_config = list (padding_config )
3544
+ padding_config .insert (operand_bdim , (0 , 0 , 0 ))
3539
3545
if padding_value_bdim is None :
3540
- assert operand_bdim is not None
3541
- padding_config = list (padding_config )
3542
- padding_config .insert (operand_bdim , (0 , 0 , 0 ))
3543
3546
return pad (operand , padding_value , padding_config ), operand_bdim
3544
- else :
3545
- raise NotImplementedError # loop and stack
3547
+
3548
+ assert padding_value_bdim == 0 , padding_value_bdim
3549
+
3550
+ x = pad (operand , _zero (operand ), padding_config )
3551
+ mask = pad (full_like (operand , True , np .bool_ ), False , padding_config )
3552
+ broadcasted_padding = broadcast_in_dim (padding_value , x .shape ,
3553
+ (operand_bdim ,))
3554
+ return select (mask , x , broadcasted_padding ), operand_bdim
3546
3555
3547
3556
def _pad_translation_rule (c , operand , padding_value , * , padding_config ):
3548
3557
return xops .Pad (operand , padding_value ,
0 commit comments