Move transform utilities out to their own TransformUtils library, instead of
authorChris Lattner <clattner@google.com>
Thu, 25 Oct 2018 20:59:51 +0000 (13:59 -0700)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 20:39:06 +0000 (13:39 -0700)
just having the pattern matcher in its own library.  At this point,
lib/Transforms/*.cpp are all actually passes themselves (and will probably
eventually be themselves move to a new subdirectory as we accrete more).

PiperOrigin-RevId: 218745193

mlir/lib/Transforms/GreedyPatternRewriteDriver.cpp [deleted file]
mlir/lib/Transforms/LoopUtils.cpp [deleted file]
mlir/lib/Transforms/Pass.cpp [deleted file]
mlir/lib/Transforms/PatternMatch.cpp [deleted file]
mlir/lib/Transforms/Utils.cpp [deleted file]
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp [new file with mode: 0644]
mlir/lib/Transforms/Utils/LoopUtils.cpp [new file with mode: 0644]
mlir/lib/Transforms/Utils/Pass.cpp [new file with mode: 0644]
mlir/lib/Transforms/Utils/PatternMatch.cpp [new file with mode: 0644]
mlir/lib/Transforms/Utils/Utils.cpp [new file with mode: 0644]

diff --git a/mlir/lib/Transforms/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/GreedyPatternRewriteDriver.cpp
deleted file mode 100644 (file)
index 5ed8eac..0000000
+++ /dev/null
@@ -1,343 +0,0 @@
-//===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file implements mlir::applyPatternsGreedily.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/StandardOps/StandardOps.h"
-#include "mlir/Transforms/PatternMatch.h"
-#include "llvm/ADT/DenseMap.h"
-using namespace mlir;
-
-namespace {
-class WorklistRewriter;
-
-/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
-/// applies the locally optimal patterns in a roughly "bottom up" way.
-class GreedyPatternRewriteDriver {
-public:
-  explicit GreedyPatternRewriteDriver(OwningPatternList &&patterns)
-      : matcher(std::move(patterns)) {
-    worklist.reserve(64);
-  }
-
-  void simplifyFunction(Function *currentFunction, WorklistRewriter &rewriter);
-
-  void addToWorklist(Operation *op) {
-    worklistMap[op] = worklist.size();
-    worklist.push_back(op);
-  }
-
-  Operation *popFromWorklist() {
-    auto *op = worklist.back();
-    worklist.pop_back();
-
-    // This operation is no longer in the worklist, keep worklistMap up to date.
-    if (op)
-      worklistMap.erase(op);
-    return op;
-  }
-
-  /// If the specified operation is in the worklist, remove it.  If not, this is
-  /// a no-op.
-  void removeFromWorklist(Operation *op) {
-    auto it = worklistMap.find(op);
-    if (it != worklistMap.end()) {
-      assert(worklist[it->second] == op && "malformed worklist data structure");
-      worklist[it->second] = nullptr;
-    }
-  }
-
-private:
-  /// The low-level pattern matcher.
-  PatternMatcher matcher;
-
-  /// The worklist for this transformation keeps track of the operations that
-  /// need to be revisited, plus their index in the worklist.  This allows us to
-  /// efficiently remove operations from the worklist when they are removed even
-  /// if they aren't the root of a pattern.
-  std::vector<Operation *> worklist;
-  DenseMap<Operation *, unsigned> worklistMap;
-
-  /// As part of canonicalization, we move constants to the top of the entry
-  /// block of the current function and de-duplicate them.  This keeps track of
-  /// constants we have done this for.
-  DenseMap<std::pair<Attribute *, Type *>, Operation *> uniquedConstants;
-};
-}; // end anonymous namespace
-
-/// This is a listener object that updates our worklists and other data
-/// structures in response to operations being added and removed.
-namespace {
-class WorklistRewriter : public PatternRewriter {
-public:
-  WorklistRewriter(GreedyPatternRewriteDriver &driver, MLIRContext *context)
-      : PatternRewriter(context), driver(driver) {}
-
-  virtual void setInsertionPoint(Operation *op) = 0;
-
-  // If an operation is about to be removed, make sure it is not in our
-  // worklist anymore because we'd get dangling references to it.
-  void notifyOperationRemoved(Operation *op) override {
-    driver.removeFromWorklist(op);
-  }
-
-  GreedyPatternRewriteDriver &driver;
-};
-
-} // end anonymous namespace
-
-void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
-                                                  WorklistRewriter &rewriter) {
-  // These are scratch vectors used in the constant folding loop below.
-  SmallVector<Attribute *, 8> operandConstants, resultConstants;
-
-  while (!worklist.empty()) {
-    auto *op = popFromWorklist();
-
-    // Nulls get added to the worklist when operations are removed, ignore them.
-    if (op == nullptr)
-      continue;
-
-    // If we have a constant op, unique it into the entry block.
-    if (auto constant = op->dyn_cast<ConstantOp>()) {
-      // If this constant is dead, remove it, being careful to keep
-      // uniquedConstants up to date.
-      if (constant->use_empty()) {
-        auto it =
-            uniquedConstants.find({constant->getValue(), constant->getType()});
-        if (it != uniquedConstants.end() && it->second == op)
-          uniquedConstants.erase(it);
-        constant->erase();
-        continue;
-      }
-
-      // Check to see if we already have a constant with this type and value:
-      auto &entry = uniquedConstants[std::make_pair(constant->getValue(),
-                                                    constant->getType())];
-      if (entry) {
-        // If this constant is already our uniqued one, then leave it alone.
-        if (entry == op)
-          continue;
-
-        // Otherwise replace this redundant constant with the uniqued one.  We
-        // know this is safe because we move constants to the top of the
-        // function when they are uniqued, so we know they dominate all uses.
-        constant->replaceAllUsesWith(entry->getResult(0));
-        constant->erase();
-        continue;
-      }
-
-      // If we have no entry, then we should unique this constant as the
-      // canonical version.  To ensure safe dominance, move the operation to the
-      // top of the function.
-      entry = op;
-
-      // TODO: If we make terminators into Operations then we could turn this
-      // into a nice Operation::moveBefore(Operation*) method.  We just need the
-      // guarantee that a block is non-empty.
-      if (auto *cfgFunc = dyn_cast<CFGFunction>(currentFunction)) {
-        auto &entryBB = cfgFunc->front();
-        cast<OperationInst>(op)->moveBefore(&entryBB, entryBB.begin());
-      } else {
-        auto *mlFunc = cast<MLFunction>(currentFunction);
-        cast<OperationStmt>(op)->moveBefore(mlFunc, mlFunc->begin());
-      }
-
-      continue;
-    }
-
-    // If the operation has no side effects, and no users, then it is trivially
-    // dead - remove it.
-    if (op->hasNoSideEffect() && op->use_empty()) {
-      op->erase();
-      continue;
-    }
-
-    // Check to see if any operands to the instruction is constant and whether
-    // the operation knows how to constant fold itself.
-    operandConstants.clear();
-    for (auto *operand : op->getOperands()) {
-      Attribute *operandCst = nullptr;
-      if (auto *operandOp = operand->getDefiningOperation()) {
-        if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
-          operandCst = operandConstantOp->getValue();
-      }
-      operandConstants.push_back(operandCst);
-    }
-
-    // If constant folding was successful, create the result constants, RAUW the
-    // operation and remove it.
-    resultConstants.clear();
-    if (!op->constantFold(operandConstants, resultConstants)) {
-      rewriter.setInsertionPoint(op);
-
-      for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
-        auto *res = op->getResult(i);
-        if (res->use_empty()) // ignore dead uses.
-          continue;
-
-        // If we already have a canonicalized version of this constant, just
-        // reuse it.  Otherwise create a new one.
-        SSAValue *cstValue;
-        auto it = uniquedConstants.find({resultConstants[i], res->getType()});
-        if (it != uniquedConstants.end())
-          cstValue = it->second->getResult(0);
-        else
-          cstValue = rewriter.create<ConstantOp>(
-              op->getLoc(), resultConstants[i], res->getType());
-        res->replaceAllUsesWith(cstValue);
-      }
-
-      assert(op->hasNoSideEffect() && "Constant folded op with side effects?");
-      op->erase();
-      continue;
-    }
-
-    // If this is an associative binary operation with a constant on the LHS,
-    // move it to the right side.
-    if (operandConstants.size() == 2 && operandConstants[0] &&
-        !operandConstants[1]) {
-      auto *newLHS = op->getOperand(1);
-      op->setOperand(1, op->getOperand(0));
-      op->setOperand(0, newLHS);
-    }
-
-    // Check to see if we have any patterns that match this node.
-    auto match = matcher.findMatch(op);
-    if (!match.first)
-      continue;
-
-    // Make sure that any new operations are inserted at this point.
-    rewriter.setInsertionPoint(op);
-    match.first->rewrite(op, std::move(match.second), rewriter);
-  }
-
-  uniquedConstants.clear();
-}
-
-static void processMLFunction(MLFunction *fn, OwningPatternList &&patterns) {
-  class MLFuncRewriter : public WorklistRewriter {
-  public:
-    MLFuncRewriter(GreedyPatternRewriteDriver &driver, MLFuncBuilder &builder)
-        : WorklistRewriter(driver, builder.getContext()), builder(builder) {}
-
-    // Implement the hook for creating operations, and make sure that newly
-    // created ops are added to the worklist for processing.
-    Operation *createOperation(const OperationState &state) override {
-      auto *result = builder.createOperation(state);
-      driver.addToWorklist(result);
-      return result;
-    }
-
-    // When the root of a pattern is about to be replaced, it can trigger
-    // simplifications to its users - make sure to add them to the worklist
-    // before the root is changed.
-    void notifyRootReplaced(Operation *op) override {
-      auto *opStmt = cast<OperationStmt>(op);
-      for (auto *result : opStmt->getResults())
-        // TODO: Add a result->getUsers() iterator.
-        for (auto &user : result->getUses()) {
-          if (auto *op = dyn_cast<OperationStmt>(user.getOwner()))
-            driver.addToWorklist(op);
-        }
-
-      // TODO: Walk the operand list dropping them as we go.  If any of them
-      // drop to zero uses, then add them to the worklist to allow them to be
-      // deleted as dead.
-    }
-
-    void setInsertionPoint(Operation *op) override {
-      // Any new operations should be added before this statement.
-      builder.setInsertionPoint(cast<OperationStmt>(op));
-    }
-
-  private:
-    MLFuncBuilder &builder;
-  };
-
-  GreedyPatternRewriteDriver driver(std::move(patterns));
-  fn->walk([&](OperationStmt *stmt) { driver.addToWorklist(stmt); });
-
-  MLFuncBuilder mlBuilder(fn);
-  MLFuncRewriter rewriter(driver, mlBuilder);
-  driver.simplifyFunction(fn, rewriter);
-}
-
-static void processCFGFunction(CFGFunction *fn, OwningPatternList &&patterns) {
-  class CFGFuncRewriter : public WorklistRewriter {
-  public:
-    CFGFuncRewriter(GreedyPatternRewriteDriver &driver, CFGFuncBuilder &builder)
-        : WorklistRewriter(driver, builder.getContext()), builder(builder) {}
-
-    // Implement the hook for creating operations, and make sure that newly
-    // created ops are added to the worklist for processing.
-    Operation *createOperation(const OperationState &state) override {
-      auto *result = builder.createOperation(state);
-      driver.addToWorklist(result);
-      return result;
-    }
-
-    // When the root of a pattern is about to be replaced, it can trigger
-    // simplifications to its users - make sure to add them to the worklist
-    // before the root is changed.
-    void notifyRootReplaced(Operation *op) override {
-      auto *opStmt = cast<OperationInst>(op);
-      for (auto *result : opStmt->getResults())
-        // TODO: Add a result->getUsers() iterator.
-        for (auto &user : result->getUses()) {
-          if (auto *op = dyn_cast<OperationInst>(user.getOwner()))
-            driver.addToWorklist(op);
-        }
-
-      // TODO: Walk the operand list dropping them as we go.  If any of them
-      // drop to zero uses, then add them to the worklist to allow them to be
-      // deleted as dead.
-    }
-
-    void setInsertionPoint(Operation *op) override {
-      // Any new operations should be added before this instruction.
-      builder.setInsertionPoint(cast<OperationInst>(op));
-    }
-
-  private:
-    CFGFuncBuilder &builder;
-  };
-
-  GreedyPatternRewriteDriver driver(std::move(patterns));
-  for (auto &bb : *fn)
-    for (auto &op : bb)
-      driver.addToWorklist(&op);
-
-  CFGFuncBuilder cfgBuilder(fn);
-  CFGFuncRewriter rewriter(driver, cfgBuilder);
-  driver.simplifyFunction(fn, rewriter);
-}
-
-/// Rewrite the specified function by repeatedly applying the highest benefit
-/// patterns in a greedy work-list driven manner.
-///
-void mlir::applyPatternsGreedily(Function *fn, OwningPatternList &&patterns) {
-  if (auto *cfg = dyn_cast<CFGFunction>(fn)) {
-    processCFGFunction(cfg, std::move(patterns));
-  } else {
-    processMLFunction(cast<MLFunction>(fn), std::move(patterns));
-  }
-}
diff --git a/mlir/lib/Transforms/LoopUtils.cpp b/mlir/lib/Transforms/LoopUtils.cpp
deleted file mode 100644 (file)
index a6a8502..0000000
+++ /dev/null
@@ -1,316 +0,0 @@
-//===- LoopUtils.cpp ---- Misc utilities for loop transformation ----------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file implements miscellaneous loop transformation routines.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Transforms/LoopUtils.h"
-
-#include "mlir/Analysis/LoopAnalysis.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Statements.h"
-#include "mlir/IR/StmtVisitor.h"
-#include "mlir/StandardOps/StandardOps.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/Support/Debug.h"
-
-#define DEBUG_TYPE "LoopUtils"
-
-using namespace mlir;
-
-/// Returns the upper bound of an unrolled loop with lower bound 'lb' and with
-/// the specified trip count, stride, and unroll factor. Returns nullptr when
-/// the trip count can't be expressed as an affine expression.
-AffineMap mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt,
-                                          unsigned unrollFactor,
-                                          MLFuncBuilder *builder) {
-  auto lbMap = forStmt.getLowerBoundMap();
-
-  // Single result lower bound map only.
-  if (lbMap.getNumResults() != 1)
-    return AffineMap::Null();
-
-  // Sometimes, the trip count cannot be expressed as an affine expression.
-  auto tripCount = getTripCountExpr(forStmt);
-  if (!tripCount)
-    return AffineMap::Null();
-
-  AffineExpr lb(lbMap.getResult(0));
-  unsigned step = forStmt.getStep();
-  auto newUb = lb + (tripCount - tripCount % unrollFactor - 1) * step;
-
-  return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(),
-                               {newUb}, {});
-}
-
-/// Returns the lower bound of the cleanup loop when unrolling a loop with lower
-/// bound 'lb' and with the specified trip count, stride, and unroll factor.
-/// Returns an AffinMap with nullptr storage (that evaluates to false)
-/// when the trip count can't be expressed as an affine expression.
-AffineMap mlir::getCleanupLoopLowerBound(const ForStmt &forStmt,
-                                         unsigned unrollFactor,
-                                         MLFuncBuilder *builder) {
-  auto lbMap = forStmt.getLowerBoundMap();
-
-  // Single result lower bound map only.
-  if (lbMap.getNumResults() != 1)
-    return AffineMap::Null();
-
-  // Sometimes the trip count cannot be expressed as an affine expression.
-  AffineExpr tripCount(getTripCountExpr(forStmt));
-  if (!tripCount)
-    return AffineMap::Null();
-
-  AffineExpr lb(lbMap.getResult(0));
-  unsigned step = forStmt.getStep();
-  auto newLb = lb + (tripCount - tripCount % unrollFactor) * step;
-  return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(),
-                               {newLb}, {});
-}
-
-/// Promotes the loop body of a forStmt to its containing block if the forStmt
-/// was known to have a single iteration. Returns false otherwise.
-// TODO(bondhugula): extend this for arbitrary affine bounds.
-bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
-  Optional<uint64_t> tripCount = getConstantTripCount(*forStmt);
-  if (!tripCount.hasValue() || tripCount.getValue() != 1)
-    return false;
-
-  // TODO(mlir-team): there is no builder for a max.
-  if (forStmt->getLowerBoundMap().getNumResults() != 1)
-    return false;
-
-  // Replaces all IV uses to its single iteration value.
-  if (!forStmt->use_empty()) {
-    if (forStmt->hasConstantLowerBound()) {
-      auto *mlFunc = forStmt->findFunction();
-      MLFuncBuilder topBuilder(&mlFunc->front());
-      auto constOp = topBuilder.create<ConstantIndexOp>(
-          forStmt->getLoc(), forStmt->getConstantLowerBound());
-      forStmt->replaceAllUsesWith(constOp);
-    } else {
-      const AffineBound lb = forStmt->getLowerBound();
-      SmallVector<SSAValue *, 4> lbOperands(lb.operand_begin(),
-                                            lb.operand_end());
-      MLFuncBuilder builder(forStmt->getBlock(), StmtBlock::iterator(forStmt));
-      auto affineApplyOp = builder.create<AffineApplyOp>(
-          forStmt->getLoc(), lb.getMap(), lbOperands);
-      forStmt->replaceAllUsesWith(affineApplyOp->getResult(0));
-    }
-  }
-  // Move the loop body statements to the loop's containing block.
-  auto *block = forStmt->getBlock();
-  block->getStatements().splice(StmtBlock::iterator(forStmt),
-                                forStmt->getStatements());
-  forStmt->erase();
-  return true;
-}
-
-/// Promotes all single iteration for stmt's in the MLFunction, i.e., moves
-/// their body into the containing StmtBlock.
-void mlir::promoteSingleIterationLoops(MLFunction *f) {
-  // Gathers all innermost loops through a post order pruned walk.
-  class LoopBodyPromoter : public StmtWalker<LoopBodyPromoter> {
-  public:
-    void visitForStmt(ForStmt *forStmt) { promoteIfSingleIteration(forStmt); }
-  };
-
-  LoopBodyPromoter fsw;
-  fsw.walkPostOrder(f);
-}
-
-/// Generates a for 'stmt' with the specified lower and upper bounds while
-/// generating the right IV remappings for the delayed statements. The
-/// statement blocks that go into the loop are specified in stmtGroupQueue
-/// starting from the specified offset, and in that order; the first element of
-/// the pair specifies the delay applied to that group of statements. Returns
-/// nullptr if the generated loop simplifies to a single iteration one.
-static ForStmt *
-generateLoop(AffineMap lb, AffineMap ub,
-             const std::vector<std::pair<uint64_t, ArrayRef<Statement *>>>
-                 &stmtGroupQueue,
-             unsigned offset, ForStmt *srcForStmt, MLFuncBuilder *b) {
-  SmallVector<MLValue *, 4> lbOperands(srcForStmt->getLowerBoundOperands());
-  SmallVector<MLValue *, 4> ubOperands(srcForStmt->getUpperBoundOperands());
-
-  auto *loopChunk =
-      b->createFor(srcForStmt->getLoc(), lbOperands, lb, ubOperands, ub);
-  OperationStmt::OperandMapTy operandMap;
-
-  for (auto it = stmtGroupQueue.begin() + offset, e = stmtGroupQueue.end();
-       it != e; ++it) {
-    auto elt = *it;
-    // All 'same delay' statements get added with the operands being remapped
-    // (to results of cloned statements).
-    // Generate the remapping if the delay is not zero: oldIV = newIV - delay.
-    // TODO(bondhugula): check if srcForStmt is actually used in elt.second
-    // instead of just checking if it's used at all.
-    if (!srcForStmt->use_empty() && elt.first != 0) {
-      auto b = MLFuncBuilder::getForStmtBodyBuilder(loopChunk);
-      auto *oldIV =
-          b.create<AffineApplyOp>(
-               srcForStmt->getLoc(),
-               b.getSingleDimShiftAffineMap(-static_cast<int64_t>(elt.first)),
-               loopChunk)
-              ->getResult(0);
-      operandMap[srcForStmt] = cast<MLValue>(oldIV);
-    } else {
-      operandMap[srcForStmt] = static_cast<MLValue *>(loopChunk);
-    }
-    for (auto *stmt : elt.second) {
-      loopChunk->push_back(stmt->clone(operandMap, b->getContext()));
-    }
-  }
-  if (promoteIfSingleIteration(loopChunk))
-    return nullptr;
-  return loopChunk;
-}
-
-/// Skew the statements in the body of a 'for' statement with the specified
-/// statement-wise delays. The delays are with respect to the original execution
-/// order. A delay of zero for each statement will lead to no change.
-// The skewing of statements with respect to one another can be used for example
-// to allow overlap of asynchronous operations (such as DMA communication) with
-// computation, or just relative shifting of statements for better register
-// reuse, locality or parallelism. As such, the delays are typically expected to
-// be at most of the order of the number of statements. This method should not
-// be used as a substitute for loop distribution/fission.
-// This method uses an algorithm// in time linear in the number of statements in
-// the body of the for loop - (using the 'sweep line' paradigm). This method
-// asserts preservation of SSA dominance. A check for that as well as that for
-// memory-based depedence preservation check rests with the users of this
-// method.
-UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays,
-                              bool unrollPrologueEpilogue) {
-  if (forStmt->getStatements().empty())
-    return UtilResult::Success;
-
-  // If the trip counts aren't constant, we would need versioning and
-  // conditional guards (or context information to prevent such versioning). The
-  // better way to pipeline for such loops is to first tile them and extract
-  // constant trip count "full tiles" before applying this.
-  auto mayBeConstTripCount = getConstantTripCount(*forStmt);
-  if (!mayBeConstTripCount.hasValue()) {
-    LLVM_DEBUG(llvm::dbgs() << "non-constant trip count loop\n";);
-    return UtilResult::Success;
-  }
-  uint64_t tripCount = mayBeConstTripCount.getValue();
-
-  assert(isStmtwiseShiftValid(*forStmt, delays) &&
-         "shifts will lead to an invalid transformation\n");
-
-  unsigned numChildStmts = forStmt->getStatements().size();
-
-  // Do a linear time (counting) sort for the delays.
-  uint64_t maxDelay = 0;
-  for (unsigned i = 0; i < numChildStmts; i++) {
-    maxDelay = std::max(maxDelay, delays[i]);
-  }
-  // Such large delays are not the typical use case.
-  if (maxDelay >= numChildStmts) {
-    LLVM_DEBUG(llvm::dbgs() << "stmt delays too large - unexpected\n";);
-    return UtilResult::Success;
-  }
-
-  // An array of statement groups sorted by delay amount; each group has all
-  // statements with the same delay in the order in which they appear in the
-  // body of the 'for' stmt.
-  std::vector<std::vector<Statement *>> sortedStmtGroups(maxDelay + 1);
-  unsigned pos = 0;
-  for (auto &stmt : *forStmt) {
-    auto delay = delays[pos++];
-    sortedStmtGroups[delay].push_back(&stmt);
-  }
-
-  // Unless the shifts have a specific pattern (which actually would be the
-  // common use case), prologue and epilogue are not meaningfully defined.
-  // Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first
-  // loop generated as the prologue and the last as epilogue and unroll these
-  // fully.
-  ForStmt *prologue = nullptr;
-  ForStmt *epilogue = nullptr;
-
-  // Do a sweep over the sorted delays while storing open groups in a
-  // vector, and generating loop portions as necessary during the sweep. A block
-  // of statements is paired with its delay.
-  std::vector<std::pair<uint64_t, ArrayRef<Statement *>>> stmtGroupQueue;
-
-  auto origLbMap = forStmt->getLowerBoundMap();
-  uint64_t lbDelay = 0;
-  MLFuncBuilder b(forStmt);
-  for (uint64_t d = 0, e = sortedStmtGroups.size(); d < e; ++d) {
-    // If nothing is delayed by d, continue.
-    if (sortedStmtGroups[d].empty())
-      continue;
-    if (!stmtGroupQueue.empty()) {
-      assert(d >= 1 &&
-             "Queue expected to be empty when the first block is found");
-      // The interval for which the loop needs to be generated here is:
-      // ( lbDelay, min(lbDelay + tripCount - 1, d - 1) ] and the body of the
-      // loop needs to have all statements in stmtQueue in that order.
-      ForStmt *res;
-      if (lbDelay + tripCount - 1 < d - 1) {
-        res = generateLoop(
-            b.getShiftedAffineMap(origLbMap, lbDelay),
-            b.getShiftedAffineMap(origLbMap, lbDelay + tripCount - 1),
-            stmtGroupQueue, 0, forStmt, &b);
-        // Entire loop for the queued stmt groups generated, empty it.
-        stmtGroupQueue.clear();
-        lbDelay += tripCount;
-      } else {
-        res = generateLoop(b.getShiftedAffineMap(origLbMap, lbDelay),
-                           b.getShiftedAffineMap(origLbMap, d - 1),
-                           stmtGroupQueue, 0, forStmt, &b);
-        lbDelay = d;
-      }
-      if (!prologue && res)
-        prologue = res;
-      epilogue = res;
-    } else {
-      // Start of first interval.
-      lbDelay = d;
-    }
-    // Augment the list of statements that get into the current open interval.
-    stmtGroupQueue.push_back({d, sortedStmtGroups[d]});
-  }
-
-  // Those statements groups left in the queue now need to be processed (FIFO)
-  // and their loops completed.
-  for (unsigned i = 0, e = stmtGroupQueue.size(); i < e; ++i) {
-    uint64_t ubDelay = stmtGroupQueue[i].first + tripCount - 1;
-    epilogue = generateLoop(b.getShiftedAffineMap(origLbMap, lbDelay),
-                            b.getShiftedAffineMap(origLbMap, ubDelay),
-                            stmtGroupQueue, i, forStmt, &b);
-    lbDelay = ubDelay + 1;
-    if (!prologue)
-      prologue = epilogue;
-  }
-
-  // Erase the original for stmt.
-  forStmt->erase();
-
-  if (unrollPrologueEpilogue && prologue)
-    loopUnrollFull(prologue);
-  if (unrollPrologueEpilogue && !epilogue && epilogue != prologue)
-    loopUnrollFull(epilogue);
-
-  return UtilResult::Success;
-}
diff --git a/mlir/lib/Transforms/Pass.cpp b/mlir/lib/Transforms/Pass.cpp
deleted file mode 100644 (file)
index 8b11107..0000000
+++ /dev/null
@@ -1,41 +0,0 @@
-//===- Pass.cpp - Pass infrastructure implementation ----------------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file implements common pass infrastructure.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Transforms/Pass.h"
-#include "mlir/IR/CFGFunction.h"
-#include "mlir/IR/MLFunction.h"
-#include "mlir/IR/Module.h"
-
-using namespace mlir;
-
-/// Function passes walk a module and look at each function with their
-/// corresponding hooks and terminates upon error encountered.
-PassResult FunctionPass::runOnModule(Module *m) {
-  for (auto &fn : *m) {
-    if (auto *mlFunc = dyn_cast<MLFunction>(&fn))
-      if (runOnMLFunction(mlFunc))
-        return failure();
-    if (auto *cfgFunc = dyn_cast<CFGFunction>(&fn))
-      if (runOnCFGFunction(cfgFunc))
-        return failure();
-  }
-  return success();
-}
diff --git a/mlir/lib/Transforms/PatternMatch.cpp b/mlir/lib/Transforms/PatternMatch.cpp
deleted file mode 100644 (file)
index 6cc3b43..0000000
+++ /dev/null
@@ -1,196 +0,0 @@
-//===- PatternMatch.cpp - Base classes for pattern match ------------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-#include "mlir/IR/SSAValue.h"
-#include "mlir/IR/Statements.h"
-#include "mlir/StandardOps/StandardOps.h"
-#include "mlir/Transforms/PatternMatch.h"
-using namespace mlir;
-
-PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
-  assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
-         "This pattern match benefit is too large to represent");
-}
-
-unsigned short PatternBenefit::getBenefit() const {
-  assert(representation != ImpossibleToMatchSentinel &&
-         "Pattern doesn't match");
-  return representation;
-}
-
-bool PatternBenefit::operator==(const PatternBenefit& other) {
-  if (isImpossibleToMatch())
-    return other.isImpossibleToMatch();
-  if (other.isImpossibleToMatch())
-    return false;
-  return getBenefit() == other.getBenefit();
-}
-
-bool PatternBenefit::operator!=(const PatternBenefit& other) {
-  return !(*this == other);
-}
-
-//===----------------------------------------------------------------------===//
-// Pattern implementation
-//===----------------------------------------------------------------------===//
-
-Pattern::Pattern(StringRef rootName, MLIRContext *context,
-                 Optional<PatternBenefit> staticBenefit)
-    : rootKind(OperationName(rootName, context)), staticBenefit(staticBenefit) {
-}
-
-Pattern::Pattern(StringRef rootName, MLIRContext *context,
-                 unsigned staticBenefit)
-    : rootKind(rootName, context), staticBenefit(staticBenefit) {}
-
-Optional<PatternBenefit> Pattern::getStaticBenefit() const {
-  return staticBenefit;
-}
-
-OperationName Pattern::getRootKind() const { return rootKind; }
-
-void Pattern::rewrite(Operation *op, std::unique_ptr<PatternState> state,
-                      PatternRewriter &rewriter) const {
-  rewrite(op, rewriter);
-}
-
-void Pattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
-  llvm_unreachable("need to implement one of the rewrite functions!");
-}
-
-/// This method indicates that no match was found.
-PatternMatchResult Pattern::matchFailure() {
-  return {PatternBenefit::impossibleToMatch(), std::unique_ptr<PatternState>()};
-}
-
-/// This method indicates that a match was found and has the specified cost.
-PatternMatchResult
-Pattern::matchSuccess(PatternBenefit benefit,
-                      std::unique_ptr<PatternState> state) const {
-  assert((!getStaticBenefit().hasValue() ||
-          getStaticBenefit().getValue() == benefit) &&
-         "This version of matchSuccess must be called with a benefit that "
-         "matches the static benefit if set!");
-
-  return {benefit, std::move(state)};
-}
-
-/// This method indicates that a match was found for patterns that have a
-/// known static benefit.
-PatternMatchResult
-Pattern::matchSuccess(std::unique_ptr<PatternState> state) const {
-  auto benefit = getStaticBenefit();
-  assert(benefit.hasValue() && "Pattern doesn't have a static benefit");
-  return matchSuccess(benefit.getValue(), std::move(state));
-}
-
-//===----------------------------------------------------------------------===//
-// PatternRewriter implementation
-//===----------------------------------------------------------------------===//
-
-PatternRewriter::~PatternRewriter() {
-  // Out of line to provide a vtable anchor for the class.
-}
-
-/// This method is used as the final replacement hook for patterns that match
-/// a single result value.  In addition to replacing and removing the
-/// specified operation, clients can specify a list of other nodes that this
-/// replacement may make (perhaps transitively) dead.  If any of those ops are
-/// dead, this will remove them as well.
-void PatternRewriter::replaceSingleResultOp(
-    Operation *op, SSAValue *newValue, ArrayRef<SSAValue *> opsToRemoveIfDead) {
-  // Notify the rewriter subclass that we're about to replace this root.
-  notifyRootReplaced(op);
-
-  assert(op->getNumResults() == 1 && "op isn't a SingleResultOp!");
-  op->getResult(0)->replaceAllUsesWith(newValue);
-
-  notifyOperationRemoved(op);
-  op->erase();
-
-  // TODO: Process the opsToRemoveIfDead list, removing things and calling the
-  // notifyOperationRemoved hook in the process.
-}
-
-/// This method is used as the final notification hook for patterns that end
-/// up modifying the pattern root in place, by changing its operands.  This is
-/// a minor efficiency win (it avoids creating a new instruction and removing
-/// the old one) but also often allows simpler code in the client.
-///
-/// The opsToRemoveIfDead list is an optional list of nodes that the rewriter
-/// should remove if they are dead at this point.
-///
-void PatternRewriter::updatedRootInPlace(
-    Operation *op, ArrayRef<SSAValue *> opsToRemoveIfDead) {
-  // Notify the rewriter subclass that we're about to replace this root.
-  notifyRootUpdated(op);
-
-  // TODO: Process the opsToRemoveIfDead list, removing things and calling the
-  // notifyOperationRemoved hook in the process.
-}
-
-//===----------------------------------------------------------------------===//
-// PatternMatcher implementation
-//===----------------------------------------------------------------------===//
-
-/// Find the highest benefit pattern available in the pattern set for the DAG
-/// rooted at the specified node.  This returns the pattern if found, or null
-/// if there are no matches.
-auto PatternMatcher::findMatch(Operation *op) -> MatchResult {
-  // TODO: This is a completely trivial implementation, expand this in the
-  // future.
-
-  // Keep track of the best match, the benefit of it, and any matcher specific
-  // state it is maintaining.
-  MatchResult bestMatch = {nullptr, nullptr};
-  Optional<PatternBenefit> bestBenefit;
-
-  for (auto &pattern : patterns) {
-    // Ignore patterns that are for the wrong root.
-    if (pattern->getRootKind() != op->getName())
-      continue;
-
-    // If we know the static cost of the pattern is worse than what we've
-    // already found then don't run it.
-    auto staticBenefit = pattern->getStaticBenefit();
-    if (staticBenefit.hasValue() && bestBenefit.hasValue() &&
-        staticBenefit.getValue().getBenefit() <
-            bestBenefit.getValue().getBenefit())
-      continue;
-
-    // Check to see if this pattern matches this node.
-    auto result = pattern->match(op);
-    auto benefit = result.first;
-
-    // If this pattern failed to match, ignore it.
-    if (benefit.isImpossibleToMatch())
-      continue;
-
-    // If it matched but had lower benefit than our best match so far, then
-    // ignore it.
-    if (bestBenefit.hasValue() &&
-        benefit.getBenefit() < bestBenefit.getValue().getBenefit())
-      continue;
-
-    // Okay we found a match that is better than our previous one, remember it.
-    bestBenefit = benefit;
-    bestMatch = {pattern.get(), std::move(result.second)};
-  }
-
-  // If we found any match, return it.
-  return bestMatch;
-}
diff --git a/mlir/lib/Transforms/Utils.cpp b/mlir/lib/Transforms/Utils.cpp
deleted file mode 100644 (file)
index 4432a80..0000000
+++ /dev/null
@@ -1,394 +0,0 @@
-//===- Utils.cpp ---- Misc utilities for code and data transformation -----===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file implements miscellaneous transformation routines for non-loop IR
-// structures.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Transforms/Utils.h"
-
-#include "mlir/Analysis/AffineAnalysis.h"
-#include "mlir/Analysis/AffineStructures.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/StandardOps/StandardOps.h"
-#include "mlir/Support/MathExtras.h"
-#include "llvm/ADT/DenseMap.h"
-
-using namespace mlir;
-
-/// Return true if this operation dereferences one or more memref's.
-// Temporary utility: will be replaced when this is modeled through
-// side-effects/op traits. TODO(b/117228571)
-static bool isMemRefDereferencingOp(const Operation &op) {
-  if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() ||
-      op.isa<DmaWaitOp>())
-    return true;
-  return false;
-}
-
-/// Replaces all uses of oldMemRef with newMemRef while optionally remapping
-/// old memref's indices to the new memref using the supplied affine map
-/// and adding any additional indices. The new memref could be of a different
-/// shape or rank, but of the same elemental type. Additional indices are added
-/// at the start for now.
-// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily
-// extended to add additional indices at any position.
-bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
-                                    MLValue *newMemRef,
-                                    ArrayRef<MLValue *> extraIndices,
-                                    AffineMap indexRemap) {
-  unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank();
-  (void)newMemRefRank; // unused in opt mode
-  unsigned oldMemRefRank = cast<MemRefType>(oldMemRef->getType())->getRank();
-  (void)newMemRefRank;
-  if (indexRemap) {
-    assert(indexRemap.getNumInputs() == oldMemRefRank);
-    assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
-  } else {
-    assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
-  }
-
-  // Assert same elemental type.
-  assert(cast<MemRefType>(oldMemRef->getType())->getElementType() ==
-         cast<MemRefType>(newMemRef->getType())->getElementType());
-
-  // Check if memref was used in a non-deferencing context.
-  for (const StmtOperand &use : oldMemRef->getUses()) {
-    auto *opStmt = cast<OperationStmt>(use.getOwner());
-    // Failure: memref used in a non-deferencing op (potentially escapes); no
-    // replacement in these cases.
-    if (!isMemRefDereferencingOp(*opStmt))
-      return false;
-  }
-
-  // Walk all uses of old memref. Statement using the memref gets replaced.
-  for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) {
-    StmtOperand &use = *(it++);
-    auto *opStmt = cast<OperationStmt>(use.getOwner());
-    assert(isMemRefDereferencingOp(*opStmt) &&
-           "memref deferencing op expected");
-
-    auto getMemRefOperandPos = [&]() -> unsigned {
-      unsigned i;
-      for (i = 0; i < opStmt->getNumOperands(); i++) {
-        if (opStmt->getOperand(i) == oldMemRef)
-          break;
-      }
-      assert(i < opStmt->getNumOperands() && "operand guaranteed to be found");
-      return i;
-    };
-    unsigned memRefOperandPos = getMemRefOperandPos();
-
-    // Construct the new operation statement using this memref.
-    SmallVector<MLValue *, 8> operands;
-    operands.reserve(opStmt->getNumOperands() + extraIndices.size());
-    // Insert the non-memref operands.
-    operands.insert(operands.end(), opStmt->operand_begin(),
-                    opStmt->operand_begin() + memRefOperandPos);
-    operands.push_back(newMemRef);
-
-    MLFuncBuilder builder(opStmt);
-    for (auto *extraIndex : extraIndices) {
-      // TODO(mlir-team): An operation/SSA value should provide a method to
-      // return the position of an SSA result in its defining
-      // operation.
-      assert(extraIndex->getDefiningStmt()->getNumResults() == 1 &&
-             "single result op's expected to generate these indices");
-      assert((cast<MLValue>(extraIndex)->isValidDim() ||
-              cast<MLValue>(extraIndex)->isValidSymbol()) &&
-             "invalid memory op index");
-      operands.push_back(cast<MLValue>(extraIndex));
-    }
-
-    // Construct new indices. The indices of a memref come right after it, i.e.,
-    // at position memRefOperandPos + 1.
-    SmallVector<SSAValue *, 4> indices(
-        opStmt->operand_begin() + memRefOperandPos + 1,
-        opStmt->operand_begin() + memRefOperandPos + 1 + oldMemRefRank);
-    if (indexRemap) {
-      auto remapOp =
-          builder.create<AffineApplyOp>(opStmt->getLoc(), indexRemap, indices);
-      // Remapped indices.
-      for (auto *index : remapOp->getOperation()->getResults())
-        operands.push_back(cast<MLValue>(index));
-    } else {
-      // No remapping specified.
-      for (auto *index : indices)
-        operands.push_back(cast<MLValue>(index));
-    }
-
-    // Insert the remaining operands unmodified.
-    operands.insert(operands.end(),
-                    opStmt->operand_begin() + memRefOperandPos + 1 +
-                        oldMemRefRank,
-                    opStmt->operand_end());
-
-    // Result types don't change. Both memref's are of the same elemental type.
-    SmallVector<Type *, 8> resultTypes;
-    resultTypes.reserve(opStmt->getNumResults());
-    for (const auto *result : opStmt->getResults())
-      resultTypes.push_back(result->getType());
-
-    // Create the new operation.
-    auto *repOp =
-        builder.createOperation(opStmt->getLoc(), opStmt->getName(), operands,
-                                resultTypes, opStmt->getAttrs());
-    // Replace old memref's deferencing op's uses.
-    unsigned r = 0;
-    for (auto *res : opStmt->getResults()) {
-      res->replaceAllUsesWith(repOp->getResult(r++));
-    }
-    opStmt->erase();
-  }
-  return true;
-}
-
-// Creates and inserts into 'builder' a new AffineApplyOp, with the number of
-// its results equal to the number of 'operands, as a composition
-// of all other AffineApplyOps reachable from input parameter 'operands'. If the
-// operands were drawing results from multiple affine apply ops, this also leads
-// to a collapse into a single affine apply op. The final results of the
-// composed AffineApplyOp are returned in output parameter 'results'.
-OperationStmt *
-mlir::createComposedAffineApplyOp(MLFuncBuilder *builder, Location *loc,
-                                  ArrayRef<MLValue *> operands,
-                                  ArrayRef<OperationStmt *> affineApplyOps,
-                                  SmallVectorImpl<SSAValue *> &results) {
-  // Create identity map with same number of dimensions as number of operands.
-  auto map = builder->getMultiDimIdentityMap(operands.size());
-  // Initialize AffineValueMap with identity map.
-  AffineValueMap valueMap(map, operands);
-
-  for (auto *opStmt : affineApplyOps) {
-    assert(opStmt->isa<AffineApplyOp>());
-    auto affineApplyOp = opStmt->cast<AffineApplyOp>();
-    // Forward substitute 'affineApplyOp' into 'valueMap'.
-    valueMap.forwardSubstitute(*affineApplyOp);
-  }
-  // Compose affine maps from all ancestor AffineApplyOps.
-  // Create new AffineApplyOp from 'valueMap'.
-  unsigned numOperands = valueMap.getNumOperands();
-  SmallVector<SSAValue *, 4> outOperands(numOperands);
-  for (unsigned i = 0; i < numOperands; ++i) {
-    outOperands[i] = valueMap.getOperand(i);
-  }
-  // Create new AffineApplyOp based on 'valueMap'.
-  auto affineApplyOp =
-      builder->create<AffineApplyOp>(loc, valueMap.getAffineMap(), outOperands);
-  results.resize(operands.size());
-  for (unsigned i = 0, e = operands.size(); i < e; ++i) {
-    results[i] = affineApplyOp->getResult(i);
-  }
-  return cast<OperationStmt>(affineApplyOp->getOperation());
-}
-
-/// Given an operation statement, inserts a new single affine apply operation,
-/// that is exclusively used by this operation statement, and that provides all
-/// operands that are results of an affine_apply as a function of loop iterators
-/// and program parameters and whose results are.
-///
-/// Before
-///
-/// for %i = 0 to #map(%N)
-///   %idx = affine_apply (d0) -> (d0 mod 2) (%i)
-///   "send"(%idx, %A, ...)
-///   "compute"(%idx)
-///
-/// After
-///
-/// for %i = 0 to #map(%N)
-///   %idx = affine_apply (d0) -> (d0 mod 2) (%i)
-///   "send"(%idx, %A, ...)
-///   %idx_ = affine_apply (d0) -> (d0 mod 2) (%i)
-///   "compute"(%idx_)
-///
-/// This allows applying different transformations on send and compute (for eg.
-/// different shifts/delays).
-///
-/// Returns nullptr either if none of opStmt's operands were the result of an
-/// affine_apply and thus there was no affine computation slice to create, or if
-/// all the affine_apply op's supplying operands to this opStmt do not have any
-/// uses besides this opStmt. Returns the new affine_apply operation statement
-/// otherwise.
-OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) {
-  // Collect all operands that are results of affine apply ops.
-  SmallVector<MLValue *, 4> subOperands;
-  subOperands.reserve(opStmt->getNumOperands());
-  for (auto *operand : opStmt->getOperands()) {
-    auto *defStmt = operand->getDefiningStmt();
-    if (defStmt && defStmt->isa<AffineApplyOp>()) {
-      subOperands.push_back(operand);
-    }
-  }
-
-  // Gather sequence of AffineApplyOps reachable from 'subOperands'.
-  SmallVector<OperationStmt *, 4> affineApplyOps;
-  getReachableAffineApplyOps(subOperands, affineApplyOps);
-  // Skip transforming if there are no affine maps to compose.
-  if (affineApplyOps.empty())
-    return nullptr;
-
-  // Check if all uses of the affine apply op's lie in this op stmt
-  // itself, in which case there would be nothing to do.
-  bool localized = true;
-  for (auto *op : affineApplyOps) {
-    for (auto *result : op->getResults()) {
-      for (auto &use : result->getUses()) {
-        if (use.getOwner() != opStmt) {
-          localized = false;
-          break;
-        }
-      }
-    }
-  }
-  if (localized)
-    return nullptr;
-
-  MLFuncBuilder builder(opStmt);
-  SmallVector<SSAValue *, 4> results;
-  auto *affineApplyStmt = createComposedAffineApplyOp(
-      &builder, opStmt->getLoc(), subOperands, affineApplyOps, results);
-  assert(results.size() == subOperands.size() &&
-         "number of results should be the same as the number of subOperands");
-
-  // Construct the new operands that include the results from the composed
-  // affine apply op above instead of existing ones (subOperands). So, they
-  // differ from opStmt's operands only for those operands in 'subOperands', for
-  // which they will be replaced by the corresponding one from 'results'.
-  SmallVector<MLValue *, 4> newOperands(opStmt->getOperands());
-  for (unsigned i = 0, e = newOperands.size(); i < e; i++) {
-    // Replace the subOperands from among the new operands.
-    unsigned j, f;
-    for (j = 0, f = subOperands.size(); j < f; j++) {
-      if (newOperands[i] == subOperands[j])
-        break;
-    }
-    if (j < subOperands.size()) {
-      newOperands[i] = cast<MLValue>(results[j]);
-    }
-  }
-
-  for (unsigned idx = 0; idx < newOperands.size(); idx++) {
-    opStmt->setOperand(idx, newOperands[idx]);
-  }
-
-  return affineApplyStmt;
-}
-
-void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
-  if (affineApplyOp->getOperation()->getOperationFunction()->getKind() !=
-      Function::Kind::MLFunc) {
-    // TODO: Support forward substitution for CFGFunctions.
-    return;
-  }
-  auto *opStmt = cast<OperationStmt>(affineApplyOp->getOperation());
-  // Iterate through all uses of all results of 'opStmt', forward substituting
-  // into any uses which are AffineApplyOps.
-  for (unsigned resultIndex = 0, e = opStmt->getNumResults(); resultIndex < e;
-       ++resultIndex) {
-    const MLValue *result = opStmt->getResult(resultIndex);
-    for (auto it = result->use_begin(); it != result->use_end();) {
-      StmtOperand &use = *(it++);
-      auto *useStmt = use.getOwner();
-      auto *useOpStmt = dyn_cast<OperationStmt>(useStmt);
-      // Skip if use is not AffineApplyOp.
-      if (useOpStmt == nullptr || !useOpStmt->isa<AffineApplyOp>())
-        continue;
-      // Advance iterator past 'opStmt' operands which also use 'result'.
-      while (it != result->use_end() && it->getOwner() == useStmt)
-        ++it;
-
-      MLFuncBuilder builder(useOpStmt);
-      // Initialize AffineValueMap with 'affineApplyOp' which uses 'result'.
-      auto oldAffineApplyOp = useOpStmt->cast<AffineApplyOp>();
-      AffineValueMap valueMap(*oldAffineApplyOp);
-      // Forward substitute 'result' at index 'i' into 'valueMap'.
-      valueMap.forwardSubstituteSingle(*affineApplyOp, resultIndex);
-
-      // Create new AffineApplyOp from 'valueMap'.
-      unsigned numOperands = valueMap.getNumOperands();
-      SmallVector<SSAValue *, 4> operands(numOperands);
-      for (unsigned i = 0; i < numOperands; ++i) {
-        operands[i] = valueMap.getOperand(i);
-      }
-      auto newAffineApplyOp = builder.create<AffineApplyOp>(
-          useOpStmt->getLoc(), valueMap.getAffineMap(), operands);
-
-      // Update all uses to use results from 'newAffineApplyOp'.
-      for (unsigned i = 0, e = useOpStmt->getNumResults(); i < e; ++i) {
-        oldAffineApplyOp->getResult(i)->replaceAllUsesWith(
-            newAffineApplyOp->getResult(i));
-      }
-      // Erase 'oldAffineApplyOp'.
-      oldAffineApplyOp->getOperation()->erase();
-    }
-  }
-}
-
-/// Folds the specified (lower or upper) bound to a constant if possible
-/// considering its operands. Returns false if the folding happens for any of
-/// the bounds, true otherwise.
-bool mlir::constantFoldBounds(ForStmt *forStmt) {
-  auto foldLowerOrUpperBound = [forStmt](bool lower) {
-    // Check if the bound is already a constant.
-    if (lower && forStmt->hasConstantLowerBound())
-      return true;
-    if (!lower && forStmt->hasConstantUpperBound())
-      return true;
-
-    // Check to see if each of the operands is the result of a constant.  If so,
-    // get the value.  If not, ignore it.
-    SmallVector<Attribute *, 8> operandConstants;
-    auto boundOperands = lower ? forStmt->getLowerBoundOperands()
-                               : forStmt->getUpperBoundOperands();
-    for (const auto *operand : boundOperands) {
-      Attribute *operandCst = nullptr;
-      if (auto *operandOp = operand->getDefiningOperation()) {
-        if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
-          operandCst = operandConstantOp->getValue();
-      }
-      operandConstants.push_back(operandCst);
-    }
-
-    AffineMap boundMap =
-        lower ? forStmt->getLowerBoundMap() : forStmt->getUpperBoundMap();
-    assert(boundMap.getNumResults() >= 1 &&
-           "bound maps should have at least one result");
-    SmallVector<Attribute *, 4> foldedResults;
-    if (boundMap.constantFold(operandConstants, foldedResults))
-      return true;
-
-    // Compute the max or min as applicable over the results.
-    assert(!foldedResults.empty() && "bounds should have at least one result");
-    auto maxOrMin = cast<IntegerAttr>(foldedResults[0])->getValue();
-    for (unsigned i = 1; i < foldedResults.size(); i++) {
-      auto foldedResult = cast<IntegerAttr>(foldedResults[i])->getValue();
-      maxOrMin = lower ? std::max(maxOrMin, foldedResult)
-                       : std::min(maxOrMin, foldedResult);
-    }
-    lower ? forStmt->setConstantLowerBound(maxOrMin)
-          : forStmt->setConstantUpperBound(maxOrMin);
-
-    // Return false on success.
-    return false;
-  };
-
-  bool ret = foldLowerOrUpperBound(/*lower=*/true);
-  ret &= foldLowerOrUpperBound(/*lower=*/false);
-  return ret;
-}
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
new file mode 100644 (file)
index 0000000..5ed8eac
--- /dev/null
@@ -0,0 +1,343 @@
+//===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements mlir::applyPatternsGreedily.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/StandardOps/StandardOps.h"
+#include "mlir/Transforms/PatternMatch.h"
+#include "llvm/ADT/DenseMap.h"
+using namespace mlir;
+
+namespace {
+class WorklistRewriter;
+
+/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
+/// applies the locally optimal patterns in a roughly "bottom up" way.
+class GreedyPatternRewriteDriver {
+public:
+  explicit GreedyPatternRewriteDriver(OwningPatternList &&patterns)
+      : matcher(std::move(patterns)) {
+    worklist.reserve(64);
+  }
+
+  void simplifyFunction(Function *currentFunction, WorklistRewriter &rewriter);
+
+  void addToWorklist(Operation *op) {
+    worklistMap[op] = worklist.size();
+    worklist.push_back(op);
+  }
+
+  Operation *popFromWorklist() {
+    auto *op = worklist.back();
+    worklist.pop_back();
+
+    // This operation is no longer in the worklist, keep worklistMap up to date.
+    if (op)
+      worklistMap.erase(op);
+    return op;
+  }
+
+  /// If the specified operation is in the worklist, remove it.  If not, this is
+  /// a no-op.
+  void removeFromWorklist(Operation *op) {
+    auto it = worklistMap.find(op);
+    if (it != worklistMap.end()) {
+      assert(worklist[it->second] == op && "malformed worklist data structure");
+      worklist[it->second] = nullptr;
+    }
+  }
+
+private:
+  /// The low-level pattern matcher.
+  PatternMatcher matcher;
+
+  /// The worklist for this transformation keeps track of the operations that
+  /// need to be revisited, plus their index in the worklist.  This allows us to
+  /// efficiently remove operations from the worklist when they are removed even
+  /// if they aren't the root of a pattern.
+  std::vector<Operation *> worklist;
+  DenseMap<Operation *, unsigned> worklistMap;
+
+  /// As part of canonicalization, we move constants to the top of the entry
+  /// block of the current function and de-duplicate them.  This keeps track of
+  /// constants we have done this for.
+  DenseMap<std::pair<Attribute *, Type *>, Operation *> uniquedConstants;
+};
+}; // end anonymous namespace
+
+/// This is a listener object that updates our worklists and other data
+/// structures in response to operations being added and removed.
+namespace {
+class WorklistRewriter : public PatternRewriter {
+public:
+  WorklistRewriter(GreedyPatternRewriteDriver &driver, MLIRContext *context)
+      : PatternRewriter(context), driver(driver) {}
+
+  virtual void setInsertionPoint(Operation *op) = 0;
+
+  // If an operation is about to be removed, make sure it is not in our
+  // worklist anymore because we'd get dangling references to it.
+  void notifyOperationRemoved(Operation *op) override {
+    driver.removeFromWorklist(op);
+  }
+
+  GreedyPatternRewriteDriver &driver;
+};
+
+} // end anonymous namespace
+
+void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
+                                                  WorklistRewriter &rewriter) {
+  // These are scratch vectors used in the constant folding loop below.
+  SmallVector<Attribute *, 8> operandConstants, resultConstants;
+
+  while (!worklist.empty()) {
+    auto *op = popFromWorklist();
+
+    // Nulls get added to the worklist when operations are removed, ignore them.
+    if (op == nullptr)
+      continue;
+
+    // If we have a constant op, unique it into the entry block.
+    if (auto constant = op->dyn_cast<ConstantOp>()) {
+      // If this constant is dead, remove it, being careful to keep
+      // uniquedConstants up to date.
+      if (constant->use_empty()) {
+        auto it =
+            uniquedConstants.find({constant->getValue(), constant->getType()});
+        if (it != uniquedConstants.end() && it->second == op)
+          uniquedConstants.erase(it);
+        constant->erase();
+        continue;
+      }
+
+      // Check to see if we already have a constant with this type and value:
+      auto &entry = uniquedConstants[std::make_pair(constant->getValue(),
+                                                    constant->getType())];
+      if (entry) {
+        // If this constant is already our uniqued one, then leave it alone.
+        if (entry == op)
+          continue;
+
+        // Otherwise replace this redundant constant with the uniqued one.  We
+        // know this is safe because we move constants to the top of the
+        // function when they are uniqued, so we know they dominate all uses.
+        constant->replaceAllUsesWith(entry->getResult(0));
+        constant->erase();
+        continue;
+      }
+
+      // If we have no entry, then we should unique this constant as the
+      // canonical version.  To ensure safe dominance, move the operation to the
+      // top of the function.
+      entry = op;
+
+      // TODO: If we make terminators into Operations then we could turn this
+      // into a nice Operation::moveBefore(Operation*) method.  We just need the
+      // guarantee that a block is non-empty.
+      if (auto *cfgFunc = dyn_cast<CFGFunction>(currentFunction)) {
+        auto &entryBB = cfgFunc->front();
+        cast<OperationInst>(op)->moveBefore(&entryBB, entryBB.begin());
+      } else {
+        auto *mlFunc = cast<MLFunction>(currentFunction);
+        cast<OperationStmt>(op)->moveBefore(mlFunc, mlFunc->begin());
+      }
+
+      continue;
+    }
+
+    // If the operation has no side effects, and no users, then it is trivially
+    // dead - remove it.
+    if (op->hasNoSideEffect() && op->use_empty()) {
+      op->erase();
+      continue;
+    }
+
+    // Check to see if any operands to the instruction is constant and whether
+    // the operation knows how to constant fold itself.
+    operandConstants.clear();
+    for (auto *operand : op->getOperands()) {
+      Attribute *operandCst = nullptr;
+      if (auto *operandOp = operand->getDefiningOperation()) {
+        if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
+          operandCst = operandConstantOp->getValue();
+      }
+      operandConstants.push_back(operandCst);
+    }
+
+    // If constant folding was successful, create the result constants, RAUW the
+    // operation and remove it.
+    resultConstants.clear();
+    if (!op->constantFold(operandConstants, resultConstants)) {
+      rewriter.setInsertionPoint(op);
+
+      for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
+        auto *res = op->getResult(i);
+        if (res->use_empty()) // ignore dead uses.
+          continue;
+
+        // If we already have a canonicalized version of this constant, just
+        // reuse it.  Otherwise create a new one.
+        SSAValue *cstValue;
+        auto it = uniquedConstants.find({resultConstants[i], res->getType()});
+        if (it != uniquedConstants.end())
+          cstValue = it->second->getResult(0);
+        else
+          cstValue = rewriter.create<ConstantOp>(
+              op->getLoc(), resultConstants[i], res->getType());
+        res->replaceAllUsesWith(cstValue);
+      }
+
+      assert(op->hasNoSideEffect() && "Constant folded op with side effects?");
+      op->erase();
+      continue;
+    }
+
+    // If this is an associative binary operation with a constant on the LHS,
+    // move it to the right side.
+    if (operandConstants.size() == 2 && operandConstants[0] &&
+        !operandConstants[1]) {
+      auto *newLHS = op->getOperand(1);
+      op->setOperand(1, op->getOperand(0));
+      op->setOperand(0, newLHS);
+    }
+
+    // Check to see if we have any patterns that match this node.
+    auto match = matcher.findMatch(op);
+    if (!match.first)
+      continue;
+
+    // Make sure that any new operations are inserted at this point.
+    rewriter.setInsertionPoint(op);
+    match.first->rewrite(op, std::move(match.second), rewriter);
+  }
+
+  uniquedConstants.clear();
+}
+
+static void processMLFunction(MLFunction *fn, OwningPatternList &&patterns) {
+  class MLFuncRewriter : public WorklistRewriter {
+  public:
+    MLFuncRewriter(GreedyPatternRewriteDriver &driver, MLFuncBuilder &builder)
+        : WorklistRewriter(driver, builder.getContext()), builder(builder) {}
+
+    // Implement the hook for creating operations, and make sure that newly
+    // created ops are added to the worklist for processing.
+    Operation *createOperation(const OperationState &state) override {
+      auto *result = builder.createOperation(state);
+      driver.addToWorklist(result);
+      return result;
+    }
+
+    // When the root of a pattern is about to be replaced, it can trigger
+    // simplifications to its users - make sure to add them to the worklist
+    // before the root is changed.
+    void notifyRootReplaced(Operation *op) override {
+      auto *opStmt = cast<OperationStmt>(op);
+      for (auto *result : opStmt->getResults())
+        // TODO: Add a result->getUsers() iterator.
+        for (auto &user : result->getUses()) {
+          if (auto *op = dyn_cast<OperationStmt>(user.getOwner()))
+            driver.addToWorklist(op);
+        }
+
+      // TODO: Walk the operand list dropping them as we go.  If any of them
+      // drop to zero uses, then add them to the worklist to allow them to be
+      // deleted as dead.
+    }
+
+    void setInsertionPoint(Operation *op) override {
+      // Any new operations should be added before this statement.
+      builder.setInsertionPoint(cast<OperationStmt>(op));
+    }
+
+  private:
+    MLFuncBuilder &builder;
+  };
+
+  GreedyPatternRewriteDriver driver(std::move(patterns));
+  fn->walk([&](OperationStmt *stmt) { driver.addToWorklist(stmt); });
+
+  MLFuncBuilder mlBuilder(fn);
+  MLFuncRewriter rewriter(driver, mlBuilder);
+  driver.simplifyFunction(fn, rewriter);
+}
+
+static void processCFGFunction(CFGFunction *fn, OwningPatternList &&patterns) {
+  class CFGFuncRewriter : public WorklistRewriter {
+  public:
+    CFGFuncRewriter(GreedyPatternRewriteDriver &driver, CFGFuncBuilder &builder)
+        : WorklistRewriter(driver, builder.getContext()), builder(builder) {}
+
+    // Implement the hook for creating operations, and make sure that newly
+    // created ops are added to the worklist for processing.
+    Operation *createOperation(const OperationState &state) override {
+      auto *result = builder.createOperation(state);
+      driver.addToWorklist(result);
+      return result;
+    }
+
+    // When the root of a pattern is about to be replaced, it can trigger
+    // simplifications to its users - make sure to add them to the worklist
+    // before the root is changed.
+    void notifyRootReplaced(Operation *op) override {
+      auto *opStmt = cast<OperationInst>(op);
+      for (auto *result : opStmt->getResults())
+        // TODO: Add a result->getUsers() iterator.
+        for (auto &user : result->getUses()) {
+          if (auto *op = dyn_cast<OperationInst>(user.getOwner()))
+            driver.addToWorklist(op);
+        }
+
+      // TODO: Walk the operand list dropping them as we go.  If any of them
+      // drop to zero uses, then add them to the worklist to allow them to be
+      // deleted as dead.
+    }
+
+    void setInsertionPoint(Operation *op) override {
+      // Any new operations should be added before this instruction.
+      builder.setInsertionPoint(cast<OperationInst>(op));
+    }
+
+  private:
+    CFGFuncBuilder &builder;
+  };
+
+  GreedyPatternRewriteDriver driver(std::move(patterns));
+  for (auto &bb : *fn)
+    for (auto &op : bb)
+      driver.addToWorklist(&op);
+
+  CFGFuncBuilder cfgBuilder(fn);
+  CFGFuncRewriter rewriter(driver, cfgBuilder);
+  driver.simplifyFunction(fn, rewriter);
+}
+
+/// Rewrite the specified function by repeatedly applying the highest benefit
+/// patterns in a greedy work-list driven manner.
+///
+void mlir::applyPatternsGreedily(Function *fn, OwningPatternList &&patterns) {
+  if (auto *cfg = dyn_cast<CFGFunction>(fn)) {
+    processCFGFunction(cfg, std::move(patterns));
+  } else {
+    processMLFunction(cast<MLFunction>(fn), std::move(patterns));
+  }
+}
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
new file mode 100644 (file)
index 0000000..a6a8502
--- /dev/null
@@ -0,0 +1,316 @@
+//===- LoopUtils.cpp ---- Misc utilities for loop transformation ----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements miscellaneous loop transformation routines.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/LoopUtils.h"
+
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Statements.h"
+#include "mlir/IR/StmtVisitor.h"
+#include "mlir/StandardOps/StandardOps.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "LoopUtils"
+
+using namespace mlir;
+
+/// Returns the upper bound of an unrolled loop with lower bound 'lb' and with
+/// the specified trip count, stride, and unroll factor. Returns nullptr when
+/// the trip count can't be expressed as an affine expression.
+AffineMap mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt,
+                                          unsigned unrollFactor,
+                                          MLFuncBuilder *builder) {
+  auto lbMap = forStmt.getLowerBoundMap();
+
+  // Single result lower bound map only.
+  if (lbMap.getNumResults() != 1)
+    return AffineMap::Null();
+
+  // Sometimes, the trip count cannot be expressed as an affine expression.
+  auto tripCount = getTripCountExpr(forStmt);
+  if (!tripCount)
+    return AffineMap::Null();
+
+  AffineExpr lb(lbMap.getResult(0));
+  unsigned step = forStmt.getStep();
+  auto newUb = lb + (tripCount - tripCount % unrollFactor - 1) * step;
+
+  return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(),
+                               {newUb}, {});
+}
+
+/// Returns the lower bound of the cleanup loop when unrolling a loop with lower
+/// bound 'lb' and with the specified trip count, stride, and unroll factor.
+/// Returns an AffinMap with nullptr storage (that evaluates to false)
+/// when the trip count can't be expressed as an affine expression.
+AffineMap mlir::getCleanupLoopLowerBound(const ForStmt &forStmt,
+                                         unsigned unrollFactor,
+                                         MLFuncBuilder *builder) {
+  auto lbMap = forStmt.getLowerBoundMap();
+
+  // Single result lower bound map only.
+  if (lbMap.getNumResults() != 1)
+    return AffineMap::Null();
+
+  // Sometimes the trip count cannot be expressed as an affine expression.
+  AffineExpr tripCount(getTripCountExpr(forStmt));
+  if (!tripCount)
+    return AffineMap::Null();
+
+  AffineExpr lb(lbMap.getResult(0));
+  unsigned step = forStmt.getStep();
+  auto newLb = lb + (tripCount - tripCount % unrollFactor) * step;
+  return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(),
+                               {newLb}, {});
+}
+
+/// Promotes the loop body of a forStmt to its containing block if the forStmt
+/// was known to have a single iteration. Returns false otherwise.
+// TODO(bondhugula): extend this for arbitrary affine bounds.
+bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
+  Optional<uint64_t> tripCount = getConstantTripCount(*forStmt);
+  if (!tripCount.hasValue() || tripCount.getValue() != 1)
+    return false;
+
+  // TODO(mlir-team): there is no builder for a max.
+  if (forStmt->getLowerBoundMap().getNumResults() != 1)
+    return false;
+
+  // Replaces all IV uses to its single iteration value.
+  if (!forStmt->use_empty()) {
+    if (forStmt->hasConstantLowerBound()) {
+      auto *mlFunc = forStmt->findFunction();
+      MLFuncBuilder topBuilder(&mlFunc->front());
+      auto constOp = topBuilder.create<ConstantIndexOp>(
+          forStmt->getLoc(), forStmt->getConstantLowerBound());
+      forStmt->replaceAllUsesWith(constOp);
+    } else {
+      const AffineBound lb = forStmt->getLowerBound();
+      SmallVector<SSAValue *, 4> lbOperands(lb.operand_begin(),
+                                            lb.operand_end());
+      MLFuncBuilder builder(forStmt->getBlock(), StmtBlock::iterator(forStmt));
+      auto affineApplyOp = builder.create<AffineApplyOp>(
+          forStmt->getLoc(), lb.getMap(), lbOperands);
+      forStmt->replaceAllUsesWith(affineApplyOp->getResult(0));
+    }
+  }
+  // Move the loop body statements to the loop's containing block.
+  auto *block = forStmt->getBlock();
+  block->getStatements().splice(StmtBlock::iterator(forStmt),
+                                forStmt->getStatements());
+  forStmt->erase();
+  return true;
+}
+
+/// Promotes all single iteration for stmt's in the MLFunction, i.e., moves
+/// their body into the containing StmtBlock.
+void mlir::promoteSingleIterationLoops(MLFunction *f) {
+  // Gathers all innermost loops through a post order pruned walk.
+  class LoopBodyPromoter : public StmtWalker<LoopBodyPromoter> {
+  public:
+    void visitForStmt(ForStmt *forStmt) { promoteIfSingleIteration(forStmt); }
+  };
+
+  LoopBodyPromoter fsw;
+  fsw.walkPostOrder(f);
+}
+
+/// Generates a for 'stmt' with the specified lower and upper bounds while
+/// generating the right IV remappings for the delayed statements. The
+/// statement blocks that go into the loop are specified in stmtGroupQueue
+/// starting from the specified offset, and in that order; the first element of
+/// the pair specifies the delay applied to that group of statements. Returns
+/// nullptr if the generated loop simplifies to a single iteration one.
+static ForStmt *
+generateLoop(AffineMap lb, AffineMap ub,
+             const std::vector<std::pair<uint64_t, ArrayRef<Statement *>>>
+                 &stmtGroupQueue,
+             unsigned offset, ForStmt *srcForStmt, MLFuncBuilder *b) {
+  SmallVector<MLValue *, 4> lbOperands(srcForStmt->getLowerBoundOperands());
+  SmallVector<MLValue *, 4> ubOperands(srcForStmt->getUpperBoundOperands());
+
+  auto *loopChunk =
+      b->createFor(srcForStmt->getLoc(), lbOperands, lb, ubOperands, ub);
+  OperationStmt::OperandMapTy operandMap;
+
+  for (auto it = stmtGroupQueue.begin() + offset, e = stmtGroupQueue.end();
+       it != e; ++it) {
+    auto elt = *it;
+    // All 'same delay' statements get added with the operands being remapped
+    // (to results of cloned statements).
+    // Generate the remapping if the delay is not zero: oldIV = newIV - delay.
+    // TODO(bondhugula): check if srcForStmt is actually used in elt.second
+    // instead of just checking if it's used at all.
+    if (!srcForStmt->use_empty() && elt.first != 0) {
+      auto b = MLFuncBuilder::getForStmtBodyBuilder(loopChunk);
+      auto *oldIV =
+          b.create<AffineApplyOp>(
+               srcForStmt->getLoc(),
+               b.getSingleDimShiftAffineMap(-static_cast<int64_t>(elt.first)),
+               loopChunk)
+              ->getResult(0);
+      operandMap[srcForStmt] = cast<MLValue>(oldIV);
+    } else {
+      operandMap[srcForStmt] = static_cast<MLValue *>(loopChunk);
+    }
+    for (auto *stmt : elt.second) {
+      loopChunk->push_back(stmt->clone(operandMap, b->getContext()));
+    }
+  }
+  if (promoteIfSingleIteration(loopChunk))
+    return nullptr;
+  return loopChunk;
+}
+
+/// Skew the statements in the body of a 'for' statement with the specified
+/// statement-wise delays. The delays are with respect to the original execution
+/// order. A delay of zero for each statement will lead to no change.
+// The skewing of statements with respect to one another can be used for example
+// to allow overlap of asynchronous operations (such as DMA communication) with
+// computation, or just relative shifting of statements for better register
+// reuse, locality or parallelism. As such, the delays are typically expected to
+// be at most of the order of the number of statements. This method should not
+// be used as a substitute for loop distribution/fission.
+// This method uses an algorithm// in time linear in the number of statements in
+// the body of the for loop - (using the 'sweep line' paradigm). This method
+// asserts preservation of SSA dominance. A check for that as well as that for
+// memory-based depedence preservation check rests with the users of this
+// method.
+UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays,
+                              bool unrollPrologueEpilogue) {
+  if (forStmt->getStatements().empty())
+    return UtilResult::Success;
+
+  // If the trip counts aren't constant, we would need versioning and
+  // conditional guards (or context information to prevent such versioning). The
+  // better way to pipeline for such loops is to first tile them and extract
+  // constant trip count "full tiles" before applying this.
+  auto mayBeConstTripCount = getConstantTripCount(*forStmt);
+  if (!mayBeConstTripCount.hasValue()) {
+    LLVM_DEBUG(llvm::dbgs() << "non-constant trip count loop\n";);
+    return UtilResult::Success;
+  }
+  uint64_t tripCount = mayBeConstTripCount.getValue();
+
+  assert(isStmtwiseShiftValid(*forStmt, delays) &&
+         "shifts will lead to an invalid transformation\n");
+
+  unsigned numChildStmts = forStmt->getStatements().size();
+
+  // Do a linear time (counting) sort for the delays.
+  uint64_t maxDelay = 0;
+  for (unsigned i = 0; i < numChildStmts; i++) {
+    maxDelay = std::max(maxDelay, delays[i]);
+  }
+  // Such large delays are not the typical use case.
+  if (maxDelay >= numChildStmts) {
+    LLVM_DEBUG(llvm::dbgs() << "stmt delays too large - unexpected\n";);
+    return UtilResult::Success;
+  }
+
+  // An array of statement groups sorted by delay amount; each group has all
+  // statements with the same delay in the order in which they appear in the
+  // body of the 'for' stmt.
+  std::vector<std::vector<Statement *>> sortedStmtGroups(maxDelay + 1);
+  unsigned pos = 0;
+  for (auto &stmt : *forStmt) {
+    auto delay = delays[pos++];
+    sortedStmtGroups[delay].push_back(&stmt);
+  }
+
+  // Unless the shifts have a specific pattern (which actually would be the
+  // common use case), prologue and epilogue are not meaningfully defined.
+  // Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first
+  // loop generated as the prologue and the last as epilogue and unroll these
+  // fully.
+  ForStmt *prologue = nullptr;
+  ForStmt *epilogue = nullptr;
+
+  // Do a sweep over the sorted delays while storing open groups in a
+  // vector, and generating loop portions as necessary during the sweep. A block
+  // of statements is paired with its delay.
+  std::vector<std::pair<uint64_t, ArrayRef<Statement *>>> stmtGroupQueue;
+
+  auto origLbMap = forStmt->getLowerBoundMap();
+  uint64_t lbDelay = 0;
+  MLFuncBuilder b(forStmt);
+  for (uint64_t d = 0, e = sortedStmtGroups.size(); d < e; ++d) {
+    // If nothing is delayed by d, continue.
+    if (sortedStmtGroups[d].empty())
+      continue;
+    if (!stmtGroupQueue.empty()) {
+      assert(d >= 1 &&
+             "Queue expected to be empty when the first block is found");
+      // The interval for which the loop needs to be generated here is:
+      // ( lbDelay, min(lbDelay + tripCount - 1, d - 1) ] and the body of the
+      // loop needs to have all statements in stmtQueue in that order.
+      ForStmt *res;
+      if (lbDelay + tripCount - 1 < d - 1) {
+        res = generateLoop(
+            b.getShiftedAffineMap(origLbMap, lbDelay),
+            b.getShiftedAffineMap(origLbMap, lbDelay + tripCount - 1),
+            stmtGroupQueue, 0, forStmt, &b);
+        // Entire loop for the queued stmt groups generated, empty it.
+        stmtGroupQueue.clear();
+        lbDelay += tripCount;
+      } else {
+        res = generateLoop(b.getShiftedAffineMap(origLbMap, lbDelay),
+                           b.getShiftedAffineMap(origLbMap, d - 1),
+                           stmtGroupQueue, 0, forStmt, &b);
+        lbDelay = d;
+      }
+      if (!prologue && res)
+        prologue = res;
+      epilogue = res;
+    } else {
+      // Start of first interval.
+      lbDelay = d;
+    }
+    // Augment the list of statements that get into the current open interval.
+    stmtGroupQueue.push_back({d, sortedStmtGroups[d]});
+  }
+
+  // Those statements groups left in the queue now need to be processed (FIFO)
+  // and their loops completed.
+  for (unsigned i = 0, e = stmtGroupQueue.size(); i < e; ++i) {
+    uint64_t ubDelay = stmtGroupQueue[i].first + tripCount - 1;
+    epilogue = generateLoop(b.getShiftedAffineMap(origLbMap, lbDelay),
+                            b.getShiftedAffineMap(origLbMap, ubDelay),
+                            stmtGroupQueue, i, forStmt, &b);
+    lbDelay = ubDelay + 1;
+    if (!prologue)
+      prologue = epilogue;
+  }
+
+  // Erase the original for stmt.
+  forStmt->erase();
+
+  if (unrollPrologueEpilogue && prologue)
+    loopUnrollFull(prologue);
+  if (unrollPrologueEpilogue && !epilogue && epilogue != prologue)
+    loopUnrollFull(epilogue);
+
+  return UtilResult::Success;
+}
diff --git a/mlir/lib/Transforms/Utils/Pass.cpp b/mlir/lib/Transforms/Utils/Pass.cpp
new file mode 100644 (file)
index 0000000..8b11107
--- /dev/null
@@ -0,0 +1,41 @@
+//===- Pass.cpp - Pass infrastructure implementation ----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements common pass infrastructure.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Pass.h"
+#include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Module.h"
+
+using namespace mlir;
+
+/// Function passes walk a module and look at each function with their
+/// corresponding hooks and terminates upon error encountered.
+PassResult FunctionPass::runOnModule(Module *m) {
+  for (auto &fn : *m) {
+    if (auto *mlFunc = dyn_cast<MLFunction>(&fn))
+      if (runOnMLFunction(mlFunc))
+        return failure();
+    if (auto *cfgFunc = dyn_cast<CFGFunction>(&fn))
+      if (runOnCFGFunction(cfgFunc))
+        return failure();
+  }
+  return success();
+}
diff --git a/mlir/lib/Transforms/Utils/PatternMatch.cpp b/mlir/lib/Transforms/Utils/PatternMatch.cpp
new file mode 100644 (file)
index 0000000..6cc3b43
--- /dev/null
@@ -0,0 +1,196 @@
+//===- PatternMatch.cpp - Base classes for pattern match ------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/SSAValue.h"
+#include "mlir/IR/Statements.h"
+#include "mlir/StandardOps/StandardOps.h"
+#include "mlir/Transforms/PatternMatch.h"
+using namespace mlir;
+
+PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
+  assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
+         "This pattern match benefit is too large to represent");
+}
+
+unsigned short PatternBenefit::getBenefit() const {
+  assert(representation != ImpossibleToMatchSentinel &&
+         "Pattern doesn't match");
+  return representation;
+}
+
+bool PatternBenefit::operator==(const PatternBenefit& other) {
+  if (isImpossibleToMatch())
+    return other.isImpossibleToMatch();
+  if (other.isImpossibleToMatch())
+    return false;
+  return getBenefit() == other.getBenefit();
+}
+
+bool PatternBenefit::operator!=(const PatternBenefit& other) {
+  return !(*this == other);
+}
+
+//===----------------------------------------------------------------------===//
+// Pattern implementation
+//===----------------------------------------------------------------------===//
+
+Pattern::Pattern(StringRef rootName, MLIRContext *context,
+                 Optional<PatternBenefit> staticBenefit)
+    : rootKind(OperationName(rootName, context)), staticBenefit(staticBenefit) {
+}
+
+Pattern::Pattern(StringRef rootName, MLIRContext *context,
+                 unsigned staticBenefit)
+    : rootKind(rootName, context), staticBenefit(staticBenefit) {}
+
+Optional<PatternBenefit> Pattern::getStaticBenefit() const {
+  return staticBenefit;
+}
+
+OperationName Pattern::getRootKind() const { return rootKind; }
+
+void Pattern::rewrite(Operation *op, std::unique_ptr<PatternState> state,
+                      PatternRewriter &rewriter) const {
+  rewrite(op, rewriter);
+}
+
+void Pattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
+  llvm_unreachable("need to implement one of the rewrite functions!");
+}
+
+/// This method indicates that no match was found.
+PatternMatchResult Pattern::matchFailure() {
+  return {PatternBenefit::impossibleToMatch(), std::unique_ptr<PatternState>()};
+}
+
+/// This method indicates that a match was found and has the specified cost.
+PatternMatchResult
+Pattern::matchSuccess(PatternBenefit benefit,
+                      std::unique_ptr<PatternState> state) const {
+  assert((!getStaticBenefit().hasValue() ||
+          getStaticBenefit().getValue() == benefit) &&
+         "This version of matchSuccess must be called with a benefit that "
+         "matches the static benefit if set!");
+
+  return {benefit, std::move(state)};
+}
+
+/// This method indicates that a match was found for patterns that have a
+/// known static benefit.
+PatternMatchResult
+Pattern::matchSuccess(std::unique_ptr<PatternState> state) const {
+  auto benefit = getStaticBenefit();
+  assert(benefit.hasValue() && "Pattern doesn't have a static benefit");
+  return matchSuccess(benefit.getValue(), std::move(state));
+}
+
+//===----------------------------------------------------------------------===//
+// PatternRewriter implementation
+//===----------------------------------------------------------------------===//
+
+PatternRewriter::~PatternRewriter() {
+  // Out of line to provide a vtable anchor for the class.
+}
+
+/// This method is used as the final replacement hook for patterns that match
+/// a single result value.  In addition to replacing and removing the
+/// specified operation, clients can specify a list of other nodes that this
+/// replacement may make (perhaps transitively) dead.  If any of those ops are
+/// dead, this will remove them as well.
+void PatternRewriter::replaceSingleResultOp(
+    Operation *op, SSAValue *newValue, ArrayRef<SSAValue *> opsToRemoveIfDead) {
+  // Notify the rewriter subclass that we're about to replace this root.
+  notifyRootReplaced(op);
+
+  assert(op->getNumResults() == 1 && "op isn't a SingleResultOp!");
+  op->getResult(0)->replaceAllUsesWith(newValue);
+
+  notifyOperationRemoved(op);
+  op->erase();
+
+  // TODO: Process the opsToRemoveIfDead list, removing things and calling the
+  // notifyOperationRemoved hook in the process.
+}
+
+/// This method is used as the final notification hook for patterns that end
+/// up modifying the pattern root in place, by changing its operands.  This is
+/// a minor efficiency win (it avoids creating a new instruction and removing
+/// the old one) but also often allows simpler code in the client.
+///
+/// The opsToRemoveIfDead list is an optional list of nodes that the rewriter
+/// should remove if they are dead at this point.
+///
+void PatternRewriter::updatedRootInPlace(
+    Operation *op, ArrayRef<SSAValue *> opsToRemoveIfDead) {
+  // Notify the rewriter subclass that we're about to replace this root.
+  notifyRootUpdated(op);
+
+  // TODO: Process the opsToRemoveIfDead list, removing things and calling the
+  // notifyOperationRemoved hook in the process.
+}
+
+//===----------------------------------------------------------------------===//
+// PatternMatcher implementation
+//===----------------------------------------------------------------------===//
+
+/// Find the highest benefit pattern available in the pattern set for the DAG
+/// rooted at the specified node.  This returns the pattern if found, or null
+/// if there are no matches.
+auto PatternMatcher::findMatch(Operation *op) -> MatchResult {
+  // TODO: This is a completely trivial implementation, expand this in the
+  // future.
+
+  // Keep track of the best match, the benefit of it, and any matcher specific
+  // state it is maintaining.
+  MatchResult bestMatch = {nullptr, nullptr};
+  Optional<PatternBenefit> bestBenefit;
+
+  for (auto &pattern : patterns) {
+    // Ignore patterns that are for the wrong root.
+    if (pattern->getRootKind() != op->getName())
+      continue;
+
+    // If we know the static cost of the pattern is worse than what we've
+    // already found then don't run it.
+    auto staticBenefit = pattern->getStaticBenefit();
+    if (staticBenefit.hasValue() && bestBenefit.hasValue() &&
+        staticBenefit.getValue().getBenefit() <
+            bestBenefit.getValue().getBenefit())
+      continue;
+
+    // Check to see if this pattern matches this node.
+    auto result = pattern->match(op);
+    auto benefit = result.first;
+
+    // If this pattern failed to match, ignore it.
+    if (benefit.isImpossibleToMatch())
+      continue;
+
+    // If it matched but had lower benefit than our best match so far, then
+    // ignore it.
+    if (bestBenefit.hasValue() &&
+        benefit.getBenefit() < bestBenefit.getValue().getBenefit())
+      continue;
+
+    // Okay we found a match that is better than our previous one, remember it.
+    bestBenefit = benefit;
+    bestMatch = {pattern.get(), std::move(result.second)};
+  }
+
+  // If we found any match, return it.
+  return bestMatch;
+}
diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp
new file mode 100644 (file)
index 0000000..4432a80
--- /dev/null
@@ -0,0 +1,394 @@
+//===- Utils.cpp ---- Misc utilities for code and data transformation -----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements miscellaneous transformation routines for non-loop IR
+// structures.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Utils.h"
+
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/StandardOps/StandardOps.h"
+#include "mlir/Support/MathExtras.h"
+#include "llvm/ADT/DenseMap.h"
+
+using namespace mlir;
+
+/// Return true if this operation dereferences one or more memref's.
+// Temporary utility: will be replaced when this is modeled through
+// side-effects/op traits. TODO(b/117228571)
+static bool isMemRefDereferencingOp(const Operation &op) {
+  if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() ||
+      op.isa<DmaWaitOp>())
+    return true;
+  return false;
+}
+
+/// Replaces all uses of oldMemRef with newMemRef while optionally remapping
+/// old memref's indices to the new memref using the supplied affine map
+/// and adding any additional indices. The new memref could be of a different
+/// shape or rank, but of the same elemental type. Additional indices are added
+/// at the start for now.
+// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily
+// extended to add additional indices at any position.
+bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
+                                    MLValue *newMemRef,
+                                    ArrayRef<MLValue *> extraIndices,
+                                    AffineMap indexRemap) {
+  unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank();
+  (void)newMemRefRank; // unused in opt mode
+  unsigned oldMemRefRank = cast<MemRefType>(oldMemRef->getType())->getRank();
+  (void)newMemRefRank;
+  if (indexRemap) {
+    assert(indexRemap.getNumInputs() == oldMemRefRank);
+    assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
+  } else {
+    assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
+  }
+
+  // Assert same elemental type.
+  assert(cast<MemRefType>(oldMemRef->getType())->getElementType() ==
+         cast<MemRefType>(newMemRef->getType())->getElementType());
+
+  // Check if memref was used in a non-deferencing context.
+  for (const StmtOperand &use : oldMemRef->getUses()) {
+    auto *opStmt = cast<OperationStmt>(use.getOwner());
+    // Failure: memref used in a non-deferencing op (potentially escapes); no
+    // replacement in these cases.
+    if (!isMemRefDereferencingOp(*opStmt))
+      return false;
+  }
+
+  // Walk all uses of old memref. Statement using the memref gets replaced.
+  for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) {
+    StmtOperand &use = *(it++);
+    auto *opStmt = cast<OperationStmt>(use.getOwner());
+    assert(isMemRefDereferencingOp(*opStmt) &&
+           "memref deferencing op expected");
+
+    auto getMemRefOperandPos = [&]() -> unsigned {
+      unsigned i;
+      for (i = 0; i < opStmt->getNumOperands(); i++) {
+        if (opStmt->getOperand(i) == oldMemRef)
+          break;
+      }
+      assert(i < opStmt->getNumOperands() && "operand guaranteed to be found");
+      return i;
+    };
+    unsigned memRefOperandPos = getMemRefOperandPos();
+
+    // Construct the new operation statement using this memref.
+    SmallVector<MLValue *, 8> operands;
+    operands.reserve(opStmt->getNumOperands() + extraIndices.size());
+    // Insert the non-memref operands.
+    operands.insert(operands.end(), opStmt->operand_begin(),
+                    opStmt->operand_begin() + memRefOperandPos);
+    operands.push_back(newMemRef);
+
+    MLFuncBuilder builder(opStmt);
+    for (auto *extraIndex : extraIndices) {
+      // TODO(mlir-team): An operation/SSA value should provide a method to
+      // return the position of an SSA result in its defining
+      // operation.
+      assert(extraIndex->getDefiningStmt()->getNumResults() == 1 &&
+             "single result op's expected to generate these indices");
+      assert((cast<MLValue>(extraIndex)->isValidDim() ||
+              cast<MLValue>(extraIndex)->isValidSymbol()) &&
+             "invalid memory op index");
+      operands.push_back(cast<MLValue>(extraIndex));
+    }
+
+    // Construct new indices. The indices of a memref come right after it, i.e.,
+    // at position memRefOperandPos + 1.
+    SmallVector<SSAValue *, 4> indices(
+        opStmt->operand_begin() + memRefOperandPos + 1,
+        opStmt->operand_begin() + memRefOperandPos + 1 + oldMemRefRank);
+    if (indexRemap) {
+      auto remapOp =
+          builder.create<AffineApplyOp>(opStmt->getLoc(), indexRemap, indices);
+      // Remapped indices.
+      for (auto *index : remapOp->getOperation()->getResults())
+        operands.push_back(cast<MLValue>(index));
+    } else {
+      // No remapping specified.
+      for (auto *index : indices)
+        operands.push_back(cast<MLValue>(index));
+    }
+
+    // Insert the remaining operands unmodified.
+    operands.insert(operands.end(),
+                    opStmt->operand_begin() + memRefOperandPos + 1 +
+                        oldMemRefRank,
+                    opStmt->operand_end());
+
+    // Result types don't change. Both memref's are of the same elemental type.
+    SmallVector<Type *, 8> resultTypes;
+    resultTypes.reserve(opStmt->getNumResults());
+    for (const auto *result : opStmt->getResults())
+      resultTypes.push_back(result->getType());
+
+    // Create the new operation.
+    auto *repOp =
+        builder.createOperation(opStmt->getLoc(), opStmt->getName(), operands,
+                                resultTypes, opStmt->getAttrs());
+    // Replace old memref's deferencing op's uses.
+    unsigned r = 0;
+    for (auto *res : opStmt->getResults()) {
+      res->replaceAllUsesWith(repOp->getResult(r++));
+    }
+    opStmt->erase();
+  }
+  return true;
+}
+
+// Creates and inserts into 'builder' a new AffineApplyOp, with the number of
+// its results equal to the number of 'operands, as a composition
+// of all other AffineApplyOps reachable from input parameter 'operands'. If the
+// operands were drawing results from multiple affine apply ops, this also leads
+// to a collapse into a single affine apply op. The final results of the
+// composed AffineApplyOp are returned in output parameter 'results'.
+OperationStmt *
+mlir::createComposedAffineApplyOp(MLFuncBuilder *builder, Location *loc,
+                                  ArrayRef<MLValue *> operands,
+                                  ArrayRef<OperationStmt *> affineApplyOps,
+                                  SmallVectorImpl<SSAValue *> &results) {
+  // Create identity map with same number of dimensions as number of operands.
+  auto map = builder->getMultiDimIdentityMap(operands.size());
+  // Initialize AffineValueMap with identity map.
+  AffineValueMap valueMap(map, operands);
+
+  for (auto *opStmt : affineApplyOps) {
+    assert(opStmt->isa<AffineApplyOp>());
+    auto affineApplyOp = opStmt->cast<AffineApplyOp>();
+    // Forward substitute 'affineApplyOp' into 'valueMap'.
+    valueMap.forwardSubstitute(*affineApplyOp);
+  }
+  // Compose affine maps from all ancestor AffineApplyOps.
+  // Create new AffineApplyOp from 'valueMap'.
+  unsigned numOperands = valueMap.getNumOperands();
+  SmallVector<SSAValue *, 4> outOperands(numOperands);
+  for (unsigned i = 0; i < numOperands; ++i) {
+    outOperands[i] = valueMap.getOperand(i);
+  }
+  // Create new AffineApplyOp based on 'valueMap'.
+  auto affineApplyOp =
+      builder->create<AffineApplyOp>(loc, valueMap.getAffineMap(), outOperands);
+  results.resize(operands.size());
+  for (unsigned i = 0, e = operands.size(); i < e; ++i) {
+    results[i] = affineApplyOp->getResult(i);
+  }
+  return cast<OperationStmt>(affineApplyOp->getOperation());
+}
+
+/// Given an operation statement, inserts a new single affine apply operation,
+/// that is exclusively used by this operation statement, and that provides all
+/// operands that are results of an affine_apply as a function of loop iterators
+/// and program parameters and whose results are.
+///
+/// Before
+///
+/// for %i = 0 to #map(%N)
+///   %idx = affine_apply (d0) -> (d0 mod 2) (%i)
+///   "send"(%idx, %A, ...)
+///   "compute"(%idx)
+///
+/// After
+///
+/// for %i = 0 to #map(%N)
+///   %idx = affine_apply (d0) -> (d0 mod 2) (%i)
+///   "send"(%idx, %A, ...)
+///   %idx_ = affine_apply (d0) -> (d0 mod 2) (%i)
+///   "compute"(%idx_)
+///
+/// This allows applying different transformations on send and compute (for eg.
+/// different shifts/delays).
+///
+/// Returns nullptr either if none of opStmt's operands were the result of an
+/// affine_apply and thus there was no affine computation slice to create, or if
+/// all the affine_apply op's supplying operands to this opStmt do not have any
+/// uses besides this opStmt. Returns the new affine_apply operation statement
+/// otherwise.
+OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) {
+  // Collect all operands that are results of affine apply ops.
+  SmallVector<MLValue *, 4> subOperands;
+  subOperands.reserve(opStmt->getNumOperands());
+  for (auto *operand : opStmt->getOperands()) {
+    auto *defStmt = operand->getDefiningStmt();
+    if (defStmt && defStmt->isa<AffineApplyOp>()) {
+      subOperands.push_back(operand);
+    }
+  }
+
+  // Gather sequence of AffineApplyOps reachable from 'subOperands'.
+  SmallVector<OperationStmt *, 4> affineApplyOps;
+  getReachableAffineApplyOps(subOperands, affineApplyOps);
+  // Skip transforming if there are no affine maps to compose.
+  if (affineApplyOps.empty())
+    return nullptr;
+
+  // Check if all uses of the affine apply op's lie in this op stmt
+  // itself, in which case there would be nothing to do.
+  bool localized = true;
+  for (auto *op : affineApplyOps) {
+    for (auto *result : op->getResults()) {
+      for (auto &use : result->getUses()) {
+        if (use.getOwner() != opStmt) {
+          localized = false;
+          break;
+        }
+      }
+    }
+  }
+  if (localized)
+    return nullptr;
+
+  MLFuncBuilder builder(opStmt);
+  SmallVector<SSAValue *, 4> results;
+  auto *affineApplyStmt = createComposedAffineApplyOp(
+      &builder, opStmt->getLoc(), subOperands, affineApplyOps, results);
+  assert(results.size() == subOperands.size() &&
+         "number of results should be the same as the number of subOperands");
+
+  // Construct the new operands that include the results from the composed
+  // affine apply op above instead of existing ones (subOperands). So, they
+  // differ from opStmt's operands only for those operands in 'subOperands', for
+  // which they will be replaced by the corresponding one from 'results'.
+  SmallVector<MLValue *, 4> newOperands(opStmt->getOperands());
+  for (unsigned i = 0, e = newOperands.size(); i < e; i++) {
+    // Replace the subOperands from among the new operands.
+    unsigned j, f;
+    for (j = 0, f = subOperands.size(); j < f; j++) {
+      if (newOperands[i] == subOperands[j])
+        break;
+    }
+    if (j < subOperands.size()) {
+      newOperands[i] = cast<MLValue>(results[j]);
+    }
+  }
+
+  for (unsigned idx = 0; idx < newOperands.size(); idx++) {
+    opStmt->setOperand(idx, newOperands[idx]);
+  }
+
+  return affineApplyStmt;
+}
+
+void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
+  if (affineApplyOp->getOperation()->getOperationFunction()->getKind() !=
+      Function::Kind::MLFunc) {
+    // TODO: Support forward substitution for CFGFunctions.
+    return;
+  }
+  auto *opStmt = cast<OperationStmt>(affineApplyOp->getOperation());
+  // Iterate through all uses of all results of 'opStmt', forward substituting
+  // into any uses which are AffineApplyOps.
+  for (unsigned resultIndex = 0, e = opStmt->getNumResults(); resultIndex < e;
+       ++resultIndex) {
+    const MLValue *result = opStmt->getResult(resultIndex);
+    for (auto it = result->use_begin(); it != result->use_end();) {
+      StmtOperand &use = *(it++);
+      auto *useStmt = use.getOwner();
+      auto *useOpStmt = dyn_cast<OperationStmt>(useStmt);
+      // Skip if use is not AffineApplyOp.
+      if (useOpStmt == nullptr || !useOpStmt->isa<AffineApplyOp>())
+        continue;
+      // Advance iterator past 'opStmt' operands which also use 'result'.
+      while (it != result->use_end() && it->getOwner() == useStmt)
+        ++it;
+
+      MLFuncBuilder builder(useOpStmt);
+      // Initialize AffineValueMap with 'affineApplyOp' which uses 'result'.
+      auto oldAffineApplyOp = useOpStmt->cast<AffineApplyOp>();
+      AffineValueMap valueMap(*oldAffineApplyOp);
+      // Forward substitute 'result' at index 'i' into 'valueMap'.
+      valueMap.forwardSubstituteSingle(*affineApplyOp, resultIndex);
+
+      // Create new AffineApplyOp from 'valueMap'.
+      unsigned numOperands = valueMap.getNumOperands();
+      SmallVector<SSAValue *, 4> operands(numOperands);
+      for (unsigned i = 0; i < numOperands; ++i) {
+        operands[i] = valueMap.getOperand(i);
+      }
+      auto newAffineApplyOp = builder.create<AffineApplyOp>(
+          useOpStmt->getLoc(), valueMap.getAffineMap(), operands);
+
+      // Update all uses to use results from 'newAffineApplyOp'.
+      for (unsigned i = 0, e = useOpStmt->getNumResults(); i < e; ++i) {
+        oldAffineApplyOp->getResult(i)->replaceAllUsesWith(
+            newAffineApplyOp->getResult(i));
+      }
+      // Erase 'oldAffineApplyOp'.
+      oldAffineApplyOp->getOperation()->erase();
+    }
+  }
+}
+
+/// Folds the specified (lower or upper) bound to a constant if possible
+/// considering its operands. Returns false if the folding happens for any of
+/// the bounds, true otherwise.
+bool mlir::constantFoldBounds(ForStmt *forStmt) {
+  auto foldLowerOrUpperBound = [forStmt](bool lower) {
+    // Check if the bound is already a constant.
+    if (lower && forStmt->hasConstantLowerBound())
+      return true;
+    if (!lower && forStmt->hasConstantUpperBound())
+      return true;
+
+    // Check to see if each of the operands is the result of a constant.  If so,
+    // get the value.  If not, ignore it.
+    SmallVector<Attribute *, 8> operandConstants;
+    auto boundOperands = lower ? forStmt->getLowerBoundOperands()
+                               : forStmt->getUpperBoundOperands();
+    for (const auto *operand : boundOperands) {
+      Attribute *operandCst = nullptr;
+      if (auto *operandOp = operand->getDefiningOperation()) {
+        if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
+          operandCst = operandConstantOp->getValue();
+      }
+      operandConstants.push_back(operandCst);
+    }
+
+    AffineMap boundMap =
+        lower ? forStmt->getLowerBoundMap() : forStmt->getUpperBoundMap();
+    assert(boundMap.getNumResults() >= 1 &&
+           "bound maps should have at least one result");
+    SmallVector<Attribute *, 4> foldedResults;
+    if (boundMap.constantFold(operandConstants, foldedResults))
+      return true;
+
+    // Compute the max or min as applicable over the results.
+    assert(!foldedResults.empty() && "bounds should have at least one result");
+    auto maxOrMin = cast<IntegerAttr>(foldedResults[0])->getValue();
+    for (unsigned i = 1; i < foldedResults.size(); i++) {
+      auto foldedResult = cast<IntegerAttr>(foldedResults[i])->getValue();
+      maxOrMin = lower ? std::max(maxOrMin, foldedResult)
+                       : std::min(maxOrMin, foldedResult);
+    }
+    lower ? forStmt->setConstantLowerBound(maxOrMin)
+          : forStmt->setConstantUpperBound(maxOrMin);
+
+    // Return false on success.
+    return false;
+  };
+
+  bool ret = foldLowerOrUpperBound(/*lower=*/true);
+  ret &= foldLowerOrUpperBound(/*lower=*/false);
+  return ret;
+}