-
Notifications
You must be signed in to change notification settings - Fork 15
Open
Labels
Description
FSDP fails when using FP8 quantized FLUX.2 models due to missing implementation of foreach_tensor_copy operation for the Float8_e4m3fn dtype.
Environment
- Hardware: Multi-GPU 4x B60 system
- Docker Image: intel/llm-scaler-omni:0.1.0-b4
- Model: FLUX.2 Dev FP8 Mixed Precision (
flux2_dev_fp8mixed.safetensors) - Distributed Framework: Ray + XFuser + FSDP
Reproduction Steps
- Load the image_flux2_multi_xpu.json workflow to ComfyUI
- Dowload model dependencies
BASE_DIR="/llm/ComfyUI/models"
wget https://huggingface.co/Comfy-Org/flux2-dev/resolve/main/split_files/diffusion_models/flux2_dev_fp8mixed.safetensors -O $BASE_DIR/diffusion_models/flux2_dev_fp8mixed.safetensors
wget https://huggingface.co/Comfy-Org/flux2-dev/resolve/main/split_files/text_encoders/mistral_3_small_flux2_bf16.safetensors -O $BASE_DIR/text_encoders/mistral_3_small_flux2_bf16.safetensors
wget https://huggingface.co/ostris/flux2_berthe_morisot/resolve/main/flux2_berthe_morisot.safetensors -O $BASE_DIR/loras/flux2_berthe_morisot.safetensors
wget https://huggingface.co/Comfy-Org/flux2-dev/resolve/main/split_files/vae/flux2-vae.safetensors -O $BASE_DIR/vae/flux2-vae.safetensors
- Set FSDP=True and run the workflow
Error Traceback
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1827, in inner
result = forward_call(*args, **kwargs)
File "/llm/ComfyUI/comfy/ldm/flux/model.py", line 288, in forward
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
File "/llm/ComfyUI/comfy/patcher_extension.py", line 112, in execute
return self.original(*args, **kwargs)
File "/llm/ComfyUI/comfy/ldm/flux/model.py", line 338, in _forward
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
File "/tmp/ray/session_2025-12-19_17-18-54_412269_13375/runtime_resources/py_modules_files/_ray_pkg_8450ce77068aa83f/raylight/diffusion_models/flux/xdit_context_parallel.py", line 167, in usp_dit_forward
img, txt = block(img=img,
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1879, in _call_impl
return inner()
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1806, in inner
args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 62, in fsdp_hook_wrapper
return torch._dynamo.disable(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 248, in _pre_forward
args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 351, in pre_forward
self.unshard(self.unshard_async_op)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 279, in unshard
self._all_gather_result = foreach_all_gather(
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py", line 190, in foreach_all_gather
all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(
File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1243, in __call__
return self._op(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py", line 115, in all_gather_copy_in_cuda
torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs)
NotImplementedError: "foreach_tensor_copy" not implemented for 'Float8_e4m3fn'
