From 164c7afaf5cbf924806fef7c280e2d71bdac0037 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 16 Mar 2022 16:51:03 -0400 Subject: [PATCH] [MLIR][Math] Add constant folder for powf Constant fold powf, given two constant operands and a compatible type Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D121845 --- mlir/include/mlir/Dialect/Math/IR/MathOps.td | 1 + mlir/lib/Dialect/Math/IR/MathOps.cpp | 35 +++++++++++++++++++++++++++- mlir/test/Dialect/Math/canonicalize.mlir | 9 +++++++ 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td index c636c70..ca91d53 100644 --- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -684,6 +684,7 @@ def Math_PowFOp : Math_FloatBinaryOp<"powf"> { %x = math.powf %y, %z : tensor<4x?xbf16> ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp index 4396c32..4410e93 100644 --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -63,7 +63,40 @@ OpFoldResult math::Log2Op::fold(ArrayRef operands) { return FloatAttr::get(getType(), log2(apf.convertToDouble())); if (ft.getWidth() == 32) - return FloatAttr::get(getType(), log2f(apf.convertToDouble())); + return FloatAttr::get(getType(), log2f(apf.convertToFloat())); + + return {}; +} + +//===----------------------------------------------------------------------===// +// PowFOp folder +//===----------------------------------------------------------------------===// + +OpFoldResult math::PowFOp::fold(ArrayRef operands) { + auto ft = getType().dyn_cast(); + if (!ft) + return {}; + + APFloat vals[2]{APFloat(ft.getFloatSemantics()), + APFloat(ft.getFloatSemantics())}; + for (int i = 0; i < 2; ++i) { + if (!operands[i]) + return {}; + + auto attr = operands[i].dyn_cast(); + if (!attr) + return {}; + + vals[i] = attr.getValue(); + } + + if (ft.getWidth() == 64) + return FloatAttr::get( + getType(), pow(vals[0].convertToDouble(), vals[1].convertToDouble())); + + if (ft.getWidth() == 32) + return FloatAttr::get( + getType(), powf(vals[0].convertToFloat(), vals[1].convertToFloat())); return {}; } diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir index f62f0cf..5ee63b6 100644 --- a/mlir/test/Dialect/Math/canonicalize.mlir +++ b/mlir/test/Dialect/Math/canonicalize.mlir @@ -73,3 +73,12 @@ func @log2_nofold2_64() -> f64 { %r = math.log2 %c : f64 return %r : f64 } + +// CHECK-LABEL: @powf_fold +// CHECK: %[[cst:.+]] = arith.constant 4.000000e+00 : f32 +// CHECK: return %[[cst]] +func @powf_fold() -> f32 { + %c = arith.constant 2.0 : f32 + %r = math.powf %c, %c : f32 + return %r : f32 +} -- 2.7.4