/// Returns the trip count of the loop as an affine expression if the latter is
/// expressible as an affine expression, and nullptr otherwise. The trip count
/// expression is simplified before returning.
-AffineExpr *getTripCount(const ForStmt &forStmt);
+AffineExpr *getTripCountExpr(const ForStmt &forStmt);
/// Returns the trip count of the loop if it's a constant, None otherwise. This
/// uses affine expression analysis and is able to determine constant trip count
AffineSymbolExpr *getSymbolExpr(unsigned position);
AffineConstantExpr *getConstantExpr(int64_t constant);
AffineExpr *getAddExpr(AffineExpr *lhs, AffineExpr *rhs);
+ AffineExpr *getAddExpr(AffineExpr *lhs, int64_t rhs);
AffineExpr *getSubExpr(AffineExpr *lhs, AffineExpr *rhs);
+ AffineExpr *getSubExpr(AffineExpr *lhs, int64_t rhs);
AffineExpr *getMulExpr(AffineExpr *lhs, AffineExpr *rhs);
AffineExpr *getMulExpr(AffineExpr *lhs, int64_t rhs);
AffineExpr *getModExpr(AffineExpr *lhs, AffineExpr *rhs);
/// Returns nullptr if the statement is unlinked.
MLFunction *findFunction() const;
- /// Returns true if there are no more loops nested under this stmt.
- bool isInnermost() const;
-
/// Destroys this statement and its subclass data.
void destroy();
/// Sets the upper bound to the given constant value.
void setConstantUpperBound(int64_t value);
+ /// Returns true if both the lower and upper bound have the same operand lists
+ /// (same operands in the same order).
+ bool matchingBoundOperandList() const;
+
//===--------------------------------------------------------------------===//
// Operands
//===--------------------------------------------------------------------===//
AffineMap *ubMap;
// Constant step.
int64_t step;
- // Operands for the lower and upper bounds.
+ // Operands for the lower and upper bounds, with the former followed by the
+ // latter. Dimensional operands are followed by symbolic operands for each
+ // bound.
std::vector<StmtOperand> operands;
explicit ForStmt(Location *location, unsigned numOperands, AffineMap *lbMap,
--- /dev/null
+//===- LoopUtils.h - Loop transformation utilities --------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This header file defines prototypes for various loop transformation utility
+// methods: these are not passes by themselves but are used either by passes,
+// optimization sequences, or in turn by other transformation utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_LOOP_UTILS_H
+#define MLIR_TRANSFORMS_LOOP_UTILS_H
+
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+
+class AffineMap;
+class ForStmt;
+class MLFunction;
+class MLFuncBuilder;
+
+/// Unrolls this for statement completely if the trip count is known to be
+/// constant. Returns false otherwise.
+bool loopUnrollFull(ForStmt *forStmt);
+/// Unrolls this for statement by the specified unroll factor. Returns false if
+/// the loop cannot be unrolled either due to restrictions or due to invalid
+/// unroll factors.
+bool loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor);
+/// Unrolls this loop by the specified unroll factor or its trip count,
+/// whichever is lower.
+bool loopUnrollUpToFactor(ForStmt *forStmt, uint64_t unrollFactor);
+
+/// Unrolls and jams this loop by the specified factor. Returns true if the loop
+/// is successfully unroll-jammed.
+bool loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor);
+
+/// Unrolls and jams this loop by the specified factor or by the trip count (if
+/// constant), whichever is lower.
+bool loopUnrollJamUpToFactor(ForStmt *forStmt, uint64_t unrollJamFactor);
+
+/// Promotes the loop body of a ForStmt to its containing block if the ForStmt
+/// was known to have a single iteration. Returns false otherwise.
+bool promoteIfSingleIteration(ForStmt *forStmt);
+
+/// Promotes all single iteration ForStmt's in the MLFunction, i.e., moves
+/// their body into the containing StmtBlock.
+void promoteSingleIterationLoops(MLFunction *f);
+
+/// Returns the lower bound of the cleanup loop when unrolling a loop
+/// with the specified unroll factor.
+AffineMap *getCleanupLoopLowerBound(const ForStmt &forStmt,
+ unsigned unrollFactor,
+ MLFuncBuilder *builder);
+
+/// Returns the upper bound of an unrolled loop when unrolling with
+/// the specified trip count, stride, and unroll factor.
+AffineMap *getUnrolledLoopUpperBound(const ForStmt &forStmt,
+ unsigned unrollFactor,
+ MLFuncBuilder *builder);
+
+} // end namespace mlir
+
+#endif // MLIR_TRANSFORMS_LOOP_UTILS_H
namespace mlir {
-class ForStmt;
class FunctionPass;
-class MLFunction;
class MLFunctionPass;
class ModulePass;
MLFunctionPass *createLoopUnrollPass(int unrollFactor = -1,
int unrollFull = -1);
-/// Unrolls this loop completely.
-bool loopUnrollFull(ForStmt *forStmt);
-/// Unrolls this loop by the specified unroll factor.
-bool loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor);
-
/// Creates a loop unroll jam pass to unroll jam by the specified factor. A
/// factor of -1 lets the pass use the default factor or the one on the command
/// line if provided.
MLFunctionPass *createLoopUnrollAndJamPass(int unrollJamFactor = -1);
-/// Unrolls and jams this loop by the specified factor.
-bool loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor);
-
/// Creates an affine expression simplification pass.
FunctionPass *createSimplifyAffineExprPass();
/// generated CFG functions.
ModulePass *createConvertToCFGPass();
-/// Promotes the loop body of a ForStmt to its containing block if the ForStmt
-/// was known to have a single iteration. Returns false otherwise.
-bool promoteIfSingleIteration(ForStmt *forStmt);
-
-/// Promotes all single iteration ForStmt's in the MLFunction, i.e., moves
-/// their body into the containing StmtBlock.
-void promoteSingleIterationLoops(MLFunction *f);
-
} // end namespace mlir
-#endif // MLIR_TRANSFORMS_LOOP_H
+#endif // MLIR_TRANSFORMS_PASSES_H
/// Returns the trip count of the loop as an affine expression if the latter is
/// expressible as an affine expression, and nullptr otherwise. The trip count
/// expression is simplified before returning.
-AffineExpr *mlir::getTripCount(const ForStmt &forStmt) {
+AffineExpr *mlir::getTripCountExpr(const ForStmt &forStmt) {
// upper_bound - lower_bound + 1
int64_t loopSpan;
int64_t ub = forStmt.getConstantUpperBound();
loopSpan = ub - lb + 1;
} else {
- const AffineBound lb = forStmt.getLowerBound();
- const AffineBound ub = forStmt.getUpperBound();
- auto lbMap = lb.getMap();
- auto ubMap = ub.getMap();
+ auto *lbMap = forStmt.getLowerBoundMap();
+ auto *ubMap = forStmt.getUpperBoundMap();
// TODO(bondhugula): handle max/min of multiple expressions.
- if (lbMap->getNumResults() != 1 || ubMap->getNumResults() != 1 ||
- lbMap->getNumDims() != ubMap->getNumDims() ||
- lbMap->getNumSymbols() != ubMap->getNumSymbols()) {
+ if (lbMap->getNumResults() != 1 || ubMap->getNumResults() != 1)
return nullptr;
- }
// TODO(bondhugula): handle bounds with different operands.
- unsigned i, e = lb.getNumOperands();
- for (i = 0; i < e; i++) {
- if (lb.getStmtOperand(i).get() != ub.getStmtOperand(i).get())
- break;
- }
// Bounds have different operands, unhandled for now.
- if (i != e)
+ if (!forStmt.matchingBoundOperandList())
return nullptr;
// ub_expr - lb_expr + 1
+ auto *lbExpr = lbMap->getResult(0);
+ auto *ubExpr = ubMap->getResult(0);
auto *loopSpanExpr = AffineBinaryOpExpr::getAdd(
- AffineBinaryOpExpr::getSub(ubMap->getResult(0), lbMap->getResult(0),
- context),
- 1, context);
+ AffineBinaryOpExpr::getSub(ubExpr, lbExpr, context), 1, context);
if (auto *expr = simplifyAffineExpr(loopSpanExpr, lbMap->getNumDims(),
lbMap->getNumSymbols(), context))
/// method uses affine expression analysis (in turn using getTripCount) and is
/// able to determine constant trip count in non-trivial cases.
llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) {
- AffineExpr *tripCountExpr = getTripCount(forStmt);
+ AffineExpr *tripCountExpr = getTripCountExpr(forStmt);
if (auto *constExpr = dyn_cast_or_null<AffineConstantExpr>(tripCountExpr))
return constExpr->getValue();
/// expression analysis is used (indirectly through getTripCount), and
/// this method is thus able to determine non-trivial divisors.
uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) {
- AffineExpr *tripCountExpr = getTripCount(forStmt);
+ AffineExpr *tripCountExpr = getTripCountExpr(forStmt);
if (!tripCountExpr)
return 1;
return AffineBinaryOpExpr::get(AffineExpr::Kind::Add, lhs, rhs, context);
}
+AffineExpr *Builder::getAddExpr(AffineExpr *lhs, int64_t rhs) {
+ return AffineBinaryOpExpr::getAdd(lhs, rhs, context);
+}
+
AffineExpr *Builder::getMulExpr(AffineExpr *lhs, AffineExpr *rhs) {
return AffineBinaryOpExpr::get(AffineExpr::Kind::Mul, lhs, rhs, context);
}
return getAddExpr(lhs, getMulExpr(rhs, getConstantExpr(-1)));
}
+AffineExpr *Builder::getSubExpr(AffineExpr *lhs, int64_t rhs) {
+ return AffineBinaryOpExpr::getAdd(lhs, -rhs, context);
+}
+
AffineExpr *Builder::getModExpr(AffineExpr *lhs, AffineExpr *rhs) {
return AffineBinaryOpExpr::get(AffineExpr::Kind::Mod, lhs, rhs, context);
}
return block ? block->findFunction() : nullptr;
}
-bool Statement::isInnermost() const {
- struct NestedLoopCounter : public StmtWalker<NestedLoopCounter> {
- unsigned numNestedLoops;
- NestedLoopCounter() : numNestedLoops(0) {}
- void walkForStmt(const ForStmt *fs) { numNestedLoops++; }
- };
-
- NestedLoopCounter nlc;
- nlc.walk(const_cast<Statement *>(this));
- return nlc.numNestedLoops == 1;
-}
-
MLValue *Statement::getOperand(unsigned idx) {
return getStmtOperand(idx).get();
}
setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
}
+bool ForStmt::matchingBoundOperandList() const {
+ if (lbMap->getNumDims() != ubMap->getNumDims() ||
+ lbMap->getNumSymbols() != ubMap->getNumSymbols())
+ return false;
+
+ unsigned numOperands = lbMap->getNumInputs();
+ for (unsigned i = 0, e = lbMap->getNumInputs(); i < e; i++) {
+ // Compare MLValue *'s.
+ if (getOperand(i) != getOperand(numOperands + i))
+ return false;
+ }
+ return true;
+}
+
//===----------------------------------------------------------------------===//
// IfStmt
//===----------------------------------------------------------------------===//
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardOps.h"
#include "mlir/IR/StmtVisitor.h"
+#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/CommandLine.h"
return false;
}
-/// Returns the upper bound of an unrolled loop with lower bound 'lb' and with
-/// the specified trip count, stride, and unroll factor.
-static AffineMap *getUnrolledLoopUpperBound(AffineMap *lbMap,
- uint64_t tripCount,
- unsigned unrollFactor, int64_t step,
- MLFuncBuilder *builder) {
- assert(lbMap->getNumResults() == 1);
- auto *lbExpr = lbMap->getResult(0);
- // lbExpr + (count - count % unrollFactor - 1) * step).
- auto *expr = builder->getAddExpr(
- lbExpr, builder->getConstantExpr(
- (tripCount - tripCount % unrollFactor - 1) * step));
- return builder->getAffineMap(lbMap->getNumDims(), lbMap->getNumSymbols(),
- {expr}, {});
-}
+/// Unrolls and jams this loop by the specified factor or by the trip count (if
+/// constant) whichever is lower.
+bool mlir::loopUnrollUpToFactor(ForStmt *forStmt, uint64_t unrollFactor) {
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
-/// Returns the lower bound of the cleanup loop when unrolling a loop with lower
-/// bound 'lb' and with the specified trip count, stride, and unroll factor.
-static AffineMap *getCleanupLoopLowerBound(AffineMap *lbMap, uint64_t tripCount,
- unsigned unrollFactor, int64_t step,
- MLFuncBuilder *builder) {
- assert(lbMap->getNumResults() == 1);
- auto *lbExpr = lbMap->getResult(0);
- // lbExpr + (count - count % unrollFactor) * step);
- auto *expr = builder->getAddExpr(
- lbExpr,
- builder->getConstantExpr((tripCount - tripCount % unrollFactor) * step));
- return builder->getAffineMap(lbMap->getNumDims(), lbMap->getNumSymbols(),
- {expr}, {});
+ if (mayBeConstantTripCount.hasValue() &&
+ mayBeConstantTripCount.getValue() < unrollFactor)
+ return loopUnrollByFactor(forStmt, mayBeConstantTripCount.getValue());
+ return loopUnrollByFactor(forStmt, unrollFactor);
}
-/// Unrolls this loop by the specified unroll factor.
+/// Unrolls this loop by the specified factor. Returns true if the loop
+/// is successfully unrolled.
bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
- assert(unrollFactor >= 1 && "unroll factor shoud be >= 1");
+ assert(unrollFactor >= 1 && "unroll factor should be >= 1");
if (unrollFactor == 1 || forStmt->getStatements().empty())
return false;
- Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
-
- if (!mayBeConstantTripCount.hasValue() &&
- getLargestDivisorOfTripCount(*forStmt) % unrollFactor != 0)
- return false;
-
- const AffineBound &lb = forStmt->getLowerBound();
- const AffineBound &ub = forStmt->getLowerBound();
- auto lbMap = lb.getMap();
- auto ubMap = lb.getMap();
+ auto *lbMap = forStmt->getLowerBoundMap();
+ auto *ubMap = forStmt->getUpperBoundMap();
// Loops with max/min expressions won't be unrolled here (the output can't be
// expressed as an MLFunction in the general case). However, the right way to
// do such unrolling for an MLFunction would be to specialize the loop for the
- // 'hotspot' case and unroll that hotspot case.
+ // 'hotspot' case and unroll that hotspot.
if (lbMap->getNumResults() != 1 || ubMap->getNumResults() != 1)
return false;
- // TODO(bondhugula): handle bounds with different sets of operands.
- // Same operand list for now.
- if (lbMap->getNumDims() != ubMap->getNumDims() ||
- lbMap->getNumSymbols() != ubMap->getNumSymbols())
- return false;
- unsigned i, e = lb.getNumOperands();
- for (i = 0; i < e; i++) {
- if (lb.getStmtOperand(i).get() != ub.getStmtOperand(i).get())
- break;
- }
- if (i != e)
+ // Same operand list for lower and upper bound for now.
+ // TODO(bondhugula): handle bounds with different operand lists.
+ if (!forStmt->matchingBoundOperandList())
return false;
- int64_t step = forStmt->getStep();
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
// If the trip count is lower than the unroll factor, no unrolled body.
// TODO(bondhugula): option to specify cleanup loop unrolling.
return false;
// Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
- // If the trip count is unknown, we currently unroll only when the unknown
- // trip count is known to be a multiple of unroll factor - hence, no cleanup
- // loop will be necessary in those cases.
- // TODO(bondhugula): handle generation of cleanup loop for unknown trip count
- // when it's not known to be a multiple of unroll factor (still for single
- // result / same operands case).
- if (mayBeConstantTripCount.hasValue() &&
- mayBeConstantTripCount.getValue() % unrollFactor != 0) {
- uint64_t tripCount = mayBeConstantTripCount.getValue();
+ if (getLargestDivisorOfTripCount(*forStmt) % unrollFactor != 0) {
DenseMap<const MLValue *, MLValue *> operandMap;
MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
- if (forStmt->hasConstantLowerBound()) {
- cleanupForStmt->setConstantLowerBound(
- forStmt->getConstantLowerBound() +
- (tripCount - tripCount % unrollFactor) * step);
- } else {
- cleanupForStmt->setLowerBoundMap(
- getCleanupLoopLowerBound(forStmt->getLowerBoundMap(), tripCount,
- unrollFactor, step, &builder));
- }
+ auto *clLbMap = getCleanupLoopLowerBound(*forStmt, unrollFactor, &builder);
+ assert(clLbMap &&
+ "cleanup loop lower bound map for single result bound maps can "
+ "always be determined");
+ cleanupForStmt->setLowerBoundMap(clLbMap);
// Promote the loop body up if this has turned into a single iteration loop.
promoteIfSingleIteration(cleanupForStmt);
- // The upper bound needs to be adjusted.
- if (forStmt->hasConstantUpperBound()) {
- forStmt->setConstantUpperBound(
- forStmt->getConstantLowerBound() +
- (tripCount - tripCount % unrollFactor - 1) * step);
- } else {
- forStmt->setUpperBoundMap(
- getUnrolledLoopUpperBound(forStmt->getLowerBoundMap(), tripCount,
- unrollFactor, step, &builder));
- }
+ // Adjust upper bound.
+ auto *unrolledUbMap =
+ getUnrolledLoopUpperBound(*forStmt, unrollFactor, &builder);
+ assert(unrolledUbMap &&
+ "upper bound map can alwayys be determined for an unrolled loop "
+ "with single result bounds");
+ forStmt->setUpperBoundMap(unrolledUbMap);
}
// Scale the step of loop being unrolled by unroll factor.
+ int64_t step = forStmt->getStep();
forStmt->setStep(step * unrollFactor);
// Builder to insert unrolled bodies right after the last statement in the
//
// Note: 'if/else' blocks are not jammed. So, if there are loops inside if
// stmt's, bodies of those loops will not be jammed.
-//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Passes.h"
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardOps.h"
#include "mlir/IR/StmtVisitor.h"
+#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/CommandLine.h"
return loopUnrollJamByFactor(forStmt, kDefaultUnrollJamFactor);
}
+bool mlir::loopUnrollJamUpToFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
+
+ if (mayBeConstantTripCount.hasValue() &&
+ mayBeConstantTripCount.getValue() < unrollJamFactor)
+ return loopUnrollJamByFactor(forStmt, mayBeConstantTripCount.getValue());
+ return loopUnrollJamByFactor(forStmt, unrollJamFactor);
+}
+
/// Unrolls and jams this loop by the specified factor.
bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
// Gathers all maximal sub-blocks of statements that do not themselves include
if (unrollJamFactor == 1 || forStmt->getStatements().empty())
return false;
- Optional<uint64_t> mayTripCount = getConstantTripCount(*forStmt).getValue();
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
- if (!mayTripCount.hasValue())
+ if (!mayBeConstantTripCount.hasValue() &&
+ getLargestDivisorOfTripCount(*forStmt) % unrollJamFactor != 0)
return false;
- uint64_t tripCount = mayTripCount.getValue();
- int64_t lb = forStmt->getConstantLowerBound();
- int64_t step = forStmt->getStep();
+ auto *lbMap = forStmt->getLowerBoundMap();
+ auto *ubMap = forStmt->getUpperBoundMap();
+
+ // Loops with max/min expressions won't be unrolled here (the output can't be
+ // expressed as an MLFunction in the general case). However, the right way to
+ // do such unrolling for an MLFunction would be to specialize the loop for the
+ // 'hotspot' case and unroll that hotspot.
+ if (lbMap->getNumResults() != 1 || ubMap->getNumResults() != 1)
+ return false;
+
+ // Same operand list for lower and upper bound for now.
+ // TODO(bondhugula): handle bounds with different sets of operands.
+ if (!forStmt->matchingBoundOperandList())
+ return false;
- // If the trip count is lower than the unroll jam factor, no unrolled body.
+ // If the trip count is lower than the unroll jam factor, no unroll jam.
// TODO(bondhugula): option to specify cleanup loop unrolling.
- if (tripCount < unrollJamFactor)
- return true;
+ if (mayBeConstantTripCount.hasValue() &&
+ mayBeConstantTripCount.getValue() < unrollJamFactor)
+ return false;
// Gather all sub-blocks to jam upon the loop being unrolled.
JamBlockGatherer jbg;
// Generate the cleanup loop if trip count isn't a multiple of
// unrollJamFactor.
- if (tripCount % unrollJamFactor) {
+ if (mayBeConstantTripCount.hasValue() &&
+ mayBeConstantTripCount.getValue() % unrollJamFactor != 0) {
DenseMap<const MLValue *, MLValue *> operandMap;
// Insert the cleanup loop right after 'forStmt'.
MLFuncBuilder builder(forStmt->getBlock(),
std::next(StmtBlock::iterator(forStmt)));
auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
- cleanupForStmt->setConstantLowerBound(
- lb + (tripCount - tripCount % unrollJamFactor) * step);
+ cleanupForStmt->setLowerBoundMap(
+ getCleanupLoopLowerBound(*forStmt, unrollJamFactor, &builder));
+
+ // The upper bound needs to be adjusted.
+ forStmt->setUpperBoundMap(
+ getUnrolledLoopUpperBound(*forStmt, unrollJamFactor, &builder));
// Promote the loop body up if this has turned into a single iteration loop.
promoteIfSingleIteration(cleanupForStmt);
}
- MLFuncBuilder b(forStmt);
+ // Scale the step of loop being unroll-jammed by the unroll-jam factor.
+ int64_t step = forStmt->getStep();
forStmt->setStep(step * unrollJamFactor);
- forStmt->setConstantUpperBound(
- lb + (tripCount - tripCount % unrollJamFactor - 1) * step);
for (auto &subBlock : subBlocks) {
// Builder to insert unroll-jammed bodies. Insert right at the end of
//
//===----------------------------------------------------------------------===//
-#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardOps.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
+using namespace mlir;
+
+/// Returns the upper bound of an unrolled loop with lower bound 'lb' and with
+/// the specified trip count, stride, and unroll factor. Returns nullptr when
+/// the trip count can't be expressed as an affine expression.
+AffineMap *mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt,
+ unsigned unrollFactor,
+ MLFuncBuilder *builder) {
+ auto *lbMap = forStmt.getLowerBoundMap();
+
+ // Single result lower bound map only.
+ if (lbMap->getNumResults() != 1)
+ return nullptr;
+
+ // Sometimes, the trip count cannot be expressed as an affine expression.
+ auto *tripCountExpr = getTripCountExpr(forStmt);
+ if (!tripCountExpr)
+ return nullptr;
+
+ AffineExpr *newUbExpr;
+ auto *lbExpr = lbMap->getResult(0);
+ int64_t step = forStmt.getStep();
+ // lbExpr + (count - count % unrollFactor - 1) * step).
+ if (auto *cTripCountExpr = dyn_cast<AffineConstantExpr>(tripCountExpr)) {
+ uint64_t tripCount = static_cast<uint64_t>(cTripCountExpr->getValue());
+ newUbExpr = builder->getAddExpr(
+ lbExpr, builder->getConstantExpr(
+ (tripCount - tripCount % unrollFactor - 1) * step));
+ } else {
+ newUbExpr = builder->getAddExpr(
+ lbExpr, builder->getMulExpr(
+ builder->getSubExpr(
+ builder->getSubExpr(
+ tripCountExpr,
+ builder->getModExpr(tripCountExpr, unrollFactor)),
+ 1),
+ step));
+ }
+ return builder->getAffineMap(lbMap->getNumDims(), lbMap->getNumSymbols(),
+ {newUbExpr}, {});
+}
+
+/// Returns the lower bound of the cleanup loop when unrolling a loop with lower
+/// bound 'lb' and with the specified trip count, stride, and unroll factor.
+/// Returns nullptr when the trip count can't be expressed as an affine
+/// expression.
+AffineMap *mlir::getCleanupLoopLowerBound(const ForStmt &forStmt,
+ unsigned unrollFactor,
+ MLFuncBuilder *builder) {
+ auto *lbMap = forStmt.getLowerBoundMap();
+
+ // Single result lower bound map only.
+ if (lbMap->getNumResults() != 1)
+ return nullptr;
+
+ // Sometimes the trip count cannot be expressed as an affine expression.
+ auto *tripCountExpr = getTripCountExpr(forStmt);
+ if (!tripCountExpr)
+ return nullptr;
+
+ AffineExpr *newLbExpr;
+ auto *lbExpr = lbMap->getResult(0);
+ int64_t step = forStmt.getStep();
+
+ // lbExpr + (count - count % unrollFactor) * step);
+ if (auto *cTripCountExpr = dyn_cast<AffineConstantExpr>(tripCountExpr)) {
+ uint64_t tripCount = static_cast<uint64_t>(cTripCountExpr->getValue());
+ newLbExpr = builder->getAddExpr(
+ lbExpr, builder->getConstantExpr(
+ (tripCount - tripCount % unrollFactor) * step));
+ } else {
+ newLbExpr = builder->getAddExpr(
+ lbExpr, builder->getMulExpr(
+ builder->getSubExpr(
+ tripCountExpr,
+ builder->getModExpr(tripCountExpr, unrollFactor)),
+ step));
+ }
+ return builder->getAffineMap(lbMap->getNumDims(), lbMap->getNumSymbols(),
+ {newLbExpr}, {});
+}
+
/// Promotes the loop body of a forStmt to its containing block if the forStmt
/// was known to have a single iteration. Returns false otherwise.
// TODO(bondhugula): extend this for arbitrary affine bounds.
bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
Optional<uint64_t> tripCount = getConstantTripCount(*forStmt);
- if (!tripCount.hasValue() || !forStmt->hasConstantLowerBound())
+ if (!tripCount.hasValue() || tripCount.getValue() != 1)
return false;
- if (tripCount.getValue() != 1)
+ // TODO(mlir-team): there is no builder for a max.
+ if (forStmt->getLowerBoundMap()->getNumResults() != 1)
return false;
// Replaces all IV uses to its single iteration value.
- auto *mlFunc = forStmt->findFunction();
- MLFuncBuilder topBuilder(&mlFunc->front());
- auto constOp = topBuilder.create<ConstantAffineIntOp>(
- forStmt->getLoc(), forStmt->getConstantLowerBound());
- forStmt->replaceAllUsesWith(constOp->getResult());
- // Move the statements to the containing block.
+ if (!forStmt->use_empty()) {
+ if (forStmt->hasConstantLowerBound()) {
+ auto *mlFunc = forStmt->findFunction();
+ MLFuncBuilder topBuilder(&mlFunc->front());
+ auto constOp = topBuilder.create<ConstantAffineIntOp>(
+ forStmt->getLoc(), forStmt->getConstantLowerBound());
+ forStmt->replaceAllUsesWith(constOp->getResult());
+ } else {
+ const AffineBound lb = forStmt->getLowerBound();
+ SmallVector<SSAValue *, 4> lbOperands(lb.operand_begin(),
+ lb.operand_end());
+ MLFuncBuilder builder(forStmt->getBlock(), StmtBlock::iterator(forStmt));
+ auto affineApplyOp = builder.create<AffineApplyOp>(
+ forStmt->getLoc(), lb.getMap(), lbOperands);
+ forStmt->replaceAllUsesWith(affineApplyOp->getResult(0));
+ }
+ }
+ // Move the loop body statements to the loop's containing block.
auto *block = forStmt->getBlock();
block->getStatements().splice(StmtBlock::iterator(forStmt),
forStmt->getStatements());
// RUN: mlir-opt %s -o - -loop-unroll-jam -unroll-jam-factor=2 | FileCheck %s
// CHECK: #map0 = (d0) -> (d0 + 1)
+// This should be matched to M1, but M1 is defined later.
+// CHECK: {{#map[0-9]+}} = ()[s0] -> (s0 + 8)
// CHECK-LABEL: mlfunc @unroll_jam_imperfect_nest() {
mlfunc @unroll_jam_imperfect_nest() {
// CHECK-NEXT: %14 = "addi32"(%c100, %c100) : (affineint, affineint) -> i32
return
}
+
+// UNROLL-BY-4-LABEL: mlfunc @loop_nest_unknown_count_1(%arg0 : affineint) {
+mlfunc @loop_nest_unknown_count_1(%N : affineint) {
+ // UNROLL-BY-4-NEXT: for %i0 = 1 to #map{{[0-9]+}}()[%arg0] step 4 {
+ // UNROLL-BY-4-NEXT: for %i1 = 1 to 100 {
+ // UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32
+ // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32
+ // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32
+ // UNROLL-BY-4-NEXT: %3 = "foo"() : () -> i32
+ // UNROLL-BY-4-NEXT: }
+ // UNROLL-BY-4-NEXT: }
+ // A cleanup loop should be generated here.
+ // UNROLL-BY-4-NEXT: for %i2 = #map{{[0-9]+}}()[%arg0] to %arg0 {
+ // UNROLL-BY-4-NEXT: for %i3 = 1 to 100 {
+ // UNROLL-BY-4-NEXT: %4 = "foo"() : () -> i32
+ // UNROLL-BY-4_NEXT: }
+ // UNROLL-BY-4_NEXT: }
+ // Specify the lower bound in a form so that both lb and ub operands match.
+ for %i = ()[s0] -> (1)()[%N] to %N {
+ for %j = 1 to 100 {
+ %x = "foo"() : () -> i32
+ }
+ }
+ return
+}
+
+// UNROLL-BY-4-LABEL: mlfunc @loop_nest_unknown_count_2(%arg0 : affineint) {
+mlfunc @loop_nest_unknown_count_2(%arg : affineint) {
+ // UNROLL-BY-4-NEXT: for %i0 = %arg0 to #map{{[0-9]+}}()[%arg0] step 4 {
+ // UNROLL-BY-4-NEXT: for %i1 = 1 to 100 {
+ // UNROLL-BY-4-NEXT: %0 = "foo"(%i0) : (affineint) -> i32
+ // UNROLL-BY-4-NEXT: %1 = affine_apply #map{{[0-9]+}}(%i0)
+ // UNROLL-BY-4-NEXT: %2 = "foo"(%1) : (affineint) -> i32
+ // UNROLL-BY-4-NEXT: %3 = affine_apply #map{{[0-9]+}}(%i0)
+ // UNROLL-BY-4-NEXT: %4 = "foo"(%3) : (affineint) -> i32
+ // UNROLL-BY-4-NEXT: %5 = affine_apply #map{{[0-9]+}}(%i0)
+ // UNROLL-BY-4-NEXT: %6 = "foo"(%5) : (affineint) -> i32
+ // UNROLL-BY-4-NEXT: }
+ // UNROLL-BY-4-NEXT: }
+ // The cleanup loop is a single iteration one and is promoted.
+ // UNROLL-BY-4-NEXT: %7 = affine_apply [[M1:#map{{[0-9]+}}]]()[%arg0]
+ // UNROLL-BY-4-NEXT: for %i3 = 1 to 100 {
+ // UNROLL-BY-4-NEXT: %8 = "foo"() : () -> i32
+ // UNROLL-BY-4_NEXT: }
+ // Specify the lower bound in a form so that both lb and ub operands match.
+ for %i = ()[s0] -> (s0) ()[%arg] to ()[s0] -> (s0+8) ()[%arg] {
+ for %j = 1 to 100 {
+ %x = "foo"(%i) : (affineint) -> i32
+ }
+ }
+ return
+}
}
// Difference between loop bounds is constant, but not a multiple of unroll
-// factor. A cleanup loop is generated.
+// factor. The cleanup loop happens to be a single iteration one and is promoted.
// UNROLL-BY-4-LABEL: mlfunc @loop_nest_operand3() {
mlfunc @loop_nest_operand3() {
// UNROLL-BY-4: for %i0 = 1 to 100 step 2 {
// UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32
// UNROLL-BY-4-NEXT: %3 = "foo"() : () -> i32
// UNROLL-BY-4-NEXT: }
- // UNROLL-BY-4-NEXT: for %i2 = #map{{[0-9]+}}(%i0) to #map{{[0-9]+}}(%i0) {
// UNROLL-BY-4-NEXT: %4 = "foo"() : () -> i32
- // UNROLL-BY-4-NEXT: }
- for %j = (d0) -> (d0) (%i) to (d0) -> (d0 + 4) (%i) {
+ for %j = (d0) -> (d0) (%i) to (d0) -> (d0 + 8) (%i) {
%x = "foo"() : () -> i32
}
} // UNROLL-BY-4: }
return
}
-// Will not be unrolled for now. TODO(bondhugula): handle this.
-// xUNROLL-BY-4-LABEL: mlfunc @loop_nest_operand4(%arg0 : affineint) {
+// UNROLL-BY-4-LABEL: mlfunc @loop_nest_operand4(%arg0 : affineint) {
mlfunc @loop_nest_operand4(%N : affineint) {
- // UNROLL-BY-4: for %i0 = 1 to 100 step 2 {
- for %i = 1 to 100 step 2 {
- // UNROLL-BY-4: for %i1 = 0 to %arg0 {
- // xUNROLL-BY-4: for %i1 = 0 to #map{{[0-9]+}}(%N) step 4 {
- // xUNROLL-BY-4: %0 = "foo"() : () -> i32
- // xUNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32
- // xUNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32
- // xUNROLL-BY-4-NEXT: %3 = "foo"() : () -> i32
- // xUNROLL-BY-4-NEXT: }
- // a cleanup loop should be generated here.
- for %j = (d0) -> (0) (%N) to %N {
+ // UNROLL-BY-4: for %i0 = 1 to 100 {
+ for %i = 1 to 100 {
+ // UNROLL-BY-4: for %i1 = 1 to #map{{[0-9]+}}()[%arg0] step 4 {
+ // UNROLL-BY-4: %0 = "foo"() : () -> i32
+ // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32
+ // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32
+ // UNROLL-BY-4-NEXT: %3 = "foo"() : () -> i32
+ // UNROLL-BY-4-NEXT: }
+ // A cleanup loop will be be generated here.
+ // UNROLL-BY-4-NEXT: for %i2 = #map{{[0-9]+}}()[%arg0] to %arg0 {
+ // UNROLL-BY-4-NEXT: %4 = "foo"() : () -> i32
+ // UNROLL-BY-4_NEXT: }
+ // Specify the lower bound so that both lb and ub operands match.
+ for %j = ()[s0] -> (1)()[%N] to %N {
%x = "foo"() : () -> i32
}
}