From: Jeff Niu Date: Sun, 16 Oct 2022 18:30:08 +0000 (-0700) Subject: [mlir][scf] Add an IndexSwitchOp X-Git-Tag: upstream/17.0.6~29905 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=07d8fe9391a1bda7bb5fdfd17a5b897df7a003f5;p=platform%2Fupstream%2Fllvm.git [mlir][scf] Add an IndexSwitchOp The `scf.index_switch` is a control-flow operation that branches to one of the given regions based on the values of the argument and the cases. The argument is always of type `index`. Example: ```mlir %0 = scf.index_switch %arg0 -> i32 case 2 { %1 = arith.constant 10 : i32 scf.yield %1 : i32 } case 5 { %2 = arith.constant 20 : i32 scf.yield %2 : i32 } default { %3 = arith.constant 30 : i32 scf.yield %3 : i32 } ``` Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D136003 --- diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index 41b7099..38dd0ac 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -235,7 +235,7 @@ def ForOp : SCF_Op<"for", } /// Return the `index`-th region iteration argument. BlockArgument getRegionIterArg(unsigned index) { - assert(index < getNumRegionIterArgs() && + assert(index < getNumRegionIterArgs() && "expected an index less than the number of region iter args"); return getBody()->getArguments().drop_front(getNumInductionVars())[index]; } @@ -434,7 +434,7 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [ ``` Example with thread_dim_mapping attribute: - + ```mlir // // Sequential context. @@ -456,7 +456,7 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [ ``` Example with privatized tensors: - + ```mlir %t0 = ... %t1 = ... @@ -527,8 +527,8 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [ return getBody()->getArguments().drop_front(getRank()); } - /// Return the thread indices in the order specified by the - /// thread_dim_mapping attribute. Return failure is + /// Return the thread indices in the order specified by the + /// thread_dim_mapping attribute. Return failure is /// thread_dim_mapping is not a valid permutation. FailureOr> getPermutedThreadIndices(); @@ -989,12 +989,76 @@ def WhileOp : SCF_Op<"while", } //===----------------------------------------------------------------------===// +// IndexSwitchOp +//===----------------------------------------------------------------------===// + +def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects, + SingleBlockImplicitTerminator<"scf::YieldOp">, + DeclareOpInterfaceMethods]> { + let summary = "switch-case operation on an index argument"; + let description = [{ + The `scf.index_switch` is a control-flow operation that branches to one of + the given regions based on the values of the argument and the cases. The + argument is always of type `index`. + + The operation always has a "default" region and any number of case regions + denoted by integer constants. Control-flow transfers to the case region + whose constant value equals the value of the argument. If the argument does + not equal any of the case values, control-flow transfer to the "default" + region. + + Example: + + ```mlir + %0 = scf.index_switch %arg0 : index -> i32 + case 2 { + %1 = arith.constant 10 : i32 + scf.yield %1 : i32 + } + case 5 { + %2 = arith.constant 20 : i32 + scf.yield %2 : i32 + } + default { + %3 = arith.constant 30 : i32 + scf.yield %3 : i32 + } + ``` + }]; + + let arguments = (ins Index:$arg, DenseI64ArrayAttr:$cases); + let results = (outs Variadic:$results); + let regions = (region SizedRegion<1>:$defaultRegion, + VariadicRegion>:$caseRegions); + + let assemblyFormat = [{ + $arg attr-dict (`->` type($results)^)? + custom($cases, $caseRegions) `\n` + `` `default` $defaultRegion + }]; + + let extraClassDeclaration = [{ + /// Get the number of cases. + unsigned getNumCases(); + + /// Get the default region body. + Block &getDefaultBlock(); + + /// Get the body of a case region. + Block &getCaseBlock(unsigned idx); + }]; + + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// def YieldOp : SCF_Op<"yield", [Pure, ReturnLike, Terminator, - ParentOneOf<["ExecuteRegionOp, ForOp", - "IfOp, ParallelOp, WhileOp"]>]> { + ParentOneOf<["ExecuteRegionOp, ForOp", "IfOp", "IndexSwitchOp", + "ParallelOp", "WhileOp"]>]> { let summary = "loop yield and termination operation"; let description = [{ "scf.yield" yields an SSA value from the SCF dialect op region and diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index b3e84f0..cd1d382 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -3387,6 +3387,137 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results, } //===----------------------------------------------------------------------===// +// IndexSwitchOp +//===----------------------------------------------------------------------===// + +/// Parse the case regions and values. +static ParseResult +parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, + SmallVectorImpl> &caseRegions) { + SmallVector caseValues; + while (succeeded(p.parseOptionalKeyword("case"))) { + int64_t value; + Region ®ion = + *caseRegions.emplace_back(std::make_unique()).get(); + if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{})) + return failure(); + caseValues.push_back(value); + } + cases = p.getBuilder().getDenseI64ArrayAttr(caseValues); + return success(); +} + +/// Print the case regions and values. +static void printSwitchCases(OpAsmPrinter &p, Operation *op, + DenseI64ArrayAttr cases, RegionRange caseRegions) { + for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) { + p.printNewline(); + p << "case " << value << ' '; + p.printRegion(*region, /*printEntryBlockArgs=*/false); + } +} + +LogicalResult scf::IndexSwitchOp::verify() { + if (getCases().size() != getCaseRegions().size()) { + return emitOpError("has ") + << getCaseRegions().size() << " case regions but " + << getCases().size() << " case values"; + } + + DenseSet valueSet; + for (int64_t value : getCases()) + if (!valueSet.insert(value).second) + return emitOpError("has duplicate case value: ") << value; + + auto verifyRegion = [&](Region ®ion, const Twine &name) -> LogicalResult { + auto yield = cast(region.front().getTerminator()); + if (yield.getNumOperands() != getNumResults()) { + return (emitOpError("expected each region to return ") + << getNumResults() << " values, but " << name << " returns " + << yield.getNumOperands()) + .attachNote(yield.getLoc()) + << "see yield operation here"; + } + for (auto [idx, result, operand] : + llvm::zip(llvm::seq(0, getNumResults()), getResultTypes(), + yield.getOperandTypes())) { + if (result == operand) + continue; + return (emitOpError("expected result #") + << idx << " of each region to be " << result) + .attachNote(yield.getLoc()) + << name << " returns " << operand << " here"; + } + return success(); + }; + + if (failed(verifyRegion(getDefaultRegion(), "default region"))) + return failure(); + for (auto &[idx, caseRegion] : llvm::enumerate(getCaseRegions())) + if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx)))) + return failure(); + + return success(); +} + +unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); } + +Block &scf::IndexSwitchOp::getDefaultBlock() { + return getDefaultRegion().front(); +} + +Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) { + assert(idx < getNumCases() && "case index out-of-bounds"); + return getCaseRegions()[idx].front(); +} + +void IndexSwitchOp::getSuccessorRegions( + Optional index, ArrayRef operands, + SmallVectorImpl &successors) { + // All regions branch back to the parent op. + if (index) { + successors.emplace_back(getResults()); + return; + } + + // If a constant was not provided, all regions are possible successors. + auto operandValue = operands.front().dyn_cast_or_null(); + if (!operandValue) { + for (Region &caseRegion : getCaseRegions()) + successors.emplace_back(&caseRegion); + successors.emplace_back(&getDefaultRegion()); + return; + } + + // Otherwise, try to find a case with a matching value. If not, the default + // region is the only successor. + for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) { + if (caseValue == operandValue.getInt()) { + successors.emplace_back(&caseRegion); + return; + } + } + successors.emplace_back(&getDefaultRegion()); +} + +void IndexSwitchOp::getRegionInvocationBounds( + ArrayRef operands, SmallVectorImpl &bounds) { + auto operandValue = operands.front().dyn_cast_or_null(); + if (!operandValue) { + // All regions are invoked at most once. + bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1)); + return; + } + + unsigned liveIndex = getNumRegions() - 1; + auto it = llvm::find(getCases(), operandValue.getInt()); + if (it != getCases().end()) + liveIndex = std::distance(getCases().begin(), it); + for (unsigned i = 0, e = getNumRegions(); i < e; ++i) + bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex); +} + +//===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir index b79ecb4..fa91ba0 100644 --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -428,7 +428,7 @@ func.func @parallel_invalid_yield( func.func @yield_invalid_parent_op() { "my.op"() ({ - // expected-error@+1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.parallel, scf.while'}} + // expected-error@+1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.index_switch, scf.parallel, scf.while'}} scf.yield }) : () -> () return @@ -572,3 +572,57 @@ func.func @wrong_terminator_op(%in: tensor<100xf32>, %out: tensor<100xf32>) { } return } + +// ----- + +func.func @switch_wrong_case_count(%arg0: index) { + // expected-error @below {{'scf.index_switch' op has 0 case regions but 1 case values}} + "scf.index_switch"(%arg0) ({ + scf.yield + }) {cases = array} : (index) -> () + return +} + +// ----- + +func.func @switch_duplicate_case(%arg0: index) { + // expected-error @below {{'scf.index_switch' op has duplicate case value: 0}} + scf.index_switch %arg0 + case 0 { + scf.yield + } + case 0 { + scf.yield + } + default { + scf.yield + } + return +} + +// ----- + +func.func @switch_wrong_types(%arg0: index) { + // expected-error @below {{'scf.index_switch' op expected each region to return 0 values, but default region returns 1}} + scf.index_switch %arg0 + default { + // expected-note @below {{see yield operation here}} + scf.yield %arg0 : index + } + return +} + +// ----- + +func.func @switch_wrong_types(%arg0: index, %arg1: i32) { + // expected-error @below {{'scf.index_switch' op expected result #0 of each region to be 'index'}} + scf.index_switch %arg0 -> index + case 0 { + // expected-note @below {{case region #0 returns 'i32' here}} + scf.yield %arg1 : i32 + } + default { + scf.yield %arg0 : index + } + return +} diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir index c1fa4e6..e563838 100644 --- a/mlir/test/Dialect/SCF/ops.mlir +++ b/mlir/test/Dialect/SCF/ops.mlir @@ -346,3 +346,36 @@ func.func @elide_terminator() -> () { } {thread_dim_mapping = [42]} return } + +// CHECK-LABEL: @switch +func.func @switch(%arg0: index) -> i32 { + // CHECK: %{{.*}} = scf.index_switch %arg0 -> i32 + %0 = scf.index_switch %arg0 -> i32 + // CHECK-NEXT: case 2 { + case 2 { + // CHECK-NEXT: arith.constant + %c10_i32 = arith.constant 10 : i32 + // CHECK-NEXT: scf.yield %{{.*}} : i32 + scf.yield %c10_i32 : i32 + // CHECK-NEXT: } + } + // CHECK-NEXT: case 5 { + case 5 { + %c20_i32 = arith.constant 20 : i32 + scf.yield %c20_i32 : i32 + } + // CHECK: default { + default { + %c30_i32 = arith.constant 30 : i32 + scf.yield %c30_i32 : i32 + } + + // CHECK: scf.index_switch %arg0 + scf.index_switch %arg0 + // CHECK-NEXT: default { + default { + scf.yield + } + + return %0 : i32 +}