}
};
+// A `log1p` is converted into `log(1 + ...)`.
+struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
+ using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(math::Log1pOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ math::Log1pOp::Adaptor transformed(operands);
+ auto operandType = transformed.operand().getType();
+
+ if (!operandType || !LLVM::isCompatibleType(operandType))
+ return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+ auto loc = op.getLoc();
+ auto resultType = op.getResult().getType();
+ auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
+ auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
+
+ if (!operandType.isa<LLVM::LLVMArrayType>()) {
+ LLVM::ConstantOp one =
+ LLVM::isCompatibleVectorType(operandType)
+ ? rewriter.create<LLVM::ConstantOp>(
+ loc, operandType,
+ SplatElementsAttr::get(resultType.cast<ShapedType>(),
+ floatOne))
+ : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
+
+ auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
+ transformed.operand());
+ rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
+ return success();
+ }
+
+ auto vectorType = resultType.dyn_cast<VectorType>();
+ if (!vectorType)
+ return rewriter.notifyMatchFailure(op, "expected vector result type");
+
+ return handleMultidimensionalVectors(
+ op.getOperation(), operands, *getTypeConverter(),
+ [&](Type llvm1DVectorTy, ValueRange operands) {
+ auto splatAttr = SplatElementsAttr::get(
+ mlir::VectorType::get(
+ {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
+ floatType),
+ floatOne);
+ auto one =
+ rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
+ auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
+ transformed.operand());
+ return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
+ },
+ rewriter);
+ }
+};
+
// A `rsqrt` is converted into `1 / sqrt`.
struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
GenericAtomicRMWOpLowering,
LogOpLowering,
Log10OpLowering,
+ Log1pOpLowering,
Log2OpLowering,
FPExtLowering,
FPToSILowering,