#include "mlir/Dialect/Affine/Passes.h"
-#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/Utils.h"
// Returns the graph node for 'forOp'.
Node *getForOpNode(AffineForOp forOp) {
for (auto &idAndNode : nodes)
- if (idAndNode.second.op == forOp.getOperation())
+ if (idAndNode.second.op == forOp)
return &idAndNode.second;
return nullptr;
}
producerConsumerMemrefs);
}
+/// A memref escapes the function if either:
+/// 1. it is a function argument, or
+/// 2. it is used by a non-affine op (e.g., std load/store, std
+/// call, etc.)
+/// FIXME: Support alias creating ops like memref view ops.
+static bool isEscapingMemref(Value memref) {
+ // Check if 'memref' escapes because it's a block argument.
+ if (memref.isa<BlockArgument>())
+ return true;
+
+ // Check if 'memref' escapes through a non-affine op (e.g., std load/store,
+ // call op, etc.). This already covers aliases created from this.
+ for (Operation *user : memref.getUsers())
+ if (!isa<AffineMapAccessInterface>(*user))
+ return true;
+ return false;
+}
+
/// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
-/// that escape the function. A memref escapes the function if either:
-/// 1. It's a function argument, or
-/// 2. It's used by a non-affine op (e.g., std load/store, std call, etc.)
+/// that escape the function.
void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
DenseSet<Value> &escapingMemRefs) {
auto *node = mdg->getNode(id);
- for (auto *storeOpInst : node->stores) {
- auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
+ for (Operation *storeOp : node->stores) {
+ auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
if (escapingMemRefs.count(memref))
continue;
- // Check if 'memref' escapes because it's a block argument.
- if (memref.isa<BlockArgument>()) {
+ if (isEscapingMemref(memref))
escapingMemRefs.insert(memref);
- continue;
- }
- // Check if 'memref' escapes through a non-affine op (e.g., std load/store,
- // call op, etc.).
- for (Operation *user : memref.getUsers())
- if (!isa<AffineMapAccessInterface>(*user))
- escapingMemRefs.insert(memref);
}
}
// dependence graph at a different depth.
bool MemRefDependenceGraph::init(func::FuncOp f) {
LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
+ // Map from a memref to the set of ids of the nodes that have ops accessing
+ // the memref.
DenseMap<Value, SetVector<unsigned>> memrefAccesses;
// TODO: support multi-block functions.
getLoopIVs(*user, &loops);
if (loops.empty())
continue;
- assert(forToNodeMap.count(loops[0].getOperation()) > 0);
- unsigned userLoopNestId = forToNodeMap[loops[0].getOperation()];
+ assert(forToNodeMap.count(loops[0]) > 0);
+ unsigned userLoopNestId = forToNodeMap[loops[0]];
addEdge(node.id, userLoopNestId, value);
}
}
static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
assert(isa<AffineForOp>(node->op));
AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op));
- node->op = newRootForOp.getOperation();
+ node->op = newRootForOp;
}
// TODO: improve/complete this when we have target data.
unsigned dstLoopDepth,
Optional<unsigned> fastMemorySpace,
uint64_t localBufSizeThreshold) {
- auto *forInst = forOp.getOperation();
+ Operation *forInst = forOp.getOperation();
// Create builder to insert alloc op just before 'forOp'.
OpBuilder b(forInst);
eraseUnusedMemRefAllocations();
}
+ /// Visit each node in the graph, and for each node, attempt to fuse it with
+ /// producer-consumer candidates. No fusion is performed when producers with a
+ /// user count greater than `maxSrcUserCount` for any of the memrefs involved
+ /// are encountered.
void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n");
init();
<< dstAffineForOp << "\n");
// Move 'dstAffineForOp' before 'insertPointInst' if needed.
- if (fusedLoopInsPoint != dstAffineForOp.getOperation())
- dstAffineForOp.getOperation()->moveBefore(fusedLoopInsPoint);
+ if (fusedLoopInsPoint != dstAffineForOp)
+ dstAffineForOp->moveBefore(fusedLoopInsPoint);
// Update edges between 'srcNode' and 'dstNode'.
mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs,
dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
Value storeMemRef = storeOp.getMemRef();
if (privateMemrefs.count(storeMemRef) > 0)
- privateMemRefToStores[storeMemRef].push_back(
- storeOp.getOperation());
+ privateMemRefToStores[storeMemRef].push_back(storeOp);
});
// Replace original memrefs with private memrefs. Note that all the
// Collect dst loop stats after memref privatization transformation.
LoopNestStateCollector dstLoopCollector;
- dstLoopCollector.collect(dstAffineForOp.getOperation());
+ dstLoopCollector.collect(dstAffineForOp);
// Clear and add back loads and stores.
mdg->clearNodeLoadAndStores(dstNode->id);
auto dstForInst = cast<AffineForOp>(dstNode->op);
// Update operation position of fused loop nest (if needed).
- if (insertPointInst != dstForInst.getOperation()) {
+ if (insertPointInst != dstForInst) {
dstForInst->moveBefore(insertPointInst);
}
// Update data dependence graph state post fusion.
// Collect dst loop stats after memref privatization transformation.
auto dstForInst = cast<AffineForOp>(dstNode->op);
LoopNestStateCollector dstLoopCollector;
- dstLoopCollector.collect(dstForInst.getOperation());
+ dstLoopCollector.collect(dstForInst);
// Clear and add back loads and stores
mdg->clearNodeLoadAndStores(dstNode->id);
mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
#include "mlir/Dialect/Affine/LoopFusionUtils.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
-#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
}
return WalkResult::advance();
}
- for (auto value : op->getResults()) {
+ for (Value value : op->getResults()) {
for (Operation *user : value.getUsers()) {
SmallVector<AffineForOp, 4> loops;
// Check if any loop in loop nest surrounding 'user' is 'opB'.
// dependences. Returns nullptr if no such insertion point is found.
static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp,
AffineForOp dstForOp) {
- bool isSrcForOpBeforeDstForOp =
- srcForOp->isBeforeInBlock(dstForOp.getOperation());
+ bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp);
auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
- auto *firstDepOpA =
- getFirstDependentOpInRange(forOpA.getOperation(), forOpB.getOperation());
- auto *lastDepOpB =
- getLastDependentOpInRange(forOpA.getOperation(), forOpB.getOperation());
+ Operation *firstDepOpA = getFirstDependentOpInRange(forOpA, forOpB);
+ Operation *lastDepOpB = getLastDependentOpInRange(forOpA, forOpB);
// Block:
// ...
// |-- opA
}
// No dependences from 'opA' to operation in range ('opA', 'opB'), return
// 'opB' insertion point.
- return forOpB.getOperation();
+ return forOpB;
}
// Gathers all load and store ops in loop nest rooted at 'forOp' into
}
// Check if 'srcForOp' precedes 'dstForOp' in 'block'.
- bool isSrcForOpBeforeDstForOp =
- srcForOp->isBeforeInBlock(dstForOp.getOperation());
+ bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp);
// 'forOpA' executes before 'forOpB' in 'block'.
auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
}
// Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'.
- unsigned numCommonLoops = mlir::getNumCommonSurroundingLoops(
- *srcForOp.getOperation(), *dstForOp.getOperation());
+ unsigned numCommonLoops =
+ mlir::getNumCommonSurroundingLoops(*srcForOp, *dstForOp);
// Filter out ops in 'opsA' to compute the slice union based on the
// assumptions made by the fusion strategy.
int64_t opCount = stats.opCountMap[forOp] - 1;
if (stats.loopMap.count(forOp) > 0) {
for (auto childForOp : stats.loopMap[forOp]) {
- opCount += getComputeCostHelper(childForOp.getOperation(), stats,
- tripCountOverrideMap, computeCostMap);
+ opCount += getComputeCostHelper(childForOp, stats, tripCountOverrideMap,
+ computeCostMap);
}
}
// Add in additional op instances from slice (if specified in map).
/// instance count (i.e. total number of operations in the loop body * loop
/// trip count) for the entire loop nest.
int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) {
- return getComputeCostHelper(forOp.getOperation(), stats,
+ return getComputeCostHelper(forOp, stats,
/*tripCountOverrideMap=*/nullptr,
/*computeCostMap=*/nullptr);
}
computeCostMap[insertPointParent] = -storeCount;
// Subtract out any load users of 'storeMemrefs' nested below
// 'insertPointParent'.
- for (auto value : storeMemrefs) {
- for (auto *user : value.getUsers()) {
+ for (Value memref : storeMemrefs) {
+ for (auto *user : memref.getUsers()) {
if (auto loadOp = dyn_cast<AffineReadOpInterface>(user)) {
SmallVector<AffineForOp, 4> loops;
// Check if any loop in loop nest surrounding 'user' is
// Compute op instance count for the src loop nest with iteration slicing.
int64_t sliceComputeCost = getComputeCostHelper(
- srcForOp.getOperation(), srcStats, &sliceTripCountMap, &computeCostMap);
+ srcForOp, srcStats, &sliceTripCountMap, &computeCostMap);
// Compute cost of fusion for this depth.
computeCostMap[insertPointParent] = sliceComputeCost;
*computeCost =
- getComputeCostHelper(dstForOp.getOperation(), dstStats,
+ getComputeCostHelper(dstForOp, dstStats,
/*tripCountOverrideMap=*/nullptr, &computeCostMap);
return true;
}