+++ /dev/null
-//===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file implements mlir::applyPatternsGreedily.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/StandardOps/StandardOps.h"
-#include "mlir/Transforms/PatternMatch.h"
-#include "llvm/ADT/DenseMap.h"
-using namespace mlir;
-
-namespace {
-class WorklistRewriter;
-
-/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
-/// applies the locally optimal patterns in a roughly "bottom up" way.
-class GreedyPatternRewriteDriver {
-public:
- explicit GreedyPatternRewriteDriver(OwningPatternList &&patterns)
- : matcher(std::move(patterns)) {
- worklist.reserve(64);
- }
-
- void simplifyFunction(Function *currentFunction, WorklistRewriter &rewriter);
-
- void addToWorklist(Operation *op) {
- worklistMap[op] = worklist.size();
- worklist.push_back(op);
- }
-
- Operation *popFromWorklist() {
- auto *op = worklist.back();
- worklist.pop_back();
-
- // This operation is no longer in the worklist, keep worklistMap up to date.
- if (op)
- worklistMap.erase(op);
- return op;
- }
-
- /// If the specified operation is in the worklist, remove it. If not, this is
- /// a no-op.
- void removeFromWorklist(Operation *op) {
- auto it = worklistMap.find(op);
- if (it != worklistMap.end()) {
- assert(worklist[it->second] == op && "malformed worklist data structure");
- worklist[it->second] = nullptr;
- }
- }
-
-private:
- /// The low-level pattern matcher.
- PatternMatcher matcher;
-
- /// The worklist for this transformation keeps track of the operations that
- /// need to be revisited, plus their index in the worklist. This allows us to
- /// efficiently remove operations from the worklist when they are removed even
- /// if they aren't the root of a pattern.
- std::vector<Operation *> worklist;
- DenseMap<Operation *, unsigned> worklistMap;
-
- /// As part of canonicalization, we move constants to the top of the entry
- /// block of the current function and de-duplicate them. This keeps track of
- /// constants we have done this for.
- DenseMap<std::pair<Attribute *, Type *>, Operation *> uniquedConstants;
-};
-}; // end anonymous namespace
-
-/// This is a listener object that updates our worklists and other data
-/// structures in response to operations being added and removed.
-namespace {
-class WorklistRewriter : public PatternRewriter {
-public:
- WorklistRewriter(GreedyPatternRewriteDriver &driver, MLIRContext *context)
- : PatternRewriter(context), driver(driver) {}
-
- virtual void setInsertionPoint(Operation *op) = 0;
-
- // If an operation is about to be removed, make sure it is not in our
- // worklist anymore because we'd get dangling references to it.
- void notifyOperationRemoved(Operation *op) override {
- driver.removeFromWorklist(op);
- }
-
- GreedyPatternRewriteDriver &driver;
-};
-
-} // end anonymous namespace
-
-void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
- WorklistRewriter &rewriter) {
- // These are scratch vectors used in the constant folding loop below.
- SmallVector<Attribute *, 8> operandConstants, resultConstants;
-
- while (!worklist.empty()) {
- auto *op = popFromWorklist();
-
- // Nulls get added to the worklist when operations are removed, ignore them.
- if (op == nullptr)
- continue;
-
- // If we have a constant op, unique it into the entry block.
- if (auto constant = op->dyn_cast<ConstantOp>()) {
- // If this constant is dead, remove it, being careful to keep
- // uniquedConstants up to date.
- if (constant->use_empty()) {
- auto it =
- uniquedConstants.find({constant->getValue(), constant->getType()});
- if (it != uniquedConstants.end() && it->second == op)
- uniquedConstants.erase(it);
- constant->erase();
- continue;
- }
-
- // Check to see if we already have a constant with this type and value:
- auto &entry = uniquedConstants[std::make_pair(constant->getValue(),
- constant->getType())];
- if (entry) {
- // If this constant is already our uniqued one, then leave it alone.
- if (entry == op)
- continue;
-
- // Otherwise replace this redundant constant with the uniqued one. We
- // know this is safe because we move constants to the top of the
- // function when they are uniqued, so we know they dominate all uses.
- constant->replaceAllUsesWith(entry->getResult(0));
- constant->erase();
- continue;
- }
-
- // If we have no entry, then we should unique this constant as the
- // canonical version. To ensure safe dominance, move the operation to the
- // top of the function.
- entry = op;
-
- // TODO: If we make terminators into Operations then we could turn this
- // into a nice Operation::moveBefore(Operation*) method. We just need the
- // guarantee that a block is non-empty.
- if (auto *cfgFunc = dyn_cast<CFGFunction>(currentFunction)) {
- auto &entryBB = cfgFunc->front();
- cast<OperationInst>(op)->moveBefore(&entryBB, entryBB.begin());
- } else {
- auto *mlFunc = cast<MLFunction>(currentFunction);
- cast<OperationStmt>(op)->moveBefore(mlFunc, mlFunc->begin());
- }
-
- continue;
- }
-
- // If the operation has no side effects, and no users, then it is trivially
- // dead - remove it.
- if (op->hasNoSideEffect() && op->use_empty()) {
- op->erase();
- continue;
- }
-
- // Check to see if any operands to the instruction is constant and whether
- // the operation knows how to constant fold itself.
- operandConstants.clear();
- for (auto *operand : op->getOperands()) {
- Attribute *operandCst = nullptr;
- if (auto *operandOp = operand->getDefiningOperation()) {
- if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
- operandCst = operandConstantOp->getValue();
- }
- operandConstants.push_back(operandCst);
- }
-
- // If constant folding was successful, create the result constants, RAUW the
- // operation and remove it.
- resultConstants.clear();
- if (!op->constantFold(operandConstants, resultConstants)) {
- rewriter.setInsertionPoint(op);
-
- for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
- auto *res = op->getResult(i);
- if (res->use_empty()) // ignore dead uses.
- continue;
-
- // If we already have a canonicalized version of this constant, just
- // reuse it. Otherwise create a new one.
- SSAValue *cstValue;
- auto it = uniquedConstants.find({resultConstants[i], res->getType()});
- if (it != uniquedConstants.end())
- cstValue = it->second->getResult(0);
- else
- cstValue = rewriter.create<ConstantOp>(
- op->getLoc(), resultConstants[i], res->getType());
- res->replaceAllUsesWith(cstValue);
- }
-
- assert(op->hasNoSideEffect() && "Constant folded op with side effects?");
- op->erase();
- continue;
- }
-
- // If this is an associative binary operation with a constant on the LHS,
- // move it to the right side.
- if (operandConstants.size() == 2 && operandConstants[0] &&
- !operandConstants[1]) {
- auto *newLHS = op->getOperand(1);
- op->setOperand(1, op->getOperand(0));
- op->setOperand(0, newLHS);
- }
-
- // Check to see if we have any patterns that match this node.
- auto match = matcher.findMatch(op);
- if (!match.first)
- continue;
-
- // Make sure that any new operations are inserted at this point.
- rewriter.setInsertionPoint(op);
- match.first->rewrite(op, std::move(match.second), rewriter);
- }
-
- uniquedConstants.clear();
-}
-
-static void processMLFunction(MLFunction *fn, OwningPatternList &&patterns) {
- class MLFuncRewriter : public WorklistRewriter {
- public:
- MLFuncRewriter(GreedyPatternRewriteDriver &driver, MLFuncBuilder &builder)
- : WorklistRewriter(driver, builder.getContext()), builder(builder) {}
-
- // Implement the hook for creating operations, and make sure that newly
- // created ops are added to the worklist for processing.
- Operation *createOperation(const OperationState &state) override {
- auto *result = builder.createOperation(state);
- driver.addToWorklist(result);
- return result;
- }
-
- // When the root of a pattern is about to be replaced, it can trigger
- // simplifications to its users - make sure to add them to the worklist
- // before the root is changed.
- void notifyRootReplaced(Operation *op) override {
- auto *opStmt = cast<OperationStmt>(op);
- for (auto *result : opStmt->getResults())
- // TODO: Add a result->getUsers() iterator.
- for (auto &user : result->getUses()) {
- if (auto *op = dyn_cast<OperationStmt>(user.getOwner()))
- driver.addToWorklist(op);
- }
-
- // TODO: Walk the operand list dropping them as we go. If any of them
- // drop to zero uses, then add them to the worklist to allow them to be
- // deleted as dead.
- }
-
- void setInsertionPoint(Operation *op) override {
- // Any new operations should be added before this statement.
- builder.setInsertionPoint(cast<OperationStmt>(op));
- }
-
- private:
- MLFuncBuilder &builder;
- };
-
- GreedyPatternRewriteDriver driver(std::move(patterns));
- fn->walk([&](OperationStmt *stmt) { driver.addToWorklist(stmt); });
-
- MLFuncBuilder mlBuilder(fn);
- MLFuncRewriter rewriter(driver, mlBuilder);
- driver.simplifyFunction(fn, rewriter);
-}
-
-static void processCFGFunction(CFGFunction *fn, OwningPatternList &&patterns) {
- class CFGFuncRewriter : public WorklistRewriter {
- public:
- CFGFuncRewriter(GreedyPatternRewriteDriver &driver, CFGFuncBuilder &builder)
- : WorklistRewriter(driver, builder.getContext()), builder(builder) {}
-
- // Implement the hook for creating operations, and make sure that newly
- // created ops are added to the worklist for processing.
- Operation *createOperation(const OperationState &state) override {
- auto *result = builder.createOperation(state);
- driver.addToWorklist(result);
- return result;
- }
-
- // When the root of a pattern is about to be replaced, it can trigger
- // simplifications to its users - make sure to add them to the worklist
- // before the root is changed.
- void notifyRootReplaced(Operation *op) override {
- auto *opStmt = cast<OperationInst>(op);
- for (auto *result : opStmt->getResults())
- // TODO: Add a result->getUsers() iterator.
- for (auto &user : result->getUses()) {
- if (auto *op = dyn_cast<OperationInst>(user.getOwner()))
- driver.addToWorklist(op);
- }
-
- // TODO: Walk the operand list dropping them as we go. If any of them
- // drop to zero uses, then add them to the worklist to allow them to be
- // deleted as dead.
- }
-
- void setInsertionPoint(Operation *op) override {
- // Any new operations should be added before this instruction.
- builder.setInsertionPoint(cast<OperationInst>(op));
- }
-
- private:
- CFGFuncBuilder &builder;
- };
-
- GreedyPatternRewriteDriver driver(std::move(patterns));
- for (auto &bb : *fn)
- for (auto &op : bb)
- driver.addToWorklist(&op);
-
- CFGFuncBuilder cfgBuilder(fn);
- CFGFuncRewriter rewriter(driver, cfgBuilder);
- driver.simplifyFunction(fn, rewriter);
-}
-
-/// Rewrite the specified function by repeatedly applying the highest benefit
-/// patterns in a greedy work-list driven manner.
-///
-void mlir::applyPatternsGreedily(Function *fn, OwningPatternList &&patterns) {
- if (auto *cfg = dyn_cast<CFGFunction>(fn)) {
- processCFGFunction(cfg, std::move(patterns));
- } else {
- processMLFunction(cast<MLFunction>(fn), std::move(patterns));
- }
-}
+++ /dev/null
-//===- LoopUtils.cpp ---- Misc utilities for loop transformation ----------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file implements miscellaneous loop transformation routines.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Transforms/LoopUtils.h"
-
-#include "mlir/Analysis/LoopAnalysis.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Statements.h"
-#include "mlir/IR/StmtVisitor.h"
-#include "mlir/StandardOps/StandardOps.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/Support/Debug.h"
-
-#define DEBUG_TYPE "LoopUtils"
-
-using namespace mlir;
-
-/// Returns the upper bound of an unrolled loop with lower bound 'lb' and with
-/// the specified trip count, stride, and unroll factor. Returns nullptr when
-/// the trip count can't be expressed as an affine expression.
-AffineMap mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt,
- unsigned unrollFactor,
- MLFuncBuilder *builder) {
- auto lbMap = forStmt.getLowerBoundMap();
-
- // Single result lower bound map only.
- if (lbMap.getNumResults() != 1)
- return AffineMap::Null();
-
- // Sometimes, the trip count cannot be expressed as an affine expression.
- auto tripCount = getTripCountExpr(forStmt);
- if (!tripCount)
- return AffineMap::Null();
-
- AffineExpr lb(lbMap.getResult(0));
- unsigned step = forStmt.getStep();
- auto newUb = lb + (tripCount - tripCount % unrollFactor - 1) * step;
-
- return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(),
- {newUb}, {});
-}
-
-/// Returns the lower bound of the cleanup loop when unrolling a loop with lower
-/// bound 'lb' and with the specified trip count, stride, and unroll factor.
-/// Returns an AffinMap with nullptr storage (that evaluates to false)
-/// when the trip count can't be expressed as an affine expression.
-AffineMap mlir::getCleanupLoopLowerBound(const ForStmt &forStmt,
- unsigned unrollFactor,
- MLFuncBuilder *builder) {
- auto lbMap = forStmt.getLowerBoundMap();
-
- // Single result lower bound map only.
- if (lbMap.getNumResults() != 1)
- return AffineMap::Null();
-
- // Sometimes the trip count cannot be expressed as an affine expression.
- AffineExpr tripCount(getTripCountExpr(forStmt));
- if (!tripCount)
- return AffineMap::Null();
-
- AffineExpr lb(lbMap.getResult(0));
- unsigned step = forStmt.getStep();
- auto newLb = lb + (tripCount - tripCount % unrollFactor) * step;
- return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(),
- {newLb}, {});
-}
-
-/// Promotes the loop body of a forStmt to its containing block if the forStmt
-/// was known to have a single iteration. Returns false otherwise.
-// TODO(bondhugula): extend this for arbitrary affine bounds.
-bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
- Optional<uint64_t> tripCount = getConstantTripCount(*forStmt);
- if (!tripCount.hasValue() || tripCount.getValue() != 1)
- return false;
-
- // TODO(mlir-team): there is no builder for a max.
- if (forStmt->getLowerBoundMap().getNumResults() != 1)
- return false;
-
- // Replaces all IV uses to its single iteration value.
- if (!forStmt->use_empty()) {
- if (forStmt->hasConstantLowerBound()) {
- auto *mlFunc = forStmt->findFunction();
- MLFuncBuilder topBuilder(&mlFunc->front());
- auto constOp = topBuilder.create<ConstantIndexOp>(
- forStmt->getLoc(), forStmt->getConstantLowerBound());
- forStmt->replaceAllUsesWith(constOp);
- } else {
- const AffineBound lb = forStmt->getLowerBound();
- SmallVector<SSAValue *, 4> lbOperands(lb.operand_begin(),
- lb.operand_end());
- MLFuncBuilder builder(forStmt->getBlock(), StmtBlock::iterator(forStmt));
- auto affineApplyOp = builder.create<AffineApplyOp>(
- forStmt->getLoc(), lb.getMap(), lbOperands);
- forStmt->replaceAllUsesWith(affineApplyOp->getResult(0));
- }
- }
- // Move the loop body statements to the loop's containing block.
- auto *block = forStmt->getBlock();
- block->getStatements().splice(StmtBlock::iterator(forStmt),
- forStmt->getStatements());
- forStmt->erase();
- return true;
-}
-
-/// Promotes all single iteration for stmt's in the MLFunction, i.e., moves
-/// their body into the containing StmtBlock.
-void mlir::promoteSingleIterationLoops(MLFunction *f) {
- // Gathers all innermost loops through a post order pruned walk.
- class LoopBodyPromoter : public StmtWalker<LoopBodyPromoter> {
- public:
- void visitForStmt(ForStmt *forStmt) { promoteIfSingleIteration(forStmt); }
- };
-
- LoopBodyPromoter fsw;
- fsw.walkPostOrder(f);
-}
-
-/// Generates a for 'stmt' with the specified lower and upper bounds while
-/// generating the right IV remappings for the delayed statements. The
-/// statement blocks that go into the loop are specified in stmtGroupQueue
-/// starting from the specified offset, and in that order; the first element of
-/// the pair specifies the delay applied to that group of statements. Returns
-/// nullptr if the generated loop simplifies to a single iteration one.
-static ForStmt *
-generateLoop(AffineMap lb, AffineMap ub,
- const std::vector<std::pair<uint64_t, ArrayRef<Statement *>>>
- &stmtGroupQueue,
- unsigned offset, ForStmt *srcForStmt, MLFuncBuilder *b) {
- SmallVector<MLValue *, 4> lbOperands(srcForStmt->getLowerBoundOperands());
- SmallVector<MLValue *, 4> ubOperands(srcForStmt->getUpperBoundOperands());
-
- auto *loopChunk =
- b->createFor(srcForStmt->getLoc(), lbOperands, lb, ubOperands, ub);
- OperationStmt::OperandMapTy operandMap;
-
- for (auto it = stmtGroupQueue.begin() + offset, e = stmtGroupQueue.end();
- it != e; ++it) {
- auto elt = *it;
- // All 'same delay' statements get added with the operands being remapped
- // (to results of cloned statements).
- // Generate the remapping if the delay is not zero: oldIV = newIV - delay.
- // TODO(bondhugula): check if srcForStmt is actually used in elt.second
- // instead of just checking if it's used at all.
- if (!srcForStmt->use_empty() && elt.first != 0) {
- auto b = MLFuncBuilder::getForStmtBodyBuilder(loopChunk);
- auto *oldIV =
- b.create<AffineApplyOp>(
- srcForStmt->getLoc(),
- b.getSingleDimShiftAffineMap(-static_cast<int64_t>(elt.first)),
- loopChunk)
- ->getResult(0);
- operandMap[srcForStmt] = cast<MLValue>(oldIV);
- } else {
- operandMap[srcForStmt] = static_cast<MLValue *>(loopChunk);
- }
- for (auto *stmt : elt.second) {
- loopChunk->push_back(stmt->clone(operandMap, b->getContext()));
- }
- }
- if (promoteIfSingleIteration(loopChunk))
- return nullptr;
- return loopChunk;
-}
-
-/// Skew the statements in the body of a 'for' statement with the specified
-/// statement-wise delays. The delays are with respect to the original execution
-/// order. A delay of zero for each statement will lead to no change.
-// The skewing of statements with respect to one another can be used for example
-// to allow overlap of asynchronous operations (such as DMA communication) with
-// computation, or just relative shifting of statements for better register
-// reuse, locality or parallelism. As such, the delays are typically expected to
-// be at most of the order of the number of statements. This method should not
-// be used as a substitute for loop distribution/fission.
-// This method uses an algorithm// in time linear in the number of statements in
-// the body of the for loop - (using the 'sweep line' paradigm). This method
-// asserts preservation of SSA dominance. A check for that as well as that for
-// memory-based depedence preservation check rests with the users of this
-// method.
-UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays,
- bool unrollPrologueEpilogue) {
- if (forStmt->getStatements().empty())
- return UtilResult::Success;
-
- // If the trip counts aren't constant, we would need versioning and
- // conditional guards (or context information to prevent such versioning). The
- // better way to pipeline for such loops is to first tile them and extract
- // constant trip count "full tiles" before applying this.
- auto mayBeConstTripCount = getConstantTripCount(*forStmt);
- if (!mayBeConstTripCount.hasValue()) {
- LLVM_DEBUG(llvm::dbgs() << "non-constant trip count loop\n";);
- return UtilResult::Success;
- }
- uint64_t tripCount = mayBeConstTripCount.getValue();
-
- assert(isStmtwiseShiftValid(*forStmt, delays) &&
- "shifts will lead to an invalid transformation\n");
-
- unsigned numChildStmts = forStmt->getStatements().size();
-
- // Do a linear time (counting) sort for the delays.
- uint64_t maxDelay = 0;
- for (unsigned i = 0; i < numChildStmts; i++) {
- maxDelay = std::max(maxDelay, delays[i]);
- }
- // Such large delays are not the typical use case.
- if (maxDelay >= numChildStmts) {
- LLVM_DEBUG(llvm::dbgs() << "stmt delays too large - unexpected\n";);
- return UtilResult::Success;
- }
-
- // An array of statement groups sorted by delay amount; each group has all
- // statements with the same delay in the order in which they appear in the
- // body of the 'for' stmt.
- std::vector<std::vector<Statement *>> sortedStmtGroups(maxDelay + 1);
- unsigned pos = 0;
- for (auto &stmt : *forStmt) {
- auto delay = delays[pos++];
- sortedStmtGroups[delay].push_back(&stmt);
- }
-
- // Unless the shifts have a specific pattern (which actually would be the
- // common use case), prologue and epilogue are not meaningfully defined.
- // Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first
- // loop generated as the prologue and the last as epilogue and unroll these
- // fully.
- ForStmt *prologue = nullptr;
- ForStmt *epilogue = nullptr;
-
- // Do a sweep over the sorted delays while storing open groups in a
- // vector, and generating loop portions as necessary during the sweep. A block
- // of statements is paired with its delay.
- std::vector<std::pair<uint64_t, ArrayRef<Statement *>>> stmtGroupQueue;
-
- auto origLbMap = forStmt->getLowerBoundMap();
- uint64_t lbDelay = 0;
- MLFuncBuilder b(forStmt);
- for (uint64_t d = 0, e = sortedStmtGroups.size(); d < e; ++d) {
- // If nothing is delayed by d, continue.
- if (sortedStmtGroups[d].empty())
- continue;
- if (!stmtGroupQueue.empty()) {
- assert(d >= 1 &&
- "Queue expected to be empty when the first block is found");
- // The interval for which the loop needs to be generated here is:
- // ( lbDelay, min(lbDelay + tripCount - 1, d - 1) ] and the body of the
- // loop needs to have all statements in stmtQueue in that order.
- ForStmt *res;
- if (lbDelay + tripCount - 1 < d - 1) {
- res = generateLoop(
- b.getShiftedAffineMap(origLbMap, lbDelay),
- b.getShiftedAffineMap(origLbMap, lbDelay + tripCount - 1),
- stmtGroupQueue, 0, forStmt, &b);
- // Entire loop for the queued stmt groups generated, empty it.
- stmtGroupQueue.clear();
- lbDelay += tripCount;
- } else {
- res = generateLoop(b.getShiftedAffineMap(origLbMap, lbDelay),
- b.getShiftedAffineMap(origLbMap, d - 1),
- stmtGroupQueue, 0, forStmt, &b);
- lbDelay = d;
- }
- if (!prologue && res)
- prologue = res;
- epilogue = res;
- } else {
- // Start of first interval.
- lbDelay = d;
- }
- // Augment the list of statements that get into the current open interval.
- stmtGroupQueue.push_back({d, sortedStmtGroups[d]});
- }
-
- // Those statements groups left in the queue now need to be processed (FIFO)
- // and their loops completed.
- for (unsigned i = 0, e = stmtGroupQueue.size(); i < e; ++i) {
- uint64_t ubDelay = stmtGroupQueue[i].first + tripCount - 1;
- epilogue = generateLoop(b.getShiftedAffineMap(origLbMap, lbDelay),
- b.getShiftedAffineMap(origLbMap, ubDelay),
- stmtGroupQueue, i, forStmt, &b);
- lbDelay = ubDelay + 1;
- if (!prologue)
- prologue = epilogue;
- }
-
- // Erase the original for stmt.
- forStmt->erase();
-
- if (unrollPrologueEpilogue && prologue)
- loopUnrollFull(prologue);
- if (unrollPrologueEpilogue && !epilogue && epilogue != prologue)
- loopUnrollFull(epilogue);
-
- return UtilResult::Success;
-}
+++ /dev/null
-//===- Pass.cpp - Pass infrastructure implementation ----------------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file implements common pass infrastructure.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Transforms/Pass.h"
-#include "mlir/IR/CFGFunction.h"
-#include "mlir/IR/MLFunction.h"
-#include "mlir/IR/Module.h"
-
-using namespace mlir;
-
-/// Function passes walk a module and look at each function with their
-/// corresponding hooks and terminates upon error encountered.
-PassResult FunctionPass::runOnModule(Module *m) {
- for (auto &fn : *m) {
- if (auto *mlFunc = dyn_cast<MLFunction>(&fn))
- if (runOnMLFunction(mlFunc))
- return failure();
- if (auto *cfgFunc = dyn_cast<CFGFunction>(&fn))
- if (runOnCFGFunction(cfgFunc))
- return failure();
- }
- return success();
-}
+++ /dev/null
-//===- PatternMatch.cpp - Base classes for pattern match ------------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-#include "mlir/IR/SSAValue.h"
-#include "mlir/IR/Statements.h"
-#include "mlir/StandardOps/StandardOps.h"
-#include "mlir/Transforms/PatternMatch.h"
-using namespace mlir;
-
-PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
- assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
- "This pattern match benefit is too large to represent");
-}
-
-unsigned short PatternBenefit::getBenefit() const {
- assert(representation != ImpossibleToMatchSentinel &&
- "Pattern doesn't match");
- return representation;
-}
-
-bool PatternBenefit::operator==(const PatternBenefit& other) {
- if (isImpossibleToMatch())
- return other.isImpossibleToMatch();
- if (other.isImpossibleToMatch())
- return false;
- return getBenefit() == other.getBenefit();
-}
-
-bool PatternBenefit::operator!=(const PatternBenefit& other) {
- return !(*this == other);
-}
-
-//===----------------------------------------------------------------------===//
-// Pattern implementation
-//===----------------------------------------------------------------------===//
-
-Pattern::Pattern(StringRef rootName, MLIRContext *context,
- Optional<PatternBenefit> staticBenefit)
- : rootKind(OperationName(rootName, context)), staticBenefit(staticBenefit) {
-}
-
-Pattern::Pattern(StringRef rootName, MLIRContext *context,
- unsigned staticBenefit)
- : rootKind(rootName, context), staticBenefit(staticBenefit) {}
-
-Optional<PatternBenefit> Pattern::getStaticBenefit() const {
- return staticBenefit;
-}
-
-OperationName Pattern::getRootKind() const { return rootKind; }
-
-void Pattern::rewrite(Operation *op, std::unique_ptr<PatternState> state,
- PatternRewriter &rewriter) const {
- rewrite(op, rewriter);
-}
-
-void Pattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
- llvm_unreachable("need to implement one of the rewrite functions!");
-}
-
-/// This method indicates that no match was found.
-PatternMatchResult Pattern::matchFailure() {
- return {PatternBenefit::impossibleToMatch(), std::unique_ptr<PatternState>()};
-}
-
-/// This method indicates that a match was found and has the specified cost.
-PatternMatchResult
-Pattern::matchSuccess(PatternBenefit benefit,
- std::unique_ptr<PatternState> state) const {
- assert((!getStaticBenefit().hasValue() ||
- getStaticBenefit().getValue() == benefit) &&
- "This version of matchSuccess must be called with a benefit that "
- "matches the static benefit if set!");
-
- return {benefit, std::move(state)};
-}
-
-/// This method indicates that a match was found for patterns that have a
-/// known static benefit.
-PatternMatchResult
-Pattern::matchSuccess(std::unique_ptr<PatternState> state) const {
- auto benefit = getStaticBenefit();
- assert(benefit.hasValue() && "Pattern doesn't have a static benefit");
- return matchSuccess(benefit.getValue(), std::move(state));
-}
-
-//===----------------------------------------------------------------------===//
-// PatternRewriter implementation
-//===----------------------------------------------------------------------===//
-
-PatternRewriter::~PatternRewriter() {
- // Out of line to provide a vtable anchor for the class.
-}
-
-/// This method is used as the final replacement hook for patterns that match
-/// a single result value. In addition to replacing and removing the
-/// specified operation, clients can specify a list of other nodes that this
-/// replacement may make (perhaps transitively) dead. If any of those ops are
-/// dead, this will remove them as well.
-void PatternRewriter::replaceSingleResultOp(
- Operation *op, SSAValue *newValue, ArrayRef<SSAValue *> opsToRemoveIfDead) {
- // Notify the rewriter subclass that we're about to replace this root.
- notifyRootReplaced(op);
-
- assert(op->getNumResults() == 1 && "op isn't a SingleResultOp!");
- op->getResult(0)->replaceAllUsesWith(newValue);
-
- notifyOperationRemoved(op);
- op->erase();
-
- // TODO: Process the opsToRemoveIfDead list, removing things and calling the
- // notifyOperationRemoved hook in the process.
-}
-
-/// This method is used as the final notification hook for patterns that end
-/// up modifying the pattern root in place, by changing its operands. This is
-/// a minor efficiency win (it avoids creating a new instruction and removing
-/// the old one) but also often allows simpler code in the client.
-///
-/// The opsToRemoveIfDead list is an optional list of nodes that the rewriter
-/// should remove if they are dead at this point.
-///
-void PatternRewriter::updatedRootInPlace(
- Operation *op, ArrayRef<SSAValue *> opsToRemoveIfDead) {
- // Notify the rewriter subclass that we're about to replace this root.
- notifyRootUpdated(op);
-
- // TODO: Process the opsToRemoveIfDead list, removing things and calling the
- // notifyOperationRemoved hook in the process.
-}
-
-//===----------------------------------------------------------------------===//
-// PatternMatcher implementation
-//===----------------------------------------------------------------------===//
-
-/// Find the highest benefit pattern available in the pattern set for the DAG
-/// rooted at the specified node. This returns the pattern if found, or null
-/// if there are no matches.
-auto PatternMatcher::findMatch(Operation *op) -> MatchResult {
- // TODO: This is a completely trivial implementation, expand this in the
- // future.
-
- // Keep track of the best match, the benefit of it, and any matcher specific
- // state it is maintaining.
- MatchResult bestMatch = {nullptr, nullptr};
- Optional<PatternBenefit> bestBenefit;
-
- for (auto &pattern : patterns) {
- // Ignore patterns that are for the wrong root.
- if (pattern->getRootKind() != op->getName())
- continue;
-
- // If we know the static cost of the pattern is worse than what we've
- // already found then don't run it.
- auto staticBenefit = pattern->getStaticBenefit();
- if (staticBenefit.hasValue() && bestBenefit.hasValue() &&
- staticBenefit.getValue().getBenefit() <
- bestBenefit.getValue().getBenefit())
- continue;
-
- // Check to see if this pattern matches this node.
- auto result = pattern->match(op);
- auto benefit = result.first;
-
- // If this pattern failed to match, ignore it.
- if (benefit.isImpossibleToMatch())
- continue;
-
- // If it matched but had lower benefit than our best match so far, then
- // ignore it.
- if (bestBenefit.hasValue() &&
- benefit.getBenefit() < bestBenefit.getValue().getBenefit())
- continue;
-
- // Okay we found a match that is better than our previous one, remember it.
- bestBenefit = benefit;
- bestMatch = {pattern.get(), std::move(result.second)};
- }
-
- // If we found any match, return it.
- return bestMatch;
-}
+++ /dev/null
-//===- Utils.cpp ---- Misc utilities for code and data transformation -----===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file implements miscellaneous transformation routines for non-loop IR
-// structures.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Transforms/Utils.h"
-
-#include "mlir/Analysis/AffineAnalysis.h"
-#include "mlir/Analysis/AffineStructures.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/StandardOps/StandardOps.h"
-#include "mlir/Support/MathExtras.h"
-#include "llvm/ADT/DenseMap.h"
-
-using namespace mlir;
-
-/// Return true if this operation dereferences one or more memref's.
-// Temporary utility: will be replaced when this is modeled through
-// side-effects/op traits. TODO(b/117228571)
-static bool isMemRefDereferencingOp(const Operation &op) {
- if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() ||
- op.isa<DmaWaitOp>())
- return true;
- return false;
-}
-
-/// Replaces all uses of oldMemRef with newMemRef while optionally remapping
-/// old memref's indices to the new memref using the supplied affine map
-/// and adding any additional indices. The new memref could be of a different
-/// shape or rank, but of the same elemental type. Additional indices are added
-/// at the start for now.
-// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily
-// extended to add additional indices at any position.
-bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
- MLValue *newMemRef,
- ArrayRef<MLValue *> extraIndices,
- AffineMap indexRemap) {
- unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank();
- (void)newMemRefRank; // unused in opt mode
- unsigned oldMemRefRank = cast<MemRefType>(oldMemRef->getType())->getRank();
- (void)newMemRefRank;
- if (indexRemap) {
- assert(indexRemap.getNumInputs() == oldMemRefRank);
- assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
- } else {
- assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
- }
-
- // Assert same elemental type.
- assert(cast<MemRefType>(oldMemRef->getType())->getElementType() ==
- cast<MemRefType>(newMemRef->getType())->getElementType());
-
- // Check if memref was used in a non-deferencing context.
- for (const StmtOperand &use : oldMemRef->getUses()) {
- auto *opStmt = cast<OperationStmt>(use.getOwner());
- // Failure: memref used in a non-deferencing op (potentially escapes); no
- // replacement in these cases.
- if (!isMemRefDereferencingOp(*opStmt))
- return false;
- }
-
- // Walk all uses of old memref. Statement using the memref gets replaced.
- for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) {
- StmtOperand &use = *(it++);
- auto *opStmt = cast<OperationStmt>(use.getOwner());
- assert(isMemRefDereferencingOp(*opStmt) &&
- "memref deferencing op expected");
-
- auto getMemRefOperandPos = [&]() -> unsigned {
- unsigned i;
- for (i = 0; i < opStmt->getNumOperands(); i++) {
- if (opStmt->getOperand(i) == oldMemRef)
- break;
- }
- assert(i < opStmt->getNumOperands() && "operand guaranteed to be found");
- return i;
- };
- unsigned memRefOperandPos = getMemRefOperandPos();
-
- // Construct the new operation statement using this memref.
- SmallVector<MLValue *, 8> operands;
- operands.reserve(opStmt->getNumOperands() + extraIndices.size());
- // Insert the non-memref operands.
- operands.insert(operands.end(), opStmt->operand_begin(),
- opStmt->operand_begin() + memRefOperandPos);
- operands.push_back(newMemRef);
-
- MLFuncBuilder builder(opStmt);
- for (auto *extraIndex : extraIndices) {
- // TODO(mlir-team): An operation/SSA value should provide a method to
- // return the position of an SSA result in its defining
- // operation.
- assert(extraIndex->getDefiningStmt()->getNumResults() == 1 &&
- "single result op's expected to generate these indices");
- assert((cast<MLValue>(extraIndex)->isValidDim() ||
- cast<MLValue>(extraIndex)->isValidSymbol()) &&
- "invalid memory op index");
- operands.push_back(cast<MLValue>(extraIndex));
- }
-
- // Construct new indices. The indices of a memref come right after it, i.e.,
- // at position memRefOperandPos + 1.
- SmallVector<SSAValue *, 4> indices(
- opStmt->operand_begin() + memRefOperandPos + 1,
- opStmt->operand_begin() + memRefOperandPos + 1 + oldMemRefRank);
- if (indexRemap) {
- auto remapOp =
- builder.create<AffineApplyOp>(opStmt->getLoc(), indexRemap, indices);
- // Remapped indices.
- for (auto *index : remapOp->getOperation()->getResults())
- operands.push_back(cast<MLValue>(index));
- } else {
- // No remapping specified.
- for (auto *index : indices)
- operands.push_back(cast<MLValue>(index));
- }
-
- // Insert the remaining operands unmodified.
- operands.insert(operands.end(),
- opStmt->operand_begin() + memRefOperandPos + 1 +
- oldMemRefRank,
- opStmt->operand_end());
-
- // Result types don't change. Both memref's are of the same elemental type.
- SmallVector<Type *, 8> resultTypes;
- resultTypes.reserve(opStmt->getNumResults());
- for (const auto *result : opStmt->getResults())
- resultTypes.push_back(result->getType());
-
- // Create the new operation.
- auto *repOp =
- builder.createOperation(opStmt->getLoc(), opStmt->getName(), operands,
- resultTypes, opStmt->getAttrs());
- // Replace old memref's deferencing op's uses.
- unsigned r = 0;
- for (auto *res : opStmt->getResults()) {
- res->replaceAllUsesWith(repOp->getResult(r++));
- }
- opStmt->erase();
- }
- return true;
-}
-
-// Creates and inserts into 'builder' a new AffineApplyOp, with the number of
-// its results equal to the number of 'operands, as a composition
-// of all other AffineApplyOps reachable from input parameter 'operands'. If the
-// operands were drawing results from multiple affine apply ops, this also leads
-// to a collapse into a single affine apply op. The final results of the
-// composed AffineApplyOp are returned in output parameter 'results'.
-OperationStmt *
-mlir::createComposedAffineApplyOp(MLFuncBuilder *builder, Location *loc,
- ArrayRef<MLValue *> operands,
- ArrayRef<OperationStmt *> affineApplyOps,
- SmallVectorImpl<SSAValue *> &results) {
- // Create identity map with same number of dimensions as number of operands.
- auto map = builder->getMultiDimIdentityMap(operands.size());
- // Initialize AffineValueMap with identity map.
- AffineValueMap valueMap(map, operands);
-
- for (auto *opStmt : affineApplyOps) {
- assert(opStmt->isa<AffineApplyOp>());
- auto affineApplyOp = opStmt->cast<AffineApplyOp>();
- // Forward substitute 'affineApplyOp' into 'valueMap'.
- valueMap.forwardSubstitute(*affineApplyOp);
- }
- // Compose affine maps from all ancestor AffineApplyOps.
- // Create new AffineApplyOp from 'valueMap'.
- unsigned numOperands = valueMap.getNumOperands();
- SmallVector<SSAValue *, 4> outOperands(numOperands);
- for (unsigned i = 0; i < numOperands; ++i) {
- outOperands[i] = valueMap.getOperand(i);
- }
- // Create new AffineApplyOp based on 'valueMap'.
- auto affineApplyOp =
- builder->create<AffineApplyOp>(loc, valueMap.getAffineMap(), outOperands);
- results.resize(operands.size());
- for (unsigned i = 0, e = operands.size(); i < e; ++i) {
- results[i] = affineApplyOp->getResult(i);
- }
- return cast<OperationStmt>(affineApplyOp->getOperation());
-}
-
-/// Given an operation statement, inserts a new single affine apply operation,
-/// that is exclusively used by this operation statement, and that provides all
-/// operands that are results of an affine_apply as a function of loop iterators
-/// and program parameters and whose results are.
-///
-/// Before
-///
-/// for %i = 0 to #map(%N)
-/// %idx = affine_apply (d0) -> (d0 mod 2) (%i)
-/// "send"(%idx, %A, ...)
-/// "compute"(%idx)
-///
-/// After
-///
-/// for %i = 0 to #map(%N)
-/// %idx = affine_apply (d0) -> (d0 mod 2) (%i)
-/// "send"(%idx, %A, ...)
-/// %idx_ = affine_apply (d0) -> (d0 mod 2) (%i)
-/// "compute"(%idx_)
-///
-/// This allows applying different transformations on send and compute (for eg.
-/// different shifts/delays).
-///
-/// Returns nullptr either if none of opStmt's operands were the result of an
-/// affine_apply and thus there was no affine computation slice to create, or if
-/// all the affine_apply op's supplying operands to this opStmt do not have any
-/// uses besides this opStmt. Returns the new affine_apply operation statement
-/// otherwise.
-OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) {
- // Collect all operands that are results of affine apply ops.
- SmallVector<MLValue *, 4> subOperands;
- subOperands.reserve(opStmt->getNumOperands());
- for (auto *operand : opStmt->getOperands()) {
- auto *defStmt = operand->getDefiningStmt();
- if (defStmt && defStmt->isa<AffineApplyOp>()) {
- subOperands.push_back(operand);
- }
- }
-
- // Gather sequence of AffineApplyOps reachable from 'subOperands'.
- SmallVector<OperationStmt *, 4> affineApplyOps;
- getReachableAffineApplyOps(subOperands, affineApplyOps);
- // Skip transforming if there are no affine maps to compose.
- if (affineApplyOps.empty())
- return nullptr;
-
- // Check if all uses of the affine apply op's lie in this op stmt
- // itself, in which case there would be nothing to do.
- bool localized = true;
- for (auto *op : affineApplyOps) {
- for (auto *result : op->getResults()) {
- for (auto &use : result->getUses()) {
- if (use.getOwner() != opStmt) {
- localized = false;
- break;
- }
- }
- }
- }
- if (localized)
- return nullptr;
-
- MLFuncBuilder builder(opStmt);
- SmallVector<SSAValue *, 4> results;
- auto *affineApplyStmt = createComposedAffineApplyOp(
- &builder, opStmt->getLoc(), subOperands, affineApplyOps, results);
- assert(results.size() == subOperands.size() &&
- "number of results should be the same as the number of subOperands");
-
- // Construct the new operands that include the results from the composed
- // affine apply op above instead of existing ones (subOperands). So, they
- // differ from opStmt's operands only for those operands in 'subOperands', for
- // which they will be replaced by the corresponding one from 'results'.
- SmallVector<MLValue *, 4> newOperands(opStmt->getOperands());
- for (unsigned i = 0, e = newOperands.size(); i < e; i++) {
- // Replace the subOperands from among the new operands.
- unsigned j, f;
- for (j = 0, f = subOperands.size(); j < f; j++) {
- if (newOperands[i] == subOperands[j])
- break;
- }
- if (j < subOperands.size()) {
- newOperands[i] = cast<MLValue>(results[j]);
- }
- }
-
- for (unsigned idx = 0; idx < newOperands.size(); idx++) {
- opStmt->setOperand(idx, newOperands[idx]);
- }
-
- return affineApplyStmt;
-}
-
-void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
- if (affineApplyOp->getOperation()->getOperationFunction()->getKind() !=
- Function::Kind::MLFunc) {
- // TODO: Support forward substitution for CFGFunctions.
- return;
- }
- auto *opStmt = cast<OperationStmt>(affineApplyOp->getOperation());
- // Iterate through all uses of all results of 'opStmt', forward substituting
- // into any uses which are AffineApplyOps.
- for (unsigned resultIndex = 0, e = opStmt->getNumResults(); resultIndex < e;
- ++resultIndex) {
- const MLValue *result = opStmt->getResult(resultIndex);
- for (auto it = result->use_begin(); it != result->use_end();) {
- StmtOperand &use = *(it++);
- auto *useStmt = use.getOwner();
- auto *useOpStmt = dyn_cast<OperationStmt>(useStmt);
- // Skip if use is not AffineApplyOp.
- if (useOpStmt == nullptr || !useOpStmt->isa<AffineApplyOp>())
- continue;
- // Advance iterator past 'opStmt' operands which also use 'result'.
- while (it != result->use_end() && it->getOwner() == useStmt)
- ++it;
-
- MLFuncBuilder builder(useOpStmt);
- // Initialize AffineValueMap with 'affineApplyOp' which uses 'result'.
- auto oldAffineApplyOp = useOpStmt->cast<AffineApplyOp>();
- AffineValueMap valueMap(*oldAffineApplyOp);
- // Forward substitute 'result' at index 'i' into 'valueMap'.
- valueMap.forwardSubstituteSingle(*affineApplyOp, resultIndex);
-
- // Create new AffineApplyOp from 'valueMap'.
- unsigned numOperands = valueMap.getNumOperands();
- SmallVector<SSAValue *, 4> operands(numOperands);
- for (unsigned i = 0; i < numOperands; ++i) {
- operands[i] = valueMap.getOperand(i);
- }
- auto newAffineApplyOp = builder.create<AffineApplyOp>(
- useOpStmt->getLoc(), valueMap.getAffineMap(), operands);
-
- // Update all uses to use results from 'newAffineApplyOp'.
- for (unsigned i = 0, e = useOpStmt->getNumResults(); i < e; ++i) {
- oldAffineApplyOp->getResult(i)->replaceAllUsesWith(
- newAffineApplyOp->getResult(i));
- }
- // Erase 'oldAffineApplyOp'.
- oldAffineApplyOp->getOperation()->erase();
- }
- }
-}
-
-/// Folds the specified (lower or upper) bound to a constant if possible
-/// considering its operands. Returns false if the folding happens for any of
-/// the bounds, true otherwise.
-bool mlir::constantFoldBounds(ForStmt *forStmt) {
- auto foldLowerOrUpperBound = [forStmt](bool lower) {
- // Check if the bound is already a constant.
- if (lower && forStmt->hasConstantLowerBound())
- return true;
- if (!lower && forStmt->hasConstantUpperBound())
- return true;
-
- // Check to see if each of the operands is the result of a constant. If so,
- // get the value. If not, ignore it.
- SmallVector<Attribute *, 8> operandConstants;
- auto boundOperands = lower ? forStmt->getLowerBoundOperands()
- : forStmt->getUpperBoundOperands();
- for (const auto *operand : boundOperands) {
- Attribute *operandCst = nullptr;
- if (auto *operandOp = operand->getDefiningOperation()) {
- if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
- operandCst = operandConstantOp->getValue();
- }
- operandConstants.push_back(operandCst);
- }
-
- AffineMap boundMap =
- lower ? forStmt->getLowerBoundMap() : forStmt->getUpperBoundMap();
- assert(boundMap.getNumResults() >= 1 &&
- "bound maps should have at least one result");
- SmallVector<Attribute *, 4> foldedResults;
- if (boundMap.constantFold(operandConstants, foldedResults))
- return true;
-
- // Compute the max or min as applicable over the results.
- assert(!foldedResults.empty() && "bounds should have at least one result");
- auto maxOrMin = cast<IntegerAttr>(foldedResults[0])->getValue();
- for (unsigned i = 1; i < foldedResults.size(); i++) {
- auto foldedResult = cast<IntegerAttr>(foldedResults[i])->getValue();
- maxOrMin = lower ? std::max(maxOrMin, foldedResult)
- : std::min(maxOrMin, foldedResult);
- }
- lower ? forStmt->setConstantLowerBound(maxOrMin)
- : forStmt->setConstantUpperBound(maxOrMin);
-
- // Return false on success.
- return false;
- };
-
- bool ret = foldLowerOrUpperBound(/*lower=*/true);
- ret &= foldLowerOrUpperBound(/*lower=*/false);
- return ret;
-}
--- /dev/null
+//===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements mlir::applyPatternsGreedily.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/StandardOps/StandardOps.h"
+#include "mlir/Transforms/PatternMatch.h"
+#include "llvm/ADT/DenseMap.h"
+using namespace mlir;
+
+namespace {
+class WorklistRewriter;
+
+/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
+/// applies the locally optimal patterns in a roughly "bottom up" way.
+class GreedyPatternRewriteDriver {
+public:
+ explicit GreedyPatternRewriteDriver(OwningPatternList &&patterns)
+ : matcher(std::move(patterns)) {
+ worklist.reserve(64);
+ }
+
+ void simplifyFunction(Function *currentFunction, WorklistRewriter &rewriter);
+
+ void addToWorklist(Operation *op) {
+ worklistMap[op] = worklist.size();
+ worklist.push_back(op);
+ }
+
+ Operation *popFromWorklist() {
+ auto *op = worklist.back();
+ worklist.pop_back();
+
+ // This operation is no longer in the worklist, keep worklistMap up to date.
+ if (op)
+ worklistMap.erase(op);
+ return op;
+ }
+
+ /// If the specified operation is in the worklist, remove it. If not, this is
+ /// a no-op.
+ void removeFromWorklist(Operation *op) {
+ auto it = worklistMap.find(op);
+ if (it != worklistMap.end()) {
+ assert(worklist[it->second] == op && "malformed worklist data structure");
+ worklist[it->second] = nullptr;
+ }
+ }
+
+private:
+ /// The low-level pattern matcher.
+ PatternMatcher matcher;
+
+ /// The worklist for this transformation keeps track of the operations that
+ /// need to be revisited, plus their index in the worklist. This allows us to
+ /// efficiently remove operations from the worklist when they are removed even
+ /// if they aren't the root of a pattern.
+ std::vector<Operation *> worklist;
+ DenseMap<Operation *, unsigned> worklistMap;
+
+ /// As part of canonicalization, we move constants to the top of the entry
+ /// block of the current function and de-duplicate them. This keeps track of
+ /// constants we have done this for.
+ DenseMap<std::pair<Attribute *, Type *>, Operation *> uniquedConstants;
+};
+}; // end anonymous namespace
+
+/// This is a listener object that updates our worklists and other data
+/// structures in response to operations being added and removed.
+namespace {
+class WorklistRewriter : public PatternRewriter {
+public:
+ WorklistRewriter(GreedyPatternRewriteDriver &driver, MLIRContext *context)
+ : PatternRewriter(context), driver(driver) {}
+
+ virtual void setInsertionPoint(Operation *op) = 0;
+
+ // If an operation is about to be removed, make sure it is not in our
+ // worklist anymore because we'd get dangling references to it.
+ void notifyOperationRemoved(Operation *op) override {
+ driver.removeFromWorklist(op);
+ }
+
+ GreedyPatternRewriteDriver &driver;
+};
+
+} // end anonymous namespace
+
+void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
+ WorklistRewriter &rewriter) {
+ // These are scratch vectors used in the constant folding loop below.
+ SmallVector<Attribute *, 8> operandConstants, resultConstants;
+
+ while (!worklist.empty()) {
+ auto *op = popFromWorklist();
+
+ // Nulls get added to the worklist when operations are removed, ignore them.
+ if (op == nullptr)
+ continue;
+
+ // If we have a constant op, unique it into the entry block.
+ if (auto constant = op->dyn_cast<ConstantOp>()) {
+ // If this constant is dead, remove it, being careful to keep
+ // uniquedConstants up to date.
+ if (constant->use_empty()) {
+ auto it =
+ uniquedConstants.find({constant->getValue(), constant->getType()});
+ if (it != uniquedConstants.end() && it->second == op)
+ uniquedConstants.erase(it);
+ constant->erase();
+ continue;
+ }
+
+ // Check to see if we already have a constant with this type and value:
+ auto &entry = uniquedConstants[std::make_pair(constant->getValue(),
+ constant->getType())];
+ if (entry) {
+ // If this constant is already our uniqued one, then leave it alone.
+ if (entry == op)
+ continue;
+
+ // Otherwise replace this redundant constant with the uniqued one. We
+ // know this is safe because we move constants to the top of the
+ // function when they are uniqued, so we know they dominate all uses.
+ constant->replaceAllUsesWith(entry->getResult(0));
+ constant->erase();
+ continue;
+ }
+
+ // If we have no entry, then we should unique this constant as the
+ // canonical version. To ensure safe dominance, move the operation to the
+ // top of the function.
+ entry = op;
+
+ // TODO: If we make terminators into Operations then we could turn this
+ // into a nice Operation::moveBefore(Operation*) method. We just need the
+ // guarantee that a block is non-empty.
+ if (auto *cfgFunc = dyn_cast<CFGFunction>(currentFunction)) {
+ auto &entryBB = cfgFunc->front();
+ cast<OperationInst>(op)->moveBefore(&entryBB, entryBB.begin());
+ } else {
+ auto *mlFunc = cast<MLFunction>(currentFunction);
+ cast<OperationStmt>(op)->moveBefore(mlFunc, mlFunc->begin());
+ }
+
+ continue;
+ }
+
+ // If the operation has no side effects, and no users, then it is trivially
+ // dead - remove it.
+ if (op->hasNoSideEffect() && op->use_empty()) {
+ op->erase();
+ continue;
+ }
+
+ // Check to see if any operands to the instruction is constant and whether
+ // the operation knows how to constant fold itself.
+ operandConstants.clear();
+ for (auto *operand : op->getOperands()) {
+ Attribute *operandCst = nullptr;
+ if (auto *operandOp = operand->getDefiningOperation()) {
+ if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
+ operandCst = operandConstantOp->getValue();
+ }
+ operandConstants.push_back(operandCst);
+ }
+
+ // If constant folding was successful, create the result constants, RAUW the
+ // operation and remove it.
+ resultConstants.clear();
+ if (!op->constantFold(operandConstants, resultConstants)) {
+ rewriter.setInsertionPoint(op);
+
+ for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
+ auto *res = op->getResult(i);
+ if (res->use_empty()) // ignore dead uses.
+ continue;
+
+ // If we already have a canonicalized version of this constant, just
+ // reuse it. Otherwise create a new one.
+ SSAValue *cstValue;
+ auto it = uniquedConstants.find({resultConstants[i], res->getType()});
+ if (it != uniquedConstants.end())
+ cstValue = it->second->getResult(0);
+ else
+ cstValue = rewriter.create<ConstantOp>(
+ op->getLoc(), resultConstants[i], res->getType());
+ res->replaceAllUsesWith(cstValue);
+ }
+
+ assert(op->hasNoSideEffect() && "Constant folded op with side effects?");
+ op->erase();
+ continue;
+ }
+
+ // If this is an associative binary operation with a constant on the LHS,
+ // move it to the right side.
+ if (operandConstants.size() == 2 && operandConstants[0] &&
+ !operandConstants[1]) {
+ auto *newLHS = op->getOperand(1);
+ op->setOperand(1, op->getOperand(0));
+ op->setOperand(0, newLHS);
+ }
+
+ // Check to see if we have any patterns that match this node.
+ auto match = matcher.findMatch(op);
+ if (!match.first)
+ continue;
+
+ // Make sure that any new operations are inserted at this point.
+ rewriter.setInsertionPoint(op);
+ match.first->rewrite(op, std::move(match.second), rewriter);
+ }
+
+ uniquedConstants.clear();
+}
+
+static void processMLFunction(MLFunction *fn, OwningPatternList &&patterns) {
+ class MLFuncRewriter : public WorklistRewriter {
+ public:
+ MLFuncRewriter(GreedyPatternRewriteDriver &driver, MLFuncBuilder &builder)
+ : WorklistRewriter(driver, builder.getContext()), builder(builder) {}
+
+ // Implement the hook for creating operations, and make sure that newly
+ // created ops are added to the worklist for processing.
+ Operation *createOperation(const OperationState &state) override {
+ auto *result = builder.createOperation(state);
+ driver.addToWorklist(result);
+ return result;
+ }
+
+ // When the root of a pattern is about to be replaced, it can trigger
+ // simplifications to its users - make sure to add them to the worklist
+ // before the root is changed.
+ void notifyRootReplaced(Operation *op) override {
+ auto *opStmt = cast<OperationStmt>(op);
+ for (auto *result : opStmt->getResults())
+ // TODO: Add a result->getUsers() iterator.
+ for (auto &user : result->getUses()) {
+ if (auto *op = dyn_cast<OperationStmt>(user.getOwner()))
+ driver.addToWorklist(op);
+ }
+
+ // TODO: Walk the operand list dropping them as we go. If any of them
+ // drop to zero uses, then add them to the worklist to allow them to be
+ // deleted as dead.
+ }
+
+ void setInsertionPoint(Operation *op) override {
+ // Any new operations should be added before this statement.
+ builder.setInsertionPoint(cast<OperationStmt>(op));
+ }
+
+ private:
+ MLFuncBuilder &builder;
+ };
+
+ GreedyPatternRewriteDriver driver(std::move(patterns));
+ fn->walk([&](OperationStmt *stmt) { driver.addToWorklist(stmt); });
+
+ MLFuncBuilder mlBuilder(fn);
+ MLFuncRewriter rewriter(driver, mlBuilder);
+ driver.simplifyFunction(fn, rewriter);
+}
+
+static void processCFGFunction(CFGFunction *fn, OwningPatternList &&patterns) {
+ class CFGFuncRewriter : public WorklistRewriter {
+ public:
+ CFGFuncRewriter(GreedyPatternRewriteDriver &driver, CFGFuncBuilder &builder)
+ : WorklistRewriter(driver, builder.getContext()), builder(builder) {}
+
+ // Implement the hook for creating operations, and make sure that newly
+ // created ops are added to the worklist for processing.
+ Operation *createOperation(const OperationState &state) override {
+ auto *result = builder.createOperation(state);
+ driver.addToWorklist(result);
+ return result;
+ }
+
+ // When the root of a pattern is about to be replaced, it can trigger
+ // simplifications to its users - make sure to add them to the worklist
+ // before the root is changed.
+ void notifyRootReplaced(Operation *op) override {
+ auto *opStmt = cast<OperationInst>(op);
+ for (auto *result : opStmt->getResults())
+ // TODO: Add a result->getUsers() iterator.
+ for (auto &user : result->getUses()) {
+ if (auto *op = dyn_cast<OperationInst>(user.getOwner()))
+ driver.addToWorklist(op);
+ }
+
+ // TODO: Walk the operand list dropping them as we go. If any of them
+ // drop to zero uses, then add them to the worklist to allow them to be
+ // deleted as dead.
+ }
+
+ void setInsertionPoint(Operation *op) override {
+ // Any new operations should be added before this instruction.
+ builder.setInsertionPoint(cast<OperationInst>(op));
+ }
+
+ private:
+ CFGFuncBuilder &builder;
+ };
+
+ GreedyPatternRewriteDriver driver(std::move(patterns));
+ for (auto &bb : *fn)
+ for (auto &op : bb)
+ driver.addToWorklist(&op);
+
+ CFGFuncBuilder cfgBuilder(fn);
+ CFGFuncRewriter rewriter(driver, cfgBuilder);
+ driver.simplifyFunction(fn, rewriter);
+}
+
+/// Rewrite the specified function by repeatedly applying the highest benefit
+/// patterns in a greedy work-list driven manner.
+///
+void mlir::applyPatternsGreedily(Function *fn, OwningPatternList &&patterns) {
+ if (auto *cfg = dyn_cast<CFGFunction>(fn)) {
+ processCFGFunction(cfg, std::move(patterns));
+ } else {
+ processMLFunction(cast<MLFunction>(fn), std::move(patterns));
+ }
+}
--- /dev/null
+//===- LoopUtils.cpp ---- Misc utilities for loop transformation ----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements miscellaneous loop transformation routines.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/LoopUtils.h"
+
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Statements.h"
+#include "mlir/IR/StmtVisitor.h"
+#include "mlir/StandardOps/StandardOps.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "LoopUtils"
+
+using namespace mlir;
+
+/// Returns the upper bound of an unrolled loop with lower bound 'lb' and with
+/// the specified trip count, stride, and unroll factor. Returns nullptr when
+/// the trip count can't be expressed as an affine expression.
+AffineMap mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt,
+ unsigned unrollFactor,
+ MLFuncBuilder *builder) {
+ auto lbMap = forStmt.getLowerBoundMap();
+
+ // Single result lower bound map only.
+ if (lbMap.getNumResults() != 1)
+ return AffineMap::Null();
+
+ // Sometimes, the trip count cannot be expressed as an affine expression.
+ auto tripCount = getTripCountExpr(forStmt);
+ if (!tripCount)
+ return AffineMap::Null();
+
+ AffineExpr lb(lbMap.getResult(0));
+ unsigned step = forStmt.getStep();
+ auto newUb = lb + (tripCount - tripCount % unrollFactor - 1) * step;
+
+ return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(),
+ {newUb}, {});
+}
+
+/// Returns the lower bound of the cleanup loop when unrolling a loop with lower
+/// bound 'lb' and with the specified trip count, stride, and unroll factor.
+/// Returns an AffinMap with nullptr storage (that evaluates to false)
+/// when the trip count can't be expressed as an affine expression.
+AffineMap mlir::getCleanupLoopLowerBound(const ForStmt &forStmt,
+ unsigned unrollFactor,
+ MLFuncBuilder *builder) {
+ auto lbMap = forStmt.getLowerBoundMap();
+
+ // Single result lower bound map only.
+ if (lbMap.getNumResults() != 1)
+ return AffineMap::Null();
+
+ // Sometimes the trip count cannot be expressed as an affine expression.
+ AffineExpr tripCount(getTripCountExpr(forStmt));
+ if (!tripCount)
+ return AffineMap::Null();
+
+ AffineExpr lb(lbMap.getResult(0));
+ unsigned step = forStmt.getStep();
+ auto newLb = lb + (tripCount - tripCount % unrollFactor) * step;
+ return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(),
+ {newLb}, {});
+}
+
+/// Promotes the loop body of a forStmt to its containing block if the forStmt
+/// was known to have a single iteration. Returns false otherwise.
+// TODO(bondhugula): extend this for arbitrary affine bounds.
+bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
+ Optional<uint64_t> tripCount = getConstantTripCount(*forStmt);
+ if (!tripCount.hasValue() || tripCount.getValue() != 1)
+ return false;
+
+ // TODO(mlir-team): there is no builder for a max.
+ if (forStmt->getLowerBoundMap().getNumResults() != 1)
+ return false;
+
+ // Replaces all IV uses to its single iteration value.
+ if (!forStmt->use_empty()) {
+ if (forStmt->hasConstantLowerBound()) {
+ auto *mlFunc = forStmt->findFunction();
+ MLFuncBuilder topBuilder(&mlFunc->front());
+ auto constOp = topBuilder.create<ConstantIndexOp>(
+ forStmt->getLoc(), forStmt->getConstantLowerBound());
+ forStmt->replaceAllUsesWith(constOp);
+ } else {
+ const AffineBound lb = forStmt->getLowerBound();
+ SmallVector<SSAValue *, 4> lbOperands(lb.operand_begin(),
+ lb.operand_end());
+ MLFuncBuilder builder(forStmt->getBlock(), StmtBlock::iterator(forStmt));
+ auto affineApplyOp = builder.create<AffineApplyOp>(
+ forStmt->getLoc(), lb.getMap(), lbOperands);
+ forStmt->replaceAllUsesWith(affineApplyOp->getResult(0));
+ }
+ }
+ // Move the loop body statements to the loop's containing block.
+ auto *block = forStmt->getBlock();
+ block->getStatements().splice(StmtBlock::iterator(forStmt),
+ forStmt->getStatements());
+ forStmt->erase();
+ return true;
+}
+
+/// Promotes all single iteration for stmt's in the MLFunction, i.e., moves
+/// their body into the containing StmtBlock.
+void mlir::promoteSingleIterationLoops(MLFunction *f) {
+ // Gathers all innermost loops through a post order pruned walk.
+ class LoopBodyPromoter : public StmtWalker<LoopBodyPromoter> {
+ public:
+ void visitForStmt(ForStmt *forStmt) { promoteIfSingleIteration(forStmt); }
+ };
+
+ LoopBodyPromoter fsw;
+ fsw.walkPostOrder(f);
+}
+
+/// Generates a for 'stmt' with the specified lower and upper bounds while
+/// generating the right IV remappings for the delayed statements. The
+/// statement blocks that go into the loop are specified in stmtGroupQueue
+/// starting from the specified offset, and in that order; the first element of
+/// the pair specifies the delay applied to that group of statements. Returns
+/// nullptr if the generated loop simplifies to a single iteration one.
+static ForStmt *
+generateLoop(AffineMap lb, AffineMap ub,
+ const std::vector<std::pair<uint64_t, ArrayRef<Statement *>>>
+ &stmtGroupQueue,
+ unsigned offset, ForStmt *srcForStmt, MLFuncBuilder *b) {
+ SmallVector<MLValue *, 4> lbOperands(srcForStmt->getLowerBoundOperands());
+ SmallVector<MLValue *, 4> ubOperands(srcForStmt->getUpperBoundOperands());
+
+ auto *loopChunk =
+ b->createFor(srcForStmt->getLoc(), lbOperands, lb, ubOperands, ub);
+ OperationStmt::OperandMapTy operandMap;
+
+ for (auto it = stmtGroupQueue.begin() + offset, e = stmtGroupQueue.end();
+ it != e; ++it) {
+ auto elt = *it;
+ // All 'same delay' statements get added with the operands being remapped
+ // (to results of cloned statements).
+ // Generate the remapping if the delay is not zero: oldIV = newIV - delay.
+ // TODO(bondhugula): check if srcForStmt is actually used in elt.second
+ // instead of just checking if it's used at all.
+ if (!srcForStmt->use_empty() && elt.first != 0) {
+ auto b = MLFuncBuilder::getForStmtBodyBuilder(loopChunk);
+ auto *oldIV =
+ b.create<AffineApplyOp>(
+ srcForStmt->getLoc(),
+ b.getSingleDimShiftAffineMap(-static_cast<int64_t>(elt.first)),
+ loopChunk)
+ ->getResult(0);
+ operandMap[srcForStmt] = cast<MLValue>(oldIV);
+ } else {
+ operandMap[srcForStmt] = static_cast<MLValue *>(loopChunk);
+ }
+ for (auto *stmt : elt.second) {
+ loopChunk->push_back(stmt->clone(operandMap, b->getContext()));
+ }
+ }
+ if (promoteIfSingleIteration(loopChunk))
+ return nullptr;
+ return loopChunk;
+}
+
+/// Skew the statements in the body of a 'for' statement with the specified
+/// statement-wise delays. The delays are with respect to the original execution
+/// order. A delay of zero for each statement will lead to no change.
+// The skewing of statements with respect to one another can be used for example
+// to allow overlap of asynchronous operations (such as DMA communication) with
+// computation, or just relative shifting of statements for better register
+// reuse, locality or parallelism. As such, the delays are typically expected to
+// be at most of the order of the number of statements. This method should not
+// be used as a substitute for loop distribution/fission.
+// This method uses an algorithm// in time linear in the number of statements in
+// the body of the for loop - (using the 'sweep line' paradigm). This method
+// asserts preservation of SSA dominance. A check for that as well as that for
+// memory-based depedence preservation check rests with the users of this
+// method.
+UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays,
+ bool unrollPrologueEpilogue) {
+ if (forStmt->getStatements().empty())
+ return UtilResult::Success;
+
+ // If the trip counts aren't constant, we would need versioning and
+ // conditional guards (or context information to prevent such versioning). The
+ // better way to pipeline for such loops is to first tile them and extract
+ // constant trip count "full tiles" before applying this.
+ auto mayBeConstTripCount = getConstantTripCount(*forStmt);
+ if (!mayBeConstTripCount.hasValue()) {
+ LLVM_DEBUG(llvm::dbgs() << "non-constant trip count loop\n";);
+ return UtilResult::Success;
+ }
+ uint64_t tripCount = mayBeConstTripCount.getValue();
+
+ assert(isStmtwiseShiftValid(*forStmt, delays) &&
+ "shifts will lead to an invalid transformation\n");
+
+ unsigned numChildStmts = forStmt->getStatements().size();
+
+ // Do a linear time (counting) sort for the delays.
+ uint64_t maxDelay = 0;
+ for (unsigned i = 0; i < numChildStmts; i++) {
+ maxDelay = std::max(maxDelay, delays[i]);
+ }
+ // Such large delays are not the typical use case.
+ if (maxDelay >= numChildStmts) {
+ LLVM_DEBUG(llvm::dbgs() << "stmt delays too large - unexpected\n";);
+ return UtilResult::Success;
+ }
+
+ // An array of statement groups sorted by delay amount; each group has all
+ // statements with the same delay in the order in which they appear in the
+ // body of the 'for' stmt.
+ std::vector<std::vector<Statement *>> sortedStmtGroups(maxDelay + 1);
+ unsigned pos = 0;
+ for (auto &stmt : *forStmt) {
+ auto delay = delays[pos++];
+ sortedStmtGroups[delay].push_back(&stmt);
+ }
+
+ // Unless the shifts have a specific pattern (which actually would be the
+ // common use case), prologue and epilogue are not meaningfully defined.
+ // Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first
+ // loop generated as the prologue and the last as epilogue and unroll these
+ // fully.
+ ForStmt *prologue = nullptr;
+ ForStmt *epilogue = nullptr;
+
+ // Do a sweep over the sorted delays while storing open groups in a
+ // vector, and generating loop portions as necessary during the sweep. A block
+ // of statements is paired with its delay.
+ std::vector<std::pair<uint64_t, ArrayRef<Statement *>>> stmtGroupQueue;
+
+ auto origLbMap = forStmt->getLowerBoundMap();
+ uint64_t lbDelay = 0;
+ MLFuncBuilder b(forStmt);
+ for (uint64_t d = 0, e = sortedStmtGroups.size(); d < e; ++d) {
+ // If nothing is delayed by d, continue.
+ if (sortedStmtGroups[d].empty())
+ continue;
+ if (!stmtGroupQueue.empty()) {
+ assert(d >= 1 &&
+ "Queue expected to be empty when the first block is found");
+ // The interval for which the loop needs to be generated here is:
+ // ( lbDelay, min(lbDelay + tripCount - 1, d - 1) ] and the body of the
+ // loop needs to have all statements in stmtQueue in that order.
+ ForStmt *res;
+ if (lbDelay + tripCount - 1 < d - 1) {
+ res = generateLoop(
+ b.getShiftedAffineMap(origLbMap, lbDelay),
+ b.getShiftedAffineMap(origLbMap, lbDelay + tripCount - 1),
+ stmtGroupQueue, 0, forStmt, &b);
+ // Entire loop for the queued stmt groups generated, empty it.
+ stmtGroupQueue.clear();
+ lbDelay += tripCount;
+ } else {
+ res = generateLoop(b.getShiftedAffineMap(origLbMap, lbDelay),
+ b.getShiftedAffineMap(origLbMap, d - 1),
+ stmtGroupQueue, 0, forStmt, &b);
+ lbDelay = d;
+ }
+ if (!prologue && res)
+ prologue = res;
+ epilogue = res;
+ } else {
+ // Start of first interval.
+ lbDelay = d;
+ }
+ // Augment the list of statements that get into the current open interval.
+ stmtGroupQueue.push_back({d, sortedStmtGroups[d]});
+ }
+
+ // Those statements groups left in the queue now need to be processed (FIFO)
+ // and their loops completed.
+ for (unsigned i = 0, e = stmtGroupQueue.size(); i < e; ++i) {
+ uint64_t ubDelay = stmtGroupQueue[i].first + tripCount - 1;
+ epilogue = generateLoop(b.getShiftedAffineMap(origLbMap, lbDelay),
+ b.getShiftedAffineMap(origLbMap, ubDelay),
+ stmtGroupQueue, i, forStmt, &b);
+ lbDelay = ubDelay + 1;
+ if (!prologue)
+ prologue = epilogue;
+ }
+
+ // Erase the original for stmt.
+ forStmt->erase();
+
+ if (unrollPrologueEpilogue && prologue)
+ loopUnrollFull(prologue);
+ if (unrollPrologueEpilogue && !epilogue && epilogue != prologue)
+ loopUnrollFull(epilogue);
+
+ return UtilResult::Success;
+}
--- /dev/null
+//===- Pass.cpp - Pass infrastructure implementation ----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements common pass infrastructure.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Pass.h"
+#include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Module.h"
+
+using namespace mlir;
+
+/// Function passes walk a module and look at each function with their
+/// corresponding hooks and terminates upon error encountered.
+PassResult FunctionPass::runOnModule(Module *m) {
+ for (auto &fn : *m) {
+ if (auto *mlFunc = dyn_cast<MLFunction>(&fn))
+ if (runOnMLFunction(mlFunc))
+ return failure();
+ if (auto *cfgFunc = dyn_cast<CFGFunction>(&fn))
+ if (runOnCFGFunction(cfgFunc))
+ return failure();
+ }
+ return success();
+}
--- /dev/null
+//===- PatternMatch.cpp - Base classes for pattern match ------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/SSAValue.h"
+#include "mlir/IR/Statements.h"
+#include "mlir/StandardOps/StandardOps.h"
+#include "mlir/Transforms/PatternMatch.h"
+using namespace mlir;
+
+PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
+ assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
+ "This pattern match benefit is too large to represent");
+}
+
+unsigned short PatternBenefit::getBenefit() const {
+ assert(representation != ImpossibleToMatchSentinel &&
+ "Pattern doesn't match");
+ return representation;
+}
+
+bool PatternBenefit::operator==(const PatternBenefit& other) {
+ if (isImpossibleToMatch())
+ return other.isImpossibleToMatch();
+ if (other.isImpossibleToMatch())
+ return false;
+ return getBenefit() == other.getBenefit();
+}
+
+bool PatternBenefit::operator!=(const PatternBenefit& other) {
+ return !(*this == other);
+}
+
+//===----------------------------------------------------------------------===//
+// Pattern implementation
+//===----------------------------------------------------------------------===//
+
+Pattern::Pattern(StringRef rootName, MLIRContext *context,
+ Optional<PatternBenefit> staticBenefit)
+ : rootKind(OperationName(rootName, context)), staticBenefit(staticBenefit) {
+}
+
+Pattern::Pattern(StringRef rootName, MLIRContext *context,
+ unsigned staticBenefit)
+ : rootKind(rootName, context), staticBenefit(staticBenefit) {}
+
+Optional<PatternBenefit> Pattern::getStaticBenefit() const {
+ return staticBenefit;
+}
+
+OperationName Pattern::getRootKind() const { return rootKind; }
+
+void Pattern::rewrite(Operation *op, std::unique_ptr<PatternState> state,
+ PatternRewriter &rewriter) const {
+ rewrite(op, rewriter);
+}
+
+void Pattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
+ llvm_unreachable("need to implement one of the rewrite functions!");
+}
+
+/// This method indicates that no match was found.
+PatternMatchResult Pattern::matchFailure() {
+ return {PatternBenefit::impossibleToMatch(), std::unique_ptr<PatternState>()};
+}
+
+/// This method indicates that a match was found and has the specified cost.
+PatternMatchResult
+Pattern::matchSuccess(PatternBenefit benefit,
+ std::unique_ptr<PatternState> state) const {
+ assert((!getStaticBenefit().hasValue() ||
+ getStaticBenefit().getValue() == benefit) &&
+ "This version of matchSuccess must be called with a benefit that "
+ "matches the static benefit if set!");
+
+ return {benefit, std::move(state)};
+}
+
+/// This method indicates that a match was found for patterns that have a
+/// known static benefit.
+PatternMatchResult
+Pattern::matchSuccess(std::unique_ptr<PatternState> state) const {
+ auto benefit = getStaticBenefit();
+ assert(benefit.hasValue() && "Pattern doesn't have a static benefit");
+ return matchSuccess(benefit.getValue(), std::move(state));
+}
+
+//===----------------------------------------------------------------------===//
+// PatternRewriter implementation
+//===----------------------------------------------------------------------===//
+
+PatternRewriter::~PatternRewriter() {
+ // Out of line to provide a vtable anchor for the class.
+}
+
+/// This method is used as the final replacement hook for patterns that match
+/// a single result value. In addition to replacing and removing the
+/// specified operation, clients can specify a list of other nodes that this
+/// replacement may make (perhaps transitively) dead. If any of those ops are
+/// dead, this will remove them as well.
+void PatternRewriter::replaceSingleResultOp(
+ Operation *op, SSAValue *newValue, ArrayRef<SSAValue *> opsToRemoveIfDead) {
+ // Notify the rewriter subclass that we're about to replace this root.
+ notifyRootReplaced(op);
+
+ assert(op->getNumResults() == 1 && "op isn't a SingleResultOp!");
+ op->getResult(0)->replaceAllUsesWith(newValue);
+
+ notifyOperationRemoved(op);
+ op->erase();
+
+ // TODO: Process the opsToRemoveIfDead list, removing things and calling the
+ // notifyOperationRemoved hook in the process.
+}
+
+/// This method is used as the final notification hook for patterns that end
+/// up modifying the pattern root in place, by changing its operands. This is
+/// a minor efficiency win (it avoids creating a new instruction and removing
+/// the old one) but also often allows simpler code in the client.
+///
+/// The opsToRemoveIfDead list is an optional list of nodes that the rewriter
+/// should remove if they are dead at this point.
+///
+void PatternRewriter::updatedRootInPlace(
+ Operation *op, ArrayRef<SSAValue *> opsToRemoveIfDead) {
+ // Notify the rewriter subclass that we're about to replace this root.
+ notifyRootUpdated(op);
+
+ // TODO: Process the opsToRemoveIfDead list, removing things and calling the
+ // notifyOperationRemoved hook in the process.
+}
+
+//===----------------------------------------------------------------------===//
+// PatternMatcher implementation
+//===----------------------------------------------------------------------===//
+
+/// Find the highest benefit pattern available in the pattern set for the DAG
+/// rooted at the specified node. This returns the pattern if found, or null
+/// if there are no matches.
+auto PatternMatcher::findMatch(Operation *op) -> MatchResult {
+ // TODO: This is a completely trivial implementation, expand this in the
+ // future.
+
+ // Keep track of the best match, the benefit of it, and any matcher specific
+ // state it is maintaining.
+ MatchResult bestMatch = {nullptr, nullptr};
+ Optional<PatternBenefit> bestBenefit;
+
+ for (auto &pattern : patterns) {
+ // Ignore patterns that are for the wrong root.
+ if (pattern->getRootKind() != op->getName())
+ continue;
+
+ // If we know the static cost of the pattern is worse than what we've
+ // already found then don't run it.
+ auto staticBenefit = pattern->getStaticBenefit();
+ if (staticBenefit.hasValue() && bestBenefit.hasValue() &&
+ staticBenefit.getValue().getBenefit() <
+ bestBenefit.getValue().getBenefit())
+ continue;
+
+ // Check to see if this pattern matches this node.
+ auto result = pattern->match(op);
+ auto benefit = result.first;
+
+ // If this pattern failed to match, ignore it.
+ if (benefit.isImpossibleToMatch())
+ continue;
+
+ // If it matched but had lower benefit than our best match so far, then
+ // ignore it.
+ if (bestBenefit.hasValue() &&
+ benefit.getBenefit() < bestBenefit.getValue().getBenefit())
+ continue;
+
+ // Okay we found a match that is better than our previous one, remember it.
+ bestBenefit = benefit;
+ bestMatch = {pattern.get(), std::move(result.second)};
+ }
+
+ // If we found any match, return it.
+ return bestMatch;
+}
--- /dev/null
+//===- Utils.cpp ---- Misc utilities for code and data transformation -----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements miscellaneous transformation routines for non-loop IR
+// structures.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Utils.h"
+
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/StandardOps/StandardOps.h"
+#include "mlir/Support/MathExtras.h"
+#include "llvm/ADT/DenseMap.h"
+
+using namespace mlir;
+
+/// Return true if this operation dereferences one or more memref's.
+// Temporary utility: will be replaced when this is modeled through
+// side-effects/op traits. TODO(b/117228571)
+static bool isMemRefDereferencingOp(const Operation &op) {
+ if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() ||
+ op.isa<DmaWaitOp>())
+ return true;
+ return false;
+}
+
+/// Replaces all uses of oldMemRef with newMemRef while optionally remapping
+/// old memref's indices to the new memref using the supplied affine map
+/// and adding any additional indices. The new memref could be of a different
+/// shape or rank, but of the same elemental type. Additional indices are added
+/// at the start for now.
+// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily
+// extended to add additional indices at any position.
+bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
+ MLValue *newMemRef,
+ ArrayRef<MLValue *> extraIndices,
+ AffineMap indexRemap) {
+ unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank();
+ (void)newMemRefRank; // unused in opt mode
+ unsigned oldMemRefRank = cast<MemRefType>(oldMemRef->getType())->getRank();
+ (void)newMemRefRank;
+ if (indexRemap) {
+ assert(indexRemap.getNumInputs() == oldMemRefRank);
+ assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
+ } else {
+ assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
+ }
+
+ // Assert same elemental type.
+ assert(cast<MemRefType>(oldMemRef->getType())->getElementType() ==
+ cast<MemRefType>(newMemRef->getType())->getElementType());
+
+ // Check if memref was used in a non-deferencing context.
+ for (const StmtOperand &use : oldMemRef->getUses()) {
+ auto *opStmt = cast<OperationStmt>(use.getOwner());
+ // Failure: memref used in a non-deferencing op (potentially escapes); no
+ // replacement in these cases.
+ if (!isMemRefDereferencingOp(*opStmt))
+ return false;
+ }
+
+ // Walk all uses of old memref. Statement using the memref gets replaced.
+ for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) {
+ StmtOperand &use = *(it++);
+ auto *opStmt = cast<OperationStmt>(use.getOwner());
+ assert(isMemRefDereferencingOp(*opStmt) &&
+ "memref deferencing op expected");
+
+ auto getMemRefOperandPos = [&]() -> unsigned {
+ unsigned i;
+ for (i = 0; i < opStmt->getNumOperands(); i++) {
+ if (opStmt->getOperand(i) == oldMemRef)
+ break;
+ }
+ assert(i < opStmt->getNumOperands() && "operand guaranteed to be found");
+ return i;
+ };
+ unsigned memRefOperandPos = getMemRefOperandPos();
+
+ // Construct the new operation statement using this memref.
+ SmallVector<MLValue *, 8> operands;
+ operands.reserve(opStmt->getNumOperands() + extraIndices.size());
+ // Insert the non-memref operands.
+ operands.insert(operands.end(), opStmt->operand_begin(),
+ opStmt->operand_begin() + memRefOperandPos);
+ operands.push_back(newMemRef);
+
+ MLFuncBuilder builder(opStmt);
+ for (auto *extraIndex : extraIndices) {
+ // TODO(mlir-team): An operation/SSA value should provide a method to
+ // return the position of an SSA result in its defining
+ // operation.
+ assert(extraIndex->getDefiningStmt()->getNumResults() == 1 &&
+ "single result op's expected to generate these indices");
+ assert((cast<MLValue>(extraIndex)->isValidDim() ||
+ cast<MLValue>(extraIndex)->isValidSymbol()) &&
+ "invalid memory op index");
+ operands.push_back(cast<MLValue>(extraIndex));
+ }
+
+ // Construct new indices. The indices of a memref come right after it, i.e.,
+ // at position memRefOperandPos + 1.
+ SmallVector<SSAValue *, 4> indices(
+ opStmt->operand_begin() + memRefOperandPos + 1,
+ opStmt->operand_begin() + memRefOperandPos + 1 + oldMemRefRank);
+ if (indexRemap) {
+ auto remapOp =
+ builder.create<AffineApplyOp>(opStmt->getLoc(), indexRemap, indices);
+ // Remapped indices.
+ for (auto *index : remapOp->getOperation()->getResults())
+ operands.push_back(cast<MLValue>(index));
+ } else {
+ // No remapping specified.
+ for (auto *index : indices)
+ operands.push_back(cast<MLValue>(index));
+ }
+
+ // Insert the remaining operands unmodified.
+ operands.insert(operands.end(),
+ opStmt->operand_begin() + memRefOperandPos + 1 +
+ oldMemRefRank,
+ opStmt->operand_end());
+
+ // Result types don't change. Both memref's are of the same elemental type.
+ SmallVector<Type *, 8> resultTypes;
+ resultTypes.reserve(opStmt->getNumResults());
+ for (const auto *result : opStmt->getResults())
+ resultTypes.push_back(result->getType());
+
+ // Create the new operation.
+ auto *repOp =
+ builder.createOperation(opStmt->getLoc(), opStmt->getName(), operands,
+ resultTypes, opStmt->getAttrs());
+ // Replace old memref's deferencing op's uses.
+ unsigned r = 0;
+ for (auto *res : opStmt->getResults()) {
+ res->replaceAllUsesWith(repOp->getResult(r++));
+ }
+ opStmt->erase();
+ }
+ return true;
+}
+
+// Creates and inserts into 'builder' a new AffineApplyOp, with the number of
+// its results equal to the number of 'operands, as a composition
+// of all other AffineApplyOps reachable from input parameter 'operands'. If the
+// operands were drawing results from multiple affine apply ops, this also leads
+// to a collapse into a single affine apply op. The final results of the
+// composed AffineApplyOp are returned in output parameter 'results'.
+OperationStmt *
+mlir::createComposedAffineApplyOp(MLFuncBuilder *builder, Location *loc,
+ ArrayRef<MLValue *> operands,
+ ArrayRef<OperationStmt *> affineApplyOps,
+ SmallVectorImpl<SSAValue *> &results) {
+ // Create identity map with same number of dimensions as number of operands.
+ auto map = builder->getMultiDimIdentityMap(operands.size());
+ // Initialize AffineValueMap with identity map.
+ AffineValueMap valueMap(map, operands);
+
+ for (auto *opStmt : affineApplyOps) {
+ assert(opStmt->isa<AffineApplyOp>());
+ auto affineApplyOp = opStmt->cast<AffineApplyOp>();
+ // Forward substitute 'affineApplyOp' into 'valueMap'.
+ valueMap.forwardSubstitute(*affineApplyOp);
+ }
+ // Compose affine maps from all ancestor AffineApplyOps.
+ // Create new AffineApplyOp from 'valueMap'.
+ unsigned numOperands = valueMap.getNumOperands();
+ SmallVector<SSAValue *, 4> outOperands(numOperands);
+ for (unsigned i = 0; i < numOperands; ++i) {
+ outOperands[i] = valueMap.getOperand(i);
+ }
+ // Create new AffineApplyOp based on 'valueMap'.
+ auto affineApplyOp =
+ builder->create<AffineApplyOp>(loc, valueMap.getAffineMap(), outOperands);
+ results.resize(operands.size());
+ for (unsigned i = 0, e = operands.size(); i < e; ++i) {
+ results[i] = affineApplyOp->getResult(i);
+ }
+ return cast<OperationStmt>(affineApplyOp->getOperation());
+}
+
+/// Given an operation statement, inserts a new single affine apply operation,
+/// that is exclusively used by this operation statement, and that provides all
+/// operands that are results of an affine_apply as a function of loop iterators
+/// and program parameters and whose results are.
+///
+/// Before
+///
+/// for %i = 0 to #map(%N)
+/// %idx = affine_apply (d0) -> (d0 mod 2) (%i)
+/// "send"(%idx, %A, ...)
+/// "compute"(%idx)
+///
+/// After
+///
+/// for %i = 0 to #map(%N)
+/// %idx = affine_apply (d0) -> (d0 mod 2) (%i)
+/// "send"(%idx, %A, ...)
+/// %idx_ = affine_apply (d0) -> (d0 mod 2) (%i)
+/// "compute"(%idx_)
+///
+/// This allows applying different transformations on send and compute (for eg.
+/// different shifts/delays).
+///
+/// Returns nullptr either if none of opStmt's operands were the result of an
+/// affine_apply and thus there was no affine computation slice to create, or if
+/// all the affine_apply op's supplying operands to this opStmt do not have any
+/// uses besides this opStmt. Returns the new affine_apply operation statement
+/// otherwise.
+OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) {
+ // Collect all operands that are results of affine apply ops.
+ SmallVector<MLValue *, 4> subOperands;
+ subOperands.reserve(opStmt->getNumOperands());
+ for (auto *operand : opStmt->getOperands()) {
+ auto *defStmt = operand->getDefiningStmt();
+ if (defStmt && defStmt->isa<AffineApplyOp>()) {
+ subOperands.push_back(operand);
+ }
+ }
+
+ // Gather sequence of AffineApplyOps reachable from 'subOperands'.
+ SmallVector<OperationStmt *, 4> affineApplyOps;
+ getReachableAffineApplyOps(subOperands, affineApplyOps);
+ // Skip transforming if there are no affine maps to compose.
+ if (affineApplyOps.empty())
+ return nullptr;
+
+ // Check if all uses of the affine apply op's lie in this op stmt
+ // itself, in which case there would be nothing to do.
+ bool localized = true;
+ for (auto *op : affineApplyOps) {
+ for (auto *result : op->getResults()) {
+ for (auto &use : result->getUses()) {
+ if (use.getOwner() != opStmt) {
+ localized = false;
+ break;
+ }
+ }
+ }
+ }
+ if (localized)
+ return nullptr;
+
+ MLFuncBuilder builder(opStmt);
+ SmallVector<SSAValue *, 4> results;
+ auto *affineApplyStmt = createComposedAffineApplyOp(
+ &builder, opStmt->getLoc(), subOperands, affineApplyOps, results);
+ assert(results.size() == subOperands.size() &&
+ "number of results should be the same as the number of subOperands");
+
+ // Construct the new operands that include the results from the composed
+ // affine apply op above instead of existing ones (subOperands). So, they
+ // differ from opStmt's operands only for those operands in 'subOperands', for
+ // which they will be replaced by the corresponding one from 'results'.
+ SmallVector<MLValue *, 4> newOperands(opStmt->getOperands());
+ for (unsigned i = 0, e = newOperands.size(); i < e; i++) {
+ // Replace the subOperands from among the new operands.
+ unsigned j, f;
+ for (j = 0, f = subOperands.size(); j < f; j++) {
+ if (newOperands[i] == subOperands[j])
+ break;
+ }
+ if (j < subOperands.size()) {
+ newOperands[i] = cast<MLValue>(results[j]);
+ }
+ }
+
+ for (unsigned idx = 0; idx < newOperands.size(); idx++) {
+ opStmt->setOperand(idx, newOperands[idx]);
+ }
+
+ return affineApplyStmt;
+}
+
+void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
+ if (affineApplyOp->getOperation()->getOperationFunction()->getKind() !=
+ Function::Kind::MLFunc) {
+ // TODO: Support forward substitution for CFGFunctions.
+ return;
+ }
+ auto *opStmt = cast<OperationStmt>(affineApplyOp->getOperation());
+ // Iterate through all uses of all results of 'opStmt', forward substituting
+ // into any uses which are AffineApplyOps.
+ for (unsigned resultIndex = 0, e = opStmt->getNumResults(); resultIndex < e;
+ ++resultIndex) {
+ const MLValue *result = opStmt->getResult(resultIndex);
+ for (auto it = result->use_begin(); it != result->use_end();) {
+ StmtOperand &use = *(it++);
+ auto *useStmt = use.getOwner();
+ auto *useOpStmt = dyn_cast<OperationStmt>(useStmt);
+ // Skip if use is not AffineApplyOp.
+ if (useOpStmt == nullptr || !useOpStmt->isa<AffineApplyOp>())
+ continue;
+ // Advance iterator past 'opStmt' operands which also use 'result'.
+ while (it != result->use_end() && it->getOwner() == useStmt)
+ ++it;
+
+ MLFuncBuilder builder(useOpStmt);
+ // Initialize AffineValueMap with 'affineApplyOp' which uses 'result'.
+ auto oldAffineApplyOp = useOpStmt->cast<AffineApplyOp>();
+ AffineValueMap valueMap(*oldAffineApplyOp);
+ // Forward substitute 'result' at index 'i' into 'valueMap'.
+ valueMap.forwardSubstituteSingle(*affineApplyOp, resultIndex);
+
+ // Create new AffineApplyOp from 'valueMap'.
+ unsigned numOperands = valueMap.getNumOperands();
+ SmallVector<SSAValue *, 4> operands(numOperands);
+ for (unsigned i = 0; i < numOperands; ++i) {
+ operands[i] = valueMap.getOperand(i);
+ }
+ auto newAffineApplyOp = builder.create<AffineApplyOp>(
+ useOpStmt->getLoc(), valueMap.getAffineMap(), operands);
+
+ // Update all uses to use results from 'newAffineApplyOp'.
+ for (unsigned i = 0, e = useOpStmt->getNumResults(); i < e; ++i) {
+ oldAffineApplyOp->getResult(i)->replaceAllUsesWith(
+ newAffineApplyOp->getResult(i));
+ }
+ // Erase 'oldAffineApplyOp'.
+ oldAffineApplyOp->getOperation()->erase();
+ }
+ }
+}
+
+/// Folds the specified (lower or upper) bound to a constant if possible
+/// considering its operands. Returns false if the folding happens for any of
+/// the bounds, true otherwise.
+bool mlir::constantFoldBounds(ForStmt *forStmt) {
+ auto foldLowerOrUpperBound = [forStmt](bool lower) {
+ // Check if the bound is already a constant.
+ if (lower && forStmt->hasConstantLowerBound())
+ return true;
+ if (!lower && forStmt->hasConstantUpperBound())
+ return true;
+
+ // Check to see if each of the operands is the result of a constant. If so,
+ // get the value. If not, ignore it.
+ SmallVector<Attribute *, 8> operandConstants;
+ auto boundOperands = lower ? forStmt->getLowerBoundOperands()
+ : forStmt->getUpperBoundOperands();
+ for (const auto *operand : boundOperands) {
+ Attribute *operandCst = nullptr;
+ if (auto *operandOp = operand->getDefiningOperation()) {
+ if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
+ operandCst = operandConstantOp->getValue();
+ }
+ operandConstants.push_back(operandCst);
+ }
+
+ AffineMap boundMap =
+ lower ? forStmt->getLowerBoundMap() : forStmt->getUpperBoundMap();
+ assert(boundMap.getNumResults() >= 1 &&
+ "bound maps should have at least one result");
+ SmallVector<Attribute *, 4> foldedResults;
+ if (boundMap.constantFold(operandConstants, foldedResults))
+ return true;
+
+ // Compute the max or min as applicable over the results.
+ assert(!foldedResults.empty() && "bounds should have at least one result");
+ auto maxOrMin = cast<IntegerAttr>(foldedResults[0])->getValue();
+ for (unsigned i = 1; i < foldedResults.size(); i++) {
+ auto foldedResult = cast<IntegerAttr>(foldedResults[i])->getValue();
+ maxOrMin = lower ? std::max(maxOrMin, foldedResult)
+ : std::min(maxOrMin, foldedResult);
+ }
+ lower ? forStmt->setConstantLowerBound(maxOrMin)
+ : forStmt->setConstantUpperBound(maxOrMin);
+
+ // Return false on success.
+ return false;
+ };
+
+ bool ret = foldLowerOrUpperBound(/*lower=*/true);
+ ret &= foldLowerOrUpperBound(/*lower=*/false);
+ return ret;
+}