From b8c6b15283000f1f065acd10d487ef87df0542c9 Mon Sep 17 00:00:00 2001 From: Chia-hung Duan Date: Sat, 4 Dec 2021 04:35:24 +0000 Subject: [PATCH] [mlir] Support collecting logs from notifyMatchFailure(). Let the user registers their own handler to processing the matching failure information. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D110896 --- mlir/include/mlir/Transforms/DialectConversion.h | 24 ++++++++------ mlir/lib/Transforms/Utils/DialectConversion.cpp | 41 ++++++++++++++++++------ 2 files changed, 47 insertions(+), 18 deletions(-) diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index e66dbbc..d5fb2938 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -525,7 +525,7 @@ struct ConversionPatternRewriterImpl; /// hooks. class ConversionPatternRewriter final : public PatternRewriter { public: - ConversionPatternRewriter(MLIRContext *ctx); + explicit ConversionPatternRewriter(MLIRContext *ctx); ~ConversionPatternRewriter() override; /// Apply a signature conversion to the entry block of the given region. This @@ -932,14 +932,20 @@ LogicalResult applyFullConversion(Operation *op, ConversionTarget &target, /// provided 'convertedOps' set; note that no actual rewrites are applied to the /// operations on success and only pre-existing operations are added to the set. /// This method only returns failure if there are unreachable blocks in any of -/// the regions nested within 'ops'. -LogicalResult applyAnalysisConversion(ArrayRef ops, - ConversionTarget &target, - const FrozenRewritePatternSet &patterns, - DenseSet &convertedOps); -LogicalResult applyAnalysisConversion(Operation *op, ConversionTarget &target, - const FrozenRewritePatternSet &patterns, - DenseSet &convertedOps); +/// the regions nested within 'ops'. There's an additional argument +/// `notifyCallback` which is used for collecting match failure diagnostics +/// generated during the conversion. Diagnostics are only reported to this +/// callback may only be available in debug mode. +LogicalResult applyAnalysisConversion( + ArrayRef ops, ConversionTarget &target, + const FrozenRewritePatternSet &patterns, + DenseSet &convertedOps, + function_ref notifyCallback = nullptr); +LogicalResult applyAnalysisConversion( + Operation *op, ConversionTarget &target, + const FrozenRewritePatternSet &patterns, + DenseSet &convertedOps, + function_ref notifyCallback = nullptr); } // end namespace mlir #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_ diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 1d793f9..ad34eeb 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -851,8 +851,9 @@ void ArgConverter::insertConversion(Block *newBlock, namespace mlir { namespace detail { struct ConversionPatternRewriterImpl { - ConversionPatternRewriterImpl(PatternRewriter &rewriter) - : argConverter(rewriter, unresolvedMaterializations) {} + explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter) + : argConverter(rewriter, unresolvedMaterializations), + notifyCallback(nullptr) {} /// Cleanup and destroy any generated rewrite operations. This method is /// invoked when the conversion process fails. @@ -1004,6 +1005,9 @@ struct ConversionPatternRewriterImpl { /// active. TypeConverter *currentTypeConverter = nullptr; + /// This allows the user to collect the match failure message. + function_ref notifyCallback; + #ifndef NDEBUG /// A set of operations that have pending updates. This tracking isn't /// strictly necessary, and is thus only active during debug builds for extra @@ -1475,6 +1479,8 @@ LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure( Diagnostic diag(loc, DiagnosticSeverity::Remark); reasonCallback(diag); logger.startLine() << "** Failure : " << diag.str() << "\n"; + if (notifyCallback) + notifyCallback(diag); }); return failure(); } @@ -1949,7 +1955,16 @@ OperationLegalizer::legalizeWithPattern(Operation *op, // Functor that cleans up the rewriter state after a pattern failed to match. RewriterState curState = rewriterImpl.getCurrentState(); auto onFailure = [&](const Pattern &pattern) { - LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern failed to match")); + LLVM_DEBUG({ + logFailure(rewriterImpl.logger, "pattern failed to match"); + if (rewriterImpl.notifyCallback) { + Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark); + diag << "Failed to apply pattern \"" << pattern.getDebugName() + << "\" on op:\n" + << *op; + rewriterImpl.notifyCallback(diag); + } + }); rewriterImpl.resetState(curState); appliedPatterns.erase(&pattern); }; @@ -2333,7 +2348,9 @@ struct OperationConverter { : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {} /// Converts the given operations to the conversion target. - LogicalResult convertOperations(ArrayRef ops); + LogicalResult + convertOperations(ArrayRef ops, + function_ref notifyCallback = nullptr); private: /// Converts an operation with the given rewriter. @@ -2410,7 +2427,9 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter, return success(); } -LogicalResult OperationConverter::convertOperations(ArrayRef ops) { +LogicalResult OperationConverter::convertOperations( + ArrayRef ops, + function_ref notifyCallback) { if (ops.empty()) return success(); ConversionTarget &target = opLegalizer.getTarget(); @@ -2428,6 +2447,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { // Convert each operation and discard rewrites on failure. ConversionPatternRewriter rewriter(ops.front()->getContext()); ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); + rewriterImpl.notifyCallback = notifyCallback; + for (auto *op : toConvert) if (failed(convert(rewriter, op))) return rewriterImpl.discardRewrites(), failure(); @@ -3275,15 +3296,17 @@ LogicalResult mlir::applyAnalysisConversion(ArrayRef ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, - DenseSet &convertedOps) { + DenseSet &convertedOps, + function_ref notifyCallback) { OperationConverter opConverter(target, patterns, OpConversionMode::Analysis, &convertedOps); - return opConverter.convertOperations(ops); + return opConverter.convertOperations(ops, notifyCallback); } LogicalResult mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target, const FrozenRewritePatternSet &patterns, - DenseSet &convertedOps) { + DenseSet &convertedOps, + function_ref notifyCallback) { return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns, - convertedOps); + convertedOps, notifyCallback); } -- 2.7.4