From: Roman Lebedev Date: Wed, 17 Jun 2020 10:09:22 +0000 (+0300) Subject: [NFCI] SCEV: promote ScalarEvolutionDivision into an publicly usable class X-Git-Tag: llvmorg-12-init~2037 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=1e2691fe238b9b47dd3fa2c7a8874826ead45760;p=platform%2Fupstream%2Fllvm.git [NFCI] SCEV: promote ScalarEvolutionDivision into an publicly usable class This makes it usable from outside of SCEV, while previously it was internal to the ScalarEvolution.cpp In particular, i want to use it in an WIP alloca promotion helper pass, to analyze if some SCEV is a multiple of some other SCEV. --- diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionDivision.h b/llvm/include/llvm/Analysis/ScalarEvolutionDivision.h new file mode 100644 index 0000000..480f92c --- /dev/null +++ b/llvm/include/llvm/Analysis/ScalarEvolutionDivision.h @@ -0,0 +1,69 @@ +//===- llvm/Analysis/ScalarEvolutionDivision.h - See below ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the class that knows how to divide SCEV's. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ANALYSIS_SCALAREVOLUTIONDIVISION_H +#define LLVM_ANALYSIS_SCALAREVOLUTIONDIVISION_H + +#include "llvm/Analysis/ScalarEvolutionExpressions.h" + +namespace llvm { + +class SCEV; + +class ScalarEvolution; + +struct SCEVCouldNotCompute; + +struct SCEVDivision : public SCEVVisitor { +public: + // Computes the Quotient and Remainder of the division of Numerator by + // Denominator. + static void divide(ScalarEvolution &SE, const SCEV *Numerator, + const SCEV *Denominator, const SCEV **Quotient, + const SCEV **Remainder); + + // Except in the trivial case described above, we do not know how to divide + // Expr by Denominator for the following functions with empty implementation. + void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {} + void visitZeroExtendExpr(const SCEVZeroExtendExpr *Numerator) {} + void visitSignExtendExpr(const SCEVSignExtendExpr *Numerator) {} + void visitUDivExpr(const SCEVUDivExpr *Numerator) {} + void visitSMaxExpr(const SCEVSMaxExpr *Numerator) {} + void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {} + void visitSMinExpr(const SCEVSMinExpr *Numerator) {} + void visitUMinExpr(const SCEVUMinExpr *Numerator) {} + void visitUnknown(const SCEVUnknown *Numerator) {} + void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {} + + void visitConstant(const SCEVConstant *Numerator); + + void visitAddRecExpr(const SCEVAddRecExpr *Numerator); + + void visitAddExpr(const SCEVAddExpr *Numerator); + + void visitMulExpr(const SCEVMulExpr *Numerator); + +private: + SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, + const SCEV *Denominator); + + // Convenience function for giving up on the division. We set the quotient to + // be equal to zero and the remainder to be equal to the numerator. + void cannotDivide(const SCEV *Numerator); + + ScalarEvolution &SE; + const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One; +}; + +} // end namespace llvm + +#endif // LLVM_ANALYSIS_SCALAREVOLUTIONDIVISION_H diff --git a/llvm/lib/Analysis/CMakeLists.txt b/llvm/lib/Analysis/CMakeLists.txt index eaf9670..9cc2576 100644 --- a/llvm/lib/Analysis/CMakeLists.txt +++ b/llvm/lib/Analysis/CMakeLists.txt @@ -101,6 +101,7 @@ add_llvm_component_library(LLVMAnalysis RegionPrinter.cpp ScalarEvolution.cpp ScalarEvolutionAliasAnalysis.cpp + ScalarEvolutionDivision.cpp ScalarEvolutionNormalization.cpp StackLifetime.cpp StackSafetyAnalysis.cpp diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index eb7fd1d..e2afbb7 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -79,6 +79,7 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolutionDivision.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" @@ -847,30 +848,6 @@ static void GroupByComplexity(SmallVectorImpl &Ops, } } -// Returns the size of the SCEV S. -static inline int sizeOfSCEV(const SCEV *S) { - struct FindSCEVSize { - int Size = 0; - - FindSCEVSize() = default; - - bool follow(const SCEV *S) { - ++Size; - // Keep looking at all operands of S. - return true; - } - - bool isDone() const { - return false; - } - }; - - FindSCEVSize F; - SCEVTraversal ST(F); - ST.visitAll(S); - return F.Size; -} - /// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at /// least HugeExprThreshold nodes). static bool hasHugeExpression(ArrayRef Ops) { @@ -879,238 +856,6 @@ static bool hasHugeExpression(ArrayRef Ops) { }); } -namespace { - -struct SCEVDivision : public SCEVVisitor { -public: - // Computes the Quotient and Remainder of the division of Numerator by - // Denominator. - static void divide(ScalarEvolution &SE, const SCEV *Numerator, - const SCEV *Denominator, const SCEV **Quotient, - const SCEV **Remainder) { - assert(Numerator && Denominator && "Uninitialized SCEV"); - - SCEVDivision D(SE, Numerator, Denominator); - - // Check for the trivial case here to avoid having to check for it in the - // rest of the code. - if (Numerator == Denominator) { - *Quotient = D.One; - *Remainder = D.Zero; - return; - } - - if (Numerator->isZero()) { - *Quotient = D.Zero; - *Remainder = D.Zero; - return; - } - - // A simple case when N/1. The quotient is N. - if (Denominator->isOne()) { - *Quotient = Numerator; - *Remainder = D.Zero; - return; - } - - // Split the Denominator when it is a product. - if (const SCEVMulExpr *T = dyn_cast(Denominator)) { - const SCEV *Q, *R; - *Quotient = Numerator; - for (const SCEV *Op : T->operands()) { - divide(SE, *Quotient, Op, &Q, &R); - *Quotient = Q; - - // Bail out when the Numerator is not divisible by one of the terms of - // the Denominator. - if (!R->isZero()) { - *Quotient = D.Zero; - *Remainder = Numerator; - return; - } - } - *Remainder = D.Zero; - return; - } - - D.visit(Numerator); - *Quotient = D.Quotient; - *Remainder = D.Remainder; - } - - // Except in the trivial case described above, we do not know how to divide - // Expr by Denominator for the following functions with empty implementation. - void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {} - void visitZeroExtendExpr(const SCEVZeroExtendExpr *Numerator) {} - void visitSignExtendExpr(const SCEVSignExtendExpr *Numerator) {} - void visitUDivExpr(const SCEVUDivExpr *Numerator) {} - void visitSMaxExpr(const SCEVSMaxExpr *Numerator) {} - void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {} - void visitSMinExpr(const SCEVSMinExpr *Numerator) {} - void visitUMinExpr(const SCEVUMinExpr *Numerator) {} - void visitUnknown(const SCEVUnknown *Numerator) {} - void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {} - - void visitConstant(const SCEVConstant *Numerator) { - if (const SCEVConstant *D = dyn_cast(Denominator)) { - APInt NumeratorVal = Numerator->getAPInt(); - APInt DenominatorVal = D->getAPInt(); - uint32_t NumeratorBW = NumeratorVal.getBitWidth(); - uint32_t DenominatorBW = DenominatorVal.getBitWidth(); - - if (NumeratorBW > DenominatorBW) - DenominatorVal = DenominatorVal.sext(NumeratorBW); - else if (NumeratorBW < DenominatorBW) - NumeratorVal = NumeratorVal.sext(DenominatorBW); - - APInt QuotientVal(NumeratorVal.getBitWidth(), 0); - APInt RemainderVal(NumeratorVal.getBitWidth(), 0); - APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal); - Quotient = SE.getConstant(QuotientVal); - Remainder = SE.getConstant(RemainderVal); - return; - } - } - - void visitAddRecExpr(const SCEVAddRecExpr *Numerator) { - const SCEV *StartQ, *StartR, *StepQ, *StepR; - if (!Numerator->isAffine()) - return cannotDivide(Numerator); - divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR); - divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR); - // Bail out if the types do not match. - Type *Ty = Denominator->getType(); - if (Ty != StartQ->getType() || Ty != StartR->getType() || - Ty != StepQ->getType() || Ty != StepR->getType()) - return cannotDivide(Numerator); - Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(), - Numerator->getNoWrapFlags()); - Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(), - Numerator->getNoWrapFlags()); - } - - void visitAddExpr(const SCEVAddExpr *Numerator) { - SmallVector Qs, Rs; - Type *Ty = Denominator->getType(); - - for (const SCEV *Op : Numerator->operands()) { - const SCEV *Q, *R; - divide(SE, Op, Denominator, &Q, &R); - - // Bail out if types do not match. - if (Ty != Q->getType() || Ty != R->getType()) - return cannotDivide(Numerator); - - Qs.push_back(Q); - Rs.push_back(R); - } - - if (Qs.size() == 1) { - Quotient = Qs[0]; - Remainder = Rs[0]; - return; - } - - Quotient = SE.getAddExpr(Qs); - Remainder = SE.getAddExpr(Rs); - } - - void visitMulExpr(const SCEVMulExpr *Numerator) { - SmallVector Qs; - Type *Ty = Denominator->getType(); - - bool FoundDenominatorTerm = false; - for (const SCEV *Op : Numerator->operands()) { - // Bail out if types do not match. - if (Ty != Op->getType()) - return cannotDivide(Numerator); - - if (FoundDenominatorTerm) { - Qs.push_back(Op); - continue; - } - - // Check whether Denominator divides one of the product operands. - const SCEV *Q, *R; - divide(SE, Op, Denominator, &Q, &R); - if (!R->isZero()) { - Qs.push_back(Op); - continue; - } - - // Bail out if types do not match. - if (Ty != Q->getType()) - return cannotDivide(Numerator); - - FoundDenominatorTerm = true; - Qs.push_back(Q); - } - - if (FoundDenominatorTerm) { - Remainder = Zero; - if (Qs.size() == 1) - Quotient = Qs[0]; - else - Quotient = SE.getMulExpr(Qs); - return; - } - - if (!isa(Denominator)) - return cannotDivide(Numerator); - - // The Remainder is obtained by replacing Denominator by 0 in Numerator. - ValueToValueMap RewriteMap; - RewriteMap[cast(Denominator)->getValue()] = - cast(Zero)->getValue(); - Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); - - if (Remainder->isZero()) { - // The Quotient is obtained by replacing Denominator by 1 in Numerator. - RewriteMap[cast(Denominator)->getValue()] = - cast(One)->getValue(); - Quotient = - SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); - return; - } - - // Quotient is (Numerator - Remainder) divided by Denominator. - const SCEV *Q, *R; - const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder); - // This SCEV does not seem to simplify: fail the division here. - if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) - return cannotDivide(Numerator); - divide(SE, Diff, Denominator, &Q, &R); - if (R != Zero) - return cannotDivide(Numerator); - Quotient = Q; - } - -private: - SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, - const SCEV *Denominator) - : SE(S), Denominator(Denominator) { - Zero = SE.getZero(Denominator->getType()); - One = SE.getOne(Denominator->getType()); - - // We generally do not know how to divide Expr by Denominator. We - // initialize the division to a "cannot divide" state to simplify the rest - // of the code. - cannotDivide(Numerator); - } - - // Convenience function for giving up on the division. We set the quotient to - // be equal to zero and the remainder to be equal to the numerator. - void cannotDivide(const SCEV *Numerator) { - Quotient = Zero; - Remainder = Numerator; - } - - ScalarEvolution &SE; - const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One; -}; - -} // end anonymous namespace - //===----------------------------------------------------------------------===// // Simple SCEV method implementations //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Analysis/ScalarEvolutionDivision.cpp b/llvm/lib/Analysis/ScalarEvolutionDivision.cpp new file mode 100644 index 0000000..19bf576 --- /dev/null +++ b/llvm/lib/Analysis/ScalarEvolutionDivision.cpp @@ -0,0 +1,259 @@ +//===- ScalarEvolutionDivision.h - See below --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the class that knows how to divide SCEV's. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ScalarEvolutionDivision.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/IR/Constants.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include +#include + +namespace llvm { +class Type; +} + +using namespace llvm; + +namespace { + +static inline int sizeOfSCEV(const SCEV *S) { + struct FindSCEVSize { + int Size = 0; + + FindSCEVSize() = default; + + bool follow(const SCEV *S) { + ++Size; + // Keep looking at all operands of S. + return true; + } + + bool isDone() const { return false; } + }; + + FindSCEVSize F; + SCEVTraversal ST(F); + ST.visitAll(S); + return F.Size; +} + +} // namespace + +// Computes the Quotient and Remainder of the division of Numerator by +// Denominator. +void SCEVDivision::divide(ScalarEvolution &SE, const SCEV *Numerator, + const SCEV *Denominator, const SCEV **Quotient, + const SCEV **Remainder) { + assert(Numerator && Denominator && "Uninitialized SCEV"); + + SCEVDivision D(SE, Numerator, Denominator); + + // Check for the trivial case here to avoid having to check for it in the + // rest of the code. + if (Numerator == Denominator) { + *Quotient = D.One; + *Remainder = D.Zero; + return; + } + + if (Numerator->isZero()) { + *Quotient = D.Zero; + *Remainder = D.Zero; + return; + } + + // A simple case when N/1. The quotient is N. + if (Denominator->isOne()) { + *Quotient = Numerator; + *Remainder = D.Zero; + return; + } + + // Split the Denominator when it is a product. + if (const SCEVMulExpr *T = dyn_cast(Denominator)) { + const SCEV *Q, *R; + *Quotient = Numerator; + for (const SCEV *Op : T->operands()) { + divide(SE, *Quotient, Op, &Q, &R); + *Quotient = Q; + + // Bail out when the Numerator is not divisible by one of the terms of + // the Denominator. + if (!R->isZero()) { + *Quotient = D.Zero; + *Remainder = Numerator; + return; + } + } + *Remainder = D.Zero; + return; + } + + D.visit(Numerator); + *Quotient = D.Quotient; + *Remainder = D.Remainder; +} + +void SCEVDivision::visitConstant(const SCEVConstant *Numerator) { + if (const SCEVConstant *D = dyn_cast(Denominator)) { + APInt NumeratorVal = Numerator->getAPInt(); + APInt DenominatorVal = D->getAPInt(); + uint32_t NumeratorBW = NumeratorVal.getBitWidth(); + uint32_t DenominatorBW = DenominatorVal.getBitWidth(); + + if (NumeratorBW > DenominatorBW) + DenominatorVal = DenominatorVal.sext(NumeratorBW); + else if (NumeratorBW < DenominatorBW) + NumeratorVal = NumeratorVal.sext(DenominatorBW); + + APInt QuotientVal(NumeratorVal.getBitWidth(), 0); + APInt RemainderVal(NumeratorVal.getBitWidth(), 0); + APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal); + Quotient = SE.getConstant(QuotientVal); + Remainder = SE.getConstant(RemainderVal); + return; + } +} + +void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr *Numerator) { + const SCEV *StartQ, *StartR, *StepQ, *StepR; + if (!Numerator->isAffine()) + return cannotDivide(Numerator); + divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR); + divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR); + // Bail out if the types do not match. + Type *Ty = Denominator->getType(); + if (Ty != StartQ->getType() || Ty != StartR->getType() || + Ty != StepQ->getType() || Ty != StepR->getType()) + return cannotDivide(Numerator); + Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(), + Numerator->getNoWrapFlags()); + Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(), + Numerator->getNoWrapFlags()); +} + +void SCEVDivision::visitAddExpr(const SCEVAddExpr *Numerator) { + SmallVector Qs, Rs; + Type *Ty = Denominator->getType(); + + for (const SCEV *Op : Numerator->operands()) { + const SCEV *Q, *R; + divide(SE, Op, Denominator, &Q, &R); + + // Bail out if types do not match. + if (Ty != Q->getType() || Ty != R->getType()) + return cannotDivide(Numerator); + + Qs.push_back(Q); + Rs.push_back(R); + } + + if (Qs.size() == 1) { + Quotient = Qs[0]; + Remainder = Rs[0]; + return; + } + + Quotient = SE.getAddExpr(Qs); + Remainder = SE.getAddExpr(Rs); +} + +void SCEVDivision::visitMulExpr(const SCEVMulExpr *Numerator) { + SmallVector Qs; + Type *Ty = Denominator->getType(); + + bool FoundDenominatorTerm = false; + for (const SCEV *Op : Numerator->operands()) { + // Bail out if types do not match. + if (Ty != Op->getType()) + return cannotDivide(Numerator); + + if (FoundDenominatorTerm) { + Qs.push_back(Op); + continue; + } + + // Check whether Denominator divides one of the product operands. + const SCEV *Q, *R; + divide(SE, Op, Denominator, &Q, &R); + if (!R->isZero()) { + Qs.push_back(Op); + continue; + } + + // Bail out if types do not match. + if (Ty != Q->getType()) + return cannotDivide(Numerator); + + FoundDenominatorTerm = true; + Qs.push_back(Q); + } + + if (FoundDenominatorTerm) { + Remainder = Zero; + if (Qs.size() == 1) + Quotient = Qs[0]; + else + Quotient = SE.getMulExpr(Qs); + return; + } + + if (!isa(Denominator)) + return cannotDivide(Numerator); + + // The Remainder is obtained by replacing Denominator by 0 in Numerator. + ValueToValueMap RewriteMap; + RewriteMap[cast(Denominator)->getValue()] = + cast(Zero)->getValue(); + Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); + + if (Remainder->isZero()) { + // The Quotient is obtained by replacing Denominator by 1 in Numerator. + RewriteMap[cast(Denominator)->getValue()] = + cast(One)->getValue(); + Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); + return; + } + + // Quotient is (Numerator - Remainder) divided by Denominator. + const SCEV *Q, *R; + const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder); + // This SCEV does not seem to simplify: fail the division here. + if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) + return cannotDivide(Numerator); + divide(SE, Diff, Denominator, &Q, &R); + if (R != Zero) + return cannotDivide(Numerator); + Quotient = Q; +} + +SCEVDivision::SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, + const SCEV *Denominator) + : SE(S), Denominator(Denominator) { + Zero = SE.getZero(Denominator->getType()); + One = SE.getOne(Denominator->getType()); + + // We generally do not know how to divide Expr by Denominator. We initialize + // the division to a "cannot divide" state to simplify the rest of the code. + cannotDivide(Numerator); +} + +// Convenience function for giving up on the division. We set the quotient to +// be equal to zero and the remainder to be equal to the numerator. +void SCEVDivision::cannotDivide(const SCEV *Numerator) { + Quotient = Zero; + Remainder = Numerator; +}