[mlir][Math] Support fold Log2Op with constant dense.
authorjacquesguan <Jianjian.Guan@streamcomputing.com>
Thu, 7 Jul 2022 06:17:50 +0000 (14:17 +0800)
committerjacquesguan <Jianjian.Guan@streamcomputing.com>
Mon, 11 Jul 2022 02:34:28 +0000 (10:34 +0800)
This patch is similar to D129108, it adds a conditional unary constant folder which allow to exit when the constants not meet the fold condition. And use it for Log2Op to make it able to fold the constant dense.

Differential Revision: https://reviews.llvm.org/D129251

mlir/include/mlir/Dialect/CommonFolders.h
mlir/lib/Dialect/Math/IR/MathOps.cpp
mlir/test/Dialect/Math/canonicalize.mlir
mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir

index 55dc5ec..8680893 100644 (file)
@@ -98,11 +98,11 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
 
 /// Performs constant folding `calculate` with element-wise behavior on the one
 /// attributes in `operands` and returns the result if possible.
-template <class AttrElementT,
-          class ElementValueT = typename AttrElementT::ValueType,
-          class CalculationT = function_ref<ElementValueT(ElementValueT)>>
-Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
-                           const CalculationT &&calculate) {
+template <
+    class AttrElementT, class ElementValueT = typename AttrElementT::ValueType,
+    class CalculationT = function_ref<Optional<ElementValueT>(ElementValueT)>>
+Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
+                                      const CalculationT &&calculate) {
   assert(operands.size() == 1 && "unary op takes one operands");
   if (!operands[0])
     return {};
@@ -110,7 +110,10 @@ Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
   if (operands[0].isa<AttrElementT>()) {
     auto op = operands[0].cast<AttrElementT>();
 
-    return AttrElementT::get(op.getType(), calculate(op.getValue()));
+    auto res = calculate(op.getValue());
+    if (!res)
+      return {};
+    return AttrElementT::get(op.getType(), *res);
   }
   if (operands[0].isa<SplatElementsAttr>()) {
     // Both operands are splats so we can avoid expanding the values out and
@@ -118,7 +121,9 @@ Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
     auto op = operands[0].cast<SplatElementsAttr>();
 
     auto elementResult = calculate(op.getSplatValue<ElementValueT>());
-    return DenseElementsAttr::get(op.getType(), elementResult);
+    if (!elementResult)
+      return {};
+    return DenseElementsAttr::get(op.getType(), *elementResult);
   } else if (operands[0].isa<ElementsAttr>()) {
     // Operands are ElementsAttr-derived; perform an element-wise fold by
     // expanding the values.
@@ -127,13 +132,27 @@ Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
     auto opIt = op.value_begin<ElementValueT>();
     SmallVector<ElementValueT> elementResults;
     elementResults.reserve(op.getNumElements());
-    for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt)
-      elementResults.push_back(calculate(*opIt));
+    for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
+      auto elementResult = calculate(*opIt);
+      if (!elementResult)
+        return {};
+      elementResults.push_back(*elementResult);
+    }
     return DenseElementsAttr::get(op.getType(), elementResults);
   }
   return {};
 }
 
+template <class AttrElementT,
+          class ElementValueT = typename AttrElementT::ValueType,
+          class CalculationT = function_ref<ElementValueT(ElementValueT)>>
+Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
+                           const CalculationT &&calculate) {
+  return constFoldUnaryOpConditional<AttrElementT>(
+      operands,
+      [&](ElementValueT a) -> Optional<ElementValueT> { return calculate(a); });
+}
+
 template <
     class AttrElementT, class TargetAttrElementT,
     class ElementValueT = typename AttrElementT::ValueType,
index 34e2072..035e9b4 100644 (file)
@@ -92,28 +92,19 @@ OpFoldResult math::CtPopOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult math::Log2Op::fold(ArrayRef<Attribute> operands) {
-  auto constOperand = operands.front();
-  if (!constOperand)
-    return {};
-
-  auto attr = constOperand.dyn_cast<FloatAttr>();
-  if (!attr)
-    return {};
+  return constFoldUnaryOpConditional<FloatAttr>(
+      operands, [](const APFloat &a) -> Optional<APFloat> {
+        if (a.isNegative())
+          return {};
 
-  auto ft = getType().cast<FloatType>();
+        if (a.getSizeInBits(a.getSemantics()) == 64)
+          return APFloat(log2(a.convertToDouble()));
 
-  APFloat apf = attr.getValue();
+        if (a.getSizeInBits(a.getSemantics()) == 32)
+          return APFloat(log2f(a.convertToFloat()));
 
-  if (apf.isNegative())
-    return {};
-
-  if (ft.getWidth() == 64)
-    return FloatAttr::get(getType(), log2(apf.convertToDouble()));
-
-  if (ft.getWidth() == 32)
-    return FloatAttr::get(getType(), log2f(apf.convertToFloat()));
-
-  return {};
+        return {};
+      });
 }
 
 //===----------------------------------------------------------------------===//
index bcfdf1b..2ddd766 100644 (file)
@@ -74,6 +74,15 @@ func.func @log2_nofold2_64() -> f64 {
   return %r : f64
 }
 
+// CHECK-LABEL: @log2_fold_vec
+// CHECK: %[[cst:.+]] = arith.constant dense<[0.000000e+00, 1.000000e+00, 1.58496249, 2.000000e+00]> : vector<4xf32>
+// CHECK: return %[[cst]]
+func.func @log2_fold_vec() -> (vector<4xf32>) {
+  %v1 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32>
+  %0 = math.log2 %v1 : vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
 // CHECK-LABEL: @powf_fold
 // CHECK: %[[cst:.+]] = arith.constant 4.000000e+00 : f32
 // CHECK: return %[[cst]]
index 3f34bc9..e0fe1a8 100644 (file)
@@ -80,7 +80,7 @@ func.func @log2() {
   %1 = math.log2 %0 : f32
   vector.print %1 : f32
 
-  // CHECK: -2, -0.415037, 0, 0.321928
+  // CHECK: -2, -0.415038, 0, 0.321928
   %2 = arith.constant dense<[0.25, 0.75, 1.0, 1.25]> : vector<4xf32>
   %3 = math.log2 %2 : vector<4xf32>
   vector.print %3 : vector<4xf32>