Add FSDP support to TiledMLP by preventing premature resharding during the tiled backward recompute loop.#1128
Add FSDP support to TiledMLP by preventing premature resharding during the tiled backward recompute loop.#1128alektebel wants to merge 5 commits intolinkedin:mainfrom
Conversation
|
I'm not entirely if According to https://docs.pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.summon_full_params
Is mutli-gpu environment accessible to you? If not, could you provide mutli-gpu test script that I can run directly? |
48383f0 to
cfbe850
Compare
|
Hi @Tcc0403 , You're right, after checking the docs, I don't have multi-GPU access locally, I'm renting 2× RTX 3060s (~$0.09/h) to properly validate this. Will update the PR with a multi-GPU test script and results as soon as I get positive outcomes. |
doc says only apply on the root FSDP instance. It's really tricky to make it work with fsdp. |
584f385 to
e115ae5
Compare
The previous backward implementation called torch.autograd.backward() inside the tiling loop, triggering FSDP's post-backward hook (reshard) once per shard. This caused FSDP1 to reshard parameters mid-loop, leading to errors on subsequent shard iterations. Fix: replace torch.autograd.backward() with torch.autograd.grad() inside the tiling loop. This computes gradients locally without accumulating into .grad or triggering any hooks. Param gradients are accumulated manually across shards and written to .grad exactly once after the loop — FSDP sees a single gradient event, as expected. This fix is FSDP-agnostic: LigerTiledSwiGLUMLP requires no knowledge of FSDP. Verified with FSDP1 (FullyShardedDataParallel) and FSDP2 (fully_shard) on 2x RTX 3060. - FSDP1: previously errored, now passes - FSDP2: passes - Non-distributed: unaffected
2182143 to
8ea9c7d
Compare
|
Hi @Tcc0403 , You were right about the I propose two different potential solutions:
The FSDP2 fixAfter analyzing more deeply the problem, I found it tricky to stop the gradient flow to take place in FSDP, being this an external application, when doing the n forward-backward steps in tiledMLP. Researching, I found that actually, there is a newer version of FSDP that can handle this issue in pytorch. Here is the code I ran in a 2x3060 cloud instance that actually ran succesfully: (main) root@C.32454339:/workspace$ cat > test_fsdp2.py << 'EOF'
import os
from types import SimpleNamespace
import torch
import torch.distributed as dist
from torch.distributed.fsdp import fully_shard
from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP
def run():
dist.init_process_group("nccl")
rank = dist.get_rank()
torch.cuda.set_device(rank)
config = SimpleNamespace(hidden_size=32, intermediate_size=64)
model = LigerTiledSwiGLUMLP(config, num_shards=4).cuda()
fully_shard(model)
x = torch.randn(4, 16, 32, device=f"cuda:{rank}", requires_grad=True)
model(x).sum().backward()
if rank == 0:
print("OK — no error")
dist.destroy_process_group()
if __name__ == "__main__":
run()
EOF
torchrun --nproc_per_node=2 test_fsdp2.py
W0306 16:19:01.108000 1700 site-packages/torch/distributed/run.py:852]
W0306 16:19:01.108000 1700 site-packages/torch/distributed/run.py:852] *****************************************
W0306 16:19:01.108000 1700 site-packages/torch/distributed/run.py:852] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0306 16:19:01.108000 1700 site-packages/torch/distributed/run.py:852] *****************************************
OK — no errorFSDP2's per-parameter DTensor sharding handles multiple backward passes through the same parameters correctly by design, unlike FSDP1's flat parameter state machine. Source: Whereas if we use the current HEAD FSDP, with the same functionlity as above, I get an error: (main) root@C.32454339:/workspace$ cat > test_fsdp1.py << 'EOF'
import os
from types import SimpleNamespace
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP
def run():
dist.init_process_group("nccl")
rank = dist.get_rank()
torch.cuda.set_device(rank)
config = SimpleNamespace(hidden_size=32, intermediate_size=64)
model = FSDP(
LigerTiledSwiGLUMLP(config, num_shards=4).cuda(),
device_id=torch.device(f"cuda:{rank}"),
)
x = torch.randn(4, 16, 32, device=f"cuda:{rank}", requires_grad=True)
model(x).sum().backward()
if rank == 0:
print("OK — no error")
dist.destroy_process_group()
if __name__ == "__main__":
run()
EOF
torchrun --nproc_per_node=2 test_fsdp1.py
W0306 16:23:22.581000 1819 site-packages/torch/distributed/run.py:852]
W0306 16:23:22.581000 1819 site-packages/torch/distributed/run.py:852] *****************************************
W0306 16:23:22.581000 1819 site-packages/torch/distributed/run.py:852] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0306 16:23:22.581000 1819 site-packages/torch/distributed/run.py:852] *****************************************
[rank1]: Traceback (most recent call last):
[rank1]: File "/workspace/test_fsdp1.py", line 33, in <module>
[rank1]: run()
[rank1]: File "/workspace/test_fsdp1.py", line 24, in run
[rank1]: model(x).sum().backward()
[rank1]: File "/venv/main/lib/python3.12/site-packages/torch/_tensor.py", line 630, in backward
[rank1]: torch.autograd.backward(
[rank1]: File "/venv/main/lib/python3.12/site-packages/torch/autograd/__init__.py", line 364, in backward
[rank1]: _engine_run_backward(
[rank1]: File "/venv/main/lib/python3.12/site-packages/torch/autograd/graph.py", line 865, in _engine_run_backward
[rank1]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/venv/main/lib/python3.12/site-packages/torch/autograd/function.py", line 317, in apply
[rank1]: return user_fn(self, *args)
[rank1]: ^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/workspace/Liger-Kernel/src/liger_kernel/ops/utils.py", line 40, in wrapper
[rank1]: return fn(ctx, *args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/workspace/Liger-Kernel/src/liger_kernel/ops/tiled_mlp.py", line 92, in backward
[rank1]: output = fn(mlp_module, x_shard)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/workspace/Liger-Kernel/src/liger_kernel/transformers/tiled_mlp.py", line 102, in _mlp_forward
[rank1]: gate = module.gate_proj(x)
[rank1]: ^^^^^^^^^^^^^^^^^^^
[rank1]: File "/venv/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/venv/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/venv/main/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 134, in forward
[rank1]: return F.linear(input, self.weight, self.bias)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: RuntimeError: setStorage: sizes [32, 64], strides [1, 32], storage offset 0, and itemsize 4 requiring a storage size of 8192 are out of bounds for storage of size 0
[rank0]: Traceback (most recent call last):
[rank0]: File "/workspace/test_fsdp1.py", line 33, in <module>
[rank0]: run()
[rank0]: File "/workspace/test_fsdp1.py", line 24, in run
[rank0]: model(x).sum().backward()
[rank0]: File "/venv/main/lib/python3.12/site-packages/torch/_tensor.py", line 630, in backward
[rank0]: torch.autograd.backward(
[rank0]: File "/venv/main/lib/python3.12/site-packages/torch/autograd/__init__.py", line 364, in backward
[rank0]: _engine_run_backward(
[rank0]: File "/venv/main/lib/python3.12/site-packages/torch/autograd/graph.py", line 865, in _engine_run_backward
[rank0]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/venv/main/lib/python3.12/site-packages/torch/autograd/function.py", line 317, in apply
[rank0]: return user_fn(self, *args)
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/workspace/Liger-Kernel/src/liger_kernel/ops/utils.py", line 40, in wrapper
[rank0]: return fn(ctx, *args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/workspace/Liger-Kernel/src/liger_kernel/ops/tiled_mlp.py", line 92, in backward
[rank0]: output = fn(mlp_module, x_shard)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/workspace/Liger-Kernel/src/liger_kernel/transformers/tiled_mlp.py", line 102, in _mlp_forward
[rank0]: gate = module.gate_proj(x)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/venv/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/venv/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/venv/main/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 134, in forward
[rank0]: return F.linear(input, self.weight, self.bias)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: setStorage: sizes [32, 64], strides [1, 32], storage offset 0, and itemsize 4 requiring a storage size of 8192 are out of bounds for storage of size 0The FSDP1 fixI want to propose a solution related to the current proposed solution Please find updated these changes on the PR branch. I've tested this fix on the 2x3060 RTX environment and found no problems. Also, I've tried benchmarking it against the old version of With that, here would be the passing test I mentioned, for this fix. Importantly, this fix requires no knowledge of (main) root@C.32454339:/workspace$ cat > test_fsdp1.py << 'EOF'
import os
from types import SimpleNamespace
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP
def run():
dist.init_process_group("nccl")
rank = dist.get_rank()
torch.cuda.set_device(rank)
config = SimpleNamespace(hidden_size=32, intermediate_size=64)
model = FSDP(
LigerTiledSwiGLUMLP(config, num_shards=4).cuda(),
device_id=torch.device(f"cuda:{rank}"),
)
x = torch.randn(4, 16, 32, device=f"cuda:{rank}", requires_grad=True)
model(x).sum().backward()
if rank == 0:
print("OK — no error")
dist.destroy_process_group()
if __name__ == "__main__":
run()
EOF
torchrun --nproc_per_node=2 test_fsdp1.py
W0306 18:13:26.722000 1029 site-packages/torch/distributed/run.py:852]
W0306 18:13:26.722000 1029 site-packages/torch/distributed/run.py:852] *****************************************
W0306 18:13:26.722000 1029 site-packages/torch/distributed/run.py:852] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0306 18:13:26.722000 1029 site-packages/torch/distributed/run.py:852] *****************************************
OK — no errorLet me know what you think. |
Sounds good. Let's try this out. |
|
Hi @Tcc0403 , I believe the changes are updated on the PR branch. This is the code I ran too:
Succesfull code run on the PR brach root@C.32658063:/workspace/Liger-Kernel$ cat > test_fsdp1.py << 'EOF'
import os
from types import SimpleNamespace
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from liger_kernel.ops.tiled_mlp import apply_tiled_mlp
from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP
ATOL = 1e-4
RTOL = 1e-4
def run():
dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(rank)
config = SimpleNamespace(hidden_size=32, intermediate_size=64)
# Shared input — identical on all ranks.
# requires_grad=True is required: LigerTiledMLPFunction.forward() runs under
# torch.no_grad(), so the output only has a grad_fn (and is differentiable)
# when the input itself requires grad.
torch.manual_seed(0)
x_full = torch.randn(4, 16, 32, device=f"cuda:{rank}", requires_grad=True)
# --- Reference: plain model on rank 0 only ---
if rank == 0:
torch.manual_seed(42)
ref_model = LigerTiledSwiGLUMLP(config, num_shards=4).cuda()
# detach+clone gives an independent leaf so ref backward graph is isolated
x_ref = x_full.detach().clone().requires_grad_(True)
ref_out = ref_model(x_ref)
ref_out.sum().backward()
ref_grads = {n: p.grad.clone() for n, p in ref_model.named_parameters()}
ref_out = ref_out.detach()
# --- FSDP model (same seed so weights are identical at init) ---
torch.manual_seed(42)
fsdp_model = FSDP(
LigerTiledSwiGLUMLP(config, num_shards=4).cuda(),
device_id=torch.device(f"cuda:{rank}"),
use_orig_params=True,
)
# Each rank processes its own batch shard
batch_per_rank = x_full.shape[0] // world_size
x_local = x_full[rank * batch_per_rank:(rank + 1) * batch_per_rank].detach().clone().requires_grad_(True)
fsdp_out = fsdp_model(x_local)
torchrun --nproc_per_node=2 test_fsdp1.py{name}': max_diff={max_gdiff:.6f}"erence.
W0310 23:10:49.691000 2370 site-packages/torch/distributed/run.py:852]
W0310 23:10:49.691000 2370 site-packages/torch/distributed/run.py:852] *****************************************
W0310 23:10:49.691000 2370 site-packages/torch/distributed/run.py:852] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0310 23:10:49.691000 2370 site-packages/torch/distributed/run.py:852] *****************************************
OK — forward match (max_diff=0.00e+00)
OK — gradient matchBug reproducing code run on the repo HEADPlease find below the code to reproduce the bug at HEAD: (main) root@C.32658063:/workspace/Liger-Kernel$ cat > test_fsdp1.py << 'EOF'
import os
from types import SimpleNamespace
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from liger_kernel.ops.tiled_mlp import apply_tiled_mlp
from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP
ATOL = 1e-4
RTOL = 1e-4
def run():
dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(rank)
config = SimpleNamespace(hidden_size=32, intermediate_size=64)
# Shared input — identical on all ranks.
# requires_grad=True is required: LigerTiledMLPFunction.forward() runs under
# torch.no_grad(), so the output only has a grad_fn (and is differentiable)
# when the input itself requires grad.
torch.manual_seed(0)
x_full = torch.randn(4, 16, 32, device=f"cuda:{rank}", requires_grad=True)
# --- Reference: plain model on rank 0 only ---
if rank == 0:
torch.manual_seed(42)
ref_model = LigerTiledSwiGLUMLP(config, num_shards=4).cuda()
# detach+clone gives an independent leaf so ref backward graph is isolated
x_ref = x_full.detach().clone().requires_grad_(True)
ref_out = ref_model(x_ref)
ref_out.sum().backward()
ref_grads = {n: p.grad.clone() for n, p in ref_model.named_parameters()}
ref_out = ref_out.detach()
# --- FSDP model (same seed so weights are identical at init) ---
torch.manual_seed(42)
fsdp_model = FSDP(
LigerTiledSwiGLUMLP(config, num_shards=4).cuda(),
device_id=torch.device(f"cuda:{rank}"),
use_orig_params=True,
)
# Each rank processes its own batch shard
batch_per_rank = x_full.shape[0] // world_size
x_local = x_full[rank * batch_per_rank:(rank + 1) * batch_per_rank].detach().clone().requires_grad_(True)
fsdp_out = fsdp_model(x_local)
torchrun --nproc_per_node=2 test_fsdp1.py{name}': max_diff={max_gdiff:.6f}"erence.
W0310 23:13:58.381000 2921 site-packages/torch/distributed/run.py:852]
W0310 23:13:58.381000 2921 site-packages/torch/distributed/run.py:852] *****************************************
W0310 23:13:58.381000 2921 site-packages/torch/distributed/run.py:852] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0310 23:13:58.381000 2921 site-packages/torch/distributed/run.py:852] *****************************************
[rank0]: Traceback (most recent call last):
[rank0]: File "/workspace/Liger-Kernel/test_fsdp1.py", line 88, in <module>
[rank0]: run()
[rank0]: File "/workspace/Liger-Kernel/test_fsdp1.py", line 54, in run
[rank0]: fsdp_out.sum().backward()
[rank0]: File "/venv/main/lib/python3.12/site-packages/torch/_tensor.py", line 630, in backward
[rank0]: torch.autograd.backward(
[rank0]: File "/venv/main/lib/python3.12/site-packages/torch/autograd/__init__.py", line 364, in backward
[rank0]: _engine_run_backward(
[rank0]: File "/venv/main/lib/python3.12/site-packages/torch/autograd/graph.py", line 865, in _engine_run_backward
[rank0]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/venv/main/lib/python3.12/site-packages/torch/autograd/function.py", line 317, in apply
[rank0]: return user_fn(self, *args)
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/workspace/Liger-Kernel/src/liger_kernel/ops/utils.py", line 40, in wrapper
[rank0]: return fn(ctx, *args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/workspace/Liger-Kernel/src/liger_kernel/ops/tiled_mlp.py", line 92, in backward
[rank0]: output = fn(mlp_module, x_shard)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/workspace/Liger-Kernel/src/liger_kernel/transformers/tiled_mlp.py", line 102, in _mlp_forward
[rank0]: gate = module.gate_proj(x)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/venv/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/venv/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/venv/main/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 134, in forward
[rank0]: return F.linear(input, self.weight, self.bias)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: size mismatch, got input (8), mat (8x32), vec (2048)
W0310 23:14:03.576000 2921 site-packages/torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 2985 closing signal SIGTERM
E0310 23:14:03.741000 2921 site-packages/torch/distributed/elastic/multiprocessing/api.py:984] failed (exitcode: 1) local_rank: 0 (pid: 2984) of binary: /venv/main/bin/python3
What more tests come to your mind in order to make sure that this doesn't break anything and behaves as expected? Should I integrate this from above into the test suite? Only thing, it will only work on a 2+ gpu instance. |
|
Add a test similar to |
|
Hi @Tcc0403 , Added a test for FSDP-TiledMLP compatibility. Ran it on 2x3060 env without issues: test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype0-1e-05-1e-05-2-256-512-1-2] PASSED [ 81%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype0-1e-05-1e-05-2-256-512-2-2] PASSED [ 82%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype0-1e-05-1e-05-2-256-512-4-2] PASSED [ 83%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype0-1e-05-1e-05-2-512-1024-1-2] PASSED [ 84%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype0-1e-05-1e-05-2-512-1024-2-2] PASSED [ 85%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype0-1e-05-1e-05-2-512-1024-4-2] PASSED [ 86%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype0-1e-05-1e-05-1-128-256-1-2] PASSED [ 87%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype0-1e-05-1e-05-1-128-256-2-2] PASSED [ 88%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype0-1e-05-1e-05-1-128-256-4-2] PASSED [ 90%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype1-0.1-0.1-2-256-512-1-2] PASSED [ 91%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype1-0.1-0.1-2-256-512-2-2] PASSED [ 92%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype1-0.1-0.1-2-256-512-4-2] PASSED [ 93%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype1-0.1-0.1-2-512-1024-1-2] PASSED [ 94%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype1-0.1-0.1-2-512-1024-2-2] PASSED [ 95%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype1-0.1-0.1-2-512-1024-4-2] PASSED [ 96%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype1-0.1-0.1-1-128-256-1-2] PASSED [ 97%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype1-0.1-0.1-1-128-256-2-2] PASSED [ 98%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype1-0.1-0.1-1-128-256-4-2] PASSED [100%]One question: ¿Is adding 2, 4, 8 world size overkill for the tests? I've left the combinations to test as just 2 world size: @pytest.mark.parametrize("world_size", [2]) # extend to [2, 4, 8] on multi-GPU hostsThe test was added to Let me know if it would be better to have 2, 4, 8 gpu instance test, skipped in case not available by the setup. |
Summary
Add FSDP compatibility to
TiledMLPby preventing premature parameter resharding during the tiled backward recompute loop.The fix introduces a small helper
_get_fsdp_ctxinutils.pythat returnsFSDP.summon_full_params(writeback=True)when the module is FSDP-wrapped (or a no-opnullcontextotherwise). This context wraps the entire tiled backward loop, ensuring parameters remain unsharded across all tile iterations. Without it, the first inner backward triggers FSDP's post-backward hook → resharding → subsequent tiles see only local shards, causing state mismatch, runtime errors, or silently corrupted gradients.This resolves the known FSDP incompatibility described in #893 and #935.
No behavior change or overhead when not using FSDP (DDP, single-GPU, etc.).
Details
src/liger_kernel/ops/utils.py→_get_fsdp_ctxsrc/liger_kernel/ops/tiled_mlp.py→ insideLigerTiledMLPFunction.backwardTesting Done
Quick manual verification I did locally:
LigerTiledSwiGLUMLPThis confirms the fix runs without error on a single-GPU setup. Full correctness validation (grad comparison, multi-GPU sharding) will be covered in CI.
Hardware Type: RTX 5060 Ti (local single-GPU testing)
run make test to ensure correctness
run make checkstyle to ensure code style
run make test-convergence to ensure convergence
Manually verified that gradients match a non-tiled reference backward pass when using FSDP with world_size=1 (NO_SHARD fallback).
Convergence testing on A100/H100 with real multi-GPU FSDP will be done in CI or by reviewers.
Additional manual verification:
Single-GPU (world_size=1): forward + backward completes without errors
Simulated FSDP behavior: gradients accumulate correctly across tiles, match reference non-tiled case (within fp tolerance)
No OOM or performance regression observed during tiled backward loop
Happy to add multi-GPU CI test coverage or adjust based on feedback.