[mlir][Arith] Fix folder of CmpIOp to not fail when element type is not integer.
authorMahesh Ravishankar <ravishankarm@google.com>
Thu, 3 Nov 2022 20:38:34 +0000 (20:38 +0000)
committerMahesh Ravishankar <ravishankarm@google.com>
Thu, 3 Nov 2022 20:38:34 +0000 (20:38 +0000)
The folder used `cast<IntegerType>`  which would segfault if the type were
a vector type. Handle this case appropriately and avoid failure.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D137345

mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/Dialect/Arith/canonicalize.mlir

index d1d03a5..2c0fc51 100644 (file)
@@ -19,6 +19,7 @@
 
 #include "llvm/ADT/APSInt.h"
 #include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
 using namespace mlir::arith;
@@ -1444,6 +1445,16 @@ static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
   return DenseElementsAttr::get(shapedType, boolAttr);
 }
 
+static Optional<int64_t> getIntegerWidth(Type t) {
+  if (auto intType = t.dyn_cast<IntegerType>()) {
+    return intType.getWidth();
+  }
+  if (auto vectorIntType = t.dyn_cast<VectorType>()) {
+    return vectorIntType.getElementType().cast<IntegerType>().getWidth();
+  }
+  return llvm::None;
+}
+
 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
   assert(operands.size() == 2 && "cmpi takes two operands");
 
@@ -1456,13 +1467,17 @@ OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
   if (matchPattern(getRhs(), m_Zero())) {
     if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
       // extsi(%x : i1 -> iN) != 0  ->  %x
-      if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
+      Optional<int64_t> integerWidth =
+          getIntegerWidth(extOp.getOperand().getType());
+      if (integerWidth && integerWidth.value() == 1 &&
           getPredicate() == arith::CmpIPredicate::ne)
         return extOp.getOperand();
     }
     if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
       // extui(%x : i1 -> iN) != 0  ->  %x
-      if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
+      Optional<int64_t> integerWidth =
+          getIntegerWidth(extOp.getOperand().getType());
+      if (integerWidth && integerWidth.value() == 1 &&
           getPredicate() == arith::CmpIPredicate::ne)
         return extOp.getOperand();
     }
index 337eec0..336324e 100644 (file)
@@ -162,7 +162,7 @@ func.func @cmpi_const_right(%arg0: i64)
 
 // -----
 
-// CHECK-LABEL: @cmpOfExtSI
+// CHECK-LABEL: @cmpOfExtSI(
 //  CHECK-NEXT:   return %arg0
 func.func @cmpOfExtSI(%arg0: i1) -> i1 {
   %ext = arith.extsi %arg0 : i1 to i64
@@ -171,7 +171,7 @@ func.func @cmpOfExtSI(%arg0: i1) -> i1 {
   return %res : i1
 }
 
-// CHECK-LABEL: @cmpOfExtUI
+// CHECK-LABEL: @cmpOfExtUI(
 //  CHECK-NEXT:   return %arg0
 func.func @cmpOfExtUI(%arg0: i1) -> i1 {
   %ext = arith.extui %arg0 : i1 to i64
@@ -182,6 +182,26 @@ func.func @cmpOfExtUI(%arg0: i1) -> i1 {
 
 // -----
 
+// CHECK-LABEL: @cmpOfExtSIVector(
+//  CHECK-NEXT:   return %arg0
+func.func @cmpOfExtSIVector(%arg0: vector<4xi1>) -> vector<4xi1> {
+  %ext = arith.extsi %arg0 : vector<4xi1> to vector<4xi64>
+  %c0 = arith.constant dense<0> : vector<4xi64>
+  %res = arith.cmpi ne, %ext, %c0 : vector<4xi64>
+  return %res : vector<4xi1>
+}
+
+// CHECK-LABEL: @cmpOfExtUIVector(
+//  CHECK-NEXT:   return %arg0
+func.func @cmpOfExtUIVector(%arg0: vector<4xi1>) -> vector<4xi1> {
+  %ext = arith.extui %arg0 : vector<4xi1> to vector<4xi64>
+  %c0 = arith.constant dense<0> : vector<4xi64>
+  %res = arith.cmpi ne, %ext, %c0 : vector<4xi64>
+  return %res : vector<4xi1>
+}
+
+// -----
+
 // CHECK-LABEL: @extSIOfExtUI
 //       CHECK:   %[[res:.+]] = arith.extui %arg0 : i1 to i64
 //       CHECK:   return %[[res]]
@@ -1660,3 +1680,5 @@ func.func @xorxor3(%a : i32, %b : i32) -> i32 {
   %res = arith.xori %b, %c : i32
   return %res : i32
 }
+
+// -----