[LAA] Move runtime-check generation to Transforms/Utils/loopUtils (NFC)
authorFlorian Hahn <flo@fhahn.com>
Sun, 10 May 2020 15:48:06 +0000 (16:48 +0100)
committerFlorian Hahn <flo@fhahn.com>
Sun, 10 May 2020 16:39:26 +0000 (17:39 +0100)
Currently LAA's uses of ScalarEvolutionExpander blocks moving the
expander from Analysis to Transforms. Conceptually the expander does not
fit into Analysis (it is only used for code generation) and
runtime-check generation also seems to be better suited as a
transformation utility.

Reviewers: Ayal, anemet

Reviewed By: Ayal

Differential Revision: https://reviews.llvm.org/D78460

llvm/include/llvm/Analysis/LoopAccessAnalysis.h
llvm/include/llvm/Transforms/Utils/LoopUtils.h
llvm/lib/Analysis/LoopAccessAnalysis.cpp
llvm/lib/Transforms/Utils/LoopUtils.cpp
llvm/lib/Transforms/Utils/LoopVersioning.cpp
llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

index cea7523..276edc8 100644 (file)
@@ -465,6 +465,8 @@ public:
     return Pointers[PtrIdx];
   }
 
+  ScalarEvolution *getSE() const { return SE; }
+
 private:
   /// Groups pointers such that a single memcheck is required
   /// between two different groups. This will clear the CheckingGroups vector
@@ -541,16 +543,6 @@ public:
   unsigned getNumStores() const { return NumStores; }
   unsigned getNumLoads() const { return NumLoads;}
 
-  /// Add code that checks at runtime if the accessed arrays \in p PointerChecks
-  /// overlap.
-  ///
-  /// Returns a pair of instructions where the first element is the first
-  /// instruction generated in possibly a sequence of instructions and the
-  /// second value is the final comparator value or NULL if no check is needed.
-  std::pair<Instruction *, Instruction *> addRuntimeChecks(
-      Instruction *Loc,
-      const SmallVectorImpl<RuntimePointerCheck> &PointerChecks) const;
-
   /// The diagnostics report generated for the analysis.  E.g. why we
   /// couldn't analyze the loop.
   const OptimizationRemarkAnalysis *getReport() const { return Report.get(); }
index 56b5698..12c023f 100644 (file)
@@ -41,6 +41,10 @@ class TargetLibraryInfo;
 class TargetTransformInfo;
 class LPPassManager;
 class Instruction;
+struct RuntimeCheckingPtrGroup;
+typedef std::pair<const RuntimeCheckingPtrGroup *,
+                  const RuntimeCheckingPtrGroup *>
+    RuntimePointerCheck;
 
 template <typename T> class Optional;
 template <typename T, unsigned N> class SmallSetVector;
@@ -428,6 +432,17 @@ void appendLoopsToWorklist(LoopInfo &, SmallPriorityWorklist<Loop *, 4> &);
 Loop *cloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM,
                 LoopInfo *LI, LPPassManager *LPM);
 
+/// Add code that checks at runtime if the accessed arrays in \p PointerChecks
+/// overlap.
+///
+/// Returns a pair of instructions where the first element is the first
+/// instruction generated in possibly a sequence of instructions and the
+/// second value is the final comparator value or NULL if no check is needed.
+std::pair<Instruction *, Instruction *>
+addRuntimeChecks(Instruction *Loc, Loop *TheLoop,
+                 const SmallVectorImpl<RuntimePointerCheck> &PointerChecks,
+                 ScalarEvolution *SE);
+
 } // end namespace llvm
 
 #endif // LLVM_TRANSFORMS_UTILS_LOOPUTILS_H
index 35780bd..b0a5e0b 100644 (file)
@@ -43,7 +43,6 @@
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/Function.h"
-#include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InstrTypes.h"
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
@@ -2123,158 +2122,6 @@ bool LoopAccessInfo::isUniform(Value *V) const {
   return (SE->isLoopInvariant(SE->getSCEV(V), TheLoop));
 }
 
-// FIXME: this function is currently a duplicate of the one in
-// LoopVectorize.cpp.
-static Instruction *getFirstInst(Instruction *FirstInst, Value *V,
-                                 Instruction *Loc) {
-  if (FirstInst)
-    return FirstInst;
-  if (Instruction *I = dyn_cast<Instruction>(V))
-    return I->getParent() == Loc->getParent() ? I : nullptr;
-  return nullptr;
-}
-
-namespace {
-
-/// IR Values for the lower and upper bounds of a pointer evolution.  We
-/// need to use value-handles because SCEV expansion can invalidate previously
-/// expanded values.  Thus expansion of a pointer can invalidate the bounds for
-/// a previous one.
-struct PointerBounds {
-  TrackingVH<Value> Start;
-  TrackingVH<Value> End;
-};
-
-} // end anonymous namespace
-
-/// Expand code for the lower and upper bound of the pointer group \p CG
-/// in \p TheLoop.  \return the values for the bounds.
-static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG,
-                                  Loop *TheLoop, Instruction *Loc,
-                                  SCEVExpander &Exp, ScalarEvolution *SE) {
-  Value *Ptr = CG->RtCheck.Pointers[CG->Members[0]].PointerValue;
-
-  const SCEV *Sc = SE->getSCEV(Ptr);
-
-  unsigned AS = Ptr->getType()->getPointerAddressSpace();
-  LLVMContext &Ctx = Loc->getContext();
-
-  // Use this type for pointer arithmetic.
-  Type *PtrArithTy = Type::getInt8PtrTy(Ctx, AS);
-
-  if (SE->isLoopInvariant(Sc, TheLoop)) {
-    LLVM_DEBUG(dbgs() << "LAA: Adding RT check for a loop invariant ptr:"
-                      << *Ptr << "\n");
-    // Ptr could be in the loop body. If so, expand a new one at the correct
-    // location.
-    Instruction *Inst = dyn_cast<Instruction>(Ptr);
-    Value *NewPtr = (Inst && TheLoop->contains(Inst))
-                        ? Exp.expandCodeFor(Sc, PtrArithTy, Loc)
-                        : Ptr;
-    // We must return a half-open range, which means incrementing Sc.
-    const SCEV *ScPlusOne = SE->getAddExpr(Sc, SE->getOne(PtrArithTy));
-    Value *NewPtrPlusOne = Exp.expandCodeFor(ScPlusOne, PtrArithTy, Loc);
-    return {NewPtr, NewPtrPlusOne};
-  } else {
-    Value *Start = nullptr, *End = nullptr;
-    LLVM_DEBUG(dbgs() << "LAA: Adding RT check for range:\n");
-    Start = Exp.expandCodeFor(CG->Low, PtrArithTy, Loc);
-    End = Exp.expandCodeFor(CG->High, PtrArithTy, Loc);
-    LLVM_DEBUG(dbgs() << "Start: " << *CG->Low << " End: " << *CG->High
-                      << "\n");
-    return {Start, End};
-  }
-}
-
-/// Turns a collection of checks into a collection of expanded upper and
-/// lower bounds for both pointers in the check.
-static SmallVector<std::pair<PointerBounds, PointerBounds>, 4>
-expandBounds(const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, Loop *L,
-             Instruction *Loc, ScalarEvolution *SE, SCEVExpander &Exp) {
-  SmallVector<std::pair<PointerBounds, PointerBounds>, 4> ChecksWithBounds;
-
-  // Here we're relying on the SCEV Expander's cache to only emit code for the
-  // same bounds once.
-  transform(
-      PointerChecks, std::back_inserter(ChecksWithBounds),
-      [&](const RuntimePointerCheck &Check) {
-        PointerBounds
-          First = expandBounds(Check.first, L, Loc, Exp, SE),
-          Second = expandBounds(Check.second, L, Loc, Exp, SE);
-        return std::make_pair(First, Second);
-      });
-
-  return ChecksWithBounds;
-}
-
-std::pair<Instruction *, Instruction *> LoopAccessInfo::addRuntimeChecks(
-    Instruction *Loc,
-    const SmallVectorImpl<RuntimePointerCheck> &PointerChecks) const {
-  const DataLayout &DL = TheLoop->getHeader()->getModule()->getDataLayout();
-  auto *SE = PSE->getSE();
-  SCEVExpander Exp(*SE, DL, "induction");
-  auto ExpandedChecks = expandBounds(PointerChecks, TheLoop, Loc, SE, Exp);
-
-  LLVMContext &Ctx = Loc->getContext();
-  Instruction *FirstInst = nullptr;
-  IRBuilder<> ChkBuilder(Loc);
-  // Our instructions might fold to a constant.
-  Value *MemoryRuntimeCheck = nullptr;
-
-  for (const auto &Check : ExpandedChecks) {
-    const PointerBounds &A = Check.first, &B = Check.second;
-    // Check if two pointers (A and B) conflict where conflict is computed as:
-    // start(A) <= end(B) && start(B) <= end(A)
-    unsigned AS0 = A.Start->getType()->getPointerAddressSpace();
-    unsigned AS1 = B.Start->getType()->getPointerAddressSpace();
-
-    assert((AS0 == B.End->getType()->getPointerAddressSpace()) &&
-           (AS1 == A.End->getType()->getPointerAddressSpace()) &&
-           "Trying to bounds check pointers with different address spaces");
-
-    Type *PtrArithTy0 = Type::getInt8PtrTy(Ctx, AS0);
-    Type *PtrArithTy1 = Type::getInt8PtrTy(Ctx, AS1);
-
-    Value *Start0 = ChkBuilder.CreateBitCast(A.Start, PtrArithTy0, "bc");
-    Value *Start1 = ChkBuilder.CreateBitCast(B.Start, PtrArithTy1, "bc");
-    Value *End0 =   ChkBuilder.CreateBitCast(A.End,   PtrArithTy1, "bc");
-    Value *End1 =   ChkBuilder.CreateBitCast(B.End,   PtrArithTy0, "bc");
-
-    // [A|B].Start points to the first accessed byte under base [A|B].
-    // [A|B].End points to the last accessed byte, plus one.
-    // There is no conflict when the intervals are disjoint:
-    // NoConflict = (B.Start >= A.End) || (A.Start >= B.End)
-    //
-    // bound0 = (B.Start < A.End)
-    // bound1 = (A.Start < B.End)
-    //  IsConflict = bound0 & bound1
-    Value *Cmp0 = ChkBuilder.CreateICmpULT(Start0, End1, "bound0");
-    FirstInst = getFirstInst(FirstInst, Cmp0, Loc);
-    Value *Cmp1 = ChkBuilder.CreateICmpULT(Start1, End0, "bound1");
-    FirstInst = getFirstInst(FirstInst, Cmp1, Loc);
-    Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict");
-    FirstInst = getFirstInst(FirstInst, IsConflict, Loc);
-    if (MemoryRuntimeCheck) {
-      IsConflict =
-          ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx");
-      FirstInst = getFirstInst(FirstInst, IsConflict, Loc);
-    }
-    MemoryRuntimeCheck = IsConflict;
-  }
-
-  if (!MemoryRuntimeCheck)
-    return std::make_pair(nullptr, nullptr);
-
-  // We have to do this trickery because the IRBuilder might fold the check to a
-  // constant expression in which case there is no Instruction anchored in a
-  // the block.
-  Instruction *Check = BinaryOperator::CreateAnd(MemoryRuntimeCheck,
-                                                 ConstantInt::getTrue(Ctx));
-  ChkBuilder.Insert(Check, "memcheck.conflict");
-  FirstInst = getFirstInst(FirstInst, Check, Loc);
-  return std::make_pair(FirstInst, Check);
-}
-
 void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
   Value *Ptr = nullptr;
   if (LoadInst *LI = dyn_cast<LoadInst>(MemAccess))
index a5fbdb5..5f54685 100644 (file)
@@ -1534,3 +1534,152 @@ Loop *llvm::cloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM,
 
   return &New;
 }
+
+/// IR Values for the lower and upper bounds of a pointer evolution.  We
+/// need to use value-handles because SCEV expansion can invalidate previously
+/// expanded values.  Thus expansion of a pointer can invalidate the bounds for
+/// a previous one.
+struct PointerBounds {
+  TrackingVH<Value> Start;
+  TrackingVH<Value> End;
+};
+
+/// Expand code for the lower and upper bound of the pointer group \p CG
+/// in \p TheLoop.  \return the values for the bounds.
+static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG,
+                                  Loop *TheLoop, Instruction *Loc,
+                                  SCEVExpander &Exp, ScalarEvolution *SE) {
+  // TODO: Add helper to retrieve pointers to CG.
+  Value *Ptr = CG->RtCheck.Pointers[CG->Members[0]].PointerValue;
+  const SCEV *Sc = SE->getSCEV(Ptr);
+
+  unsigned AS = Ptr->getType()->getPointerAddressSpace();
+  LLVMContext &Ctx = Loc->getContext();
+
+  // Use this type for pointer arithmetic.
+  Type *PtrArithTy = Type::getInt8PtrTy(Ctx, AS);
+
+  if (SE->isLoopInvariant(Sc, TheLoop)) {
+    LLVM_DEBUG(dbgs() << "LAA: Adding RT check for a loop invariant ptr:"
+                      << *Ptr << "\n");
+    // Ptr could be in the loop body. If so, expand a new one at the correct
+    // location.
+    Instruction *Inst = dyn_cast<Instruction>(Ptr);
+    Value *NewPtr = (Inst && TheLoop->contains(Inst))
+                        ? Exp.expandCodeFor(Sc, PtrArithTy, Loc)
+                        : Ptr;
+    // We must return a half-open range, which means incrementing Sc.
+    const SCEV *ScPlusOne = SE->getAddExpr(Sc, SE->getOne(PtrArithTy));
+    Value *NewPtrPlusOne = Exp.expandCodeFor(ScPlusOne, PtrArithTy, Loc);
+    return {NewPtr, NewPtrPlusOne};
+  } else {
+    Value *Start = nullptr, *End = nullptr;
+    LLVM_DEBUG(dbgs() << "LAA: Adding RT check for range:\n");
+    Start = Exp.expandCodeFor(CG->Low, PtrArithTy, Loc);
+    End = Exp.expandCodeFor(CG->High, PtrArithTy, Loc);
+    LLVM_DEBUG(dbgs() << "Start: " << *CG->Low << " End: " << *CG->High
+                      << "\n");
+    return {Start, End};
+  }
+}
+
+/// Turns a collection of checks into a collection of expanded upper and
+/// lower bounds for both pointers in the check.
+static SmallVector<std::pair<PointerBounds, PointerBounds>, 4>
+expandBounds(const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, Loop *L,
+             Instruction *Loc, ScalarEvolution *SE, SCEVExpander &Exp) {
+  SmallVector<std::pair<PointerBounds, PointerBounds>, 4> ChecksWithBounds;
+
+  // Here we're relying on the SCEV Expander's cache to only emit code for the
+  // same bounds once.
+  transform(PointerChecks, std::back_inserter(ChecksWithBounds),
+            [&](const RuntimePointerCheck &Check) {
+              PointerBounds First = expandBounds(Check.first, L, Loc, Exp, SE),
+                            Second =
+                                expandBounds(Check.second, L, Loc, Exp, SE);
+              return std::make_pair(First, Second);
+            });
+
+  return ChecksWithBounds;
+}
+
+std::pair<Instruction *, Instruction *> llvm::addRuntimeChecks(
+    Instruction *Loc, Loop *TheLoop,
+    const SmallVectorImpl<RuntimePointerCheck> &PointerChecks,
+    ScalarEvolution *SE) {
+  // TODO: Move noalias annotation code from LoopVersioning here and share with LV if possible.
+  // TODO: Pass  RtPtrChecking instead of PointerChecks and SE separately, if possible
+  const DataLayout &DL = TheLoop->getHeader()->getModule()->getDataLayout();
+  SCEVExpander Exp(*SE, DL, "induction");
+  auto ExpandedChecks = expandBounds(PointerChecks, TheLoop, Loc, SE, Exp);
+
+  LLVMContext &Ctx = Loc->getContext();
+  Instruction *FirstInst = nullptr;
+  IRBuilder<> ChkBuilder(Loc);
+  // Our instructions might fold to a constant.
+  Value *MemoryRuntimeCheck = nullptr;
+
+  // FIXME: this helper is currently a duplicate of the one in
+  // LoopVectorize.cpp.
+  auto GetFirstInst = [](Instruction *FirstInst, Value *V,
+                         Instruction *Loc) -> Instruction * {
+    if (FirstInst)
+      return FirstInst;
+    if (Instruction *I = dyn_cast<Instruction>(V))
+      return I->getParent() == Loc->getParent() ? I : nullptr;
+    return nullptr;
+  };
+
+  for (const auto &Check : ExpandedChecks) {
+    const PointerBounds &A = Check.first, &B = Check.second;
+    // Check if two pointers (A and B) conflict where conflict is computed as:
+    // start(A) <= end(B) && start(B) <= end(A)
+    unsigned AS0 = A.Start->getType()->getPointerAddressSpace();
+    unsigned AS1 = B.Start->getType()->getPointerAddressSpace();
+
+    assert((AS0 == B.End->getType()->getPointerAddressSpace()) &&
+           (AS1 == A.End->getType()->getPointerAddressSpace()) &&
+           "Trying to bounds check pointers with different address spaces");
+
+    Type *PtrArithTy0 = Type::getInt8PtrTy(Ctx, AS0);
+    Type *PtrArithTy1 = Type::getInt8PtrTy(Ctx, AS1);
+
+    Value *Start0 = ChkBuilder.CreateBitCast(A.Start, PtrArithTy0, "bc");
+    Value *Start1 = ChkBuilder.CreateBitCast(B.Start, PtrArithTy1, "bc");
+    Value *End0 = ChkBuilder.CreateBitCast(A.End, PtrArithTy1, "bc");
+    Value *End1 = ChkBuilder.CreateBitCast(B.End, PtrArithTy0, "bc");
+
+    // [A|B].Start points to the first accessed byte under base [A|B].
+    // [A|B].End points to the last accessed byte, plus one.
+    // There is no conflict when the intervals are disjoint:
+    // NoConflict = (B.Start >= A.End) || (A.Start >= B.End)
+    //
+    // bound0 = (B.Start < A.End)
+    // bound1 = (A.Start < B.End)
+    //  IsConflict = bound0 & bound1
+    Value *Cmp0 = ChkBuilder.CreateICmpULT(Start0, End1, "bound0");
+    FirstInst = GetFirstInst(FirstInst, Cmp0, Loc);
+    Value *Cmp1 = ChkBuilder.CreateICmpULT(Start1, End0, "bound1");
+    FirstInst = GetFirstInst(FirstInst, Cmp1, Loc);
+    Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict");
+    FirstInst = GetFirstInst(FirstInst, IsConflict, Loc);
+    if (MemoryRuntimeCheck) {
+      IsConflict =
+          ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx");
+      FirstInst = GetFirstInst(FirstInst, IsConflict, Loc);
+    }
+    MemoryRuntimeCheck = IsConflict;
+  }
+
+  if (!MemoryRuntimeCheck)
+    return std::make_pair(nullptr, nullptr);
+
+  // We have to do this trickery because the IRBuilder might fold the check to a
+  // constant expression in which case there is no Instruction anchored in a
+  // the block.
+  Instruction *Check =
+      BinaryOperator::CreateAnd(MemoryRuntimeCheck, ConstantInt::getTrue(Ctx));
+  ChkBuilder.Insert(Check, "memcheck.conflict");
+  FirstInst = GetFirstInst(FirstInst, Check, Loc);
+  return std::make_pair(FirstInst, Check);
+}
index ce0fb62..30a783c 100644 (file)
@@ -62,8 +62,10 @@ void LoopVersioning::versionLoop(
 
   // Add the memcheck in the original preheader (this is empty initially).
   BasicBlock *RuntimeCheckBB = VersionedLoop->getLoopPreheader();
+  const auto &RtPtrChecking = *LAI.getRuntimePointerChecking();
   std::tie(FirstCheckInst, MemRuntimeCheck) =
-      LAI.addRuntimeChecks(RuntimeCheckBB->getTerminator(), AliasChecks);
+      addRuntimeChecks(RuntimeCheckBB->getTerminator(), VersionedLoop,
+                       AliasChecks, RtPtrChecking.getSE());
 
   const SCEVUnionPredicate &Pred = LAI.getPSE().getUnionPredicate();
   SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(),
index ce9641f..04726f5 100644 (file)
@@ -2791,8 +2791,9 @@ void InnerLoopVectorizer::emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass) {
     return;
   Instruction *FirstCheckInst;
   Instruction *MemRuntimeCheck;
-  std::tie(FirstCheckInst, MemRuntimeCheck) = LAI->addRuntimeChecks(
-      MemCheckBlock->getTerminator(), RtPtrChecking.getChecks());
+  std::tie(FirstCheckInst, MemRuntimeCheck) =
+      addRuntimeChecks(MemCheckBlock->getTerminator(), OrigLoop,
+                       RtPtrChecking.getChecks(), RtPtrChecking.getSE());
   if (!MemRuntimeCheck)
     return;