Refactor ForStmt: having it contain a StmtBlock instead of subclassing
authorChris Lattner <clattner@google.com>
Sun, 23 Dec 2018 16:17:48 +0000 (08:17 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 21:35:19 +0000 (14:35 -0700)
StmtBlock.  This is more consistent with IfStmt and also conceptually makes
more sense - a forstmt "isn't" its body, it contains its body.

This is step 1/N towards merging BasicBlock and StmtBlock.  This is required
because in the new regime StmtBlock will have a use list (just like BasicBlock
does) of operands, and ForStmt already has a use list for its induction
variable.

This is a mechanical patch, NFC.

PiperOrigin-RevId: 226684158

22 files changed:
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/Instructions.h
mlir/include/mlir/IR/Statements.h
mlir/include/mlir/IR/StmtBlock.h
mlir/include/mlir/IR/StmtVisitor.h
mlir/lib/Analysis/AffineAnalysis.cpp
mlir/lib/Analysis/LoopAnalysis.cpp
mlir/lib/Analysis/Utils.cpp
mlir/lib/Analysis/Verifier.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Instructions.cpp
mlir/lib/IR/Statement.cpp
mlir/lib/IR/StmtBlock.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Transforms/ConvertToCFG.cpp
mlir/lib/Transforms/DmaGeneration.cpp
mlir/lib/Transforms/LoopTiling.cpp
mlir/lib/Transforms/LoopUnroll.cpp
mlir/lib/Transforms/LoopUnrollAndJam.cpp
mlir/lib/Transforms/LowerVectorTransfers.cpp
mlir/lib/Transforms/PipelineDataTransfer.cpp
mlir/lib/Transforms/Utils/LoopUtils.cpp

index 4598896c31313197c2c69d14681bad443ccc222c..301e7853c3a48bb97f350e7684b21f760cff9da1 100644 (file)
@@ -345,7 +345,7 @@ public:
 
   /// Returns a builder for the body of a for Stmt.
   static MLFuncBuilder getForStmtBodyBuilder(ForStmt *forStmt) {
-    return MLFuncBuilder(forStmt, forStmt->end());
+    return MLFuncBuilder(forStmt->getBody(), forStmt->getBody()->end());
   }
 
   /// Returns the current insertion point of the builder.
index 2c89722f2c8cf31bcce2af56babafb048d4c5566..5885cbd38868cd7b5b0b7bad16d1c3d5722074ba 100644 (file)
@@ -228,7 +228,7 @@ public:
     assert(index < getNumSuccessors());
     return getBasicBlockOperands()[index].get();
   }
-  BasicBlock *getSuccessor(unsigned index) const {
+  const BasicBlock *getSuccessor(unsigned index) const {
     return const_cast<Instruction *>(this)->getSuccessor(index);
   }
   void setSuccessor(BasicBlock *block, unsigned index);
index 9b3914773368963817ecbb330ac29a5c71aed1ae..28f5a14540dc14852fc6eb1f56c105795b42b03d 100644 (file)
@@ -216,8 +216,31 @@ private:
   }
 };
 
+/// A ForStmtBody represents statements contained within a ForStmt.
+class ForStmtBody : public StmtBlock {
+public:
+  explicit ForStmtBody(ForStmt *stmt)
+      : StmtBlock(StmtBlockKind::ForBody), forStmt(stmt) {
+    assert(stmt != nullptr && "ForStmtBody must have non-null parent");
+  }
+
+  ~ForStmtBody() {}
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast
+  static bool classof(const StmtBlock *block) {
+    return block->getStmtBlockKind() == StmtBlockKind::ForBody;
+  }
+
+  /// Returns the 'for' statement that contains this body.
+  ForStmt *getFor() { return forStmt; }
+  const ForStmt *getFor() const { return forStmt; }
+
+private:
+  ForStmt *forStmt;
+};
+
 /// For statement represents an affine loop nest.
-class ForStmt : public Statement, public MLValue, public StmtBlock {
+class ForStmt : public Statement, public MLValue {
 public:
   static ForStmt *create(Location location, ArrayRef<MLValue *> lbOperands,
                          AffineMap lbMap, ArrayRef<MLValue *> ubOperands,
@@ -228,7 +251,7 @@ public:
     // since child statements need to be destroyed before the MLValue that this
     // for stmt represents is destroyed. Affine maps are immortal objects and
     // don't need to be deleted.
-    clear();
+    getBody()->clear();
   }
 
   /// Resolve base class ambiguity.
@@ -242,6 +265,12 @@ public:
   using operand_range = llvm::iterator_range<operand_iterator>;
   using const_operand_range = llvm::iterator_range<const_operand_iterator>;
 
+  /// Get the body of the ForStmt.
+  ForStmtBody *getBody() { return &body; }
+
+  /// Get the body of the ForStmt.
+  const ForStmtBody *getBody() const { return &body; }
+
   //===--------------------------------------------------------------------===//
   // Bounds and step
   //===--------------------------------------------------------------------===//
@@ -359,10 +388,6 @@ public:
     return ptr->getKind() == IROperandOwner::Kind::ForStmt;
   }
 
-  static bool classof(const StmtBlock *block) {
-    return block->getStmtBlockKind() == StmtBlockKind::For;
-  }
-
   // For statement represents implicitly represents induction variable by
   // inheriting from MLValue class. Whenever you need to refer to the loop
   // induction variable, just use the for statement itself.
@@ -371,6 +396,9 @@ public:
   }
 
 private:
+  // The StmtBlock for the body.
+  ForStmtBody body;
+
   // Affine map for the lower bound.
   AffineMap lbMap;
   // Affine map for the upper bound. The upper bound is exclusive.
@@ -456,7 +484,9 @@ public:
   ~IfClause() {}
 
   /// Returns the if statement that contains this clause.
-  IfStmt *getIf() const { return ifStmt; }
+  const IfStmt *getIf() const { return ifStmt; }
+
+  IfStmt *getIf() { return ifStmt; }
 
 private:
   IfStmt *ifStmt;
index 6ee37d5472b4f6ef01c6bb328644f546c90c5df2..a9a35e564c77acc7324a428a041eb9115e173a47 100644 (file)
@@ -36,7 +36,7 @@ class StmtBlock {
 public:
   enum class StmtBlockKind {
     MLFunc,  // MLFunction
-    For,     // ForStmt
+    ForBody, // ForStmtBody
     IfClause // IfClause
   };
 
@@ -53,7 +53,11 @@ public:
 
   /// Returns the closest surrounding statement that contains this block or
   /// nullptr if this is a top-level statement block.
-  Statement *getContainingStmt() const;
+  Statement *getContainingStmt();
+
+  const Statement *getContainingStmt() const {
+    return const_cast<StmtBlock *>(this)->getContainingStmt();
+  }
 
   /// Returns the function that this statement block is part of.
   /// The function is determined by traversing the chain of parent statements.
index 21d98f14aa0b58284608200aa6639300b7687139..94bc0b0cdc13dcc331bb2decdcb93c5575d2c6c9 100644 (file)
@@ -146,12 +146,13 @@ public:
 
   void walkForStmt(ForStmt *forStmt) {
     static_cast<SubClass *>(this)->visitForStmt(forStmt);
-    static_cast<SubClass *>(this)->walk(forStmt->begin(), forStmt->end());
+    auto *body = forStmt->getBody();
+    static_cast<SubClass *>(this)->walk(body->begin(), body->end());
   }
 
   void walkForStmtPostOrder(ForStmt *forStmt) {
-    static_cast<SubClass *>(this)->walkPostOrder(forStmt->begin(),
-                                                 forStmt->end());
+    auto *body = forStmt->getBody();
+    static_cast<SubClass *>(this)->walkPostOrder(body->begin(), body->end());
     static_cast<SubClass *>(this)->visitForStmt(forStmt);
   }
 
index 91f4ccf480464ef1c8c7dfa178d6173fb236073f..bdc2c7ec2863aeeb13110defdbfcc4e82c328567 100644 (file)
@@ -905,7 +905,7 @@ static StmtBlock *getCommonStmtBlock(const MemRefAccess &srcAccess,
   }
   auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1);
   assert(isa<ForStmt>(commonForValue));
-  return dyn_cast<ForStmt>(commonForValue);
+  return cast<ForStmt>(commonForValue)->getBody();
 }
 
 // Returns true if the ancestor operation statement of 'srcAccess' properly
index 5e6bd7fa59bd512e534fd26bbd48a3e133dc0a37..3ee62bb2c42088ef7b8e0941b22c577312d970af 100644 (file)
@@ -305,9 +305,10 @@ bool mlir::isVectorizableLoop(const ForStmt &loop) {
 // violation when we have the support.
 bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt,
                                 ArrayRef<uint64_t> shifts) {
-  assert(shifts.size() == forStmt.getStatements().size());
+  auto *forBody = forStmt.getBody();
+  assert(shifts.size() == forBody->getStatements().size());
   unsigned s = 0;
-  for (const auto &stmt : forStmt) {
+  for (const auto &stmt : *forBody) {
     // A for or if stmt does not produce any def/results (that are used
     // outside).
     if (const auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
@@ -319,8 +320,8 @@ bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt,
           // This is a naive way. If performance becomes an issue, a map can
           // be used to store 'shifts' - to look up the shift for a statement in
           // constant time.
-          if (auto *ancStmt = forStmt.findAncestorStmtInBlock(*use.getOwner()))
-            if (shifts[s] != shifts[forStmt.findStmtPosInBlock(*ancStmt)])
+          if (auto *ancStmt = forBody->findAncestorStmtInBlock(*use.getOwner()))
+            if (shifts[s] != shifts[forBody->findStmtPosInBlock(*ancStmt)])
               return false;
         }
       }
index cc30cfffb063629793a21bf37674528c760f7382..2428265acdb8a2f4438e003264a59af3d574977a 100644 (file)
@@ -362,7 +362,7 @@ static Statement *getStmtAtPosition(ArrayRef<unsigned> positions,
     if (level == positions.size() - 1)
       return &stmt;
     if (auto *childForStmt = dyn_cast<ForStmt>(&stmt))
-      return getStmtAtPosition(positions, level + 1, childForStmt);
+      return getStmtAtPosition(positions, level + 1, childForStmt->getBody());
 
     if (auto *ifStmt = dyn_cast<IfStmt>(&stmt)) {
       auto *ret = getStmtAtPosition(positions, level + 1, ifStmt->getThen());
@@ -453,13 +453,13 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
   // Clone src loop nest and insert it a the beginning of the statement block
   // of the loop at 'dstLoopDepth' in 'dstLoopNest'.
   auto *dstForStmt = dstLoopNest[dstLoopDepth - 1];
-  MLFuncBuilder b(dstForStmt, dstForStmt->begin());
+  MLFuncBuilder b(dstForStmt->getBody(), dstForStmt->getBody()->begin());
   DenseMap<const MLValue *, MLValue *> operandMap;
   auto *sliceLoopNest = cast<ForStmt>(b.clone(*srcLoopNest[0], operandMap));
 
   // Lookup stmt in cloned 'sliceLoopNest' at 'positions'.
   Statement *sliceStmt =
-      getStmtAtPosition(positions, /*level=*/0, sliceLoopNest);
+      getStmtAtPosition(positions, /*level=*/0, sliceLoopNest->getBody());
   // Get loop nest surrounding 'sliceStmt'.
   SmallVector<ForStmt *, 4> sliceSurroundingLoops;
   getLoopIVs(*sliceStmt, &sliceSurroundingLoops);
index d955bcd5edb66abcf731c10ed5d6db0063f5c4ec..6e1522a656f56e266f6e5f792ffafa5782eee1d9 100644 (file)
@@ -288,8 +288,8 @@ bool MLFuncVerifier::verifyDominance() {
     HashTable::ScopeTy blockScope(liveValues);
 
     // The induction variable of a for statement is live within its body.
-    if (auto *forStmt = dyn_cast<ForStmt>(&block))
-      liveValues.insert(forStmt, true);
+    if (auto *forStmtBody = dyn_cast<ForStmtBody>(&block))
+      liveValues.insert(forStmtBody->getFor(), true);
 
     for (auto &stmt : block) {
       // Verify that each of the operands are live.
@@ -322,7 +322,7 @@ bool MLFuncVerifier::verifyDominance() {
             return true;
       }
       if (auto *forStmt = dyn_cast<ForStmt>(&stmt))
-        if (walkBlock(*forStmt))
+        if (walkBlock(*forStmt->getBody()))
           return true;
     }
 
index b798e3890a028e31f495cc1c84ac425df436e507..58f34af60f50a7ca1d7f709025d597d18875566a 100644 (file)
@@ -206,7 +206,7 @@ void ModuleState::visitForStmt(const ForStmt *forStmt) {
   if (!hasShorthandForm(ubMap))
     recordAffineMapReference(ubMap);
 
-  for (auto &childStmt : *forStmt)
+  for (auto &childStmt : *forStmt->getBody())
     visitStatement(&childStmt);
 }
 
@@ -1447,7 +1447,7 @@ void MLFunctionPrinter::print(const ForStmt *stmt) {
     os << " step " << stmt->getStep();
 
   os << " {\n";
-  print(static_cast<const StmtBlock *>(stmt));
+  print(stmt->getBody());
   os.indent(numSpaces) << "}";
 }
 
index 9d65f4376b3ef81fbac8a788ddb70bb9f53c8ede..de73f3a96d391b7e59837ca082c4fd1df41b289a 100644 (file)
@@ -147,7 +147,7 @@ Instruction *Instruction::clone() const {
     int cloneOperandIt = operands.size() - 1, operandIt = getNumOperands() - 1;
     for (int succIt = getNumSuccessors() - 1, succE = 0; succIt >= succE;
          --succIt) {
-      successors[succIt] = getSuccessor(succIt);
+      successors[succIt] = const_cast<BasicBlock *>(getSuccessor(succIt));
 
       // Add the successor operands in-place in reverse order.
       for (unsigned i = 0, e = getNumSuccessorOperands(succIt); i != e;
index 69afc5c1e98852c5bf4f92d6803a2d74076ae681..f63c76605de865216668b9ef1dfff35054ea3d2d 100644 (file)
@@ -338,7 +338,7 @@ ForStmt::ForStmt(Location location, unsigned numOperands, AffineMap lbMap,
     : Statement(Kind::For, location),
       MLValue(MLValueKind::ForStmt,
               Type::getIndex(lbMap.getResult(0).getContext())),
-      StmtBlock(StmtBlockKind::For), lbMap(lbMap), ubMap(ubMap), step(step) {
+      body(this), lbMap(lbMap), ubMap(ubMap), step(step) {
   operands.reserve(numOperands);
 }
 
@@ -544,8 +544,8 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
     operandMap[forStmt] = newFor;
 
     // Recursively clone the body of the for loop.
-    for (auto &subStmt : *forStmt)
-      newFor->push_back(subStmt.clone(operandMap, context));
+    for (auto &subStmt : *forStmt->getBody())
+      newFor->getBody()->push_back(subStmt.clone(operandMap, context));
 
     return newFor;
   }
index 40a31f6c3b99b6474f53eaf0c66d1467811cd979..898dd7bc337b9ffab8ce65f355a5ea6dd3ea5836 100644 (file)
@@ -24,18 +24,19 @@ using namespace mlir;
 // Statement block
 //===----------------------------------------------------------------------===//
 
-Statement *StmtBlock::getContainingStmt() const {
+Statement *StmtBlock::getContainingStmt() {
   switch (kind) {
   case StmtBlockKind::MLFunc:
     return nullptr;
-  case StmtBlockKind::For:
-    return cast<ForStmt>(const_cast<StmtBlock *>(this));
+  case StmtBlockKind::ForBody:
+    return cast<ForStmtBody>(this)->getFor();
   case StmtBlockKind::IfClause:
     return cast<IfClause>(this)->getIf();
   }
 }
 
 MLFunction *StmtBlock::findFunction() const {
+  // FIXME: const incorrect.
   StmtBlock *block = const_cast<StmtBlock *>(this);
 
   while (block->getContainingStmt()) {
index 46dd35682fd60562b496975f2603783555fb9411..781ec461b626cb0a3096e9221527ab2b2526fa58 100644 (file)
@@ -2876,7 +2876,7 @@ ParseResult MLFunctionParser::parseForStmt() {
   // If parsing of the for statement body fails,
   // MLIR contains for statement with those nested statements that have been
   // successfully parsed.
-  if (parseStmtBlock(forStmt))
+  if (parseStmtBlock(forStmt->getBody()))
     return ParseFailure;
 
   // Reset insertion point to the current block.
index 247a264cd5ce4fc9f792815afea52d443a7a7341..0ed803db64dd75ebd0f7e4871ca2d44c7a5a72af 100644 (file)
@@ -242,7 +242,7 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
   // Walking manually because we need custom logic before and after traversing
   // the list of children.
   builder.setInsertionPoint(loopBodyFirstBlock);
-  visitStmtBlock(forStmt);
+  visitStmtBlock(forStmt->getBody());
 
   // Builder point is currently at the last block of the loop body.  Append the
   // induction variable stepping to this block and branch back to the exit
index bd7cad7fd3d406c191b93a0f49792a3a9b8edfe4..2b79064e53fb829781a0e954389fd28e34a9eb51 100644 (file)
@@ -365,7 +365,7 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
   replaceAllMemRefUsesWith(memref, cast<MLValue>(fastMemRef),
                            /*extraIndices=*/{}, indexRemap,
                            /*extraOperands=*/outerIVs,
-                           /*domStmtFilter=*/&*forStmt->begin());
+                           /*domStmtFilter=*/&*forStmt->getBody()->begin());
   return true;
 }
 
@@ -391,7 +391,7 @@ void DmaGeneration::runOnForStmt(ForStmt *forStmt) {
   // the pass has to be instantiated with additional information that we aren't
   // provided with at the moment.
   if (forStmt->getStep() != 1) {
-    if (auto *innerFor = dyn_cast<ForStmt>(&*forStmt->begin())) {
+    if (auto *innerFor = dyn_cast<ForStmt>(&*forStmt->getBody()->begin())) {
       runOnForStmt(innerFor);
     }
     return;
index 85c88f785d16baef45e581726e9015bee7ec20f5..847db83aebc40da0d32ffc83fdf70c3c6c97ed83 100644 (file)
@@ -59,12 +59,12 @@ FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); }
 // destination's body.
 static inline void moveLoopBody(ForStmt *src, ForStmt *dest,
                                 StmtBlock::iterator loc) {
-  dest->getStatements().splice(loc, src->getStatements());
+  dest->getBody()->getStatements().splice(loc, src->getBody()->getStatements());
 }
 
 // Move the loop body of ForStmt 'src' from 'src' to the start of dest's body.
 static inline void moveLoopBody(ForStmt *src, ForStmt *dest) {
-  moveLoopBody(src, dest, dest->begin());
+  moveLoopBody(src, dest, dest->getBody()->begin());
 }
 
 /// Constructs and sets new loop bounds after tiling for the case of
@@ -167,8 +167,9 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
     MLFuncBuilder b(topLoop);
     // Loop bounds will be set later.
     auto *pointLoop = b.createFor(loc, 0, 0);
-    pointLoop->getStatements().splice(
-        pointLoop->begin(), topLoop->getBlock()->getStatements(), topLoop);
+    pointLoop->getBody()->getStatements().splice(
+        pointLoop->getBody()->begin(), topLoop->getBlock()->getStatements(),
+        topLoop);
     newLoops[2 * width - 1 - i] = pointLoop;
     topLoop = pointLoop;
     if (i == 0)
@@ -180,8 +181,9 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
     MLFuncBuilder b(topLoop);
     // Loop bounds will be set later.
     auto *tileSpaceLoop = b.createFor(loc, 0, 0);
-    tileSpaceLoop->getStatements().splice(
-        tileSpaceLoop->begin(), topLoop->getBlock()->getStatements(), topLoop);
+    tileSpaceLoop->getBody()->getStatements().splice(
+        tileSpaceLoop->getBody()->begin(), topLoop->getBlock()->getStatements(),
+        topLoop);
     newLoops[2 * width - i - 1] = tileSpaceLoop;
     topLoop = tileSpaceLoop;
   }
@@ -223,8 +225,8 @@ static void getTileableBands(MLFunction *f,
     ForStmt *currStmt = root;
     do {
       band.push_back(currStmt);
-    } while (currStmt->getStatements().size() == 1 &&
-             (currStmt = dyn_cast<ForStmt>(&*currStmt->begin())));
+    } while (currStmt->getBody()->getStatements().size() == 1 &&
+             (currStmt = dyn_cast<ForStmt>(&*currStmt->getBody()->begin())));
     bands->push_back(band);
   };
 
index a43087bd2e1572947d31226a7704829838984f51..183613a2f69a4a754d1bb746c13622bb957bfbe4 100644 (file)
@@ -104,7 +104,8 @@ PassResult LoopUnroll::runOnMLFunction(MLFunction *f) {
     }
 
     bool walkForStmtPostOrder(ForStmt *forStmt) {
-      bool hasInnerLoops = walkPostOrder(forStmt->begin(), forStmt->end());
+      bool hasInnerLoops =
+          walkPostOrder(forStmt->getBody()->begin(), forStmt->getBody()->end());
       if (!hasInnerLoops)
         loops.push_back(forStmt);
       return true;
index 45ca9dd98dfd7590ee8a036d88dfb7c5992fb382..dd491f8119b98e89543798c1d2d18b0cc79db3cb 100644 (file)
@@ -152,7 +152,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
 
   assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1");
 
-  if (unrollJamFactor == 1 || forStmt->getStatements().empty())
+  if (unrollJamFactor == 1 || forStmt->getBody()->empty())
     return false;
 
   Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
index df30a7794614cd47c9ef065807832fdb87c7fdea..d4069eaa638ba454e6e44834d4f20993a7164042 100644 (file)
@@ -147,7 +147,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
     auto *forStmt = b.createFor(transfer->getLoc(), 0, it.value());
     loops.insert(forStmt);
     // Setting the insertion point to the innermost loop achieves nesting.
-    b.setInsertionPointToStart(loops.back());
+    b.setInsertionPointToStart(loops.back()->getBody());
     if (composed == getAffineConstantExpr(0, b.getContext())) {
       transfer->emitWarning(
           "Redundant copy can be implemented as a vector broadcast");
index b656af0d69d63d5f9ebd443d26e3046d25e69caf..8d75bfbd7ae5acd7b39e335cca51f25907453f7a 100644 (file)
@@ -81,8 +81,9 @@ static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
 /// the loop IV of the specified 'for' statement modulo 2. Returns false if such
 /// a replacement cannot be performed.
 static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
-  MLFuncBuilder bInner(forStmt, forStmt->begin());
-  bInner.setInsertionPoint(forStmt, forStmt->begin());
+  auto *forBody = forStmt->getBody();
+  MLFuncBuilder bInner(forBody, forBody->begin());
+  bInner.setInsertionPoint(forBody, forBody->begin());
 
   // Doubles the shape with a leading dimension extent of 2.
   auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
@@ -127,7 +128,7 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
   // non-deferencing uses of the memref.
   if (!replaceAllMemRefUsesWith(oldMemRef, cast<MLValue>(newMemRef),
                                 ivModTwoOp->getResult(0), AffineMap::Null(), {},
-                                &*forStmt->begin())) {
+                                &*forStmt->getBody()->begin())) {
     LLVM_DEBUG(llvm::dbgs()
                    << "memref replacement for double buffering failed\n";);
     ivModTwoOp->getOperation()->erase();
@@ -184,7 +185,7 @@ static void findMatchingStartFinishStmts(
 
   // Collect outgoing DMA statements - needed to check for dependences below.
   SmallVector<OpPointer<DmaStartOp>, 4> outgoingDmaOps;
-  for (auto &stmt : *forStmt) {
+  for (auto &stmt : *forStmt->getBody()) {
     auto *opStmt = dyn_cast<OperationStmt>(&stmt);
     if (!opStmt)
       continue;
@@ -195,7 +196,7 @@ static void findMatchingStartFinishStmts(
   }
 
   SmallVector<OperationStmt *, 4> dmaStartStmts, dmaFinishStmts;
-  for (auto &stmt : *forStmt) {
+  for (auto &stmt : *forStmt->getBody()) {
     auto *opStmt = dyn_cast<OperationStmt>(&stmt);
     if (!opStmt)
       continue;
@@ -228,7 +229,7 @@ static void findMatchingStartFinishStmts(
         cast<MLValue>(dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()));
     bool escapingUses = false;
     for (const auto &use : memref->getUses()) {
-      if (!dominates(*forStmt->begin(), *use.getOwner())) {
+      if (!dominates(*forStmt->getBody()->begin(), *use.getOwner())) {
         LLVM_DEBUG(llvm::dbgs()
                        << "can't pipeline: buffer is live out of loop\n";);
         escapingUses = true;
@@ -339,16 +340,16 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
     }
   }
   // Everything else (including compute ops and dma finish) are shifted by one.
-  for (const auto &stmt : *forStmt) {
+  for (const auto &stmt : *forStmt->getBody()) {
     if (stmtShiftMap.find(&stmt) == stmtShiftMap.end()) {
       stmtShiftMap[&stmt] = 1;
     }
   }
 
   // Get shifts stored in map.
-  std::vector<uint64_t> shifts(forStmt->getStatements().size());
+  std::vector<uint64_t> shifts(forStmt->getBody()->getStatements().size());
   unsigned s = 0;
-  for (auto &stmt : *forStmt) {
+  for (auto &stmt : *forStmt->getBody()) {
     assert(stmtShiftMap.find(&stmt) != stmtShiftMap.end());
     shifts[s++] = stmtShiftMap[&stmt];
     LLVM_DEBUG(
index 791997e7ff166b227ad47b256bf7f83605ee1b8a..4d75f7c0835524afbe6023830b5e5eb75eddf590 100644 (file)
@@ -119,7 +119,7 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
   // Move the loop body statements to the loop's containing block.
   auto *block = forStmt->getBlock();
   block->getStatements().splice(StmtBlock::iterator(forStmt),
-                                forStmt->getStatements());
+                                forStmt->getBody()->getStatements());
   forStmt->erase();
   return true;
 }
@@ -181,7 +181,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
       operandMap[srcForStmt] = loopChunk;
     }
     for (auto *stmt : stmts) {
-      loopChunk->push_back(stmt->clone(operandMap, b->getContext()));
+      loopChunk->getBody()->push_back(stmt->clone(operandMap, b->getContext()));
     }
   }
   if (promoteIfSingleIteration(loopChunk))
@@ -206,7 +206,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
 // method.
 UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
                               bool unrollPrologueEpilogue) {
-  if (forStmt->getStatements().empty())
+  if (forStmt->getBody()->empty())
     return UtilResult::Success;
 
   // If the trip counts aren't constant, we would need versioning and
@@ -225,7 +225,7 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
 
   int64_t step = forStmt->getStep();
 
-  unsigned numChildStmts = forStmt->getStatements().size();
+  unsigned numChildStmts = forStmt->getBody()->getStatements().size();
 
   // Do a linear time (counting) sort for the shifts.
   uint64_t maxShift = 0;
@@ -243,7 +243,7 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
   // body of the 'for' stmt.
   std::vector<std::vector<Statement *>> sortedStmtGroups(maxShift + 1);
   unsigned pos = 0;
-  for (auto &stmt : *forStmt) {
+  for (auto &stmt : *forStmt->getBody()) {
     auto shift = shifts[pos++];
     sortedStmtGroups[shift].push_back(&stmt);
   }
@@ -352,7 +352,7 @@ bool mlir::loopUnrollUpToFactor(ForStmt *forStmt, uint64_t unrollFactor) {
 bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
   assert(unrollFactor >= 1 && "unroll factor should be >= 1");
 
-  if (unrollFactor == 1 || forStmt->getStatements().empty())
+  if (unrollFactor == 1 || forStmt->getBody()->empty())
     return false;
 
   auto lbMap = forStmt->getLowerBoundMap();
@@ -406,11 +406,11 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
 
   // Builder to insert unrolled bodies right after the last statement in the
   // body of 'forStmt'.
-  MLFuncBuilder builder(forStmt, StmtBlock::iterator(forStmt->end()));
+  MLFuncBuilder builder(forStmt->getBody(), forStmt->getBody()->end());
 
   // Keep a pointer to the last statement in the original block so that we know
   // what to clone (since we are doing this in-place).
-  StmtBlock::iterator srcBlockEnd = std::prev(forStmt->end());
+  StmtBlock::iterator srcBlockEnd = std::prev(forStmt->getBody()->end());
 
   // Unroll the contents of 'forStmt' (append unrollFactor-1 additional copies).
   for (unsigned i = 1; i < unrollFactor; i++) {
@@ -429,7 +429,8 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
     }
 
     // Clone the original body of 'forStmt'.
-    for (auto it = forStmt->begin(); it != std::next(srcBlockEnd); it++) {
+    for (auto it = forStmt->getBody()->begin(); it != std::next(srcBlockEnd);
+         it++) {
       builder.clone(*it, operandMap);
     }
   }