[MLIR][LLVM] Permit integer types in switch other than i32
authorWilliam S. Moses <gh@wsmoses.com>
Tue, 16 Nov 2021 01:13:33 +0000 (20:13 -0500)
committerWilliam S. Moses <gh@wsmoses.com>
Tue, 16 Nov 2021 05:46:25 +0000 (00:46 -0500)
LLVM switchop currently only permits i32. Both LLVM IR and MLIR Standard switch permit other integer types leading to an illegal state when lowering an i8 switch from MLIR standard

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D113955

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir
mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
mlir/test/Dialect/LLVMIR/invalid.mlir
mlir/test/Dialect/LLVMIR/roundtrip.mlir
mlir/test/Target/LLVMIR/llvmir.mlir

index 6bd64ed..055975e 100644 (file)
@@ -724,7 +724,7 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
     [AttrSizedOperandSegments, DeclareOpInterfaceMethods<BranchOpInterface>,
      NoSideEffect]> {
   let arguments = (ins
-    I32:$value,
+    AnyInteger:$value,
     Variadic<AnyType>:$defaultOperands,
     VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
     OptionalAttr<ElementsAttr>:$case_values,
@@ -738,9 +738,9 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
 
   let verifier = [{ return ::verify(*this); }];
   let assemblyFormat = [{
-    $value `,`
+    $value `:` type($value) `,`
     $defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)?
-    `[` `\n` custom<SwitchOpCases>($case_values, $caseDestinations,
+    `[` `\n` custom<SwitchOpCases>(ref(type($value)), $case_values, $caseDestinations,
                                    $caseOperands, type($caseOperands)) `]`
     attr-dict
   }];
index ba7ec4f..2587dbd 100644 (file)
@@ -269,20 +269,21 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
 /// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
 ///             ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )?
 static ParseResult parseSwitchOpCases(
-    OpAsmParser &parser, ElementsAttr &caseValues,
+    OpAsmParser &parser, Type &flagType, ElementsAttr &caseValues,
     SmallVectorImpl<Block *> &caseDestinations,
     SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &caseOperands,
     SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
-  SmallVector<int32_t> values;
-  int32_t value = 0;
+  SmallVector<APInt> values;
+  unsigned bitWidth = flagType.getIntOrFloatBitWidth();
   do {
+    int64_t value = 0;
     OptionalParseResult integerParseResult = parser.parseOptionalInteger(value);
     if (values.empty() && !integerParseResult.hasValue())
       return success();
 
     if (!integerParseResult.hasValue() || integerParseResult.getValue())
       return failure();
-    values.push_back(value);
+    values.push_back(APInt(bitWidth, value));
 
     Block *destination;
     SmallVector<OpAsmParser::OperandType> operands;
@@ -299,11 +300,14 @@ static ParseResult parseSwitchOpCases(
     caseOperandTypes.emplace_back(operandTypes);
   } while (!parser.parseOptionalComma());
 
-  caseValues = parser.getBuilder().getI32VectorAttr(values);
+  ShapedType caseValueType =
+        VectorType::get(static_cast<int64_t>(values.size()), flagType); 
+  caseValues = DenseIntElementsAttr::get(caseValueType, values);
   return success();
 }
 
 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op,
+                               Type &flagType,
                                ElementsAttr caseValues,
                                SuccessorRange caseDestinations,
                                OperandRangeRange caseOperands,
index 7f5500c..7e47448 100644 (file)
@@ -64,7 +64,7 @@ func @coro_suspend() {
   // CHECK: %[[FINAL:.*]] = llvm.mlir.constant(false) : i1
   // CHECK: %[[RET:.*]] = llvm.intr.coro.suspend %[[STATE]], %[[FINAL]]
   // CHECK: %[[SEXT:.*]] = llvm.sext %[[RET]] : i8 to i32
-  // CHECK: llvm.switch %[[SEXT]], ^[[SUSPEND:[b0-9]+]]
+  // CHECK: llvm.switch %[[SEXT]] : i32, ^[[SUSPEND:[b0-9]+]]
   // CHECK-NEXT: 0: ^[[RESUME:[b0-9]+]]
   // CHECK-NEXT: 1: ^[[CLEANUP:[b0-9]+]]
   async.coro.suspend %2, ^suspend, ^resume, ^cleanup
index eb8ddbb..46ff750 100644 (file)
@@ -49,7 +49,7 @@ func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {
 
 // Decide the next block based on the code returned from suspend.
 // CHECK: %[[SEXT:.*]] = llvm.sext %[[SUSPENDED]] : i8 to i32
-// CHECK: llvm.switch %[[SEXT]], ^[[SUSPEND:[b0-9]+]]
+// CHECK: llvm.switch %[[SEXT]] : i32, ^[[SUSPEND:[b0-9]+]]
 // CHECK-NEXT: 0: ^[[RESUME:[b0-9]+]]
 // CHECK-NEXT: 1: ^[[CLEANUP:[b0-9]+]]
 
index 9e9636a..7d0942c 100644 (file)
@@ -592,3 +592,31 @@ func @select_2dvector(%arg0 : vector<4x3xi1>, %arg1 : vector<4x3xi32>, %arg2 : v
   %0 = select %arg0, %arg1, %arg2 : vector<4x3xi1>, vector<4x3xi32>
   std.return
 }
+
+// -----
+
+// CHECK-LABEL: func @switchi8(
+func @switchi8(%arg0 : i8) -> i32 {
+switch %arg0 : i8, [
+  default: ^bb1,
+    42: ^bb1,
+    43: ^bb3
+  ]
+^bb1:
+  %c_1 = arith.constant 1 : i32
+  std.return %c_1 : i32
+^bb3:
+  %c_42 = arith.constant 42 : i32
+  std.return %c_42: i32
+}
+// CHECK:     llvm.switch %arg0 : i8, ^bb1 [
+// CHECK-NEXT:       42: ^bb1,
+// CHECK-NEXT:       43: ^bb2
+// CHECK-NEXT:     ]
+// CHECK:   ^bb1:  // 2 preds: ^bb0, ^bb0
+// CHECK-NEXT:     %[[E0:.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK-NEXT:     llvm.return %[[E0]] : i32
+// CHECK:   ^bb2:  // pred: ^bb0
+// CHECK-NEXT:     %[[E1:.+]] = llvm.mlir.constant(42 : i32) : i32
+// CHECK-NEXT:     llvm.return %[[E1]] : i32
+// CHECK-NEXT:   }
index 3f07f17..fd9b576 100644 (file)
@@ -805,7 +805,7 @@ module attributes {llvm.data_layout = "#vjkr32"} {
 
 func @switch_wrong_number_of_weights(%arg0 : i32) {
   // expected-error@+1 {{expects number of branch weights to match number of successors: 3 vs 2}}
-  llvm.switch %arg0, ^bb1 [
+  llvm.switch %arg0 : i32, ^bb1 [
     42: ^bb2(%arg0, %arg0 : i32, i32)
   ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>}
 
index 8efd14e..b931c9b 100644 (file)
@@ -84,12 +84,12 @@ func @ops(%arg0: i32, %arg1: f32,
 // CHECK: %{{.*}} = llvm.mlir.constant(42 : i64) : i47
   %22 = llvm.mlir.undef : !llvm.struct<(i32, f64, i32)>
   %23 = llvm.mlir.constant(42) : i47
-  // CHECK:      llvm.switch %0, ^[[BB3]] [
+  // CHECK:      llvm.switch %0 : i32, ^[[BB3]] [
   // CHECK-NEXT:   1: ^[[BB4:.*]],
   // CHECK-NEXT:   2: ^[[BB5:.*]],
   // CHECK-NEXT:   3: ^[[BB6:.*]]
   // CHECK-NEXT: ]
-  llvm.switch %0, ^bb3 [
+  llvm.switch %0 : i32, ^bb3 [
     1: ^bb4,
     2: ^bb5,
     3: ^bb6
@@ -97,24 +97,24 @@ func @ops(%arg0: i32, %arg1: f32,
 
 // CHECK: ^[[BB3]]
 ^bb3:
-// CHECK:      llvm.switch %0, ^[[BB7:.*]] [
+// CHECK:      llvm.switch %0 : i32, ^[[BB7:.*]] [
 // CHECK-NEXT: ]
-  llvm.switch %0, ^bb7 [
+  llvm.switch %0 : i32, ^bb7 [
   ]
 
 // CHECK: ^[[BB4]]
 ^bb4:
-  llvm.switch %0, ^bb7 [
+  llvm.switch %0 : i32, ^bb7 [
   ]
 
 // CHECK: ^[[BB5]]
 ^bb5:
-  llvm.switch %0, ^bb7 [
+  llvm.switch %0 : i32, ^bb7 [
   ]
 
 // CHECK: ^[[BB6]]
 ^bb6:
-  llvm.switch %0, ^bb7 [
+  llvm.switch %0 : i32, ^bb7 [
   ]
 
 // CHECK: ^[[BB7]]
index 7677e59..f5b6d60 100644 (file)
@@ -1560,7 +1560,7 @@ llvm.func @switch_args(%arg0: i32) -> i32 {
   // CHECK-NEXT:   i32 -1, label %[[SWITCHCASE_bb2:[0-9]+]]
   // CHECK-NEXT:   i32 1, label %[[SWITCHCASE_bb3:[0-9]+]]
   // CHECK-NEXT: ]
-  llvm.switch %arg0, ^bb1 [
+  llvm.switch %arg0 : i32, ^bb1 [
     -1: ^bb2(%0 : i32),
     1: ^bb3(%1, %2 : i32, i32)
   ]
@@ -1590,7 +1590,7 @@ llvm.func @switch_weights(%arg0: i32) -> i32 {
   %1 = llvm.mlir.constant(23 : i32) : i32
   %2 = llvm.mlir.constant(29 : i32) : i32
   // CHECK: !prof ![[SWITCH_WEIGHT_NODE:[0-9]+]]
-  llvm.switch %arg0, ^bb1(%0 : i32) [
+  llvm.switch %arg0 : i32, ^bb1(%0 : i32) [
     9: ^bb2(%1, %2 : i32, i32),
     99: ^bb3
   ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>}