From 68a8da4a938e5489ba915d615352af0b069ae56a Mon Sep 17 00:00:00 2001 From: Andy Davis Date: Mon, 18 Nov 2019 11:20:03 -0800 Subject: [PATCH] Fix Affine Loop Fusion test case reported on github. This CL utilizies the more robust fusion feasibility analysis being built out in LoopFusionUtils, which will eventually be used to replace the current affine loop fusion pass. PiperOrigin-RevId: 281112340 --- mlir/lib/Analysis/Utils.cpp | 4 +- mlir/lib/Transforms/LoopFusion.cpp | 82 +++++++++++++++++++++++++++-------- mlir/test/Transforms/loop-fusion.mlir | 60 ++++++++++++++++++++++--- 3 files changed, 122 insertions(+), 24 deletions(-) diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 042c744..23361e3 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -616,7 +616,9 @@ LogicalResult mlir::computeSliceUnion(ArrayRef opsA, return failure(); } // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'. - if (failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) { + if (sliceUnionCst.getNumLocalIds() > 0 || + tmpSliceCst.getNumLocalIds() > 0 || + failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute union bounding box of slice bounds." "\n."); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 24d91c2f..7985ca1 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -546,8 +546,10 @@ public: } // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef' - // has been replaced in node at 'dstId' by a private memref. - void updateEdges(unsigned srcId, unsigned dstId, Value *oldMemRef) { + // has been replaced in node at 'dstId' by a private memref depending + // on the value of 'createPrivateMemRef'. + void updateEdges(unsigned srcId, unsigned dstId, Value *oldMemRef, + bool createPrivateMemRef) { // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'. if (inEdges.count(srcId) > 0) { SmallVector oldInEdges = inEdges[srcId]; @@ -569,7 +571,7 @@ public: // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being // replaced by a private memref). These edges could come from nodes // other than 'srcId' which were removed in the previous step. - if (inEdges.count(dstId) > 0) { + if (inEdges.count(dstId) > 0 && createPrivateMemRef) { SmallVector oldInEdges = inEdges[dstId]; for (auto &inEdge : oldInEdges) if (inEdge.value == oldMemRef) @@ -1522,8 +1524,27 @@ public: // TODO(andydavis) Support more generic multi-output src loop nests // fusion. auto srcStoreOp = mdg->getUniqueOutgoingStore(srcNode); - if (!srcStoreOp) - continue; + if (!srcStoreOp) { + // Get the src store op at the deepest loop depth. + // We will use 'LoopFusionUtils::canFuseLoops' to check fusion + // feasibility for loops with multiple stores. + unsigned maxLoopDepth = 0; + for (auto *op : srcNode->stores) { + auto storeOp = cast(op); + if (storeOp.getMemRef() != memref) { + srcStoreOp = nullptr; + break; + } + unsigned loopDepth = getNestingDepth(*storeOp); + if (loopDepth > maxLoopDepth) { + maxLoopDepth = loopDepth; + srcStoreOp = storeOp; + } + } + if (!srcStoreOp) + continue; + } + // Unique outgoing store found must write to 'memref' since 'memref' // is the one that established the producer-consumer relationship // between 'srcNode' and 'dstNode'. @@ -1538,6 +1559,15 @@ public: !canFuseSrcWhichWritesToLiveOut(srcId, dstId, srcStoreOp, mdg)) continue; + // Dont create a private memref if 'writesToLiveInOrOut'. + bool createPrivateMemref = !writesToLiveInOrOut; + // Dont create a private memref if 'srcNode' has in edges on 'memref', + // or if 'dstNode' has out edges on 'memref'. + if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) > 0 || + mdg->getOutEdgeCount(dstNode->id, memref) > 0) { + createPrivateMemref = false; + } + // Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'. if (mdg->getOutEdgeCount(srcNode->id, memref) > maxSrcUserCount) continue; @@ -1549,6 +1579,29 @@ public: if (insertPointInst == nullptr) continue; + // Compute the innermost common loop depth for dstNode loads/stores. + SmallVector dstOps(dstNode->loads.begin(), + dstNode->loads.end()); + dstOps.append(dstNode->stores.begin(), dstNode->stores.end()); + unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstOps); + // Check the feasibility of fusing src loop nest into dst loop nest + // at loop depths in range [1, dstLoopDepthTest]. + // TODO(andydavis) Use slice union computation and union of memref + // read/write regions to cost model and fusion. + bool canFuse = false; + for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { + ComputationSliceState sliceUnion; + FusionResult result = mlir::canFuseLoops( + cast(srcNode->op), cast(dstNode->op), + /*dstLoopDepth=*/i, &sliceUnion); + if (result.value == FusionResult::Success) + canFuse = true; + } + + // Skip if fusion is not feasible at all loop depths. + if (!canFuse) + continue; + // Gather 'dstNode' store ops to 'memref'. SmallVector dstStoreOpInsts; for (auto *storeOpInst : dstNode->stores) @@ -1562,16 +1615,7 @@ public: dstStoreOpInsts, &sliceState, &bestDstLoopDepth, maximalFusion)) continue; - // TODO(andydavis) Remove the following test code when canFuseLoops - // is fully functional. - mlir::ComputationSliceState sliceUnion; - if (!maximalFusion) { - FusionResult result = mlir::canFuseLoops( - cast(srcNode->op), cast(dstNode->op), - bestDstLoopDepth, &sliceUnion); - assert(result.value == FusionResult::Success); - (void)result; - } + // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. auto sliceLoopNest = mlir::insertBackwardComputationSlice( srcStoreOp, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); @@ -1584,7 +1628,8 @@ public: dstAffineForOp.getOperation()->moveBefore(insertPointInst); } // Update edges between 'srcNode' and 'dstNode'. - mdg->updateEdges(srcNode->id, dstNode->id, memref); + mdg->updateEdges(srcNode->id, dstNode->id, memref, + createPrivateMemref); // Collect slice loop stats. LoopNestStateCollector sliceCollector; @@ -1593,14 +1638,15 @@ public: for (auto forOp : sliceCollector.forOps) { promoteIfSingleIteration(forOp); } - if (!writesToLiveInOrOut) { + if (createPrivateMemref) { // Create private memref for 'memref' in 'dstAffineForOp'. SmallVector storesForMemref; for (auto *storeOpInst : sliceCollector.storeOpInsts) { if (cast(storeOpInst).getMemRef() == memref) storesForMemref.push_back(storeOpInst); } - assert(storesForMemref.size() == 1); + // TODO(andydavis) Use union of memref write regions to compute + // private memref footprint. auto *newMemRef = createPrivateMemRef( dstAffineForOp, storesForMemref[0], bestDstLoopDepth, fastMemorySpace, localBufSizeThreshold); diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 36bcd0e..592b45d 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -321,11 +321,8 @@ func @should_fuse_producer_consumer() { // TODO(andydavis) When the fusion pass is run to a fixed-point, it should // fuse all three of these loop nests. // CHECK: %{{.*}} = alloc() : memref<1xf32> - // CHECK: %{{.*}} = alloc() : memref<10xf32> // CHECK: affine.for %{{.*}} = 0 to 10 { - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> - // CHECK-NEXT: } - // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> // CHECK-NEXT: %{{.*}} = affine.load %{{.*}}[0] : memref<1xf32> // CHECK-NEXT: } @@ -1238,7 +1235,6 @@ func @R3_to_R2_reshape() { // ----- -// CHECK-LABEL: func @should_not_fuse_multi_output_producer() { func @should_not_fuse_multi_output_producer() { %a = alloc() : memref<10xf32> %b = alloc() : memref<10xf32> @@ -2341,3 +2337,57 @@ func @should_fuse_function_live_out_multi_store_producer(%live_in_out_m : memref // CHECK-NEXT: return return } + +// ----- + +// Test case from github bug 777. +// CHECK-LABEL: func @mul_add_0 +func @mul_add_0(%arg0: memref<3x4xf32>, %arg1: memref<4x3xf32>, %arg2: memref<3x3xf32>, %arg3: memref<3x3xf32>) { + %cst = constant 0.000000e+00 : f32 + %0 = alloc() : memref<3x3xf32> + affine.for %arg4 = 0 to 3 { + affine.for %arg5 = 0 to 3 { + affine.store %cst, %0[%arg4, %arg5] : memref<3x3xf32> + } + } + affine.for %arg4 = 0 to 3 { + affine.for %arg5 = 0 to 3 { + affine.for %arg6 = 0 to 4 { + %1 = affine.load %arg1[%arg6, %arg5] : memref<4x3xf32> + %2 = affine.load %arg0[%arg4, %arg6] : memref<3x4xf32> + %3 = mulf %2, %1 : f32 + %4 = affine.load %0[%arg4, %arg5] : memref<3x3xf32> + %5 = addf %4, %3 : f32 + affine.store %5, %0[%arg4, %arg5] : memref<3x3xf32> + } + } + } + affine.for %arg4 = 0 to 3 { + affine.for %arg5 = 0 to 3 { + %6 = affine.load %arg2[%arg4, %arg5] : memref<3x3xf32> + %7 = affine.load %0[%arg4, %arg5] : memref<3x3xf32> + %8 = addf %7, %6 : f32 + affine.store %8, %arg3[%arg4, %arg5] : memref<3x3xf32> + } + } + // CHECK: affine.for %[[i0:.*]] = 0 to 3 { + // CHECK-NEXT: affine.for %[[i1:.*]] = 0 to 3 { + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0, 0] : memref<1x1xf32> + // CHECK-NEXT: affine.for %[[i2:.*]] = 0 to 4 { + // CHECK-NEXT: affine.load %{{.*}}[%[[i2]], %[[i1]]] : memref<4x3xf32> + // CHECK-NEXT: affine.load %{{.*}}[%[[i0]], %[[i2]]] : memref<3x4xf32> + // CHECK-NEXT: %{{.*}} = mulf %{{.*}}, %{{.*}} : f32 + // CHECK-NEXT: affine.load %{{.*}}[0, 0] : memref<1x1xf32> + // CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0, 0] : memref<1x1xf32> + // CHECK-NEXT: } + // CHECK-NEXT: affine.load %{{.*}}[%[[i0]], %[[i1]]] : memref<3x3xf32> + // CHECK-NEXT: affine.load %{{.*}}[0, 0] : memref<1x1xf32> + // CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%[[i0]], %[[i1]]] : memref<3x3xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: return + + return +} -- 2.7.4