Commit 8ea9c7d
committed
fix: support FSDP compatibility in LigerTiledSwiGLUMLP backward
The previous backward implementation called torch.autograd.backward()
inside the tiling loop, triggering FSDP's post-backward hook (reshard)
once per shard. This caused FSDP1 to reshard parameters mid-loop,
leading to errors on subsequent shard iterations.
Fix: replace torch.autograd.backward() with torch.autograd.grad()
inside the tiling loop. This computes gradients locally without
accumulating into .grad or triggering any hooks. Param gradients
are accumulated manually across shards and written to .grad exactly
once after the loop — FSDP sees a single gradient event, as expected.
This fix is FSDP-agnostic: LigerTiledSwiGLUMLP requires no knowledge
of FSDP. Verified with FSDP1 (FullyShardedDataParallel) and FSDP2
(fully_shard) on 2x RTX 3060.
- FSDP1: previously errored, now passes
- FSDP2: passes
- Non-distributed: unaffected1 parent 3343eee commit 8ea9c7d
1 file changed
+23
-13
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
76 | 76 | | |
77 | 77 | | |
78 | 78 | | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
79 | 82 | | |
80 | 83 | | |
81 | 84 | | |
| |||
84 | 87 | | |
85 | 88 | | |
86 | 89 | | |
87 | | - | |
88 | | - | |
89 | 90 | | |
90 | 91 | | |
91 | | - | |
92 | | - | |
93 | 92 | | |
94 | | - | |
95 | | - | |
96 | | - | |
97 | | - | |
98 | | - | |
99 | | - | |
100 | | - | |
101 | | - | |
102 | | - | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
103 | 113 | | |
104 | 114 | | |
105 | 115 | | |
| |||
0 commit comments