[mlir][arith] Allow to specify `constFoldBinaryOp` result type
authorJakub Kuderski <kubak@google.com>
Mon, 13 Feb 2023 19:18:12 +0000 (14:18 -0500)
committerJakub Kuderski <kubak@google.com>
Mon, 13 Feb 2023 19:18:14 +0000 (14:18 -0500)
This enables us to use the common fold helpers on elementwise ops that
produce different result type than operand types, e.g., `arith.cmpi` or
`arith.addui_extended`.

Use the updated helper to teach `arith.cmpi` to fold constant vectors.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D143779

mlir/include/mlir/Dialect/CommonFolders.h
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/Dialect/Arith/canonicalize.mlir
mlir/test/Dialect/SCF/canonicalize.mlir

index 425a8c2..4f24580 100644 (file)
 namespace mlir {
 /// Performs constant folding `calculate` with element-wise behavior on the two
 /// attributes in `operands` and returns the result if possible.
+/// Uses `resultType` for the type of the returned attribute.
 template <class AttrElementT,
           class ElementValueT = typename AttrElementT::ValueType,
           class CalculationT = function_ref<
               std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
 Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
+                                       Type resultType,
                                        const CalculationT &calculate) {
   assert(operands.size() == 2 && "binary op takes two operands");
-  if (!operands[0] || !operands[1])
+  if (!resultType || !operands[0] || !operands[1])
     return {};
 
   if (operands[0].isa<AttrElementT>() && operands[1].isa<AttrElementT>()) {
@@ -45,7 +47,7 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
     if (!calRes)
       return {};
 
-    return AttrElementT::get(lhs.getType(), *calRes);
+    return AttrElementT::get(resultType, *calRes);
   }
 
   if (operands[0].isa<SplatElementsAttr>() &&
@@ -62,9 +64,10 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
     if (!elementResult)
       return {};
 
-    return DenseElementsAttr::get(lhs.getType(), *elementResult);
-  } else if (operands[0].isa<ElementsAttr>() &&
-             operands[1].isa<ElementsAttr>()) {
+    return DenseElementsAttr::get(resultType, *elementResult);
+  }
+
+  if (operands[0].isa<ElementsAttr>() && operands[1].isa<ElementsAttr>()) {
     // Operands are ElementsAttr-derived; perform an element-wise fold by
     // expanding the values.
     auto lhs = operands[0].cast<ElementsAttr>();
@@ -83,11 +86,53 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
       elementResults.push_back(*elementResult);
     }
 
-    return DenseElementsAttr::get(lhs.getType(), elementResults);
+    return DenseElementsAttr::get(resultType, elementResults);
   }
   return {};
 }
 
+/// Performs constant folding `calculate` with element-wise behavior on the two
+/// attributes in `operands` and returns the result if possible.
+/// Uses the operand element type for the element type of the returned
+/// attribute.
+template <class AttrElementT,
+          class ElementValueT = typename AttrElementT::ValueType,
+          class CalculationT = function_ref<
+              std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
+Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
+                                       const CalculationT &calculate) {
+  assert(operands.size() == 2 && "binary op takes two operands");
+  auto getResultType = [](Attribute attr) -> Type {
+    if (auto typed = attr.dyn_cast_or_null<TypedAttr>())
+      return typed.getType();
+    return {};
+  };
+
+  Type lhsType = getResultType(operands[0]);
+  Type rhsType = getResultType(operands[1]);
+  if (!lhsType || !rhsType)
+    return {};
+  if (lhsType != rhsType)
+    return {};
+
+  return constFoldBinaryOpConditional<AttrElementT, ElementValueT,
+                                      CalculationT>(operands, lhsType,
+                                                    calculate);
+}
+
+template <class AttrElementT,
+          class ElementValueT = typename AttrElementT::ValueType,
+          class CalculationT =
+              function_ref<ElementValueT(ElementValueT, ElementValueT)>>
+Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType,
+                            const CalculationT &calculate) {
+  return constFoldBinaryOpConditional<AttrElementT>(
+      operands, resultType,
+      [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
+        return calculate(a, b);
+      });
+}
+
 template <class AttrElementT,
           class ElementValueT = typename AttrElementT::ValueType,
           class CalculationT =
index d3739f8..775ee84 100644 (file)
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/CommonFolders.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 
+#include "llvm/ADT/APInt.h"
 #include "llvm/ADT/APSInt.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallString.h"
@@ -108,6 +111,23 @@ namespace {
 } // namespace
 
 //===----------------------------------------------------------------------===//
+// Common helpers
+//===----------------------------------------------------------------------===//
+
+/// Return the type of the same shape (scalar, vector or tensor) containing i1.
+static Type getI1SameShape(Type type) {
+  auto i1Type = IntegerType::get(type.getContext(), 1);
+  if (auto tensorType = type.dyn_cast<RankedTensorType>())
+    return RankedTensorType::get(tensorType.getShape(), i1Type);
+  if (type.isa<UnrankedTensorType>())
+    return UnrankedTensorType::get(i1Type);
+  if (auto vectorType = type.dyn_cast<VectorType>())
+    return VectorType::get(vectorType.getShape(), i1Type,
+                           vectorType.getNumScalableDims());
+  return i1Type;
+}
+
+//===----------------------------------------------------------------------===//
 // ConstantOp
 //===----------------------------------------------------------------------===//
 
@@ -276,41 +296,16 @@ arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
   // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
   // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
   // operands. If that succeeds, calculate the overflow bit based on the sum
-  // and the first (constant) operand, `lhs`. Note that we cannot simply call
-  // `constFoldBinaryOp` again to calculate the overflow bit because the
-  // constructed attribute is of the same element type as both operands.
+  // and the first (constant) operand, `lhs`.
   if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
           adaptor.getOperands(),
           [](APInt a, const APInt &b) { return std::move(a) + b; })) {
-    Attribute overflowAttr;
-    if (auto lhs = adaptor.getLhs().dyn_cast<IntegerAttr>()) {
-      // Both arguments are scalars, calculate the scalar overflow value.
-      auto sum = sumAttr.cast<IntegerAttr>();
-      overflowAttr = IntegerAttr::get(
-          overflowTy,
-          calculateUnsignedOverflow(sum.getValue(), lhs.getValue()));
-    } else if (auto lhs = adaptor.getLhs().dyn_cast<SplatElementsAttr>()) {
-      // Both arguments are splats, calculate the splat overflow value.
-      auto sum = sumAttr.cast<SplatElementsAttr>();
-      APInt overflow = calculateUnsignedOverflow(sum.getSplatValue<APInt>(),
-                                                 lhs.getSplatValue<APInt>());
-      overflowAttr = SplatElementsAttr::get(overflowTy, overflow);
-    } else if (auto lhs = adaptor.getLhs().dyn_cast<ElementsAttr>()) {
-      // Othwerwise calculate element-wise overflow values.
-      auto sum = sumAttr.cast<ElementsAttr>();
-      const auto numElems = static_cast<size_t>(sum.getNumElements());
-      SmallVector<APInt> overflowValues;
-      overflowValues.reserve(numElems);
-
-      auto sumIt = sum.value_begin<APInt>();
-      auto lhsIt = lhs.value_begin<APInt>();
-      for (size_t i = 0, e = numElems; i != e; ++i, ++sumIt, ++lhsIt)
-        overflowValues.push_back(calculateUnsignedOverflow(*sumIt, *lhsIt));
-
-      overflowAttr = DenseElementsAttr::get(overflowTy, overflowValues);
-    } else {
+    Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
+        ArrayRef({sumAttr, adaptor.getLhs()}),
+        getI1SameShape(sumAttr.cast<TypedAttr>().getType()),
+        calculateUnsignedOverflow);
+    if (!overflowAttr)
       return failure();
-    }
 
     results.push_back(sumAttr);
     results.push_back(overflowAttr);
@@ -1535,23 +1530,6 @@ void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 }
 
 //===----------------------------------------------------------------------===//
-// Helpers for compare ops
-//===----------------------------------------------------------------------===//
-
-/// Return the type of the same shape (scalar, vector or tensor) containing i1.
-static Type getI1SameShape(Type type) {
-  auto i1Type = IntegerType::get(type.getContext(), 1);
-  if (auto tensorType = type.dyn_cast<RankedTensorType>())
-    return RankedTensorType::get(tensorType.getShape(), i1Type);
-  if (type.isa<UnrankedTensorType>())
-    return UnrankedTensorType::get(i1Type);
-  if (auto vectorType = type.dyn_cast<VectorType>())
-    return VectorType::get(vectorType.getShape(), i1Type,
-                           vectorType.getNumScalableDims());
-  return i1Type;
-}
-
-//===----------------------------------------------------------------------===//
 // CmpIOp
 //===----------------------------------------------------------------------===//
 
@@ -1671,16 +1649,18 @@ OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
     llvm_unreachable("unknown cmpi predicate kind");
   }
 
-  auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>();
-  if (!lhs)
-    return {};
-
   // We are moving constants to the right side; So if lhs is constant rhs is
   // guaranteed to be a constant.
-  auto rhs = adaptor.getRhs().cast<IntegerAttr>();
+  if (auto lhs = adaptor.getLhs().dyn_cast_or_null<TypedAttr>()) {
+    return constFoldBinaryOp<IntegerAttr>(
+        adaptor.getOperands(), getI1SameShape(lhs.getType()),
+        [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
+          return APInt(1,
+                       static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
+        });
+  }
 
-  auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
-  return BoolAttr::get(getContext(), val);
+  return {};
 }
 
 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
index 0ee1b0b..355e7a8 100644 (file)
@@ -322,6 +322,46 @@ func.func @cmpIExtUIEQ(%arg0: i8, %arg1: i8) -> i1 {
   return %res : i1
 }
 
+// CHECK-LABEL: @cmpIFoldEQ
+//       CHECK:  %[[res:.+]] = arith.constant dense<[true, true, false]> : vector<3xi1>
+//       CHECK:   return %[[res]]
+func.func @cmpIFoldEQ() -> vector<3xi1> {
+  %lhs = arith.constant dense<[1, 2, 3]> : vector<3xi32>
+  %rhs = arith.constant dense<[1, 2, 4]> : vector<3xi32>
+  %res = arith.cmpi eq, %lhs, %rhs : vector<3xi32>
+  return %res : vector<3xi1>
+}
+
+// CHECK-LABEL: @cmpIFoldNE
+//       CHECK:  %[[res:.+]] = arith.constant dense<[false, false, true]> : vector<3xi1>
+//       CHECK:   return %[[res]]
+func.func @cmpIFoldNE() -> vector<3xi1> {
+  %lhs = arith.constant dense<[1, 2, 3]> : vector<3xi32>
+  %rhs = arith.constant dense<[1, 2, 4]> : vector<3xi32>
+  %res = arith.cmpi ne, %lhs, %rhs : vector<3xi32>
+  return %res : vector<3xi1>
+}
+
+// CHECK-LABEL: @cmpIFoldSGE
+//       CHECK:  %[[res:.+]] = arith.constant dense<[true, true, false]> : vector<3xi1>
+//       CHECK:   return %[[res]]
+func.func @cmpIFoldSGE() -> vector<3xi1> {
+  %lhs = arith.constant dense<2> : vector<3xi32>
+  %rhs = arith.constant dense<[1, 2, 4]> : vector<3xi32>
+  %res = arith.cmpi sge, %lhs, %rhs : vector<3xi32>
+  return %res : vector<3xi1>
+}
+
+// CHECK-LABEL: @cmpIFoldULT
+//       CHECK:  %[[res:.+]] = arith.constant dense<false> : vector<3xi1>
+//       CHECK:   return %[[res]]
+func.func @cmpIFoldULT() -> vector<3xi1> {
+  %lhs = arith.constant dense<2> : vector<3xi32>
+  %rhs = arith.constant dense<1> : vector<3xi32>
+  %res = arith.cmpi ult, %lhs, %rhs : vector<3xi32>
+  return %res : vector<3xi1>
+}
+
 // -----
 
 // CHECK-LABEL: @andOfExtSI
index 220adc5..7ee88b6 100644 (file)
@@ -1070,13 +1070,13 @@ func.func @invariant_loop_args_in_same_order(%f_arg0: tensor<i32>) -> (tensor<i3
 // CHECK:    return %[[WHILE]]#0, %[[FUNC_ARG0]], %[[WHILE]]#1, %[[WHILE]]#2, %[[ZERO]]
 
 // CHECK-LABEL: @while_loop_invariant_argument_different_order
-func.func @while_loop_invariant_argument_different_order() -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
+func.func @while_loop_invariant_argument_different_order(%arg : tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
   %cst_0 = arith.constant dense<0> : tensor<i32>
   %cst_1 = arith.constant dense<1> : tensor<i32>
   %cst_42 = arith.constant dense<42> : tensor<i32>
 
   %0:6 = scf.while (%arg0 = %cst_0, %arg1 = %cst_1, %arg2 = %cst_1, %arg3 = %cst_1, %arg4 = %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
-    %1 = arith.cmpi slt, %arg0, %cst_42 : tensor<i32>
+    %1 = arith.cmpi slt, %arg0, %arg : tensor<i32>
     %2 = tensor.extract %1[] : tensor<i1>
     scf.condition(%2) %arg1, %arg0, %arg2, %arg0, %arg3, %arg4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
   } do {
@@ -1087,11 +1087,11 @@ func.func @while_loop_invariant_argument_different_order() -> (tensor<i32>, tens
   }
   return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
 }
+// CHECK-SAME: (%[[ARG:.+]]: tensor<i32>)
 // CHECK:    %[[ZERO:.*]] = arith.constant dense<0>
 // CHECK:    %[[ONE:.*]] = arith.constant dense<1>
-// CHECK:    %[[CST42:.*]] = arith.constant dense<42>
 // CHECK:    %[[WHILE:.*]]:2 = scf.while (%[[ARG1:.*]] = %[[ONE]], %[[ARG4:.*]] = %[[ZERO]])
-// CHECK:       arith.cmpi slt, %[[ZERO]], %[[CST42]]
+// CHECK:       arith.cmpi sgt, %[[ARG]], %[[ZERO]]
 // CHECK:       tensor.extract %{{.*}}[]
 // CHECK:       scf.condition(%{{.*}}) %[[ARG1]], %[[ARG4]]
 // CHECK:    } do {