Fix Affine Loop Fusion test case reported on github.
authorAndy Davis <andydavis@google.com>
Mon, 18 Nov 2019 19:20:03 +0000 (11:20 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 18 Nov 2019 19:20:37 +0000 (11:20 -0800)
This CL utilizies the more robust fusion feasibility analysis being built out in LoopFusionUtils, which will eventually be used to replace the current affine loop fusion pass.

PiperOrigin-RevId: 281112340

mlir/lib/Analysis/Utils.cpp
mlir/lib/Transforms/LoopFusion.cpp
mlir/test/Transforms/loop-fusion.mlir

index 042c744..23361e3 100644 (file)
@@ -616,7 +616,9 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
           return failure();
       }
       // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
-      if (failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
+      if (sliceUnionCst.getNumLocalIds() > 0 ||
+          tmpSliceCst.getNumLocalIds() > 0 ||
+          failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
         LLVM_DEBUG(llvm::dbgs()
                    << "Unable to compute union bounding box of slice bounds."
                       "\n.");
index 24d91c2..7985ca1 100644 (file)
@@ -546,8 +546,10 @@ public:
   }
 
   // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef'
-  // has been replaced in node at 'dstId' by a private memref.
-  void updateEdges(unsigned srcId, unsigned dstId, Value *oldMemRef) {
+  // has been replaced in node at 'dstId' by a private memref depending
+  // on the value of 'createPrivateMemRef'.
+  void updateEdges(unsigned srcId, unsigned dstId, Value *oldMemRef,
+                   bool createPrivateMemRef) {
     // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'.
     if (inEdges.count(srcId) > 0) {
       SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
@@ -569,7 +571,7 @@ public:
     // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
     // replaced by a private memref). These edges could come from nodes
     // other than 'srcId' which were removed in the previous step.
-    if (inEdges.count(dstId) > 0) {
+    if (inEdges.count(dstId) > 0 && createPrivateMemRef) {
       SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
       for (auto &inEdge : oldInEdges)
         if (inEdge.value == oldMemRef)
@@ -1522,8 +1524,27 @@ public:
           // TODO(andydavis) Support more generic multi-output src loop nests
           // fusion.
           auto srcStoreOp = mdg->getUniqueOutgoingStore(srcNode);
-          if (!srcStoreOp)
-            continue;
+          if (!srcStoreOp) {
+            // Get the src store op at the deepest loop depth.
+            // We will use 'LoopFusionUtils::canFuseLoops' to check fusion
+            // feasibility for loops with multiple stores.
+            unsigned maxLoopDepth = 0;
+            for (auto *op : srcNode->stores) {
+              auto storeOp = cast<AffineStoreOp>(op);
+              if (storeOp.getMemRef() != memref) {
+                srcStoreOp = nullptr;
+                break;
+              }
+              unsigned loopDepth = getNestingDepth(*storeOp);
+              if (loopDepth > maxLoopDepth) {
+                maxLoopDepth = loopDepth;
+                srcStoreOp = storeOp;
+              }
+            }
+            if (!srcStoreOp)
+              continue;
+          }
+
           // Unique outgoing store found must write to 'memref' since 'memref'
           // is the one that established the producer-consumer relationship
           // between 'srcNode' and 'dstNode'.
@@ -1538,6 +1559,15 @@ public:
               !canFuseSrcWhichWritesToLiveOut(srcId, dstId, srcStoreOp, mdg))
             continue;
 
+          // Dont create a private memref if 'writesToLiveInOrOut'.
+          bool createPrivateMemref = !writesToLiveInOrOut;
+          // Dont create a private memref if 'srcNode' has in edges on 'memref',
+          // or if 'dstNode' has out edges on 'memref'.
+          if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) > 0 ||
+              mdg->getOutEdgeCount(dstNode->id, memref) > 0) {
+            createPrivateMemref = false;
+          }
+
           // Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'.
           if (mdg->getOutEdgeCount(srcNode->id, memref) > maxSrcUserCount)
             continue;
@@ -1549,6 +1579,29 @@ public:
           if (insertPointInst == nullptr)
             continue;
 
+          // Compute the innermost common loop depth for dstNode loads/stores.
+          SmallVector<Operation *, 2> dstOps(dstNode->loads.begin(),
+                                             dstNode->loads.end());
+          dstOps.append(dstNode->stores.begin(), dstNode->stores.end());
+          unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstOps);
+          // Check the feasibility of fusing src loop nest into dst loop nest
+          // at loop depths in range [1, dstLoopDepthTest].
+          // TODO(andydavis) Use slice union computation and union of memref
+          // read/write regions to cost model and fusion.
+          bool canFuse = false;
+          for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
+            ComputationSliceState sliceUnion;
+            FusionResult result = mlir::canFuseLoops(
+                cast<AffineForOp>(srcNode->op), cast<AffineForOp>(dstNode->op),
+                /*dstLoopDepth=*/i, &sliceUnion);
+            if (result.value == FusionResult::Success)
+              canFuse = true;
+          }
+
+          // Skip if fusion is not feasible at all loop depths.
+          if (!canFuse)
+            continue;
+
           // Gather 'dstNode' store ops to 'memref'.
           SmallVector<Operation *, 2> dstStoreOpInsts;
           for (auto *storeOpInst : dstNode->stores)
@@ -1562,16 +1615,7 @@ public:
                                   dstStoreOpInsts, &sliceState,
                                   &bestDstLoopDepth, maximalFusion))
             continue;
-          // TODO(andydavis) Remove the following test code when canFuseLoops
-          // is fully functional.
-          mlir::ComputationSliceState sliceUnion;
-          if (!maximalFusion) {
-            FusionResult result = mlir::canFuseLoops(
-                cast<AffineForOp>(srcNode->op), cast<AffineForOp>(dstNode->op),
-                bestDstLoopDepth, &sliceUnion);
-            assert(result.value == FusionResult::Success);
-            (void)result;
-          }
+
           // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
           auto sliceLoopNest = mlir::insertBackwardComputationSlice(
               srcStoreOp, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
@@ -1584,7 +1628,8 @@ public:
               dstAffineForOp.getOperation()->moveBefore(insertPointInst);
             }
             // Update edges between 'srcNode' and 'dstNode'.
-            mdg->updateEdges(srcNode->id, dstNode->id, memref);
+            mdg->updateEdges(srcNode->id, dstNode->id, memref,
+                             createPrivateMemref);
 
             // Collect slice loop stats.
             LoopNestStateCollector sliceCollector;
@@ -1593,14 +1638,15 @@ public:
             for (auto forOp : sliceCollector.forOps) {
               promoteIfSingleIteration(forOp);
             }
-            if (!writesToLiveInOrOut) {
+            if (createPrivateMemref) {
               // Create private memref for 'memref' in 'dstAffineForOp'.
               SmallVector<Operation *, 4> storesForMemref;
               for (auto *storeOpInst : sliceCollector.storeOpInsts) {
                 if (cast<AffineStoreOp>(storeOpInst).getMemRef() == memref)
                   storesForMemref.push_back(storeOpInst);
               }
-              assert(storesForMemref.size() == 1);
+              // TODO(andydavis) Use union of memref write regions to compute
+              // private memref footprint.
               auto *newMemRef = createPrivateMemRef(
                   dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
                   fastMemorySpace, localBufSizeThreshold);
index 36bcd0e..592b45d 100644 (file)
@@ -321,11 +321,8 @@ func @should_fuse_producer_consumer() {
   // TODO(andydavis) When the fusion pass is run to a fixed-point, it should
   // fuse all three of these loop nests.
   // CHECK:      %{{.*}} = alloc() : memref<1xf32>
-  // CHECK:      %{{.*}} = alloc() : memref<10xf32>
   // CHECK:      affine.for %{{.*}} = 0 to 10 {
-  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
-  // CHECK-NEXT: }
-  // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
+  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
   // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
   // CHECK-NEXT:   %{{.*}} = affine.load %{{.*}}[0] : memref<1xf32>
   // CHECK-NEXT: }
@@ -1238,7 +1235,6 @@ func @R3_to_R2_reshape() {
 
 // -----
 
-// CHECK-LABEL: func @should_not_fuse_multi_output_producer() {
 func @should_not_fuse_multi_output_producer() {
   %a = alloc() : memref<10xf32>
   %b = alloc() : memref<10xf32>
@@ -2341,3 +2337,57 @@ func @should_fuse_function_live_out_multi_store_producer(%live_in_out_m : memref
   // CHECK-NEXT: return
   return
 }
+
+// -----
+
+// Test case from github bug 777.
+// CHECK-LABEL: func @mul_add_0
+func @mul_add_0(%arg0: memref<3x4xf32>, %arg1: memref<4x3xf32>, %arg2: memref<3x3xf32>, %arg3: memref<3x3xf32>) {
+  %cst = constant 0.000000e+00 : f32
+  %0 = alloc() : memref<3x3xf32>
+  affine.for %arg4 = 0 to 3 {
+    affine.for %arg5 = 0 to 3 {
+      affine.store %cst, %0[%arg4, %arg5] : memref<3x3xf32>
+    }
+  }
+  affine.for %arg4 = 0 to 3 {
+    affine.for %arg5 = 0 to 3 {
+      affine.for %arg6 = 0 to 4 {
+        %1 = affine.load %arg1[%arg6, %arg5] : memref<4x3xf32>
+        %2 = affine.load %arg0[%arg4, %arg6] : memref<3x4xf32>
+        %3 = mulf %2, %1 : f32
+        %4 = affine.load %0[%arg4, %arg5] : memref<3x3xf32>
+        %5 = addf %4, %3 : f32
+        affine.store %5, %0[%arg4, %arg5] : memref<3x3xf32>
+      }
+    }
+  }
+  affine.for %arg4 = 0 to 3 {
+    affine.for %arg5 = 0 to 3 {
+      %6 = affine.load %arg2[%arg4, %arg5] : memref<3x3xf32>
+      %7 = affine.load %0[%arg4, %arg5] : memref<3x3xf32>
+      %8 = addf %7, %6 : f32
+      affine.store %8, %arg3[%arg4, %arg5] : memref<3x3xf32>
+    }
+  }
+  // CHECK:      affine.for %[[i0:.*]] = 0 to 3 {
+  // CHECK-NEXT:   affine.for %[[i1:.*]] = 0 to 3 {
+  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[0, 0] : memref<1x1xf32>
+  // CHECK-NEXT:     affine.for %[[i2:.*]] = 0 to 4 {
+  // CHECK-NEXT:       affine.load %{{.*}}[%[[i2]], %[[i1]]] : memref<4x3xf32>
+  // CHECK-NEXT:       affine.load %{{.*}}[%[[i0]], %[[i2]]] : memref<3x4xf32>
+  // CHECK-NEXT:       %{{.*}} = mulf %{{.*}}, %{{.*}} : f32
+  // CHECK-NEXT:       affine.load %{{.*}}[0, 0] : memref<1x1xf32>
+  // CHECK-NEXT:       %{{.*}} = addf %{{.*}}, %{{.*}} : f32
+  // CHECK-NEXT:       affine.store %{{.*}}, %{{.*}}[0, 0] : memref<1x1xf32>
+  // CHECK-NEXT:     }
+  // CHECK-NEXT:     affine.load %{{.*}}[%[[i0]], %[[i1]]] : memref<3x3xf32>
+  // CHECK-NEXT:     affine.load %{{.*}}[0, 0] : memref<1x1xf32>
+  // CHECK-NEXT:     %{{.*}} = addf %{{.*}}, %{{.*}} : f32
+  // CHECK-NEXT:     affine.store %{{.*}}, %{{.*}}[%[[i0]], %[[i1]]] : memref<3x3xf32>
+  // CHECK-NEXT:   }
+  // CHECK-NEXT: }
+  // CHECK-NEXT: return
+
+  return
+}