-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature request: Extend the uop forwarding in the fusion segmenter to include other single-input trivial ops #3647
Comments
Does this mean ViewOp would be included? |
Why not? |
I guess I just assumed not all the schedulers are able to handle |
Would it be a problem for the matmul scheduler? I don't think any other scheduler would have any issue with reshape at the beginning of a fusion, maybe except when an expanded ID becomes a concrete ID by a reshape. Reshape at the end of a fusion would need more attention as it might interfere with scheduling heuristics. |
For Hopper, we will need to reject if there are any ops before the matmul. I think we probably need to do a better job of checking for cases like this on Ampere also, since I think we might actually accept a case like
In that case we will not be able to schedule the fusion but I think we might actually accept it (I will try this out). In any case, I think it would probably be alright to include the ops in forwarding and we will just take care to reject those segments in the matmul scheduler. |
That seems to mean the uop forwarding is already broken with matmul. IIRC, any forwarded ops are not considered during the segmentation time, but they are just added back to first segments after the segmentation is completed. Another |
Yes I think that's right. I think we discussed this issue in the context of slice and pad about a year ago: #1541 (comment). At that time I think we chose to just move the pad and slice to be closer to inputs as a workaround. I believe the issue still remains that we do not actually inspect the real segment with forwarded ops during segmentation. I will try and create a repro for a view+matmul or a uop + matmul that segments then fails to schedule and I will open a separate issue to track that. |
Thanks. Setting those "trivial" beginning ops aside would generally help the segmenter, so I think we should do this more aggressively. However, we could back off when there's a potential matmul pattern. This is an optimization, so we can always safely disable it. |
Actually currently matmul does seem to implicitly check for this because we strictly check that fusion inputs must have all known IterDomain types. When there is a slice or reshape in the prologue then we will not consider the input dims to be one of M, N, K, or Bias, so the matmul scheduler rejects the segment. That said, once these are forwarded we won't see that check until computing heuristics and will probably hit an error then. The repro I used is here: def test_mulsum_reshape_prologue(self):
m = 24
n = 16
k = 8
inputs = [
# This input can be used to check reshape
#torch.randn(m // 2, 2, k, device="cuda", dtype=torch.float16),
torch.randn(m + 1, k, device="cuda", dtype=torch.float16),
torch.randn(k, n, device="cuda", dtype=torch.float16),
]
def fusion_func(fd: FusionDefinition) -> None:
t0 = fd.from_pytorch(inputs[0])
t1 = fd.from_pytorch(inputs[1])
#t2 = fd.ops.reshape(t0, [m, k])
t2 = fd.ops.slice(
t0,
start_indices=[1, 0],
end_indices=[m+1, k],
strides=[1, 1],
)
t3 = fd.ops.broadcast_in_dim(t2, [m, k, 1], [0, 1])
t4 = fd.ops.broadcast_in_dim(t1, [1, k, n], [1, 2])
t5 = fd.ops.mul(t3, t4)
t6 = fd.ops.sum(t5, dim=1)
t7 = fd.ops.cast(t6, DataType.Half)
fd.add_output(t7)
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
eager_out = torch.matmul(inputs[0].view(m, k), inputs[1])
fp16_nvf_out = nvf_out[0]
self.assertEqual(eager_out, fp16_nvf_out) |
Thanks for checking. I don't think we would want to forward slice/pad/concat as they are not "trivial" for nvFuser (yet). Also for reshape, I realized we should probably not skip splitting reshape since that would mean we would not be able to find the connection between split output IDs. Merge-only reshape should be fine, I believe. |
The current fusion segmenter forwards a straight-line sequence of UnaryOp's. Should it be extended to include other single-input ops like
BroadcastOp
,ExpandOp
, and reshape (LoadStoreOp
)? That would help generate better segmentation with some of the RoPE modules. Would there be any side effect?The text was updated successfully, but these errors were encountered: