Add support for some multi-store cases in affine fusion
authorDiego Caballero <diego.caballero@intel.com>
Wed, 9 Oct 2019 17:36:54 +0000 (10:36 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 9 Oct 2019 17:37:30 +0000 (10:37 -0700)
This PR is a stepping stone towards supporting generic multi-store
source loop nests in affine loop fusion. It extends the algorithm to
support fusion of multi-store loop nests that:
 1. have only one store that writes to a function-local live out, and
 2. the remaining stores are involved in loop nest self dependences
    or no dependences within the function.

Closes tensorflow/mlir#162

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/162 from dcaballe:dcaballe/multi-output-fusion 7fb7dec6fe8b45f5ce176f018bfe37b256420c45
PiperOrigin-RevId: 273773907

mlir/lib/Transforms/LoopFusion.cpp
mlir/test/Transforms/loop-fusion.mlir

index 188165b..15dc36c 100644 (file)
@@ -322,6 +322,44 @@ public:
     return false;
   }
 
+  // Returns the unique AffineStoreOp in `node` that meets all the following:
+  //   *) store is the only one that writes to a function-local memref live out
+  //      of `node`,
+  //   *) store is not the source of a self-dependence on `node`.
+  // Otherwise, returns a null AffineStoreOp.
+  AffineStoreOp getUniqueOutgoingStore(Node *node) {
+    AffineStoreOp uniqueStore;
+
+    // Return null if `node` doesn't have any outgoing edges.
+    auto outEdgeIt = outEdges.find(node->id);
+    if (outEdgeIt == outEdges.end())
+      return nullptr;
+
+    const auto &nodeOutEdges = outEdgeIt->second;
+    for (auto *op : node->stores) {
+      auto storeOp = cast<AffineStoreOp>(op);
+      auto *memref = storeOp.getMemRef();
+      // Skip this store if there are no dependences on its memref. This means
+      // that store either:
+      // *) writes to a memref that is only read within the same loop nest
+      //    (self-dependence edges are not represented in graph at the moment),
+      // *) writes to a function live out memref (function parameter), or
+      // *) is dead.
+      if (llvm::all_of(nodeOutEdges, [=](const Edge &edge) {
+            return (edge.value != memref);
+          }))
+        continue;
+
+      if (uniqueStore)
+        // Found multiple stores to function-local live-out memrefs.
+        return nullptr;
+      // Found first store to function-local live-out memref.
+      uniqueStore = storeOp;
+    }
+
+    return uniqueStore;
+  }
+
   // Returns true if node 'id' can be removed from the graph. Returns false
   // otherwise. A node can be removed from the graph iff the following
   // conditions are met:
@@ -963,42 +1001,30 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
   return newMemRef;
 }
 
-// Checks if node 'srcId' (which writes to a live out memref), can be safely
-// fused into node 'dstId'. Returns true if the following conditions are met:
-// *) 'srcNode' only writes to live out 'memref'.
-// *) 'srcNode' has exactly one output edge on 'memref' (which is to 'dstId').
-// *) 'dstNode's read/write region to 'memref' is a super set of 'srcNode's
-//    write region to 'memref'.
+// Checks if node 'srcId' can be safely fused into node 'dstId'. Node 'srcId'
+// may write to multiple memrefs but it is required that only one of them,
+// 'srcLiveOutStoreOp', have an output edge.
+// Returns true if 'dstNode's read/write region to 'memref' is a super set of
+// 'srcNode's write region to 'memref'.
 // TODO(andydavis) Generalize this to handle more live in/out cases.
 static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
-                                           Value *memref,
+                                           AffineStoreOp srcLiveOutStoreOp,
                                            MemRefDependenceGraph *mdg) {
-  auto *srcNode = mdg->getNode(srcId);
+  assert(srcLiveOutStoreOp && "Expected a valid store op");
+  assert(mdg->getOutEdgeCount(srcId) == 1 && "Expected only one output edge");
   auto *dstNode = mdg->getNode(dstId);
+  Value *memref = srcLiveOutStoreOp.getMemRef();
 
-  // Gather all memrefs from 'srcNode' store ops.
-  DenseSet<Value *> storeMemrefs;
-  for (auto *storeOpInst : srcNode->stores) {
-    storeMemrefs.insert(cast<AffineStoreOp>(storeOpInst).getMemRef());
-  }
-  // Return false if any of the following are true:
-  // *) 'srcNode' writes to a live in/out memref other than 'memref'.
-  // *) 'srcNode' has more than one output edge on 'memref'.
-  // Check that all stores are to the same memref.
-  if (storeMemrefs.size() != 1 ||
-      mdg->getOutEdgeCount(srcNode->id, memref) != 1)
-    return false;
-  // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOpInst' on 'memref'.
-  auto *srcStoreOpInst = srcNode->stores.front();
-  MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
-  if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) {
+  // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOp' on 'memref'.
+  MemRefRegion srcWriteRegion(srcLiveOutStoreOp.getLoc());
+  if (failed(srcWriteRegion.compute(srcLiveOutStoreOp, /*loopDepth=*/0))) {
     LLVM_DEBUG(llvm::dbgs()
                << "Unable to compute MemRefRegion for source operation\n.");
     return false;
   }
   SmallVector<int64_t, 4> srcShape;
   // Query 'srcWriteRegion' for 'srcShape' and 'srcNumElements'.
-  // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
+  // by 'srcStoreOp' at depth 'dstLoopDepth'.
   Optional<int64_t> srcNumElements =
       srcWriteRegion.getConstantBoundingSizeAndShape(&srcShape);
   if (!srcNumElements.hasValue())
@@ -1491,17 +1517,25 @@ public:
           // Skip if 'srcNode' is not a loop nest.
           if (!isa<AffineForOp>(srcNode->op))
             continue;
-          // Skip if 'srcNode' has more than one store to any memref.
-          // TODO(andydavis) Support fusing multi-output src loop nests.
-          if (srcNode->stores.size() != 1)
+          // Skip if 'srcNode' has more than one live-out store to a
+          // function-local memref.
+          // TODO(andydavis) Support more generic multi-output src loop nests
+          // fusion.
+          auto srcStoreOp = mdg->getUniqueOutgoingStore(srcNode);
+          if (!srcStoreOp)
             continue;
+          // Unique outgoing store found must write to 'memref' since 'memref'
+          // is the one that established the producer-consumer relationship
+          // between 'srcNode' and 'dstNode'.
+          assert(srcStoreOp.getMemRef() == memref &&
+                 "Found store to unexpected memref");
 
           // Skip if 'srcNode' writes to any live in or escaping memrefs,
           // and cannot be fused.
           bool writesToLiveInOrOut =
               mdg->writesToLiveInOrEscapingMemrefs(srcNode->id);
           if (writesToLiveInOrOut &&
-              !canFuseSrcWhichWritesToLiveOut(srcId, dstId, memref, mdg))
+              !canFuseSrcWhichWritesToLiveOut(srcId, dstId, srcStoreOp, mdg))
             continue;
 
           // Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'.
@@ -1515,8 +1549,6 @@ public:
           if (insertPointInst == nullptr)
             continue;
 
-          // Get unique 'srcNode' store op.
-          auto *srcStoreOpInst = srcNode->stores.front();
           // Gather 'dstNode' store ops to 'memref'.
           SmallVector<Operation *, 2> dstStoreOpInsts;
           for (auto *storeOpInst : dstNode->stores)
@@ -1526,8 +1558,8 @@ public:
           unsigned bestDstLoopDepth;
           mlir::ComputationSliceState sliceState;
           // Check if fusion would be profitable.
-          if (!isFusionProfitable(srcStoreOpInst, srcStoreOpInst,
-                                  dstLoadOpInsts, dstStoreOpInsts, &sliceState,
+          if (!isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts,
+                                  dstStoreOpInsts, &sliceState,
                                   &bestDstLoopDepth, maximalFusion))
             continue;
           // TODO(andydavis) Remove the following test code when canFuseLoops
@@ -1542,7 +1574,7 @@ public:
           }
           // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
           auto sliceLoopNest = mlir::insertBackwardComputationSlice(
-              srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
+              srcStoreOp, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
           if (sliceLoopNest) {
             LLVM_DEBUG(llvm::dbgs() << "\tslice loop nest:\n"
                                     << *sliceLoopNest.getOperation() << "\n");
index c97e3df..6ff31de 100644 (file)
@@ -1251,6 +1251,7 @@ func @should_not_fuse_multi_output_producer() {
   }
   affine.for %i1 = 0 to 10 {
     %v0 = affine.load %a[%i1] : memref<10xf32>
+    %v1 = affine.load %b[%i1] : memref<10xf32>
   }
 
   // CHECK:       affine.for %{{.*}} = 0 to 10 {
@@ -1259,6 +1260,7 @@ func @should_not_fuse_multi_output_producer() {
   // CHECK-NEXT:  }
   // CHECK-NEXT:  affine.for %{{.*}} = 0 to 10 {
   // CHECK-NEXT:    %{{.*}} = affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:    %{{.*}} = affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
   // CHECK-NEXT:  }
   // CHECK-NEXT:  return
   return
@@ -2266,3 +2268,76 @@ func @affine_2_dependent_mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<10
   // CHECK-NEXT:   }
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @should_fuse_self_dependence_multi_store_producer() {
+func @should_fuse_self_dependence_multi_store_producer() {
+  %m = alloc() : memref<10xf32>
+  %local_m = alloc() : memref<10xf32>
+  %cf7 = constant 7.0 : f32
+
+  affine.for %i0 = 0 to 10 {
+    affine.store %cf7, %local_m[%i0] : memref<10xf32>
+    %v0 = affine.load %local_m[%i0] : memref<10xf32>
+    affine.store %v0, %m[%i0] : memref<10xf32>
+  }
+  affine.for %i1 = 0 to 10 {
+    %v1 = affine.load %m[%i1] : memref<10xf32>
+  }
+  // CHECK:      affine.for %[[i0:.*]] = 0 to 10 {
+  // CHECK-NEXT:   affine.store %{{.*}}, [[LOCAL_M:%.*]][%[[i0]]] : memref<10xf32>
+  // CHECK-NEXT:   [[v0:%.*]] = affine.load [[LOCAL_M]][%[[i0]]] : memref<10xf32>
+  // CHECK-NEXT:   affine.store [[v0]], %{{.*}}[0] : memref<1xf32>
+  // CHECK-NEXT:   affine.load %{{.*}}[0] : memref<1xf32>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: return
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @should_fuse_dead_multi_store_producer() {
+func @should_fuse_dead_multi_store_producer() {
+  %m = alloc() : memref<10xf32>
+  %dead_m = alloc() : memref<10xf32>
+  %cf7 = constant 7.0 : f32
+
+  affine.for %i0 = 0 to 10 {
+    affine.store %cf7, %dead_m[%i0] : memref<10xf32>
+    affine.store %cf7, %m[%i0] : memref<10xf32>
+  }
+  affine.for %i1 = 0 to 10 {
+    %v0 = affine.load %m[%i1] : memref<10xf32>
+  }
+  // CHECK:      affine.for %[[i0:.*]] = 0 to 10 {
+  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[%[[i0]]] : memref<10xf32>
+  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
+  // CHECK-NEXT:   affine.load %{{.*}}[0] : memref<1xf32>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: return
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @should_fuse_function_live_out_multi_store_producer
+func @should_fuse_function_live_out_multi_store_producer(%live_in_out_m : memref<10xf32>) {
+  %m = alloc() : memref<10xf32>
+  %cf7 = constant 7.0 : f32
+
+  affine.for %i0 = 0 to 10 {
+    affine.store %cf7, %live_in_out_m[%i0] : memref<10xf32>
+    affine.store %cf7, %m[%i0] : memref<10xf32>
+  }
+  affine.for %i1 = 0 to 10 {
+    %v0 = affine.load %m[%i1] : memref<10xf32>
+  }
+  // CHECK:      affine.for %[[i0:.*]] = 0 to 10 {
+  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[%[[i0]]] : memref<10xf32>
+  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[%[[i0]]] : memref<10xf32>
+  // CHECK-NEXT:   affine.load %{{.*}}[%[[i0]]] : memref<10xf32>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: return
+  return
+}