[mlir] Change the internal representation of FrozenRewritePatternList to use shared_ptr
authorRiver Riddle <riddleriver@gmail.com>
Mon, 14 Dec 2020 20:32:21 +0000 (12:32 -0800)
committerRiver Riddle <riddleriver@gmail.com>
Mon, 14 Dec 2020 20:32:44 +0000 (12:32 -0800)
This will allow for caching pattern lists across multiple pass instances, such as when multithreading. This is an extremely important invariant for PDL patterns, which are compiled at runtime when the FrozenRewritePatternList is built.

Differential Revision: https://reviews.llvm.org/D93146

mlir/include/mlir/Rewrite/FrozenRewritePatternList.h
mlir/lib/Rewrite/FrozenRewritePatternList.cpp

index c2335b9..0e583aa 100644 (file)
@@ -18,34 +18,52 @@ class PDLByteCode;
 
 /// This class represents a frozen set of patterns that can be processed by a
 /// pattern applicator. This class is designed to enable caching pattern lists
-/// such that they need not be continuously recomputed.
+/// such that they need not be continuously recomputed. Note that all copies of
+/// this class share the same compiled pattern list, allowing for a reduction in
+/// the number of duplicated patterns that need to be created.
 class FrozenRewritePatternList {
   using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
 
 public:
   /// Freeze the patterns held in `patterns`, and take ownership.
+  FrozenRewritePatternList();
   FrozenRewritePatternList(OwningRewritePatternList &&patterns);
-  FrozenRewritePatternList(FrozenRewritePatternList &&patterns);
+  FrozenRewritePatternList(FrozenRewritePatternList &&patterns) = default;
+  FrozenRewritePatternList(const FrozenRewritePatternList &patterns) = default;
+  FrozenRewritePatternList &
+  operator=(const FrozenRewritePatternList &patterns) = default;
+  FrozenRewritePatternList &
+  operator=(FrozenRewritePatternList &&patterns) = default;
   ~FrozenRewritePatternList();
 
   /// Return the native patterns held by this list.
   iterator_range<llvm::pointee_iterator<NativePatternListT::const_iterator>>
   getNativePatterns() const {
+    const NativePatternListT &nativePatterns = impl->nativePatterns;
     return llvm::make_pointee_range(nativePatterns);
   }
 
   /// Return the compiled PDL bytecode held by this list. Returns null if
   /// there are no PDL patterns within the list.
   const detail::PDLByteCode *getPDLByteCode() const {
-    return pdlByteCode.get();
+    return impl->pdlByteCode.get();
   }
 
 private:
-  /// The set of.
-  std::vector<std::unique_ptr<RewritePattern>> nativePatterns;
+  /// The internal implementation of the frozen pattern list.
+  struct Impl {
+    /// The set of native C++ rewrite patterns.
+    NativePatternListT nativePatterns;
 
-  /// The bytecode containing the compiled PDL patterns.
-  std::unique_ptr<detail::PDLByteCode> pdlByteCode;
+    /// The bytecode containing the compiled PDL patterns.
+    std::unique_ptr<detail::PDLByteCode> pdlByteCode;
+  };
+
+  /// A pointer to the internal pattern list. This uses a shared_ptr to avoid
+  /// the need to compile the same pattern list multiple times. For example,
+  /// during multi-threaded pass execution, all copies of a pass can share the
+  /// same pattern list.
+  std::shared_ptr<Impl> impl;
 };
 
 } // end namespace mlir
index 60f6dce..40d7fcd 100644 (file)
@@ -50,12 +50,16 @@ static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) {
 // FrozenRewritePatternList
 //===----------------------------------------------------------------------===//
 
+FrozenRewritePatternList::FrozenRewritePatternList()
+    : impl(std::make_shared<Impl>()) {}
+
 FrozenRewritePatternList::FrozenRewritePatternList(
     OwningRewritePatternList &&patterns)
-    : nativePatterns(std::move(patterns.getNativePatterns())) {
-  PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
+    : impl(std::make_shared<Impl>()) {
+  impl->nativePatterns = std::move(patterns.getNativePatterns());
 
   // Generate the bytecode for the PDL patterns if any were provided.
+  PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
   ModuleOp pdlModule = pdlPatterns.getModule();
   if (!pdlModule)
     return;
@@ -64,14 +68,9 @@ FrozenRewritePatternList::FrozenRewritePatternList(
         "failed to lower PDL pattern module to the PDL Interpreter");
 
   // Generate the pdl bytecode.
-  pdlByteCode = std::make_unique<detail::PDLByteCode>(
+  impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
       pdlModule, pdlPatterns.takeConstraintFunctions(),
       pdlPatterns.takeCreateFunctions(), pdlPatterns.takeRewriteFunctions());
 }
 
-FrozenRewritePatternList::FrozenRewritePatternList(
-    FrozenRewritePatternList &&patterns)
-    : nativePatterns(std::move(patterns.nativePatterns)),
-      pdlByteCode(std::move(patterns.pdlByteCode)) {}
-
 FrozenRewritePatternList::~FrozenRewritePatternList() {}