Skip to content

Commit f5ba4a4

Browse files
committed
Add pragma(LDC_musttail)
1 parent 0d4d711 commit f5ba4a4

21 files changed

+138
-68
lines changed

dmd/expression.d

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5077,6 +5077,10 @@ extern (C++) final class CallExp : UnaExp
50775077
bool directcall; // true if a virtual call is devirtualized
50785078
bool inDebugStatement; /// true if this was in a debug statement
50795079
bool ignoreAttributes; /// don't enforce attributes (e.g. call @gc function in @nogc code)
5080+
version (IN_LLVM)
5081+
{
5082+
bool isMustTail; // If marked with pragma(musttail)
5083+
}
50805084
VarDeclaration vthis2; // container for multi-context
50815085

50825086
extern (D) this(const ref Loc loc, Expression e, Expressions* exps)

dmd/expression.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,9 @@ class CallExp final : public UnaExp
863863
bool directcall; // true if a virtual call is devirtualized
864864
bool inDebugStatement; // true if this was in a debug statement
865865
bool ignoreAttributes; // don't enforce attributes (e.g. call @gc function in @nogc code)
866+
#if IN_LLVM
867+
bool isMustTail; // If marked with pragma(musttail)
868+
#endif
866869
VarDeclaration *vthis2; // container for multi-context
867870

868871
static CallExp *create(const Loc &loc, Expression *e, Expressions *exps);

dmd/id.d

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,7 @@ immutable Msgtable[] msgtable =
555555
{ "LDC_global_crt_dtor" },
556556
{ "LDC_extern_weak" },
557557
{ "LDC_profile_instr" },
558+
{ "musttail" },
558559

559560
// IN_LLVM: LDC-specific traits
560561
{ "targetCPU" },

dmd/id.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ struct Id
7676
static Identifier *LDC_inline_ir;
7777
static Identifier *LDC_extern_weak;
7878
static Identifier *LDC_profile_instr;
79+
static Identifier *musttail;
7980
static Identifier *dcReflect;
8081
static Identifier *opencl;
8182
static Identifier *criticalenter;

dmd/statementsem.d

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2135,6 +2135,10 @@ else
21352135
return setError();
21362136
}
21372137
}
2138+
else if (ps.ident == Id.musttail)
2139+
{
2140+
pragmaMustTailSemantic(ps);
2141+
}
21382142
else if (!global.params.ignoreUnsupportedPragmas)
21392143
{
21402144
ps.error("unrecognized `pragma(%s)`", ps.ident.toChars());
@@ -2153,6 +2157,37 @@ else
21532157
result = ps._body;
21542158
}
21552159

2160+
private void pragmaMustTailSemantic(PragmaStatement ps)
2161+
{
2162+
if (!ps._body)
2163+
{
2164+
ps.error("`pragma(musttail)` must be attached to a return statement");
2165+
return setError();
2166+
}
2167+
2168+
auto rs = ps._body.isReturnStatement();
2169+
if (!rs)
2170+
{
2171+
ps.error("`pragma(musttail)` must be attached to a return statement");
2172+
return setError();
2173+
}
2174+
2175+
if (!rs.exp)
2176+
{
2177+
ps.error("`pragma(musttail)` must be attached to a return statement returning result of a function call");
2178+
return setError();
2179+
}
2180+
2181+
auto ce = rs.exp.isCallExp();
2182+
if (!ce)
2183+
{
2184+
ps.error("`pragma(musttail)` must be attached to a return statement returning result of a function call");
2185+
return setError();
2186+
}
2187+
2188+
ce.isMustTail = true;
2189+
}
2190+
21562191
override void visit(StaticAssertStatement s)
21572192
{
21582193
s.sa.semantic2(sc);

gen/aa.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,11 @@ DLValue *DtoAAIndex(const Loc &loc, Type *type, DValue *aa, DValue *key,
6464
DtoTypeInfoOf(loc, aa->type->unSharedOf()->mutableOf(), /*base=*/false);
6565
LLValue *castedAATI = DtoBitCast(rawAATI, funcTy->getParamType(1));
6666
LLValue *valsize = DtoConstSize_t(getTypeAllocSize(DtoType(type)));
67-
ret = gIR->CreateCallOrInvoke(func, aaval, castedAATI, valsize, pkey,
67+
ret = gIR->CreateCallOrInvoke(loc, func, aaval, castedAATI, valsize, pkey,
6868
"aa.index");
6969
} else {
7070
LLValue *keyti = to_keyti(loc, aa, funcTy->getParamType(1));
71-
ret = gIR->CreateCallOrInvoke(func, aaval, keyti, pkey, "aa.index");
71+
ret = gIR->CreateCallOrInvoke(loc, func, aaval, keyti, pkey, "aa.index");
7272
}
7373

7474
// cast return value
@@ -130,7 +130,7 @@ DValue *DtoAAIn(const Loc &loc, Type *type, DValue *aa, DValue *key) {
130130
pkey = DtoBitCast(pkey, getVoidPtrType());
131131

132132
// call runtime
133-
LLValue *ret = gIR->CreateCallOrInvoke(func, aaval, keyti, pkey, "aa.in");
133+
LLValue *ret = gIR->CreateCallOrInvoke(loc, func, aaval, keyti, pkey, "aa.in");
134134

135135
// cast return value
136136
LLType *targettype = DtoType(type);
@@ -174,7 +174,7 @@ DValue *DtoAARemove(const Loc &loc, DValue *aa, DValue *key) {
174174
pkey = DtoBitCast(pkey, funcTy->getParamType(2));
175175

176176
// call runtime
177-
LLValue *res = gIR->CreateCallOrInvoke(func, aaval, keyti, pkey);
177+
LLValue *res = gIR->CreateCallOrInvoke(loc, func, aaval, keyti, pkey);
178178

179179
return new DImValue(Type::tbool, res);
180180
}
@@ -192,7 +192,7 @@ LLValue *DtoAAEquals(const Loc &loc, EXP op, DValue *l, DValue *r) {
192192
LLValue *abval = DtoBitCast(DtoRVal(r), funcTy->getParamType(2));
193193
LLValue *aaTypeInfo = DtoTypeInfoOf(loc, t);
194194
LLValue *res =
195-
gIR->CreateCallOrInvoke(func, aaTypeInfo, aaval, abval, "aaEqRes");
195+
gIR->CreateCallOrInvoke(loc, func, aaTypeInfo, aaval, abval, "aaEqRes");
196196

197197
const auto predicate = eqTokToICmpPred(op, /* invert = */ true);
198198
res = gIR->ir->CreateICmp(predicate, res, DtoConstInt(0));

gen/arrays.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ static void copySlice(const Loc &loc, LLValue *dstarr, LLValue *dstlen,
168168
if (checksEnabled && !knownInBounds) {
169169
LLFunction *fn = getRuntimeFunction(loc, gIR->module, "_d_array_slice_copy");
170170
gIR->CreateCallOrInvoke(
171-
fn, {dstarr, dstlen, srcarr, srclen, DtoConstSize_t(elementSize)}, "",
171+
loc, fn, {dstarr, dstlen, srcarr, srclen, DtoConstSize_t(elementSize)}, "",
172172
/*isNothrow=*/true);
173173
} else {
174174
// We might have dstarr == srcarr at compile time, but as long as
@@ -271,7 +271,7 @@ void DtoArrayAssign(const Loc &loc, DValue *lhs, DValue *rhs, EXP op,
271271
loc, gIR->module,
272272
!canSkipPostblit ? "_d_arrayassign_l" : "_d_arrayassign_r");
273273
gIR->CreateCallOrInvoke(
274-
fn, DtoTypeInfoOf(loc, elemType), DtoSlice(rhsPtr, rhsLength, getI8Type()),
274+
loc, fn, DtoTypeInfoOf(loc, elemType), DtoSlice(rhsPtr, rhsLength, getI8Type()),
275275
DtoSlice(lhsPtr, lhsLength, getI8Type()), DtoBitCast(tmpSwap, getVoidPtrType()));
276276
}
277277
} else {
@@ -305,7 +305,7 @@ void DtoArrayAssign(const Loc &loc, DValue *lhs, DValue *rhs, EXP op,
305305
LLFunction *fn =
306306
getRuntimeFunction(loc, gIR->module, "_d_arraysetassign");
307307
gIR->CreateCallOrInvoke(
308-
fn, lhsPtr, DtoBitCast(makeLValue(loc, rhs), getVoidPtrType()),
308+
loc, fn, lhsPtr, DtoBitCast(makeLValue(loc, rhs), getVoidPtrType()),
309309
gIR->ir->CreateTruncOrBitCast(lhsLength,
310310
LLType::getInt32Ty(gIR->context())),
311311
DtoTypeInfoOf(loc, stripModifiers(t2)));
@@ -672,7 +672,7 @@ DSliceValue *DtoNewDynArray(const Loc &loc, Type *arrayType, DValue *dim,
672672

673673
// call allocator
674674
LLValue *newArray =
675-
gIR->CreateCallOrInvoke(fn, arrayTypeInfo, arrayLen, ".gc_mem");
675+
gIR->CreateCallOrInvoke(loc, fn, arrayTypeInfo, arrayLen, ".gc_mem");
676676

677677
// return a DSliceValue with the well-known length for better optimizability
678678
auto ptr =
@@ -741,7 +741,7 @@ DSliceValue *DtoNewMulDimDynArray(const Loc &loc, Type *arrayType,
741741

742742
// call allocator
743743
LLValue *newptr =
744-
gIR->CreateCallOrInvoke(fn, arrayTypeInfo, DtoLoad(dtype, darray), ".gc_mem");
744+
gIR->CreateCallOrInvoke(loc, fn, arrayTypeInfo, DtoLoad(dtype, darray), ".gc_mem");
745745

746746
IF_LOG Logger::cout() << "final ptr = " << *newptr << '\n';
747747

@@ -769,7 +769,7 @@ DSliceValue *DtoResizeDynArray(const Loc &loc, Type *arrayType, DValue *array,
769769
: "_d_arraysetlengthiT");
770770

771771
LLValue *newArray = gIR->CreateCallOrInvoke(
772-
fn, DtoTypeInfoOf(loc, arrayType), newdim,
772+
loc, fn, DtoTypeInfoOf(loc, arrayType), newdim,
773773
DtoBitCast(DtoLVal(array), fn->getFunctionType()->getParamType(2)),
774774
".gc_mem");
775775

@@ -871,7 +871,7 @@ DSliceValue *DtoCatArrays(const Loc &loc, Type *arrayType, Expression *exp1,
871871
args.push_back(loadArray(exp2,2));
872872
}
873873

874-
auto newArray = gIR->CreateCallOrInvoke(fn, args, ".appendedArray");
874+
auto newArray = gIR->CreateCallOrInvoke(loc, fn, args, ".appendedArray");
875875
return getSlice(arrayType, newArray);
876876
}
877877

@@ -886,7 +886,7 @@ DSliceValue *DtoAppendDChar(const Loc &loc, DValue *arr, Expression *exp,
886886

887887
// Call function (ref string x, dchar c)
888888
LLValue *newArray = gIR->CreateCallOrInvoke(
889-
fn, DtoBitCast(DtoLVal(arr), fn->getFunctionType()->getParamType(0)),
889+
loc, fn, DtoBitCast(DtoLVal(arr), fn->getFunctionType()->getParamType(0)),
890890
DtoBitCast(valueToAppend, fn->getFunctionType()->getParamType(1)),
891891
".appendedArray");
892892

@@ -942,7 +942,7 @@ LLValue *DtoArrayEqCmp_impl(const Loc &loc, const char *func, DValue *l,
942942
args.push_back(DtoBitCast(tival, fn->getFunctionType()->getParamType(2)));
943943
}
944944

945-
return gIR->CreateCallOrInvoke(fn, args);
945+
return gIR->CreateCallOrInvoke(loc, fn, args);
946946
}
947947

948948
/// When `true` is returned, the type can be compared using `memcmp`.
@@ -1324,7 +1324,7 @@ static void emitRangeErrorImpl(IRState *irs, const Loc &loc,
13241324
args.push_back(DtoModuleFileName(module, loc));
13251325
args.push_back(DtoConstUint(loc.linnum));
13261326
args.insert(args.end(), extraArgs.begin(), extraArgs.end());
1327-
irs->CreateCallOrInvoke(fn, args);
1327+
irs->CreateCallOrInvoke(loc, fn, args);
13281328
irs->ir->CreateUnreachable();
13291329
break;
13301330
}

gen/classes.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ DValue *DtoNewClass(const Loc &loc, TypeClass *tc, NewExp *newexp) {
9090
LLConstant *ci =
9191
DtoBitCast(irClass->getClassInfoSymbol(), DtoType(getClassInfoType()));
9292
mem = gIR->CreateCallOrInvoke(
93-
fn, ci, useEHAlloc ? ".newthrowable_alloc" : ".newclass_gc_alloc");
93+
loc, fn, ci, useEHAlloc ? ".newthrowable_alloc" : ".newclass_gc_alloc");
9494
mem = DtoBitCast(mem, DtoType(tc),
9595
useEHAlloc ? ".newthrowable" : ".newclass_gc");
9696
doInit = !useEHAlloc;
@@ -183,7 +183,7 @@ void DtoFinalizeClass(const Loc &loc, LLValue *inst) {
183183
getRuntimeFunction(loc, gIR->module, "_d_callfinalizer");
184184

185185
gIR->CreateCallOrInvoke(
186-
fn, DtoBitCast(inst, fn->getFunctionType()->getParamType(0)), "");
186+
loc, fn, DtoBitCast(inst, fn->getFunctionType()->getParamType(0)), "");
187187
}
188188

189189
////////////////////////////////////////////////////////////////////////////////
@@ -378,7 +378,7 @@ DValue *DtoDynamicCastObject(const Loc &loc, DValue *val, Type *_to) {
378378
assert(funcTy->getParamType(1) == cinfo->getType());
379379

380380
// call it
381-
LLValue *ret = gIR->CreateCallOrInvoke(func, obj, cinfo);
381+
LLValue *ret = gIR->CreateCallOrInvoke(loc, func, obj, cinfo);
382382

383383
// cast return value
384384
ret = DtoBitCast(ret, DtoType(_to));
@@ -412,7 +412,7 @@ DValue *DtoDynamicCastInterface(const Loc &loc, DValue *val, Type *_to) {
412412
cinfo = DtoBitCast(cinfo, funcTy->getParamType(1));
413413

414414
// call it
415-
LLValue *ret = gIR->CreateCallOrInvoke(func, ptr, cinfo);
415+
LLValue *ret = gIR->CreateCallOrInvoke(loc, func, ptr, cinfo);
416416

417417
// cast return value
418418
ret = DtoBitCast(ret, DtoType(_to));

gen/dpragma.d

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ extern (C++) enum LDCPragma : int {
4444
LLVMbitop_bts,
4545
LLVMbitop_vld,
4646
LLVMbitop_vst,
47-
LLVMextern_weak
47+
LLVMextern_weak,
48+
LLVMmusttail,
4849
};
4950

5051
extern (C++) LDCPragma DtoGetPragma(Scope* sc, PragmaDeclaration decl, ref const(char)* arg1str);

gen/funcgenstate.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include "gen/funcgenstate.h"
1111

12+
#include "dmd/errors.h"
1213
#include "dmd/identifier.h"
1314
#include "gen/llvm.h"
1415
#include "gen/llvmhelpers.h"
@@ -103,10 +104,10 @@ FuncGenState::FuncGenState(IrFunction &irFunc, IRState &irs)
103104
: irFunc(irFunc), scopes(irs), jumpTargets(scopes), switchTargets(),
104105
irs(irs) {}
105106

106-
LLCallBasePtr FuncGenState::callOrInvoke(llvm::Value *callee,
107+
LLCallBasePtr FuncGenState::callOrInvoke(const Loc &loc, llvm::Value *callee,
107108
llvm::FunctionType *calleeType,
108109
llvm::ArrayRef<llvm::Value *> args,
109-
const char *name, bool isNothrow) {
110+
const char *name, bool isNothrow, bool isMustTail) {
110111
// If this is a direct call, we might be able to use the callee attributes
111112
// to our advantage.
112113
llvm::Function *calleeFn = llvm::dyn_cast<llvm::Function>(callee);
@@ -135,9 +136,17 @@ LLCallBasePtr FuncGenState::callOrInvoke(llvm::Value *callee,
135136
if (calleeFn) {
136137
call->setAttributes(calleeFn->getAttributes());
137138
}
139+
if (isMustTail) {
140+
call->setTailCallKind(llvm::CallInst::TCK_MustTail);
141+
}
138142
return call;
139143
}
140144

145+
if (isMustTail) {
146+
error(loc, "cannot perform tail-call, there is code after call");
147+
fatal();
148+
}
149+
141150
llvm::BasicBlock *landingPad = scopes.getLandingPad();
142151

143152
llvm::BasicBlock *postinvoke = irs.insertBB("postinvoke");

gen/funcgenstate.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,11 @@ class FuncGenState {
199199

200200
/// Emits a call or invoke to the given callee, depending on whether there
201201
/// are catches/cleanups active or not.
202-
LLCallBasePtr callOrInvoke(llvm::Value *callee,
202+
LLCallBasePtr callOrInvoke(const Loc &loc, llvm::Value *callee,
203203
llvm::FunctionType *calleeType,
204204
llvm::ArrayRef<llvm::Value *> args,
205-
const char *name = "", bool isNothrow = false);
205+
const char *name = "", bool isNothrow = false,
206+
bool isMustTail = false);
206207

207208
private:
208209
IRState &irs;

gen/irstate.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -79,43 +79,44 @@ llvm::BasicBlock *IRState::insertBB(const llvm::Twine &name) {
7979
return insertBBAfter(scopebb(), name);
8080
}
8181

82-
llvm::Instruction *IRState::CreateCallOrInvoke(LLFunction *Callee,
82+
llvm::Instruction *IRState::CreateCallOrInvoke(const Loc &loc, LLFunction *Callee,
8383
const char *Name) {
84-
return CreateCallOrInvoke(Callee, {}, Name);
84+
return CreateCallOrInvoke(loc, Callee, {}, Name);
8585
}
8686

87-
llvm::Instruction *IRState::CreateCallOrInvoke(LLFunction *Callee,
87+
llvm::Instruction *IRState::CreateCallOrInvoke(const Loc &loc, LLFunction *Callee,
8888
llvm::ArrayRef<LLValue *> Args,
8989
const char *Name,
9090
bool isNothrow) {
91-
return funcGen().callOrInvoke(Callee, Callee->getFunctionType(), Args, Name,
92-
isNothrow);
91+
return funcGen().callOrInvoke(loc, Callee, Callee->getFunctionType(), Args,
92+
Name, isNothrow);
9393
}
9494

95-
llvm::Instruction *IRState::CreateCallOrInvoke(LLFunction *Callee,
95+
llvm::Instruction *IRState::CreateCallOrInvoke(const Loc &loc,
96+
LLFunction *Callee,
9697
LLValue *Arg1,
9798
const char *Name) {
98-
return CreateCallOrInvoke(Callee, llvm::ArrayRef<LLValue *>(Arg1), Name);
99+
return CreateCallOrInvoke(loc, Callee, llvm::ArrayRef<LLValue *>(Arg1), Name);
99100
}
100101

101-
llvm::Instruction *IRState::CreateCallOrInvoke(LLFunction *Callee,
102+
llvm::Instruction *IRState::CreateCallOrInvoke(const Loc &loc, LLFunction *Callee,
102103
LLValue *Arg1, LLValue *Arg2,
103104
const char *Name) {
104-
return CreateCallOrInvoke(Callee, {Arg1, Arg2}, Name);
105+
return CreateCallOrInvoke(loc, Callee, {Arg1, Arg2}, Name);
105106
}
106107

107-
llvm::Instruction *IRState::CreateCallOrInvoke(LLFunction *Callee,
108+
llvm::Instruction *IRState::CreateCallOrInvoke(const Loc &loc, LLFunction *Callee,
108109
LLValue *Arg1, LLValue *Arg2,
109110
LLValue *Arg3,
110111
const char *Name) {
111-
return CreateCallOrInvoke(Callee, {Arg1, Arg2, Arg3}, Name);
112+
return CreateCallOrInvoke(loc, Callee, {Arg1, Arg2, Arg3}, Name);
112113
}
113114

114-
llvm::Instruction *IRState::CreateCallOrInvoke(LLFunction *Callee,
115+
llvm::Instruction *IRState::CreateCallOrInvoke(const Loc &loc, LLFunction *Callee,
115116
LLValue *Arg1, LLValue *Arg2,
116117
LLValue *Arg3, LLValue *Arg4,
117118
const char *Name) {
118-
return CreateCallOrInvoke(Callee, {Arg1, Arg2, Arg3, Arg4}, Name);
119+
return CreateCallOrInvoke(loc, Callee, {Arg1, Arg2, Arg3, Arg4}, Name);
119120
}
120121

121122
bool IRState::emitArrayBoundsChecks() {

0 commit comments

Comments
 (0)