}
//----------------------------------------------------------------------------//
+// ExpM1 approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+
+struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(math::ExpM1Op op,
+ PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+LogicalResult
+ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
+ PatternRewriter &rewriter) const {
+ auto width = vectorWidth(op.operand().getType(), isF32);
+ if (!width.hasValue())
+ return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ auto bcast = [&](Value value) -> Value {
+ return broadcast(builder, value, *width);
+ };
+
+ // expm1(x) = exp(x) - 1 = u - 1.
+ // We have to handle it carefully when x is near 0, i.e. u ~= 1,
+ // and when the input is ~= -inf, i.e. u - 1 ~= -1.
+ Value cstOne = bcast(f32Cst(builder, 1.0f));
+ Value cstNegOne = bcast(f32Cst(builder, -1.0f));
+ Value x = op.operand();
+ Value u = builder.create<math::ExpOp>(x);
+ Value uEqOne = builder.create<CmpFOp>(CmpFPredicate::OEQ, u, cstOne);
+ Value uMinusOne = builder.create<SubFOp>(u, cstOne);
+ Value uMinusOneEqNegOne =
+ builder.create<CmpFOp>(CmpFPredicate::OEQ, uMinusOne, cstNegOne);
+ // logU = log(u) ~= x
+ Value logU = builder.create<math::LogOp>(u);
+
+ // Detect exp(x) = +inf; written this way to avoid having to form +inf.
+ Value isInf = builder.create<CmpFOp>(CmpFPredicate::OEQ, logU, u);
+
+ // (u - 1) * (x / ~x)
+ Value expm1 =
+ builder.create<MulFOp>(uMinusOne, builder.create<DivFOp>(x, logU));
+ expm1 = builder.create<SelectOp>(isInf, u, expm1);
+ Value approximation = builder.create<SelectOp>(
+ uEqOne, x, builder.create<SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
+ rewriter.replaceOp(op, approximation);
+ return success();
+}
+
+//----------------------------------------------------------------------------//
void mlir::populateMathPolynomialApproximationPatterns(
RewritePatternSet &patterns) {
patterns.add<TanhApproximation, LogApproximation, Log2Approximation,
- Log1pApproximation, ExpApproximation>(patterns.getContext());
+ Log1pApproximation, ExpApproximation, ExpM1Approximation>(
+ patterns.getContext());
}
%1 = math.log %0 : f32
%2 = math.log2 %1 : f32
%3 = math.log1p %2 : f32
- return %3 : f32
+ // CHECK-NOT: exp
+ %4 = math.exp %3 : f32
+ %5 = math.expm1 %4 : f32
+ return %5 : f32
}
// CHECK-LABEL: @vector
%1 = math.log %0 : vector<8xf32>
%2 = math.log2 %1 : vector<8xf32>
%3 = math.log1p %2 : vector<8xf32>
- return %3 : vector<8xf32>
-}
-
-// CHECK-LABEL: @exp_scalar
-func @exp_scalar(%arg0: f32) -> f32 {
- %0 = math.exp %arg0 : f32
- return %0 : f32
-}
-
-// CHECK-LABEL: @exp_vector
-func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
- // CHECK-NOT: math.exp
- %0 = math.exp %arg0 : vector<8xf32>
- return %0 : vector<8xf32>
+ // CHECK-NOT: exp
+ %4 = math.exp %3 : vector<8xf32>
+ %5 = math.expm1 %4 : vector<8xf32>
+ return %5 : vector<8xf32>
}
return
}
+func @expm1() {
+ // CHECK: 1e-10
+ %0 = constant 1.0e-10 : f32
+ %1 = math.expm1 %0 : f32
+ vector.print %1 : f32
+
+ // CHECK: -0.00995016, 0.0100502, 0.648721, 6.38905
+ %2 = constant dense<[-0.01, 0.01, 0.5, 2.0]> : vector<4xf32>
+ %3 = math.expm1 %2 : vector<4xf32>
+ vector.print %3 : vector<4xf32>
+
+ // CHECK: -0.181269, 0, 0.221403, 0.491825, 0.822119, 1.22554, 1.71828, 2.32012
+ %4 = constant dense<[-0.2, 0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2]> : vector<8xf32>
+ %5 = math.expm1 %4 : vector<8xf32>
+ vector.print %5 : vector<8xf32>
+
+ // CHECK: -1
+ %neg_inf = constant 0xff800000 : f32
+ %expm1_neg_inf = math.expm1 %neg_inf : f32
+ vector.print %expm1_neg_inf : f32
+
+ // CHECK: inf
+ %inf = constant 0x7f800000 : f32
+ %expm1_inf = math.expm1 %inf : f32
+ vector.print %expm1_inf : f32
+
+ // CHECK: -1, inf, 1e-10
+ %special_vec = constant dense<[0xff800000, 0x7f800000, 1.0e-10]> : vector<3xf32>
+ %log_special_vec = math.expm1 %special_vec : vector<3xf32>
+ vector.print %log_special_vec : vector<3xf32>
+
+ return
+}
+
func @main() {
call @tanh(): () -> ()
call @log(): () -> ()
call @log2(): () -> ()
call @log1p(): () -> ()
call @exp(): () -> ()
+ call @expm1(): () -> ()
return
}