[mlir][math] Add math.cbrt polynomial approximation
authorRobert Suderman <suderman@google.com>
Mon, 6 Mar 2023 19:09:11 +0000 (11:09 -0800)
committerRob Suderman <suderman@google.com>
Mon, 6 Mar 2023 21:29:49 +0000 (13:29 -0800)
Cbrt can be approximated with some relatively simple polynomial
operators. This includes a lit test validating the implementation
and some run tests that validate numerical correct.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D145019

mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
mlir/test/Dialect/Math/polynomial-approximation.mlir
mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir

index c0f3028..0d170f9 100644 (file)
@@ -1213,6 +1213,99 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
 }
 
 //----------------------------------------------------------------------------//
+// Cbrt approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+struct CbrtApproximation : public OpRewritePattern<math::CbrtOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(math::CbrtOp op,
+                                PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+// Estimation of cube-root using an algorithm defined in
+// Hacker's Delight 2nd Edition.
+LogicalResult
+CbrtApproximation::matchAndRewrite(math::CbrtOp op,
+                                   PatternRewriter &rewriter) const {
+  auto operand = op.getOperand();
+  if (!getElementTypeOrSelf(operand).isF32())
+    return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  ArrayRef<int64_t> shape = vectorShape(operand);
+
+  Type floatTy = getElementTypeOrSelf(operand.getType());
+  Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth());
+
+  // Convert to vector types if necessary.
+  floatTy = broadcast(floatTy, shape);
+  intTy = broadcast(intTy, shape);
+
+  auto bconst = [&](Attribute attr) -> Value {
+    Value value = b.create<arith::ConstantOp>(attr);
+    return broadcast(b, value, shape);
+  };
+
+  // Declare the initial values:
+  Value intTwo = bconst(b.getI32IntegerAttr(2));
+  Value intFour = bconst(b.getI32IntegerAttr(4));
+  Value intEight = bconst(b.getI32IntegerAttr(8));
+  Value intMagic = bconst(b.getI32IntegerAttr(0x2a5137a0));
+  Value fpThird = bconst(b.getF32FloatAttr(0.33333333f));
+  Value fpTwo = bconst(b.getF32FloatAttr(2.0f));
+  Value fpZero = bconst(b.getF32FloatAttr(0.0f));
+
+  // Compute an approximation of one third:
+  // union {int ix; float x;};
+  // x = x0;
+  // ix = ix/4 + ix/16;
+  Value absValue = b.create<math::AbsFOp>(operand);
+  Value intValue = b.create<arith::BitcastOp>(intTy, absValue);
+  Value divideBy4 = b.create<arith::ShRSIOp>(intValue, intTwo);
+  Value divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
+  intValue = b.create<arith::AddIOp>(divideBy4, divideBy16);
+
+  // ix = ix + ix/16;
+  divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
+  intValue = b.create<arith::AddIOp>(intValue, divideBy16);
+
+  // ix = ix + ix/256;
+  Value divideBy256 = b.create<arith::ShRSIOp>(intValue, intEight);
+  intValue = b.create<arith::AddIOp>(intValue, divideBy256);
+
+  // ix = 0x2a5137a0 + ix;
+  intValue = b.create<arith::AddIOp>(intValue, intMagic);
+
+  // Perform one newtons step:
+  // x = 0.33333333f*(2.0f*x + x0/(x*x));
+  Value floatValue = b.create<arith::BitcastOp>(floatTy, intValue);
+  Value squared = b.create<arith::MulFOp>(floatValue, floatValue);
+  Value mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
+  Value divSquared = b.create<arith::DivFOp>(absValue, squared);
+  floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
+  floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
+
+  // x = 0.33333333f*(2.0f*x + x0/(x*x));
+  squared = b.create<arith::MulFOp>(floatValue, floatValue);
+  mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
+  divSquared = b.create<arith::DivFOp>(absValue, squared);
+  floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
+  floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
+
+  // Check for zero and restore sign.
+  Value isZero =
+      b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absValue, fpZero);
+  floatValue = b.create<arith::SelectOp>(isZero, fpZero, floatValue);
+  floatValue = b.create<math::CopySignOp>(floatValue, operand);
+
+  rewriter.replaceOp(op, floatValue);
+  return success();
+}
+
+//----------------------------------------------------------------------------//
 // Rsqrt approximation.
 //----------------------------------------------------------------------------//
 
@@ -1291,7 +1384,7 @@ void mlir::populateMathPolynomialApproximationPatterns(
   patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
                LogApproximation, Log2Approximation, Log1pApproximation,
                ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
-               ReuseF32Expansion<math::Atan2Op>,
+               CbrtApproximation, ReuseF32Expansion<math::Atan2Op>,
                SinAndCosApproximation<true, math::SinOp>,
                SinAndCosApproximation<false, math::CosOp>>(
       patterns.getContext());
index 33ac11b..4b490e4 100644 (file)
@@ -593,3 +593,53 @@ func.func @atan2_scalar(%arg0: f16, %arg1: f16) -> f16 {
   %0 = math.atan2 %arg0, %arg1 : f16
   return %0 : f16
 }
+
+// CHECK-LABEL: @cbrt_vector
+// CHECK-SAME: %[[ARG0:.+]]: vector<4xf32>
+
+// CHECK: %[[TWO_INT:.+]] = arith.constant dense<2>
+// CHECK: %[[FOUR_INT:.+]] = arith.constant dense<4>
+// CHECK: %[[EIGHT_INT:.+]] = arith.constant dense<8>
+// CHECK: %[[MAGIC:.+]] = arith.constant dense<709965728>
+// CHECK: %[[THIRD_FP:.+]] = arith.constant dense<0.333333343> : vector<4xf32>
+// CHECK: %[[TWO_FP:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32>
+// CHECK: %[[ZERO_FP:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+
+// CHECK: %[[ABS:.+]] = math.absf %[[ARG0]] : vector<4xf32>
+
+// Perform the initial approximation:
+// CHECK: %[[CAST:.+]] = arith.bitcast %[[ABS]] : vector<4xf32> to vector<4xi32>
+// CHECK: %[[SH_TWO:.+]] = arith.shrsi %[[CAST]], %[[TWO_INT]]
+// CHECK: %[[SH_FOUR:.+]] = arith.shrsi %[[CAST]], %[[FOUR_INT]]
+// CHECK: %[[APPROX0:.+]] = arith.addi %[[SH_TWO]], %[[SH_FOUR]]
+// CHECK: %[[SH_FOUR:.+]] = arith.shrsi %[[APPROX0]], %[[FOUR_INT]]
+// CHECK: %[[APPROX1:.+]] = arith.addi %[[APPROX0]], %[[SH_FOUR]]
+// CHECK: %[[SH_EIGHT:.+]] = arith.shrsi %[[APPROX1]], %[[EIGHT_INT]]
+// CHECK: %[[APPROX2:.+]] = arith.addi %[[APPROX1]], %[[SH_EIGHT]]
+// CHECK: %[[FIX:.+]] = arith.addi %[[APPROX2]], %[[MAGIC]]
+// CHECK: %[[BCAST:.+]] = arith.bitcast %[[FIX]]
+
+// First Newton Step:
+// CHECK: %[[SQR:.+]] = arith.mulf %[[BCAST]], %[[BCAST]]
+// CHECK: %[[DOUBLE:.+]] = arith.mulf %[[BCAST]], %[[TWO_FP]]
+// CHECK: %[[DIV:.+]] = arith.divf %[[ABS]], %[[SQR]]
+// CHECK: %[[ADD:.+]] = arith.addf %[[DOUBLE]], %[[DIV]]
+// CHECK: %[[APPROX3:.+]] = arith.mulf %[[ADD]], %[[THIRD_FP]]
+
+// Second Newton Step:
+// CHECK: %[[SQR:.+]] = arith.mulf %[[APPROX3]], %[[APPROX3]]
+// CHECK: %[[DOUBLE:.+]] = arith.mulf %[[APPROX3]], %[[TWO_FP]]
+// CHECK: %[[DIV:.+]] = arith.divf %[[ABS]], %[[SQR]]
+// CHECK: %[[ADD:.+]] = arith.addf %[[DOUBLE]], %[[DIV]]
+// CHECK: %[[APPROX4:.+]] = arith.mulf %[[ADD]], %[[THIRD_FP]]
+
+// Check for zero special case and copy the sign:
+// CHECK: %[[CMP:.+]] = arith.cmpf oeq, %[[ABS]], %[[ZERO_FP]]
+// CHECK: %[[SEL:.+]] = arith.select %[[CMP]], %[[ZERO_FP]], %[[APPROX4]]
+// CHECK: %[[SIGN:.+]] = math.copysign %[[SEL]], %[[ARG0]]
+// CHECK: return %[[SIGN]]
+
+func.func @cbrt_vector(%arg0: vector<4xf32>) -> vector<4xf32> {
+  %0 = "math.cbrt"(%arg0) : (vector<4xf32>) -> vector<4xf32>
+  func.return %0 : vector<4xf32>
+}
\ No newline at end of file
index dbd8166..665d328 100644 (file)
@@ -568,6 +568,48 @@ func.func @atan2() {
 }
 
 
+// -------------------------------------------------------------------------- //
+// Cbrt.
+// -------------------------------------------------------------------------- //
+
+func.func @cbrt_f32(%a : f32) {
+  %r = math.cbrt %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @cbrt() {
+  // CHECK: 1
+  %a = arith.constant 1.0 : f32
+  call @cbrt_f32(%a) : (f32) -> ()
+
+  // CHECK: -1
+  %b = arith.constant -1.0 : f32
+  call @cbrt_f32(%b) : (f32) -> ()
+
+  // CHECK: 0
+  %c = arith.constant 0.0 : f32
+  call @cbrt_f32(%c) : (f32) -> ()
+
+  // CHECK: -0
+  %d = arith.constant -0.0 : f32
+  call @cbrt_f32(%d) : (f32) -> ()
+
+  // CHECK: 10
+  %e = arith.constant 1000.0 : f32
+  call @cbrt_f32(%e) : (f32) -> ()
+
+  // CHECK: -10
+  %f = arith.constant -1000.0 : f32
+  call @cbrt_f32(%f) : (f32) -> ()
+
+  // CHECK: 2.57128
+  %g = arith.constant 17.0 : f32
+  call @cbrt_f32(%g) : (f32) -> ()
+
+  return
+}
+
 func.func @main() {
   call @tanh(): () -> ()
   call @log(): () -> ()
@@ -580,5 +622,8 @@ func.func @main() {
   call @cos(): () -> ()
   call @atan() : () -> ()
   call @atan2() : () -> ()
+  call @cbrt() : () -> ()
   return
 }
+
+