[mlir] Add lowering from math.expm1 to LLVM.
authorAdrian Kuegel <akuegel@google.com>
Tue, 16 Feb 2021 13:34:35 +0000 (14:34 +0100)
committerAdrian Kuegel <akuegel@google.com>
Tue, 4 May 2021 12:22:10 +0000 (14:22 +0200)
Differential Revision: https://reviews.llvm.org/D96776

mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir

index 8df6504..5f94804 100644 (file)
@@ -2352,6 +2352,60 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering {
   }
 };
 
+// A `expm1` is converted into `exp - 1`.
+struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
+  using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(math::ExpM1Op op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    math::ExpM1Op::Adaptor transformed(operands);
+    auto operandType = transformed.operand().getType();
+
+    if (!operandType || !LLVM::isCompatibleType(operandType))
+      return failure();
+
+    auto loc = op.getLoc();
+    auto resultType = op.getResult().getType();
+    auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
+    auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
+
+    if (!operandType.isa<LLVM::LLVMArrayType>()) {
+      LLVM::ConstantOp one;
+      if (LLVM::isCompatibleVectorType(operandType)) {
+        one = rewriter.create<LLVM::ConstantOp>(
+            loc, operandType,
+            SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
+      } else {
+        one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
+      }
+      auto exp = rewriter.create<LLVM::ExpOp>(loc, transformed.operand());
+      rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
+      return success();
+    }
+
+    auto vectorType = resultType.dyn_cast<VectorType>();
+    if (!vectorType)
+      return rewriter.notifyMatchFailure(op, "expected vector result type");
+
+    return handleMultidimensionalVectors(
+        op.getOperation(), operands, *getTypeConverter(),
+        [&](Type llvm1DVectorTy, ValueRange operands) {
+          auto splatAttr = SplatElementsAttr::get(
+              mlir::VectorType::get(
+                  {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
+                  floatType),
+              floatOne);
+          auto one =
+              rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
+          auto exp =
+              rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
+          return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
+        },
+        rewriter);
+  }
+};
+
 // A `log1p` is converted into `log(1 + ...)`.
 struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
   using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
@@ -3924,6 +3978,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
       DivFOpLowering,
       ExpOpLowering,
       Exp2OpLowering,
+      ExpM1OpLowering,
       FloorFOpLowering,
       FmaFOpLowering,
       GenericAtomicRMWOpLowering,
index 4a2983b..a743e31 100644 (file)
@@ -37,6 +37,18 @@ func @log1p_2dvector(%arg0 : vector<4x3xf32>) {
 
 // -----
 
+// CHECK-LABEL: func @expm1(
+// CHECK-SAME: f32
+func @expm1(%arg0 : f32) {
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
+  // CHECK: %[[EXP:.*]] = "llvm.intr.exp"(%arg0) : (f32) -> f32
+  // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : f32
+  %0 = math.expm1 %arg0 : f32
+  std.return
+}
+
+// -----
+
 // CHECK-LABEL: func @rsqrt(
 // CHECK-SAME: f32
 func @rsqrt(%arg0 : f32) {