Remove overly conservative check in LoopFusion pass (enables fusion in tutorial example).
authorMLIR Team <no-reply@google.com>
Thu, 28 Mar 2019 21:54:49 +0000 (14:54 -0700)
committerjpienaar <jpienaar@google.com>
Sat, 30 Mar 2019 00:51:16 +0000 (17:51 -0700)
PiperOrigin-RevId: 240859227

mlir/lib/Transforms/LoopFusion.cpp
mlir/test/Transforms/loop-fusion.mlir

index 80308ea..c35b75f 100644 (file)
@@ -1275,7 +1275,8 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
   // Compute MemRefRegion 'dstWriteRegion' for 'dstStoreOpInst' on 'memref'.
   SmallVector<Operation *, 2> dstStoreOps;
   dstNode->getStoreOpsForMemref(memref, &dstStoreOps);
-  assert(dstStoreOps.size() == 1);
+  // TODO(andydavis) Compute 'unionboundingbox' of all write regions (one for
+  // each store op in 'dstStoreOps').
   auto *dstStoreOpInst = dstStoreOps[0];
   MemRefRegion dstWriteRegion(dstStoreOpInst->getLoc());
   if (failed(dstWriteRegion.compute(dstStoreOpInst, /*loopDepth=*/0))) {
@@ -1886,13 +1887,6 @@ public:
           if (srcNode->stores.size() != 1)
             continue;
 
-          // Skip 'srcNode' if it has in edges on 'memref'.
-          // TODO(andydavis) Track dependence type with edges, and just check
-          // for WAW dependence edge here. Note that this check is overly
-          // conservative and will be removed in the future.
-          if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) != 0)
-            continue;
-
           // Skip if 'srcNode' writes to any live in or escaping memrefs,
           // and cannot be fused.
           bool writesToLiveInOrOut =
index 4d21d00..dd3af06 100644 (file)
@@ -342,8 +342,10 @@ func @should_not_fuse_would_create_cycle() {
 
 // -----
 
-// CHECK-LABEL: func @should_not_fuse_across_waw_dep() {
-func @should_not_fuse_across_waw_dep() {
+// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1)
+
+// CHECK-LABEL: func @should_fuse_producer_consumer() {
+func @should_fuse_producer_consumer() {
   %m = alloc() : memref<10xf32>
   %cf7 = constant 7.0 : f32
 
@@ -356,15 +358,20 @@ func @should_not_fuse_across_waw_dep() {
   affine.for %i2 = 0 to 10 {
     %v1 = load %m[%i2] : memref<10xf32>
   }
-  // Fusing loop %i0 to %i2 would violate the WAW dependence between %i0 and %i1
+  // Fusing loop %i0 to %i2 would violate the WAW dependence between %i0 and
+  // %i1, but OK to fuse %i1 into %i2.
+  // TODO(andydavis) When the fusion pass is run to a fixed-point, it should
+  // fuse all three of these loop nests.
+  // CHECK:      %0 = alloc() : memref<1xf32>
+  // CHECK:      %1 = alloc() : memref<10xf32>
   // CHECK:      affine.for %i0 = 0 to 10 {
-  // CHECK-NEXT:   store %cst, %0[%i0] : memref<10xf32>
-  // CHECK-NEXT: }
-  // CHECK:      affine.for %i1 = 0 to 10 {
-  // CHECK-NEXT:   store %cst, %0[%i1] : memref<10xf32>
+  // CHECK-NEXT:   store %cst, %1[%i0] : memref<10xf32>
   // CHECK-NEXT: }
-  // CHECK:      affine.for %i2 = 0 to 10 {
-  // CHECK-NEXT:   %1 = load %0[%i2] : memref<10xf32>
+  // CHECK-NEXT: affine.for %i1 = 0 to 10 {
+  // CHECK-NEXT:   %2 = affine.apply [[MAP0]](%i1, %i1)
+  // CHECK-NEXT:   store %cst, %0[%2] : memref<1xf32>
+  // CHECK-NEXT:   %3 = affine.apply [[MAP0]](%i1, %i1)
+  // CHECK-NEXT:   %4 = load %0[%3] : memref<1xf32>
   // CHECK-NEXT: }
   // CHECK-NEXT: return
   return
@@ -2289,3 +2296,47 @@ func @should_fuse_with_slice_union() {
 // CHECK-NEXT: return
   return
 }
+
+// -----
+
+func @affine_add_mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024xf32>, %arg2: memref<1024x1024xf32>, %arg3: memref<1024x1024xf32>) {
+  affine.for %i2 = 0 to 1024 {
+    affine.for %i3 = 0 to 1024 {
+      %0 = load %arg3[%i2, %i3] : memref<1024x1024xf32>
+      %1 = load %arg2[%i2, %i3] : memref<1024x1024xf32>
+      %2 = addf %1, %0 : f32
+      store %2, %arg2[%i2, %i3] : memref<1024x1024xf32>
+    }
+  }
+  affine.for %i4 = 0 to 1024 {
+    affine.for %i5 = 0 to 1024 {
+      affine.for %i6 = 0 to 1024 {
+        %3 = load %arg1[%i6, %i5] : memref<1024x1024xf32>
+        %4 = load %arg0[%i4, %i6] : memref<1024x1024xf32>
+        %5 = mulf %4, %3 : f32
+        %6 = load %arg2[%i4, %i5] : memref<1024x1024xf32>
+        %7 = addf %6, %5 : f32
+        store %7, %arg2[%i4, %i5] : memref<1024x1024xf32>
+      }
+    }
+  }
+  // Should fuse elementwise add loop at loop depth 2, above loop-carried
+  // dependence between load/store on '%arg2', carried on reduction loop %i6.
+  // CHECK:       affine.for %i0 = 0 to 1024 {
+  // CHECK-NEXT:    affine.for %i1 = 0 to 1024 {
+  // CHECK-NEXT:      %0 = load %arg3[%i0, %i1] : memref<1024x1024xf32>
+  // CHECK-NEXT:      %1 = load %arg2[%i0, %i1] : memref<1024x1024xf32>
+  // CHECK-NEXT:      %2 = addf %1, %0 : f32
+  // CHECK-NEXT:      store %2, %arg2[%i0, %i1] : memref<1024x1024xf32>
+  // CHECK-NEXT:      affine.for %i2 = 0 to 1024 {
+  // CHECK-NEXT:        %3 = load %arg1[%i2, %i1] : memref<1024x1024xf32>
+  // CHECK-NEXT:        %4 = load %arg0[%i0, %i2] : memref<1024x1024xf32>
+  // CHECK-NEXT:        %5 = mulf %4, %3 : f32
+  // CHECK-NEXT:        %6 = load %arg2[%i0, %i1] : memref<1024x1024xf32>
+  // CHECK-NEXT:        %7 = addf %6, %5 : f32
+  // CHECK-NEXT:        store %7, %arg2[%i0, %i1] : memref<1024x1024xf32>
+  // CHECK-NEXT:      }
+  // CHECK-NEXT:    }
+  // CHECK-NEXT:  }
+  return
+}