Skip to content

Commit

Permalink
fixed empty dgrad buffer dtype at initialization
Browse files Browse the repository at this point in the history
Signed-off-by: Alp Dener <[email protected]>
  • Loading branch information
denera committed Jan 16, 2025
1 parent 296c6fa commit a1d810e
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,13 +506,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad")
dgrad = ub_obj_wgrad.get_ubuf_output(1)

if dgrad is None:
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
dgrad_shape[0] = dgrad_shape[0] * tp_world_size
dgrad = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
)

(
grad_output,
grad_output_c,
Expand Down Expand Up @@ -550,6 +543,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)

output_dtype = ctx.activation_dtype
if ctx.requires_dgrad:
if ctx.fp8:
if ctx.is_input_fp8 or (
Expand All @@ -570,6 +564,16 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
None,
ctx.activation_dtype,
)

if dgrad is None:
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
dgrad_shape[0] = dgrad_shape[0] * tp_world_size
dgrad = torch.empty(
dgrad_shape, dtype=output_dtype, device=grad_output.device
)

if ctx.requires_dgrad:
if ctx.fp8:
_ = fp8_gemm(
weight_fp8.transpose_2d(),
weight_fp8._scale_inv,
Expand Down

0 comments on commit a1d810e

Please sign in to comment.