From e6d90a0d5e202166a9846f1845196086aa02f35e Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Wed, 24 May 2023 16:14:47 +0200 Subject: [PATCH] [mlir][Transforms] GreedyPatternRewriteDriver debugging: Detect faulty patterns Compute operation finger prints to detect incorrect API usage in RewritePatterns. Does not work for dialect conversion patterns. Detect patterns that: * Returned `failure` but changed the IR. * Returned `success` but did not change the IR. * Inserted/removed/modified ops, bypassing the rewriter. Not all cases are detected. These new checks are quite expensive, so they are only enabled with `-DMLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON`. Failures manifest as fatal errors (`llvm::report_fatal_error`) or crashes (accessing deallocated memory). To get better debugging information, run `mlir-opt -debug` (to see which pattern is broken) with ASAN (to see where memory was deallocated). Differential Revision: https://reviews.llvm.org/D144552 --- mlir/CMakeLists.txt | 4 + mlir/include/mlir/Config/mlir-config.h.cmake | 22 ++++ mlir/include/mlir/IR/PatternMatch.h | 32 +++++ .../Utils/GreedyPatternRewriteDriver.cpp | 135 ++++++++++++++++++++- 4 files changed, 190 insertions(+), 3 deletions(-) create mode 100644 mlir/include/mlir/Config/mlir-config.h.cmake diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index c9b0d53..cd38836 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -141,6 +141,10 @@ set(MLIR_INSTALL_AGGREGATE_OBJECTS 1 CACHE BOOL set(MLIR_BUILD_MLIR_C_DYLIB 0 CACHE BOOL "Builds libMLIR-C shared library.") +configure_file( + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Config/mlir-config.h.cmake + ${MLIR_INCLUDE_DIR}/mlir/Config/mlir-config.h) + #------------------------------------------------------------------------------- # Python Bindings Configuration # Requires: diff --git a/mlir/include/mlir/Config/mlir-config.h.cmake b/mlir/include/mlir/Config/mlir-config.h.cmake new file mode 100644 index 0000000..2bcc9bf --- /dev/null +++ b/mlir/include/mlir/Config/mlir-config.h.cmake @@ -0,0 +1,22 @@ +//===- mlir-config.h - MLIR configuration ------------------------*- C -*-===*// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +/* This file enumerates variables from the MLIR configuration so that they + can be in exported headers and won't override package specific directives. + This is a C header that can be included in the mlir-c headers. */ + +#ifndef MLIR_CONFIG_H +#define MLIR_CONFIG_H + +/* Enable expensive checks to detect invalid pattern API usage. Failed checks + manifest as fatal errors or invalid memory accesses (e.g., accessing + deallocated memory) that cause a crash. Running with ASAN is recommended for + easier debugging. */ +#cmakedefine01 MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + +#endif diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 600ace4..4614649 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -429,6 +429,38 @@ public: static bool classof(const OpBuilder::Listener *base); }; + /// A listener that forwards all notifications to another listener. This + /// struct can be used as a base to create listener chains, so that multiple + /// listeners can be notified of IR changes. + struct ForwardingListener : public RewriterBase::Listener { + ForwardingListener(Listener *listener) : listener(listener) {} + + void notifyOperationInserted(Operation *op) override { + listener->notifyOperationInserted(op); + } + void notifyBlockCreated(Block *block) override { + listener->notifyBlockCreated(block); + } + void notifyOperationModified(Operation *op) override { + listener->notifyOperationModified(op); + } + void notifyOperationReplaced(Operation *op, + ValueRange replacement) override { + listener->notifyOperationReplaced(op, replacement); + } + void notifyOperationRemoved(Operation *op) override { + listener->notifyOperationRemoved(op); + } + LogicalResult notifyMatchFailure( + Location loc, + function_ref reasonCallback) override { + return listener->notifyMatchFailure(loc, reasonCallback); + } + + private: + Listener *listener; + }; + /// Move the blocks that belong to "region" before the given position in /// another region "parent". The two regions must be different. The caller /// is responsible for creating or updating the operation transferring flow diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index f6e7fa1..c05b639 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -11,6 +11,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Config/mlir-config.h" #include "mlir/IR/Action.h" #include "mlir/IR/Matchers.h" #include "mlir/Interfaces/SideEffectInterfaces.h" @@ -30,10 +32,108 @@ using namespace mlir; #define DEBUG_TYPE "greedy-rewriter" //===----------------------------------------------------------------------===// -// GreedyPatternRewriteDriver +// Debugging Infrastructure //===----------------------------------------------------------------------===// namespace { +#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS +/// A helper struct that stores finger prints of ops in order to detect broken +/// RewritePatterns. A rewrite pattern is broken if it modifies IR without +/// using the rewriter API or if it returns an inconsistent return value. +struct DebugFingerPrints : public RewriterBase::ForwardingListener { + DebugFingerPrints(RewriterBase::Listener *driver) + : RewriterBase::ForwardingListener(driver) {} + + /// Compute finger prints of the given op and its nested ops. + void computeFingerPrints(Operation *topLevel) { + this->topLevel = topLevel; + this->topLevelFingerPrint.emplace(topLevel); + topLevel->walk([&](Operation *op) { fingerprints.try_emplace(op, op); }); + } + + /// Clear all finger prints. + void clear() { + topLevel = nullptr; + topLevelFingerPrint.reset(); + fingerprints.clear(); + } + + void notifyRewriteSuccess() { + // Pattern application success => IR must have changed. + OperationFingerPrint afterFingerPrint(topLevel); + if (*topLevelFingerPrint == afterFingerPrint) { + // Note: Run "mlir-opt -debug" to see which pattern is broken. + llvm::report_fatal_error( + "pattern returned success but IR did not change"); + } + for (const auto &it : fingerprints) { + // Skip top-level op, its finger print is never invalidated. + if (it.first == topLevel) + continue; + // Note: Finger print computation may crash when an op was erased + // without notifying the rewriter. (Run with ASAN to see where the op was + // erased; the op was probably erased directly, bypassing the rewriter + // API.) Finger print computation does may not crash if a new op was + // created at the same memory location. (But then the finger print should + // have changed.) + if (it.second != OperationFingerPrint(it.first)) { + // Note: Run "mlir-opt -debug" to see which pattern is broken. + llvm::report_fatal_error("operation finger print changed"); + } + } + } + + void notifyRewriteFailure() { + // Pattern application failure => IR must not have changed. + OperationFingerPrint afterFingerPrint(topLevel); + if (*topLevelFingerPrint != afterFingerPrint) { + // Note: Run "mlir-opt -debug" to see which pattern is broken. + llvm::report_fatal_error("pattern returned failure but IR did change"); + } + } + +protected: + /// Invalidate the finger print of the given op, i.e., remove it from the map. + void invalidateFingerPrint(Operation *op) { + // Invalidate all finger prints until the top level. + while (op && op != topLevel) { + fingerprints.erase(op); + op = op->getParentOp(); + } + } + + void notifyOperationInserted(Operation *op) override { + RewriterBase::ForwardingListener::notifyOperationInserted(op); + invalidateFingerPrint(op->getParentOp()); + } + + void notifyOperationModified(Operation *op) override { + RewriterBase::ForwardingListener::notifyOperationModified(op); + invalidateFingerPrint(op); + } + + void notifyOperationRemoved(Operation *op) override { + RewriterBase::ForwardingListener::notifyOperationRemoved(op); + op->walk([this](Operation *op) { invalidateFingerPrint(op); }); + } + + /// Operation finger prints to detect invalid pattern API usage. IR is checked + /// against these finger prints after pattern application to detect cases + /// where IR was modified directly, bypassing the rewriter API. + DenseMap fingerprints; + + /// Top-level operation of the current greedy rewrite. + Operation *topLevel = nullptr; + + /// Finger print of the top-level operation. + std::optional topLevelFingerPrint; +}; +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + +//===----------------------------------------------------------------------===// +// GreedyPatternRewriteDriver +//===----------------------------------------------------------------------===// + /// This is a worklist-driven driver for the PatternMatcher, which repeatedly /// applies the locally optimal patterns. /// @@ -122,21 +222,36 @@ private: /// The low-level pattern applicator. PatternApplicator matcher; + +#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + DebugFingerPrints debugFingerPrints; +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS }; } // namespace GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, const GreedyRewriteConfig &config) - : PatternRewriter(ctx), folder(ctx, this), config(config), - matcher(patterns) { + : PatternRewriter(ctx), folder(ctx, this), config(config), matcher(patterns) +#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + // clang-format off + , debugFingerPrints(this) +// clang-format on +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS +{ worklist.reserve(64); // Apply a simple cost model based solely on pattern benefit. matcher.applyDefaultCostModel(); // Set up listener. +#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + // Send IR notifications to the debug handler. This handler will then forward + // all notifications to this GreedyPatternRewriteDriver. + setListener(&debugFingerPrints); +#else setListener(this); +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS } bool GreedyPatternRewriteDriver::processWorklist() { @@ -231,15 +346,28 @@ bool GreedyPatternRewriteDriver::processWorklist() { function_ref onSuccess = {}; #endif +#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + debugFingerPrints.computeFingerPrints( + /*topLevel=*/config.scope ? config.scope->getParentOp() : op); + auto clearFingerprints = + llvm::make_scope_exit([&]() { debugFingerPrints.clear(); }); +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + LogicalResult matchResult = matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess); if (succeeded(matchResult)) { LLVM_DEBUG(logResultWithLine("success", "pattern matched")); +#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + debugFingerPrints.notifyRewriteSuccess(); +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS changed = true; ++numRewrites; } else { LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match")); +#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + debugFingerPrints.notifyRewriteFailure(); +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS } } @@ -247,6 +375,7 @@ bool GreedyPatternRewriteDriver::processWorklist() { } void GreedyPatternRewriteDriver::addToWorklist(Operation *op) { + assert(op && "expected valid op"); // Gather potential ancestors while looking for a "scope" parent region. SmallVector ancestors; Region *region = nullptr; -- 2.7.4