Remove dependency of Scop::getStmtFor(Inst) on getStmtFor(BB). NFC.
authorMichael Kruse <llvm@meinersbur.de>
Wed, 9 Aug 2017 16:45:37 +0000 (16:45 +0000)
committerMichael Kruse <llvm@meinersbur.de>
Wed, 9 Aug 2017 16:45:37 +0000 (16:45 +0000)
We are working towards removing uses of Scop::getStmtFor(BB). In this
patch, we remove dependency of Scop::getStmtFor(Inst) on getStmtFor(BB).
To do so, we introduce a map of instructions to their corresponding scop
statements and use it to get the instructions' statement.

Contributed-by: Nandini Singhal <cs15mtech01004@iith.ac.in>
Differential Revision: https://reviews.llvm.org/D35663

llvm-svn: 310494

polly/include/polly/ScopInfo.h
polly/lib/Analysis/ScopInfo.cpp
polly/lib/Transform/ForwardOpTree.cpp

index a3e1bcd..524a186 100644 (file)
@@ -1400,6 +1400,9 @@ public:
   bool contains(Instruction *Inst) const {
     if (!Inst)
       return false;
+    if (isBlockStmt())
+      return std::find(Instructions.begin(), Instructions.end(), Inst) !=
+             Instructions.end();
     return represents(Inst->getParent());
   }
 
@@ -1748,6 +1751,9 @@ private:
   /// vector comprises only of a single statement.
   DenseMap<BasicBlock *, std::vector<ScopStmt *>> StmtMap;
 
+  /// A map from instructions to SCoP statements.
+  DenseMap<Instruction *, ScopStmt *> InstStmtMap;
+
   /// A map from basic blocks to their domains.
   DenseMap<BasicBlock *, isl::set> DomainMap;
 
@@ -2697,10 +2703,6 @@ public:
   /// Get an isl string representing the invalid context.
   std::string getInvalidContextStr() const;
 
-  /// Return the ScopStmt for the given @p BB or nullptr if there is
-  ///        none.
-  ScopStmt *getStmtFor(BasicBlock *BB) const;
-
   /// Return the list of ScopStmts that represent the given @p BB.
   ArrayRef<ScopStmt *> getStmtListFor(BasicBlock *BB) const;
 
@@ -2724,7 +2726,7 @@ public:
   /// Return the ScopStmt an instruction belongs to, or nullptr if it
   ///        does not belong to any statement in this Scop.
   ScopStmt *getStmtFor(Instruction *Inst) const {
-    return getStmtFor(Inst->getParent());
+    return InstStmtMap.lookup(Inst);
   }
 
   /// Return the number of statements in the SCoP.
index 4816cec..3d2ee79 100644 (file)
@@ -1284,7 +1284,6 @@ void ScopStmt::addAccess(MemoryAccess *Access, bool Prepend) {
     MAL.emplace_front(Access);
   } else if (Access->isValueKind() && Access->isWrite()) {
     Instruction *AccessVal = cast<Instruction>(Access->getAccessValue());
-    assert(Parent.getStmtFor(AccessVal) == this);
     assert(!ValueWrites.lookup(AccessVal));
 
     ValueWrites[AccessVal] = Access;
@@ -3768,10 +3767,16 @@ void Scop::assumeNoOutOfBounds() {
 
 void Scop::removeFromStmtMap(ScopStmt &Stmt) {
   if (Stmt.isRegionStmt())
-    for (BasicBlock *BB : Stmt.getRegion()->blocks())
+    for (BasicBlock *BB : Stmt.getRegion()->blocks()) {
       StmtMap.erase(BB);
-  else
+      for (Instruction &Inst : *BB)
+        InstStmtMap.erase(&Inst);
+    }
+  else {
     StmtMap.erase(Stmt.getBasicBlock());
+    for (Instruction *Inst : Stmt.getInstructions())
+      InstStmtMap.erase(Inst);
+  }
 }
 
 void Scop::removeStmts(std::function<bool(ScopStmt &)> ShouldDelete) {
@@ -4090,11 +4095,13 @@ void Scop::verifyInvariantLoads() {
   auto &RIL = getRequiredInvariantLoads();
   for (LoadInst *LI : RIL) {
     assert(LI && contains(LI));
-    ScopStmt *Stmt = getStmtFor(LI);
-    if (Stmt && Stmt->getArrayAccessOrNULLFor(LI)) {
-      invalidate(INVARIANTLOAD, LI->getDebugLoc(), LI->getParent());
-      return;
-    }
+    // If there exists a statement in the scop which has a memory access for
+    // @p LI, then mark this scop as infeasible for optimization.
+    for (ScopStmt &Stmt : Stmts)
+      if (Stmt.getArrayAccessOrNULLFor(LI)) {
+        invalidate(INVARIANTLOAD, LI->getDebugLoc(), LI->getParent());
+        return;
+      }
   }
 }
 
@@ -4837,14 +4844,25 @@ void Scop::addScopStmt(BasicBlock *BB, Loop *SurroundingLoop,
   Stmts.emplace_back(*this, *BB, SurroundingLoop, Instructions);
   auto *Stmt = &Stmts.back();
   StmtMap[BB].push_back(Stmt);
+  for (Instruction *Inst : Instructions) {
+    assert(!InstStmtMap.count(Inst) &&
+           "Unexpected statement corresponding to the instruction.");
+    InstStmtMap[Inst] = Stmt;
+  }
 }
 
 void Scop::addScopStmt(Region *R, Loop *SurroundingLoop) {
   assert(R && "Unexpected nullptr!");
   Stmts.emplace_back(*this, *R, SurroundingLoop);
   auto *Stmt = &Stmts.back();
-  for (BasicBlock *BB : R->blocks())
+  for (BasicBlock *BB : R->blocks()) {
     StmtMap[BB].push_back(Stmt);
+    for (Instruction &Inst : *BB) {
+      assert(!InstStmtMap.count(&Inst) &&
+             "Unexpected statement corresponding to the instruction.");
+      InstStmtMap[&Inst] = Stmt;
+    }
+  }
 }
 
 ScopStmt *Scop::addScopStmt(isl::map SourceRel, isl::map TargetRel,
@@ -4989,14 +5007,6 @@ void Scop::buildSchedule(RegionNode *RN, LoopStackTy &LoopStack, LoopInfo &LI) {
   }
 }
 
-ScopStmt *Scop::getStmtFor(BasicBlock *BB) const {
-  auto StmtMapIt = StmtMap.find(BB);
-  if (StmtMapIt == StmtMap.end())
-    return nullptr;
-  assert(StmtMapIt->second.size() == 1);
-  return StmtMapIt->second.front();
-}
-
 ArrayRef<ScopStmt *> Scop::getStmtListFor(BasicBlock *BB) const {
   auto StmtMapIt = StmtMap.find(BB);
   if (StmtMapIt == StmtMap.end())
index 27f9f61..0b916c5 100644 (file)
@@ -666,9 +666,12 @@ public:
     case VirtualUse::Inter:
       Instruction *Inst = cast<Instruction>(UseVal);
 
-      if (!DefStmt)
+      if (!DefStmt) {
         DefStmt = S->getStmtFor(Inst);
-      assert(DefStmt && "Value must be defined somewhere");
+        if (!DefStmt)
+          return FD_CannotForward;
+      }
+
       DefLoop = LI->getLoopFor(Inst->getParent());
 
       if (DefToTarget.is_null() && !Known.is_null()) {