From 8a9d4895df780231a14a1afc44e18b1f6b7eab93 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Tue, 21 Feb 2023 08:18:45 +0100 Subject: [PATCH] [mlir] Clean-up math -> libm/llvm conversion. At the moment, there is an optional log1pBenefit populateMathToLibmConversionPatterns which is used to increase the priority of the log1p->libm pattern compared to log1p->llvm pattern that approximates log1p with precision issues. Instead, we can have a flag for the MathToLLVM pass to enable or disable the imprecise approximation. Differential Revision: https://reviews.llvm.org/D144450 --- .../mlir/Conversion/MathToLLVM/MathToLLVM.h | 3 +- .../mlir/Conversion/MathToLibm/MathToLibm.h | 7 +-- mlir/include/mlir/Conversion/Passes.td | 7 ++- mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 8 ++- mlir/lib/Conversion/MathToLibm/MathToLibm.cpp | 65 ++++++++-------------- 5 files changed, 37 insertions(+), 53 deletions(-) diff --git a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h index d0fc2e3..b2e5db3 100644 --- a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h +++ b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h @@ -21,7 +21,8 @@ class Pass; #include "mlir/Conversion/Passes.h.inc" void populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns); + RewritePatternSet &patterns, + bool approximateLog1p = true); } // namespace mlir #endif // MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H diff --git a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h index cd79e8e..ab9a1ce 100644 --- a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h +++ b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h @@ -8,8 +8,7 @@ #ifndef MLIR_CONVERSION_MATHTOLIBM_MATHTOLIBM_H_ #define MLIR_CONVERSION_MATHTOLIBM_MATHTOLIBM_H_ -#include "mlir/Transforms/DialectConversion.h" -#include +#include "mlir/IR/PatternMatch.h" namespace mlir { template @@ -20,9 +19,7 @@ class OperationPass; /// Populate the given list with patterns that convert from Math to Libm calls. /// If log1pBenefit is present, use it instead of benefit for the Log1p op. -void populateMathToLibmConversionPatterns( - RewritePatternSet &patterns, PatternBenefit benefit, - std::optional log1pBenefit = std::nullopt); +void populateMathToLibmConversionPatterns(RewritePatternSet &patterns); /// Create a pass to convert Math operations to libm calls. std::unique_ptr> createConvertMathToLibmPass(); diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 02299f4..33502ab 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -561,10 +561,11 @@ def ConvertMathToLibm : Pass<"convert-math-to-libm", "ModuleOp"> { def ConvertMathToLLVMPass : Pass<"convert-math-to-llvm"> { let summary = "Convert Math dialect to LLVM dialect"; - let description = [{ - This pass converts supported Math ops to LLVM dialect intrinsics. - }]; let dependentDialects = ["LLVM::LLVMDialect"]; + let options = [ + Option<"approximateLog1p", "approximate-log1p", "bool", "true", + "Enable approximation of Log1p."> + ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp index 888c512..c331f4f 100644 --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -291,7 +291,7 @@ struct ConvertMathToLLVMPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); LLVMTypeConverter converter(&getContext()); - populateMathToLLVMConversionPatterns(converter, patterns); + populateMathToLLVMConversionPatterns(converter, patterns, approximateLog1p); LLVMConversionTarget target(getContext()); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) @@ -301,7 +301,10 @@ struct ConvertMathToLLVMPass } // namespace void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns) { + RewritePatternSet &patterns, + bool approximateLog1p) { + if (approximateLog1p) + patterns.add(converter); // clang-format off patterns.add< AbsFOpLowering, @@ -319,7 +322,6 @@ void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, FloorOpLowering, FmaOpLowering, Log10OpLowering, - Log1pOpLowering, Log2OpLowering, LogOpLowering, PowFOpLowering, diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp index 93b58e2..35ac2b3 100644 --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -14,11 +14,10 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" -#include +#include "mlir/Transforms/DialectConversion.h" namespace mlir { #define GEN_PASS_DEF_CONVERTMATHTOLIBM @@ -52,8 +51,8 @@ struct ScalarOpToLibmCall : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc, - StringRef doubleFunc, PatternBenefit benefit) - : OpRewritePattern(context, benefit), floatFunc(floatFunc), + StringRef doubleFunc) + : OpRewritePattern(context), floatFunc(floatFunc), doubleFunc(doubleFunc){}; LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; @@ -152,53 +151,37 @@ ScalarOpToLibmCall::matchAndRewrite(Op op, return success(); } -void mlir::populateMathToLibmConversionPatterns( - RewritePatternSet &patterns, PatternBenefit benefit, - std::optional log1pBenefit) { +void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) { + MLIRContext *ctx = patterns.getContext(); patterns.add, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp>( - patterns.getContext(), benefit); + ctx); patterns.add, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32, - PromoteOpToF32, PromoteOpToF32>( - patterns.getContext(), benefit); - patterns.add>(patterns.getContext(), "atanf", - "atan", benefit); - patterns.add>(patterns.getContext(), - "atan2f", "atan2", benefit); - patterns.add>(patterns.getContext(), "cbrtf", - "cbrt", benefit); - patterns.add>(patterns.getContext(), "erff", - "erf", benefit); - patterns.add>(patterns.getContext(), - "expm1f", "expm1", benefit); - patterns.add>(patterns.getContext(), "tanf", - "tan", benefit); - patterns.add>(patterns.getContext(), "tanhf", - "tanh", benefit); - patterns.add>( - patterns.getContext(), "roundevenf", "roundeven", benefit); - patterns.add>(patterns.getContext(), - "roundf", "round", benefit); - patterns.add>(patterns.getContext(), "cosf", - "cos", benefit); - patterns.add>(patterns.getContext(), "sinf", - "sin", benefit); - patterns.add>( - patterns.getContext(), "log1pf", "log1p", log1pBenefit.value_or(benefit)); - patterns.add>(patterns.getContext(), - "floorf", "floor", benefit); - patterns.add>(patterns.getContext(), "ceilf", - "ceil", benefit); - patterns.add>(patterns.getContext(), - "truncf", "trunc", benefit); + PromoteOpToF32, PromoteOpToF32>(ctx); + patterns.add>(ctx, "atanf", "atan"); + patterns.add>(ctx, "atan2f", "atan2"); + patterns.add>(ctx, "cbrtf", "cbrt"); + patterns.add>(ctx, "erff", "erf"); + patterns.add>(ctx, "expm1f", "expm1"); + patterns.add>(ctx, "tanf", "tan"); + patterns.add>(ctx, "tanhf", "tanh"); + patterns.add>(ctx, "roundevenf", + "roundeven"); + patterns.add>(ctx, "roundf", "round"); + patterns.add>(ctx, "cosf", "cos"); + patterns.add>(ctx, "sinf", "sin"); + patterns.add>(ctx, "log1pf", "log1p"); + patterns.add>(ctx, "floorf", "floor"); + patterns.add>(ctx, "ceilf", "ceil"); + patterns.add>(ctx, "truncf", "trunc"); } namespace { @@ -212,7 +195,7 @@ void ConvertMathToLibmPass::runOnOperation() { auto module = getOperation(); RewritePatternSet patterns(&getContext()); - populateMathToLibmConversionPatterns(patterns, /*benefit=*/1); + populateMathToLibmConversionPatterns(patterns); ConversionTarget target(getContext()); target.addLegalDialect