From 74f6138bd98f480be2bd39d8ecc2cf66089739c3 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 15 Mar 2021 14:39:19 -0700 Subject: [PATCH] [mlir] Add lowering from math::Log1p to LLVM [mlir] Add lowering from math::Log1p to LLVM Reviewed By: cota Differential Revision: https://reviews.llvm.org/D98662 --- .../Conversion/StandardToLLVM/StandardToLLVM.cpp | 56 ++++++++++++++++++++++ .../StandardToLLVM/standard-to-llvm.mlir | 12 +++++ 2 files changed, 68 insertions(+) diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index b3a2bb6..de1df34 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2303,6 +2303,61 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering { } }; +// A `log1p` is converted into `log(1 + ...)`. +struct Log1pOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(math::Log1pOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + math::Log1pOp::Adaptor transformed(operands); + auto operandType = transformed.operand().getType(); + + if (!operandType || !LLVM::isCompatibleType(operandType)) + return rewriter.notifyMatchFailure(op, "unsupported operand type"); + + auto loc = op.getLoc(); + auto resultType = op.getResult().getType(); + auto floatType = getElementTypeOrSelf(resultType).cast(); + auto floatOne = rewriter.getFloatAttr(floatType, 1.0); + + if (!operandType.isa()) { + LLVM::ConstantOp one = + LLVM::isCompatibleVectorType(operandType) + ? rewriter.create( + loc, operandType, + SplatElementsAttr::get(resultType.cast(), + floatOne)) + : rewriter.create(loc, operandType, floatOne); + + auto add = rewriter.create(loc, operandType, one, + transformed.operand()); + rewriter.replaceOpWithNewOp(op, operandType, add); + return success(); + } + + auto vectorType = resultType.dyn_cast(); + if (!vectorType) + return rewriter.notifyMatchFailure(op, "expected vector result type"); + + return handleMultidimensionalVectors( + op.getOperation(), operands, *getTypeConverter(), + [&](Type llvm1DVectorTy, ValueRange operands) { + auto splatAttr = SplatElementsAttr::get( + mlir::VectorType::get( + {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, + floatType), + floatOne); + auto one = + rewriter.create(loc, llvm1DVectorTy, splatAttr); + auto add = rewriter.create(loc, llvm1DVectorTy, one, + transformed.operand()); + return rewriter.create(loc, llvm1DVectorTy, add); + }, + rewriter); + } +}; + // A `rsqrt` is converted into `1 / sqrt`. struct RsqrtOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -3788,6 +3843,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns( GenericAtomicRMWOpLowering, LogOpLowering, Log10OpLowering, + Log1pOpLowering, Log2OpLowering, FPExtLowering, FPToSILowering, diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir index fcb5b1c..5eca81d 100644 --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -12,6 +12,18 @@ func @address_space(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>, 7>) { // ----- +// CHECK-LABEL: func @log1p( +// CHECK-SAME: f32 +func @log1p(%arg0 : f32) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 + // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %arg0 : f32 + // CHECK: %[[LOG:.*]] = "llvm.intr.log"(%[[ADD]]) : (f32) -> f32 + %0 = math.log1p %arg0 : f32 + std.return +} + +// ----- + // CHECK-LABEL: func @rsqrt( // CHECK-SAME: f32 func @rsqrt(%arg0 : f32) { -- 2.7.4