From 6b5388104803262fedc783ad09d4b4fdfcc3646f Mon Sep 17 00:00:00 2001 From: Robert Suderman Date: Mon, 6 Mar 2023 11:09:11 -0800 Subject: [PATCH] [mlir][math] Add math.cbrt polynomial approximation Cbrt can be approximated with some relatively simple polynomial operators. This includes a lit test validating the implementation and some run tests that validate numerical correct. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D145019 --- .../Math/Transforms/PolynomialApproximation.cpp | 95 +++++++++++++++++++++- .../Dialect/Math/polynomial-approximation.mlir | 50 ++++++++++++ .../mlir-cpu-runner/math-polynomial-approx.mlir | 45 ++++++++++ 3 files changed, 189 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp index c0f3028..0d170f9 100644 --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -1213,6 +1213,99 @@ LogicalResult SinAndCosApproximation::matchAndRewrite( } //----------------------------------------------------------------------------// +// Cbrt approximation. +//----------------------------------------------------------------------------// + +namespace { +struct CbrtApproximation : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::CbrtOp op, + PatternRewriter &rewriter) const final; +}; +} // namespace + +// Estimation of cube-root using an algorithm defined in +// Hacker's Delight 2nd Edition. +LogicalResult +CbrtApproximation::matchAndRewrite(math::CbrtOp op, + PatternRewriter &rewriter) const { + auto operand = op.getOperand(); + if (!getElementTypeOrSelf(operand).isF32()) + return rewriter.notifyMatchFailure(op, "unsupported operand type"); + + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + ArrayRef shape = vectorShape(operand); + + Type floatTy = getElementTypeOrSelf(operand.getType()); + Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth()); + + // Convert to vector types if necessary. + floatTy = broadcast(floatTy, shape); + intTy = broadcast(intTy, shape); + + auto bconst = [&](Attribute attr) -> Value { + Value value = b.create(attr); + return broadcast(b, value, shape); + }; + + // Declare the initial values: + Value intTwo = bconst(b.getI32IntegerAttr(2)); + Value intFour = bconst(b.getI32IntegerAttr(4)); + Value intEight = bconst(b.getI32IntegerAttr(8)); + Value intMagic = bconst(b.getI32IntegerAttr(0x2a5137a0)); + Value fpThird = bconst(b.getF32FloatAttr(0.33333333f)); + Value fpTwo = bconst(b.getF32FloatAttr(2.0f)); + Value fpZero = bconst(b.getF32FloatAttr(0.0f)); + + // Compute an approximation of one third: + // union {int ix; float x;}; + // x = x0; + // ix = ix/4 + ix/16; + Value absValue = b.create(operand); + Value intValue = b.create(intTy, absValue); + Value divideBy4 = b.create(intValue, intTwo); + Value divideBy16 = b.create(intValue, intFour); + intValue = b.create(divideBy4, divideBy16); + + // ix = ix + ix/16; + divideBy16 = b.create(intValue, intFour); + intValue = b.create(intValue, divideBy16); + + // ix = ix + ix/256; + Value divideBy256 = b.create(intValue, intEight); + intValue = b.create(intValue, divideBy256); + + // ix = 0x2a5137a0 + ix; + intValue = b.create(intValue, intMagic); + + // Perform one newtons step: + // x = 0.33333333f*(2.0f*x + x0/(x*x)); + Value floatValue = b.create(floatTy, intValue); + Value squared = b.create(floatValue, floatValue); + Value mulTwo = b.create(floatValue, fpTwo); + Value divSquared = b.create(absValue, squared); + floatValue = b.create(mulTwo, divSquared); + floatValue = b.create(floatValue, fpThird); + + // x = 0.33333333f*(2.0f*x + x0/(x*x)); + squared = b.create(floatValue, floatValue); + mulTwo = b.create(floatValue, fpTwo); + divSquared = b.create(absValue, squared); + floatValue = b.create(mulTwo, divSquared); + floatValue = b.create(floatValue, fpThird); + + // Check for zero and restore sign. + Value isZero = + b.create(arith::CmpFPredicate::OEQ, absValue, fpZero); + floatValue = b.create(isZero, fpZero, floatValue); + floatValue = b.create(floatValue, operand); + + rewriter.replaceOp(op, floatValue); + return success(); +} + +//----------------------------------------------------------------------------// // Rsqrt approximation. //----------------------------------------------------------------------------// @@ -1291,7 +1384,7 @@ void mlir::populateMathPolynomialApproximationPatterns( patterns.add, + CbrtApproximation, ReuseF32Expansion, SinAndCosApproximation, SinAndCosApproximation>( patterns.getContext()); diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir index 33ac11b..4b490e4 100644 --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -593,3 +593,53 @@ func.func @atan2_scalar(%arg0: f16, %arg1: f16) -> f16 { %0 = math.atan2 %arg0, %arg1 : f16 return %0 : f16 } + +// CHECK-LABEL: @cbrt_vector +// CHECK-SAME: %[[ARG0:.+]]: vector<4xf32> + +// CHECK: %[[TWO_INT:.+]] = arith.constant dense<2> +// CHECK: %[[FOUR_INT:.+]] = arith.constant dense<4> +// CHECK: %[[EIGHT_INT:.+]] = arith.constant dense<8> +// CHECK: %[[MAGIC:.+]] = arith.constant dense<709965728> +// CHECK: %[[THIRD_FP:.+]] = arith.constant dense<0.333333343> : vector<4xf32> +// CHECK: %[[TWO_FP:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32> +// CHECK: %[[ZERO_FP:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> + +// CHECK: %[[ABS:.+]] = math.absf %[[ARG0]] : vector<4xf32> + +// Perform the initial approximation: +// CHECK: %[[CAST:.+]] = arith.bitcast %[[ABS]] : vector<4xf32> to vector<4xi32> +// CHECK: %[[SH_TWO:.+]] = arith.shrsi %[[CAST]], %[[TWO_INT]] +// CHECK: %[[SH_FOUR:.+]] = arith.shrsi %[[CAST]], %[[FOUR_INT]] +// CHECK: %[[APPROX0:.+]] = arith.addi %[[SH_TWO]], %[[SH_FOUR]] +// CHECK: %[[SH_FOUR:.+]] = arith.shrsi %[[APPROX0]], %[[FOUR_INT]] +// CHECK: %[[APPROX1:.+]] = arith.addi %[[APPROX0]], %[[SH_FOUR]] +// CHECK: %[[SH_EIGHT:.+]] = arith.shrsi %[[APPROX1]], %[[EIGHT_INT]] +// CHECK: %[[APPROX2:.+]] = arith.addi %[[APPROX1]], %[[SH_EIGHT]] +// CHECK: %[[FIX:.+]] = arith.addi %[[APPROX2]], %[[MAGIC]] +// CHECK: %[[BCAST:.+]] = arith.bitcast %[[FIX]] + +// First Newton Step: +// CHECK: %[[SQR:.+]] = arith.mulf %[[BCAST]], %[[BCAST]] +// CHECK: %[[DOUBLE:.+]] = arith.mulf %[[BCAST]], %[[TWO_FP]] +// CHECK: %[[DIV:.+]] = arith.divf %[[ABS]], %[[SQR]] +// CHECK: %[[ADD:.+]] = arith.addf %[[DOUBLE]], %[[DIV]] +// CHECK: %[[APPROX3:.+]] = arith.mulf %[[ADD]], %[[THIRD_FP]] + +// Second Newton Step: +// CHECK: %[[SQR:.+]] = arith.mulf %[[APPROX3]], %[[APPROX3]] +// CHECK: %[[DOUBLE:.+]] = arith.mulf %[[APPROX3]], %[[TWO_FP]] +// CHECK: %[[DIV:.+]] = arith.divf %[[ABS]], %[[SQR]] +// CHECK: %[[ADD:.+]] = arith.addf %[[DOUBLE]], %[[DIV]] +// CHECK: %[[APPROX4:.+]] = arith.mulf %[[ADD]], %[[THIRD_FP]] + +// Check for zero special case and copy the sign: +// CHECK: %[[CMP:.+]] = arith.cmpf oeq, %[[ABS]], %[[ZERO_FP]] +// CHECK: %[[SEL:.+]] = arith.select %[[CMP]], %[[ZERO_FP]], %[[APPROX4]] +// CHECK: %[[SIGN:.+]] = math.copysign %[[SEL]], %[[ARG0]] +// CHECK: return %[[SIGN]] + +func.func @cbrt_vector(%arg0: vector<4xf32>) -> vector<4xf32> { + %0 = "math.cbrt"(%arg0) : (vector<4xf32>) -> vector<4xf32> + func.return %0 : vector<4xf32> +} \ No newline at end of file diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir index dbd8166..665d328 100644 --- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir +++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir @@ -568,6 +568,48 @@ func.func @atan2() { } +// -------------------------------------------------------------------------- // +// Cbrt. +// -------------------------------------------------------------------------- // + +func.func @cbrt_f32(%a : f32) { + %r = math.cbrt %a : f32 + vector.print %r : f32 + return +} + +func.func @cbrt() { + // CHECK: 1 + %a = arith.constant 1.0 : f32 + call @cbrt_f32(%a) : (f32) -> () + + // CHECK: -1 + %b = arith.constant -1.0 : f32 + call @cbrt_f32(%b) : (f32) -> () + + // CHECK: 0 + %c = arith.constant 0.0 : f32 + call @cbrt_f32(%c) : (f32) -> () + + // CHECK: -0 + %d = arith.constant -0.0 : f32 + call @cbrt_f32(%d) : (f32) -> () + + // CHECK: 10 + %e = arith.constant 1000.0 : f32 + call @cbrt_f32(%e) : (f32) -> () + + // CHECK: -10 + %f = arith.constant -1000.0 : f32 + call @cbrt_f32(%f) : (f32) -> () + + // CHECK: 2.57128 + %g = arith.constant 17.0 : f32 + call @cbrt_f32(%g) : (f32) -> () + + return +} + func.func @main() { call @tanh(): () -> () call @log(): () -> () @@ -580,5 +622,8 @@ func.func @main() { call @cos(): () -> () call @atan() : () -> () call @atan2() : () -> () + call @cbrt() : () -> () return } + + -- 2.7.4