Skip to content

Commit 8ea9c7d

Browse files
committed
fix: support FSDP compatibility in LigerTiledSwiGLUMLP backward
The previous backward implementation called torch.autograd.backward() inside the tiling loop, triggering FSDP's post-backward hook (reshard) once per shard. This caused FSDP1 to reshard parameters mid-loop, leading to errors on subsequent shard iterations. Fix: replace torch.autograd.backward() with torch.autograd.grad() inside the tiling loop. This computes gradients locally without accumulating into .grad or triggering any hooks. Param gradients are accumulated manually across shards and written to .grad exactly once after the loop — FSDP sees a single gradient event, as expected. This fix is FSDP-agnostic: LigerTiledSwiGLUMLP requires no knowledge of FSDP. Verified with FSDP1 (FullyShardedDataParallel) and FSDP2 (fully_shard) on 2x RTX 3060. - FSDP1: previously errored, now passes - FSDP2: passes - Non-distributed: unaffected
1 parent 3343eee commit 8ea9c7d

File tree

1 file changed

+23
-13
lines changed

1 file changed

+23
-13
lines changed

src/liger_kernel/ops/tiled_mlp.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ def backward(ctx, *grads) -> tuple:
7676
incoming_grad = grads[0].view(-1, hidden_size)
7777
x_grad = torch.zeros_like(x)
7878

79+
# initialize param grad accumulators
80+
param_grads = {p: None for p in mlp_module.parameters()}
81+
7982
x_shards = list(torch.chunk(x, chunks=shards, dim=0))
8083

8184
for i, x_shard in enumerate(x_shards):
@@ -84,22 +87,29 @@ def backward(ctx, *grads) -> tuple:
8487
# if seqlen is not exactly divisible by shards the last step will be shorter than shard_step
8588
shard_step = x_shards[i].shape[0]
8689
shard_offset = i * x_shards[0].shape[0]
87-
88-
x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
8990
incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
9091

91-
all_outputs = []
92-
all_incoming_grads = []
9392
with torch.enable_grad():
94-
all_outputs.append(fn(mlp_module, x_shard))
95-
all_incoming_grads.append(
96-
incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
97-
)
98-
99-
# AccumulateGrad fires once here, after all shards are computed
100-
torch.autograd.backward(all_outputs, all_incoming_grads)
101-
102-
93+
output = fn(mlp_module, x_shard)
94+
local_grads = torch.autograd.grad(
95+
outputs=output,
96+
inputs=[x_shard] + list(mlp_module.parameters()),
97+
grad_outputs=incoming_grad_shard,
98+
)
99+
100+
x_grad.narrow(0, shard_offset, shard_step).copy_(local_grads[0])
101+
102+
for p, g in zip(mlp_module.parameters(), local_grads[1:]):
103+
if param_grads[p] is None:
104+
param_grads[p] = g
105+
else:
106+
param_grads[p] += g
107+
108+
for p, g in param_grads.items():
109+
if p.grad is None:
110+
p.grad = g
111+
else:
112+
p.grad += g
103113

104114
# unflatten
105115
x_grad = x_grad.view(x_shape_orig)

0 commit comments

Comments
 (0)