return false;
}
+ // Returns the unique AffineStoreOp in `node` that meets all the following:
+ // *) store is the only one that writes to a function-local memref live out
+ // of `node`,
+ // *) store is not the source of a self-dependence on `node`.
+ // Otherwise, returns a null AffineStoreOp.
+ AffineStoreOp getUniqueOutgoingStore(Node *node) {
+ AffineStoreOp uniqueStore;
+
+ // Return null if `node` doesn't have any outgoing edges.
+ auto outEdgeIt = outEdges.find(node->id);
+ if (outEdgeIt == outEdges.end())
+ return nullptr;
+
+ const auto &nodeOutEdges = outEdgeIt->second;
+ for (auto *op : node->stores) {
+ auto storeOp = cast<AffineStoreOp>(op);
+ auto *memref = storeOp.getMemRef();
+ // Skip this store if there are no dependences on its memref. This means
+ // that store either:
+ // *) writes to a memref that is only read within the same loop nest
+ // (self-dependence edges are not represented in graph at the moment),
+ // *) writes to a function live out memref (function parameter), or
+ // *) is dead.
+ if (llvm::all_of(nodeOutEdges, [=](const Edge &edge) {
+ return (edge.value != memref);
+ }))
+ continue;
+
+ if (uniqueStore)
+ // Found multiple stores to function-local live-out memrefs.
+ return nullptr;
+ // Found first store to function-local live-out memref.
+ uniqueStore = storeOp;
+ }
+
+ return uniqueStore;
+ }
+
// Returns true if node 'id' can be removed from the graph. Returns false
// otherwise. A node can be removed from the graph iff the following
// conditions are met:
return newMemRef;
}
-// Checks if node 'srcId' (which writes to a live out memref), can be safely
-// fused into node 'dstId'. Returns true if the following conditions are met:
-// *) 'srcNode' only writes to live out 'memref'.
-// *) 'srcNode' has exactly one output edge on 'memref' (which is to 'dstId').
-// *) 'dstNode's read/write region to 'memref' is a super set of 'srcNode's
-// write region to 'memref'.
+// Checks if node 'srcId' can be safely fused into node 'dstId'. Node 'srcId'
+// may write to multiple memrefs but it is required that only one of them,
+// 'srcLiveOutStoreOp', have an output edge.
+// Returns true if 'dstNode's read/write region to 'memref' is a super set of
+// 'srcNode's write region to 'memref'.
// TODO(andydavis) Generalize this to handle more live in/out cases.
static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
- Value *memref,
+ AffineStoreOp srcLiveOutStoreOp,
MemRefDependenceGraph *mdg) {
- auto *srcNode = mdg->getNode(srcId);
+ assert(srcLiveOutStoreOp && "Expected a valid store op");
+ assert(mdg->getOutEdgeCount(srcId) == 1 && "Expected only one output edge");
auto *dstNode = mdg->getNode(dstId);
+ Value *memref = srcLiveOutStoreOp.getMemRef();
- // Gather all memrefs from 'srcNode' store ops.
- DenseSet<Value *> storeMemrefs;
- for (auto *storeOpInst : srcNode->stores) {
- storeMemrefs.insert(cast<AffineStoreOp>(storeOpInst).getMemRef());
- }
- // Return false if any of the following are true:
- // *) 'srcNode' writes to a live in/out memref other than 'memref'.
- // *) 'srcNode' has more than one output edge on 'memref'.
- // Check that all stores are to the same memref.
- if (storeMemrefs.size() != 1 ||
- mdg->getOutEdgeCount(srcNode->id, memref) != 1)
- return false;
- // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOpInst' on 'memref'.
- auto *srcStoreOpInst = srcNode->stores.front();
- MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
- if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) {
+ // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOp' on 'memref'.
+ MemRefRegion srcWriteRegion(srcLiveOutStoreOp.getLoc());
+ if (failed(srcWriteRegion.compute(srcLiveOutStoreOp, /*loopDepth=*/0))) {
LLVM_DEBUG(llvm::dbgs()
<< "Unable to compute MemRefRegion for source operation\n.");
return false;
}
SmallVector<int64_t, 4> srcShape;
// Query 'srcWriteRegion' for 'srcShape' and 'srcNumElements'.
- // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
+ // by 'srcStoreOp' at depth 'dstLoopDepth'.
Optional<int64_t> srcNumElements =
srcWriteRegion.getConstantBoundingSizeAndShape(&srcShape);
if (!srcNumElements.hasValue())
// Skip if 'srcNode' is not a loop nest.
if (!isa<AffineForOp>(srcNode->op))
continue;
- // Skip if 'srcNode' has more than one store to any memref.
- // TODO(andydavis) Support fusing multi-output src loop nests.
- if (srcNode->stores.size() != 1)
+ // Skip if 'srcNode' has more than one live-out store to a
+ // function-local memref.
+ // TODO(andydavis) Support more generic multi-output src loop nests
+ // fusion.
+ auto srcStoreOp = mdg->getUniqueOutgoingStore(srcNode);
+ if (!srcStoreOp)
continue;
+ // Unique outgoing store found must write to 'memref' since 'memref'
+ // is the one that established the producer-consumer relationship
+ // between 'srcNode' and 'dstNode'.
+ assert(srcStoreOp.getMemRef() == memref &&
+ "Found store to unexpected memref");
// Skip if 'srcNode' writes to any live in or escaping memrefs,
// and cannot be fused.
bool writesToLiveInOrOut =
mdg->writesToLiveInOrEscapingMemrefs(srcNode->id);
if (writesToLiveInOrOut &&
- !canFuseSrcWhichWritesToLiveOut(srcId, dstId, memref, mdg))
+ !canFuseSrcWhichWritesToLiveOut(srcId, dstId, srcStoreOp, mdg))
continue;
// Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'.
if (insertPointInst == nullptr)
continue;
- // Get unique 'srcNode' store op.
- auto *srcStoreOpInst = srcNode->stores.front();
// Gather 'dstNode' store ops to 'memref'.
SmallVector<Operation *, 2> dstStoreOpInsts;
for (auto *storeOpInst : dstNode->stores)
unsigned bestDstLoopDepth;
mlir::ComputationSliceState sliceState;
// Check if fusion would be profitable.
- if (!isFusionProfitable(srcStoreOpInst, srcStoreOpInst,
- dstLoadOpInsts, dstStoreOpInsts, &sliceState,
+ if (!isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts,
+ dstStoreOpInsts, &sliceState,
&bestDstLoopDepth, maximalFusion))
continue;
// TODO(andydavis) Remove the following test code when canFuseLoops
}
// Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
auto sliceLoopNest = mlir::insertBackwardComputationSlice(
- srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
+ srcStoreOp, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
if (sliceLoopNest) {
LLVM_DEBUG(llvm::dbgs() << "\tslice loop nest:\n"
<< *sliceLoopNest.getOperation() << "\n");
}
affine.for %i1 = 0 to 10 {
%v0 = affine.load %a[%i1] : memref<10xf32>
+ %v1 = affine.load %b[%i1] : memref<10xf32>
}
// CHECK: affine.for %{{.*}} = 0 to 10 {
// CHECK-NEXT: }
// CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
// CHECK-NEXT: %{{.*}} = affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
+ // CHECK-NEXT: %{{.*}} = affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: return
return
// CHECK-NEXT: }
return
}
+
+// -----
+
+// CHECK-LABEL: func @should_fuse_self_dependence_multi_store_producer() {
+func @should_fuse_self_dependence_multi_store_producer() {
+ %m = alloc() : memref<10xf32>
+ %local_m = alloc() : memref<10xf32>
+ %cf7 = constant 7.0 : f32
+
+ affine.for %i0 = 0 to 10 {
+ affine.store %cf7, %local_m[%i0] : memref<10xf32>
+ %v0 = affine.load %local_m[%i0] : memref<10xf32>
+ affine.store %v0, %m[%i0] : memref<10xf32>
+ }
+ affine.for %i1 = 0 to 10 {
+ %v1 = affine.load %m[%i1] : memref<10xf32>
+ }
+ // CHECK: affine.for %[[i0:.*]] = 0 to 10 {
+ // CHECK-NEXT: affine.store %{{.*}}, [[LOCAL_M:%.*]][%[[i0]]] : memref<10xf32>
+ // CHECK-NEXT: [[v0:%.*]] = affine.load [[LOCAL_M]][%[[i0]]] : memref<10xf32>
+ // CHECK-NEXT: affine.store [[v0]], %{{.*}}[0] : memref<1xf32>
+ // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: return
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @should_fuse_dead_multi_store_producer() {
+func @should_fuse_dead_multi_store_producer() {
+ %m = alloc() : memref<10xf32>
+ %dead_m = alloc() : memref<10xf32>
+ %cf7 = constant 7.0 : f32
+
+ affine.for %i0 = 0 to 10 {
+ affine.store %cf7, %dead_m[%i0] : memref<10xf32>
+ affine.store %cf7, %m[%i0] : memref<10xf32>
+ }
+ affine.for %i1 = 0 to 10 {
+ %v0 = affine.load %m[%i1] : memref<10xf32>
+ }
+ // CHECK: affine.for %[[i0:.*]] = 0 to 10 {
+ // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%[[i0]]] : memref<10xf32>
+ // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
+ // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: return
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @should_fuse_function_live_out_multi_store_producer
+func @should_fuse_function_live_out_multi_store_producer(%live_in_out_m : memref<10xf32>) {
+ %m = alloc() : memref<10xf32>
+ %cf7 = constant 7.0 : f32
+
+ affine.for %i0 = 0 to 10 {
+ affine.store %cf7, %live_in_out_m[%i0] : memref<10xf32>
+ affine.store %cf7, %m[%i0] : memref<10xf32>
+ }
+ affine.for %i1 = 0 to 10 {
+ %v0 = affine.load %m[%i1] : memref<10xf32>
+ }
+ // CHECK: affine.for %[[i0:.*]] = 0 to 10 {
+ // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%[[i0]]] : memref<10xf32>
+ // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%[[i0]]] : memref<10xf32>
+ // CHECK-NEXT: affine.load %{{.*}}[%[[i0]]] : memref<10xf32>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: return
+ return
+}