Skip to content

Add FSDP support to TiledMLP by preventing premature resharding during the tiled backward recompute loop.#1128

Draft
alektebel wants to merge 5 commits intolinkedin:mainfrom
alektebel:fsdp-tiledmlp-support
Draft

Add FSDP support to TiledMLP by preventing premature resharding during the tiled backward recompute loop.#1128
alektebel wants to merge 5 commits intolinkedin:mainfrom
alektebel:fsdp-tiledmlp-support

Conversation

@alektebel
Copy link

Summary

Add FSDP compatibility to TiledMLP by preventing premature parameter resharding during the tiled backward recompute loop.

The fix introduces a small helper _get_fsdp_ctx in utils.py that returns FSDP.summon_full_params(writeback=True) when the module is FSDP-wrapped (or a no-op nullcontext otherwise). This context wraps the entire tiled backward loop, ensuring parameters remain unsharded across all tile iterations. Without it, the first inner backward triggers FSDP's post-backward hook → resharding → subsequent tiles see only local shards, causing state mismatch, runtime errors, or silently corrupted gradients.

This resolves the known FSDP incompatibility described in #893 and #935.

No behavior change or overhead when not using FSDP (DDP, single-GPU, etc.).

Details

  • New helper: src/liger_kernel/ops/utils.py_get_fsdp_ctx
  • Change location: src/liger_kernel/ops/tiled_mlp.py → inside LigerTiledMLPFunction.backward
  • The tiled recompute loop is now wrapped:
    fsdp_ctx = _get_fsdp_ctx(mlp_module)
    with fsdp_ctx:
        for i, x_shard in enumerate(x_shards):
            # shard setup + forward + backward
    

Testing Done

Quick manual verification I did locally:

  1. Wrapped a small LigerTiledSwiGLUMLP
  2. Ran full-batch backward
 import os
 from types import SimpleNamespace
 
 import torch
 import torch.distributed as dist
 import torch.multiprocessing as mp
 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 
 from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP
 
 
 def run(rank, _):
     os.environ.update({"MASTER_ADDR": "localhost", "MASTER_PORT": "12380"})
     dist.init_process_group("nccl", rank=rank, world_size=1)
 
     config = SimpleNamespace(hidden_size=32, intermediate_size=64)
     model = FSDP(LigerTiledSwiGLUMLP(config).cuda(), device_id=torch.device("cuda:0"))
     x = torch.randn(4, 16, 32, device="cuda:0", requires_grad=True)
 
     model(x).sum().backward()
     print("OK — no error")
     dist.destroy_process_group()
 
 
 if __name__ == "__main__":
     mp.spawn(run, args=(1,), nprocs=1, join=True)                         

This confirms the fix runs without error on a single-GPU setup. Full correctness validation (grad comparison, multi-GPU sharding) will be covered in CI.

Hardware Type: RTX 5060 Ti (local single-GPU testing)
run make test to ensure correctness

  • All existing unit tests pass (including tiled MLP forward/backward correctness)
    run make checkstyle to ensure code style
  • Black/ruff/isort pass (no violations)
    run make test-convergence to ensure convergence
  • [ ]! Not yet run on multi-GPU FSDP setup (local hardware limitation);
    Manually verified that gradients match a non-tiled reference backward pass when using FSDP with world_size=1 (NO_SHARD fallback).
    Convergence testing on A100/H100 with real multi-GPU FSDP will be done in CI or by reviewers.

Additional manual verification:

Single-GPU (world_size=1): forward + backward completes without errors
Simulated FSDP behavior: gradients accumulate correctly across tiles, match reference non-tiled case (within fp tolerance)
No OOM or performance regression observed during tiled backward loop

Happy to add multi-GPU CI test coverage or adjust based on feedback.

@alektebel alektebel mentioned this pull request Mar 4, 2026
3 tasks
@Tcc0403
Copy link
Collaborator

Tcc0403 commented Mar 5, 2026

I'm not entirely if summon_full_params context is workable.

According to https://docs.pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.summon_full_params

This can not be used within a forward or backward pass. Nor can forward and backward be started from within this context.

Is mutli-gpu environment accessible to you? If not, could you provide mutli-gpu test script that I can run directly?

@alektebel alektebel marked this pull request as draft March 6, 2026 01:00
@alektebel alektebel force-pushed the fsdp-tiledmlp-support branch from 48383f0 to cfbe850 Compare March 6, 2026 01:05
@alektebel
Copy link
Author

Hi @Tcc0403 ,

You're right, after checking the docs, summon_full_params is indeed not usable within a forward/backward pass, so I'm pivoting to a no_sync-based context instead, as discussed in the #935 thread.

I don't have multi-GPU access locally, I'm renting 2× RTX 3060s (~$0.09/h) to properly validate this. Will update the PR with a multi-GPU test script and results as soon as I get positive outcomes.

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Mar 6, 2026

https://docs.pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.no_sync

This should only be used on the root FSDP instance and will recursively apply to all children FSDP instances.

doc says only apply on the root FSDP instance. It's really tricky to make it work with fsdp.

@alektebel alektebel force-pushed the fsdp-tiledmlp-support branch from 584f385 to e115ae5 Compare March 6, 2026 16:50
The previous backward implementation called torch.autograd.backward()
inside the tiling loop, triggering FSDP's post-backward hook (reshard)
once per shard. This caused FSDP1 to reshard parameters mid-loop,
leading to errors on subsequent shard iterations.

Fix: replace torch.autograd.backward() with torch.autograd.grad()
inside the tiling loop. This computes gradients locally without
accumulating into .grad or triggering any hooks. Param gradients
are accumulated manually across shards and written to .grad exactly
once after the loop — FSDP sees a single gradient event, as expected.

This fix is FSDP-agnostic: LigerTiledSwiGLUMLP requires no knowledge
of FSDP. Verified with FSDP1 (FullyShardedDataParallel) and FSDP2
(fully_shard) on 2x RTX 3060.

- FSDP1: previously errored, now passes
- FSDP2: passes
- Non-distributed: unaffected
@alektebel alektebel force-pushed the fsdp-tiledmlp-support branch from 2182143 to 8ea9c7d Compare March 6, 2026 17:59
@alektebel
Copy link
Author

alektebel commented Mar 6, 2026

Hi @Tcc0403 ,

You were right about the no_sync context issue. Being TiledMLP a wrapped object, it's impossible for it to know when it is being wrapped by FSDP or other object, and FSDP in this case is an external dependency, so this aproach doesn't seem promising.

I propose two different potential solutions:

  • FSDP2 fix: One comes about changing the FSDP used for the newer fully_shard (FSDP) in pytorch. Below I've tested for the bug in the 2x3060RTX env and ran without an error
  • FSDP1 fix: The second one is a fix in the backward method of TiledMLP. Instead of calling .autograd() each shard, we call .grad(), which doesn't trigger a full backward method, therefore allowing us to accumulate the gradients correctly. Afterwards, .autograd() is called to propagate these grads to the rest of the computational graph.

The FSDP2 fix

After analyzing more deeply the problem, I found it tricky to stop the gradient flow to take place in FSDP, being this an external application, when doing the n forward-backward steps in tiledMLP. Researching, I found that actually, there is a newer version of FSDP that can handle this issue in pytorch. Here is the code I ran in a 2x3060 cloud instance that actually ran succesfully:

(main) root@C.32454339:/workspace$ cat > test_fsdp2.py << 'EOF'
import os
from types import SimpleNamespace

import torch
import torch.distributed as dist
from torch.distributed.fsdp import fully_shard

from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP


def run():
    dist.init_process_group("nccl")
    rank = dist.get_rank()
    torch.cuda.set_device(rank)

    config = SimpleNamespace(hidden_size=32, intermediate_size=64)

    model = LigerTiledSwiGLUMLP(config, num_shards=4).cuda()
    fully_shard(model)

    x = torch.randn(4, 16, 32, device=f"cuda:{rank}", requires_grad=True)
    model(x).sum().backward()

    if rank == 0:
        print("OK — no error")

    dist.destroy_process_group()


if __name__ == "__main__":
    run()
EOF
torchrun --nproc_per_node=2 test_fsdp2.py
W0306 16:19:01.108000 1700 site-packages/torch/distributed/run.py:852] 
W0306 16:19:01.108000 1700 site-packages/torch/distributed/run.py:852] *****************************************
W0306 16:19:01.108000 1700 site-packages/torch/distributed/run.py:852] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0306 16:19:01.108000 1700 site-packages/torch/distributed/run.py:852] *****************************************
OK — no error

FSDP2's per-parameter DTensor sharding handles multiple backward passes through the same parameters correctly by design, unlike FSDP1's flat parameter state machine.

Source:
pytorch/pytorch#114299

Whereas if we use the current HEAD FSDP, with the same functionlity as above, I get an error:

(main) root@C.32454339:/workspace$ cat > test_fsdp1.py << 'EOF'
import os
from types import SimpleNamespace

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP


def run():
    dist.init_process_group("nccl")
    rank = dist.get_rank()
    torch.cuda.set_device(rank)

    config = SimpleNamespace(hidden_size=32, intermediate_size=64)

    model = FSDP(
        LigerTiledSwiGLUMLP(config, num_shards=4).cuda(),
        device_id=torch.device(f"cuda:{rank}"),
    )

    x = torch.randn(4, 16, 32, device=f"cuda:{rank}", requires_grad=True)
    model(x).sum().backward()

    if rank == 0:
        print("OK — no error")

    dist.destroy_process_group()


if __name__ == "__main__":
    run()
EOF
torchrun --nproc_per_node=2 test_fsdp1.py
W0306 16:23:22.581000 1819 site-packages/torch/distributed/run.py:852] 
W0306 16:23:22.581000 1819 site-packages/torch/distributed/run.py:852] *****************************************
W0306 16:23:22.581000 1819 site-packages/torch/distributed/run.py:852] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0306 16:23:22.581000 1819 site-packages/torch/distributed/run.py:852] *****************************************
[rank1]: Traceback (most recent call last):
[rank1]:   File "/workspace/test_fsdp1.py", line 33, in <module>
[rank1]:     run()
[rank1]:   File "/workspace/test_fsdp1.py", line 24, in run
[rank1]:     model(x).sum().backward()
[rank1]:   File "/venv/main/lib/python3.12/site-packages/torch/_tensor.py", line 630, in backward
[rank1]:     torch.autograd.backward(
[rank1]:   File "/venv/main/lib/python3.12/site-packages/torch/autograd/__init__.py", line 364, in backward
[rank1]:     _engine_run_backward(
[rank1]:   File "/venv/main/lib/python3.12/site-packages/torch/autograd/graph.py", line 865, in _engine_run_backward
[rank1]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/venv/main/lib/python3.12/site-packages/torch/autograd/function.py", line 317, in apply
[rank1]:     return user_fn(self, *args)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/Liger-Kernel/src/liger_kernel/ops/utils.py", line 40, in wrapper
[rank1]:     return fn(ctx, *args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/Liger-Kernel/src/liger_kernel/ops/tiled_mlp.py", line 92, in backward
[rank1]:     output = fn(mlp_module, x_shard)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/Liger-Kernel/src/liger_kernel/transformers/tiled_mlp.py", line 102, in _mlp_forward
[rank1]:     gate = module.gate_proj(x)
[rank1]:            ^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/venv/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/venv/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/venv/main/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 134, in forward
[rank1]:     return F.linear(input, self.weight, self.bias)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: RuntimeError: setStorage: sizes [32, 64], strides [1, 32], storage offset 0, and itemsize 4 requiring a storage size of 8192 are out of bounds for storage of size 0
[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/test_fsdp1.py", line 33, in <module>
[rank0]:     run()
[rank0]:   File "/workspace/test_fsdp1.py", line 24, in run
[rank0]:     model(x).sum().backward()
[rank0]:   File "/venv/main/lib/python3.12/site-packages/torch/_tensor.py", line 630, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "/venv/main/lib/python3.12/site-packages/torch/autograd/__init__.py", line 364, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/venv/main/lib/python3.12/site-packages/torch/autograd/graph.py", line 865, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/venv/main/lib/python3.12/site-packages/torch/autograd/function.py", line 317, in apply
[rank0]:     return user_fn(self, *args)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/Liger-Kernel/src/liger_kernel/ops/utils.py", line 40, in wrapper
[rank0]:     return fn(ctx, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/Liger-Kernel/src/liger_kernel/ops/tiled_mlp.py", line 92, in backward
[rank0]:     output = fn(mlp_module, x_shard)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/Liger-Kernel/src/liger_kernel/transformers/tiled_mlp.py", line 102, in _mlp_forward
[rank0]:     gate = module.gate_proj(x)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/venv/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/venv/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/venv/main/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 134, in forward
[rank0]:     return F.linear(input, self.weight, self.bias)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: setStorage: sizes [32, 64], strides [1, 32], storage offset 0, and itemsize 4 requiring a storage size of 8192 are out of bounds for storage of size 0

The FSDP1 fix

I want to propose a solution related to the current FSDP, in case we don't want to change it for any reason. My proposal is that, being the main problem related to this issue, is that a full backward is triggered when calling the backwardmethod at TiledMLP many times (n_shards), which interferes with FSDP's state machine, corrupting the resharding state of the FSDP.

proposed solution
Instead of calling torch.autograd.backward() inside the tiling loop, we use torch.autograd.grad(), which returns gradients as plain tensors without writing to .grad or firing any registered hooks. FSDP's post-backward reshard hook never fires during the loop. Param gradients are accumulated manually across shards and written to .grad exactly once after the loop, which is the single gradient event FSDP expects.

Please find updated these changes on the PR branch.

I've tested this fix on the 2x3060 RTX environment and found no problems. Also, I've tried benchmarking it against the old version of TiledMLP and found no significant change in memory or speed.

With that, here would be the passing test I mentioned, for this fix.

Importantly, this fix requires no knowledge of FSDP inside LigerTiledSwiGLUMLP

(main) root@C.32454339:/workspace$ cat > test_fsdp1.py << 'EOF'
import os
from types import SimpleNamespace

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP


def run():
    dist.init_process_group("nccl")
    rank = dist.get_rank()
    torch.cuda.set_device(rank)

    config = SimpleNamespace(hidden_size=32, intermediate_size=64)

    model = FSDP(
        LigerTiledSwiGLUMLP(config, num_shards=4).cuda(),
        device_id=torch.device(f"cuda:{rank}"),
    )

    x = torch.randn(4, 16, 32, device=f"cuda:{rank}", requires_grad=True)
    model(x).sum().backward()

    if rank == 0:
        print("OK — no error")

    dist.destroy_process_group()


if __name__ == "__main__":
    run()
EOF
torchrun --nproc_per_node=2 test_fsdp1.py
W0306 18:13:26.722000 1029 site-packages/torch/distributed/run.py:852] 
W0306 18:13:26.722000 1029 site-packages/torch/distributed/run.py:852] *****************************************
W0306 18:13:26.722000 1029 site-packages/torch/distributed/run.py:852] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0306 18:13:26.722000 1029 site-packages/torch/distributed/run.py:852] *****************************************
OK — no error

Let me know what you think.

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Mar 9, 2026

Instead of calling torch.autograd.backward() inside the tiling loop, we use torch.autograd.grad(), which returns gradients as plain tensors without writing to .grad or firing any registered hooks. FSDP's post-backward reshard hook never fires during the loop. Param gradients are accumulated manually across shards and written to .grad exactly once after the loop, which is the single gradient event FSDP expects.

Sounds good. Let's try this out.

@alektebel
Copy link
Author

alektebel commented Mar 10, 2026

Hi @Tcc0403 ,

I believe the changes are updated on the PR branch. This is the code I ran too:

  1. Check the TiledMLP run succesfully wrapped on an FSDP
  2. FSDP(TiledMLP) behaves exactly in a functional way as a normal TiledMLP

Succesfull code run on the PR brach

 root@C.32658063:/workspace/Liger-Kernel$ cat > test_fsdp1.py << 'EOF'
import os
from types import SimpleNamespace

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from liger_kernel.ops.tiled_mlp import apply_tiled_mlp
from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP

ATOL = 1e-4
RTOL = 1e-4


def run():
    dist.init_process_group("nccl")
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    torch.cuda.set_device(rank)

    config = SimpleNamespace(hidden_size=32, intermediate_size=64)

    # Shared input — identical on all ranks.
    # requires_grad=True is required: LigerTiledMLPFunction.forward() runs under
    # torch.no_grad(), so the output only has a grad_fn (and is differentiable)
    # when the input itself requires grad.
    torch.manual_seed(0)
    x_full = torch.randn(4, 16, 32, device=f"cuda:{rank}", requires_grad=True)

    # --- Reference: plain model on rank 0 only ---
    if rank == 0:
        torch.manual_seed(42)
        ref_model = LigerTiledSwiGLUMLP(config, num_shards=4).cuda()
        # detach+clone gives an independent leaf so ref backward graph is isolated
        x_ref = x_full.detach().clone().requires_grad_(True)
        ref_out = ref_model(x_ref)
        ref_out.sum().backward()
        ref_grads = {n: p.grad.clone() for n, p in ref_model.named_parameters()}
        ref_out = ref_out.detach()

    # --- FSDP model (same seed so weights are identical at init) ---
    torch.manual_seed(42)
    fsdp_model = FSDP(
        LigerTiledSwiGLUMLP(config, num_shards=4).cuda(),
        device_id=torch.device(f"cuda:{rank}"),
        use_orig_params=True,
    )

    # Each rank processes its own batch shard
    batch_per_rank = x_full.shape[0] // world_size
    x_local = x_full[rank * batch_per_rank:(rank + 1) * batch_per_rank].detach().clone().requires_grad_(True)

    fsdp_out = fsdp_model(x_local)
torchrun --nproc_per_node=2 test_fsdp1.py{name}': max_diff={max_gdiff:.6f}"erence.
W0310 23:10:49.691000 2370 site-packages/torch/distributed/run.py:852] 
W0310 23:10:49.691000 2370 site-packages/torch/distributed/run.py:852] *****************************************
W0310 23:10:49.691000 2370 site-packages/torch/distributed/run.py:852] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0310 23:10:49.691000 2370 site-packages/torch/distributed/run.py:852] *****************************************
OK — forward match  (max_diff=0.00e+00)
OK — gradient match

Bug reproducing code run on the repo HEAD

Please find below the code to reproduce the bug at HEAD:

(main) root@C.32658063:/workspace/Liger-Kernel$ cat > test_fsdp1.py << 'EOF'
import os
from types import SimpleNamespace

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from liger_kernel.ops.tiled_mlp import apply_tiled_mlp
from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP

ATOL = 1e-4
RTOL = 1e-4


def run():
    dist.init_process_group("nccl")
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    torch.cuda.set_device(rank)

    config = SimpleNamespace(hidden_size=32, intermediate_size=64)

    # Shared input — identical on all ranks.
    # requires_grad=True is required: LigerTiledMLPFunction.forward() runs under
    # torch.no_grad(), so the output only has a grad_fn (and is differentiable)
    # when the input itself requires grad.
    torch.manual_seed(0)
    x_full = torch.randn(4, 16, 32, device=f"cuda:{rank}", requires_grad=True)

    # --- Reference: plain model on rank 0 only ---
    if rank == 0:
        torch.manual_seed(42)
        ref_model = LigerTiledSwiGLUMLP(config, num_shards=4).cuda()
        # detach+clone gives an independent leaf so ref backward graph is isolated
        x_ref = x_full.detach().clone().requires_grad_(True)
        ref_out = ref_model(x_ref)
        ref_out.sum().backward()
        ref_grads = {n: p.grad.clone() for n, p in ref_model.named_parameters()}
        ref_out = ref_out.detach()

    # --- FSDP model (same seed so weights are identical at init) ---
    torch.manual_seed(42)
    fsdp_model = FSDP(
        LigerTiledSwiGLUMLP(config, num_shards=4).cuda(),
        device_id=torch.device(f"cuda:{rank}"),
        use_orig_params=True,
    )

    # Each rank processes its own batch shard
    batch_per_rank = x_full.shape[0] // world_size
    x_local = x_full[rank * batch_per_rank:(rank + 1) * batch_per_rank].detach().clone().requires_grad_(True)

    fsdp_out = fsdp_model(x_local)
torchrun --nproc_per_node=2 test_fsdp1.py{name}': max_diff={max_gdiff:.6f}"erence.
W0310 23:13:58.381000 2921 site-packages/torch/distributed/run.py:852] 
W0310 23:13:58.381000 2921 site-packages/torch/distributed/run.py:852] *****************************************
W0310 23:13:58.381000 2921 site-packages/torch/distributed/run.py:852] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0310 23:13:58.381000 2921 site-packages/torch/distributed/run.py:852] *****************************************
[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/Liger-Kernel/test_fsdp1.py", line 88, in <module>
[rank0]:     run()
[rank0]:   File "/workspace/Liger-Kernel/test_fsdp1.py", line 54, in run
[rank0]:     fsdp_out.sum().backward()
[rank0]:   File "/venv/main/lib/python3.12/site-packages/torch/_tensor.py", line 630, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "/venv/main/lib/python3.12/site-packages/torch/autograd/__init__.py", line 364, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/venv/main/lib/python3.12/site-packages/torch/autograd/graph.py", line 865, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/venv/main/lib/python3.12/site-packages/torch/autograd/function.py", line 317, in apply
[rank0]:     return user_fn(self, *args)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/Liger-Kernel/src/liger_kernel/ops/utils.py", line 40, in wrapper
[rank0]:     return fn(ctx, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/Liger-Kernel/src/liger_kernel/ops/tiled_mlp.py", line 92, in backward
[rank0]:     output = fn(mlp_module, x_shard)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/Liger-Kernel/src/liger_kernel/transformers/tiled_mlp.py", line 102, in _mlp_forward
[rank0]:     gate = module.gate_proj(x)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/venv/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/venv/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/venv/main/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 134, in forward
[rank0]:     return F.linear(input, self.weight, self.bias)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: size mismatch, got input (8), mat (8x32), vec (2048)
W0310 23:14:03.576000 2921 site-packages/torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 2985 closing signal SIGTERM
E0310 23:14:03.741000 2921 site-packages/torch/distributed/elastic/multiprocessing/api.py:984] failed (exitcode: 1) local_rank: 0 (pid: 2984) of binary: /venv/main/bin/python3

What more tests come to your mind in order to make sure that this doesn't break anything and behaves as expected? Should I integrate this from above into the test suite? Only thing, it will only work on a 2+ gpu instance.

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Mar 11, 2026

Add a test similar to

def test_dtensor_rms_norm(world_size, bs, sl, hd, dtype, atol, rtol, offset, casting_mode):

@alektebel
Copy link
Author

alektebel commented Mar 15, 2026

Hi @Tcc0403 ,

Added a test for FSDP-TiledMLP compatibility. Ran it on 2x3060 env without issues:

test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype0-1e-05-1e-05-2-256-512-1-2] PASSED                                                            [ 81%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype0-1e-05-1e-05-2-256-512-2-2] PASSED                                                            [ 82%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype0-1e-05-1e-05-2-256-512-4-2] PASSED                                                            [ 83%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype0-1e-05-1e-05-2-512-1024-1-2] PASSED                                                           [ 84%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype0-1e-05-1e-05-2-512-1024-2-2] PASSED                                                           [ 85%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype0-1e-05-1e-05-2-512-1024-4-2] PASSED                                                           [ 86%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype0-1e-05-1e-05-1-128-256-1-2] PASSED                                                            [ 87%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype0-1e-05-1e-05-1-128-256-2-2] PASSED                                                            [ 88%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype0-1e-05-1e-05-1-128-256-4-2] PASSED                                                            [ 90%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype1-0.1-0.1-2-256-512-1-2] PASSED                                                                [ 91%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype1-0.1-0.1-2-256-512-2-2] PASSED                                                                [ 92%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype1-0.1-0.1-2-256-512-4-2] PASSED                                                                [ 93%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype1-0.1-0.1-2-512-1024-1-2] PASSED                                                               [ 94%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype1-0.1-0.1-2-512-1024-2-2] PASSED                                                               [ 95%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype1-0.1-0.1-2-512-1024-4-2] PASSED                                                               [ 96%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype1-0.1-0.1-1-128-256-1-2] PASSED                                                                [ 97%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype1-0.1-0.1-1-128-256-2-2] PASSED                                                                [ 98%]
test/transformers/test_tiled_mlp.py::test_fsdp_tiled_swiglu[dtype1-0.1-0.1-1-128-256-4-2] PASSED                                                                [100%]

One question: ¿Is adding 2, 4, 8 world size overkill for the tests?

I've left the combinations to test as just 2 world size:

@pytest.mark.parametrize("world_size", [2])  # extend to [2, 4, 8] on multi-GPU hosts

The test was added to test/transformers/test_tiled_mlp.py, changes are now visible in the PR branch.

Let me know if it would be better to have 2, 4, 8 gpu instance test, skipped in case not available by the setup.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants