From 67c3c289936401ac239c7b684ba86bf088da8189 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Jan 2025 18:06:42 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/linear.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 88e98d531e..5893c4ea3c 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -568,9 +568,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if dgrad is None: if ctx.parallel_mode == "column" and ctx.sequence_parallel: dgrad_shape[0] = dgrad_shape[0] * tp_world_size - dgrad = torch.empty( - dgrad_shape, dtype=output_dtype, device=grad_output.device - ) + dgrad = torch.empty(dgrad_shape, dtype=output_dtype, device=grad_output.device) if ctx.requires_dgrad: if ctx.fp8: