[MLIR][SCF] Assume uses of condition in the body of scf.while is true
authorWilliam S. Moses <gh@wsmoses.com>
Tue, 4 May 2021 03:40:44 +0000 (23:40 -0400)
committerWilliam S. Moses <gh@wsmoses.com>
Tue, 4 May 2021 15:39:07 +0000 (11:39 -0400)
Differential Revision: https://reviews.llvm.org/D101801

mlir/include/mlir/Dialect/SCF/SCFOps.td
mlir/lib/Dialect/SCF/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir

index 28348f0..c3c64e0 100644 (file)
@@ -586,7 +586,11 @@ def WhileOp : SCF_Op<"while",
 
   let extraClassDeclaration = [{
     OperandRange getSuccessorEntryOperands(unsigned index);
+    ConditionOp getConditionOp();
+    Block::BlockArgListType getAfterArguments();
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 def YieldOp : SCF_Op<"yield", [NoSideEffect, ReturnLike, Terminator,
index 141b880..c28e438 100644 (file)
@@ -1694,6 +1694,14 @@ OperandRange WhileOp::getSuccessorEntryOperands(unsigned index) {
   return inits();
 }
 
+ConditionOp WhileOp::getConditionOp() {
+  return cast<ConditionOp>(before().front().getTerminator());
+}
+
+Block::BlockArgListType WhileOp::getAfterArguments() {
+  return after().front().getArguments();
+}
+
 void WhileOp::getSuccessorRegions(Optional<unsigned> index,
                                   ArrayRef<Attribute> operands,
                                   SmallVectorImpl<RegionSuccessor> &regions) {
@@ -1835,6 +1843,62 @@ static LogicalResult verify(scf::WhileOp op) {
   return success(afterTerminator != nullptr);
 }
 
+namespace {
+/// Replace uses of the condition within the do block with true, since otherwise
+/// the block would not be evaluated.
+///
+/// scf.while (..) : (i1, ...) -> ... {
+///  %condition = call @evaluate_condition() : () -> i1
+///  scf.condition(%condition) %condition : i1, ...
+/// } do {
+/// ^bb0(%arg0: i1, ...):
+///    use(%arg0)
+///    ...
+///
+/// becomes
+/// scf.while (..) : (i1, ...) -> ... {
+///  %condition = call @evaluate_condition() : () -> i1
+///  scf.condition(%condition) %condition : i1, ...
+/// } do {
+/// ^bb0(%arg0: i1, ...):
+///    use(%true)
+///    ...
+struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
+  using OpRewritePattern<WhileOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(WhileOp op,
+                                PatternRewriter &rewriter) const override {
+    auto term = op.getConditionOp();
+
+    // These variables serve to prevent creating duplicate constants
+    // and hold constant true or false values.
+    Value constantTrue = nullptr;
+
+    bool replaced = false;
+    for (auto yieldedAndBlockArgs :
+         llvm::zip(term.args(), op.getAfterArguments())) {
+      if (std::get<0>(yieldedAndBlockArgs) == term.condition()) {
+        if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
+          if (!constantTrue)
+            constantTrue = rewriter.create<mlir::ConstantOp>(
+                op.getLoc(), term.condition().getType(),
+                rewriter.getBoolAttr(true));
+
+          std::get<1>(yieldedAndBlockArgs).replaceAllUsesWith(constantTrue);
+          replaced = true;
+        }
+      }
+    }
+    return success(replaced);
+  }
+};
+} // namespace
+
+void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                          MLIRContext *context) {
+  results.insert<WhileConditionTruth>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
index 4dee382..3ba8e80 100644 (file)
@@ -724,3 +724,26 @@ func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
 // CHECK-NEXT:       scf.yield %[[sv2]] : i32
 // CHECK-NEXT:     }
 // CHECK-NEXT:     return %[[if]], %arg1 : i32, i64
+
+
+// CHECK-LABEL: @while_cond_true
+func @while_cond_true() {
+  %0 = scf.while () : () -> i1 {
+    %condition = "test.condition"() : () -> i1
+    scf.condition(%condition) %condition : i1
+  } do {
+  ^bb0(%arg0: i1):
+    "test.use"(%arg0) : (i1) -> ()
+    scf.yield
+  }
+  return
+}
+// CHECK-NEXT:         %[[true:.+]] = constant true
+// CHECK-NEXT:         %{{.+}} = scf.while : () -> i1 {
+// CHECK-NEXT:           %[[cmp:.+]] = "test.condition"() : () -> i1
+// CHECK-NEXT:           scf.condition(%[[cmp]]) %[[cmp]] : i1
+// CHECK-NEXT:         } do {
+// CHECK-NEXT:         ^bb0(%arg0: i1):  // no predecessors
+// CHECK-NEXT:           "test.use"(%[[true]]) : (i1) -> ()
+// CHECK-NEXT:           scf.yield
+// CHECK-NEXT:         }