Skip to content

FSDP Fails with FP8 FLUX.2 Model #201

@munikera

Description

@munikera

FSDP fails when using FP8 quantized FLUX.2 models due to missing implementation of foreach_tensor_copy operation for the Float8_e4m3fn dtype.

Environment

  • Hardware: Multi-GPU 4x B60 system
  • Docker Image: intel/llm-scaler-omni:0.1.0-b4
  • Model: FLUX.2 Dev FP8 Mixed Precision (flux2_dev_fp8mixed.safetensors)
  • Distributed Framework: Ray + XFuser + FSDP

Reproduction Steps

  1. Load the image_flux2_multi_xpu.json workflow to ComfyUI
  2. Dowload model dependencies
BASE_DIR="/llm/ComfyUI/models"

wget https://huggingface.co/Comfy-Org/flux2-dev/resolve/main/split_files/diffusion_models/flux2_dev_fp8mixed.safetensors -O $BASE_DIR/diffusion_models/flux2_dev_fp8mixed.safetensors
 
wget https://huggingface.co/Comfy-Org/flux2-dev/resolve/main/split_files/text_encoders/mistral_3_small_flux2_bf16.safetensors -O $BASE_DIR/text_encoders/mistral_3_small_flux2_bf16.safetensors
 
wget https://huggingface.co/ostris/flux2_berthe_morisot/resolve/main/flux2_berthe_morisot.safetensors -O $BASE_DIR/loras/flux2_berthe_morisot.safetensors
 
wget https://huggingface.co/Comfy-Org/flux2-dev/resolve/main/split_files/vae/flux2-vae.safetensors -O $BASE_DIR/vae/flux2-vae.safetensors
  1. Set FSDP=True and run the workflow

Error Traceback

  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1827, in inner
    result = forward_call(*args, **kwargs)
  File "/llm/ComfyUI/comfy/ldm/flux/model.py", line 288, in forward
    return comfy.patcher_extension.WrapperExecutor.new_class_executor(
  File "/llm/ComfyUI/comfy/patcher_extension.py", line 112, in execute
    return self.original(*args, **kwargs)
  File "/llm/ComfyUI/comfy/ldm/flux/model.py", line 338, in _forward
    out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
  File "/tmp/ray/session_2025-12-19_17-18-54_412269_13375/runtime_resources/py_modules_files/_ray_pkg_8450ce77068aa83f/raylight/diffusion_models/flux/xdit_context_parallel.py", line 167, in usp_dit_forward
    img, txt = block(img=img,
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1806, in inner
    args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 62, in fsdp_hook_wrapper
    return torch._dynamo.disable(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 248, in _pre_forward
    args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 351, in pre_forward
    self.unshard(self.unshard_async_op)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 279, in unshard
    self._all_gather_result = foreach_all_gather(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py", line 190, in foreach_all_gather
    all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(
  File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1243, in __call__
    return self._op(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py", line 115, in all_gather_copy_in_cuda
    torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs)
NotImplementedError: "foreach_tensor_copy" not implemented for 'Float8_e4m3fn'
Image

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions