LogicalResult matchAndRewrite(WhileOp whileOp,
PatternRewriter &rewriter) const override;
};
+
+/// Lower an `scf.index_switch` operation to a `cf.switch` operation.
+struct IndexSwitchLowering : public OpRewritePattern<IndexSwitchOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IndexSwitchOp op,
+ PatternRewriter &rewriter) const override;
+};
} // namespace
LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
return success();
}
+LogicalResult
+IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
+ PatternRewriter &rewriter) const {
+ // Split the block at the op.
+ Block *condBlock = rewriter.getInsertionBlock();
+ Block *continueBlock = rewriter.splitBlock(condBlock, Block::iterator(op));
+
+ // Create the arguments on the continue block with which to replace the
+ // results of the op.
+ SmallVector<Value> results;
+ results.reserve(op.getNumResults());
+ for (Type resultType : op.getResultTypes())
+ results.push_back(continueBlock->addArgument(resultType, op.getLoc()));
+
+ // Handle the regions.
+ auto convertRegion = [&](Region ®ion) -> FailureOr<Block *> {
+ Block *block = ®ion.front();
+
+ // Convert the yield terminator to a branch to the continue block.
+ auto yield = cast<scf::YieldOp>(block->getTerminator());
+ rewriter.setInsertionPoint(yield);
+ rewriter.replaceOpWithNewOp<cf::BranchOp>(yield, continueBlock,
+ yield.getOperands());
+
+ // Inline the region.
+ rewriter.inlineRegionBefore(region, continueBlock);
+ return block;
+ };
+
+ // Convert the case regions.
+ SmallVector<Block *> caseSuccessors;
+ SmallVector<int32_t> caseValues;
+ caseSuccessors.reserve(op.getCases().size());
+ caseValues.reserve(op.getCases().size());
+ for (auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) {
+ FailureOr<Block *> block = convertRegion(region);
+ if (failed(block))
+ return failure();
+ caseSuccessors.push_back(*block);
+ caseValues.push_back(value);
+ }
+
+ // Convert the default region.
+ FailureOr<Block *> defaultBlock = convertRegion(op.getDefaultRegion());
+ if (failed(defaultBlock))
+ return failure();
+
+ // Create the switch.
+ rewriter.setInsertionPointToEnd(condBlock);
+ SmallVector<ValueRange> caseOperands(caseSuccessors.size(), {});
+ rewriter.create<cf::SwitchOp>(
+ op.getLoc(), op.getArg(), *defaultBlock, ValueRange(),
+ rewriter.getDenseI32ArrayAttr(caseValues), caseSuccessors, caseOperands);
+ rewriter.replaceOp(op, continueBlock->getArguments());
+ return success();
+}
+
void mlir::populateSCFToControlFlowConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
- ExecuteRegionLowering>(patterns.getContext());
+ ExecuteRegionLowering, IndexSwitchLowering>(
+ patterns.getContext());
patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
}
scf.condition(%0) %2, %3 : i64, f64
} do {
// CHECK: ^[[AFTER]](%[[ARG4:.*]]: i64, %[[ARG5:.*]]: f64):
- ^bb0(%arg2: i64, %arg3: f64):
+ ^bb0(%arg2: i64, %arg3: f64):
// CHECK: cf.br ^[[BEFORE]](%{{.*}}, %{{.*}} : i32, f32)
scf.yield %c0_i32, %cst : i32, f32
}
// CHECK: ^[[bb3]](%[[z:.+]]: i64):
// CHECK: "test.bar"(%[[z]])
// CHECK: return
+
+// SWITCH-LABEL: @index_switch
+func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
+ // SWITCH: cf.switch %arg0 : index
+ // SWITCH-NEXT: default: ^bb3
+ // SWITCH-NEXT: 0: ^bb1
+ // SWITCH-NEXT: 1: ^bb2
+ %0 = scf.index_switch %i -> i32
+ // SWITCH: ^bb1:
+ case 0 {
+ // SWITCH-NEXT: llvm.br ^bb4(%arg1
+ scf.yield %a : i32
+ }
+ // SWITCH: ^bb2:
+ case 1 {
+ // SWITCH-NEXT: llvm.br ^bb4(%arg2
+ scf.yield %b : i32
+ }
+ // SWITCH: ^bb3:
+ default {
+ // SWITCH-NEXT: llvm.br ^bb4(%arg3
+ scf.yield %c : i32
+ }
+ // SWITCH: ^bb4(%[[V:.*]]: i32
+ // SWITCH-NEXT: return %[[V]]
+ return %0 : i32
+}