[mlir][math] Expand math.powf to exp, log and multiply
authorBalaji V. Iyer <bviyer@gmail.com>
Fri, 14 Apr 2023 13:52:17 +0000 (13:52 +0000)
committerRobert Suderman <suderman@google.com>
Fri, 14 Apr 2023 14:04:19 +0000 (14:04 +0000)
Powf functions are pushed directly to libm. This is problematic for
situations where libm is not available. This patch will decompose the
powf function into log of exponent multiplied by log of base and raise
it to the exp.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D148164

mlir/include/mlir/Dialect/Math/Transforms/Passes.h
mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
mlir/test/Dialect/Math/expand-math.mlir
mlir/test/lib/Dialect/Math/TestExpandMath.cpp
mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir

index 6cd5b0a409223edd938b390984d46e22673ac9a3..245a11747d5c8856aecfccbd1f0132c04f9f2fc4 100644 (file)
@@ -20,6 +20,7 @@ void populateExpandFmaFPattern(RewritePatternSet &patterns);
 void populateExpandFloorFPattern(RewritePatternSet &patterns);
 void populateExpandCeilFPattern(RewritePatternSet &patterns);
 void populateExpandExp2FPattern(RewritePatternSet &patterns);
+void populateExpandPowFPattern(RewritePatternSet &patterns);
 void populateExpandRoundFPattern(RewritePatternSet &patterns);
 void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
 
index bc35263e12b2d29a4a622320d2047f45c74e5917..a37340d312f51a4a459260a16a3a9f065ffa0fdb 100644 (file)
@@ -157,6 +157,19 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
   rewriter.replaceOp(op, ret);
   return success();
 }
+// Converts  Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value operandA = op.getOperand(0);
+  Value operandB = op.getOperand(1);
+  Type opType = operandA.getType();
+
+  Value logA = b.create<math::LogOp>(opType, operandA);
+  Value mult = b.create<arith::MulFOp>(opType, logA, operandB);
+  Value expResult = b.create<math::ExpOp>(opType, mult);
+  rewriter.replaceOp(op, expResult);
+  return success();
+}
 
 // exp2f(float x) -> exp(x * ln(2))
 //   Proof: Let's say 2^x = y
@@ -264,6 +277,10 @@ void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
   patterns.add(convertExp2fOp);
 }
 
+void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
+  patterns.add(convertPowfOp);
+}
+
 void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
   patterns.add(convertRoundOp);
 }
index b3a5668f3235b7dd592e2a78afcd50f7b6768356..382278c060c8ea0cf4d69b91296a472abc922014 100644 (file)
@@ -207,3 +207,16 @@ func.func @roundf_func(%a: f64) -> f64 {
   %ret = math.round %a : f64
   return %ret : f64
 }
+
+// -----
+
+// CHECK-LABEL:   func @powf_func
+// CHECK-SAME:    ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
+func.func @powf_func(%a: f64, %b: f64) ->f64 {
+  // CHECK-DAG: [[LOG:%.+]] = math.log [[ARG0]]
+  // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[LOG]], [[ARG1]]
+  // CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
+  // CHECK: return [[EXPR]]
+  %ret = math.powf %a, %b : f64
+  return %ret : f64
+}
index 5692ecf8d72379fe6637d40c1812bf1784959f7d..c9b3357c9b5080f13736c391fd152b69766b8749 100644 (file)
@@ -43,6 +43,7 @@ void TestExpandMathPass::runOnOperation() {
   populateExpandFmaFPattern(patterns);
   populateExpandFloorFPattern(patterns);
   populateExpandCeilFPattern(patterns);
+  populateExpandPowFPattern(patterns);
   populateExpandRoundFPattern(patterns);
   (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
 }
index f3c7a2c4051b49de5d13fcaf1ba150c38fdb0612..b72f9ba8fd2585f20f86f782d6beac80d357dbe3 100644 (file)
@@ -100,9 +100,68 @@ func.func @roundf() {
   return
 }
 
+// -------------------------------------------------------------------------- //
+// pow.
+// -------------------------------------------------------------------------- //
+func.func @func_powff64(%a : f64, %b : f64) {
+  %r = math.powf %a, %b : f64
+  vector.print %r : f64
+  return
+}
+
+func.func @powf() {
+  // CHECK: 16
+  %a   = arith.constant 4.0 : f64
+  %a_p = arith.constant 2.0 : f64
+  call @func_powff64(%a, %a_p) : (f64, f64) -> ()
+
+  // CHECK: -nan
+  %b   = arith.constant -3.0 : f64
+  %b_p = arith.constant 3.0 : f64
+  call @func_powff64(%b, %b_p) : (f64, f64) -> ()
+
+  // CHECK: 2.343
+  %c   = arith.constant 2.343 : f64
+  %c_p = arith.constant 1.000 : f64
+  call @func_powff64(%c, %c_p) : (f64, f64) -> ()
+
+  // CHECK: 0.176171
+  %d   = arith.constant 4.25 : f64
+  %d_p = arith.constant -1.2  : f64
+  call @func_powff64(%d, %d_p) : (f64, f64) -> ()
+
+  // CHECK: 1
+  %e   = arith.constant 4.385 : f64
+  %e_p = arith.constant 0.00 : f64
+  call @func_powff64(%e, %e_p) : (f64, f64) -> ()
+
+  // CHECK: 6.62637
+  %f    = arith.constant 4.835 : f64
+  %f_p  = arith.constant 1.2 : f64
+  call @func_powff64(%f, %f_p) : (f64, f64) -> ()
+
+  // CHECK: -nan
+  %g    = arith.constant 0xff80000000000000 : f64
+  call @func_powff64(%g, %g) : (f64, f64) -> ()
+
+  // CHECK: nan
+  %h = arith.constant 0x7fffffffffffffff : f64
+  call @func_powff64(%h, %h) : (f64, f64) -> ()
+
+  // CHECK: nan
+  %i = arith.constant 1.0 : f64
+  call @func_powff64(%i, %h) : (f64, f64) -> ()
+
+  // CHECK: inf
+  %j   = arith.constant 29385.0 : f64
+  %j_p = arith.constant 23598.0 : f64
+  call @func_powff64(%j, %j_p) : (f64, f64) -> () 
+  return
+}
 
 func.func @main() {
   call @exp2f() : () -> ()
   call @roundf() : () -> ()
+  call @powf() : () -> ()
   return
 }