diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 88e98d531e..5893c4ea3c 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -568,9 +568,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if dgrad is None: if ctx.parallel_mode == "column" and ctx.sequence_parallel: dgrad_shape[0] = dgrad_shape[0] * tp_world_size - dgrad = torch.empty( - dgrad_shape, dtype=output_dtype, device=grad_output.device - ) + dgrad = torch.empty(dgrad_shape, dtype=output_dtype, device=grad_output.device) if ctx.requires_dgrad: if ctx.fp8: