From: Balaji V. Iyer Date: Thu, 13 Apr 2023 15:54:21 +0000 (+0000) Subject: [mlir][math] Expand math.exp2 to use math.exp. X-Git-Tag: upstream/17.0.6~11736 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=4da96515ea8552cdf14c6aa6310d2a91fbe74641;p=platform%2Fupstream%2Fllvm.git [mlir][math] Expand math.exp2 to use math.exp. Exp2 functions are pushed directly to libm. This is problematic for situations where libm is not available. This patch will expand the exp2 function to use exp2 with the input multiplied by ln2 (natural log). Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D148064 --- diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h index 1b32de2..3ac18c3 100644 --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -19,6 +19,7 @@ void populateExpandTanhPattern(RewritePatternSet &patterns); void populateExpandFmaFPattern(RewritePatternSet &patterns); void populateExpandFloorFPattern(RewritePatternSet &patterns); void populateExpandCeilFPattern(RewritePatternSet &patterns); +void populateExpandExp2FPattern(RewritePatternSet &patterns); void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns); struct MathPolynomialApproximationOptions { diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp index b70ac4e..e9447dc 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -158,6 +158,22 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) { return success(); } +// exp2f(float x) -> exp(x * ln(2)) +// Proof: Let's say 2^x = y +// ln(2^x) = ln(y) +// x * ln(2) = ln(y) => e ^(x*ln(2)) = y +static LogicalResult convertExp2fOp(math::Exp2Op op, + PatternRewriter &rewriter) { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + Value operand = op.getOperand(); + Type opType = operand.getType(); + Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b); + Value mult = b.create(opType, operand, ln2); + Value exp = b.create(op->getLoc(), mult); + rewriter.replaceOp(op, exp); + return success(); +} + // Converts math.ctlz to scf and arith operations. This is done // by performing a binary search on the bits. static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, @@ -222,6 +238,10 @@ void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) { patterns.add(convertCeilOp); } +void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) { + patterns.add(convertExp2fOp); +} + void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) { patterns.add(convertFloorOp); } diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir index 8ab6449..50986969 100644 --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -165,3 +165,27 @@ func.func @ceilf_func(%a: f64) -> f64 { %ret = math.ceil %a : f64 return %ret : f64 } + +// ----- + +// CHECK-LABEL: func @exp2f_func +// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64 +func.func @exp2f_func(%a: f64) -> f64 { + // CHECK-DAG: [[CST:%.+]] = arith.constant 0.69314718055994529 + // CHECK: [[MULF:%.+]] = arith.mulf [[ARG0]], [[CST]] + // CHECK: [[EXP:%.+]] = math.exp [[MULF]] + // CHECK: return [[EXP]] + %ret = math.exp2 %a : f64 + return %ret : f64 +} + +// CHECK-LABEL: func @exp2f_func_tensor +// CHECK-SAME: ([[ARG0:%.+]]: tensor<1xf32>) -> tensor<1xf32> +func.func @exp2f_func_tensor(%a: tensor<1xf32>) -> tensor<1xf32> { + // CHECK-DAG: [[CST:%.+]] = arith.constant dense<0.693147182> + // CHECK: [[MULF:%.+]] = arith.mulf [[ARG0]], [[CST]] + // CHECK: [[EXP:%.+]] = math.exp [[MULF]] + // CHECK: return [[EXP]] + %ret = math.exp2 %a : tensor<1xf32> + return %ret : tensor<1xf32> +} diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp index c670617..29eff99 100644 --- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp +++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp @@ -37,6 +37,7 @@ struct TestExpandMathPass void TestExpandMathPass::runOnOperation() { RewritePatternSet patterns(&getContext()); populateExpandCtlzPattern(patterns); + populateExpandExp2FPattern(patterns); populateExpandTanPattern(patterns); populateExpandTanhPattern(patterns); populateExpandFmaFPattern(patterns); diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir new file mode 100644 index 0000000..3fb3b2b --- /dev/null +++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir @@ -0,0 +1,61 @@ +// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(test-expand-math,convert-arith-to-llvm),convert-vector-to-llvm,func.func(convert-math-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts)" \ +// RUN: | mlir-cpu-runner \ +// RUN: -e main -entry-point-result=void -O0 \ +// RUN: -shared-libs=%mlir_c_runner_utils \ +// RUN: -shared-libs=%mlir_runner_utils \ +// RUN: | FileCheck %s + +// -------------------------------------------------------------------------- // +// exp2f. +// -------------------------------------------------------------------------- // +func.func @func_exp2f(%a : f64) { + %r = math.exp2 %a : f64 + vector.print %r : f64 + return +} + +func.func @exp2f() { + // CHECK: 2 + %a = arith.constant 1.0 : f64 + call @func_exp2f(%a) : (f64) -> () + + // CHECK: 4 + %b = arith.constant 2.0 : f64 + call @func_exp2f(%b) : (f64) -> () + + // CHECK: 5.65685 + %c = arith.constant 2.5 : f64 + call @func_exp2f(%c) : (f64) -> () + + // CHECK: 0.29730 + %d = arith.constant -1.75 : f64 + call @func_exp2f(%d) : (f64) -> () + + // CHECK: 1.09581 + %e = arith.constant 0.132 : f64 + call @func_exp2f(%e) : (f64) -> () + + // CHECK: inf + %f1 = arith.constant 0.00 : f64 + %f2 = arith.constant 1.00 : f64 + %f = arith.divf %f2, %f1 : f64 + call @func_exp2f(%f) : (f64) -> () + + // CHECK: inf + %g = arith.constant 5038939.0 : f64 + call @func_exp2f(%g) : (f64) -> () + + // CHECK: 0 + %neg_inf = arith.constant 0xff80000000000000 : f64 + call @func_exp2f(%neg_inf) : (f64) -> () + + // CHECK: inf + %i = arith.constant 0x7fc0000000000000 : f64 + call @func_exp2f(%i) : (f64) -> () + return +} + +func.func @main() { + call @exp2f() : () -> () + return +}