Skip to content

Conversation

@TmacAaron
Copy link
Contributor

@TmacAaron TmacAaron commented Jan 7, 2026

What does this PR do?

The original npu attention backend in diffusers does not support ulysses parallel yet. This PR is to implement the ulysses parallel attention for npu attention backend.

Note: Only implement forward op now, the backward op is not supported now.

Test

Hardware

Atlas 800T A2

Repro Code

import os
import time

import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu

from diffusers import FluxPipeline, ContextParallelConfig


def launched_with_torchrun() -> bool:
    return (
        "RANK" in os.environ
        and "WORLD_SIZE" in os.environ
        and "LOCAL_RANK" in os.environ
    )

warm_up = True

model_id = "black-forest-labs/FLUX.1-dev"
height = 1024
width = 1024
steps = 50
prompt = "A cat holding a sign that says hello world"


try:
    if launched_with_torchrun():
        torch.distributed.init_process_group("nccl")
        rank = torch.distributed.get_rank()
        device = torch.device("cuda", rank % torch.cuda.device_count())
        world_size = torch.distributed.get_world_size()
    else:
        rank = 0
        device = torch.device("cuda")
        world_size = 1
    torch.cuda.set_device(device)

    pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)
    pipe.transformer.set_attention_backend("_native_npu")

    if launched_with_torchrun():
        print(f"{world_size=}")
        pipe.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=world_size))
    
    # Warm Up
    if warm_up:
        _ = pipe(
            prompt,
            height=1024,
            width=1024,
            guidance_scale=3.5,
            num_inference_steps=2,
            max_sequence_length=512,
            generator=torch.Generator("cpu").manual_seed(0)
        ).images[0]

    # Inference
    start_time = time.time()
    image = pipe(
        prompt,
        height=height,
        width=width,
        guidance_scale=3.5,
        num_inference_steps=steps,
        max_sequence_length=512,
        generator=torch.Generator("cpu").manual_seed(0)
    ).images[0]
    end_time = time.time()

    if rank == 0:
        print(f"Inference Time: {end_time - start_time} s")
        image.save(f"sp{world_size}-flux-dev.png")

except Exception as e:
    print(f"An error occurred: {e}")
    raise e

finally:
    if launched_with_torchrun():
        if torch.distributed.is_initialized():
            torch.distributed.destroy_process_group()

Result

1. no ulysses attention

Command

python ./flux_infer.py

Inference Time:
24.44s

Result Image:
sp1-flux-dev

2. ulysses attention with degree 2

Command:

torchrun --nproc_per_node=2 ./flux_infer.py

Inference Time:
18.83s

Result Image:
sp2-flux-dev

3. ulysses attention with degree 4

Command:

torchrun --nproc_per_node=4 ./flux_infer.py

Inference Time:
12.69s

Result Image:
sp4-flux-dev

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@TmacAaron
Copy link
Contributor Author

TmacAaron commented Jan 7, 2026

@yiyixuxu @sayakpaul Hello, could you please review this pr, tks.

@sayakpaul
Copy link
Member

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Jan 8, 2026

Style fix is beginning .... View the workflow run here.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@TmacAaron
Copy link
Contributor Author

@bot /style

can you take a look again?

@sayakpaul sayakpaul merged commit be38f41 into huggingface:main Jan 9, 2026
10 of 11 checks passed
@sayakpaul
Copy link
Member

Thanks for your contribution!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants