[CMake] Link unittests only against libLLVM.so, if available.
[polly-mirror.git] / lib / Support / SCEVValidator.cpp
blob848357f5afb597e6fa43d4b4fad6d6400e677020
2 #include "polly/Support/SCEVValidator.h"
3 #include "polly/ScopInfo.h"
4 #include "llvm/Analysis/RegionInfo.h"
5 #include "llvm/Analysis/ScalarEvolution.h"
6 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
7 #include "llvm/Support/Debug.h"
9 using namespace llvm;
10 using namespace polly;
12 #define DEBUG_TYPE "polly-scev-validator"
14 namespace SCEVType {
15 /// The type of a SCEV
16 ///
17 /// To check for the validity of a SCEV we assign to each SCEV a type. The
18 /// possible types are INT, PARAM, IV and INVALID. The order of the types is
19 /// important. The subexpressions of SCEV with a type X can only have a type
20 /// that is smaller or equal than X.
21 enum TYPE {
22 // An integer value.
23 INT,
25 // An expression that is constant during the execution of the Scop,
26 // but that may depend on parameters unknown at compile time.
27 PARAM,
29 // An expression that may change during the execution of the SCoP.
30 IV,
32 // An invalid expression.
33 INVALID
35 } // namespace SCEVType
37 /// The result the validator returns for a SCEV expression.
38 class ValidatorResult {
39 /// The type of the expression
40 SCEVType::TYPE Type;
42 /// The set of Parameters in the expression.
43 ParameterSetTy Parameters;
45 public:
46 /// The copy constructor
47 ValidatorResult(const ValidatorResult &Source) {
48 Type = Source.Type;
49 Parameters = Source.Parameters;
52 /// Construct a result with a certain type and no parameters.
53 ValidatorResult(SCEVType::TYPE Type) : Type(Type) {
54 assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter");
57 /// Construct a result with a certain type and a single parameter.
58 ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) {
59 Parameters.insert(Expr);
62 /// Get the type of the ValidatorResult.
63 SCEVType::TYPE getType() { return Type; }
65 /// Is the analyzed SCEV constant during the execution of the SCoP.
66 bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; }
68 /// Is the analyzed SCEV valid.
69 bool isValid() { return Type != SCEVType::INVALID; }
71 /// Is the analyzed SCEV of Type IV.
72 bool isIV() { return Type == SCEVType::IV; }
74 /// Is the analyzed SCEV of Type INT.
75 bool isINT() { return Type == SCEVType::INT; }
77 /// Is the analyzed SCEV of Type PARAM.
78 bool isPARAM() { return Type == SCEVType::PARAM; }
80 /// Get the parameters of this validator result.
81 const ParameterSetTy &getParameters() { return Parameters; }
83 /// Add the parameters of Source to this result.
84 void addParamsFrom(const ValidatorResult &Source) {
85 Parameters.insert(Source.Parameters.begin(), Source.Parameters.end());
88 /// Merge a result.
89 ///
90 /// This means to merge the parameters and to set the Type to the most
91 /// specific Type that matches both.
92 void merge(const ValidatorResult &ToMerge) {
93 Type = std::max(Type, ToMerge.Type);
94 addParamsFrom(ToMerge);
97 void print(raw_ostream &OS) {
98 switch (Type) {
99 case SCEVType::INT:
100 OS << "SCEVType::INT";
101 break;
102 case SCEVType::PARAM:
103 OS << "SCEVType::PARAM";
104 break;
105 case SCEVType::IV:
106 OS << "SCEVType::IV";
107 break;
108 case SCEVType::INVALID:
109 OS << "SCEVType::INVALID";
110 break;
115 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) {
116 VR.print(OS);
117 return OS;
120 /// Check if a SCEV is valid in a SCoP.
121 struct SCEVValidator
122 : public SCEVVisitor<SCEVValidator, class ValidatorResult> {
123 private:
124 const Region *R;
125 Loop *Scope;
126 ScalarEvolution &SE;
127 InvariantLoadsSetTy *ILS;
129 public:
130 SCEVValidator(const Region *R, Loop *Scope, ScalarEvolution &SE,
131 InvariantLoadsSetTy *ILS)
132 : R(R), Scope(Scope), SE(SE), ILS(ILS) {}
134 class ValidatorResult visitConstant(const SCEVConstant *Constant) {
135 return ValidatorResult(SCEVType::INT);
138 class ValidatorResult visitZeroExtendOrTruncateExpr(const SCEV *Expr,
139 const SCEV *Operand) {
140 ValidatorResult Op = visit(Operand);
141 auto Type = Op.getType();
143 // If unsigned operations are allowed return the operand, otherwise
144 // check if we can model the expression without unsigned assumptions.
145 if (PollyAllowUnsignedOperations || Type == SCEVType::INVALID)
146 return Op;
148 if (Type == SCEVType::IV)
149 return ValidatorResult(SCEVType::INVALID);
150 return ValidatorResult(SCEVType::PARAM, Expr);
153 class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) {
154 return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
157 class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
158 return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
161 class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
162 return visit(Expr->getOperand());
165 class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) {
166 ValidatorResult Return(SCEVType::INT);
168 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
169 ValidatorResult Op = visit(Expr->getOperand(i));
170 Return.merge(Op);
172 // Early exit.
173 if (!Return.isValid())
174 break;
177 return Return;
180 class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) {
181 ValidatorResult Return(SCEVType::INT);
183 bool HasMultipleParams = false;
185 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
186 ValidatorResult Op = visit(Expr->getOperand(i));
188 if (Op.isINT())
189 continue;
191 if (Op.isPARAM() && Return.isPARAM()) {
192 HasMultipleParams = true;
193 continue;
196 if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) {
197 DEBUG(dbgs() << "INVALID: More than one non-int operand in MulExpr\n"
198 << "\tExpr: " << *Expr << "\n"
199 << "\tPrevious expression type: " << Return << "\n"
200 << "\tNext operand (" << Op
201 << "): " << *Expr->getOperand(i) << "\n");
203 return ValidatorResult(SCEVType::INVALID);
206 Return.merge(Op);
209 if (HasMultipleParams && Return.isValid())
210 return ValidatorResult(SCEVType::PARAM, Expr);
212 return Return;
215 class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) {
216 if (!Expr->isAffine()) {
217 DEBUG(dbgs() << "INVALID: AddRec is not affine");
218 return ValidatorResult(SCEVType::INVALID);
221 ValidatorResult Start = visit(Expr->getStart());
222 ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE));
224 if (!Start.isValid())
225 return Start;
227 if (!Recurrence.isValid())
228 return Recurrence;
230 auto *L = Expr->getLoop();
231 if (R->contains(L) && (!Scope || !L->contains(Scope))) {
232 DEBUG(dbgs() << "INVALID: Loop of AddRec expression boxed in an a "
233 "non-affine subregion or has a non-synthesizable exit "
234 "value.");
235 return ValidatorResult(SCEVType::INVALID);
238 if (R->contains(L)) {
239 if (Recurrence.isINT()) {
240 ValidatorResult Result(SCEVType::IV);
241 Result.addParamsFrom(Start);
242 return Result;
245 DEBUG(dbgs() << "INVALID: AddRec within scop has non-int"
246 "recurrence part");
247 return ValidatorResult(SCEVType::INVALID);
250 assert(Recurrence.isConstant() && "Expected 'Recurrence' to be constant");
252 // Directly generate ValidatorResult for Expr if 'start' is zero.
253 if (Expr->getStart()->isZero())
254 return ValidatorResult(SCEVType::PARAM, Expr);
256 // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}'
257 // if 'start' is not zero.
258 const SCEV *ZeroStartExpr = SE.getAddRecExpr(
259 SE.getConstant(Expr->getStart()->getType(), 0),
260 Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags());
262 ValidatorResult ZeroStartResult =
263 ValidatorResult(SCEVType::PARAM, ZeroStartExpr);
264 ZeroStartResult.addParamsFrom(Start);
266 return ZeroStartResult;
269 class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) {
270 ValidatorResult Return(SCEVType::INT);
272 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
273 ValidatorResult Op = visit(Expr->getOperand(i));
275 if (!Op.isValid())
276 return Op;
278 Return.merge(Op);
281 return Return;
284 class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) {
285 // We do not support unsigned max operations. If 'Expr' is constant during
286 // Scop execution we treat this as a parameter, otherwise we bail out.
287 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
288 ValidatorResult Op = visit(Expr->getOperand(i));
290 if (!Op.isConstant()) {
291 DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand");
292 return ValidatorResult(SCEVType::INVALID);
296 return ValidatorResult(SCEVType::PARAM, Expr);
299 ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) {
300 if (R->contains(I)) {
301 DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction "
302 "within the region\n");
303 return ValidatorResult(SCEVType::INVALID);
306 return ValidatorResult(SCEVType::PARAM, S);
309 ValidatorResult visitLoadInstruction(Instruction *I, const SCEV *S) {
310 if (R->contains(I) && ILS) {
311 ILS->insert(cast<LoadInst>(I));
312 return ValidatorResult(SCEVType::PARAM, S);
315 return visitGenericInst(I, S);
318 ValidatorResult visitDivision(const SCEV *Dividend, const SCEV *Divisor,
319 const SCEV *DivExpr,
320 Instruction *SDiv = nullptr) {
322 // First check if we might be able to model the division, thus if the
323 // divisor is constant. If so, check the dividend, otherwise check if
324 // the whole division can be seen as a parameter.
325 if (isa<SCEVConstant>(Divisor) && !Divisor->isZero())
326 return visit(Dividend);
328 // For signed divisions use the SDiv instruction to check for a parameter
329 // division, for unsigned divisions check the operands.
330 if (SDiv)
331 return visitGenericInst(SDiv, DivExpr);
333 ValidatorResult LHS = visit(Dividend);
334 ValidatorResult RHS = visit(Divisor);
335 if (LHS.isConstant() && RHS.isConstant())
336 return ValidatorResult(SCEVType::PARAM, DivExpr);
338 DEBUG(dbgs() << "INVALID: unsigned division of non-constant expressions");
339 return ValidatorResult(SCEVType::INVALID);
342 ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) {
343 if (!PollyAllowUnsignedOperations)
344 return ValidatorResult(SCEVType::INVALID);
346 auto *Dividend = Expr->getLHS();
347 auto *Divisor = Expr->getRHS();
348 return visitDivision(Dividend, Divisor, Expr);
351 ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *Expr) {
352 assert(SDiv->getOpcode() == Instruction::SDiv &&
353 "Assumed SDiv instruction!");
355 auto *Dividend = SE.getSCEV(SDiv->getOperand(0));
356 auto *Divisor = SE.getSCEV(SDiv->getOperand(1));
357 return visitDivision(Dividend, Divisor, Expr, SDiv);
360 ValidatorResult visitSRemInstruction(Instruction *SRem, const SCEV *S) {
361 assert(SRem->getOpcode() == Instruction::SRem &&
362 "Assumed SRem instruction!");
364 auto *Divisor = SRem->getOperand(1);
365 auto *CI = dyn_cast<ConstantInt>(Divisor);
366 if (!CI || CI->isZeroValue())
367 return visitGenericInst(SRem, S);
369 auto *Dividend = SRem->getOperand(0);
370 auto *DividendSCEV = SE.getSCEV(Dividend);
371 return visit(DividendSCEV);
374 ValidatorResult visitUnknown(const SCEVUnknown *Expr) {
375 Value *V = Expr->getValue();
377 if (!Expr->getType()->isIntegerTy() && !Expr->getType()->isPointerTy()) {
378 DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer or pointer");
379 return ValidatorResult(SCEVType::INVALID);
382 if (isa<UndefValue>(V)) {
383 DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value");
384 return ValidatorResult(SCEVType::INVALID);
387 if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) {
388 switch (I->getOpcode()) {
389 case Instruction::IntToPtr:
390 return visit(SE.getSCEVAtScope(I->getOperand(0), Scope));
391 case Instruction::PtrToInt:
392 return visit(SE.getSCEVAtScope(I->getOperand(0), Scope));
393 case Instruction::Load:
394 return visitLoadInstruction(I, Expr);
395 case Instruction::SDiv:
396 return visitSDivInstruction(I, Expr);
397 case Instruction::SRem:
398 return visitSRemInstruction(I, Expr);
399 default:
400 return visitGenericInst(I, Expr);
404 return ValidatorResult(SCEVType::PARAM, Expr);
408 /// Check whether a SCEV refers to an SSA name defined inside a region.
409 class SCEVInRegionDependences {
410 const Region *R;
411 Loop *Scope;
412 bool AllowLoops;
413 bool HasInRegionDeps = false;
415 public:
416 SCEVInRegionDependences(const Region *R, Loop *Scope, bool AllowLoops)
417 : R(R), Scope(Scope), AllowLoops(AllowLoops) {}
419 bool follow(const SCEV *S) {
420 if (auto Unknown = dyn_cast<SCEVUnknown>(S)) {
421 Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
423 // Return true when Inst is defined inside the region R.
424 if (!Inst || !R->contains(Inst))
425 return true;
427 HasInRegionDeps = true;
428 return false;
431 if (auto AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
432 if (AllowLoops)
433 return true;
435 if (!Scope) {
436 HasInRegionDeps = true;
437 return false;
439 auto *L = AddRec->getLoop();
440 if (R->contains(L) && !L->contains(Scope)) {
441 HasInRegionDeps = true;
442 return false;
446 return true;
448 bool isDone() { return false; }
449 bool hasDependences() { return HasInRegionDeps; }
452 namespace polly {
453 /// Find all loops referenced in SCEVAddRecExprs.
454 class SCEVFindLoops {
455 SetVector<const Loop *> &Loops;
457 public:
458 SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {}
460 bool follow(const SCEV *S) {
461 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S))
462 Loops.insert(AddRec->getLoop());
463 return true;
465 bool isDone() { return false; }
468 void findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) {
469 SCEVFindLoops FindLoops(Loops);
470 SCEVTraversal<SCEVFindLoops> ST(FindLoops);
471 ST.visitAll(Expr);
474 /// Find all values referenced in SCEVUnknowns.
475 class SCEVFindValues {
476 ScalarEvolution &SE;
477 SetVector<Value *> &Values;
479 public:
480 SCEVFindValues(ScalarEvolution &SE, SetVector<Value *> &Values)
481 : SE(SE), Values(Values) {}
483 bool follow(const SCEV *S) {
484 const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S);
485 if (!Unknown)
486 return true;
488 Values.insert(Unknown->getValue());
489 Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
490 if (!Inst || (Inst->getOpcode() != Instruction::SRem &&
491 Inst->getOpcode() != Instruction::SDiv))
492 return false;
494 auto *Dividend = SE.getSCEV(Inst->getOperand(1));
495 if (!isa<SCEVConstant>(Dividend))
496 return false;
498 auto *Divisor = SE.getSCEV(Inst->getOperand(0));
499 SCEVFindValues FindValues(SE, Values);
500 SCEVTraversal<SCEVFindValues> ST(FindValues);
501 ST.visitAll(Dividend);
502 ST.visitAll(Divisor);
504 return false;
506 bool isDone() { return false; }
509 void findValues(const SCEV *Expr, ScalarEvolution &SE,
510 SetVector<Value *> &Values) {
511 SCEVFindValues FindValues(SE, Values);
512 SCEVTraversal<SCEVFindValues> ST(FindValues);
513 ST.visitAll(Expr);
516 bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R,
517 llvm::Loop *Scope, bool AllowLoops) {
518 SCEVInRegionDependences InRegionDeps(R, Scope, AllowLoops);
519 SCEVTraversal<SCEVInRegionDependences> ST(InRegionDeps);
520 ST.visitAll(Expr);
521 return InRegionDeps.hasDependences();
524 bool isAffineExpr(const Region *R, llvm::Loop *Scope, const SCEV *Expr,
525 ScalarEvolution &SE, InvariantLoadsSetTy *ILS) {
526 if (isa<SCEVCouldNotCompute>(Expr))
527 return false;
529 SCEVValidator Validator(R, Scope, SE, ILS);
530 DEBUG({
531 dbgs() << "\n";
532 dbgs() << "Expr: " << *Expr << "\n";
533 dbgs() << "Region: " << R->getNameStr() << "\n";
534 dbgs() << " -> ";
537 ValidatorResult Result = Validator.visit(Expr);
539 DEBUG({
540 if (Result.isValid())
541 dbgs() << "VALID\n";
542 dbgs() << "\n";
545 return Result.isValid();
548 static bool isAffineExpr(Value *V, const Region *R, Loop *Scope,
549 ScalarEvolution &SE, ParameterSetTy &Params) {
550 auto *E = SE.getSCEV(V);
551 if (isa<SCEVCouldNotCompute>(E))
552 return false;
554 SCEVValidator Validator(R, Scope, SE, nullptr);
555 ValidatorResult Result = Validator.visit(E);
556 if (!Result.isValid())
557 return false;
559 auto ResultParams = Result.getParameters();
560 Params.insert(ResultParams.begin(), ResultParams.end());
562 return true;
565 bool isAffineConstraint(Value *V, const Region *R, llvm::Loop *Scope,
566 ScalarEvolution &SE, ParameterSetTy &Params,
567 bool OrExpr) {
568 if (auto *ICmp = dyn_cast<ICmpInst>(V)) {
569 return isAffineConstraint(ICmp->getOperand(0), R, Scope, SE, Params,
570 true) &&
571 isAffineConstraint(ICmp->getOperand(1), R, Scope, SE, Params, true);
572 } else if (auto *BinOp = dyn_cast<BinaryOperator>(V)) {
573 auto Opcode = BinOp->getOpcode();
574 if (Opcode == Instruction::And || Opcode == Instruction::Or)
575 return isAffineConstraint(BinOp->getOperand(0), R, Scope, SE, Params,
576 false) &&
577 isAffineConstraint(BinOp->getOperand(1), R, Scope, SE, Params,
578 false);
579 /* Fall through */
582 if (!OrExpr)
583 return false;
585 return isAffineExpr(V, R, Scope, SE, Params);
588 ParameterSetTy getParamsInAffineExpr(const Region *R, Loop *Scope,
589 const SCEV *Expr, ScalarEvolution &SE) {
590 if (isa<SCEVCouldNotCompute>(Expr))
591 return ParameterSetTy();
593 InvariantLoadsSetTy ILS;
594 SCEVValidator Validator(R, Scope, SE, &ILS);
595 ValidatorResult Result = Validator.visit(Expr);
596 assert(Result.isValid() && "Requested parameters for an invalid SCEV!");
598 return Result.getParameters();
601 std::pair<const SCEVConstant *, const SCEV *>
602 extractConstantFactor(const SCEV *S, ScalarEvolution &SE) {
603 auto *ConstPart = cast<SCEVConstant>(SE.getConstant(S->getType(), 1));
605 if (auto *Constant = dyn_cast<SCEVConstant>(S))
606 return std::make_pair(Constant, SE.getConstant(S->getType(), 1));
608 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
609 if (AddRec) {
610 auto *StartExpr = AddRec->getStart();
611 if (StartExpr->isZero()) {
612 auto StepPair = extractConstantFactor(AddRec->getStepRecurrence(SE), SE);
613 auto *LeftOverAddRec =
614 SE.getAddRecExpr(StartExpr, StepPair.second, AddRec->getLoop(),
615 AddRec->getNoWrapFlags());
616 return std::make_pair(StepPair.first, LeftOverAddRec);
618 return std::make_pair(ConstPart, S);
621 if (auto *Add = dyn_cast<SCEVAddExpr>(S)) {
622 SmallVector<const SCEV *, 4> LeftOvers;
623 auto Op0Pair = extractConstantFactor(Add->getOperand(0), SE);
624 auto *Factor = Op0Pair.first;
625 if (SE.isKnownNegative(Factor)) {
626 Factor = cast<SCEVConstant>(SE.getNegativeSCEV(Factor));
627 LeftOvers.push_back(SE.getNegativeSCEV(Op0Pair.second));
628 } else {
629 LeftOvers.push_back(Op0Pair.second);
632 for (unsigned u = 1, e = Add->getNumOperands(); u < e; u++) {
633 auto OpUPair = extractConstantFactor(Add->getOperand(u), SE);
634 // TODO: Use something smarter than equality here, e.g., gcd.
635 if (Factor == OpUPair.first)
636 LeftOvers.push_back(OpUPair.second);
637 else if (Factor == SE.getNegativeSCEV(OpUPair.first))
638 LeftOvers.push_back(SE.getNegativeSCEV(OpUPair.second));
639 else
640 return std::make_pair(ConstPart, S);
643 auto *NewAdd = SE.getAddExpr(LeftOvers, Add->getNoWrapFlags());
644 return std::make_pair(Factor, NewAdd);
647 auto *Mul = dyn_cast<SCEVMulExpr>(S);
648 if (!Mul)
649 return std::make_pair(ConstPart, S);
651 SmallVector<const SCEV *, 4> LeftOvers;
652 for (auto *Op : Mul->operands())
653 if (isa<SCEVConstant>(Op))
654 ConstPart = cast<SCEVConstant>(SE.getMulExpr(ConstPart, Op));
655 else
656 LeftOvers.push_back(Op);
658 return std::make_pair(ConstPart, SE.getMulExpr(LeftOvers));
660 } // namespace polly