[mlir][affine-loop-fusion] Fix a bug that AffineIfOp prevents fusion of the other...
authorTung D. Le <tung@jp.ibm.com>
Fri, 30 Jul 2021 09:52:21 +0000 (15:22 +0530)
committerUday Bondhugula <uday@polymagelabs.com>
Fri, 30 Jul 2021 09:52:46 +0000 (15:22 +0530)
The presence of AffineIfOp inside AffineFor prevents fusion of the other loops to happen. For example:

```
  affine.for %i0 = 0 to 10 {
    affine.store %cf7, %a[%i0] : memref<10xf32>
  }
  affine.for %i1 = 0 to 10 {
    %v0 = affine.load %a[%i1] : memref<10xf32>
    affine.store %v0, %b[%i1] : memref<10xf32>
  }
  affine.for %i2 = 0 to 10 {
    affine.if #set(%i2) {
      %v0 = affine.load %b[%i2] : memref<10xf32>
    }
  }
```

The first two loops were not be fused because of `affine.if` inside the last `affine.for`.

The issue seems to come from a conservative constraint that does not allow fusion if there are ops whose number of regions != 0 (affine.if is one of them).

This patch just removes such a constraint when`affine.if` is inside `affine.for`.  The existing `canFuseLoops` method is able to handle `affine.if` correctly.

Reviewed By: bondhugula, vinayaka-polymage

Differential Revision: https://reviews.llvm.org/D105963

mlir/lib/Transforms/LoopFusion.cpp
mlir/test/Transforms/loop-fusion.mlir

index 49bd52d..955230d 100644 (file)
@@ -70,19 +70,20 @@ mlir::createLoopFusionPass(unsigned fastMemorySpace,
 namespace {
 
 // LoopNestStateCollector walks loop nests and collects load and store
-// operations, and whether or not an IfInst was encountered in the loop nest.
+// operations, and whether or not a region holding op other than ForOp and IfOp
+// was encountered in the loop nest.
 struct LoopNestStateCollector {
   SmallVector<AffineForOp, 4> forOps;
   SmallVector<Operation *, 4> loadOpInsts;
   SmallVector<Operation *, 4> storeOpInsts;
-  bool hasNonForRegion = false;
+  bool hasNonAffineRegionOp = false;
 
   void collect(Operation *opToWalk) {
     opToWalk->walk([&](Operation *op) {
       if (isa<AffineForOp>(op))
         forOps.push_back(cast<AffineForOp>(op));
-      else if (op->getNumRegions() != 0)
-        hasNonForRegion = true;
+      else if (op->getNumRegions() != 0 && !isa<AffineIfOp>(op))
+        hasNonAffineRegionOp = true;
       else if (isa<AffineReadOpInterface>(op))
         loadOpInsts.push_back(op);
       else if (isa<AffineWriteOpInterface>(op))
@@ -744,9 +745,9 @@ bool MemRefDependenceGraph::init(FuncOp f) {
       // all loads and store accesses it contains.
       LoopNestStateCollector collector;
       collector.collect(&op);
-      // Return false if a non 'affine.for' region was found (not currently
-      // supported).
-      if (collector.hasNonForRegion)
+      // Return false if a region holding op other than 'affine.for' and
+      // 'affine.if' was found (not currently supported).
+      if (collector.hasNonAffineRegionOp)
         return false;
       Node node(nextNodeId++, &op);
       for (auto *opInst : collector.loadOpInsts) {
index 650a8ad..2a3bad1 100644 (file)
@@ -445,8 +445,8 @@ func @should_fuse_no_top_level_access() {
 
 #set0 = affine_set<(d0) : (1 == 0)>
 
-// CHECK-LABEL: func @should_not_fuse_if_inst_at_top_level() {
-func @should_not_fuse_if_inst_at_top_level() {
+// CHECK-LABEL: func @should_not_fuse_if_op_at_top_level() {
+func @should_not_fuse_if_op_at_top_level() {
   %m = memref.alloc() : memref<10xf32>
   %cf7 = constant 7.0 : f32
 
@@ -473,8 +473,8 @@ func @should_not_fuse_if_inst_at_top_level() {
 
 #set0 = affine_set<(d0) : (1 == 0)>
 
-// CHECK-LABEL: func @should_not_fuse_if_inst_in_loop_nest() {
-func @should_not_fuse_if_inst_in_loop_nest() {
+// CHECK-LABEL: func @should_not_fuse_if_op_in_loop_nest() {
+func @should_not_fuse_if_op_in_loop_nest() {
   %m = memref.alloc() : memref<10xf32>
   %cf7 = constant 7.0 : f32
   %c4 = constant 4 : index
@@ -488,7 +488,7 @@ func @should_not_fuse_if_inst_in_loop_nest() {
     %v0 = affine.load %m[%i1] : memref<10xf32>
   }
 
-  // IfOp in ForInst should prevent fusion.
+  // IfOp in ForOp should prevent fusion.
   // CHECK:      affine.for %{{.*}} = 0 to 10 {
   // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
   // CHECK-NEXT: }
@@ -502,6 +502,83 @@ func @should_not_fuse_if_inst_in_loop_nest() {
 
 // -----
 
+#set = affine_set<(d0) : (d0 - 1 >= 0)>
+
+// CHECK-LABEL: func @should_fuse_if_op_in_loop_nest_not_sandwiched() -> memref<10xf32> {
+func @should_fuse_if_op_in_loop_nest_not_sandwiched() -> memref<10xf32> {
+  %a = memref.alloc() : memref<10xf32>
+  %b = memref.alloc() : memref<10xf32>
+  %cf7 = constant 7.0 : f32
+
+  affine.for %i0 = 0 to 10 {
+    affine.store %cf7, %a[%i0] : memref<10xf32>
+  }
+  affine.for %i1 = 0 to 10 {
+    %v0 = affine.load %a[%i1] : memref<10xf32>
+    affine.store %v0, %b[%i1] : memref<10xf32>
+  }
+  affine.for %i2 = 0 to 10 {
+    affine.if #set(%i2) {
+      %v0 = affine.load %b[%i2] : memref<10xf32>
+    }
+  }
+
+  // IfOp in ForOp should not prevent fusion if it does not in between the
+  // source and dest ForOp ops.
+
+  // CHECK:      affine.for
+  // CHECK-NEXT:   affine.store
+  // CHECK-NEXT:   affine.load
+  // CHECK-NEXT:   affine.store
+  // CHECK:      affine.for
+  // CHECK-NEXT:   affine.if
+  // CHECK-NEXT:     affine.load
+  // CHECK-NOT:  affine.for
+  // CHECK:      return
+
+  return %a : memref<10xf32>
+}
+
+// -----
+
+#set = affine_set<(d0) : (d0 - 1 >= 0)>
+
+// CHECK-LABEL: func @should_not_fuse_if_op_in_loop_nest_between_src_and_dest() -> memref<10xf32> {
+func @should_not_fuse_if_op_in_loop_nest_between_src_and_dest() -> memref<10xf32> {
+  %a = memref.alloc() : memref<10xf32>
+  %b = memref.alloc() : memref<10xf32>
+  %cf7 = constant 7.0 : f32
+
+  affine.for %i0 = 0 to 10 {
+    affine.store %cf7, %a[%i0] : memref<10xf32>
+  }
+  affine.for %i1 = 0 to 10 {
+    affine.if #set(%i1) {
+      affine.store %cf7, %a[%i1] : memref<10xf32>
+    }
+  }
+  affine.for %i3 = 0 to 10 {
+    %v0 = affine.load %a[%i3] : memref<10xf32>
+    affine.store %v0, %b[%i3] : memref<10xf32>
+  }
+  return %b : memref<10xf32>
+
+  // IfOp in ForOp which modifies the memref should prevent fusion if it is in
+  // between the source and dest ForOp.
+
+  // CHECK:      affine.for
+  // CHECK-NEXT:   affine.store
+  // CHECK:      affine.for
+  // CHECK-NEXT:   affine.if
+  // CHECK-NEXT:     affine.store
+  // CHECK:      affine.for
+  // CHECK-NEXT:   affine.load
+  // CHECK-NEXT:   affine.store
+  // CHECK:      return
+}
+
+// -----
+
 // CHECK-LABEL: func @permute_and_fuse() {
 func @permute_and_fuse() {
   %m = memref.alloc() : memref<10x20x30xf32>