Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DID loop split for SDPA #3711

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open

DID loop split for SDPA #3711

wants to merge 5 commits into from

Conversation

Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Jan 15, 2025

In this PR, I explicitly parallelize the outputs attn, log_sumexp of sdpfa_fwd. Sharding propagation for loop split does not work correctly in this case at the moment.

@Priya2698 Priya2698 changed the title DID loop split for sdpa forward DID loop split for SDPA Jan 15, 2025
@Priya2698
Copy link
Collaborator Author

!test

@Priya2698 Priya2698 marked this pull request as ready for review January 15, 2025 19:54
@Priya2698 Priya2698 requested a review from wujingyue January 15, 2025 20:50
self.define_tensor(
shape=[b, h, s, e // h],
dtype=DataType.BFloat16,
stride_order=stride_order,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed? I believe it'll get overwritten by set_allocation_as_loop so we can simply start with the default stride_order.

tests/python/test_multidevice.py Outdated Show resolved Hide resolved
tests/python/test_multidevice.py Outdated Show resolved Hide resolved
Copy link

github-actions bot commented Jan 16, 2025

PR Reviewer Guide 🔍

(Review updated until commit 026229d)

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
🧪 PR contains tests
⚡ Recommended focus areas for review

Potential Logic Change

The introduction of QkvFormat parameter in the Model class and its usage in definition method might alter the expected behavior of the SDPA loop split functionality. Reviewers should verify that the added logic for QkvFormat.BHSE and QkvFormat.BSHE correctly handles the stride ordering and tensor reformatting.

@pytest.mark.skipif(
    utils.is_pre_ampere(),
    reason="Flash Attention is only supported on Ampere and newer devices.",
)
@pytest.mark.parametrize("qkv_format", [QkvFormat.BHSE, QkvFormat.BSHE])
@pytest.mark.mpi
def test_sdpa_loop_split(multidevice_test, qkv_format: QkvFormat):
    d, b, s, h, e = multidevice_test.size, 2, 1024, 12, 768

    if h % d != 0:
        pytest.skip(f"We only support even split, so {h} has to be divisible by {d}.")
    mesh = nvfuser.DeviceMesh(range(d))

    class Model(FusionDefinition):
        def __init__(self, qkv_format: QkvFormat):
            super().__init__()
            self._qkv_format = qkv_format

        def definition(self) -> None:
            match self._qkv_format:
                case QkvFormat.BHSE:
                    stride_order = [3, 2, 1, 0]
                case QkvFormat.BSHE:
                    stride_order = [3, 1, 2, 0]

            self.q, self.k, self.v, self.out_grad = [
                self.define_tensor(
                    shape=[b, h, s, e // h],
                    dtype=DataType.BFloat16,
                    stride_order=stride_order,
                )
                for _ in range(4)
            ]

            # TODO(#3123): support sharded dropout and change this to a
            # positive probability.
            dropout_p = self.define_scalar(0.0, dtype=DataType.Double)
            is_causal = self.define_scalar(True, dtype=DataType.Bool)
            self.attn, self.log_sumexp, seed, offset = self.ops.sdpfa_fwd(
                self.q, self.k, self.v, dropout_p, is_causal, scale=None
            )

            self.q_grad, self.k_grad, self.v_grad = self.ops.sdpfa_bwd(
                self.out_grad,
                self.q,
                self.k,
                self.v,
                self.attn,
                self.log_sumexp,
                dropout_p,
                is_causal,
                seed,
                offset,
                scale=None,
            )

            self.add_output(self.attn)
            for grad in [self.q_grad, self.k_grad, self.v_grad]:
                self.add_output(grad)

        def multidevice_schedule(self) -> None:
            for t in [
                self.q,
                self.k,
                self.v,
                self.attn,
                self.log_sumexp,
                self.out_grad,
                self.q_grad,
                self.k_grad,
                self.v_grad,
            ]:
                self.sched._set_device_mesh(t, mesh)
                self.sched.split(t, 1, d, False)
                self.sched.parallelize(t, 1, nvfuser.ParallelType.mesh_x)
                if self._qkv_format == QkvFormat.BSHE:
                    # The loop domain is: {i{B}, i{DIDx}, i{H//D}, i{S}, i{E//H}}
                    # Reorder i{S} in the allocation domain for BHSE: {i{DIDx}, i{B}, i{S}, i{H//D}, i{E//H}}
                    self.sched.reorder(t, {2: 3, 3: 2})
                self.sched.set_allocation_as_loop(t)
Device Compatibility

The test is skipped for pre-Ampere devices due to Flash Attention support. Ensure that this restriction is correctly implemented and documented, and consider adding tests for Ampere and newer devices to cover the functionality.

@pytest.mark.skipif(
    utils.is_pre_ampere(),
    reason="Flash Attention is only supported on Ampere and newer devices.",
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants