Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 0 additions & 101 deletions mmfreelm/modules/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,107 +725,6 @@ def backward(ctx, dout, *args):
None,
)

@staticmethod
@contiguous
def forward(
ctx,
x,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual=None,
eps=1e-6,
prenorm=False,
residual_in_fp32=False,
is_rms_norm=False,
):
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if residual is not None:
assert residual.shape == x_shape_og
residual = residual.reshape(-1, residual.shape[-1])
residual_dtype = (
residual.dtype
if residual is not None
else (torch.float32 if residual_in_fp32 else None)
)
y, mean, rstd, residual_out = _layer_norm_fwd(
x,
norm_weight,
norm_bias,
eps,
residual,
out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
residual_dtype=residual_dtype,
is_rms_norm=is_rms_norm,
)
y = y.reshape(x_shape_og)
dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
linear_shaweight = linear_weight.to(dtype)
linear_bias = linear_bias.to(
dtype) if linear_bias is not None else None

linear_weight_a = weight_quant(linear_weight)
y = activation_quant(y)
out = F.linear(y.to(linear_weight.dtype), linear_weight_a, linear_bias)
ctx.save_for_backward(residual_out, norm_weight,
norm_bias, linear_weight, mean, rstd)
# We don't store y, will be recomputed in the backward pass to save memory

ctx.x_shape_og = x_shape_og
ctx.eps = eps
ctx.is_rms_norm = is_rms_norm
ctx.has_residual = residual is not None
ctx.prenorm = prenorm
ctx.x_dtype = x.dtype
ctx.linear_bias_is_none = linear_bias is None
return out if not prenorm else (out, residual_out.reshape(x_shape_og))

@staticmethod
@contiguous
def backward(ctx, dout, *args):
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
dout = dout.reshape(-1, dout.shape[-1])
linear_weight = weight_quant(linear_weight.t())
dy = F.linear(dout, linear_weight)
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
assert dy.shape == x.shape
if ctx.prenorm:
dresidual = args[0]
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
assert dresidual.shape == x.shape
else:
dresidual = None
dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
dy,
x,
norm_weight,
norm_bias,
ctx.eps,
mean,
rstd,
dresidual,
ctx.has_residual,
ctx.is_rms_norm,
x_dtype=ctx.x_dtype,
recompute_output=True,
)
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
return (
dx.reshape(ctx.x_shape_og),
dnorm_weight,
dnorm_bias,
dlinear_weight,
dlinear_bias,
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
None,
None,
None,
None,
)


def layer_norm_linear_fn(
x,
Expand Down