llvm_unreachable("unknown comparison predicate");
}
+// Returns true if the predicate is true for two equal operands.
+static bool applyCmpPredicateToEqualOperands(CmpIPredicate predicate) {
+ switch (predicate) {
+ case CmpIPredicate::eq:
+ case CmpIPredicate::sle:
+ case CmpIPredicate::sge:
+ case CmpIPredicate::ule:
+ case CmpIPredicate::uge:
+ return true;
+ case CmpIPredicate::ne:
+ case CmpIPredicate::slt:
+ case CmpIPredicate::sgt:
+ case CmpIPredicate::ult:
+ case CmpIPredicate::ugt:
+ return false;
+ }
+ llvm_unreachable("unknown comparison predicate");
+}
+
// Constant folding hook for comparisons.
OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "cmpi takes two arguments");
+ if (lhs() == rhs()) {
+ auto val = applyCmpPredicateToEqualOperands(getPredicate());
+ return BoolAttr::get(val, getContext());
+ }
+
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
if (!lhs || !rhs)
return {};
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
- return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
+ return BoolAttr::get(val, getContext());
}
//===----------------------------------------------------------------------===//
%1 = dim %0, %c3 : tensor<2x?x4x?x5xindex>
return %1 : index
}
+
+// Test case: Folding of comparisons with equal operands.
+// CHECK-LABEL: @cmpi_equal_operands
+// CHECK-DAG: %[[T:.*]] = constant true
+// CHECK-DAG: %[[F:.*]] = constant false
+// CHECK: return %[[T]], %[[T]], %[[T]], %[[T]], %[[T]],
+// CHECK-SAME: %[[F]], %[[F]], %[[F]], %[[F]], %[[F]]
+func @cmpi_equal_operands(%arg0: i64)
+ -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) {
+ %0 = cmpi "eq", %arg0, %arg0 : i64
+ %1 = cmpi "sle", %arg0, %arg0 : i64
+ %2 = cmpi "sge", %arg0, %arg0 : i64
+ %3 = cmpi "ule", %arg0, %arg0 : i64
+ %4 = cmpi "uge", %arg0, %arg0 : i64
+ %5 = cmpi "ne", %arg0, %arg0 : i64
+ %6 = cmpi "slt", %arg0, %arg0 : i64
+ %7 = cmpi "sgt", %arg0, %arg0 : i64
+ %8 = cmpi "ult", %arg0, %arg0 : i64
+ %9 = cmpi "ugt", %arg0, %arg0 : i64
+ return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9
+ : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
+}