[MLIR][Math] Add constant folder for powf
authorWilliam S. Moses <gh@wsmoses.com>
Wed, 16 Mar 2022 20:51:03 +0000 (16:51 -0400)
committerWilliam S. Moses <gh@wsmoses.com>
Thu, 17 Mar 2022 18:19:47 +0000 (14:19 -0400)
Constant fold powf, given two constant operands and a compatible type

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D121845

mlir/include/mlir/Dialect/Math/IR/MathOps.td
mlir/lib/Dialect/Math/IR/MathOps.cpp
mlir/test/Dialect/Math/canonicalize.mlir

index c636c70..ca91d53 100644 (file)
@@ -684,6 +684,7 @@ def Math_PowFOp : Math_FloatBinaryOp<"powf"> {
     %x = math.powf %y, %z : tensor<4x?xbf16>
     ```
   }];
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
index 4396c32..4410e93 100644 (file)
@@ -63,7 +63,40 @@ OpFoldResult math::Log2Op::fold(ArrayRef<Attribute> operands) {
     return FloatAttr::get(getType(), log2(apf.convertToDouble()));
 
   if (ft.getWidth() == 32)
-    return FloatAttr::get(getType(), log2f(apf.convertToDouble()));
+    return FloatAttr::get(getType(), log2f(apf.convertToFloat()));
+
+  return {};
+}
+
+//===----------------------------------------------------------------------===//
+// PowFOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::PowFOp::fold(ArrayRef<Attribute> operands) {
+  auto ft = getType().dyn_cast<FloatType>();
+  if (!ft)
+    return {};
+
+  APFloat vals[2]{APFloat(ft.getFloatSemantics()),
+                  APFloat(ft.getFloatSemantics())};
+  for (int i = 0; i < 2; ++i) {
+    if (!operands[i])
+      return {};
+
+    auto attr = operands[i].dyn_cast<FloatAttr>();
+    if (!attr)
+      return {};
+
+    vals[i] = attr.getValue();
+  }
+
+  if (ft.getWidth() == 64)
+    return FloatAttr::get(
+        getType(), pow(vals[0].convertToDouble(), vals[1].convertToDouble()));
+
+  if (ft.getWidth() == 32)
+    return FloatAttr::get(
+        getType(), powf(vals[0].convertToFloat(), vals[1].convertToFloat()));
 
   return {};
 }
index f62f0cf..5ee63b6 100644 (file)
@@ -73,3 +73,12 @@ func @log2_nofold2_64() -> f64 {
   %r = math.log2 %c : f64
   return %r : f64
 }
+
+// CHECK-LABEL: @powf_fold
+// CHECK: %[[cst:.+]] = arith.constant 4.000000e+00 : f32
+// CHECK: return %[[cst]]
+func @powf_fold() -> f32 {
+  %c = arith.constant 2.0 : f32
+  %r = math.powf %c, %c : f32
+  return %r : f32
+}