return inits();
}
+ConditionOp WhileOp::getConditionOp() {
+ return cast<ConditionOp>(before().front().getTerminator());
+}
+
+Block::BlockArgListType WhileOp::getAfterArguments() {
+ return after().front().getArguments();
+}
+
void WhileOp::getSuccessorRegions(Optional<unsigned> index,
ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> ®ions) {
return success(afterTerminator != nullptr);
}
+namespace {
+/// Replace uses of the condition within the do block with true, since otherwise
+/// the block would not be evaluated.
+///
+/// scf.while (..) : (i1, ...) -> ... {
+/// %condition = call @evaluate_condition() : () -> i1
+/// scf.condition(%condition) %condition : i1, ...
+/// } do {
+/// ^bb0(%arg0: i1, ...):
+/// use(%arg0)
+/// ...
+///
+/// becomes
+/// scf.while (..) : (i1, ...) -> ... {
+/// %condition = call @evaluate_condition() : () -> i1
+/// scf.condition(%condition) %condition : i1, ...
+/// } do {
+/// ^bb0(%arg0: i1, ...):
+/// use(%true)
+/// ...
+struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
+ using OpRewritePattern<WhileOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(WhileOp op,
+ PatternRewriter &rewriter) const override {
+ auto term = op.getConditionOp();
+
+ // These variables serve to prevent creating duplicate constants
+ // and hold constant true or false values.
+ Value constantTrue = nullptr;
+
+ bool replaced = false;
+ for (auto yieldedAndBlockArgs :
+ llvm::zip(term.args(), op.getAfterArguments())) {
+ if (std::get<0>(yieldedAndBlockArgs) == term.condition()) {
+ if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
+ if (!constantTrue)
+ constantTrue = rewriter.create<mlir::ConstantOp>(
+ op.getLoc(), term.condition().getType(),
+ rewriter.getBoolAttr(true));
+
+ std::get<1>(yieldedAndBlockArgs).replaceAllUsesWith(constantTrue);
+ replaced = true;
+ }
+ }
+ }
+ return success(replaced);
+ }
+};
+} // namespace
+
+void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<WhileConditionTruth>(context);
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
// CHECK-NEXT: scf.yield %[[sv2]] : i32
// CHECK-NEXT: }
// CHECK-NEXT: return %[[if]], %arg1 : i32, i64
+
+
+// CHECK-LABEL: @while_cond_true
+func @while_cond_true() {
+ %0 = scf.while () : () -> i1 {
+ %condition = "test.condition"() : () -> i1
+ scf.condition(%condition) %condition : i1
+ } do {
+ ^bb0(%arg0: i1):
+ "test.use"(%arg0) : (i1) -> ()
+ scf.yield
+ }
+ return
+}
+// CHECK-NEXT: %[[true:.+]] = constant true
+// CHECK-NEXT: %{{.+}} = scf.while : () -> i1 {
+// CHECK-NEXT: %[[cmp:.+]] = "test.condition"() : () -> i1
+// CHECK-NEXT: scf.condition(%[[cmp]]) %[[cmp]] : i1
+// CHECK-NEXT: } do {
+// CHECK-NEXT: ^bb0(%arg0: i1): // no predecessors
+// CHECK-NEXT: "test.use"(%[[true]]) : (i1) -> ()
+// CHECK-NEXT: scf.yield
+// CHECK-NEXT: }