[mlir][math] Expand math.exp2 to use math.exp.
authorBalaji V. Iyer <bviyer@gmail.com>
Thu, 13 Apr 2023 15:54:21 +0000 (15:54 +0000)
committerRobert Suderman <suderman@google.com>
Thu, 13 Apr 2023 16:06:04 +0000 (16:06 +0000)
Exp2 functions are pushed directly to libm. This is problematic for
situations where libm is not available. This patch will expand the exp2
function to use exp2 with the input multiplied by ln2 (natural log).

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D148064

mlir/include/mlir/Dialect/Math/Transforms/Passes.h
mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
mlir/test/Dialect/Math/expand-math.mlir
mlir/test/lib/Dialect/Math/TestExpandMath.cpp
mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir [new file with mode: 0644]

index 1b32de2..3ac18c3 100644 (file)
@@ -19,6 +19,7 @@ void populateExpandTanhPattern(RewritePatternSet &patterns);
 void populateExpandFmaFPattern(RewritePatternSet &patterns);
 void populateExpandFloorFPattern(RewritePatternSet &patterns);
 void populateExpandCeilFPattern(RewritePatternSet &patterns);
+void populateExpandExp2FPattern(RewritePatternSet &patterns);
 void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
 
 struct MathPolynomialApproximationOptions {
index b70ac4e..e9447dc 100644 (file)
@@ -158,6 +158,22 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
   return success();
 }
 
+// exp2f(float x) -> exp(x * ln(2))
+//   Proof: Let's say 2^x = y
+//   ln(2^x) = ln(y)
+//   x * ln(2) = ln(y) => e ^(x*ln(2)) = y
+static LogicalResult convertExp2fOp(math::Exp2Op op,
+                                    PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value operand = op.getOperand();
+  Type opType = operand.getType();
+  Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
+  Value mult = b.create<arith::MulFOp>(opType, operand, ln2);
+  Value exp = b.create<math::ExpOp>(op->getLoc(), mult);
+  rewriter.replaceOp(op, exp);
+  return success();
+}
+
 // Converts math.ctlz to scf and arith operations. This is done
 // by performing a binary search on the bits.
 static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
@@ -222,6 +238,10 @@ void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) {
   patterns.add(convertCeilOp);
 }
 
+void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
+  patterns.add(convertExp2fOp);
+}
+
 void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
   patterns.add(convertFloorOp);
 }
index 8ab6449..5098696 100644 (file)
@@ -165,3 +165,27 @@ func.func @ceilf_func(%a: f64) -> f64 {
   %ret = math.ceil %a : f64
   return %ret : f64
 }
+
+// -----
+
+// CHECK-LABEL:     func @exp2f_func
+// CHECK-SAME:      ([[ARG0:%.+]]: f64) -> f64
+func.func @exp2f_func(%a: f64) -> f64 {
+  // CHECK-DAG:     [[CST:%.+]]  = arith.constant 0.69314718055994529
+  // CHECK:         [[MULF:%.+]] = arith.mulf [[ARG0]], [[CST]]
+  // CHECK:         [[EXP:%.+]]  = math.exp [[MULF]]
+  // CHECK:         return [[EXP]]
+  %ret = math.exp2 %a : f64
+  return %ret : f64
+}
+
+// CHECK-LABEL:     func @exp2f_func_tensor
+// CHECK-SAME:      ([[ARG0:%.+]]: tensor<1xf32>) -> tensor<1xf32>
+func.func @exp2f_func_tensor(%a: tensor<1xf32>) -> tensor<1xf32> {
+  // CHECK-DAG:     [[CST:%.+]]  = arith.constant dense<0.693147182>
+  // CHECK:         [[MULF:%.+]] = arith.mulf [[ARG0]], [[CST]]
+  // CHECK:         [[EXP:%.+]]  = math.exp [[MULF]]
+  // CHECK:         return [[EXP]]
+  %ret = math.exp2 %a : tensor<1xf32>
+  return %ret : tensor<1xf32>
+}
index c670617..29eff99 100644 (file)
@@ -37,6 +37,7 @@ struct TestExpandMathPass
 void TestExpandMathPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
   populateExpandCtlzPattern(patterns);
+  populateExpandExp2FPattern(patterns);
   populateExpandTanPattern(patterns);
   populateExpandTanhPattern(patterns);
   populateExpandFmaFPattern(patterns);
diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
new file mode 100644 (file)
index 0000000..3fb3b2b
--- /dev/null
@@ -0,0 +1,61 @@
+// RUN:   mlir-opt %s -pass-pipeline="builtin.module(func.func(test-expand-math,convert-arith-to-llvm),convert-vector-to-llvm,func.func(convert-math-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts)" \
+// RUN: | mlir-cpu-runner                                                      \
+// RUN:     -e main -entry-point-result=void -O0                               \
+// RUN:     -shared-libs=%mlir_c_runner_utils  \
+// RUN:     -shared-libs=%mlir_runner_utils    \
+// RUN: | FileCheck %s
+
+// -------------------------------------------------------------------------- //
+// exp2f.
+// -------------------------------------------------------------------------- //
+func.func @func_exp2f(%a : f64) {
+  %r = math.exp2 %a : f64
+  vector.print %r : f64
+  return
+}
+
+func.func @exp2f() {
+  // CHECK: 2
+  %a = arith.constant 1.0 : f64
+  call @func_exp2f(%a) : (f64) -> ()  
+
+  // CHECK: 4
+  %b = arith.constant 2.0 : f64
+  call @func_exp2f(%b) : (f64) -> ()
+
+  // CHECK: 5.65685
+  %c = arith.constant 2.5 : f64
+  call @func_exp2f(%c) : (f64) -> ()
+
+  // CHECK: 0.29730
+  %d = arith.constant -1.75 : f64
+  call @func_exp2f(%d) : (f64) -> ()
+
+  // CHECK: 1.09581
+  %e = arith.constant 0.132 : f64
+  call @func_exp2f(%e) : (f64) -> ()
+
+  // CHECK: inf
+  %f1 = arith.constant 0.00 : f64
+  %f2 = arith.constant 1.00 : f64
+  %f = arith.divf %f2, %f1 : f64
+  call @func_exp2f(%f) : (f64) -> ()
+
+  // CHECK: inf
+  %g = arith.constant 5038939.0 : f64
+  call @func_exp2f(%g) : (f64) -> ()
+
+  // CHECK: 0
+  %neg_inf = arith.constant 0xff80000000000000 : f64
+  call @func_exp2f(%neg_inf) : (f64) -> ()
+
+  // CHECK: inf
+  %i = arith.constant 0x7fc0000000000000 : f64
+  call @func_exp2f(%i) : (f64) -> ()
+  return
+}
+
+func.func @main() {
+  call @exp2f() : () -> ()
+  return
+}