[mlir][scf] Add scf-to-cf lowering for `scf.index_switch`
authorJeff Niu <jeff@modular.com>
Thu, 27 Oct 2022 20:43:41 +0000 (13:43 -0700)
committerJeff Niu <jeff@modular.com>
Mon, 31 Oct 2022 19:01:22 +0000 (12:01 -0700)
This patch adds lowering from `scf.index_switch` to `cf.switch.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D136883

mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir

index 72f483c..c15832a 100644 (file)
@@ -290,6 +290,14 @@ struct DoWhileLowering : public OpRewritePattern<WhileOp> {
   LogicalResult matchAndRewrite(WhileOp whileOp,
                                 PatternRewriter &rewriter) const override;
 };
+
+/// Lower an `scf.index_switch` operation to a `cf.switch` operation.
+struct IndexSwitchLowering : public OpRewritePattern<IndexSwitchOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IndexSwitchOp op,
+                                PatternRewriter &rewriter) const override;
+};
 } // namespace
 
 LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
@@ -615,10 +623,68 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
   return success();
 }
 
+LogicalResult
+IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
+                                     PatternRewriter &rewriter) const {
+  // Split the block at the op.
+  Block *condBlock = rewriter.getInsertionBlock();
+  Block *continueBlock = rewriter.splitBlock(condBlock, Block::iterator(op));
+
+  // Create the arguments on the continue block with which to replace the
+  // results of the op.
+  SmallVector<Value> results;
+  results.reserve(op.getNumResults());
+  for (Type resultType : op.getResultTypes())
+    results.push_back(continueBlock->addArgument(resultType, op.getLoc()));
+
+  // Handle the regions.
+  auto convertRegion = [&](Region &region) -> FailureOr<Block *> {
+    Block *block = &region.front();
+
+    // Convert the yield terminator to a branch to the continue block.
+    auto yield = cast<scf::YieldOp>(block->getTerminator());
+    rewriter.setInsertionPoint(yield);
+    rewriter.replaceOpWithNewOp<cf::BranchOp>(yield, continueBlock,
+                                              yield.getOperands());
+
+    // Inline the region.
+    rewriter.inlineRegionBefore(region, continueBlock);
+    return block;
+  };
+
+  // Convert the case regions.
+  SmallVector<Block *> caseSuccessors;
+  SmallVector<int32_t> caseValues;
+  caseSuccessors.reserve(op.getCases().size());
+  caseValues.reserve(op.getCases().size());
+  for (auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) {
+    FailureOr<Block *> block = convertRegion(region);
+    if (failed(block))
+      return failure();
+    caseSuccessors.push_back(*block);
+    caseValues.push_back(value);
+  }
+
+  // Convert the default region.
+  FailureOr<Block *> defaultBlock = convertRegion(op.getDefaultRegion());
+  if (failed(defaultBlock))
+    return failure();
+
+  // Create the switch.
+  rewriter.setInsertionPointToEnd(condBlock);
+  SmallVector<ValueRange> caseOperands(caseSuccessors.size(), {});
+  rewriter.create<cf::SwitchOp>(
+      op.getLoc(), op.getArg(), *defaultBlock, ValueRange(),
+      rewriter.getDenseI32ArrayAttr(caseValues), caseSuccessors, caseOperands);
+  rewriter.replaceOp(op, continueBlock->getArguments());
+  return success();
+}
+
 void mlir::populateSCFToControlFlowConversionPatterns(
     RewritePatternSet &patterns) {
   patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
-               ExecuteRegionLowering>(patterns.getContext());
+               ExecuteRegionLowering, IndexSwitchLowering>(
+      patterns.getContext());
   patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
 }
 
index df5d60c..94bacd2 100644 (file)
@@ -473,7 +473,7 @@ func.func @while_values(%arg0: i32, %arg1: f32) {
     scf.condition(%0) %2, %3 : i64, f64
   } do {
   // CHECK:   ^[[AFTER]](%[[ARG4:.*]]: i64, %[[ARG5:.*]]: f64):
-  ^bb0(%arg2: i64, %arg3: f64):  
+  ^bb0(%arg2: i64, %arg3: f64):
     // CHECK:   cf.br ^[[BEFORE]](%{{.*}}, %{{.*}} : i32, f32)
     scf.yield %c0_i32, %cst : i32, f32
   }
@@ -620,3 +620,30 @@ func.func @func_execute_region_elim_multi_yield() {
 // CHECK:   ^[[bb3]](%[[z:.+]]: i64):
 // CHECK:     "test.bar"(%[[z]])
 // CHECK:     return
+
+// SWITCH-LABEL: @index_switch
+func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
+  // SWITCH: cf.switch %arg0 : index
+  // SWITCH-NEXT: default: ^bb3
+  // SWITCH-NEXT: 0: ^bb1
+  // SWITCH-NEXT: 1: ^bb2
+  %0 = scf.index_switch %i -> i32
+  // SWITCH: ^bb1:
+  case 0 {
+    // SWITCH-NEXT: llvm.br ^bb4(%arg1
+    scf.yield %a : i32
+  }
+  // SWITCH: ^bb2:
+  case 1 {
+    // SWITCH-NEXT: llvm.br ^bb4(%arg2
+    scf.yield %b : i32
+  }
+  // SWITCH: ^bb3:
+  default {
+    // SWITCH-NEXT: llvm.br ^bb4(%arg3
+    scf.yield %c : i32
+  }
+  // SWITCH: ^bb4(%[[V:.*]]: i32
+  // SWITCH-NEXT: return %[[V]]
+  return %0 : i32
+}