[mlir][std] Fold comparisons when the operands are equal
authorStephan Herhut <herhut@google.com>
Fri, 20 Nov 2020 10:46:22 +0000 (11:46 +0100)
committerStephan Herhut <herhut@google.com>
Fri, 20 Nov 2020 12:26:41 +0000 (13:26 +0100)
For equal operands, comparisons can be decided statically.

Differential Revision: https://reviews.llvm.org/D91856

mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize.mlir

index 342d732..6e755da 100644 (file)
@@ -916,17 +916,41 @@ bool mlir::applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
   llvm_unreachable("unknown comparison predicate");
 }
 
+// Returns true if the predicate is true for two equal operands.
+static bool applyCmpPredicateToEqualOperands(CmpIPredicate predicate) {
+  switch (predicate) {
+  case CmpIPredicate::eq:
+  case CmpIPredicate::sle:
+  case CmpIPredicate::sge:
+  case CmpIPredicate::ule:
+  case CmpIPredicate::uge:
+    return true;
+  case CmpIPredicate::ne:
+  case CmpIPredicate::slt:
+  case CmpIPredicate::sgt:
+  case CmpIPredicate::ult:
+  case CmpIPredicate::ugt:
+    return false;
+  }
+  llvm_unreachable("unknown comparison predicate");
+}
+
 // Constant folding hook for comparisons.
 OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
   assert(operands.size() == 2 && "cmpi takes two arguments");
 
+  if (lhs() == rhs()) {
+    auto val = applyCmpPredicateToEqualOperands(getPredicate());
+    return BoolAttr::get(val, getContext());
+  }
+
   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
   if (!lhs || !rhs)
     return {};
 
   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
-  return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
+  return BoolAttr::get(val, getContext());
 }
 
 //===----------------------------------------------------------------------===//
index 1e2e4a5..5147537 100644 (file)
@@ -59,3 +59,25 @@ func @dim_of_dynamic_tensor_from_elements(%arg0: index, %arg1: index) -> index {
   %1 = dim %0, %c3 : tensor<2x?x4x?x5xindex>
   return %1 : index
 }
+
+// Test case: Folding of comparisons with equal operands.
+// CHECK-LABEL: @cmpi_equal_operands
+//   CHECK-DAG:   %[[T:.*]] = constant true
+//   CHECK-DAG:   %[[F:.*]] = constant false
+//       CHECK:   return %[[T]], %[[T]], %[[T]], %[[T]], %[[T]],
+//  CHECK-SAME:          %[[F]], %[[F]], %[[F]], %[[F]], %[[F]]
+func @cmpi_equal_operands(%arg0: i64)
+    -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) {
+  %0 = cmpi "eq", %arg0, %arg0 : i64
+  %1 = cmpi "sle", %arg0, %arg0 : i64
+  %2 = cmpi "sge", %arg0, %arg0 : i64
+  %3 = cmpi "ule", %arg0, %arg0 : i64
+  %4 = cmpi "uge", %arg0, %arg0 : i64
+  %5 = cmpi "ne", %arg0, %arg0 : i64
+  %6 = cmpi "slt", %arg0, %arg0 : i64
+  %7 = cmpi "sgt", %arg0, %arg0 : i64
+  %8 = cmpi "ult", %arg0, %arg0 : i64
+  %9 = cmpi "ugt", %arg0, %arg0 : i64
+  return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9
+      : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
+}