[MLIR][affine] Certain Call Ops to prevent fusion
authorVinayaka Bandishti <vinayaka@polymagelabs.com>
Fri, 26 Feb 2021 09:56:12 +0000 (15:26 +0530)
committerUday Bondhugula <uday@polymagelabs.com>
Fri, 26 Feb 2021 09:57:41 +0000 (15:27 +0530)
Fixes a bug in affine fusion pipeline where an incorrect fusion is performed
despite a Call Op that potentially modifies memrefs under consideration
exists between source and target.

Fixes part of https://bugs.llvm.org/show_bug.cgi?id=49220

Reviewed By: bondhugula, dcaballe

Differential Revision: https://reviews.llvm.org/D97252

mlir/lib/Transforms/LoopFusion.cpp
mlir/test/Transforms/loop-fusion.mlir

index 4e02f27..26a20ee 100644 (file)
@@ -784,6 +784,15 @@ bool MemRefDependenceGraph::init(FuncOp f) {
       // could be used by loop nest nodes.
       Node node(nextNodeId++, &op);
       nodes.insert({node.id, node});
+    } else if (isa<CallOpInterface>(op)) {
+      // Create graph node for top-level Call Op that takes any argument of
+      // memref type. Call Op that returns one or more memref type results
+      // is already taken care of, by the previous conditions.
+      if (llvm::any_of(op.getOperandTypes(),
+                       [&](Type t) { return t.isa<MemRefType>(); })) {
+        Node node(nextNodeId++, &op);
+        nodes.insert({node.id, node});
+      }
     } else if (auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
       // Create graph node for top-level op, which could have a memory write
       // side effect.
index 9b2966b..e800076 100644 (file)
@@ -3016,3 +3016,55 @@ func @should_not_fuse_src_loop_nest_return_value(
 
   return
 }
+
+// -----
+
+func private @some_function(memref<16xf32>)
+func @call_op_prevents_fusion(%arg0: memref<16xf32>){
+  %A = alloc() : memref<16xf32>
+  %cst_1 = constant 1.000000e+00 : f32
+  affine.for %arg1 = 0 to 16 {
+    %a = affine.load %arg0[%arg1] : memref<16xf32>
+    affine.store %a, %A[%arg1] : memref<16xf32>
+  }
+  call @some_function(%A) : (memref<16xf32>) -> ()
+  %B = alloc() : memref<16xf32>
+  affine.for %arg1 = 0 to 16 {
+    %a = affine.load %A[%arg1] : memref<16xf32>
+    %b = addf %cst_1, %a : f32
+    affine.store %b, %B[%arg1] : memref<16xf32>
+  }
+  return
+}
+// CHECK-LABEL: func @call_op_prevents_fusion
+// CHECK:         affine.for
+// CHECK-NEXT:      affine.load
+// CHECK-NEXT:      affine.store
+// CHECK:         call
+// CHECK:         affine.for
+// CHECK-NEXT:      affine.load
+// CHECK-NEXT:      addf
+// CHECK-NEXT:      affine.store
+
+// -----
+
+func private @some_function()
+func @call_op_does_not_prevent_fusion(%arg0: memref<16xf32>){
+  %A = alloc() : memref<16xf32>
+  %cst_1 = constant 1.000000e+00 : f32
+  affine.for %arg1 = 0 to 16 {
+    %a = affine.load %arg0[%arg1] : memref<16xf32>
+    affine.store %a, %A[%arg1] : memref<16xf32>
+  }
+  call @some_function() : () -> ()
+  %B = alloc() : memref<16xf32>
+  affine.for %arg1 = 0 to 16 {
+    %a = affine.load %A[%arg1] : memref<16xf32>
+    %b = addf %cst_1, %a : f32
+    affine.store %b, %B[%arg1] : memref<16xf32>
+  }
+  return
+}
+// CHECK-LABEL: func @call_op_does_not_prevent_fusion
+// CHECK:         affine.for
+// CHECK-NOT:     affine.for