[mlir][math] Expand math.powf to exp, log and multiply
authorBalaji V. Iyer <bviyer@gmail.com>
Fri, 14 Apr 2023 13:52:17 +0000 (13:52 +0000)
committerRobert Suderman <suderman@google.com>
Fri, 14 Apr 2023 14:04:19 +0000 (14:04 +0000)
Powf functions are pushed directly to libm. This is problematic for
situations where libm is not available. This patch will decompose the
powf function into log of exponent multiplied by log of base and raise
it to the exp.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D148164

mlir/include/mlir/Dialect/Math/Transforms/Passes.h
mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
mlir/test/Dialect/Math/expand-math.mlir
mlir/test/lib/Dialect/Math/TestExpandMath.cpp
mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir

index 6cd5b0a..245a117 100644 (file)
@@ -20,6 +20,7 @@ void populateExpandFmaFPattern(RewritePatternSet &patterns);
 void populateExpandFloorFPattern(RewritePatternSet &patterns);
 void populateExpandCeilFPattern(RewritePatternSet &patterns);
 void populateExpandExp2FPattern(RewritePatternSet &patterns);
+void populateExpandPowFPattern(RewritePatternSet &patterns);
 void populateExpandRoundFPattern(RewritePatternSet &patterns);
 void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
 
index bc35263..a37340d 100644 (file)
@@ -157,6 +157,19 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
   rewriter.replaceOp(op, ret);
   return success();
 }
+// Converts  Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value operandA = op.getOperand(0);
+  Value operandB = op.getOperand(1);
+  Type opType = operandA.getType();
+
+  Value logA = b.create<math::LogOp>(opType, operandA);
+  Value mult = b.create<arith::MulFOp>(opType, logA, operandB);
+  Value expResult = b.create<math::ExpOp>(opType, mult);
+  rewriter.replaceOp(op, expResult);
+  return success();
+}
 
 // exp2f(float x) -> exp(x * ln(2))
 //   Proof: Let's say 2^x = y
@@ -264,6 +277,10 @@ void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
   patterns.add(convertExp2fOp);
 }
 
+void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
+  patterns.add(convertPowfOp);
+}
+
 void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
   patterns.add(convertRoundOp);
 }
index b3a5668..382278c 100644 (file)
@@ -207,3 +207,16 @@ func.func @roundf_func(%a: f64) -> f64 {
   %ret = math.round %a : f64
   return %ret : f64
 }
+
+// -----
+
+// CHECK-LABEL:   func @powf_func
+// CHECK-SAME:    ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
+func.func @powf_func(%a: f64, %b: f64) ->f64 {
+  // CHECK-DAG: [[LOG:%.+]] = math.log [[ARG0]]
+  // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[LOG]], [[ARG1]]
+  // CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
+  // CHECK: return [[EXPR]]
+  %ret = math.powf %a, %b : f64
+  return %ret : f64
+}
index 5692ecf..c9b3357 100644 (file)
@@ -43,6 +43,7 @@ void TestExpandMathPass::runOnOperation() {
   populateExpandFmaFPattern(patterns);
   populateExpandFloorFPattern(patterns);
   populateExpandCeilFPattern(patterns);
+  populateExpandPowFPattern(patterns);
   populateExpandRoundFPattern(patterns);
   (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
 }
index f3c7a2c..b72f9ba 100644 (file)
@@ -100,9 +100,68 @@ func.func @roundf() {
   return
 }
 
+// -------------------------------------------------------------------------- //
+// pow.
+// -------------------------------------------------------------------------- //
+func.func @func_powff64(%a : f64, %b : f64) {
+  %r = math.powf %a, %b : f64
+  vector.print %r : f64
+  return
+}
+
+func.func @powf() {
+  // CHECK: 16
+  %a   = arith.constant 4.0 : f64
+  %a_p = arith.constant 2.0 : f64
+  call @func_powff64(%a, %a_p) : (f64, f64) -> ()
+
+  // CHECK: -nan
+  %b   = arith.constant -3.0 : f64
+  %b_p = arith.constant 3.0 : f64
+  call @func_powff64(%b, %b_p) : (f64, f64) -> ()
+
+  // CHECK: 2.343
+  %c   = arith.constant 2.343 : f64
+  %c_p = arith.constant 1.000 : f64
+  call @func_powff64(%c, %c_p) : (f64, f64) -> ()
+
+  // CHECK: 0.176171
+  %d   = arith.constant 4.25 : f64
+  %d_p = arith.constant -1.2  : f64
+  call @func_powff64(%d, %d_p) : (f64, f64) -> ()
+
+  // CHECK: 1
+  %e   = arith.constant 4.385 : f64
+  %e_p = arith.constant 0.00 : f64
+  call @func_powff64(%e, %e_p) : (f64, f64) -> ()
+
+  // CHECK: 6.62637
+  %f    = arith.constant 4.835 : f64
+  %f_p  = arith.constant 1.2 : f64
+  call @func_powff64(%f, %f_p) : (f64, f64) -> ()
+
+  // CHECK: -nan
+  %g    = arith.constant 0xff80000000000000 : f64
+  call @func_powff64(%g, %g) : (f64, f64) -> ()
+
+  // CHECK: nan
+  %h = arith.constant 0x7fffffffffffffff : f64
+  call @func_powff64(%h, %h) : (f64, f64) -> ()
+
+  // CHECK: nan
+  %i = arith.constant 1.0 : f64
+  call @func_powff64(%i, %h) : (f64, f64) -> ()
+
+  // CHECK: inf
+  %j   = arith.constant 29385.0 : f64
+  %j_p = arith.constant 23598.0 : f64
+  call @func_powff64(%j, %j_p) : (f64, f64) -> () 
+  return
+}
 
 func.func @main() {
   call @exp2f() : () -> ()
   call @roundf() : () -> ()
+  call @powf() : () -> ()
   return
 }