[mlir][scf] Add an IndexSwitchOp
authorJeff Niu <jeff@modular.com>
Sun, 16 Oct 2022 18:30:08 +0000 (11:30 -0700)
committerJeff Niu <jeff@modular.com>
Fri, 21 Oct 2022 16:21:10 +0000 (09:21 -0700)
The `scf.index_switch` is a control-flow operation that branches to one of the
given regions based on the values of the argument and the cases. The
argument is always of type `index`.

Example:

```mlir
%0 = scf.index_switch %arg0 -> i32
case 2 {
  %1 = arith.constant 10 : i32
  scf.yield %1 : i32
}
case 5 {
  %2 = arith.constant 20 : i32
  scf.yield %2 : i32
}
default {
  %3 = arith.constant 30 : i32
  scf.yield %3 : i32
}
```

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D136003

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/test/Dialect/SCF/invalid.mlir
mlir/test/Dialect/SCF/ops.mlir

index 41b7099..38dd0ac 100644 (file)
@@ -235,7 +235,7 @@ def ForOp : SCF_Op<"for",
     }
     /// Return the `index`-th region iteration argument.
     BlockArgument getRegionIterArg(unsigned index) {
-      assert(index < getNumRegionIterArgs() && 
+      assert(index < getNumRegionIterArgs() &&
         "expected an index less than the number of region iter args");
       return getBody()->getArguments().drop_front(getNumInductionVars())[index];
     }
@@ -434,7 +434,7 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
     ```
 
     Example with thread_dim_mapping attribute:
-    
+
     ```mlir
     //
     // Sequential context.
@@ -456,7 +456,7 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
     ```
 
     Example with privatized tensors:
-    
+
     ```mlir
     %t0 = ...
     %t1 = ...
@@ -527,8 +527,8 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
       return getBody()->getArguments().drop_front(getRank());
     }
 
-    /// Return the thread indices in the order specified by the 
-    /// thread_dim_mapping attribute. Return failure is 
+    /// Return the thread indices in the order specified by the
+    /// thread_dim_mapping attribute. Return failure is
     /// thread_dim_mapping is not a valid permutation.
     FailureOr<SmallVector<Value>> getPermutedThreadIndices();
 
@@ -989,12 +989,76 @@ def WhileOp : SCF_Op<"while",
 }
 
 //===----------------------------------------------------------------------===//
+// IndexSwitchOp
+//===----------------------------------------------------------------------===//
+
+def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
+    SingleBlockImplicitTerminator<"scf::YieldOp">,
+    DeclareOpInterfaceMethods<RegionBranchOpInterface,
+                              ["getRegionInvocationBounds"]>]> {
+  let summary = "switch-case operation on an index argument";
+  let description = [{
+    The `scf.index_switch` is a control-flow operation that branches to one of
+    the given regions based on the values of the argument and the cases. The
+    argument is always of type `index`.
+
+    The operation always has a "default" region and any number of case regions
+    denoted by integer constants. Control-flow transfers to the case region
+    whose constant value equals the value of the argument. If the argument does
+    not equal any of the case values, control-flow transfer to the "default"
+    region.
+
+    Example:
+
+    ```mlir
+    %0 = scf.index_switch %arg0 : index -> i32
+    case 2 {
+      %1 = arith.constant 10 : i32
+      scf.yield %1 : i32
+    }
+    case 5 {
+      %2 = arith.constant 20 : i32
+      scf.yield %2 : i32
+    }
+    default {
+      %3 = arith.constant 30 : i32
+      scf.yield %3 : i32
+    }
+    ```
+  }];
+
+  let arguments = (ins Index:$arg, DenseI64ArrayAttr:$cases);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region SizedRegion<1>:$defaultRegion,
+                        VariadicRegion<SizedRegion<1>>:$caseRegions);
+
+  let assemblyFormat = [{
+    $arg attr-dict (`->` type($results)^)?
+    custom<SwitchCases>($cases, $caseRegions) `\n`
+    `` `default` $defaultRegion
+  }];
+
+  let extraClassDeclaration = [{
+    /// Get the number of cases.
+    unsigned getNumCases();
+
+    /// Get the default region body.
+    Block &getDefaultBlock();
+
+    /// Get the body of a case region.
+    Block &getCaseBlock(unsigned idx);
+  }];
+
+  let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
 // YieldOp
 //===----------------------------------------------------------------------===//
 
 def YieldOp : SCF_Op<"yield", [Pure, ReturnLike, Terminator,
-                               ParentOneOf<["ExecuteRegionOp, ForOp",
-                                            "IfOp, ParallelOp, WhileOp"]>]> {
+    ParentOneOf<["ExecuteRegionOp, ForOp", "IfOp", "IndexSwitchOp",
+                 "ParallelOp", "WhileOp"]>]> {
   let summary = "loop yield and termination operation";
   let description = [{
     "scf.yield" yields an SSA value from the SCF dialect op region and
index b3e84f0..cd1d382 100644 (file)
@@ -3387,6 +3387,137 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 //===----------------------------------------------------------------------===//
+// IndexSwitchOp
+//===----------------------------------------------------------------------===//
+
+/// Parse the case regions and values.
+static ParseResult
+parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases,
+                 SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
+  SmallVector<int64_t> caseValues;
+  while (succeeded(p.parseOptionalKeyword("case"))) {
+    int64_t value;
+    Region &region =
+        *caseRegions.emplace_back(std::make_unique<Region>()).get();
+    if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
+      return failure();
+    caseValues.push_back(value);
+  }
+  cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
+  return success();
+}
+
+/// Print the case regions and values.
+static void printSwitchCases(OpAsmPrinter &p, Operation *op,
+                             DenseI64ArrayAttr cases, RegionRange caseRegions) {
+  for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
+    p.printNewline();
+    p << "case " << value << ' ';
+    p.printRegion(*region, /*printEntryBlockArgs=*/false);
+  }
+}
+
+LogicalResult scf::IndexSwitchOp::verify() {
+  if (getCases().size() != getCaseRegions().size()) {
+    return emitOpError("has ")
+           << getCaseRegions().size() << " case regions but "
+           << getCases().size() << " case values";
+  }
+
+  DenseSet<int64_t> valueSet;
+  for (int64_t value : getCases())
+    if (!valueSet.insert(value).second)
+      return emitOpError("has duplicate case value: ") << value;
+
+  auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
+    auto yield = cast<YieldOp>(region.front().getTerminator());
+    if (yield.getNumOperands() != getNumResults()) {
+      return (emitOpError("expected each region to return ")
+              << getNumResults() << " values, but " << name << " returns "
+              << yield.getNumOperands())
+                 .attachNote(yield.getLoc())
+             << "see yield operation here";
+    }
+    for (auto [idx, result, operand] :
+         llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
+                   yield.getOperandTypes())) {
+      if (result == operand)
+        continue;
+      return (emitOpError("expected result #")
+              << idx << " of each region to be " << result)
+                 .attachNote(yield.getLoc())
+             << name << " returns " << operand << " here";
+    }
+    return success();
+  };
+
+  if (failed(verifyRegion(getDefaultRegion(), "default region")))
+    return failure();
+  for (auto &[idx, caseRegion] : llvm::enumerate(getCaseRegions()))
+    if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
+      return failure();
+
+  return success();
+}
+
+unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
+
+Block &scf::IndexSwitchOp::getDefaultBlock() {
+  return getDefaultRegion().front();
+}
+
+Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
+  assert(idx < getNumCases() && "case index out-of-bounds");
+  return getCaseRegions()[idx].front();
+}
+
+void IndexSwitchOp::getSuccessorRegions(
+    Optional<unsigned> index, ArrayRef<Attribute> operands,
+    SmallVectorImpl<RegionSuccessor> &successors) {
+  // All regions branch back to the parent op.
+  if (index) {
+    successors.emplace_back(getResults());
+    return;
+  }
+
+  // If a constant was not provided, all regions are possible successors.
+  auto operandValue = operands.front().dyn_cast_or_null<IntegerAttr>();
+  if (!operandValue) {
+    for (Region &caseRegion : getCaseRegions())
+      successors.emplace_back(&caseRegion);
+    successors.emplace_back(&getDefaultRegion());
+    return;
+  }
+
+  // Otherwise, try to find a case with a matching value. If not, the default
+  // region is the only successor.
+  for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
+    if (caseValue == operandValue.getInt()) {
+      successors.emplace_back(&caseRegion);
+      return;
+    }
+  }
+  successors.emplace_back(&getDefaultRegion());
+}
+
+void IndexSwitchOp::getRegionInvocationBounds(
+    ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
+  auto operandValue = operands.front().dyn_cast_or_null<IntegerAttr>();
+  if (!operandValue) {
+    // All regions are invoked at most once.
+    bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
+    return;
+  }
+
+  unsigned liveIndex = getNumRegions() - 1;
+  auto it = llvm::find(getCases(), operandValue.getInt());
+  if (it != getCases().end())
+    liveIndex = std::distance(getCases().begin(), it);
+  for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
+    bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
+}
+
+//===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
 
index b79ecb4..fa91ba0 100644 (file)
@@ -428,7 +428,7 @@ func.func @parallel_invalid_yield(
 
 func.func @yield_invalid_parent_op() {
   "my.op"() ({
-   // expected-error@+1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.parallel, scf.while'}}
+   // expected-error@+1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.index_switch, scf.parallel, scf.while'}}
    scf.yield
   }) : () -> ()
   return
@@ -572,3 +572,57 @@ func.func @wrong_terminator_op(%in: tensor<100xf32>, %out: tensor<100xf32>) {
   }
   return
 }
+
+// -----
+
+func.func @switch_wrong_case_count(%arg0: index) {
+  // expected-error @below {{'scf.index_switch' op has 0 case regions but 1 case values}}
+  "scf.index_switch"(%arg0) ({
+    scf.yield
+  }) {cases = array<i64: 1>} : (index) -> ()
+  return
+}
+
+// -----
+
+func.func @switch_duplicate_case(%arg0: index) {
+  // expected-error @below {{'scf.index_switch' op has duplicate case value: 0}}
+  scf.index_switch %arg0
+  case 0 {
+    scf.yield
+  }
+  case 0 {
+    scf.yield
+  }
+  default {
+    scf.yield
+  }
+  return
+}
+
+// -----
+
+func.func @switch_wrong_types(%arg0: index) {
+  // expected-error @below {{'scf.index_switch' op expected each region to return 0 values, but default region returns 1}}
+  scf.index_switch %arg0
+  default {
+    // expected-note @below {{see yield operation here}}
+    scf.yield %arg0 : index
+  }
+  return
+}
+
+// -----
+
+func.func @switch_wrong_types(%arg0: index, %arg1: i32) {
+  // expected-error @below {{'scf.index_switch' op expected result #0 of each region to be 'index'}}
+  scf.index_switch %arg0 -> index
+  case 0 {
+    // expected-note @below {{case region #0 returns 'i32' here}}
+    scf.yield %arg1 : i32
+  }
+  default {
+    scf.yield %arg0 : index
+  }
+  return
+}
index c1fa4e6..e563838 100644 (file)
@@ -346,3 +346,36 @@ func.func @elide_terminator() -> () {
   } {thread_dim_mapping = [42]}
   return
 }
+
+// CHECK-LABEL: @switch
+func.func @switch(%arg0: index) -> i32 {
+  // CHECK: %{{.*}} = scf.index_switch %arg0 -> i32
+  %0 = scf.index_switch %arg0 -> i32
+  // CHECK-NEXT: case 2 {
+  case 2 {
+    // CHECK-NEXT: arith.constant
+    %c10_i32 = arith.constant 10 : i32
+    // CHECK-NEXT: scf.yield %{{.*}} : i32
+    scf.yield %c10_i32 : i32
+    // CHECK-NEXT: }
+  }
+  // CHECK-NEXT: case 5 {
+  case 5 {
+    %c20_i32 = arith.constant 20 : i32
+    scf.yield %c20_i32 : i32
+  }
+  // CHECK: default {
+  default {
+    %c30_i32 = arith.constant 30 : i32
+    scf.yield %c30_i32 : i32
+  }
+
+  // CHECK: scf.index_switch %arg0
+  scf.index_switch %arg0
+  // CHECK-NEXT: default {
+  default {
+    scf.yield
+  }
+
+  return %0 : i32
+}