2 #include "polly/Support/SCEVValidator.h"
3 #include "polly/ScopInfo.h"
4 #include "llvm/Analysis/RegionInfo.h"
5 #include "llvm/Analysis/ScalarEvolution.h"
6 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
7 #include "llvm/Support/Debug.h"
10 using namespace polly
;
12 #define DEBUG_TYPE "polly-scev-validator"
15 /// The type of a SCEV
17 /// To check for the validity of a SCEV we assign to each SCEV a type. The
18 /// possible types are INT, PARAM, IV and INVALID. The order of the types is
19 /// important. The subexpressions of SCEV with a type X can only have a type
20 /// that is smaller or equal than X.
25 // An expression that is constant during the execution of the Scop,
26 // but that may depend on parameters unknown at compile time.
29 // An expression that may change during the execution of the SCoP.
32 // An invalid expression.
35 } // namespace SCEVType
37 /// The result the validator returns for a SCEV expression.
38 class ValidatorResult
{
39 /// The type of the expression
42 /// The set of Parameters in the expression.
43 ParameterSetTy Parameters
;
46 /// The copy constructor
47 ValidatorResult(const ValidatorResult
&Source
) {
49 Parameters
= Source
.Parameters
;
52 /// Construct a result with a certain type and no parameters.
53 ValidatorResult(SCEVType::TYPE Type
) : Type(Type
) {
54 assert(Type
!= SCEVType::PARAM
&& "Did you forget to pass the parameter");
57 /// Construct a result with a certain type and a single parameter.
58 ValidatorResult(SCEVType::TYPE Type
, const SCEV
*Expr
) : Type(Type
) {
59 Parameters
.insert(Expr
);
62 /// Get the type of the ValidatorResult.
63 SCEVType::TYPE
getType() { return Type
; }
65 /// Is the analyzed SCEV constant during the execution of the SCoP.
66 bool isConstant() { return Type
== SCEVType::INT
|| Type
== SCEVType::PARAM
; }
68 /// Is the analyzed SCEV valid.
69 bool isValid() { return Type
!= SCEVType::INVALID
; }
71 /// Is the analyzed SCEV of Type IV.
72 bool isIV() { return Type
== SCEVType::IV
; }
74 /// Is the analyzed SCEV of Type INT.
75 bool isINT() { return Type
== SCEVType::INT
; }
77 /// Is the analyzed SCEV of Type PARAM.
78 bool isPARAM() { return Type
== SCEVType::PARAM
; }
80 /// Get the parameters of this validator result.
81 const ParameterSetTy
&getParameters() { return Parameters
; }
83 /// Add the parameters of Source to this result.
84 void addParamsFrom(const ValidatorResult
&Source
) {
85 Parameters
.insert(Source
.Parameters
.begin(), Source
.Parameters
.end());
90 /// This means to merge the parameters and to set the Type to the most
91 /// specific Type that matches both.
92 void merge(const ValidatorResult
&ToMerge
) {
93 Type
= std::max(Type
, ToMerge
.Type
);
94 addParamsFrom(ToMerge
);
97 void print(raw_ostream
&OS
) {
100 OS
<< "SCEVType::INT";
102 case SCEVType::PARAM
:
103 OS
<< "SCEVType::PARAM";
106 OS
<< "SCEVType::IV";
108 case SCEVType::INVALID
:
109 OS
<< "SCEVType::INVALID";
115 raw_ostream
&operator<<(raw_ostream
&OS
, class ValidatorResult
&VR
) {
120 /// Check if a SCEV is valid in a SCoP.
122 : public SCEVVisitor
<SCEVValidator
, class ValidatorResult
> {
127 InvariantLoadsSetTy
*ILS
;
130 SCEVValidator(const Region
*R
, Loop
*Scope
, ScalarEvolution
&SE
,
131 InvariantLoadsSetTy
*ILS
)
132 : R(R
), Scope(Scope
), SE(SE
), ILS(ILS
) {}
134 class ValidatorResult
visitConstant(const SCEVConstant
*Constant
) {
135 return ValidatorResult(SCEVType::INT
);
138 class ValidatorResult
visitZeroExtendOrTruncateExpr(const SCEV
*Expr
,
139 const SCEV
*Operand
) {
140 ValidatorResult Op
= visit(Operand
);
141 auto Type
= Op
.getType();
143 // If unsigned operations are allowed return the operand, otherwise
144 // check if we can model the expression without unsigned assumptions.
145 if (PollyAllowUnsignedOperations
|| Type
== SCEVType::INVALID
)
148 if (Type
== SCEVType::IV
)
149 return ValidatorResult(SCEVType::INVALID
);
150 return ValidatorResult(SCEVType::PARAM
, Expr
);
153 class ValidatorResult
visitTruncateExpr(const SCEVTruncateExpr
*Expr
) {
154 return visitZeroExtendOrTruncateExpr(Expr
, Expr
->getOperand());
157 class ValidatorResult
visitZeroExtendExpr(const SCEVZeroExtendExpr
*Expr
) {
158 return visitZeroExtendOrTruncateExpr(Expr
, Expr
->getOperand());
161 class ValidatorResult
visitSignExtendExpr(const SCEVSignExtendExpr
*Expr
) {
162 return visit(Expr
->getOperand());
165 class ValidatorResult
visitAddExpr(const SCEVAddExpr
*Expr
) {
166 ValidatorResult
Return(SCEVType::INT
);
168 for (int i
= 0, e
= Expr
->getNumOperands(); i
< e
; ++i
) {
169 ValidatorResult Op
= visit(Expr
->getOperand(i
));
173 if (!Return
.isValid())
180 class ValidatorResult
visitMulExpr(const SCEVMulExpr
*Expr
) {
181 ValidatorResult
Return(SCEVType::INT
);
183 bool HasMultipleParams
= false;
185 for (int i
= 0, e
= Expr
->getNumOperands(); i
< e
; ++i
) {
186 ValidatorResult Op
= visit(Expr
->getOperand(i
));
191 if (Op
.isPARAM() && Return
.isPARAM()) {
192 HasMultipleParams
= true;
196 if ((Op
.isIV() || Op
.isPARAM()) && !Return
.isINT()) {
197 DEBUG(dbgs() << "INVALID: More than one non-int operand in MulExpr\n"
198 << "\tExpr: " << *Expr
<< "\n"
199 << "\tPrevious expression type: " << Return
<< "\n"
200 << "\tNext operand (" << Op
201 << "): " << *Expr
->getOperand(i
) << "\n");
203 return ValidatorResult(SCEVType::INVALID
);
209 if (HasMultipleParams
&& Return
.isValid())
210 return ValidatorResult(SCEVType::PARAM
, Expr
);
215 class ValidatorResult
visitAddRecExpr(const SCEVAddRecExpr
*Expr
) {
216 if (!Expr
->isAffine()) {
217 DEBUG(dbgs() << "INVALID: AddRec is not affine");
218 return ValidatorResult(SCEVType::INVALID
);
221 ValidatorResult Start
= visit(Expr
->getStart());
222 ValidatorResult Recurrence
= visit(Expr
->getStepRecurrence(SE
));
224 if (!Start
.isValid())
227 if (!Recurrence
.isValid())
230 auto *L
= Expr
->getLoop();
231 if (R
->contains(L
) && (!Scope
|| !L
->contains(Scope
))) {
232 DEBUG(dbgs() << "INVALID: Loop of AddRec expression boxed in an a "
233 "non-affine subregion or has a non-synthesizable exit "
235 return ValidatorResult(SCEVType::INVALID
);
238 if (R
->contains(L
)) {
239 if (Recurrence
.isINT()) {
240 ValidatorResult
Result(SCEVType::IV
);
241 Result
.addParamsFrom(Start
);
245 DEBUG(dbgs() << "INVALID: AddRec within scop has non-int"
247 return ValidatorResult(SCEVType::INVALID
);
250 assert(Recurrence
.isConstant() && "Expected 'Recurrence' to be constant");
252 // Directly generate ValidatorResult for Expr if 'start' is zero.
253 if (Expr
->getStart()->isZero())
254 return ValidatorResult(SCEVType::PARAM
, Expr
);
256 // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}'
257 // if 'start' is not zero.
258 const SCEV
*ZeroStartExpr
= SE
.getAddRecExpr(
259 SE
.getConstant(Expr
->getStart()->getType(), 0),
260 Expr
->getStepRecurrence(SE
), Expr
->getLoop(), Expr
->getNoWrapFlags());
262 ValidatorResult ZeroStartResult
=
263 ValidatorResult(SCEVType::PARAM
, ZeroStartExpr
);
264 ZeroStartResult
.addParamsFrom(Start
);
266 return ZeroStartResult
;
269 class ValidatorResult
visitSMaxExpr(const SCEVSMaxExpr
*Expr
) {
270 ValidatorResult
Return(SCEVType::INT
);
272 for (int i
= 0, e
= Expr
->getNumOperands(); i
< e
; ++i
) {
273 ValidatorResult Op
= visit(Expr
->getOperand(i
));
284 class ValidatorResult
visitUMaxExpr(const SCEVUMaxExpr
*Expr
) {
285 // We do not support unsigned max operations. If 'Expr' is constant during
286 // Scop execution we treat this as a parameter, otherwise we bail out.
287 for (int i
= 0, e
= Expr
->getNumOperands(); i
< e
; ++i
) {
288 ValidatorResult Op
= visit(Expr
->getOperand(i
));
290 if (!Op
.isConstant()) {
291 DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand");
292 return ValidatorResult(SCEVType::INVALID
);
296 return ValidatorResult(SCEVType::PARAM
, Expr
);
299 ValidatorResult
visitGenericInst(Instruction
*I
, const SCEV
*S
) {
300 if (R
->contains(I
)) {
301 DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction "
302 "within the region\n");
303 return ValidatorResult(SCEVType::INVALID
);
306 return ValidatorResult(SCEVType::PARAM
, S
);
309 ValidatorResult
visitLoadInstruction(Instruction
*I
, const SCEV
*S
) {
310 if (R
->contains(I
) && ILS
) {
311 ILS
->insert(cast
<LoadInst
>(I
));
312 return ValidatorResult(SCEVType::PARAM
, S
);
315 return visitGenericInst(I
, S
);
318 ValidatorResult
visitDivision(const SCEV
*Dividend
, const SCEV
*Divisor
,
320 Instruction
*SDiv
= nullptr) {
322 // First check if we might be able to model the division, thus if the
323 // divisor is constant. If so, check the dividend, otherwise check if
324 // the whole division can be seen as a parameter.
325 if (isa
<SCEVConstant
>(Divisor
) && !Divisor
->isZero())
326 return visit(Dividend
);
328 // For signed divisions use the SDiv instruction to check for a parameter
329 // division, for unsigned divisions check the operands.
331 return visitGenericInst(SDiv
, DivExpr
);
333 ValidatorResult LHS
= visit(Dividend
);
334 ValidatorResult RHS
= visit(Divisor
);
335 if (LHS
.isConstant() && RHS
.isConstant())
336 return ValidatorResult(SCEVType::PARAM
, DivExpr
);
338 DEBUG(dbgs() << "INVALID: unsigned division of non-constant expressions");
339 return ValidatorResult(SCEVType::INVALID
);
342 ValidatorResult
visitUDivExpr(const SCEVUDivExpr
*Expr
) {
343 if (!PollyAllowUnsignedOperations
)
344 return ValidatorResult(SCEVType::INVALID
);
346 auto *Dividend
= Expr
->getLHS();
347 auto *Divisor
= Expr
->getRHS();
348 return visitDivision(Dividend
, Divisor
, Expr
);
351 ValidatorResult
visitSDivInstruction(Instruction
*SDiv
, const SCEV
*Expr
) {
352 assert(SDiv
->getOpcode() == Instruction::SDiv
&&
353 "Assumed SDiv instruction!");
355 auto *Dividend
= SE
.getSCEV(SDiv
->getOperand(0));
356 auto *Divisor
= SE
.getSCEV(SDiv
->getOperand(1));
357 return visitDivision(Dividend
, Divisor
, Expr
, SDiv
);
360 ValidatorResult
visitSRemInstruction(Instruction
*SRem
, const SCEV
*S
) {
361 assert(SRem
->getOpcode() == Instruction::SRem
&&
362 "Assumed SRem instruction!");
364 auto *Divisor
= SRem
->getOperand(1);
365 auto *CI
= dyn_cast
<ConstantInt
>(Divisor
);
366 if (!CI
|| CI
->isZeroValue())
367 return visitGenericInst(SRem
, S
);
369 auto *Dividend
= SRem
->getOperand(0);
370 auto *DividendSCEV
= SE
.getSCEV(Dividend
);
371 return visit(DividendSCEV
);
374 ValidatorResult
visitUnknown(const SCEVUnknown
*Expr
) {
375 Value
*V
= Expr
->getValue();
377 if (!Expr
->getType()->isIntegerTy() && !Expr
->getType()->isPointerTy()) {
378 DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer or pointer");
379 return ValidatorResult(SCEVType::INVALID
);
382 if (isa
<UndefValue
>(V
)) {
383 DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value");
384 return ValidatorResult(SCEVType::INVALID
);
387 if (Instruction
*I
= dyn_cast
<Instruction
>(Expr
->getValue())) {
388 switch (I
->getOpcode()) {
389 case Instruction::IntToPtr
:
390 return visit(SE
.getSCEVAtScope(I
->getOperand(0), Scope
));
391 case Instruction::PtrToInt
:
392 return visit(SE
.getSCEVAtScope(I
->getOperand(0), Scope
));
393 case Instruction::Load
:
394 return visitLoadInstruction(I
, Expr
);
395 case Instruction::SDiv
:
396 return visitSDivInstruction(I
, Expr
);
397 case Instruction::SRem
:
398 return visitSRemInstruction(I
, Expr
);
400 return visitGenericInst(I
, Expr
);
404 return ValidatorResult(SCEVType::PARAM
, Expr
);
408 /// Check whether a SCEV refers to an SSA name defined inside a region.
409 class SCEVInRegionDependences
{
413 bool HasInRegionDeps
= false;
416 SCEVInRegionDependences(const Region
*R
, Loop
*Scope
, bool AllowLoops
)
417 : R(R
), Scope(Scope
), AllowLoops(AllowLoops
) {}
419 bool follow(const SCEV
*S
) {
420 if (auto Unknown
= dyn_cast
<SCEVUnknown
>(S
)) {
421 Instruction
*Inst
= dyn_cast
<Instruction
>(Unknown
->getValue());
423 // Return true when Inst is defined inside the region R.
424 if (!Inst
|| !R
->contains(Inst
))
427 HasInRegionDeps
= true;
431 if (auto AddRec
= dyn_cast
<SCEVAddRecExpr
>(S
)) {
436 HasInRegionDeps
= true;
439 auto *L
= AddRec
->getLoop();
440 if (R
->contains(L
) && !L
->contains(Scope
)) {
441 HasInRegionDeps
= true;
448 bool isDone() { return false; }
449 bool hasDependences() { return HasInRegionDeps
; }
453 /// Find all loops referenced in SCEVAddRecExprs.
454 class SCEVFindLoops
{
455 SetVector
<const Loop
*> &Loops
;
458 SCEVFindLoops(SetVector
<const Loop
*> &Loops
) : Loops(Loops
) {}
460 bool follow(const SCEV
*S
) {
461 if (const SCEVAddRecExpr
*AddRec
= dyn_cast
<SCEVAddRecExpr
>(S
))
462 Loops
.insert(AddRec
->getLoop());
465 bool isDone() { return false; }
468 void findLoops(const SCEV
*Expr
, SetVector
<const Loop
*> &Loops
) {
469 SCEVFindLoops
FindLoops(Loops
);
470 SCEVTraversal
<SCEVFindLoops
> ST(FindLoops
);
474 /// Find all values referenced in SCEVUnknowns.
475 class SCEVFindValues
{
477 SetVector
<Value
*> &Values
;
480 SCEVFindValues(ScalarEvolution
&SE
, SetVector
<Value
*> &Values
)
481 : SE(SE
), Values(Values
) {}
483 bool follow(const SCEV
*S
) {
484 const SCEVUnknown
*Unknown
= dyn_cast
<SCEVUnknown
>(S
);
488 Values
.insert(Unknown
->getValue());
489 Instruction
*Inst
= dyn_cast
<Instruction
>(Unknown
->getValue());
490 if (!Inst
|| (Inst
->getOpcode() != Instruction::SRem
&&
491 Inst
->getOpcode() != Instruction::SDiv
))
494 auto *Dividend
= SE
.getSCEV(Inst
->getOperand(1));
495 if (!isa
<SCEVConstant
>(Dividend
))
498 auto *Divisor
= SE
.getSCEV(Inst
->getOperand(0));
499 SCEVFindValues
FindValues(SE
, Values
);
500 SCEVTraversal
<SCEVFindValues
> ST(FindValues
);
501 ST
.visitAll(Dividend
);
502 ST
.visitAll(Divisor
);
506 bool isDone() { return false; }
509 void findValues(const SCEV
*Expr
, ScalarEvolution
&SE
,
510 SetVector
<Value
*> &Values
) {
511 SCEVFindValues
FindValues(SE
, Values
);
512 SCEVTraversal
<SCEVFindValues
> ST(FindValues
);
516 bool hasScalarDepsInsideRegion(const SCEV
*Expr
, const Region
*R
,
517 llvm::Loop
*Scope
, bool AllowLoops
) {
518 SCEVInRegionDependences
InRegionDeps(R
, Scope
, AllowLoops
);
519 SCEVTraversal
<SCEVInRegionDependences
> ST(InRegionDeps
);
521 return InRegionDeps
.hasDependences();
524 bool isAffineExpr(const Region
*R
, llvm::Loop
*Scope
, const SCEV
*Expr
,
525 ScalarEvolution
&SE
, InvariantLoadsSetTy
*ILS
) {
526 if (isa
<SCEVCouldNotCompute
>(Expr
))
529 SCEVValidator
Validator(R
, Scope
, SE
, ILS
);
532 dbgs() << "Expr: " << *Expr
<< "\n";
533 dbgs() << "Region: " << R
->getNameStr() << "\n";
537 ValidatorResult Result
= Validator
.visit(Expr
);
540 if (Result
.isValid())
545 return Result
.isValid();
548 static bool isAffineExpr(Value
*V
, const Region
*R
, Loop
*Scope
,
549 ScalarEvolution
&SE
, ParameterSetTy
&Params
) {
550 auto *E
= SE
.getSCEV(V
);
551 if (isa
<SCEVCouldNotCompute
>(E
))
554 SCEVValidator
Validator(R
, Scope
, SE
, nullptr);
555 ValidatorResult Result
= Validator
.visit(E
);
556 if (!Result
.isValid())
559 auto ResultParams
= Result
.getParameters();
560 Params
.insert(ResultParams
.begin(), ResultParams
.end());
565 bool isAffineConstraint(Value
*V
, const Region
*R
, llvm::Loop
*Scope
,
566 ScalarEvolution
&SE
, ParameterSetTy
&Params
,
568 if (auto *ICmp
= dyn_cast
<ICmpInst
>(V
)) {
569 return isAffineConstraint(ICmp
->getOperand(0), R
, Scope
, SE
, Params
,
571 isAffineConstraint(ICmp
->getOperand(1), R
, Scope
, SE
, Params
, true);
572 } else if (auto *BinOp
= dyn_cast
<BinaryOperator
>(V
)) {
573 auto Opcode
= BinOp
->getOpcode();
574 if (Opcode
== Instruction::And
|| Opcode
== Instruction::Or
)
575 return isAffineConstraint(BinOp
->getOperand(0), R
, Scope
, SE
, Params
,
577 isAffineConstraint(BinOp
->getOperand(1), R
, Scope
, SE
, Params
,
585 return isAffineExpr(V
, R
, Scope
, SE
, Params
);
588 ParameterSetTy
getParamsInAffineExpr(const Region
*R
, Loop
*Scope
,
589 const SCEV
*Expr
, ScalarEvolution
&SE
) {
590 if (isa
<SCEVCouldNotCompute
>(Expr
))
591 return ParameterSetTy();
593 InvariantLoadsSetTy ILS
;
594 SCEVValidator
Validator(R
, Scope
, SE
, &ILS
);
595 ValidatorResult Result
= Validator
.visit(Expr
);
596 assert(Result
.isValid() && "Requested parameters for an invalid SCEV!");
598 return Result
.getParameters();
601 std::pair
<const SCEVConstant
*, const SCEV
*>
602 extractConstantFactor(const SCEV
*S
, ScalarEvolution
&SE
) {
603 auto *ConstPart
= cast
<SCEVConstant
>(SE
.getConstant(S
->getType(), 1));
605 if (auto *Constant
= dyn_cast
<SCEVConstant
>(S
))
606 return std::make_pair(Constant
, SE
.getConstant(S
->getType(), 1));
608 auto *AddRec
= dyn_cast
<SCEVAddRecExpr
>(S
);
610 auto *StartExpr
= AddRec
->getStart();
611 if (StartExpr
->isZero()) {
612 auto StepPair
= extractConstantFactor(AddRec
->getStepRecurrence(SE
), SE
);
613 auto *LeftOverAddRec
=
614 SE
.getAddRecExpr(StartExpr
, StepPair
.second
, AddRec
->getLoop(),
615 AddRec
->getNoWrapFlags());
616 return std::make_pair(StepPair
.first
, LeftOverAddRec
);
618 return std::make_pair(ConstPart
, S
);
621 if (auto *Add
= dyn_cast
<SCEVAddExpr
>(S
)) {
622 SmallVector
<const SCEV
*, 4> LeftOvers
;
623 auto Op0Pair
= extractConstantFactor(Add
->getOperand(0), SE
);
624 auto *Factor
= Op0Pair
.first
;
625 if (SE
.isKnownNegative(Factor
)) {
626 Factor
= cast
<SCEVConstant
>(SE
.getNegativeSCEV(Factor
));
627 LeftOvers
.push_back(SE
.getNegativeSCEV(Op0Pair
.second
));
629 LeftOvers
.push_back(Op0Pair
.second
);
632 for (unsigned u
= 1, e
= Add
->getNumOperands(); u
< e
; u
++) {
633 auto OpUPair
= extractConstantFactor(Add
->getOperand(u
), SE
);
634 // TODO: Use something smarter than equality here, e.g., gcd.
635 if (Factor
== OpUPair
.first
)
636 LeftOvers
.push_back(OpUPair
.second
);
637 else if (Factor
== SE
.getNegativeSCEV(OpUPair
.first
))
638 LeftOvers
.push_back(SE
.getNegativeSCEV(OpUPair
.second
));
640 return std::make_pair(ConstPart
, S
);
643 auto *NewAdd
= SE
.getAddExpr(LeftOvers
, Add
->getNoWrapFlags());
644 return std::make_pair(Factor
, NewAdd
);
647 auto *Mul
= dyn_cast
<SCEVMulExpr
>(S
);
649 return std::make_pair(ConstPart
, S
);
651 SmallVector
<const SCEV
*, 4> LeftOvers
;
652 for (auto *Op
: Mul
->operands())
653 if (isa
<SCEVConstant
>(Op
))
654 ConstPart
= cast
<SCEVConstant
>(SE
.getMulExpr(ConstPart
, Op
));
656 LeftOvers
.push_back(Op
);
658 return std::make_pair(ConstPart
, SE
.getMulExpr(LeftOvers
));