Skip to content

Allow noinline functions to be called with correct argument types. #3963

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

etiotto
Copy link
Contributor

@etiotto etiotto commented Apr 17, 2025

This PR is currently blocked by:

#3974
#3612

Signed-off-by: Tiotto, Ettore <[email protected]>
@etiotto etiotto self-assigned this Apr 17, 2025
@etiotto etiotto requested a review from whitneywhtsang April 17, 2025 21:24
Copy link
Contributor

@whitneywhtsang whitneywhtsang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we resolve the problem in rewrite_stack_ptr pass introduced in #3497 instead of common pass?

@whitneywhtsang whitneywhtsang requested a review from ESI-SYD April 18, 2025 12:40
@etiotto etiotto changed the title Fix issue #3805 Allow noinline functions to be called with correct argument types. Apr 22, 2025
@etiotto
Copy link
Contributor Author

etiotto commented Apr 22, 2025

Can we resolve the problem in rewrite_stack_ptr pass introduced in #3497 instead of common pass?

I tried the new pass and found that it doen't help in this case. Consider this simple test case:

module attributes {triton_intel_gpu.target_arch = "spir64", "ttg.num-warps" = 4 : i32, ttg.shared = 0 : i32} {  
  tt.func public @kernel(%arg0: !tt.ptr<f32>, %extra: i32) {
    tt.call @noinline_simple_fn(%arg0) : (!tt.ptr<f32>) -> ()
    tt.return
  }
  tt.func private @noinline_simple_fn(%arg0: !tt.ptr<f32>) attributes {noinline = true} {
    tt.return
  }  
}

Here, if the extra kernel argument is dropped, then compilation succeeds and that pass works. However just adding the extra parameter to the kernel (or a parameter that is not a tt.ptr) causes the simple test to fail compilation.

Essentially, when the last parameter of the kernel is not a tt.ptr, lowering the code to the LLVM dialect (convert-triton-intel-gpu-to-llvm) fails, and the tritonintelgpu-rewrite-stack-ptr pass doesn't get a chance to be invoked. To fix the problem is therefore not sufficient adding tritonintelgpu-rewrite-stack-ptr to our pass pipeline, we need to fix the code while it is being lowered to the LLVM dialect.

@etiotto
Copy link
Contributor Author

etiotto commented Apr 22, 2025

Can we resolve the problem in rewrite_stack_ptr pass introduced in #3497 instead of common pass?

I tried the new pass and found that it doen't help in this case. Consider this simple test case:

module attributes {triton_intel_gpu.target_arch = "spir64", "ttg.num-warps" = 4 : i32, ttg.shared = 0 : i32} {  
  tt.func public @kernel(%arg0: !tt.ptr<f32>, %extra: i32) {
    tt.call @noinline_simple_fn(%arg0) : (!tt.ptr<f32>) -> ()
    tt.return
  }
  tt.func private @noinline_simple_fn(%arg0: !tt.ptr<f32>) attributes {noinline = true} {
    tt.return
  }  
}

Here, if the extra kernel argument is dropped, then compilation succeeds and that pass works. However just adding the extra parameter to the kernel (or a parameter that is not a tt.ptr) causes the simple test to fail compilation.

Essentially, when the last parameter of the kernel is not a tt.ptr, lowering the code to the LLVM dialect (convert-triton-intel-gpu-to-llvm) fails, and the tritonintelgpu-rewrite-stack-ptr pass doesn't get a chance to be invoked. To fix the problem is therefore not sufficient adding tritonintelgpu-rewrite-stack-ptr to our pass pipeline, we need to fix the code while it is being lowered to the LLVM dialect.

I have made a copy of CallOpConversion so that we can put a fix in it instead of fixing common code.

@etiotto etiotto marked this pull request as ready for review April 22, 2025 21:04
@etiotto etiotto requested a review from chengjunlu April 22, 2025 21:04
%0 = tt.load %arg0 : !tt.ptr<f32>
%1 = tt.load %arg1 : !tt.ptr<f32>
// CHECK: llvm.call spir_funccc @noinline_shared_fn__fp32_fp32_Pfp32__(%8, %17, %arg2, %arg3, %arg2)
tt.call @noinline_shared_fn__fp32_fp32_Pfp32__(%0, %1, %arg2) {allocation.offset = 0 : i32} : (f32, f32, !tt.ptr<f32>) -> ()
// CHECK: llvm.call spir_funccc @noinline_shared_fn(%8, %17, %arg2, %arg3, %arg2)
Copy link
Contributor

@chengjunlu chengjunlu Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this our backend issue for converting the calling conversion? Why do we duplicate the %arg2 as the last param to the callee? It seems the callee doesn't require any global scratch space.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We reuse CallOpConversion from upstream one after the pass. promoteOperands will append GlobalScratchPtr.

promotedOperands.push_back(LLVM::getGlobalScratchPtr(
loc, rewriter, targetInfo, caller, opOffsetVal));

%arg2 added as lastArg.

auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1);
if (!allocOffset) {
return gmemBase;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lowering pattern should transform the kernel entry function's signature as well. Rigth?
The global scratch space base address should be appended to the kernel entry's function.
The triton::CallOp lowering pattern should use the global scratch bases instead of reusing the user buffer which is passed as last argument, %arg2 in this case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this our backend issue for converting the calling conversion? Why do we duplicate the %arg2 as the last param to the callee? It seems the callee doesn't require any global scratch space.

No it is not related to the calling convention. We fix the calling convention in another transformation pattern (FixCallCConv).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lowering pattern should transform the kernel entry function's signature as well. Rigth? The global scratch space base address should be appended to the kernel entry's function. The triton::CallOp lowering pattern should use the global scratch bases instead of reusing the user buffer which is passed as last argument, %arg2 in this case.

I do not know what transformation patter is supposed to change the kernel signature. The host would have to pass a pointer to the global scratch place as well. @ESI-SYD do you know what part of the code is supposed to append the pointer to the scratch space to the kernel ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use these two issue to track the problem here. Not to block this PR.
#3974
#3612

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am putting this PR back into draft mode because the Triton FE/driver doesn't yet pass to the kernel a pointer to the global scratch space.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the global scratch space is not supported by Intel GPU yet. The #3612

@etiotto etiotto marked this pull request as draft April 28, 2025 18:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants