23
23
#include " llvm/ADT/SmallVector.h"
24
24
#include " llvm/Analysis/AliasAnalysis.h"
25
25
#include " llvm/Analysis/AliasSetTracker.h"
26
+ #include " llvm/Analysis/AssumeBundleQueries.h"
27
+ #include " llvm/Analysis/AssumptionCache.h"
26
28
#include " llvm/Analysis/LoopAnalysisManager.h"
27
29
#include " llvm/Analysis/LoopInfo.h"
28
30
#include " llvm/Analysis/LoopIterator.h"
@@ -208,28 +210,52 @@ static const SCEV *mulSCEVOverflow(const SCEV *A, const SCEV *B,
208
210
209
211
// / Return true, if evaluating \p AR at \p MaxBTC cannot wrap, because \p AR at
210
212
// / \p MaxBTC is guaranteed inbounds of the accessed object.
211
- static bool evaluatePtrAddRecAtMaxBTCWillNotWrap ( const SCEVAddRecExpr *AR,
212
- const SCEV *MaxBTC ,
213
- const SCEV *EltSize,
214
- ScalarEvolution &SE ,
215
- const DataLayout &DL ) {
213
+ static bool
214
+ evaluatePtrAddRecAtMaxBTCWillNotWrap ( const SCEVAddRecExpr *AR ,
215
+ const SCEV *MaxBTC, const SCEV *EltSize,
216
+ ScalarEvolution &SE, const DataLayout &DL ,
217
+ AssumptionCache *AC, DominatorTree *DT ) {
216
218
auto *PointerBase = SE.getPointerBase (AR->getStart ());
217
219
auto *StartPtr = dyn_cast<SCEVUnknown>(PointerBase);
218
220
if (!StartPtr)
219
221
return false ;
222
+ const Loop *L = AR->getLoop ();
220
223
bool CheckForNonNull, CheckForFreed;
221
- uint64_t DerefBytes = StartPtr->getValue ()->getPointerDereferenceableBytes (
224
+ Value *StartPtrV = StartPtr->getValue ();
225
+ uint64_t DerefBytes = StartPtrV->getPointerDereferenceableBytes (
222
226
DL, CheckForNonNull, CheckForFreed);
223
227
224
- if (CheckForNonNull || CheckForFreed)
228
+ if (DerefBytes && ( CheckForNonNull || CheckForFreed) )
225
229
return false ;
226
230
227
231
const SCEV *Step = AR->getStepRecurrence (SE);
232
+ Type *WiderTy = SE.getWiderType (MaxBTC->getType (), Step->getType ());
233
+ const SCEV *DerefBytesSCEV = SE.getConstant (WiderTy, DerefBytes);
234
+
235
+ // Check if we have a suitable dereferencable assumption we can use.
236
+ RetainedKnowledge DerefRK;
237
+ if (!StartPtrV->canBeFreed () &&
238
+ getKnowledgeForValue (
239
+ StartPtrV, {Attribute::Dereferenceable}, *AC,
240
+ [&](RetainedKnowledge RK, Instruction *Assume, auto ) {
241
+ if (!isValidAssumeForContext (
242
+ Assume, L->getLoopPredecessor ()->getTerminator (), DT))
243
+ return false ;
244
+ if (RK.AttrKind == Attribute::Dereferenceable) {
245
+ DerefRK = std::max (DerefRK, RK);
246
+ return true ;
247
+ }
248
+ return false ;
249
+ }) &&
250
+ DerefRK.ArgValue ) {
251
+ DerefBytesSCEV = SE.getUMaxExpr (DerefBytesSCEV,
252
+ SE.getConstant (WiderTy, DerefRK.ArgValue ));
253
+ }
254
+
228
255
bool IsKnownNonNegative = SE.isKnownNonNegative (Step);
229
256
if (!IsKnownNonNegative && !SE.isKnownNegative (Step))
230
257
return false ;
231
258
232
- Type *WiderTy = SE.getWiderType (MaxBTC->getType (), Step->getType ());
233
259
Step = SE.getNoopOrSignExtend (Step, WiderTy);
234
260
MaxBTC = SE.getNoopOrZeroExtend (MaxBTC, WiderTy);
235
261
@@ -256,24 +282,23 @@ static bool evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR,
256
282
const SCEV *EndBytes = addSCEVNoOverflow (StartOffset, OffsetEndBytes, SE);
257
283
if (!EndBytes)
258
284
return false ;
259
- return SE.isKnownPredicate (CmpInst::ICMP_ULE, EndBytes,
260
- SE.getConstant (WiderTy, DerefBytes));
285
+ return SE.isKnownPredicate (CmpInst::ICMP_ULE, EndBytes, DerefBytesSCEV);
261
286
}
262
287
263
288
// For negative steps check if
264
289
// * StartOffset >= (MaxBTC * Step + EltSize)
265
290
// * StartOffset <= DerefBytes.
266
291
assert (SE.isKnownNegative (Step) && " must be known negative" );
267
292
return SE.isKnownPredicate (CmpInst::ICMP_SGE, StartOffset, OffsetEndBytes) &&
268
- SE.isKnownPredicate (CmpInst::ICMP_ULE, StartOffset,
269
- SE.getConstant (WiderTy, DerefBytes));
293
+ SE.isKnownPredicate (CmpInst::ICMP_ULE, StartOffset, DerefBytesSCEV);
270
294
}
271
295
272
296
std::pair<const SCEV *, const SCEV *> llvm::getStartAndEndForAccess (
273
297
const Loop *Lp, const SCEV *PtrExpr, Type *AccessTy, const SCEV *BTC,
274
298
const SCEV *MaxBTC, ScalarEvolution *SE,
275
299
DenseMap<std::pair<const SCEV *, Type *>,
276
- std::pair<const SCEV *, const SCEV *>> *PointerBounds) {
300
+ std::pair<const SCEV *, const SCEV *>> *PointerBounds,
301
+ AssumptionCache *AC, DominatorTree *DT) {
277
302
std::pair<const SCEV *, const SCEV *> *PtrBoundsPair;
278
303
if (PointerBounds) {
279
304
auto [Iter, Ins] = PointerBounds->insert (
@@ -308,8 +333,8 @@ std::pair<const SCEV *, const SCEV *> llvm::getStartAndEndForAccess(
308
333
// sets ScEnd to the maximum unsigned value for the type. Note that LAA
309
334
// separately checks that accesses cannot not wrap, so unsigned max
310
335
// represents an upper bound.
311
- if (evaluatePtrAddRecAtMaxBTCWillNotWrap (AR, MaxBTC, EltSizeSCEV, *SE,
312
- DL )) {
336
+ if (evaluatePtrAddRecAtMaxBTCWillNotWrap (AR, MaxBTC, EltSizeSCEV, *SE, DL,
337
+ AC, DT )) {
313
338
ScEnd = AR->evaluateAtIteration (MaxBTC, *SE);
314
339
} else {
315
340
ScEnd = SE->getAddExpr (
@@ -356,9 +381,9 @@ void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, const SCEV *PtrExpr,
356
381
bool NeedsFreeze) {
357
382
const SCEV *SymbolicMaxBTC = PSE.getSymbolicMaxBackedgeTakenCount ();
358
383
const SCEV *BTC = PSE.getBackedgeTakenCount ();
359
- const auto &[ScStart, ScEnd] =
360
- getStartAndEndForAccess ( Lp, PtrExpr, AccessTy, BTC, SymbolicMaxBTC,
361
- PSE. getSE (), & DC.getPointerBounds ());
384
+ const auto &[ScStart, ScEnd] = getStartAndEndForAccess (
385
+ Lp, PtrExpr, AccessTy, BTC, SymbolicMaxBTC, PSE. getSE () ,
386
+ &DC. getPointerBounds (), DC. getAC (), DC.getDT ());
362
387
assert (!isa<SCEVCouldNotCompute>(ScStart) &&
363
388
!isa<SCEVCouldNotCompute>(ScEnd) &&
364
389
" must be able to compute both start and end expressions" );
@@ -2011,10 +2036,10 @@ MemoryDepChecker::getDependenceDistanceStrideAndSize(
2011
2036
const SCEV *SymbolicMaxBTC = PSE.getSymbolicMaxBackedgeTakenCount ();
2012
2037
const auto &[SrcStart_, SrcEnd_] =
2013
2038
getStartAndEndForAccess (InnermostLoop, Src, ATy, BTC, SymbolicMaxBTC,
2014
- PSE.getSE (), &PointerBounds);
2039
+ PSE.getSE (), &PointerBounds, AC, DT );
2015
2040
const auto &[SinkStart_, SinkEnd_] =
2016
2041
getStartAndEndForAccess (InnermostLoop, Sink, BTy, BTC, SymbolicMaxBTC,
2017
- PSE.getSE (), &PointerBounds);
2042
+ PSE.getSE (), &PointerBounds, AC, DT );
2018
2043
if (!isa<SCEVCouldNotCompute>(SrcStart_) &&
2019
2044
!isa<SCEVCouldNotCompute>(SrcEnd_) &&
2020
2045
!isa<SCEVCouldNotCompute>(SinkStart_) &&
@@ -3015,7 +3040,7 @@ LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,
3015
3040
const TargetTransformInfo *TTI,
3016
3041
const TargetLibraryInfo *TLI, AAResults *AA,
3017
3042
DominatorTree *DT, LoopInfo *LI,
3018
- bool AllowPartial)
3043
+ AssumptionCache *AC, bool AllowPartial)
3019
3044
: PSE(std::make_unique<PredicatedScalarEvolution>(*SE, *L)),
3020
3045
PtrRtChecking (nullptr ), TheLoop(L), AllowPartial(AllowPartial) {
3021
3046
unsigned MaxTargetVectorWidthInBits = std::numeric_limits<unsigned >::max ();
@@ -3025,8 +3050,8 @@ LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,
3025
3050
MaxTargetVectorWidthInBits =
3026
3051
TTI->getRegisterBitWidth (TargetTransformInfo::RGK_FixedWidthVector) * 2 ;
3027
3052
3028
- DepChecker = std::make_unique<MemoryDepChecker>(*PSE, L, SymbolicStrides,
3029
- MaxTargetVectorWidthInBits);
3053
+ DepChecker = std::make_unique<MemoryDepChecker>(
3054
+ *PSE, AC, DT, L, SymbolicStrides, MaxTargetVectorWidthInBits);
3030
3055
PtrRtChecking = std::make_unique<RuntimePointerChecking>(*DepChecker, SE);
3031
3056
if (canAnalyzeLoop ())
3032
3057
CanVecMem = analyzeLoop (AA, LI, TLI, DT);
@@ -3095,7 +3120,7 @@ const LoopAccessInfo &LoopAccessInfoManager::getInfo(Loop &L,
3095
3120
// or if it was created with a different value of AllowPartial.
3096
3121
if (Inserted || It->second ->hasAllowPartial () != AllowPartial)
3097
3122
It->second = std::make_unique<LoopAccessInfo>(&L, &SE, TTI, TLI, &AA, &DT,
3098
- &LI, AllowPartial);
3123
+ &LI, AC, AllowPartial);
3099
3124
3100
3125
return *It->second ;
3101
3126
}
@@ -3138,7 +3163,8 @@ LoopAccessInfoManager LoopAccessAnalysis::run(Function &F,
3138
3163
auto &LI = FAM.getResult <LoopAnalysis>(F);
3139
3164
auto &TTI = FAM.getResult <TargetIRAnalysis>(F);
3140
3165
auto &TLI = FAM.getResult <TargetLibraryAnalysis>(F);
3141
- return LoopAccessInfoManager (SE, AA, DT, LI, &TTI, &TLI);
3166
+ auto &AC = FAM.getResult <AssumptionAnalysis>(F);
3167
+ return LoopAccessInfoManager (SE, AA, DT, LI, &TTI, &TLI, &AC);
3142
3168
}
3143
3169
3144
3170
AnalysisKey LoopAccessAnalysis::Key;
0 commit comments