Skip to content

[X86] Elect to tail call when sret ptr is passed to the callee #146575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

antoniofrighetto
Copy link
Contributor

We may be able to allow the callee to be tail-called when the caller expects a sret pointer argument, as long as this pointer is forwarded to the callee.

Fixes: #146303.

@llvmbot
Copy link
Member

llvmbot commented Jul 1, 2025

@llvm/pr-subscribers-backend-x86

Author: Antonio Frighetto (antoniofrighetto)

Changes

We may be able to allow the callee to be tail-called when the caller expects a sret pointer argument, as long as this pointer is forwarded to the callee.

Fixes: #146303.


Full diff: https://github.com/llvm/llvm-project/pull/146575.diff

2 Files Affected:

  • (modified) llvm/lib/Target/X86/X86ISelLoweringCall.cpp (+27-8)
  • (modified) llvm/test/CodeGen/X86/sibcall.ll (+13-86)
diff --git a/llvm/lib/Target/X86/X86ISelLoweringCall.cpp b/llvm/lib/Target/X86/X86ISelLoweringCall.cpp
index cb38a39ff991d..dd7e32b665c0d 100644
--- a/llvm/lib/Target/X86/X86ISelLoweringCall.cpp
+++ b/llvm/lib/Target/X86/X86ISelLoweringCall.cpp
@@ -2788,6 +2788,7 @@ bool X86TargetLowering::IsEligibleForTailCallOptimization(
 
   // If -tailcallopt is specified, make fastcc functions tail-callable.
   MachineFunction &MF = DAG.getMachineFunction();
+  X86MachineFunctionInfo *FuncInfo = MF.getInfo<X86MachineFunctionInfo>();
   const Function &CallerF = MF.getFunction();
 
   // If the function return type is x86_fp80 and the callee return type is not,
@@ -2824,14 +2825,33 @@ bool X86TargetLowering::IsEligibleForTailCallOptimization(
   if (RegInfo->hasStackRealignment(MF))
     return false;
 
-  // Also avoid sibcall optimization if we're an sret return fn and the callee
-  // is incompatible. See comment in LowerReturn about why hasStructRetAttr is
-  // insufficient.
-  if (MF.getInfo<X86MachineFunctionInfo>()->getSRetReturnReg()) {
+  // Avoid sibcall optimization if we are an sret return function and the callee
+  // is incompatible, unless these premises are proven wrong.
+  if (Register SRetReg = FuncInfo->getSRetReturnReg()) {
     // For a compatible tail call the callee must return our sret pointer. So it
     // needs to be (a) an sret function itself and (b) we pass our sret as its
-    // sret. Condition #b is harder to determine.
-    return false;
+    // sret. We know the caller has a sret ptr argument (SRetReg). Now locate
+    // the operand index within the callee which has a sret ptr as well.
+    auto It = find_if(Outs, [](const auto &OA) { return OA.Flags.isSRet(); });
+    // Bail out if the callee has not any sret argument.
+    if (It == Outs.end())
+      return false;
+
+    bool Found = false;
+    unsigned Pos = std::distance(Outs.begin(), It);
+    SDValue SRetArgVal = OutVals[Pos];
+    // We are looking for a CopyToReg, where the callee sret argument is written
+    // into a new vreg (which should later be %rax/%eax, if this is returned).
+    for (SDNode *User : SRetArgVal->users()) {
+      if (User->getOpcode() != ISD::CopyToReg)
+        continue;
+      Register Reg = cast<RegisterSDNode>(User->getOperand(1))->getReg();
+      if (Reg == SRetReg && User->getOperand(2) == SRetArgVal)
+        Found = true;
+    }
+
+    if (!Found)
+      return false;
   } else if (IsCalleePopSRet)
     // The callee pops an sret, so we cannot tail-call, as our caller doesn't
     // expect that.
@@ -2953,8 +2973,7 @@ bool X86TargetLowering::IsEligibleForTailCallOptimization(
       X86::isCalleePop(CalleeCC, Subtarget.is64Bit(), isVarArg,
                        MF.getTarget().Options.GuaranteedTailCallOpt);
 
-  if (unsigned BytesToPop =
-          MF.getInfo<X86MachineFunctionInfo>()->getBytesToPopOnReturn()) {
+  if (unsigned BytesToPop = FuncInfo->getBytesToPopOnReturn()) {
     // If we have bytes to pop, the callee must pop them.
     bool CalleePopMatches = CalleeWillPop && BytesToPop == StackArgsSize;
     if (!CalleePopMatches)
diff --git a/llvm/test/CodeGen/X86/sibcall.ll b/llvm/test/CodeGen/X86/sibcall.ll
index 4a0a68ee32243..51e0fea92453e 100644
--- a/llvm/test/CodeGen/X86/sibcall.ll
+++ b/llvm/test/CodeGen/X86/sibcall.ll
@@ -444,21 +444,11 @@ define dso_local void @t15(ptr noalias sret(%struct.foo) %agg.result) nounwind
 ;
 ; X64-LABEL: t15:
 ; X64:       # %bb.0:
-; X64-NEXT:    pushq %rbx
-; X64-NEXT:    movq %rdi, %rbx
-; X64-NEXT:    callq f
-; X64-NEXT:    movq %rbx, %rax
-; X64-NEXT:    popq %rbx
-; X64-NEXT:    retq
+; X64-NEXT:    jmp f # TAILCALL
 ;
 ; X32-LABEL: t15:
 ; X32:       # %bb.0:
-; X32-NEXT:    pushq %rbx
-; X32-NEXT:    movq %rdi, %rbx
-; X32-NEXT:    callq f
-; X32-NEXT:    movl %ebx, %eax
-; X32-NEXT:    popq %rbx
-; X32-NEXT:    retq
+; X32-NEXT:    jmp f # TAILCALL
   tail call fastcc void @f(ptr noalias sret(%struct.foo) %agg.result) nounwind
   ret void
 }
@@ -607,32 +597,15 @@ declare dso_local fastcc double @foo20(double) nounwind
 define fastcc void @t21_sret_to_sret(ptr noalias sret(%struct.foo) %agg.result) nounwind  {
 ; X86-LABEL: t21_sret_to_sret:
 ; X86:       # %bb.0:
-; X86-NEXT:    pushl %esi
-; X86-NEXT:    subl $8, %esp
-; X86-NEXT:    movl %ecx, %esi
-; X86-NEXT:    calll t21_f_sret
-; X86-NEXT:    movl %esi, %eax
-; X86-NEXT:    addl $8, %esp
-; X86-NEXT:    popl %esi
-; X86-NEXT:    retl
+; X86-NEXT:    jmp t21_f_sret # TAILCALL
 ;
 ; X64-LABEL: t21_sret_to_sret:
 ; X64:       # %bb.0:
-; X64-NEXT:    pushq %rbx
-; X64-NEXT:    movq %rdi, %rbx
-; X64-NEXT:    callq t21_f_sret
-; X64-NEXT:    movq %rbx, %rax
-; X64-NEXT:    popq %rbx
-; X64-NEXT:    retq
+; X64-NEXT:    jmp t21_f_sret # TAILCALL
 ;
 ; X32-LABEL: t21_sret_to_sret:
 ; X32:       # %bb.0:
-; X32-NEXT:    pushq %rbx
-; X32-NEXT:    movq %rdi, %rbx
-; X32-NEXT:    callq t21_f_sret
-; X32-NEXT:    movl %ebx, %eax
-; X32-NEXT:    popq %rbx
-; X32-NEXT:    retq
+; X32-NEXT:    jmp t21_f_sret # TAILCALL
   tail call fastcc void @t21_f_sret(ptr noalias sret(%struct.foo) %agg.result) nounwind
   ret void
 }
@@ -640,34 +613,15 @@ define fastcc void @t21_sret_to_sret(ptr noalias sret(%struct.foo) %agg.result)
 define fastcc void @t21_sret_to_sret_more_args(ptr noalias sret(%struct.foo) %agg.result, i32 %a, i32 %b) nounwind  {
 ; X86-LABEL: t21_sret_to_sret_more_args:
 ; X86:       # %bb.0:
-; X86-NEXT:    pushl %esi
-; X86-NEXT:    subl $8, %esp
-; X86-NEXT:    movl %ecx, %esi
-; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
-; X86-NEXT:    movl %eax, (%esp)
-; X86-NEXT:    calll f_sret@PLT
-; X86-NEXT:    movl %esi, %eax
-; X86-NEXT:    addl $8, %esp
-; X86-NEXT:    popl %esi
-; X86-NEXT:    retl
+; X86-NEXT:    jmp f_sret@PLT # TAILCALL
 ;
 ; X64-LABEL: t21_sret_to_sret_more_args:
 ; X64:       # %bb.0:
-; X64-NEXT:    pushq %rbx
-; X64-NEXT:    movq %rdi, %rbx
-; X64-NEXT:    callq f_sret@PLT
-; X64-NEXT:    movq %rbx, %rax
-; X64-NEXT:    popq %rbx
-; X64-NEXT:    retq
+; X64-NEXT:    jmp f_sret@PLT # TAILCALL
 ;
 ; X32-LABEL: t21_sret_to_sret_more_args:
 ; X32:       # %bb.0:
-; X32-NEXT:    pushq %rbx
-; X32-NEXT:    movq %rdi, %rbx
-; X32-NEXT:    callq f_sret@PLT
-; X32-NEXT:    movl %ebx, %eax
-; X32-NEXT:    popq %rbx
-; X32-NEXT:    retq
+; X32-NEXT:    jmp f_sret@PLT # TAILCALL
   tail call fastcc void @f_sret(ptr noalias sret(%struct.foo) %agg.result, i32 %a, i32 %b) nounwind
   ret void
 }
@@ -675,35 +629,18 @@ define fastcc void @t21_sret_to_sret_more_args(ptr noalias sret(%struct.foo) %ag
 define fastcc void @t21_sret_to_sret_second_arg_sret(ptr noalias %agg.result, ptr noalias sret(%struct.foo) %ret) nounwind  {
 ; X86-LABEL: t21_sret_to_sret_second_arg_sret:
 ; X86:       # %bb.0:
-; X86-NEXT:    pushl %esi
-; X86-NEXT:    subl $8, %esp
-; X86-NEXT:    movl %edx, %esi
 ; X86-NEXT:    movl %edx, %ecx
-; X86-NEXT:    calll t21_f_sret
-; X86-NEXT:    movl %esi, %eax
-; X86-NEXT:    addl $8, %esp
-; X86-NEXT:    popl %esi
-; X86-NEXT:    retl
+; X86-NEXT:    jmp t21_f_sret # TAILCALL
 ;
 ; X64-LABEL: t21_sret_to_sret_second_arg_sret:
 ; X64:       # %bb.0:
-; X64-NEXT:    pushq %rbx
-; X64-NEXT:    movq %rsi, %rbx
 ; X64-NEXT:    movq %rsi, %rdi
-; X64-NEXT:    callq t21_f_sret
-; X64-NEXT:    movq %rbx, %rax
-; X64-NEXT:    popq %rbx
-; X64-NEXT:    retq
+; X64-NEXT:    jmp t21_f_sret # TAILCALL
 ;
 ; X32-LABEL: t21_sret_to_sret_second_arg_sret:
 ; X32:       # %bb.0:
-; X32-NEXT:    pushq %rbx
-; X32-NEXT:    movq %rsi, %rbx
 ; X32-NEXT:    movq %rsi, %rdi
-; X32-NEXT:    callq t21_f_sret
-; X32-NEXT:    movl %ebx, %eax
-; X32-NEXT:    popq %rbx
-; X32-NEXT:    retq
+; X32-NEXT:    jmp t21_f_sret # TAILCALL
   tail call fastcc void @t21_f_sret(ptr noalias sret(%struct.foo) %ret) nounwind
   ret void
 }
@@ -725,27 +662,17 @@ define fastcc void @t21_sret_to_sret_more_args2(ptr noalias sret(%struct.foo) %a
 ;
 ; X64-LABEL: t21_sret_to_sret_more_args2:
 ; X64:       # %bb.0:
-; X64-NEXT:    pushq %rbx
 ; X64-NEXT:    movl %esi, %eax
-; X64-NEXT:    movq %rdi, %rbx
 ; X64-NEXT:    movl %edx, %esi
 ; X64-NEXT:    movl %eax, %edx
-; X64-NEXT:    callq f_sret@PLT
-; X64-NEXT:    movq %rbx, %rax
-; X64-NEXT:    popq %rbx
-; X64-NEXT:    retq
+; X64-NEXT:    jmp f_sret@PLT # TAILCALL
 ;
 ; X32-LABEL: t21_sret_to_sret_more_args2:
 ; X32:       # %bb.0:
-; X32-NEXT:    pushq %rbx
 ; X32-NEXT:    movl %esi, %eax
-; X32-NEXT:    movq %rdi, %rbx
 ; X32-NEXT:    movl %edx, %esi
 ; X32-NEXT:    movl %eax, %edx
-; X32-NEXT:    callq f_sret@PLT
-; X32-NEXT:    movl %ebx, %eax
-; X32-NEXT:    popq %rbx
-; X32-NEXT:    retq
+; X32-NEXT:    jmp f_sret@PLT # TAILCALL
   tail call fastcc void @f_sret(ptr noalias sret(%struct.foo) %agg.result, i32 %b, i32 %a) nounwind
   ret void
 }

Copy link
Collaborator

@rnk rnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, the test cases suggest that this would be a major improvement.

// For a compatible tail call the callee must return our sret pointer. So it
// needs to be (a) an sret function itself and (b) we pass our sret as its
// sret. Condition #b is harder to determine.
return false;
// sret. We know the caller has a sret ptr argument (SRetReg). Now locate
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move this logic to a helper function and keep the comment as it was so that the TCO eligibility check remains easy to read. The original comment captures the properties we're looking for well.

return false;
// sret. We know the caller has a sret ptr argument (SRetReg). Now locate
// the operand index within the callee which has a sret ptr as well.
auto It = find_if(Outs, [](const auto &OA) { return OA.Flags.isSRet(); });
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the need for an index from distance, I would just use the indexed for loop search to make this more readable.

; X32-NEXT: movl %ebx, %eax
; X32-NEXT: popq %rbx
; X32-NEXT: retq
; X32-NEXT: jmp f_sret@PLT # TAILCALL
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add new whitebox test cases that try to fool your heuristic. The C code that comes to mind for me looks like:

typedef struct SRet {
  long x, y, z;
} SRet;
SRet foo();
SRet bar(SRet *p) {
  SRet r = {0, 0, 0};
  *p = foo();
  return r;
}

The call to foo could end up in the tail position, but the implicit sret arg is not forwarded to foo here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this one, and variants thereof, thanks (this was one above was already restrained by these lines:

// If we have an explicit sret argument that is an Instruction, (i.e., it
// might point to function-local memory), we can't meaningfully tail-call.
if (Entry.IsSRet && isa<Instruction>(V))
isTailCall = false;
}

).

We may be able to allow the callee to be tail-called when the caller
expects a `sret` pointer argument, as long as this pointer is forwarded
to the callee.

Fixes: llvm#146303.
@antoniofrighetto antoniofrighetto force-pushed the feature/let-x86-sret-tailcall branch from 8c32647 to 4342698 Compare July 3, 2025 14:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

missing tail call optimization on RVO
3 participants