}
// Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef'
- // has been replaced in node at 'dstId' by a private memref.
- void updateEdges(unsigned srcId, unsigned dstId, Value *oldMemRef) {
+ // has been replaced in node at 'dstId' by a private memref depending
+ // on the value of 'createPrivateMemRef'.
+ void updateEdges(unsigned srcId, unsigned dstId, Value *oldMemRef,
+ bool createPrivateMemRef) {
// For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'.
if (inEdges.count(srcId) > 0) {
SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
// Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
// replaced by a private memref). These edges could come from nodes
// other than 'srcId' which were removed in the previous step.
- if (inEdges.count(dstId) > 0) {
+ if (inEdges.count(dstId) > 0 && createPrivateMemRef) {
SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
for (auto &inEdge : oldInEdges)
if (inEdge.value == oldMemRef)
// TODO(andydavis) Support more generic multi-output src loop nests
// fusion.
auto srcStoreOp = mdg->getUniqueOutgoingStore(srcNode);
- if (!srcStoreOp)
- continue;
+ if (!srcStoreOp) {
+ // Get the src store op at the deepest loop depth.
+ // We will use 'LoopFusionUtils::canFuseLoops' to check fusion
+ // feasibility for loops with multiple stores.
+ unsigned maxLoopDepth = 0;
+ for (auto *op : srcNode->stores) {
+ auto storeOp = cast<AffineStoreOp>(op);
+ if (storeOp.getMemRef() != memref) {
+ srcStoreOp = nullptr;
+ break;
+ }
+ unsigned loopDepth = getNestingDepth(*storeOp);
+ if (loopDepth > maxLoopDepth) {
+ maxLoopDepth = loopDepth;
+ srcStoreOp = storeOp;
+ }
+ }
+ if (!srcStoreOp)
+ continue;
+ }
+
// Unique outgoing store found must write to 'memref' since 'memref'
// is the one that established the producer-consumer relationship
// between 'srcNode' and 'dstNode'.
!canFuseSrcWhichWritesToLiveOut(srcId, dstId, srcStoreOp, mdg))
continue;
+ // Dont create a private memref if 'writesToLiveInOrOut'.
+ bool createPrivateMemref = !writesToLiveInOrOut;
+ // Dont create a private memref if 'srcNode' has in edges on 'memref',
+ // or if 'dstNode' has out edges on 'memref'.
+ if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) > 0 ||
+ mdg->getOutEdgeCount(dstNode->id, memref) > 0) {
+ createPrivateMemref = false;
+ }
+
// Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'.
if (mdg->getOutEdgeCount(srcNode->id, memref) > maxSrcUserCount)
continue;
if (insertPointInst == nullptr)
continue;
+ // Compute the innermost common loop depth for dstNode loads/stores.
+ SmallVector<Operation *, 2> dstOps(dstNode->loads.begin(),
+ dstNode->loads.end());
+ dstOps.append(dstNode->stores.begin(), dstNode->stores.end());
+ unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstOps);
+ // Check the feasibility of fusing src loop nest into dst loop nest
+ // at loop depths in range [1, dstLoopDepthTest].
+ // TODO(andydavis) Use slice union computation and union of memref
+ // read/write regions to cost model and fusion.
+ bool canFuse = false;
+ for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
+ ComputationSliceState sliceUnion;
+ FusionResult result = mlir::canFuseLoops(
+ cast<AffineForOp>(srcNode->op), cast<AffineForOp>(dstNode->op),
+ /*dstLoopDepth=*/i, &sliceUnion);
+ if (result.value == FusionResult::Success)
+ canFuse = true;
+ }
+
+ // Skip if fusion is not feasible at all loop depths.
+ if (!canFuse)
+ continue;
+
// Gather 'dstNode' store ops to 'memref'.
SmallVector<Operation *, 2> dstStoreOpInsts;
for (auto *storeOpInst : dstNode->stores)
dstStoreOpInsts, &sliceState,
&bestDstLoopDepth, maximalFusion))
continue;
- // TODO(andydavis) Remove the following test code when canFuseLoops
- // is fully functional.
- mlir::ComputationSliceState sliceUnion;
- if (!maximalFusion) {
- FusionResult result = mlir::canFuseLoops(
- cast<AffineForOp>(srcNode->op), cast<AffineForOp>(dstNode->op),
- bestDstLoopDepth, &sliceUnion);
- assert(result.value == FusionResult::Success);
- (void)result;
- }
+
// Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
auto sliceLoopNest = mlir::insertBackwardComputationSlice(
srcStoreOp, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
dstAffineForOp.getOperation()->moveBefore(insertPointInst);
}
// Update edges between 'srcNode' and 'dstNode'.
- mdg->updateEdges(srcNode->id, dstNode->id, memref);
+ mdg->updateEdges(srcNode->id, dstNode->id, memref,
+ createPrivateMemref);
// Collect slice loop stats.
LoopNestStateCollector sliceCollector;
for (auto forOp : sliceCollector.forOps) {
promoteIfSingleIteration(forOp);
}
- if (!writesToLiveInOrOut) {
+ if (createPrivateMemref) {
// Create private memref for 'memref' in 'dstAffineForOp'.
SmallVector<Operation *, 4> storesForMemref;
for (auto *storeOpInst : sliceCollector.storeOpInsts) {
if (cast<AffineStoreOp>(storeOpInst).getMemRef() == memref)
storesForMemref.push_back(storeOpInst);
}
- assert(storesForMemref.size() == 1);
+ // TODO(andydavis) Use union of memref write regions to compute
+ // private memref footprint.
auto *newMemRef = createPrivateMemRef(
dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
fastMemorySpace, localBufSizeThreshold);
// TODO(andydavis) When the fusion pass is run to a fixed-point, it should
// fuse all three of these loop nests.
// CHECK: %{{.*}} = alloc() : memref<1xf32>
- // CHECK: %{{.*}} = alloc() : memref<10xf32>
// CHECK: affine.for %{{.*}} = 0 to 10 {
- // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
- // CHECK-NEXT: }
- // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
+ // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
// CHECK-NEXT: %{{.*}} = affine.load %{{.*}}[0] : memref<1xf32>
// CHECK-NEXT: }
// -----
-// CHECK-LABEL: func @should_not_fuse_multi_output_producer() {
func @should_not_fuse_multi_output_producer() {
%a = alloc() : memref<10xf32>
%b = alloc() : memref<10xf32>
// CHECK-NEXT: return
return
}
+
+// -----
+
+// Test case from github bug 777.
+// CHECK-LABEL: func @mul_add_0
+func @mul_add_0(%arg0: memref<3x4xf32>, %arg1: memref<4x3xf32>, %arg2: memref<3x3xf32>, %arg3: memref<3x3xf32>) {
+ %cst = constant 0.000000e+00 : f32
+ %0 = alloc() : memref<3x3xf32>
+ affine.for %arg4 = 0 to 3 {
+ affine.for %arg5 = 0 to 3 {
+ affine.store %cst, %0[%arg4, %arg5] : memref<3x3xf32>
+ }
+ }
+ affine.for %arg4 = 0 to 3 {
+ affine.for %arg5 = 0 to 3 {
+ affine.for %arg6 = 0 to 4 {
+ %1 = affine.load %arg1[%arg6, %arg5] : memref<4x3xf32>
+ %2 = affine.load %arg0[%arg4, %arg6] : memref<3x4xf32>
+ %3 = mulf %2, %1 : f32
+ %4 = affine.load %0[%arg4, %arg5] : memref<3x3xf32>
+ %5 = addf %4, %3 : f32
+ affine.store %5, %0[%arg4, %arg5] : memref<3x3xf32>
+ }
+ }
+ }
+ affine.for %arg4 = 0 to 3 {
+ affine.for %arg5 = 0 to 3 {
+ %6 = affine.load %arg2[%arg4, %arg5] : memref<3x3xf32>
+ %7 = affine.load %0[%arg4, %arg5] : memref<3x3xf32>
+ %8 = addf %7, %6 : f32
+ affine.store %8, %arg3[%arg4, %arg5] : memref<3x3xf32>
+ }
+ }
+ // CHECK: affine.for %[[i0:.*]] = 0 to 3 {
+ // CHECK-NEXT: affine.for %[[i1:.*]] = 0 to 3 {
+ // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0, 0] : memref<1x1xf32>
+ // CHECK-NEXT: affine.for %[[i2:.*]] = 0 to 4 {
+ // CHECK-NEXT: affine.load %{{.*}}[%[[i2]], %[[i1]]] : memref<4x3xf32>
+ // CHECK-NEXT: affine.load %{{.*}}[%[[i0]], %[[i2]]] : memref<3x4xf32>
+ // CHECK-NEXT: %{{.*}} = mulf %{{.*}}, %{{.*}} : f32
+ // CHECK-NEXT: affine.load %{{.*}}[0, 0] : memref<1x1xf32>
+ // CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
+ // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0, 0] : memref<1x1xf32>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: affine.load %{{.*}}[%[[i0]], %[[i1]]] : memref<3x3xf32>
+ // CHECK-NEXT: affine.load %{{.*}}[0, 0] : memref<1x1xf32>
+ // CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
+ // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%[[i0]], %[[i1]]] : memref<3x3xf32>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: }
+ // CHECK-NEXT: return
+
+ return
+}