#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::arith;
return DenseElementsAttr::get(shapedType, boolAttr);
}
+static Optional<int64_t> getIntegerWidth(Type t) {
+ if (auto intType = t.dyn_cast<IntegerType>()) {
+ return intType.getWidth();
+ }
+ if (auto vectorIntType = t.dyn_cast<VectorType>()) {
+ return vectorIntType.getElementType().cast<IntegerType>().getWidth();
+ }
+ return llvm::None;
+}
+
OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "cmpi takes two operands");
if (matchPattern(getRhs(), m_Zero())) {
if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
// extsi(%x : i1 -> iN) != 0 -> %x
- if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
+ Optional<int64_t> integerWidth =
+ getIntegerWidth(extOp.getOperand().getType());
+ if (integerWidth && integerWidth.value() == 1 &&
getPredicate() == arith::CmpIPredicate::ne)
return extOp.getOperand();
}
if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
// extui(%x : i1 -> iN) != 0 -> %x
- if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
+ Optional<int64_t> integerWidth =
+ getIntegerWidth(extOp.getOperand().getType());
+ if (integerWidth && integerWidth.value() == 1 &&
getPredicate() == arith::CmpIPredicate::ne)
return extOp.getOperand();
}
// -----
-// CHECK-LABEL: @cmpOfExtSI
+// CHECK-LABEL: @cmpOfExtSI(
// CHECK-NEXT: return %arg0
func.func @cmpOfExtSI(%arg0: i1) -> i1 {
%ext = arith.extsi %arg0 : i1 to i64
return %res : i1
}
-// CHECK-LABEL: @cmpOfExtUI
+// CHECK-LABEL: @cmpOfExtUI(
// CHECK-NEXT: return %arg0
func.func @cmpOfExtUI(%arg0: i1) -> i1 {
%ext = arith.extui %arg0 : i1 to i64
// -----
+// CHECK-LABEL: @cmpOfExtSIVector(
+// CHECK-NEXT: return %arg0
+func.func @cmpOfExtSIVector(%arg0: vector<4xi1>) -> vector<4xi1> {
+ %ext = arith.extsi %arg0 : vector<4xi1> to vector<4xi64>
+ %c0 = arith.constant dense<0> : vector<4xi64>
+ %res = arith.cmpi ne, %ext, %c0 : vector<4xi64>
+ return %res : vector<4xi1>
+}
+
+// CHECK-LABEL: @cmpOfExtUIVector(
+// CHECK-NEXT: return %arg0
+func.func @cmpOfExtUIVector(%arg0: vector<4xi1>) -> vector<4xi1> {
+ %ext = arith.extui %arg0 : vector<4xi1> to vector<4xi64>
+ %c0 = arith.constant dense<0> : vector<4xi64>
+ %res = arith.cmpi ne, %ext, %c0 : vector<4xi64>
+ return %res : vector<4xi1>
+}
+
+// -----
+
// CHECK-LABEL: @extSIOfExtUI
// CHECK: %[[res:.+]] = arith.extui %arg0 : i1 to i64
// CHECK: return %[[res]]
%res = arith.xori %b, %c : i32
return %res : i32
}
+
+// -----