Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: Extend the uop forwarding in the fusion segmenter to include other single-input trivial ops #3647

Open
naoyam opened this issue Dec 25, 2024 · 10 comments
Labels

Comments

@naoyam
Copy link
Collaborator

naoyam commented Dec 25, 2024

The current fusion segmenter forwards a straight-line sequence of UnaryOp's. Should it be extended to include other single-input ops like BroadcastOp, ExpandOp, and reshape (LoadStoreOp)? That would help generate better segmentation with some of the RoPE modules. Would there be any side effect?

@naoyam naoyam added the rope label Dec 25, 2024
@naoyam naoyam changed the title Feature request: Extend the uop forwarding to include other single-input trivial ops Feature request: Extend the uop forwarding in the fusion segmenter to include other single-input trivial ops Dec 25, 2024
@jacobhinkle
Copy link
Collaborator

single-input ops like BroadcastOp, ExpandOp, and reshape (LoadStoreOp)?

Does this mean ViewOp would be included?

@naoyam
Copy link
Collaborator Author

naoyam commented Dec 26, 2024

Why not?

@jacobhinkle
Copy link
Collaborator

Why not?

I guess I just assumed not all the schedulers are able to handle ViewOp well enough to handle this automatic forwarding, but that concern might not be well-founded. I don't know of a specific reason this wouldn't work.

@naoyam
Copy link
Collaborator Author

naoyam commented Dec 31, 2024

Would it be a problem for the matmul scheduler? I don't think any other scheduler would have any issue with reshape at the beginning of a fusion, maybe except when an expanded ID becomes a concrete ID by a reshape. Reshape at the end of a fusion would need more attention as it might interfere with scheduling heuristics.

@jacobhinkle
Copy link
Collaborator

Would it be a problem for the matmul scheduler? I don't think any other scheduler would have any issue with reshape at the beginning of a fusion, maybe except when an expanded ID becomes a concrete ID by a reshape. Reshape at the end of a fusion would need more attention as it might interfere with scheduling heuristics.

For Hopper, we will need to reject if there are any ops before the matmul. I think we probably need to do a better job of checking for cases like this on Ampere also, since I think we might actually accept a case like

tv2 = reshape(tv0)
tv3 = matmul(tv2, tv1)

In that case we will not be able to schedule the fusion but I think we might actually accept it (I will try this out). In any case, I think it would probably be alright to include the ops in forwarding and we will just take care to reject those segments in the matmul scheduler.

@naoyam
Copy link
Collaborator Author

naoyam commented Dec 31, 2024

That seems to mean the uop forwarding is already broken with matmul. IIRC, any forwarded ops are not considered during the segmentation time, but they are just added back to first segments after the segmentation is completed. Another computeHeuristics is called after that with the forwarded ops, so I suspect we would see an error there for a matmul segment, no matter if forwarded ops are reshape or not. Am I understanding correctly?

@jacobhinkle
Copy link
Collaborator

Am I understanding correctly?

Yes I think that's right. I think we discussed this issue in the context of slice and pad about a year ago: #1541 (comment). At that time I think we chose to just move the pad and slice to be closer to inputs as a workaround. I believe the issue still remains that we do not actually inspect the real segment with forwarded ops during segmentation. I will try and create a repro for a view+matmul or a uop + matmul that segments then fails to schedule and I will open a separate issue to track that.

@naoyam
Copy link
Collaborator Author

naoyam commented Dec 31, 2024

Thanks. Setting those "trivial" beginning ops aside would generally help the segmenter, so I think we should do this more aggressively. However, we could back off when there's a potential matmul pattern. This is an optimization, so we can always safely disable it.

@jacobhinkle
Copy link
Collaborator

I will try and create a repro for a view+matmul or a uop + matmul

Actually currently matmul does seem to implicitly check for this because we strictly check that fusion inputs must have all known IterDomain types. When there is a slice or reshape in the prologue then we will not consider the input dims to be one of M, N, K, or Bias, so the matmul scheduler rejects the segment.

That said, once these are forwarded we won't see that check until computing heuristics and will probably hit an error then. The repro I used is here:

    def test_mulsum_reshape_prologue(self):
        m = 24
        n = 16
        k = 8
        inputs = [
            # This input can be used to check reshape
            #torch.randn(m // 2, 2, k, device="cuda", dtype=torch.float16),
            torch.randn(m + 1, k, device="cuda", dtype=torch.float16),
            torch.randn(k, n, device="cuda", dtype=torch.float16),
        ]

        def fusion_func(fd: FusionDefinition) -> None:
            t0 = fd.from_pytorch(inputs[0])
            t1 = fd.from_pytorch(inputs[1])
            #t2 = fd.ops.reshape(t0, [m, k])
            t2 = fd.ops.slice(
                t0,
                start_indices=[1, 0],
                end_indices=[m+1, k],
                strides=[1, 1],
                )
            t3 = fd.ops.broadcast_in_dim(t2, [m, k, 1], [0, 1])
            t4 = fd.ops.broadcast_in_dim(t1, [1, k, n], [1, 2])
            t5 = fd.ops.mul(t3, t4)
            t6 = fd.ops.sum(t5, dim=1)
            t7 = fd.ops.cast(t6, DataType.Half)
            fd.add_output(t7)

        nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
        eager_out = torch.matmul(inputs[0].view(m, k), inputs[1])
        fp16_nvf_out = nvf_out[0]
        self.assertEqual(eager_out, fp16_nvf_out)

@naoyam
Copy link
Collaborator Author

naoyam commented Dec 31, 2024

Thanks for checking. I don't think we would want to forward slice/pad/concat as they are not "trivial" for nvFuser (yet). Also for reshape, I realized we should probably not skip splitting reshape since that would mean we would not be able to find the connection between split output IDs. Merge-only reshape should be fine, I believe.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants