}
//----------------------------------------------------------------------------//
+// Cbrt approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+struct CbrtApproximation : public OpRewritePattern<math::CbrtOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(math::CbrtOp op,
+ PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+// Estimation of cube-root using an algorithm defined in
+// Hacker's Delight 2nd Edition.
+LogicalResult
+CbrtApproximation::matchAndRewrite(math::CbrtOp op,
+ PatternRewriter &rewriter) const {
+ auto operand = op.getOperand();
+ if (!getElementTypeOrSelf(operand).isF32())
+ return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ ArrayRef<int64_t> shape = vectorShape(operand);
+
+ Type floatTy = getElementTypeOrSelf(operand.getType());
+ Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth());
+
+ // Convert to vector types if necessary.
+ floatTy = broadcast(floatTy, shape);
+ intTy = broadcast(intTy, shape);
+
+ auto bconst = [&](Attribute attr) -> Value {
+ Value value = b.create<arith::ConstantOp>(attr);
+ return broadcast(b, value, shape);
+ };
+
+ // Declare the initial values:
+ Value intTwo = bconst(b.getI32IntegerAttr(2));
+ Value intFour = bconst(b.getI32IntegerAttr(4));
+ Value intEight = bconst(b.getI32IntegerAttr(8));
+ Value intMagic = bconst(b.getI32IntegerAttr(0x2a5137a0));
+ Value fpThird = bconst(b.getF32FloatAttr(0.33333333f));
+ Value fpTwo = bconst(b.getF32FloatAttr(2.0f));
+ Value fpZero = bconst(b.getF32FloatAttr(0.0f));
+
+ // Compute an approximation of one third:
+ // union {int ix; float x;};
+ // x = x0;
+ // ix = ix/4 + ix/16;
+ Value absValue = b.create<math::AbsFOp>(operand);
+ Value intValue = b.create<arith::BitcastOp>(intTy, absValue);
+ Value divideBy4 = b.create<arith::ShRSIOp>(intValue, intTwo);
+ Value divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
+ intValue = b.create<arith::AddIOp>(divideBy4, divideBy16);
+
+ // ix = ix + ix/16;
+ divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
+ intValue = b.create<arith::AddIOp>(intValue, divideBy16);
+
+ // ix = ix + ix/256;
+ Value divideBy256 = b.create<arith::ShRSIOp>(intValue, intEight);
+ intValue = b.create<arith::AddIOp>(intValue, divideBy256);
+
+ // ix = 0x2a5137a0 + ix;
+ intValue = b.create<arith::AddIOp>(intValue, intMagic);
+
+ // Perform one newtons step:
+ // x = 0.33333333f*(2.0f*x + x0/(x*x));
+ Value floatValue = b.create<arith::BitcastOp>(floatTy, intValue);
+ Value squared = b.create<arith::MulFOp>(floatValue, floatValue);
+ Value mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
+ Value divSquared = b.create<arith::DivFOp>(absValue, squared);
+ floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
+ floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
+
+ // x = 0.33333333f*(2.0f*x + x0/(x*x));
+ squared = b.create<arith::MulFOp>(floatValue, floatValue);
+ mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
+ divSquared = b.create<arith::DivFOp>(absValue, squared);
+ floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
+ floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
+
+ // Check for zero and restore sign.
+ Value isZero =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absValue, fpZero);
+ floatValue = b.create<arith::SelectOp>(isZero, fpZero, floatValue);
+ floatValue = b.create<math::CopySignOp>(floatValue, operand);
+
+ rewriter.replaceOp(op, floatValue);
+ return success();
+}
+
+//----------------------------------------------------------------------------//
// Rsqrt approximation.
//----------------------------------------------------------------------------//
patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
LogApproximation, Log2Approximation, Log1pApproximation,
ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
- ReuseF32Expansion<math::Atan2Op>,
+ CbrtApproximation, ReuseF32Expansion<math::Atan2Op>,
SinAndCosApproximation<true, math::SinOp>,
SinAndCosApproximation<false, math::CosOp>>(
patterns.getContext());
%0 = math.atan2 %arg0, %arg1 : f16
return %0 : f16
}
+
+// CHECK-LABEL: @cbrt_vector
+// CHECK-SAME: %[[ARG0:.+]]: vector<4xf32>
+
+// CHECK: %[[TWO_INT:.+]] = arith.constant dense<2>
+// CHECK: %[[FOUR_INT:.+]] = arith.constant dense<4>
+// CHECK: %[[EIGHT_INT:.+]] = arith.constant dense<8>
+// CHECK: %[[MAGIC:.+]] = arith.constant dense<709965728>
+// CHECK: %[[THIRD_FP:.+]] = arith.constant dense<0.333333343> : vector<4xf32>
+// CHECK: %[[TWO_FP:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32>
+// CHECK: %[[ZERO_FP:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+
+// CHECK: %[[ABS:.+]] = math.absf %[[ARG0]] : vector<4xf32>
+
+// Perform the initial approximation:
+// CHECK: %[[CAST:.+]] = arith.bitcast %[[ABS]] : vector<4xf32> to vector<4xi32>
+// CHECK: %[[SH_TWO:.+]] = arith.shrsi %[[CAST]], %[[TWO_INT]]
+// CHECK: %[[SH_FOUR:.+]] = arith.shrsi %[[CAST]], %[[FOUR_INT]]
+// CHECK: %[[APPROX0:.+]] = arith.addi %[[SH_TWO]], %[[SH_FOUR]]
+// CHECK: %[[SH_FOUR:.+]] = arith.shrsi %[[APPROX0]], %[[FOUR_INT]]
+// CHECK: %[[APPROX1:.+]] = arith.addi %[[APPROX0]], %[[SH_FOUR]]
+// CHECK: %[[SH_EIGHT:.+]] = arith.shrsi %[[APPROX1]], %[[EIGHT_INT]]
+// CHECK: %[[APPROX2:.+]] = arith.addi %[[APPROX1]], %[[SH_EIGHT]]
+// CHECK: %[[FIX:.+]] = arith.addi %[[APPROX2]], %[[MAGIC]]
+// CHECK: %[[BCAST:.+]] = arith.bitcast %[[FIX]]
+
+// First Newton Step:
+// CHECK: %[[SQR:.+]] = arith.mulf %[[BCAST]], %[[BCAST]]
+// CHECK: %[[DOUBLE:.+]] = arith.mulf %[[BCAST]], %[[TWO_FP]]
+// CHECK: %[[DIV:.+]] = arith.divf %[[ABS]], %[[SQR]]
+// CHECK: %[[ADD:.+]] = arith.addf %[[DOUBLE]], %[[DIV]]
+// CHECK: %[[APPROX3:.+]] = arith.mulf %[[ADD]], %[[THIRD_FP]]
+
+// Second Newton Step:
+// CHECK: %[[SQR:.+]] = arith.mulf %[[APPROX3]], %[[APPROX3]]
+// CHECK: %[[DOUBLE:.+]] = arith.mulf %[[APPROX3]], %[[TWO_FP]]
+// CHECK: %[[DIV:.+]] = arith.divf %[[ABS]], %[[SQR]]
+// CHECK: %[[ADD:.+]] = arith.addf %[[DOUBLE]], %[[DIV]]
+// CHECK: %[[APPROX4:.+]] = arith.mulf %[[ADD]], %[[THIRD_FP]]
+
+// Check for zero special case and copy the sign:
+// CHECK: %[[CMP:.+]] = arith.cmpf oeq, %[[ABS]], %[[ZERO_FP]]
+// CHECK: %[[SEL:.+]] = arith.select %[[CMP]], %[[ZERO_FP]], %[[APPROX4]]
+// CHECK: %[[SIGN:.+]] = math.copysign %[[SEL]], %[[ARG0]]
+// CHECK: return %[[SIGN]]
+
+func.func @cbrt_vector(%arg0: vector<4xf32>) -> vector<4xf32> {
+ %0 = "math.cbrt"(%arg0) : (vector<4xf32>) -> vector<4xf32>
+ func.return %0 : vector<4xf32>
+}
\ No newline at end of file