[MLIR] Add a switch operation to the standard dialect
authorGeoffrey Martin-Noble <gcmn@google.com>
Tue, 13 Apr 2021 00:01:30 +0000 (17:01 -0700)
committerGeoffrey Martin-Noble <gcmn@google.com>
Tue, 13 Apr 2021 01:46:02 +0000 (18:46 -0700)
This is similar to the definition of llvm.switch, providing
unstructured branch-based control flow. It differs from the LLVM
operation in that it accepts any signless integer (not only an i32),
takes no branch weights (the same as the Branch and CondBranch ops),
and has a slightly different syntax for the default case that includes
it in the list of cases with an explicit `default` keyword.

Also included are several canonicalizers.

See https://llvm.discourse.group/t/rfc-add-std-switch-and-scf-switch/3090

Reviewed By: rriddle, bondhugula

Differential Revision: https://reviews.llvm.org/D99925

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/IR/OpBase.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize-cf.mlir
mlir/test/Dialect/Standard/ops.mlir
mlir/test/Dialect/Standard/parser.mlir [new file with mode: 0644]

index 6d058f4..060062e 100644 (file)
@@ -2030,6 +2030,89 @@ def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<
   let hasFolder = 1;
 }
 
+
+//===----------------------------------------------------------------------===//
+// SwitchOp
+//===----------------------------------------------------------------------===//
+
+def SwitchOp : Std_Op<"switch",
+    [AttrSizedOperandSegments,
+     DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
+     NoSideEffect, Terminator]> {
+  let summary = "switch operation";
+  let description = [{
+    The `switch` terminator operation represents a switch on a signless integer
+    value. If the flag matches one of the specified cases, then the
+    corresponding destination is jumped to. If the flag does not match any of
+    the cases, the default destination is jumped to. The count and types of
+    operands must align with the arguments in the corresponding target blocks.
+
+    Example:
+
+    ```mlir
+    switch %flag : i32, [
+      default: ^bb1(%a : i32),
+      42: ^bb1(%b : i32),
+      43: ^bb3(%c : i32)
+    ]
+    ```
+  }];
+
+  let arguments = (ins AnyInteger:$flag,
+                       Variadic<AnyType>:$defaultOperands,
+                       Variadic<AnyType>:$caseOperands,
+                       OptionalAttr<AnyIntElementsAttr>:$case_values,
+                       OptionalAttr<I32ElementsAttr>:$case_operand_offsets);
+  let successors = (successor
+                        AnySuccessor:$defaultDestination,
+                        VariadicSuccessor<AnySuccessor>:$caseDestinations);
+  let builders = [
+    OpBuilder<(ins "Value":$flag,
+      "Block *":$defaultDestination,
+      "ValueRange":$defaultOperands,
+      CArg<"ArrayRef<APInt>", "{}">:$caseValues,
+      CArg<"BlockRange", "{}">:$caseDestinations,
+      CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
+    OpBuilder<(ins "Value":$flag,
+      "Block *":$defaultDestination,
+      "ValueRange":$defaultOperands,
+      CArg<"ArrayRef<int32_t>", "{}">:$caseValues,
+      CArg<"BlockRange", "{}">:$caseDestinations,
+      CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
+    OpBuilder<(ins "Value":$flag,
+      "Block *":$defaultDestination,
+      "ValueRange":$defaultOperands,
+      CArg<"DenseIntElementsAttr", "{}">:$caseValues,
+      CArg<"BlockRange", "{}">:$caseDestinations,
+      CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>
+  ];
+
+  let assemblyFormat = [{
+    $flag `:` type($flag) `,` `[` `\n`
+      custom<SwitchOpCases>(ref(type($flag)),$defaultDestination,
+                            $defaultOperands,
+                            type($defaultOperands),
+                            $case_values,
+                            $caseDestinations,
+                            $caseOperands,
+                            type($caseOperands),
+                            $case_operand_offsets)
+   `]`
+    attr-dict
+  }];
+
+  let extraClassDeclaration = [{
+    /// Return the operands for the case destination block at the given index.
+    OperandRange getCaseOperands(unsigned index);
+
+    /// Return a mutable range of operands for the case destination block at the
+    /// given index.
+    MutableOperandRange getCaseOperandsMutable(unsigned index);
+  }];
+
+  let hasCanonicalizer = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // TruncateIOp
 //===----------------------------------------------------------------------===//
index a2469dc..f0b741e 100644 (file)
@@ -1333,13 +1333,15 @@ def IndexElementsAttr
                                       .isIndex()}]>,
                           "index elements attribute">;
 
-class AnyIntElementsAttr<int width> : IntElementsAttrBase<
+def AnyIntElementsAttr : IntElementsAttrBase<CPred<"true">, "integer elements attribute">;
+
+class IntElementsAttrOf<int width> : IntElementsAttrBase<
   CPred<"$_self.cast<::mlir::DenseIntElementsAttr>().getType()."
         "getElementType().isInteger(" # width # ")">,
   width # "-bit integer elements attribute">;
 
-def AnyI32ElementsAttr : AnyIntElementsAttr<32>;
-def AnyI64ElementsAttr : AnyIntElementsAttr<64>;
+def AnyI32ElementsAttr : IntElementsAttrOf<32>;
+def AnyI64ElementsAttr : IntElementsAttrOf<64>;
 
 class SignlessIntElementsAttr<int width> : IntElementsAttrBase<
   CPred<"$_self.cast<::mlir::DenseIntElementsAttr>().getType()."
index 3dad958..b538cba 100644 (file)
@@ -441,8 +441,9 @@ static LogicalResult verify(AtomicYieldOp op) {
 /// Given a successor, try to collapse it to a new destination if it only
 /// contains a passthrough unconditional branch. If the successor is
 /// collapsable, `successor` and `successorOperands` are updated to reference
-/// the new destination and values. `argStorage` is an optional storage to use
-/// if operands to the collapsed successor need to be remapped.
+/// the new destination and values. `argStorage` is used as storage if operands
+/// to the collapsed successor need to be remapped. It must outlive uses of
+/// successorOperands.
 static LogicalResult collapseBranch(Block *&successor,
                                     ValueRange &successorOperands,
                                     SmallVectorImpl<Value> &argStorage) {
@@ -2161,6 +2162,490 @@ void SubTensorInsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 //===----------------------------------------------------------------------===//
+// SwitchOp
+//===----------------------------------------------------------------------===//
+
+void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
+                     Block *defaultDestination, ValueRange defaultOperands,
+                     DenseIntElementsAttr caseValues,
+                     BlockRange caseDestinations,
+                     ArrayRef<ValueRange> caseOperands) {
+  SmallVector<Value> flattenedCaseOperands;
+  SmallVector<int32_t> caseOperandOffsets;
+  int32_t offset = 0;
+  for (ValueRange operands : caseOperands) {
+    flattenedCaseOperands.append(operands.begin(), operands.end());
+    caseOperandOffsets.push_back(offset);
+    offset += operands.size();
+  }
+  DenseIntElementsAttr caseOperandOffsetsAttr;
+  if (!caseOperandOffsets.empty())
+    caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets);
+
+  build(builder, result, value, defaultOperands, flattenedCaseOperands,
+        caseValues, caseOperandOffsetsAttr, defaultDestination,
+        caseDestinations);
+}
+
+void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
+                     Block *defaultDestination, ValueRange defaultOperands,
+                     ArrayRef<APInt> caseValues, BlockRange caseDestinations,
+                     ArrayRef<ValueRange> caseOperands) {
+  DenseIntElementsAttr caseValuesAttr;
+  if (!caseValues.empty()) {
+    ShapedType caseValueType = VectorType::get(
+        static_cast<int64_t>(caseValues.size()), value.getType());
+    caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
+  }
+  build(builder, result, value, defaultDestination, defaultOperands,
+        caseValuesAttr, caseDestinations, caseOperands);
+}
+
+/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
+///             ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
+static ParseResult
+parseSwitchOpCases(OpAsmParser &parser, Type &flagType,
+                   Block *&defaultDestination,
+                   SmallVectorImpl<OpAsmParser::OperandType> &defaultOperands,
+                   SmallVectorImpl<Type> &defaultOperandTypes,
+                   DenseIntElementsAttr &caseValues,
+                   SmallVectorImpl<Block *> &caseDestinations,
+                   SmallVectorImpl<OpAsmParser::OperandType> &caseOperands,
+                   SmallVectorImpl<Type> &caseOperandTypes,
+                   DenseIntElementsAttr &caseOperandOffsets) {
+
+  if (failed(parser.parseKeyword("default")) || failed(parser.parseColon()) ||
+      failed(parser.parseSuccessor(defaultDestination)))
+    return failure();
+  if (succeeded(parser.parseOptionalLParen())) {
+    if (failed(parser.parseRegionArgumentList(defaultOperands)) ||
+        failed(parser.parseColonTypeList(defaultOperandTypes)) ||
+        failed(parser.parseRParen()))
+      return failure();
+  }
+
+  SmallVector<APInt> values;
+  SmallVector<int32_t> offsets;
+  unsigned bitWidth = flagType.getIntOrFloatBitWidth();
+  int64_t offset = 0;
+  while (succeeded(parser.parseOptionalComma())) {
+    int64_t value = 0;
+    if (failed(parser.parseInteger(value)))
+      return failure();
+    values.push_back(APInt(bitWidth, value));
+
+    Block *destination;
+    SmallVector<OpAsmParser::OperandType> operands;
+    if (failed(parser.parseColon()) ||
+        failed(parser.parseSuccessor(destination)))
+      return failure();
+    if (succeeded(parser.parseOptionalLParen())) {
+      if (failed(parser.parseRegionArgumentList(operands)) ||
+          failed(parser.parseColonTypeList(caseOperandTypes)) ||
+          failed(parser.parseRParen()))
+        return failure();
+    }
+    caseDestinations.push_back(destination);
+    caseOperands.append(operands.begin(), operands.end());
+    offsets.push_back(offset);
+    offset += operands.size();
+  }
+
+  if (values.empty())
+    return success();
+
+  Builder &builder = parser.getBuilder();
+  ShapedType caseValueType =
+      VectorType::get(static_cast<int64_t>(values.size()), flagType);
+  caseValues = DenseIntElementsAttr::get(caseValueType, values);
+  caseOperandOffsets = builder.getI32VectorAttr(offsets);
+
+  return success();
+}
+
+static void printSwitchOpCases(
+    OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
+    OperandRange defaultOperands, TypeRange defaultOperandTypes,
+    DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
+    OperandRange caseOperands, TypeRange caseOperandTypes,
+    ElementsAttr caseOperandOffsets) {
+  p << "  default: ";
+  p.printSuccessorAndUseList(defaultDestination, defaultOperands);
+
+  if (!caseValues)
+    return;
+
+  for (int64_t i = 0, size = caseValues.size(); i < size; ++i) {
+    p << ',';
+    p.printNewline();
+    p << "  ";
+    p << caseValues.getValue<APInt>(i).getLimitedValue();
+    p << ": ";
+    p.printSuccessorAndUseList(caseDestinations[i], op.getCaseOperands(i));
+  }
+  p.printNewline();
+}
+
+static LogicalResult verify(SwitchOp op) {
+  auto caseValues = op.case_values();
+  auto caseDestinations = op.caseDestinations();
+
+  if (!caseValues && caseDestinations.empty())
+    return success();
+
+  Type flagType = op.flag().getType();
+  Type caseValueType = caseValues->getType().getElementType();
+  if (caseValueType != flagType)
+    return op.emitOpError()
+           << "'flag' type (" << flagType << ") should match case value type ("
+           << caseValueType << ")";
+
+  if (caseValues &&
+      caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
+    return op.emitOpError() << "number of case values (" << caseValues->size()
+                            << ") should match number of "
+                               "case destinations ("
+                            << caseDestinations.size() << ")";
+  return success();
+}
+
+OperandRange SwitchOp::getCaseOperands(unsigned index) {
+  return getCaseOperandsMutable(index);
+}
+
+MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) {
+  MutableOperandRange caseOperands = caseOperandsMutable();
+  if (!case_operand_offsets()) {
+    assert(caseOperands.size() == 0 &&
+           "non-empty case operands must have offsets");
+    return caseOperands;
+  }
+
+  ElementsAttr offsets = case_operand_offsets().getValue();
+  assert(index < offsets.size() && "invalid case operand offset index");
+
+  int64_t begin = offsets.getValue(index).cast<IntegerAttr>().getInt();
+  int64_t end = index + 1 == offsets.size()
+                    ? caseOperands.size()
+                    : offsets.getValue(index + 1).cast<IntegerAttr>().getInt();
+  return caseOperandsMutable().slice(begin, end - begin);
+}
+
+Optional<MutableOperandRange>
+SwitchOp::getMutableSuccessorOperands(unsigned index) {
+  assert(index < getNumSuccessors() && "invalid successor index");
+  return index == 0 ? defaultOperandsMutable()
+                    : getCaseOperandsMutable(index - 1);
+}
+
+Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
+  Optional<DenseIntElementsAttr> caseValues = case_values();
+
+  if (!caseValues)
+    return defaultDestination();
+
+  SuccessorRange caseDests = caseDestinations();
+  if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) {
+    for (int64_t i = 0, size = case_values()->size(); i < size; ++i)
+      if (value == caseValues->getValue<IntegerAttr>(i))
+        return caseDests[i];
+    return defaultDestination();
+  }
+  return nullptr;
+}
+
+/// switch %flag : i32, [
+///   default:  ^bb1
+/// ]
+///  -> br ^bb1
+static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
+                                                   PatternRewriter &rewriter) {
+  if (!op.caseDestinations().empty())
+    return failure();
+
+  rewriter.replaceOpWithNewOp<BranchOp>(op, op.defaultDestination(),
+                                        op.defaultOperands());
+  return success();
+}
+
+/// switch %flag : i32, [
+///   default: ^bb1,
+///   42: ^bb1,
+///   43: ^bb2
+/// ]
+/// ->
+/// switch %flag : i32, [
+///   default: ^bb1,
+///   43: ^bb2
+/// ]
+static LogicalResult
+dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
+  SmallVector<Block *> newCaseDestinations;
+  SmallVector<ValueRange> newCaseOperands;
+  SmallVector<APInt> newCaseValues;
+  bool requiresChange = false;
+  auto caseValues = op.case_values();
+  auto caseDests = op.caseDestinations();
+
+  for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
+    if (caseDests[i] == op.defaultDestination() &&
+        op.getCaseOperands(i) == op.defaultOperands()) {
+      requiresChange = true;
+      continue;
+    }
+    newCaseDestinations.push_back(caseDests[i]);
+    newCaseOperands.push_back(op.getCaseOperands(i));
+    newCaseValues.push_back(caseValues->getValue<APInt>(i));
+  }
+
+  if (!requiresChange)
+    return failure();
+
+  rewriter.replaceOpWithNewOp<SwitchOp>(op, op.flag(), op.defaultDestination(),
+                                        op.defaultOperands(), newCaseValues,
+                                        newCaseDestinations, newCaseOperands);
+  return success();
+}
+
+/// Helper for folding a switch with a constant value.
+/// switch %c_42 : i32, [
+///   default: ^bb1 ,
+///   42: ^bb2,
+///   43: ^bb3
+/// ]
+/// -> br ^bb2
+static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
+                       APInt caseValue) {
+  auto caseValues = op.case_values();
+  for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
+    if (caseValues->getValue<APInt>(i) == caseValue) {
+      rewriter.replaceOpWithNewOp<BranchOp>(op, op.caseDestinations()[i],
+                                            op.getCaseOperands(i));
+      return;
+    }
+  }
+  rewriter.replaceOpWithNewOp<BranchOp>(op, op.defaultDestination(),
+                                        op.defaultOperands());
+}
+
+/// switch %c_42 : i32, [
+///   default: ^bb1,
+///   42: ^bb2,
+///   43: ^bb3
+/// ]
+/// -> br ^bb2
+static LogicalResult simplifyConstSwitchValue(SwitchOp op,
+                                              PatternRewriter &rewriter) {
+  APInt caseValue;
+  if (!matchPattern(op.flag(), m_ConstantInt(&caseValue)))
+    return failure();
+
+  foldSwitch(op, rewriter, caseValue);
+  return success();
+}
+
+/// switch %c_42 : i32, [
+///   default: ^bb1,
+///   42: ^bb2,
+/// ]
+/// ^bb2:
+///   br ^bb3
+/// ->
+/// switch %c_42 : i32, [
+///   default: ^bb1,
+///   42: ^bb3,
+/// ]
+static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
+                                               PatternRewriter &rewriter) {
+
+  SmallVector<Block *> newCaseDests;
+  SmallVector<ValueRange> newCaseOperands;
+  SmallVector<SmallVector<Value>> argStorage;
+  auto caseValues = op.case_values();
+  auto caseDests = op.caseDestinations();
+  bool requiresChange = false;
+  for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
+    Block *caseDest = caseDests[i];
+    ValueRange caseOperands = op.getCaseOperands(i);
+    argStorage.emplace_back();
+    if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
+      requiresChange = true;
+
+    newCaseDests.push_back(caseDest);
+    newCaseOperands.push_back(caseOperands);
+  }
+
+  Block *defaultDest = op.defaultDestination();
+  ValueRange defaultOperands = op.defaultOperands();
+  argStorage.emplace_back();
+
+  if (succeeded(
+          collapseBranch(defaultDest, defaultOperands, argStorage.back())))
+    requiresChange = true;
+
+  if (!requiresChange)
+    return failure();
+
+  rewriter.replaceOpWithNewOp<SwitchOp>(op, op.flag(), defaultDest,
+                                        defaultOperands, caseValues.getValue(),
+                                        newCaseDests, newCaseOperands);
+  return success();
+}
+
+/// switch %flag : i32, [
+///   default: ^bb1,
+///   42: ^bb2,
+/// ]
+/// ^bb2:
+///   switch %flag : i32, [
+///     default: ^bb3,
+///     42: ^bb4
+///   ]
+/// ->
+/// switch %flag : i32, [
+///   default: ^bb1,
+///   42: ^bb2,
+/// ]
+/// ^bb2:
+///   br ^bb4
+///
+///  and
+///
+/// switch %flag : i32, [
+///   default: ^bb1,
+///   42: ^bb2,
+/// ]
+/// ^bb2:
+///   switch %flag : i32, [
+///     default: ^bb3,
+///     43: ^bb4
+///   ]
+/// ->
+/// switch %flag : i32, [
+///   default: ^bb1,
+///   42: ^bb2,
+/// ]
+/// ^bb2:
+///   br ^bb3
+static LogicalResult
+simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
+                                        PatternRewriter &rewriter) {
+  // Check that we have a single distinct predecessor.
+  Block *currentBlock = op->getBlock();
+  Block *predecessor = currentBlock->getSinglePredecessor();
+  if (!predecessor)
+    return failure();
+
+  // Check that the predecessor terminates with a switch branch to this block
+  // and that it branches on the same condition and that this branch isn't the
+  // default destination.
+  auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
+  if (!predSwitch || op.flag() != predSwitch.flag() ||
+      predSwitch.defaultDestination() == currentBlock)
+    return failure();
+
+  // Fold this switch to an unconditional branch.
+  APInt caseValue;
+  bool isDefault = true;
+  SuccessorRange predDests = predSwitch.caseDestinations();
+  Optional<DenseIntElementsAttr> predCaseValues = predSwitch.case_values();
+  for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) {
+    if (currentBlock == predDests[i]) {
+      caseValue = predCaseValues->getValue<APInt>(i);
+      isDefault = false;
+      break;
+    }
+  }
+  if (isDefault)
+    rewriter.replaceOpWithNewOp<BranchOp>(op, op.defaultDestination(),
+                                          op.defaultOperands());
+  else
+    foldSwitch(op, rewriter, caseValue);
+  return success();
+}
+
+/// switch %flag : i32, [
+///   default: ^bb1,
+///   42: ^bb2
+/// ]
+/// ^bb1:
+///   switch %flag : i32, [
+///     default: ^bb3,
+///     42: ^bb4,
+///     43: ^bb5
+///   ]
+/// ->
+/// switch %flag : i32, [
+///   default: ^bb1,
+///   42: ^bb2,
+/// ]
+/// ^bb1:
+///   switch %flag : i32, [
+///     default: ^bb3,
+///     43: ^bb5
+///   ]
+static LogicalResult
+simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
+                                               PatternRewriter &rewriter) {
+  // Check that we have a single distinct predecessor.
+  Block *currentBlock = op->getBlock();
+  Block *predecessor = currentBlock->getSinglePredecessor();
+  if (!predecessor)
+    return failure();
+
+  // Check that the predecessor terminates with a switch branch to this block
+  // and that it branches on the same condition and that this branch is the
+  // default destination.
+  auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
+  if (!predSwitch || op.flag() != predSwitch.flag() ||
+      predSwitch.defaultDestination() != currentBlock)
+    return failure();
+
+  // Delete case values that are not possible here.
+  DenseSet<APInt> caseValuesToRemove;
+  auto predDests = predSwitch.caseDestinations();
+  auto predCaseValues = predSwitch.case_values();
+  for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
+    if (currentBlock != predDests[i])
+      caseValuesToRemove.insert(predCaseValues->getValue<APInt>(i));
+
+  SmallVector<Block *> newCaseDestinations;
+  SmallVector<ValueRange> newCaseOperands;
+  SmallVector<APInt> newCaseValues;
+  bool requiresChange = false;
+
+  auto caseValues = op.case_values();
+  auto caseDests = op.caseDestinations();
+  for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
+    if (caseValuesToRemove.contains(caseValues->getValue<APInt>(i))) {
+      requiresChange = true;
+      continue;
+    }
+    newCaseDestinations.push_back(caseDests[i]);
+    newCaseOperands.push_back(op.getCaseOperands(i));
+    newCaseValues.push_back(caseValues->getValue<APInt>(i));
+  }
+
+  if (!requiresChange)
+    return failure();
+
+  rewriter.replaceOpWithNewOp<SwitchOp>(op, op.flag(), op.defaultDestination(),
+                                        op.defaultOperands(), newCaseValues,
+                                        newCaseDestinations, newCaseOperands);
+  return success();
+}
+
+void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                           MLIRContext *context) {
+  results.add(&simplifySwitchWithOnlyDefault)
+      .add(&dropSwitchCasesThatMatchDefault)
+      .add(&simplifyConstSwitchValue)
+      .add(&simplifyPassThroughSwitch)
+      .add(&simplifySwitchFromSwitchOnSameCondition)
+      .add(&simplifySwitchFromDefaultSwitchOnSameCondition);
+}
+
+//===----------------------------------------------------------------------===//
 // TruncateIOp
 //===----------------------------------------------------------------------===//
 
index 5f18562..d7d0a6a 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck --dump-input-context 20 %s
 
 /// Test the folding of BranchOp.
 
@@ -139,6 +139,268 @@ func @cond_br_pass_through_fail(%cond : i1) {
   return
 }
 
+
+/// Test the folding of SwitchOp
+
+// CHECK-LABEL: func @switch_only_default(
+// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
+func @switch_only_default(%flag : i32, %caseOperand0 : f32) {
+  // add predecessors for all blocks to avoid other canonicalizations.
+  "foo.pred"() [^bb1, ^bb2] : () -> ()
+  ^bb1:
+    // CHECK-NOT: switch
+    // CHECK: br ^[[BB2:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_0]]
+    switch %flag : i32, [
+      default: ^bb2(%caseOperand0 : f32)
+    ]
+  // CHECK: ^[[BB2]]({{.*}}):
+  ^bb2(%bb2Arg : f32):
+    // CHECK-NEXT: "foo.bb2Terminator"
+    "foo.bb2Terminator"(%bb2Arg) : (f32) -> ()
+}
+
+
+// CHECK-LABEL: func @switch_case_matching_default(
+// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
+func @switch_case_matching_default(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32) {
+  // add predecessors for all blocks to avoid other canonicalizations.
+  "foo.pred"() [^bb1, ^bb2, ^bb3] : () -> ()
+  ^bb1:
+    // CHECK: switch %[[FLAG]]
+    // CHECK-NEXT:   default: ^[[BB1:.+]](%[[CASE_OPERAND_0]] : f32)
+    // CHECK-NEXT:   10: ^[[BB2:.+]](%[[CASE_OPERAND_1]] : f32)
+    // CHECK-NEXT: ]
+    switch %flag : i32, [
+      default: ^bb2(%caseOperand0 : f32),
+      42: ^bb2(%caseOperand0 : f32),
+      10: ^bb3(%caseOperand1 : f32),
+      17: ^bb2(%caseOperand0 : f32)
+    ]
+  ^bb2(%bb2Arg : f32):
+    "foo.bb2Terminator"(%bb2Arg) : (f32) -> ()
+  ^bb3(%bb3Arg : f32):
+    "foo.bb3Terminator"(%bb3Arg) : (f32) -> ()
+}
+
+
+// CHECK-LABEL: func @switch_on_const_no_match(
+// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]]
+func @switch_on_const_no_match(%caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
+  // add predecessors for all blocks to avoid other canonicalizations.
+  "foo.pred"() [^bb1, ^bb2, ^bb3, ^bb4] : () -> ()
+  ^bb1:
+    // CHECK-NOT: switch
+    // CHECK: br ^[[BB2:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_0]]
+    %c0_i32 = constant 0 : i32
+    switch %c0_i32 : i32, [
+      default: ^bb2(%caseOperand0 : f32),
+      -1: ^bb3(%caseOperand1 : f32),
+      1: ^bb4(%caseOperand2 : f32)
+    ]
+  // CHECK: ^[[BB2]]({{.*}}):
+  // CHECK-NEXT: "foo.bb2Terminator"
+  ^bb2(%bb2Arg : f32):
+    "foo.bb2Terminator"(%bb2Arg) : (f32) -> ()
+  ^bb3(%bb3Arg : f32):
+    "foo.bb3Terminator"(%bb3Arg) : (f32) -> ()
+  ^bb4(%bb4Arg : f32):
+    "foo.bb4Terminator"(%bb4Arg) : (f32) -> ()
+}
+
+// CHECK-LABEL: func @switch_on_const_with_match(
+// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]]
+func @switch_on_const_with_match(%caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
+  // add predecessors for all blocks to avoid other canonicalizations.
+  "foo.pred"() [^bb1, ^bb2, ^bb3, ^bb4] : () -> ()
+  ^bb1:
+    // CHECK-NOT: switch
+    // CHECK: br ^[[BB4:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_2]]
+    %c0_i32 = constant 1 : i32
+    switch %c0_i32 : i32, [
+      default: ^bb2(%caseOperand0 : f32),
+      -1: ^bb3(%caseOperand1 : f32),
+      1: ^bb4(%caseOperand2 : f32)
+    ]
+  ^bb2(%bb2Arg : f32):
+    "foo.bb2Terminator"(%bb2Arg) : (f32) -> ()
+  ^bb3(%bb3Arg : f32):
+    "foo.bb3Terminator"(%bb3Arg) : (f32) -> ()
+  // CHECK: ^[[BB4]]({{.*}}):
+  // CHECK-NEXT: "foo.bb4Terminator"
+  ^bb4(%bb4Arg : f32):
+    "foo.bb4Terminator"(%bb4Arg) : (f32) -> ()
+}
+
+// CHECK-LABEL: func @switch_passthrough(
+// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_3:[a-zA-Z0-9_]+]]
+func @switch_passthrough(%flag : i32,
+                         %caseOperand0 : f32,
+                         %caseOperand1 : f32,
+                         %caseOperand2 : f32,
+                         %caseOperand3 : f32) {
+  // add predecessors for all blocks to avoid other canonicalizations.
+  "foo.pred"() [^bb1, ^bb2, ^bb3, ^bb4, ^bb5, ^bb6] : () -> ()
+
+  ^bb1:
+  //      CHECK: switch %[[FLAG]]
+  // CHECK-NEXT:   default: ^[[BB5:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_0]]
+  // CHECK-NEXT:   43: ^[[BB6:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_1]]
+  // CHECK-NEXT:   44: ^[[BB4:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_2]]
+  // CHECK-NEXT: ]
+    switch %flag : i32, [
+      default: ^bb2(%caseOperand0 : f32),
+      43: ^bb3(%caseOperand1 : f32),
+      44: ^bb4(%caseOperand2 : f32)
+    ]
+  ^bb2(%bb2Arg : f32):
+    br ^bb5(%bb2Arg : f32)
+  ^bb3(%bb3Arg : f32):
+    br ^bb6(%bb3Arg : f32)
+  ^bb4(%bb4Arg : f32):
+    "foo.bb4Terminator"(%bb4Arg) : (f32) -> ()
+
+  // CHECK: ^[[BB5]]({{.*}}):
+  // CHECK-NEXT: "foo.bb5Terminator"
+  ^bb5(%bb5Arg : f32):
+    "foo.bb5Terminator"(%bb5Arg) : (f32) -> ()
+
+  // CHECK: ^[[BB6]]({{.*}}):
+  // CHECK-NEXT: "foo.bb6Terminator"
+  ^bb6(%bb6Arg : f32):
+    "foo.bb6Terminator"(%bb6Arg) : (f32) -> ()
+}
+
+// CHECK-LABEL: func @switch_from_switch_with_same_value_with_match(
+// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
+func @switch_from_switch_with_same_value_with_match(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32) {
+  // add predecessors for all blocks except ^bb3 to avoid other canonicalizations.
+  "foo.pred"() [^bb1, ^bb2, ^bb4, ^bb5] : () -> ()
+
+  ^bb1:
+    // CHECK: switch %[[FLAG]]
+    switch %flag : i32, [
+      default: ^bb2,
+      42: ^bb3
+    ]
+
+  ^bb2:
+    "foo.bb2Terminator"() : () -> ()
+  ^bb3:
+    // prevent this block from being simplified away
+    "foo.op"() : () -> ()
+    // CHECK-NOT: switch %[[FLAG]]
+    // CHECK: br ^[[BB5:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_1]]
+    switch %flag : i32, [
+      default: ^bb4(%caseOperand0 : f32),
+      42: ^bb5(%caseOperand1 : f32)
+    ]
+
+  ^bb4(%bb4Arg : f32):
+    "foo.bb4Terminator"(%bb4Arg) : (f32) -> ()
+
+  // CHECK: ^[[BB5]]({{.*}}):
+  // CHECK-NEXT: "foo.bb5Terminator"
+  ^bb5(%bb5Arg : f32):
+    "foo.bb5Terminator"(%bb5Arg) : (f32) -> ()
+}
+
+// CHECK-LABEL: func @switch_from_switch_with_same_value_no_match(
+// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]]
+func @switch_from_switch_with_same_value_no_match(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
+  // add predecessors for all blocks except ^bb3 to avoid other canonicalizations.
+  "foo.pred"() [^bb1, ^bb2, ^bb4, ^bb5, ^bb6] : () -> ()
+
+  ^bb1:
+    // CHECK: switch %[[FLAG]]
+    switch %flag : i32, [
+      default: ^bb2,
+      42: ^bb3
+    ]
+
+  ^bb2:
+    "foo.bb2Terminator"() : () -> ()
+  ^bb3:
+    "foo.op"() : () -> ()
+    // CHECK-NOT: switch %[[FLAG]]
+    // CHECK: br ^[[BB4:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_0]]
+    switch %flag : i32, [
+      default: ^bb4(%caseOperand0 : f32),
+      0: ^bb5(%caseOperand1 : f32),
+      43: ^bb6(%caseOperand2 : f32)
+    ]
+
+  // CHECK: ^[[BB4]]({{.*}})
+  // CHECK-NEXT: "foo.bb4Terminator"
+  ^bb4(%bb4Arg : f32):
+    "foo.bb4Terminator"(%bb4Arg) : (f32) -> ()
+
+  ^bb5(%bb5Arg : f32):
+    "foo.bb5Terminator"(%bb5Arg) : (f32) -> ()
+
+  ^bb6(%bb6Arg : f32):
+    "foo.bb6Terminator"(%bb6Arg) : (f32) -> ()
+}
+
+// CHECK-LABEL: func @switch_from_switch_default_with_same_value(
+// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]]
+func @switch_from_switch_default_with_same_value(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
+  // add predecessors for all blocks except ^bb3 to avoid other canonicalizations.
+  "foo.pred"() [^bb1, ^bb2, ^bb4, ^bb5, ^bb6] : () -> ()
+
+  ^bb1:
+    // CHECK: switch %[[FLAG]]
+    switch %flag : i32, [
+      default: ^bb3,
+      42: ^bb2
+    ]
+
+  ^bb2:
+    "foo.bb2Terminator"() : () -> ()
+  ^bb3:
+    "foo.op"() : () -> ()
+    // CHECK: switch %[[FLAG]]
+    // CHECK-NEXT: default: ^[[BB4:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_0]]
+    // CHECK-NEXT: 43: ^[[BB6:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_2]]
+    // CHECK-NOT: 42
+    switch %flag : i32, [
+      default: ^bb4(%caseOperand0 : f32),
+      42: ^bb5(%caseOperand1 : f32),
+      43: ^bb6(%caseOperand2 : f32)
+    ]
+
+  // CHECK: ^[[BB4]]({{.*}}):
+  // CHECK-NEXT: "foo.bb4Terminator"
+  ^bb4(%bb4Arg : f32):
+    "foo.bb4Terminator"(%bb4Arg) : (f32) -> ()
+
+  ^bb5(%bb5Arg : f32):
+    "foo.bb5Terminator"(%bb5Arg) : (f32) -> ()
+
+  // CHECK: ^[[BB6]]({{.*}}):
+  // CHECK-NEXT: "foo.bb6Terminator"
+  ^bb6(%bb6Arg : f32):
+    "foo.bb6Terminator"(%bb6Arg) : (f32) -> ()
+}
+
 /// Test folding conditional branches that are successors of conditional
 /// branches with the same condition.
 
index 02ec47f..53a4ad5 100644 (file)
@@ -96,3 +96,35 @@ func @read_global_memref() {
   %1 = memref.tensor_load %0 : memref<2xf32>
   return
 }
+
+// CHECK-LABEL: func @switch(
+func @switch(%flag : i32, %caseOperand : i32) {
+  switch %flag : i32, [
+    default: ^bb1(%caseOperand : i32),
+    42: ^bb2(%caseOperand : i32),
+    43: ^bb3(%caseOperand : i32)
+  ]
+
+  ^bb1(%bb1arg : i32):
+    return
+  ^bb2(%bb2arg : i32):
+    return
+  ^bb3(%bb3arg : i32):
+    return
+}
+
+// CHECK-LABEL: func @switch_i64(
+func @switch_i64(%flag : i64, %caseOperand : i32) {
+  switch %flag : i64, [
+    default: ^bb1(%caseOperand : i32),
+    42: ^bb2(%caseOperand : i32),
+    43: ^bb3(%caseOperand : i32)
+  ]
+
+  ^bb1(%bb1arg : i32):
+    return
+  ^bb2(%bb2arg : i32):
+    return
+  ^bb3(%bb3arg : i32):
+    return
+}
diff --git a/mlir/test/Dialect/Standard/parser.mlir b/mlir/test/Dialect/Standard/parser.mlir
new file mode 100644 (file)
index 0000000..9fcf952
--- /dev/null
@@ -0,0 +1,69 @@
+// RUN: mlir-opt -verify-diagnostics -split-input-file %s
+
+func @switch_missing_case_value(%flag : i32, %caseOperand : i32) {
+  switch %flag : i32, [
+    default: ^bb1(%caseOperand : i32),
+    45: ^bb2(%caseOperand : i32),
+    // expected-error@+1 {{expected integer value}}
+    : ^bb3(%caseOperand : i32)
+  ]
+
+  ^bb1(%bb1arg : i32):
+    return
+  ^bb2(%bb2arg : i32):
+    return
+  ^bb3(%bb3arg : i32):
+    return
+}
+
+// -----
+
+func @switch_wrong_type_case_value(%flag : i32, %caseOperand : i32) {
+  switch %flag : i32, [
+    default: ^bb1(%caseOperand : i32),
+    // expected-error@+1 {{expected integer value}}
+    "hello": ^bb2(%caseOperand : i32)
+  ]
+
+  ^bb1(%bb1arg : i32):
+    return
+  ^bb2(%bb2arg : i32):
+    return
+  ^bb3(%bb3arg : i32):
+    return
+}
+
+// -----
+
+func @switch_missing_comma(%flag : i32, %caseOperand : i32) {
+  switch %flag : i32, [
+    default: ^bb1(%caseOperand : i32),
+    45: ^bb2(%caseOperand : i32)
+    // expected-error@+1 {{expected ']'}}
+    43: ^bb3(%caseOperand : i32)
+  ]
+
+  ^bb1(%bb1arg : i32):
+    return
+  ^bb2(%bb2arg : i32):
+    return
+  ^bb3(%bb3arg : i32):
+    return
+}
+
+// -----
+
+func @switch_missing_default(%flag : i32, %caseOperand : i32) {
+  switch %flag : i32, [
+    // expected-error@+1 {{expected 'default'}}
+    45: ^bb2(%caseOperand : i32)
+    43: ^bb3(%caseOperand : i32)
+  ]
+
+  ^bb1(%bb1arg : i32):
+    return
+  ^bb2(%bb2arg : i32):
+    return
+  ^bb3(%bb3arg : i32):
+    return
+}