Extend loop unroll/unroll-and-jam to affine bounds + refactor related code.
authorUday Bondhugula <bondhugula@google.com>
Tue, 18 Sep 2018 17:22:03 +0000 (10:22 -0700)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 20:15:06 +0000 (13:15 -0700)
- extend loop unroll-jam similar to loop unroll for affine bounds
- extend both loop unroll/unroll-jam to deal with cleanup loop for non multiple
  of unroll factor.
- extend promotion of single iteration loops to work with affine bounds
- fix typo bugs in loop unroll
- refactor common code b/w loop unroll and loop unroll-jam
- move prototypes of non-pass transforms to LoopUtils.h
- add additional builder methods.
- introduce loopUnrollUpTo(factor) to unroll by either factor or trip count,
  whichever is less.
- remove Statement::isInnermost (not used for now - will come back at the right
  place/in right form later)

PiperOrigin-RevId: 213471227

14 files changed:
mlir/include/mlir/Analysis/LoopAnalysis.h
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/Statement.h
mlir/include/mlir/IR/Statements.h
mlir/include/mlir/Transforms/LoopUtils.h [new file with mode: 0644]
mlir/include/mlir/Transforms/Passes.h
mlir/lib/Analysis/LoopAnalysis.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/IR/Statement.cpp
mlir/lib/Transforms/LoopUnroll.cpp
mlir/lib/Transforms/LoopUnrollAndJam.cpp
mlir/lib/Transforms/LoopUtils.cpp
mlir/test/Transforms/unroll-jam.mlir
mlir/test/Transforms/unroll.mlir

index 482a74c3de0bfb9f82af4f4023ee895154078243..6bd55bb2c92ba25506ef5b80b803aea747765c8a 100644 (file)
@@ -32,7 +32,7 @@ class ForStmt;
 /// Returns the trip count of the loop as an affine expression if the latter is
 /// expressible as an affine expression, and nullptr otherwise. The trip count
 /// expression is simplified before returning.
-AffineExpr *getTripCount(const ForStmt &forStmt);
+AffineExpr *getTripCountExpr(const ForStmt &forStmt);
 
 /// Returns the trip count of the loop if it's a constant, None otherwise. This
 /// uses affine expression analysis and is able to determine constant trip count
index dc602563f642406adb12fa6700ab9826eda19bfa..4fb61573b3096e0506986314fcf53c78867c2117 100644 (file)
@@ -101,7 +101,9 @@ public:
   AffineSymbolExpr *getSymbolExpr(unsigned position);
   AffineConstantExpr *getConstantExpr(int64_t constant);
   AffineExpr *getAddExpr(AffineExpr *lhs, AffineExpr *rhs);
+  AffineExpr *getAddExpr(AffineExpr *lhs, int64_t rhs);
   AffineExpr *getSubExpr(AffineExpr *lhs, AffineExpr *rhs);
+  AffineExpr *getSubExpr(AffineExpr *lhs, int64_t rhs);
   AffineExpr *getMulExpr(AffineExpr *lhs, AffineExpr *rhs);
   AffineExpr *getMulExpr(AffineExpr *lhs, int64_t rhs);
   AffineExpr *getModExpr(AffineExpr *lhs, AffineExpr *rhs);
index 3b2bbeb3d079236b84d5915458736e584ae4ee9f..2af989f6c90ece11834f0106a82513658f5d8d08 100644 (file)
@@ -82,9 +82,6 @@ public:
   /// Returns nullptr if the statement is unlinked.
   MLFunction *findFunction() const;
 
-  /// Returns true if there are no more loops nested under this stmt.
-  bool isInnermost() const;
-
   /// Destroys this statement and its subclass data.
   void destroy();
 
index c49ca2678d0ff9c66f8781f6951b20c8493503b3..5eb262ce0ba14403b9cdc60df235b16213acb59c 100644 (file)
@@ -275,6 +275,10 @@ public:
   /// Sets the upper bound to the given constant value.
   void setConstantUpperBound(int64_t value);
 
+  /// Returns true if both the lower and upper bound have the same operand lists
+  /// (same operands in the same order).
+  bool matchingBoundOperandList() const;
+
   //===--------------------------------------------------------------------===//
   // Operands
   //===--------------------------------------------------------------------===//
@@ -343,7 +347,9 @@ private:
   AffineMap *ubMap;
   // Constant step.
   int64_t step;
-  // Operands for the lower and upper bounds.
+  // Operands for the lower and upper bounds, with the former followed by the
+  // latter. Dimensional operands are followed by symbolic operands for each
+  // bound.
   std::vector<StmtOperand> operands;
 
   explicit ForStmt(Location *location, unsigned numOperands, AffineMap *lbMap,
diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h
new file mode 100644 (file)
index 0000000..82f3f53
--- /dev/null
@@ -0,0 +1,77 @@
+//===- LoopUtils.h - Loop transformation utilities --------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This header file defines prototypes for various loop transformation utility
+// methods: these are not passes by themselves but are used either by passes,
+// optimization sequences, or in turn by other transformation utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_LOOP_UTILS_H
+#define MLIR_TRANSFORMS_LOOP_UTILS_H
+
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+
+class AffineMap;
+class ForStmt;
+class MLFunction;
+class MLFuncBuilder;
+
+/// Unrolls this for statement completely if the trip count is known to be
+/// constant. Returns false otherwise.
+bool loopUnrollFull(ForStmt *forStmt);
+/// Unrolls this for statement by the specified unroll factor. Returns false if
+/// the loop cannot be unrolled either due to restrictions or due to invalid
+/// unroll factors.
+bool loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor);
+/// Unrolls this loop by the specified unroll factor or its trip count,
+/// whichever is lower.
+bool loopUnrollUpToFactor(ForStmt *forStmt, uint64_t unrollFactor);
+
+/// Unrolls and jams this loop by the specified factor. Returns true if the loop
+/// is successfully unroll-jammed.
+bool loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor);
+
+/// Unrolls and jams this loop by the specified factor or by the trip count (if
+/// constant), whichever is lower.
+bool loopUnrollJamUpToFactor(ForStmt *forStmt, uint64_t unrollJamFactor);
+
+/// Promotes the loop body of a ForStmt to its containing block if the ForStmt
+/// was known to have a single iteration. Returns false otherwise.
+bool promoteIfSingleIteration(ForStmt *forStmt);
+
+/// Promotes all single iteration ForStmt's in the MLFunction, i.e., moves
+/// their body into the containing StmtBlock.
+void promoteSingleIterationLoops(MLFunction *f);
+
+/// Returns the lower bound of the cleanup loop when unrolling a loop
+/// with the specified unroll factor.
+AffineMap *getCleanupLoopLowerBound(const ForStmt &forStmt,
+                                    unsigned unrollFactor,
+                                    MLFuncBuilder *builder);
+
+/// Returns the upper bound of an unrolled loop when unrolling with
+/// the specified trip count, stride, and unroll factor.
+AffineMap *getUnrolledLoopUpperBound(const ForStmt &forStmt,
+                                     unsigned unrollFactor,
+                                     MLFuncBuilder *builder);
+
+} // end namespace mlir
+
+#endif // MLIR_TRANSFORMS_LOOP_UTILS_H
index e79c6f70085f7045cd1bfa3fae24276dbc1bbdf2..68ad0b109fbdbf85c3c1b00a8adfa1a9d3834045 100644 (file)
@@ -27,9 +27,7 @@
 
 namespace mlir {
 
-class ForStmt;
 class FunctionPass;
-class MLFunction;
 class MLFunctionPass;
 class ModulePass;
 
@@ -38,19 +36,11 @@ class ModulePass;
 MLFunctionPass *createLoopUnrollPass(int unrollFactor = -1,
                                      int unrollFull = -1);
 
-/// Unrolls this loop completely.
-bool loopUnrollFull(ForStmt *forStmt);
-/// Unrolls this loop by the specified unroll factor.
-bool loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor);
-
 /// Creates a loop unroll jam pass to unroll jam by the specified factor. A
 /// factor of -1 lets the pass use the default factor or the one on the command
 /// line if provided.
 MLFunctionPass *createLoopUnrollAndJamPass(int unrollJamFactor = -1);
 
-/// Unrolls and jams this loop by the specified factor.
-bool loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor);
-
 /// Creates an affine expression simplification pass.
 FunctionPass *createSimplifyAffineExprPass();
 
@@ -59,14 +49,6 @@ FunctionPass *createSimplifyAffineExprPass();
 /// generated CFG functions.
 ModulePass *createConvertToCFGPass();
 
-/// Promotes the loop body of a ForStmt to its containing block if the ForStmt
-/// was known to have a single iteration. Returns false otherwise.
-bool promoteIfSingleIteration(ForStmt *forStmt);
-
-/// Promotes all single iteration ForStmt's in the MLFunction, i.e., moves
-/// their body into the containing StmtBlock.
-void promoteSingleIterationLoops(MLFunction *f);
-
 } // end namespace mlir
 
-#endif // MLIR_TRANSFORMS_LOOP_H
+#endif // MLIR_TRANSFORMS_PASSES_H
index dc56d62b107f4bac1d95deb3ab89e5699f2f78d2..fa283ae65fc514c705995cf55b56c0a792a85589 100644 (file)
@@ -31,7 +31,7 @@ using mlir::AffineExpr;
 /// Returns the trip count of the loop as an affine expression if the latter is
 /// expressible as an affine expression, and nullptr otherwise. The trip count
 /// expression is simplified before returning.
-AffineExpr *mlir::getTripCount(const ForStmt &forStmt) {
+AffineExpr *mlir::getTripCountExpr(const ForStmt &forStmt) {
   // upper_bound - lower_bound + 1
   int64_t loopSpan;
 
@@ -43,32 +43,22 @@ AffineExpr *mlir::getTripCount(const ForStmt &forStmt) {
     int64_t ub = forStmt.getConstantUpperBound();
     loopSpan = ub - lb + 1;
   } else {
-    const AffineBound lb = forStmt.getLowerBound();
-    const AffineBound ub = forStmt.getUpperBound();
-    auto lbMap = lb.getMap();
-    auto ubMap = ub.getMap();
+    auto *lbMap = forStmt.getLowerBoundMap();
+    auto *ubMap = forStmt.getUpperBoundMap();
     // TODO(bondhugula): handle max/min of multiple expressions.
-    if (lbMap->getNumResults() != 1 || ubMap->getNumResults() != 1 ||
-        lbMap->getNumDims() != ubMap->getNumDims() ||
-        lbMap->getNumSymbols() != ubMap->getNumSymbols()) {
+    if (lbMap->getNumResults() != 1 || ubMap->getNumResults() != 1)
       return nullptr;
-    }
 
     // TODO(bondhugula): handle bounds with different operands.
-    unsigned i, e = lb.getNumOperands();
-    for (i = 0; i < e; i++) {
-      if (lb.getStmtOperand(i).get() != ub.getStmtOperand(i).get())
-        break;
-    }
     // Bounds have different operands, unhandled for now.
-    if (i != e)
+    if (!forStmt.matchingBoundOperandList())
       return nullptr;
 
     // ub_expr - lb_expr + 1
+    auto *lbExpr = lbMap->getResult(0);
+    auto *ubExpr = ubMap->getResult(0);
     auto *loopSpanExpr = AffineBinaryOpExpr::getAdd(
-        AffineBinaryOpExpr::getSub(ubMap->getResult(0), lbMap->getResult(0),
-                                   context),
-        1, context);
+        AffineBinaryOpExpr::getSub(ubExpr, lbExpr, context), 1, context);
 
     if (auto *expr = simplifyAffineExpr(loopSpanExpr, lbMap->getNumDims(),
                                         lbMap->getNumSymbols(), context))
@@ -95,7 +85,7 @@ AffineExpr *mlir::getTripCount(const ForStmt &forStmt) {
 /// method uses affine expression analysis (in turn using getTripCount) and is
 /// able to determine constant trip count in non-trivial cases.
 llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) {
-  AffineExpr *tripCountExpr = getTripCount(forStmt);
+  AffineExpr *tripCountExpr = getTripCountExpr(forStmt);
 
   if (auto *constExpr = dyn_cast_or_null<AffineConstantExpr>(tripCountExpr))
     return constExpr->getValue();
@@ -107,7 +97,7 @@ llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) {
 /// expression analysis is used (indirectly through getTripCount), and
 /// this method is thus able to determine non-trivial divisors.
 uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) {
-  AffineExpr *tripCountExpr = getTripCount(forStmt);
+  AffineExpr *tripCountExpr = getTripCountExpr(forStmt);
 
   if (!tripCountExpr)
     return 1;
index b18a843e707bcbb0a4d8dc087f134c127b336c85..5ded7e8a6522891bfd7a1d33eb267b09a78dfb5f 100644 (file)
@@ -157,6 +157,10 @@ AffineExpr *Builder::getAddExpr(AffineExpr *lhs, AffineExpr *rhs) {
   return AffineBinaryOpExpr::get(AffineExpr::Kind::Add, lhs, rhs, context);
 }
 
+AffineExpr *Builder::getAddExpr(AffineExpr *lhs, int64_t rhs) {
+  return AffineBinaryOpExpr::getAdd(lhs, rhs, context);
+}
+
 AffineExpr *Builder::getMulExpr(AffineExpr *lhs, AffineExpr *rhs) {
   return AffineBinaryOpExpr::get(AffineExpr::Kind::Mul, lhs, rhs, context);
 }
@@ -171,6 +175,10 @@ AffineExpr *Builder::getSubExpr(AffineExpr *lhs, AffineExpr *rhs) {
   return getAddExpr(lhs, getMulExpr(rhs, getConstantExpr(-1)));
 }
 
+AffineExpr *Builder::getSubExpr(AffineExpr *lhs, int64_t rhs) {
+  return AffineBinaryOpExpr::getAdd(lhs, -rhs, context);
+}
+
 AffineExpr *Builder::getModExpr(AffineExpr *lhs, AffineExpr *rhs) {
   return AffineBinaryOpExpr::get(AffineExpr::Kind::Mod, lhs, rhs, context);
 }
index e6681c5e6ff189aab99b78a2c46da730af30265e..c4eb5b82c5dd0a914436845fab94c2970bf43cfe 100644 (file)
@@ -84,18 +84,6 @@ MLFunction *Statement::findFunction() const {
   return block ? block->findFunction() : nullptr;
 }
 
-bool Statement::isInnermost() const {
-  struct NestedLoopCounter : public StmtWalker<NestedLoopCounter> {
-    unsigned numNestedLoops;
-    NestedLoopCounter() : numNestedLoops(0) {}
-    void walkForStmt(const ForStmt *fs) { numNestedLoops++; }
-  };
-
-  NestedLoopCounter nlc;
-  nlc.walk(const_cast<Statement *>(this));
-  return nlc.numNestedLoops == 1;
-}
-
 MLValue *Statement::getOperand(unsigned idx) {
   return getStmtOperand(idx).get();
 }
@@ -361,6 +349,20 @@ void ForStmt::setConstantUpperBound(int64_t value) {
   setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
 }
 
+bool ForStmt::matchingBoundOperandList() const {
+  if (lbMap->getNumDims() != ubMap->getNumDims() ||
+      lbMap->getNumSymbols() != ubMap->getNumSymbols())
+    return false;
+
+  unsigned numOperands = lbMap->getNumInputs();
+  for (unsigned i = 0, e = lbMap->getNumInputs(); i < e; i++) {
+    // Compare MLValue *'s.
+    if (getOperand(i) != getOperand(numOperands + i))
+      return false;
+  }
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 // IfStmt
 //===----------------------------------------------------------------------===//
index bc467bc9100777570b58508f88160ead29c0d9e1..cfbf5059659435fcefd05464d0f4fc8fbb89efa2 100644 (file)
@@ -27,6 +27,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/StandardOps.h"
 #include "mlir/IR/StmtVisitor.h"
+#include "mlir/Transforms/LoopUtils.h"
 #include "mlir/Transforms/Pass.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/Support/CommandLine.h"
@@ -176,76 +177,41 @@ bool mlir::loopUnrollFull(ForStmt *forStmt) {
   return false;
 }
 
-/// Returns the upper bound of an unrolled loop with lower bound 'lb' and with
-/// the specified trip count, stride, and unroll factor.
-static AffineMap *getUnrolledLoopUpperBound(AffineMap *lbMap,
-                                            uint64_t tripCount,
-                                            unsigned unrollFactor, int64_t step,
-                                            MLFuncBuilder *builder) {
-  assert(lbMap->getNumResults() == 1);
-  auto *lbExpr = lbMap->getResult(0);
-  // lbExpr + (count - count % unrollFactor - 1) * step).
-  auto *expr = builder->getAddExpr(
-      lbExpr, builder->getConstantExpr(
-                  (tripCount - tripCount % unrollFactor - 1) * step));
-  return builder->getAffineMap(lbMap->getNumDims(), lbMap->getNumSymbols(),
-                               {expr}, {});
-}
+/// Unrolls and jams this loop by the specified factor or by the trip count (if
+/// constant) whichever is lower.
+bool mlir::loopUnrollUpToFactor(ForStmt *forStmt, uint64_t unrollFactor) {
+  Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
 
-/// Returns the lower bound of the cleanup loop when unrolling a loop with lower
-/// bound 'lb' and with the specified trip count, stride, and unroll factor.
-static AffineMap *getCleanupLoopLowerBound(AffineMap *lbMap, uint64_t tripCount,
-                                           unsigned unrollFactor, int64_t step,
-                                           MLFuncBuilder *builder) {
-  assert(lbMap->getNumResults() == 1);
-  auto *lbExpr = lbMap->getResult(0);
-  // lbExpr + (count - count % unrollFactor) * step);
-  auto *expr = builder->getAddExpr(
-      lbExpr,
-      builder->getConstantExpr((tripCount - tripCount % unrollFactor) * step));
-  return builder->getAffineMap(lbMap->getNumDims(), lbMap->getNumSymbols(),
-                               {expr}, {});
+  if (mayBeConstantTripCount.hasValue() &&
+      mayBeConstantTripCount.getValue() < unrollFactor)
+    return loopUnrollByFactor(forStmt, mayBeConstantTripCount.getValue());
+  return loopUnrollByFactor(forStmt, unrollFactor);
 }
 
-/// Unrolls this loop by the specified unroll factor.
+/// Unrolls this loop by the specified factor. Returns true if the loop
+/// is successfully unrolled.
 bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
-  assert(unrollFactor >= 1 && "unroll factor shoud be >= 1");
+  assert(unrollFactor >= 1 && "unroll factor should be >= 1");
 
   if (unrollFactor == 1 || forStmt->getStatements().empty())
     return false;
 
-  Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
-
-  if (!mayBeConstantTripCount.hasValue() &&
-      getLargestDivisorOfTripCount(*forStmt) % unrollFactor != 0)
-    return false;
-
-  const AffineBound &lb = forStmt->getLowerBound();
-  const AffineBound &ub = forStmt->getLowerBound();
-  auto lbMap = lb.getMap();
-  auto ubMap = lb.getMap();
+  auto *lbMap = forStmt->getLowerBoundMap();
+  auto *ubMap = forStmt->getUpperBoundMap();
 
   // Loops with max/min expressions won't be unrolled here (the output can't be
   // expressed as an MLFunction in the general case). However, the right way to
   // do such unrolling for an MLFunction would be to specialize the loop for the
-  // 'hotspot' case and unroll that hotspot case.
+  // 'hotspot' case and unroll that hotspot.
   if (lbMap->getNumResults() != 1 || ubMap->getNumResults() != 1)
     return false;
 
-  // TODO(bondhugula): handle bounds with different sets of operands.
-  // Same operand list for now.
-  if (lbMap->getNumDims() != ubMap->getNumDims() ||
-      lbMap->getNumSymbols() != ubMap->getNumSymbols())
-    return false;
-  unsigned i, e = lb.getNumOperands();
-  for (i = 0; i < e; i++) {
-    if (lb.getStmtOperand(i).get() != ub.getStmtOperand(i).get())
-      break;
-  }
-  if (i != e)
+  // Same operand list for lower and upper bound for now.
+  // TODO(bondhugula): handle bounds with different operand lists.
+  if (!forStmt->matchingBoundOperandList())
     return false;
 
-  int64_t step = forStmt->getStep();
+  Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
 
   // If the trip count is lower than the unroll factor, no unrolled body.
   // TODO(bondhugula): option to specify cleanup loop unrolling.
@@ -254,43 +220,29 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
     return false;
 
   // Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
-  // If the trip count is unknown, we currently unroll only when the unknown
-  // trip count is known to be a multiple of unroll factor - hence, no cleanup
-  // loop will be necessary in those cases.
-  // TODO(bondhugula): handle generation of cleanup loop for unknown trip count
-  // when it's not known to be a multiple of unroll factor (still for single
-  // result / same operands case).
-  if (mayBeConstantTripCount.hasValue() &&
-      mayBeConstantTripCount.getValue() % unrollFactor != 0) {
-    uint64_t tripCount = mayBeConstantTripCount.getValue();
+  if (getLargestDivisorOfTripCount(*forStmt) % unrollFactor != 0) {
     DenseMap<const MLValue *, MLValue *> operandMap;
     MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
     auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
-    if (forStmt->hasConstantLowerBound()) {
-      cleanupForStmt->setConstantLowerBound(
-          forStmt->getConstantLowerBound() +
-          (tripCount - tripCount % unrollFactor) * step);
-    } else {
-      cleanupForStmt->setLowerBoundMap(
-          getCleanupLoopLowerBound(forStmt->getLowerBoundMap(), tripCount,
-                                   unrollFactor, step, &builder));
-    }
+    auto *clLbMap = getCleanupLoopLowerBound(*forStmt, unrollFactor, &builder);
+    assert(clLbMap &&
+           "cleanup loop lower bound map for single result bound maps can "
+           "always be determined");
+    cleanupForStmt->setLowerBoundMap(clLbMap);
     // Promote the loop body up if this has turned into a single iteration loop.
     promoteIfSingleIteration(cleanupForStmt);
 
-    // The upper bound needs to be adjusted.
-    if (forStmt->hasConstantUpperBound()) {
-      forStmt->setConstantUpperBound(
-          forStmt->getConstantLowerBound() +
-          (tripCount - tripCount % unrollFactor - 1) * step);
-    } else {
-      forStmt->setUpperBoundMap(
-          getUnrolledLoopUpperBound(forStmt->getLowerBoundMap(), tripCount,
-                                    unrollFactor, step, &builder));
-    }
+    // Adjust upper bound.
+    auto *unrolledUbMap =
+        getUnrolledLoopUpperBound(*forStmt, unrollFactor, &builder);
+    assert(unrolledUbMap &&
+           "upper bound map can alwayys be determined for an unrolled loop "
+           "with single result bounds");
+    forStmt->setUpperBoundMap(unrolledUbMap);
   }
 
   // Scale the step of loop being unrolled by unroll factor.
+  int64_t step = forStmt->getStep();
   forStmt->setStep(step * unrollFactor);
 
   // Builder to insert unrolled bodies right after the last statement in the
index 96125edcde7be352ed058376af255b23239e108d..9d515ab6c8d3561e2f020312528c42815bf71968 100644 (file)
 //
 // Note: 'if/else' blocks are not jammed. So, if there are loops inside if
 // stmt's, bodies of those loops will not be jammed.
-//
 //===----------------------------------------------------------------------===//
 #include "mlir/Transforms/Passes.h"
 
 #include "mlir/Analysis/LoopAnalysis.h"
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/StandardOps.h"
 #include "mlir/IR/StmtVisitor.h"
+#include "mlir/Transforms/LoopUtils.h"
 #include "mlir/Transforms/Pass.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/Support/CommandLine.h"
@@ -108,6 +109,15 @@ bool LoopUnrollAndJam::runOnForStmt(ForStmt *forStmt) {
   return loopUnrollJamByFactor(forStmt, kDefaultUnrollJamFactor);
 }
 
+bool mlir::loopUnrollJamUpToFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
+  Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
+
+  if (mayBeConstantTripCount.hasValue() &&
+      mayBeConstantTripCount.getValue() < unrollJamFactor)
+    return loopUnrollJamByFactor(forStmt, mayBeConstantTripCount.getValue());
+  return loopUnrollJamByFactor(forStmt, unrollJamFactor);
+}
+
 /// Unrolls and jams this loop by the specified factor.
 bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
   // Gathers all maximal sub-blocks of statements that do not themselves include
@@ -140,19 +150,32 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
   if (unrollJamFactor == 1 || forStmt->getStatements().empty())
     return false;
 
-  Optional<uint64_t> mayTripCount = getConstantTripCount(*forStmt).getValue();
+  Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
 
-  if (!mayTripCount.hasValue())
+  if (!mayBeConstantTripCount.hasValue() &&
+      getLargestDivisorOfTripCount(*forStmt) % unrollJamFactor != 0)
     return false;
 
-  uint64_t tripCount = mayTripCount.getValue();
-  int64_t lb = forStmt->getConstantLowerBound();
-  int64_t step = forStmt->getStep();
+  auto *lbMap = forStmt->getLowerBoundMap();
+  auto *ubMap = forStmt->getUpperBoundMap();
+
+  // Loops with max/min expressions won't be unrolled here (the output can't be
+  // expressed as an MLFunction in the general case). However, the right way to
+  // do such unrolling for an MLFunction would be to specialize the loop for the
+  // 'hotspot' case and unroll that hotspot.
+  if (lbMap->getNumResults() != 1 || ubMap->getNumResults() != 1)
+    return false;
+
+  // Same operand list for lower and upper bound for now.
+  // TODO(bondhugula): handle bounds with different sets of operands.
+  if (!forStmt->matchingBoundOperandList())
+    return false;
 
-  // If the trip count is lower than the unroll jam factor, no unrolled body.
+  // If the trip count is lower than the unroll jam factor, no unroll jam.
   // TODO(bondhugula): option to specify cleanup loop unrolling.
-  if (tripCount < unrollJamFactor)
-    return true;
+  if (mayBeConstantTripCount.hasValue() &&
+      mayBeConstantTripCount.getValue() < unrollJamFactor)
+    return false;
 
   // Gather all sub-blocks to jam upon the loop being unrolled.
   JamBlockGatherer jbg;
@@ -161,23 +184,27 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
 
   // Generate the cleanup loop if trip count isn't a multiple of
   // unrollJamFactor.
-  if (tripCount % unrollJamFactor) {
+  if (mayBeConstantTripCount.hasValue() &&
+      mayBeConstantTripCount.getValue() % unrollJamFactor != 0) {
     DenseMap<const MLValue *, MLValue *> operandMap;
     // Insert the cleanup loop right after 'forStmt'.
     MLFuncBuilder builder(forStmt->getBlock(),
                           std::next(StmtBlock::iterator(forStmt)));
     auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
-    cleanupForStmt->setConstantLowerBound(
-        lb + (tripCount - tripCount % unrollJamFactor) * step);
+    cleanupForStmt->setLowerBoundMap(
+        getCleanupLoopLowerBound(*forStmt, unrollJamFactor, &builder));
+
+    // The upper bound needs to be adjusted.
+    forStmt->setUpperBoundMap(
+        getUnrolledLoopUpperBound(*forStmt, unrollJamFactor, &builder));
 
     // Promote the loop body up if this has turned into a single iteration loop.
     promoteIfSingleIteration(cleanupForStmt);
   }
 
-  MLFuncBuilder b(forStmt);
+  // Scale the step of loop being unroll-jammed by the unroll-jam factor.
+  int64_t step = forStmt->getStep();
   forStmt->setStep(step * unrollJamFactor);
-  forStmt->setConstantUpperBound(
-      lb + (tripCount - tripCount % unrollJamFactor - 1) * step);
 
   for (auto &subBlock : subBlocks) {
     // Builder to insert unroll-jammed bodies. Insert right at the end of
index 2eecb6f114fa458492ffc3419307c46d05655dc0..1261afef6d749bce5dff92e7d9712b1160b4384e 100644 (file)
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/LoopUtils.h"
 
 #include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/StandardOps.h"
 #include "mlir/IR/Statements.h"
 #include "mlir/IR/StmtVisitor.h"
 
+using namespace mlir;
+
+/// Returns the upper bound of an unrolled loop with lower bound 'lb' and with
+/// the specified trip count, stride, and unroll factor. Returns nullptr when
+/// the trip count can't be expressed as an affine expression.
+AffineMap *mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt,
+                                           unsigned unrollFactor,
+                                           MLFuncBuilder *builder) {
+  auto *lbMap = forStmt.getLowerBoundMap();
+
+  // Single result lower bound map only.
+  if (lbMap->getNumResults() != 1)
+    return nullptr;
+
+  // Sometimes, the trip count cannot be expressed as an affine expression.
+  auto *tripCountExpr = getTripCountExpr(forStmt);
+  if (!tripCountExpr)
+    return nullptr;
+
+  AffineExpr *newUbExpr;
+  auto *lbExpr = lbMap->getResult(0);
+  int64_t step = forStmt.getStep();
+  // lbExpr + (count - count % unrollFactor - 1) * step).
+  if (auto *cTripCountExpr = dyn_cast<AffineConstantExpr>(tripCountExpr)) {
+    uint64_t tripCount = static_cast<uint64_t>(cTripCountExpr->getValue());
+    newUbExpr = builder->getAddExpr(
+        lbExpr, builder->getConstantExpr(
+                    (tripCount - tripCount % unrollFactor - 1) * step));
+  } else {
+    newUbExpr = builder->getAddExpr(
+        lbExpr, builder->getMulExpr(
+                    builder->getSubExpr(
+                        builder->getSubExpr(
+                            tripCountExpr,
+                            builder->getModExpr(tripCountExpr, unrollFactor)),
+                        1),
+                    step));
+  }
+  return builder->getAffineMap(lbMap->getNumDims(), lbMap->getNumSymbols(),
+                               {newUbExpr}, {});
+}
+
+/// Returns the lower bound of the cleanup loop when unrolling a loop with lower
+/// bound 'lb' and with the specified trip count, stride, and unroll factor.
+/// Returns nullptr when the trip count can't be expressed as an affine
+/// expression.
+AffineMap *mlir::getCleanupLoopLowerBound(const ForStmt &forStmt,
+                                          unsigned unrollFactor,
+                                          MLFuncBuilder *builder) {
+  auto *lbMap = forStmt.getLowerBoundMap();
+
+  // Single result lower bound map only.
+  if (lbMap->getNumResults() != 1)
+    return nullptr;
+
+  // Sometimes the trip count cannot be expressed as an affine expression.
+  auto *tripCountExpr = getTripCountExpr(forStmt);
+  if (!tripCountExpr)
+    return nullptr;
+
+  AffineExpr *newLbExpr;
+  auto *lbExpr = lbMap->getResult(0);
+  int64_t step = forStmt.getStep();
+
+  // lbExpr + (count - count % unrollFactor) * step);
+  if (auto *cTripCountExpr = dyn_cast<AffineConstantExpr>(tripCountExpr)) {
+    uint64_t tripCount = static_cast<uint64_t>(cTripCountExpr->getValue());
+    newLbExpr = builder->getAddExpr(
+        lbExpr, builder->getConstantExpr(
+                    (tripCount - tripCount % unrollFactor) * step));
+  } else {
+    newLbExpr = builder->getAddExpr(
+        lbExpr, builder->getMulExpr(
+                    builder->getSubExpr(
+                        tripCountExpr,
+                        builder->getModExpr(tripCountExpr, unrollFactor)),
+                    step));
+  }
+  return builder->getAffineMap(lbMap->getNumDims(), lbMap->getNumSymbols(),
+                               {newLbExpr}, {});
+}
+
 /// Promotes the loop body of a forStmt to its containing block if the forStmt
 /// was known to have a single iteration. Returns false otherwise.
 // TODO(bondhugula): extend this for arbitrary affine bounds.
 bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
   Optional<uint64_t> tripCount = getConstantTripCount(*forStmt);
-  if (!tripCount.hasValue() || !forStmt->hasConstantLowerBound())
+  if (!tripCount.hasValue() || tripCount.getValue() != 1)
     return false;
 
-  if (tripCount.getValue() != 1)
+  // TODO(mlir-team): there is no builder for a max.
+  if (forStmt->getLowerBoundMap()->getNumResults() != 1)
     return false;
 
   // Replaces all IV uses to its single iteration value.
-  auto *mlFunc = forStmt->findFunction();
-  MLFuncBuilder topBuilder(&mlFunc->front());
-  auto constOp = topBuilder.create<ConstantAffineIntOp>(
-      forStmt->getLoc(), forStmt->getConstantLowerBound());
-  forStmt->replaceAllUsesWith(constOp->getResult());
-  // Move the statements to the containing block.
+  if (!forStmt->use_empty()) {
+    if (forStmt->hasConstantLowerBound()) {
+      auto *mlFunc = forStmt->findFunction();
+      MLFuncBuilder topBuilder(&mlFunc->front());
+      auto constOp = topBuilder.create<ConstantAffineIntOp>(
+          forStmt->getLoc(), forStmt->getConstantLowerBound());
+      forStmt->replaceAllUsesWith(constOp->getResult());
+    } else {
+      const AffineBound lb = forStmt->getLowerBound();
+      SmallVector<SSAValue *, 4> lbOperands(lb.operand_begin(),
+                                            lb.operand_end());
+      MLFuncBuilder builder(forStmt->getBlock(), StmtBlock::iterator(forStmt));
+      auto affineApplyOp = builder.create<AffineApplyOp>(
+          forStmt->getLoc(), lb.getMap(), lbOperands);
+      forStmt->replaceAllUsesWith(affineApplyOp->getResult(0));
+    }
+  }
+  // Move the loop body statements to the loop's containing block.
   auto *block = forStmt->getBlock();
   block->getStatements().splice(StmtBlock::iterator(forStmt),
                                 forStmt->getStatements());
index 19cd1df6088438c7ec5f71cf017b2fe966d6caee..be5ccbe84feb345430de2b47f60eff89c9418aa0 100644 (file)
@@ -1,6 +1,8 @@
 // RUN: mlir-opt %s -o - -loop-unroll-jam -unroll-jam-factor=2 | FileCheck %s
 
 // CHECK: #map0 = (d0) -> (d0 + 1)
+// This should be matched to M1, but M1 is defined later.
+// CHECK: {{#map[0-9]+}} = ()[s0] -> (s0 + 8)
 
 // CHECK-LABEL: mlfunc @unroll_jam_imperfect_nest() {
 mlfunc @unroll_jam_imperfect_nest() {
@@ -34,3 +36,55 @@ mlfunc @unroll_jam_imperfect_nest() {
   // CHECK-NEXT: %14 = "addi32"(%c100, %c100) : (affineint, affineint) -> i32
   return
 }
+
+// UNROLL-BY-4-LABEL: mlfunc @loop_nest_unknown_count_1(%arg0 : affineint) {
+mlfunc @loop_nest_unknown_count_1(%N : affineint) {
+  // UNROLL-BY-4-NEXT: for %i0 = 1 to  #map{{[0-9]+}}()[%arg0] step 4 {
+    // UNROLL-BY-4-NEXT: for %i1 = 1 to 100 {
+      // UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32
+      // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32
+      // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32
+      // UNROLL-BY-4-NEXT: %3 = "foo"() : () -> i32
+    // UNROLL-BY-4-NEXT: }
+  // UNROLL-BY-4-NEXT: }
+  // A cleanup loop should be generated here.
+  // UNROLL-BY-4-NEXT: for %i2 = #map{{[0-9]+}}()[%arg0] to %arg0 {
+    // UNROLL-BY-4-NEXT: for %i3 = 1 to 100 {
+      // UNROLL-BY-4-NEXT: %4 = "foo"() : () -> i32
+    // UNROLL-BY-4_NEXT: }
+  // UNROLL-BY-4_NEXT: }
+  // Specify the lower bound in a form so that both lb and ub operands match.
+  for %i = ()[s0] -> (1)()[%N] to %N {
+    for %j = 1 to 100 {
+      %x = "foo"() : () -> i32
+    }
+  }
+  return
+}
+
+// UNROLL-BY-4-LABEL: mlfunc @loop_nest_unknown_count_2(%arg0 : affineint) {
+mlfunc @loop_nest_unknown_count_2(%arg : affineint) {
+  // UNROLL-BY-4-NEXT: for %i0 = %arg0 to  #map{{[0-9]+}}()[%arg0] step 4 {
+    // UNROLL-BY-4-NEXT: for %i1 = 1 to 100 {
+      // UNROLL-BY-4-NEXT: %0 = "foo"(%i0) : (affineint) -> i32
+      // UNROLL-BY-4-NEXT: %1 = affine_apply #map{{[0-9]+}}(%i0)
+      // UNROLL-BY-4-NEXT: %2 = "foo"(%1) : (affineint) -> i32
+      // UNROLL-BY-4-NEXT: %3 = affine_apply #map{{[0-9]+}}(%i0)
+      // UNROLL-BY-4-NEXT: %4 = "foo"(%3) : (affineint) -> i32
+      // UNROLL-BY-4-NEXT: %5 = affine_apply #map{{[0-9]+}}(%i0)
+      // UNROLL-BY-4-NEXT: %6 = "foo"(%5) : (affineint) -> i32
+    // UNROLL-BY-4-NEXT: }
+  // UNROLL-BY-4-NEXT: }
+  // The cleanup loop is a single iteration one and is promoted.
+  // UNROLL-BY-4-NEXT: %7 = affine_apply [[M1:#map{{[0-9]+}}]]()[%arg0]
+  // UNROLL-BY-4-NEXT: for %i3 = 1 to 100 {
+    // UNROLL-BY-4-NEXT: %8 = "foo"() : () -> i32
+  // UNROLL-BY-4_NEXT: }
+  // Specify the lower bound in a form so that both lb and ub operands match.
+  for %i = ()[s0] -> (s0) ()[%arg] to ()[s0] -> (s0+8) ()[%arg] {
+    for %j = 1 to 100 {
+      %x = "foo"(%i) : (affineint) -> i32
+    }
+  }
+  return
+}
index dbda50b104cc93603770f0532c1ca45350c54b07..a678fcd80e460c27d0ffd4699af01b8efc0cc156 100644 (file)
@@ -462,7 +462,7 @@ mlfunc @loop_nest_operand2() {
 }
 
 // Difference between loop bounds is constant, but not a multiple of unroll
-// factor. A cleanup loop is generated.
+// factor. The cleanup loop happens to be a single iteration one and is promoted.
 // UNROLL-BY-4-LABEL: mlfunc @loop_nest_operand3() {
 mlfunc @loop_nest_operand3() {
   // UNROLL-BY-4: for %i0 = 1 to 100 step 2 {
@@ -473,30 +473,30 @@ mlfunc @loop_nest_operand3() {
     // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32
     // UNROLL-BY-4-NEXT: %3 = "foo"() : () -> i32
     // UNROLL-BY-4-NEXT: }
-    // UNROLL-BY-4-NEXT: for %i2 = #map{{[0-9]+}}(%i0) to #map{{[0-9]+}}(%i0) {
     // UNROLL-BY-4-NEXT: %4 = "foo"() : () -> i32
-    // UNROLL-BY-4-NEXT: }
-    for %j = (d0) -> (d0) (%i) to (d0) -> (d0 + 4) (%i) {
+    for %j = (d0) -> (d0) (%i) to (d0) -> (d0 + 8) (%i) {
       %x = "foo"() : () -> i32
     }
   } // UNROLL-BY-4: }
   return
 }
 
-// Will not be unrolled for now. TODO(bondhugula): handle this.
-// xUNROLL-BY-4-LABEL: mlfunc @loop_nest_operand4(%arg0 : affineint) {
+// UNROLL-BY-4-LABEL: mlfunc @loop_nest_operand4(%arg0 : affineint) {
 mlfunc @loop_nest_operand4(%N : affineint) {
-  // UNROLL-BY-4: for %i0 = 1 to 100 step 2 {
-  for %i = 1 to 100 step 2 {
-    // UNROLL-BY-4: for %i1 = 0 to %arg0 {
-    // xUNROLL-BY-4: for %i1 = 0 to #map{{[0-9]+}}(%N) step 4 {
-    // xUNROLL-BY-4: %0 = "foo"() : () -> i32
-    // xUNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32
-    // xUNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32
-    // xUNROLL-BY-4-NEXT: %3 = "foo"() : () -> i32
-    // xUNROLL-BY-4-NEXT: }
-    // a cleanup loop should be generated here.
-    for %j = (d0) -> (0) (%N) to %N {
+  // UNROLL-BY-4: for %i0 = 1 to 100 {
+  for %i = 1 to 100 {
+    // UNROLL-BY-4: for %i1 = 1 to #map{{[0-9]+}}()[%arg0] step 4 {
+    // UNROLL-BY-4: %0 = "foo"() : () -> i32
+    // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32
+    // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32
+    // UNROLL-BY-4-NEXT: %3 = "foo"() : () -> i32
+    // UNROLL-BY-4-NEXT: }
+    // A cleanup loop will be be generated here.
+    // UNROLL-BY-4-NEXT: for %i2 = #map{{[0-9]+}}()[%arg0] to %arg0 {
+    // UNROLL-BY-4-NEXT: %4 = "foo"() : () -> i32
+    // UNROLL-BY-4_NEXT: }
+    // Specify the lower bound so that both lb and ub operands match.
+    for %j = ()[s0] -> (1)()[%N] to %N {
       %x = "foo"() : () -> i32
     }
   }