LLVM switchop currently only permits i32. Both LLVM IR and MLIR Standard switch permit other integer types leading to an illegal state when lowering an i8 switch from MLIR standard
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D113955
[AttrSizedOperandSegments, DeclareOpInterfaceMethods<BranchOpInterface>,
NoSideEffect]> {
let arguments = (ins
- I32:$value,
+ AnyInteger:$value,
Variadic<AnyType>:$defaultOperands,
VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
OptionalAttr<ElementsAttr>:$case_values,
let verifier = [{ return ::verify(*this); }];
let assemblyFormat = [{
- $value `,`
+ $value `:` type($value) `,`
$defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)?
- `[` `\n` custom<SwitchOpCases>($case_values, $caseDestinations,
+ `[` `\n` custom<SwitchOpCases>(ref(type($value)), $case_values, $caseDestinations,
$caseOperands, type($caseOperands)) `]`
attr-dict
}];
/// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )?
static ParseResult parseSwitchOpCases(
- OpAsmParser &parser, ElementsAttr &caseValues,
+ OpAsmParser &parser, Type &flagType, ElementsAttr &caseValues,
SmallVectorImpl<Block *> &caseDestinations,
SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &caseOperands,
SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
- SmallVector<int32_t> values;
- int32_t value = 0;
+ SmallVector<APInt> values;
+ unsigned bitWidth = flagType.getIntOrFloatBitWidth();
do {
+ int64_t value = 0;
OptionalParseResult integerParseResult = parser.parseOptionalInteger(value);
if (values.empty() && !integerParseResult.hasValue())
return success();
if (!integerParseResult.hasValue() || integerParseResult.getValue())
return failure();
- values.push_back(value);
+ values.push_back(APInt(bitWidth, value));
Block *destination;
SmallVector<OpAsmParser::OperandType> operands;
caseOperandTypes.emplace_back(operandTypes);
} while (!parser.parseOptionalComma());
- caseValues = parser.getBuilder().getI32VectorAttr(values);
+ ShapedType caseValueType =
+ VectorType::get(static_cast<int64_t>(values.size()), flagType);
+ caseValues = DenseIntElementsAttr::get(caseValueType, values);
return success();
}
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op,
+ Type &flagType,
ElementsAttr caseValues,
SuccessorRange caseDestinations,
OperandRangeRange caseOperands,
// CHECK: %[[FINAL:.*]] = llvm.mlir.constant(false) : i1
// CHECK: %[[RET:.*]] = llvm.intr.coro.suspend %[[STATE]], %[[FINAL]]
// CHECK: %[[SEXT:.*]] = llvm.sext %[[RET]] : i8 to i32
- // CHECK: llvm.switch %[[SEXT]], ^[[SUSPEND:[b0-9]+]]
+ // CHECK: llvm.switch %[[SEXT]] : i32, ^[[SUSPEND:[b0-9]+]]
// CHECK-NEXT: 0: ^[[RESUME:[b0-9]+]]
// CHECK-NEXT: 1: ^[[CLEANUP:[b0-9]+]]
async.coro.suspend %2, ^suspend, ^resume, ^cleanup
// Decide the next block based on the code returned from suspend.
// CHECK: %[[SEXT:.*]] = llvm.sext %[[SUSPENDED]] : i8 to i32
-// CHECK: llvm.switch %[[SEXT]], ^[[SUSPEND:[b0-9]+]]
+// CHECK: llvm.switch %[[SEXT]] : i32, ^[[SUSPEND:[b0-9]+]]
// CHECK-NEXT: 0: ^[[RESUME:[b0-9]+]]
// CHECK-NEXT: 1: ^[[CLEANUP:[b0-9]+]]
%0 = select %arg0, %arg1, %arg2 : vector<4x3xi1>, vector<4x3xi32>
std.return
}
+
+// -----
+
+// CHECK-LABEL: func @switchi8(
+func @switchi8(%arg0 : i8) -> i32 {
+switch %arg0 : i8, [
+ default: ^bb1,
+ 42: ^bb1,
+ 43: ^bb3
+ ]
+^bb1:
+ %c_1 = arith.constant 1 : i32
+ std.return %c_1 : i32
+^bb3:
+ %c_42 = arith.constant 42 : i32
+ std.return %c_42: i32
+}
+// CHECK: llvm.switch %arg0 : i8, ^bb1 [
+// CHECK-NEXT: 42: ^bb1,
+// CHECK-NEXT: 43: ^bb2
+// CHECK-NEXT: ]
+// CHECK: ^bb1: // 2 preds: ^bb0, ^bb0
+// CHECK-NEXT: %[[E0:.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK-NEXT: llvm.return %[[E0]] : i32
+// CHECK: ^bb2: // pred: ^bb0
+// CHECK-NEXT: %[[E1:.+]] = llvm.mlir.constant(42 : i32) : i32
+// CHECK-NEXT: llvm.return %[[E1]] : i32
+// CHECK-NEXT: }
func @switch_wrong_number_of_weights(%arg0 : i32) {
// expected-error@+1 {{expects number of branch weights to match number of successors: 3 vs 2}}
- llvm.switch %arg0, ^bb1 [
+ llvm.switch %arg0 : i32, ^bb1 [
42: ^bb2(%arg0, %arg0 : i32, i32)
] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>}
// CHECK: %{{.*}} = llvm.mlir.constant(42 : i64) : i47
%22 = llvm.mlir.undef : !llvm.struct<(i32, f64, i32)>
%23 = llvm.mlir.constant(42) : i47
- // CHECK: llvm.switch %0, ^[[BB3]] [
+ // CHECK: llvm.switch %0 : i32, ^[[BB3]] [
// CHECK-NEXT: 1: ^[[BB4:.*]],
// CHECK-NEXT: 2: ^[[BB5:.*]],
// CHECK-NEXT: 3: ^[[BB6:.*]]
// CHECK-NEXT: ]
- llvm.switch %0, ^bb3 [
+ llvm.switch %0 : i32, ^bb3 [
1: ^bb4,
2: ^bb5,
3: ^bb6
// CHECK: ^[[BB3]]
^bb3:
-// CHECK: llvm.switch %0, ^[[BB7:.*]] [
+// CHECK: llvm.switch %0 : i32, ^[[BB7:.*]] [
// CHECK-NEXT: ]
- llvm.switch %0, ^bb7 [
+ llvm.switch %0 : i32, ^bb7 [
]
// CHECK: ^[[BB4]]
^bb4:
- llvm.switch %0, ^bb7 [
+ llvm.switch %0 : i32, ^bb7 [
]
// CHECK: ^[[BB5]]
^bb5:
- llvm.switch %0, ^bb7 [
+ llvm.switch %0 : i32, ^bb7 [
]
// CHECK: ^[[BB6]]
^bb6:
- llvm.switch %0, ^bb7 [
+ llvm.switch %0 : i32, ^bb7 [
]
// CHECK: ^[[BB7]]
// CHECK-NEXT: i32 -1, label %[[SWITCHCASE_bb2:[0-9]+]]
// CHECK-NEXT: i32 1, label %[[SWITCHCASE_bb3:[0-9]+]]
// CHECK-NEXT: ]
- llvm.switch %arg0, ^bb1 [
+ llvm.switch %arg0 : i32, ^bb1 [
-1: ^bb2(%0 : i32),
1: ^bb3(%1, %2 : i32, i32)
]
%1 = llvm.mlir.constant(23 : i32) : i32
%2 = llvm.mlir.constant(29 : i32) : i32
// CHECK: !prof ![[SWITCH_WEIGHT_NODE:[0-9]+]]
- llvm.switch %arg0, ^bb1(%0 : i32) [
+ llvm.switch %arg0 : i32, ^bb1(%0 : i32) [
9: ^bb2(%1, %2 : i32, i32),
99: ^bb3
] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>}