[mlir] Debug print pattern before and after matchAndRewrite call
authorButygin <ivan.butygin@intel.com>
Sat, 10 Apr 2021 16:38:11 +0000 (19:38 +0300)
committerButygin <ivan.butygin@intel.com>
Sat, 8 May 2021 09:00:36 +0000 (12:00 +0300)
Motivation: we have passes with lot of rewrites and when one one them segfaults or asserts, it is very hard to find waht exactly pattern failed without debug info.

Differential Revision: https://reviews.llvm.org/D101443

mlir/include/mlir/IR/PatternMatch.h
mlir/lib/Rewrite/PatternApplicator.cpp

index 6d7a506..b2161cf 100644 (file)
@@ -12,6 +12,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "llvm/ADT/FunctionExtras.h"
+#include "llvm/Support/TypeName.h"
 
 namespace mlir {
 
@@ -132,6 +133,13 @@ public:
     return contextAndHasBoundedRecursion.getPointer();
   }
 
+  /// Return readable pattern name. Should only be used for debugging purposes.
+  /// Can be empty.
+  StringRef getDebugName() const { return debugName; }
+
+  /// Set readable pattern name. Should only be used for debugging purposes.
+  void setDebugName(StringRef name) { debugName = name; }
+
 protected:
   /// This class acts as a special tag that makes the desire to match "any"
   /// operation type explicit. This helps to avoid unnecessary usages of this
@@ -202,6 +210,9 @@ private:
   /// A list of the potential operations that may be generated when rewriting
   /// an op with this pattern.
   SmallVector<OperationName, 2> generatedOps;
+
+  /// Readable pattern name. Can be empty.
+  StringRef debugName;
 };
 
 //===----------------------------------------------------------------------===//
@@ -959,7 +970,9 @@ public:
     struct FnPattern final : public OpRewritePattern<OpType> {
       FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
                 MLIRContext *context)
-          : OpRewritePattern<OpType>(context), implFn(implFn) {}
+          : OpRewritePattern<OpType>(context), implFn(implFn) {
+        setDebugName(llvm::getTypeName<FnPattern>());
+      }
 
       LogicalResult matchAndRewrite(OpType op,
                                     PatternRewriter &rewriter) const override {
@@ -979,8 +992,13 @@ private:
   template <typename T, typename... Args>
   std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
   addImpl(Args &&... args) {
-    nativePatterns.emplace_back(
-        std::make_unique<T>(std::forward<Args>(args)...));
+    auto pattern = std::make_unique<T>(std::forward<Args>(args)...);
+
+    // Pattern can potentially set name in ctor. Preserve old name if present.
+    if (pattern->getDebugName().empty())
+      pattern->setDebugName(llvm::getTypeName<T>());
+
+    nativePatterns.emplace_back(std::move(pattern));
   }
   template <typename T, typename... Args>
   std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
index 1632195..ae8beff 100644 (file)
@@ -195,7 +195,13 @@ LogicalResult PatternApplicator::matchAndRewrite(
       result = success(!onSuccess || succeeded(onSuccess(*bestPattern)));
     } else {
       const auto *pattern = static_cast<const RewritePattern *>(bestPattern);
+
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Trying to match \"" << pattern->getDebugName() << "\"\n");
       result = pattern->matchAndRewrite(op, rewriter);
+      LLVM_DEBUG(llvm::dbgs() << "\"" << pattern->getDebugName() << "\" result "
+                              << succeeded(result) << "\n");
+
       if (succeeded(result) && onSuccess && failed(onSuccess(*pattern)))
         result = failure();
     }