void populateExpandFmaFPattern(RewritePatternSet &patterns);
void populateExpandFloorFPattern(RewritePatternSet &patterns);
void populateExpandCeilFPattern(RewritePatternSet &patterns);
+void populateExpandExp2FPattern(RewritePatternSet &patterns);
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
struct MathPolynomialApproximationOptions {
return success();
}
+// exp2f(float x) -> exp(x * ln(2))
+// Proof: Let's say 2^x = y
+// ln(2^x) = ln(y)
+// x * ln(2) = ln(y) => e ^(x*ln(2)) = y
+static LogicalResult convertExp2fOp(math::Exp2Op op,
+ PatternRewriter &rewriter) {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type opType = operand.getType();
+ Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
+ Value mult = b.create<arith::MulFOp>(opType, operand, ln2);
+ Value exp = b.create<math::ExpOp>(op->getLoc(), mult);
+ rewriter.replaceOp(op, exp);
+ return success();
+}
+
// Converts math.ctlz to scf and arith operations. This is done
// by performing a binary search on the bits.
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
patterns.add(convertCeilOp);
}
+void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
+ patterns.add(convertExp2fOp);
+}
+
void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
patterns.add(convertFloorOp);
}
%ret = math.ceil %a : f64
return %ret : f64
}
+
+// -----
+
+// CHECK-LABEL: func @exp2f_func
+// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64
+func.func @exp2f_func(%a: f64) -> f64 {
+ // CHECK-DAG: [[CST:%.+]] = arith.constant 0.69314718055994529
+ // CHECK: [[MULF:%.+]] = arith.mulf [[ARG0]], [[CST]]
+ // CHECK: [[EXP:%.+]] = math.exp [[MULF]]
+ // CHECK: return [[EXP]]
+ %ret = math.exp2 %a : f64
+ return %ret : f64
+}
+
+// CHECK-LABEL: func @exp2f_func_tensor
+// CHECK-SAME: ([[ARG0:%.+]]: tensor<1xf32>) -> tensor<1xf32>
+func.func @exp2f_func_tensor(%a: tensor<1xf32>) -> tensor<1xf32> {
+ // CHECK-DAG: [[CST:%.+]] = arith.constant dense<0.693147182>
+ // CHECK: [[MULF:%.+]] = arith.mulf [[ARG0]], [[CST]]
+ // CHECK: [[EXP:%.+]] = math.exp [[MULF]]
+ // CHECK: return [[EXP]]
+ %ret = math.exp2 %a : tensor<1xf32>
+ return %ret : tensor<1xf32>
+}
void TestExpandMathPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateExpandCtlzPattern(patterns);
+ populateExpandExp2FPattern(patterns);
populateExpandTanPattern(patterns);
populateExpandTanhPattern(patterns);
populateExpandFmaFPattern(patterns);
--- /dev/null
+// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(test-expand-math,convert-arith-to-llvm),convert-vector-to-llvm,func.func(convert-math-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts)" \
+// RUN: | mlir-cpu-runner \
+// RUN: -e main -entry-point-result=void -O0 \
+// RUN: -shared-libs=%mlir_c_runner_utils \
+// RUN: -shared-libs=%mlir_runner_utils \
+// RUN: | FileCheck %s
+
+// -------------------------------------------------------------------------- //
+// exp2f.
+// -------------------------------------------------------------------------- //
+func.func @func_exp2f(%a : f64) {
+ %r = math.exp2 %a : f64
+ vector.print %r : f64
+ return
+}
+
+func.func @exp2f() {
+ // CHECK: 2
+ %a = arith.constant 1.0 : f64
+ call @func_exp2f(%a) : (f64) -> ()
+
+ // CHECK: 4
+ %b = arith.constant 2.0 : f64
+ call @func_exp2f(%b) : (f64) -> ()
+
+ // CHECK: 5.65685
+ %c = arith.constant 2.5 : f64
+ call @func_exp2f(%c) : (f64) -> ()
+
+ // CHECK: 0.29730
+ %d = arith.constant -1.75 : f64
+ call @func_exp2f(%d) : (f64) -> ()
+
+ // CHECK: 1.09581
+ %e = arith.constant 0.132 : f64
+ call @func_exp2f(%e) : (f64) -> ()
+
+ // CHECK: inf
+ %f1 = arith.constant 0.00 : f64
+ %f2 = arith.constant 1.00 : f64
+ %f = arith.divf %f2, %f1 : f64
+ call @func_exp2f(%f) : (f64) -> ()
+
+ // CHECK: inf
+ %g = arith.constant 5038939.0 : f64
+ call @func_exp2f(%g) : (f64) -> ()
+
+ // CHECK: 0
+ %neg_inf = arith.constant 0xff80000000000000 : f64
+ call @func_exp2f(%neg_inf) : (f64) -> ()
+
+ // CHECK: inf
+ %i = arith.constant 0x7fc0000000000000 : f64
+ call @func_exp2f(%i) : (f64) -> ()
+ return
+}
+
+func.func @main() {
+ call @exp2f() : () -> ()
+ return
+}