Fix single producer check in loop fusion pass.
authorMLIR Team <no-reply@google.com>
Wed, 23 Jan 2019 19:11:43 +0000 (11:11 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 22:32:20 +0000 (15:32 -0700)
PiperOrigin-RevId: 230565482

mlir/lib/Transforms/LoopFusion.cpp
mlir/test/Transforms/loop-fusion.mlir

index 239915b1d4b42ee96e3ba98b659ed947853ccd23..94d763fcbd18b65c11236cbde3b59cb73ef944c7 100644 (file)
@@ -1177,9 +1177,9 @@ public:
           // Skip if 'srcNode' is not a loop nest.
           if (!isa<ForInst>(srcNode->inst))
             continue;
-
-          // Skip if 'srcNode' has more than one store to 'memref'.
-          if (srcNode->getStoreOpCount(memref) != 1)
+          // Skip if 'srcNode' has more than one store to any memref.
+          // TODO(andydavis) Support fusing multi-output src loop nests.
+          if (srcNode->stores.size() != 1)
             continue;
 
           // Skip 'srcNode' if it has in dependence edges. NOTE: This is overly
index 57b5d8dd0effe9977b4811053d3569a2421918b1..86a24cf77964c9a29609404b14f98b29092d9682 100644 (file)
@@ -1288,3 +1288,31 @@ func @R3_to_R2_reshape() {
 // CHECK-NEXT:   }
 // CHECK-NEXT:   return
 // CHECK-NEXT: }
+
+// -----
+
+// CHECK-LABEL: func @should_not_fuse_multi_output_producer() {
+func @should_not_fuse_multi_output_producer() {
+  %a = alloc() : memref<10xf32>
+  %b = alloc() : memref<10xf32>
+
+  %cf7 = constant 7.0 : f32
+
+  for %i0 = 0 to 10 {
+    store %cf7, %a[%i0] : memref<10xf32>
+    store %cf7, %b[%i0] : memref<10xf32>
+  }
+  for %i1 = 0 to 10 {
+    %v0 = load %a[%i1] : memref<10xf32>
+  }
+
+  // CHECK:       for %i0 = 0 to 10 {
+  // CHECK-NEXT:    store %cst, %0[%i0] : memref<10xf32>
+  // CHECK-NEXT:    store %cst, %1[%i0] : memref<10xf32>
+  // CHECK-NEXT:  }
+  // CHECK-NEXT:  for %i1 = 0 to 10 {
+  // CHECK-NEXT:    %2 = load %0[%i1] : memref<10xf32>
+  // CHECK-NEXT:  }
+  // CHECK-NEXT:  return
+  return
+}