let hasFolder = 1;
}
+
+//===----------------------------------------------------------------------===//
+// SwitchOp
+//===----------------------------------------------------------------------===//
+
+def SwitchOp : Std_Op<"switch",
+ [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
+ NoSideEffect, Terminator]> {
+ let summary = "switch operation";
+ let description = [{
+ The `switch` terminator operation represents a switch on a signless integer
+ value. If the flag matches one of the specified cases, then the
+ corresponding destination is jumped to. If the flag does not match any of
+ the cases, the default destination is jumped to. The count and types of
+ operands must align with the arguments in the corresponding target blocks.
+
+ Example:
+
+ ```mlir
+ switch %flag : i32, [
+ default: ^bb1(%a : i32),
+ 42: ^bb1(%b : i32),
+ 43: ^bb3(%c : i32)
+ ]
+ ```
+ }];
+
+ let arguments = (ins AnyInteger:$flag,
+ Variadic<AnyType>:$defaultOperands,
+ Variadic<AnyType>:$caseOperands,
+ OptionalAttr<AnyIntElementsAttr>:$case_values,
+ OptionalAttr<I32ElementsAttr>:$case_operand_offsets);
+ let successors = (successor
+ AnySuccessor:$defaultDestination,
+ VariadicSuccessor<AnySuccessor>:$caseDestinations);
+ let builders = [
+ OpBuilder<(ins "Value":$flag,
+ "Block *":$defaultDestination,
+ "ValueRange":$defaultOperands,
+ CArg<"ArrayRef<APInt>", "{}">:$caseValues,
+ CArg<"BlockRange", "{}">:$caseDestinations,
+ CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
+ OpBuilder<(ins "Value":$flag,
+ "Block *":$defaultDestination,
+ "ValueRange":$defaultOperands,
+ CArg<"ArrayRef<int32_t>", "{}">:$caseValues,
+ CArg<"BlockRange", "{}">:$caseDestinations,
+ CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
+ OpBuilder<(ins "Value":$flag,
+ "Block *":$defaultDestination,
+ "ValueRange":$defaultOperands,
+ CArg<"DenseIntElementsAttr", "{}">:$caseValues,
+ CArg<"BlockRange", "{}">:$caseDestinations,
+ CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>
+ ];
+
+ let assemblyFormat = [{
+ $flag `:` type($flag) `,` `[` `\n`
+ custom<SwitchOpCases>(ref(type($flag)),$defaultDestination,
+ $defaultOperands,
+ type($defaultOperands),
+ $case_values,
+ $caseDestinations,
+ $caseOperands,
+ type($caseOperands),
+ $case_operand_offsets)
+ `]`
+ attr-dict
+ }];
+
+ let extraClassDeclaration = [{
+ /// Return the operands for the case destination block at the given index.
+ OperandRange getCaseOperands(unsigned index);
+
+ /// Return a mutable range of operands for the case destination block at the
+ /// given index.
+ MutableOperandRange getCaseOperandsMutable(unsigned index);
+ }];
+
+ let hasCanonicalizer = 1;
+}
+
//===----------------------------------------------------------------------===//
// TruncateIOp
//===----------------------------------------------------------------------===//
/// Given a successor, try to collapse it to a new destination if it only
/// contains a passthrough unconditional branch. If the successor is
/// collapsable, `successor` and `successorOperands` are updated to reference
-/// the new destination and values. `argStorage` is an optional storage to use
-/// if operands to the collapsed successor need to be remapped.
+/// the new destination and values. `argStorage` is used as storage if operands
+/// to the collapsed successor need to be remapped. It must outlive uses of
+/// successorOperands.
static LogicalResult collapseBranch(Block *&successor,
ValueRange &successorOperands,
SmallVectorImpl<Value> &argStorage) {
}
//===----------------------------------------------------------------------===//
+// SwitchOp
+//===----------------------------------------------------------------------===//
+
+void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
+ Block *defaultDestination, ValueRange defaultOperands,
+ DenseIntElementsAttr caseValues,
+ BlockRange caseDestinations,
+ ArrayRef<ValueRange> caseOperands) {
+ SmallVector<Value> flattenedCaseOperands;
+ SmallVector<int32_t> caseOperandOffsets;
+ int32_t offset = 0;
+ for (ValueRange operands : caseOperands) {
+ flattenedCaseOperands.append(operands.begin(), operands.end());
+ caseOperandOffsets.push_back(offset);
+ offset += operands.size();
+ }
+ DenseIntElementsAttr caseOperandOffsetsAttr;
+ if (!caseOperandOffsets.empty())
+ caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets);
+
+ build(builder, result, value, defaultOperands, flattenedCaseOperands,
+ caseValues, caseOperandOffsetsAttr, defaultDestination,
+ caseDestinations);
+}
+
+void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
+ Block *defaultDestination, ValueRange defaultOperands,
+ ArrayRef<APInt> caseValues, BlockRange caseDestinations,
+ ArrayRef<ValueRange> caseOperands) {
+ DenseIntElementsAttr caseValuesAttr;
+ if (!caseValues.empty()) {
+ ShapedType caseValueType = VectorType::get(
+ static_cast<int64_t>(caseValues.size()), value.getType());
+ caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
+ }
+ build(builder, result, value, defaultDestination, defaultOperands,
+ caseValuesAttr, caseDestinations, caseOperands);
+}
+
+/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
+/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
+static ParseResult
+parseSwitchOpCases(OpAsmParser &parser, Type &flagType,
+ Block *&defaultDestination,
+ SmallVectorImpl<OpAsmParser::OperandType> &defaultOperands,
+ SmallVectorImpl<Type> &defaultOperandTypes,
+ DenseIntElementsAttr &caseValues,
+ SmallVectorImpl<Block *> &caseDestinations,
+ SmallVectorImpl<OpAsmParser::OperandType> &caseOperands,
+ SmallVectorImpl<Type> &caseOperandTypes,
+ DenseIntElementsAttr &caseOperandOffsets) {
+
+ if (failed(parser.parseKeyword("default")) || failed(parser.parseColon()) ||
+ failed(parser.parseSuccessor(defaultDestination)))
+ return failure();
+ if (succeeded(parser.parseOptionalLParen())) {
+ if (failed(parser.parseRegionArgumentList(defaultOperands)) ||
+ failed(parser.parseColonTypeList(defaultOperandTypes)) ||
+ failed(parser.parseRParen()))
+ return failure();
+ }
+
+ SmallVector<APInt> values;
+ SmallVector<int32_t> offsets;
+ unsigned bitWidth = flagType.getIntOrFloatBitWidth();
+ int64_t offset = 0;
+ while (succeeded(parser.parseOptionalComma())) {
+ int64_t value = 0;
+ if (failed(parser.parseInteger(value)))
+ return failure();
+ values.push_back(APInt(bitWidth, value));
+
+ Block *destination;
+ SmallVector<OpAsmParser::OperandType> operands;
+ if (failed(parser.parseColon()) ||
+ failed(parser.parseSuccessor(destination)))
+ return failure();
+ if (succeeded(parser.parseOptionalLParen())) {
+ if (failed(parser.parseRegionArgumentList(operands)) ||
+ failed(parser.parseColonTypeList(caseOperandTypes)) ||
+ failed(parser.parseRParen()))
+ return failure();
+ }
+ caseDestinations.push_back(destination);
+ caseOperands.append(operands.begin(), operands.end());
+ offsets.push_back(offset);
+ offset += operands.size();
+ }
+
+ if (values.empty())
+ return success();
+
+ Builder &builder = parser.getBuilder();
+ ShapedType caseValueType =
+ VectorType::get(static_cast<int64_t>(values.size()), flagType);
+ caseValues = DenseIntElementsAttr::get(caseValueType, values);
+ caseOperandOffsets = builder.getI32VectorAttr(offsets);
+
+ return success();
+}
+
+static void printSwitchOpCases(
+ OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
+ OperandRange defaultOperands, TypeRange defaultOperandTypes,
+ DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
+ OperandRange caseOperands, TypeRange caseOperandTypes,
+ ElementsAttr caseOperandOffsets) {
+ p << " default: ";
+ p.printSuccessorAndUseList(defaultDestination, defaultOperands);
+
+ if (!caseValues)
+ return;
+
+ for (int64_t i = 0, size = caseValues.size(); i < size; ++i) {
+ p << ',';
+ p.printNewline();
+ p << " ";
+ p << caseValues.getValue<APInt>(i).getLimitedValue();
+ p << ": ";
+ p.printSuccessorAndUseList(caseDestinations[i], op.getCaseOperands(i));
+ }
+ p.printNewline();
+}
+
+static LogicalResult verify(SwitchOp op) {
+ auto caseValues = op.case_values();
+ auto caseDestinations = op.caseDestinations();
+
+ if (!caseValues && caseDestinations.empty())
+ return success();
+
+ Type flagType = op.flag().getType();
+ Type caseValueType = caseValues->getType().getElementType();
+ if (caseValueType != flagType)
+ return op.emitOpError()
+ << "'flag' type (" << flagType << ") should match case value type ("
+ << caseValueType << ")";
+
+ if (caseValues &&
+ caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
+ return op.emitOpError() << "number of case values (" << caseValues->size()
+ << ") should match number of "
+ "case destinations ("
+ << caseDestinations.size() << ")";
+ return success();
+}
+
+OperandRange SwitchOp::getCaseOperands(unsigned index) {
+ return getCaseOperandsMutable(index);
+}
+
+MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) {
+ MutableOperandRange caseOperands = caseOperandsMutable();
+ if (!case_operand_offsets()) {
+ assert(caseOperands.size() == 0 &&
+ "non-empty case operands must have offsets");
+ return caseOperands;
+ }
+
+ ElementsAttr offsets = case_operand_offsets().getValue();
+ assert(index < offsets.size() && "invalid case operand offset index");
+
+ int64_t begin = offsets.getValue(index).cast<IntegerAttr>().getInt();
+ int64_t end = index + 1 == offsets.size()
+ ? caseOperands.size()
+ : offsets.getValue(index + 1).cast<IntegerAttr>().getInt();
+ return caseOperandsMutable().slice(begin, end - begin);
+}
+
+Optional<MutableOperandRange>
+SwitchOp::getMutableSuccessorOperands(unsigned index) {
+ assert(index < getNumSuccessors() && "invalid successor index");
+ return index == 0 ? defaultOperandsMutable()
+ : getCaseOperandsMutable(index - 1);
+}
+
+Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
+ Optional<DenseIntElementsAttr> caseValues = case_values();
+
+ if (!caseValues)
+ return defaultDestination();
+
+ SuccessorRange caseDests = caseDestinations();
+ if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) {
+ for (int64_t i = 0, size = case_values()->size(); i < size; ++i)
+ if (value == caseValues->getValue<IntegerAttr>(i))
+ return caseDests[i];
+ return defaultDestination();
+ }
+ return nullptr;
+}
+
+/// switch %flag : i32, [
+/// default: ^bb1
+/// ]
+/// -> br ^bb1
+static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
+ PatternRewriter &rewriter) {
+ if (!op.caseDestinations().empty())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<BranchOp>(op, op.defaultDestination(),
+ op.defaultOperands());
+ return success();
+}
+
+/// switch %flag : i32, [
+/// default: ^bb1,
+/// 42: ^bb1,
+/// 43: ^bb2
+/// ]
+/// ->
+/// switch %flag : i32, [
+/// default: ^bb1,
+/// 43: ^bb2
+/// ]
+static LogicalResult
+dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
+ SmallVector<Block *> newCaseDestinations;
+ SmallVector<ValueRange> newCaseOperands;
+ SmallVector<APInt> newCaseValues;
+ bool requiresChange = false;
+ auto caseValues = op.case_values();
+ auto caseDests = op.caseDestinations();
+
+ for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
+ if (caseDests[i] == op.defaultDestination() &&
+ op.getCaseOperands(i) == op.defaultOperands()) {
+ requiresChange = true;
+ continue;
+ }
+ newCaseDestinations.push_back(caseDests[i]);
+ newCaseOperands.push_back(op.getCaseOperands(i));
+ newCaseValues.push_back(caseValues->getValue<APInt>(i));
+ }
+
+ if (!requiresChange)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<SwitchOp>(op, op.flag(), op.defaultDestination(),
+ op.defaultOperands(), newCaseValues,
+ newCaseDestinations, newCaseOperands);
+ return success();
+}
+
+/// Helper for folding a switch with a constant value.
+/// switch %c_42 : i32, [
+/// default: ^bb1 ,
+/// 42: ^bb2,
+/// 43: ^bb3
+/// ]
+/// -> br ^bb2
+static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
+ APInt caseValue) {
+ auto caseValues = op.case_values();
+ for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
+ if (caseValues->getValue<APInt>(i) == caseValue) {
+ rewriter.replaceOpWithNewOp<BranchOp>(op, op.caseDestinations()[i],
+ op.getCaseOperands(i));
+ return;
+ }
+ }
+ rewriter.replaceOpWithNewOp<BranchOp>(op, op.defaultDestination(),
+ op.defaultOperands());
+}
+
+/// switch %c_42 : i32, [
+/// default: ^bb1,
+/// 42: ^bb2,
+/// 43: ^bb3
+/// ]
+/// -> br ^bb2
+static LogicalResult simplifyConstSwitchValue(SwitchOp op,
+ PatternRewriter &rewriter) {
+ APInt caseValue;
+ if (!matchPattern(op.flag(), m_ConstantInt(&caseValue)))
+ return failure();
+
+ foldSwitch(op, rewriter, caseValue);
+ return success();
+}
+
+/// switch %c_42 : i32, [
+/// default: ^bb1,
+/// 42: ^bb2,
+/// ]
+/// ^bb2:
+/// br ^bb3
+/// ->
+/// switch %c_42 : i32, [
+/// default: ^bb1,
+/// 42: ^bb3,
+/// ]
+static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
+ PatternRewriter &rewriter) {
+
+ SmallVector<Block *> newCaseDests;
+ SmallVector<ValueRange> newCaseOperands;
+ SmallVector<SmallVector<Value>> argStorage;
+ auto caseValues = op.case_values();
+ auto caseDests = op.caseDestinations();
+ bool requiresChange = false;
+ for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
+ Block *caseDest = caseDests[i];
+ ValueRange caseOperands = op.getCaseOperands(i);
+ argStorage.emplace_back();
+ if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
+ requiresChange = true;
+
+ newCaseDests.push_back(caseDest);
+ newCaseOperands.push_back(caseOperands);
+ }
+
+ Block *defaultDest = op.defaultDestination();
+ ValueRange defaultOperands = op.defaultOperands();
+ argStorage.emplace_back();
+
+ if (succeeded(
+ collapseBranch(defaultDest, defaultOperands, argStorage.back())))
+ requiresChange = true;
+
+ if (!requiresChange)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<SwitchOp>(op, op.flag(), defaultDest,
+ defaultOperands, caseValues.getValue(),
+ newCaseDests, newCaseOperands);
+ return success();
+}
+
+/// switch %flag : i32, [
+/// default: ^bb1,
+/// 42: ^bb2,
+/// ]
+/// ^bb2:
+/// switch %flag : i32, [
+/// default: ^bb3,
+/// 42: ^bb4
+/// ]
+/// ->
+/// switch %flag : i32, [
+/// default: ^bb1,
+/// 42: ^bb2,
+/// ]
+/// ^bb2:
+/// br ^bb4
+///
+/// and
+///
+/// switch %flag : i32, [
+/// default: ^bb1,
+/// 42: ^bb2,
+/// ]
+/// ^bb2:
+/// switch %flag : i32, [
+/// default: ^bb3,
+/// 43: ^bb4
+/// ]
+/// ->
+/// switch %flag : i32, [
+/// default: ^bb1,
+/// 42: ^bb2,
+/// ]
+/// ^bb2:
+/// br ^bb3
+static LogicalResult
+simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
+ PatternRewriter &rewriter) {
+ // Check that we have a single distinct predecessor.
+ Block *currentBlock = op->getBlock();
+ Block *predecessor = currentBlock->getSinglePredecessor();
+ if (!predecessor)
+ return failure();
+
+ // Check that the predecessor terminates with a switch branch to this block
+ // and that it branches on the same condition and that this branch isn't the
+ // default destination.
+ auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
+ if (!predSwitch || op.flag() != predSwitch.flag() ||
+ predSwitch.defaultDestination() == currentBlock)
+ return failure();
+
+ // Fold this switch to an unconditional branch.
+ APInt caseValue;
+ bool isDefault = true;
+ SuccessorRange predDests = predSwitch.caseDestinations();
+ Optional<DenseIntElementsAttr> predCaseValues = predSwitch.case_values();
+ for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) {
+ if (currentBlock == predDests[i]) {
+ caseValue = predCaseValues->getValue<APInt>(i);
+ isDefault = false;
+ break;
+ }
+ }
+ if (isDefault)
+ rewriter.replaceOpWithNewOp<BranchOp>(op, op.defaultDestination(),
+ op.defaultOperands());
+ else
+ foldSwitch(op, rewriter, caseValue);
+ return success();
+}
+
+/// switch %flag : i32, [
+/// default: ^bb1,
+/// 42: ^bb2
+/// ]
+/// ^bb1:
+/// switch %flag : i32, [
+/// default: ^bb3,
+/// 42: ^bb4,
+/// 43: ^bb5
+/// ]
+/// ->
+/// switch %flag : i32, [
+/// default: ^bb1,
+/// 42: ^bb2,
+/// ]
+/// ^bb1:
+/// switch %flag : i32, [
+/// default: ^bb3,
+/// 43: ^bb5
+/// ]
+static LogicalResult
+simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
+ PatternRewriter &rewriter) {
+ // Check that we have a single distinct predecessor.
+ Block *currentBlock = op->getBlock();
+ Block *predecessor = currentBlock->getSinglePredecessor();
+ if (!predecessor)
+ return failure();
+
+ // Check that the predecessor terminates with a switch branch to this block
+ // and that it branches on the same condition and that this branch is the
+ // default destination.
+ auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
+ if (!predSwitch || op.flag() != predSwitch.flag() ||
+ predSwitch.defaultDestination() != currentBlock)
+ return failure();
+
+ // Delete case values that are not possible here.
+ DenseSet<APInt> caseValuesToRemove;
+ auto predDests = predSwitch.caseDestinations();
+ auto predCaseValues = predSwitch.case_values();
+ for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
+ if (currentBlock != predDests[i])
+ caseValuesToRemove.insert(predCaseValues->getValue<APInt>(i));
+
+ SmallVector<Block *> newCaseDestinations;
+ SmallVector<ValueRange> newCaseOperands;
+ SmallVector<APInt> newCaseValues;
+ bool requiresChange = false;
+
+ auto caseValues = op.case_values();
+ auto caseDests = op.caseDestinations();
+ for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
+ if (caseValuesToRemove.contains(caseValues->getValue<APInt>(i))) {
+ requiresChange = true;
+ continue;
+ }
+ newCaseDestinations.push_back(caseDests[i]);
+ newCaseOperands.push_back(op.getCaseOperands(i));
+ newCaseValues.push_back(caseValues->getValue<APInt>(i));
+ }
+
+ if (!requiresChange)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<SwitchOp>(op, op.flag(), op.defaultDestination(),
+ op.defaultOperands(), newCaseValues,
+ newCaseDestinations, newCaseOperands);
+ return success();
+}
+
+void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add(&simplifySwitchWithOnlyDefault)
+ .add(&dropSwitchCasesThatMatchDefault)
+ .add(&simplifyConstSwitchValue)
+ .add(&simplifyPassThroughSwitch)
+ .add(&simplifySwitchFromSwitchOnSameCondition)
+ .add(&simplifySwitchFromDefaultSwitchOnSameCondition);
+}
+
+//===----------------------------------------------------------------------===//
// TruncateIOp
//===----------------------------------------------------------------------===//
-// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck --dump-input-context 20 %s
/// Test the folding of BranchOp.
return
}
+
+/// Test the folding of SwitchOp
+
+// CHECK-LABEL: func @switch_only_default(
+// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
+func @switch_only_default(%flag : i32, %caseOperand0 : f32) {
+ // add predecessors for all blocks to avoid other canonicalizations.
+ "foo.pred"() [^bb1, ^bb2] : () -> ()
+ ^bb1:
+ // CHECK-NOT: switch
+ // CHECK: br ^[[BB2:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_0]]
+ switch %flag : i32, [
+ default: ^bb2(%caseOperand0 : f32)
+ ]
+ // CHECK: ^[[BB2]]({{.*}}):
+ ^bb2(%bb2Arg : f32):
+ // CHECK-NEXT: "foo.bb2Terminator"
+ "foo.bb2Terminator"(%bb2Arg) : (f32) -> ()
+}
+
+
+// CHECK-LABEL: func @switch_case_matching_default(
+// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
+func @switch_case_matching_default(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32) {
+ // add predecessors for all blocks to avoid other canonicalizations.
+ "foo.pred"() [^bb1, ^bb2, ^bb3] : () -> ()
+ ^bb1:
+ // CHECK: switch %[[FLAG]]
+ // CHECK-NEXT: default: ^[[BB1:.+]](%[[CASE_OPERAND_0]] : f32)
+ // CHECK-NEXT: 10: ^[[BB2:.+]](%[[CASE_OPERAND_1]] : f32)
+ // CHECK-NEXT: ]
+ switch %flag : i32, [
+ default: ^bb2(%caseOperand0 : f32),
+ 42: ^bb2(%caseOperand0 : f32),
+ 10: ^bb3(%caseOperand1 : f32),
+ 17: ^bb2(%caseOperand0 : f32)
+ ]
+ ^bb2(%bb2Arg : f32):
+ "foo.bb2Terminator"(%bb2Arg) : (f32) -> ()
+ ^bb3(%bb3Arg : f32):
+ "foo.bb3Terminator"(%bb3Arg) : (f32) -> ()
+}
+
+
+// CHECK-LABEL: func @switch_on_const_no_match(
+// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]]
+func @switch_on_const_no_match(%caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
+ // add predecessors for all blocks to avoid other canonicalizations.
+ "foo.pred"() [^bb1, ^bb2, ^bb3, ^bb4] : () -> ()
+ ^bb1:
+ // CHECK-NOT: switch
+ // CHECK: br ^[[BB2:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_0]]
+ %c0_i32 = constant 0 : i32
+ switch %c0_i32 : i32, [
+ default: ^bb2(%caseOperand0 : f32),
+ -1: ^bb3(%caseOperand1 : f32),
+ 1: ^bb4(%caseOperand2 : f32)
+ ]
+ // CHECK: ^[[BB2]]({{.*}}):
+ // CHECK-NEXT: "foo.bb2Terminator"
+ ^bb2(%bb2Arg : f32):
+ "foo.bb2Terminator"(%bb2Arg) : (f32) -> ()
+ ^bb3(%bb3Arg : f32):
+ "foo.bb3Terminator"(%bb3Arg) : (f32) -> ()
+ ^bb4(%bb4Arg : f32):
+ "foo.bb4Terminator"(%bb4Arg) : (f32) -> ()
+}
+
+// CHECK-LABEL: func @switch_on_const_with_match(
+// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]]
+func @switch_on_const_with_match(%caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
+ // add predecessors for all blocks to avoid other canonicalizations.
+ "foo.pred"() [^bb1, ^bb2, ^bb3, ^bb4] : () -> ()
+ ^bb1:
+ // CHECK-NOT: switch
+ // CHECK: br ^[[BB4:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_2]]
+ %c0_i32 = constant 1 : i32
+ switch %c0_i32 : i32, [
+ default: ^bb2(%caseOperand0 : f32),
+ -1: ^bb3(%caseOperand1 : f32),
+ 1: ^bb4(%caseOperand2 : f32)
+ ]
+ ^bb2(%bb2Arg : f32):
+ "foo.bb2Terminator"(%bb2Arg) : (f32) -> ()
+ ^bb3(%bb3Arg : f32):
+ "foo.bb3Terminator"(%bb3Arg) : (f32) -> ()
+ // CHECK: ^[[BB4]]({{.*}}):
+ // CHECK-NEXT: "foo.bb4Terminator"
+ ^bb4(%bb4Arg : f32):
+ "foo.bb4Terminator"(%bb4Arg) : (f32) -> ()
+}
+
+// CHECK-LABEL: func @switch_passthrough(
+// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_3:[a-zA-Z0-9_]+]]
+func @switch_passthrough(%flag : i32,
+ %caseOperand0 : f32,
+ %caseOperand1 : f32,
+ %caseOperand2 : f32,
+ %caseOperand3 : f32) {
+ // add predecessors for all blocks to avoid other canonicalizations.
+ "foo.pred"() [^bb1, ^bb2, ^bb3, ^bb4, ^bb5, ^bb6] : () -> ()
+
+ ^bb1:
+ // CHECK: switch %[[FLAG]]
+ // CHECK-NEXT: default: ^[[BB5:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_0]]
+ // CHECK-NEXT: 43: ^[[BB6:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_1]]
+ // CHECK-NEXT: 44: ^[[BB4:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_2]]
+ // CHECK-NEXT: ]
+ switch %flag : i32, [
+ default: ^bb2(%caseOperand0 : f32),
+ 43: ^bb3(%caseOperand1 : f32),
+ 44: ^bb4(%caseOperand2 : f32)
+ ]
+ ^bb2(%bb2Arg : f32):
+ br ^bb5(%bb2Arg : f32)
+ ^bb3(%bb3Arg : f32):
+ br ^bb6(%bb3Arg : f32)
+ ^bb4(%bb4Arg : f32):
+ "foo.bb4Terminator"(%bb4Arg) : (f32) -> ()
+
+ // CHECK: ^[[BB5]]({{.*}}):
+ // CHECK-NEXT: "foo.bb5Terminator"
+ ^bb5(%bb5Arg : f32):
+ "foo.bb5Terminator"(%bb5Arg) : (f32) -> ()
+
+ // CHECK: ^[[BB6]]({{.*}}):
+ // CHECK-NEXT: "foo.bb6Terminator"
+ ^bb6(%bb6Arg : f32):
+ "foo.bb6Terminator"(%bb6Arg) : (f32) -> ()
+}
+
+// CHECK-LABEL: func @switch_from_switch_with_same_value_with_match(
+// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
+func @switch_from_switch_with_same_value_with_match(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32) {
+ // add predecessors for all blocks except ^bb3 to avoid other canonicalizations.
+ "foo.pred"() [^bb1, ^bb2, ^bb4, ^bb5] : () -> ()
+
+ ^bb1:
+ // CHECK: switch %[[FLAG]]
+ switch %flag : i32, [
+ default: ^bb2,
+ 42: ^bb3
+ ]
+
+ ^bb2:
+ "foo.bb2Terminator"() : () -> ()
+ ^bb3:
+ // prevent this block from being simplified away
+ "foo.op"() : () -> ()
+ // CHECK-NOT: switch %[[FLAG]]
+ // CHECK: br ^[[BB5:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_1]]
+ switch %flag : i32, [
+ default: ^bb4(%caseOperand0 : f32),
+ 42: ^bb5(%caseOperand1 : f32)
+ ]
+
+ ^bb4(%bb4Arg : f32):
+ "foo.bb4Terminator"(%bb4Arg) : (f32) -> ()
+
+ // CHECK: ^[[BB5]]({{.*}}):
+ // CHECK-NEXT: "foo.bb5Terminator"
+ ^bb5(%bb5Arg : f32):
+ "foo.bb5Terminator"(%bb5Arg) : (f32) -> ()
+}
+
+// CHECK-LABEL: func @switch_from_switch_with_same_value_no_match(
+// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]]
+func @switch_from_switch_with_same_value_no_match(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
+ // add predecessors for all blocks except ^bb3 to avoid other canonicalizations.
+ "foo.pred"() [^bb1, ^bb2, ^bb4, ^bb5, ^bb6] : () -> ()
+
+ ^bb1:
+ // CHECK: switch %[[FLAG]]
+ switch %flag : i32, [
+ default: ^bb2,
+ 42: ^bb3
+ ]
+
+ ^bb2:
+ "foo.bb2Terminator"() : () -> ()
+ ^bb3:
+ "foo.op"() : () -> ()
+ // CHECK-NOT: switch %[[FLAG]]
+ // CHECK: br ^[[BB4:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_0]]
+ switch %flag : i32, [
+ default: ^bb4(%caseOperand0 : f32),
+ 0: ^bb5(%caseOperand1 : f32),
+ 43: ^bb6(%caseOperand2 : f32)
+ ]
+
+ // CHECK: ^[[BB4]]({{.*}})
+ // CHECK-NEXT: "foo.bb4Terminator"
+ ^bb4(%bb4Arg : f32):
+ "foo.bb4Terminator"(%bb4Arg) : (f32) -> ()
+
+ ^bb5(%bb5Arg : f32):
+ "foo.bb5Terminator"(%bb5Arg) : (f32) -> ()
+
+ ^bb6(%bb6Arg : f32):
+ "foo.bb6Terminator"(%bb6Arg) : (f32) -> ()
+}
+
+// CHECK-LABEL: func @switch_from_switch_default_with_same_value(
+// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]]
+func @switch_from_switch_default_with_same_value(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
+ // add predecessors for all blocks except ^bb3 to avoid other canonicalizations.
+ "foo.pred"() [^bb1, ^bb2, ^bb4, ^bb5, ^bb6] : () -> ()
+
+ ^bb1:
+ // CHECK: switch %[[FLAG]]
+ switch %flag : i32, [
+ default: ^bb3,
+ 42: ^bb2
+ ]
+
+ ^bb2:
+ "foo.bb2Terminator"() : () -> ()
+ ^bb3:
+ "foo.op"() : () -> ()
+ // CHECK: switch %[[FLAG]]
+ // CHECK-NEXT: default: ^[[BB4:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_0]]
+ // CHECK-NEXT: 43: ^[[BB6:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_2]]
+ // CHECK-NOT: 42
+ switch %flag : i32, [
+ default: ^bb4(%caseOperand0 : f32),
+ 42: ^bb5(%caseOperand1 : f32),
+ 43: ^bb6(%caseOperand2 : f32)
+ ]
+
+ // CHECK: ^[[BB4]]({{.*}}):
+ // CHECK-NEXT: "foo.bb4Terminator"
+ ^bb4(%bb4Arg : f32):
+ "foo.bb4Terminator"(%bb4Arg) : (f32) -> ()
+
+ ^bb5(%bb5Arg : f32):
+ "foo.bb5Terminator"(%bb5Arg) : (f32) -> ()
+
+ // CHECK: ^[[BB6]]({{.*}}):
+ // CHECK-NEXT: "foo.bb6Terminator"
+ ^bb6(%bb6Arg : f32):
+ "foo.bb6Terminator"(%bb6Arg) : (f32) -> ()
+}
+
/// Test folding conditional branches that are successors of conditional
/// branches with the same condition.