[MLIR][SCF] Fix nested if merging bug
authorWilliam S. Moses <gh@wsmoses.com>
Mon, 21 Mar 2022 00:32:36 +0000 (20:32 -0400)
committerWilliam S. Moses <gh@wsmoses.com>
Mon, 21 Mar 2022 15:42:26 +0000 (11:42 -0400)
The current nested if merging has a bug. Specifically, consider the following code:

```
    %r = scf.if %arg3 -> (i32) {
      scf.if %arg1 {
        "test.op"() : () -> ()
      }
      scf.yield %arg0 : i32
    } else {
      scf.yield %arg2 : i32
    }
```

When the above gets merged, it will become:
```
    %r = scf.if %arg3 && %arg1-> (i32) {
      "test.op"() : () -> ()
      scf.yield %arg0 : i32
    } else {
      scf.yield %arg2 : i32
    }
```

However, this means that when only %arg3 is true, we will incorrectly return %arg2 instead
of %arg0. This change updates the behavior of the pass to only enable nested if merging where
the outer yield contains only values from the inner if, or values defined outside of the if.

In the case of the latter, they can turned into a select of only the outer if condition, thus
maintaining correctness.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D122108

mlir/lib/Dialect/SCF/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir

index c36e4cc..114d93b 100644 (file)
@@ -1768,6 +1768,10 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
     if (op.elseBlock())
       llvm::append_range(elseYield, op.elseYield().getOperands());
 
+    // A list of indices for which we should upgrade the value yielded
+    // in the else to a select.
+    SmallVector<unsigned> elseYieldsToUpgradeToSelect;
+
     // If the outer scf.if yields a value produced by the inner scf.if,
     // only permit combining if the value yielded when the condition
     // is false in the outer scf.if is the same value yielded when the
@@ -1785,6 +1789,22 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
         // If the correctness test passes, we will yield
         // corresponding value from the inner scf.if
         thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
+        continue;
+      }
+
+      // Otherwise, we need to ensure the else block of the combined
+      // condition still returns the same value when the outer condition is
+      // true and the inner condition is false. This can be accomplished if
+      // the then value is defined outside the outer scf.if and we replace the
+      // value with a select that considers just the outer condition. Since
+      // the else region contains just the yield, its yielded value is
+      // defined outside the scf.if, by definition.
+
+      // If the then value is defined within the scf.if, bail.
+      if (tup.value().getParentRegion() == &op.getThenRegion()) {
+        return failure();
+      } else {
+        elseYieldsToUpgradeToSelect.push_back(tup.index());
       }
     }
 
@@ -1792,6 +1812,15 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
     Value newCondition = rewriter.create<arith::AndIOp>(
         loc, op.getCondition(), nestedIf.getCondition());
     auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
+
+    SmallVector<Value> results;
+    llvm::append_range(results, newIf.getResults());
+    rewriter.setInsertionPoint(newIf);
+
+    for (auto idx : elseYieldsToUpgradeToSelect)
+      results[idx] = rewriter.create<arith::SelectOp>(
+          op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
+
     Block *newIfBlock = newIf.thenBlock();
     if (newIfBlock)
       rewriter.eraseOp(newIfBlock->getTerminator());
@@ -1805,7 +1834,7 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
       rewriter.setInsertionPointToEnd(newIf.elseBlock());
       rewriter.create<YieldOp>(loc, elseYield);
     }
-    rewriter.replaceOp(op, newIf.getResults());
+    rewriter.replaceOp(op, results);
     return success();
   }
 };
index c9c2b08..de176fb 100644 (file)
@@ -556,7 +556,7 @@ func @merge_yielding_nested_if_nv2(%arg0: i1, %arg1: i1) -> i32 {
 // CHECK: %[[PRE0:.*]] = "test.op"() : () -> i32
 // CHECK: %[[PRE1:.*]] = "test.op1"() : () -> i32
 // CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]]
-// CHECK: %[[RES:.*]] = arith.select %[[COND]], %[[PRE0]], %[[PRE1]]
+// CHECK: %[[RES:.*]] = arith.select %[[ARG0]], %[[PRE0]], %[[PRE1]]
 // CHECK: scf.if %[[COND]] 
 // CHECK:   "test.run"() : () -> ()
 // CHECK: }
@@ -596,6 +596,7 @@ func @merge_fail_yielding_nested_if(%arg0: i1, %arg1: i1) -> (i32, f32, i32, i8)
   }
   return %r#0, %r#1, %r#2, %r#3 : i32, f32, i32, i8
 }
+
 // -----
 
 // CHECK-LABEL:   func @if_condition_swap