From: River Riddle Date: Mon, 14 Dec 2020 20:32:21 +0000 (-0800) Subject: [mlir] Change the internal representation of FrozenRewritePatternList to use shared_ptr X-Git-Tag: llvmorg-13-init~3457 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=6af2c4ca9bdb37e56cfda8dae4f6c3c6ca21b8d7;p=platform%2Fupstream%2Fllvm.git [mlir] Change the internal representation of FrozenRewritePatternList to use shared_ptr This will allow for caching pattern lists across multiple pass instances, such as when multithreading. This is an extremely important invariant for PDL patterns, which are compiled at runtime when the FrozenRewritePatternList is built. Differential Revision: https://reviews.llvm.org/D93146 --- diff --git a/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h b/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h index c2335b9..0e583aa 100644 --- a/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h +++ b/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h @@ -18,34 +18,52 @@ class PDLByteCode; /// This class represents a frozen set of patterns that can be processed by a /// pattern applicator. This class is designed to enable caching pattern lists -/// such that they need not be continuously recomputed. +/// such that they need not be continuously recomputed. Note that all copies of +/// this class share the same compiled pattern list, allowing for a reduction in +/// the number of duplicated patterns that need to be created. class FrozenRewritePatternList { using NativePatternListT = std::vector>; public: /// Freeze the patterns held in `patterns`, and take ownership. + FrozenRewritePatternList(); FrozenRewritePatternList(OwningRewritePatternList &&patterns); - FrozenRewritePatternList(FrozenRewritePatternList &&patterns); + FrozenRewritePatternList(FrozenRewritePatternList &&patterns) = default; + FrozenRewritePatternList(const FrozenRewritePatternList &patterns) = default; + FrozenRewritePatternList & + operator=(const FrozenRewritePatternList &patterns) = default; + FrozenRewritePatternList & + operator=(FrozenRewritePatternList &&patterns) = default; ~FrozenRewritePatternList(); /// Return the native patterns held by this list. iterator_range> getNativePatterns() const { + const NativePatternListT &nativePatterns = impl->nativePatterns; return llvm::make_pointee_range(nativePatterns); } /// Return the compiled PDL bytecode held by this list. Returns null if /// there are no PDL patterns within the list. const detail::PDLByteCode *getPDLByteCode() const { - return pdlByteCode.get(); + return impl->pdlByteCode.get(); } private: - /// The set of. - std::vector> nativePatterns; + /// The internal implementation of the frozen pattern list. + struct Impl { + /// The set of native C++ rewrite patterns. + NativePatternListT nativePatterns; - /// The bytecode containing the compiled PDL patterns. - std::unique_ptr pdlByteCode; + /// The bytecode containing the compiled PDL patterns. + std::unique_ptr pdlByteCode; + }; + + /// A pointer to the internal pattern list. This uses a shared_ptr to avoid + /// the need to compile the same pattern list multiple times. For example, + /// during multi-threaded pass execution, all copies of a pass can share the + /// same pattern list. + std::shared_ptr impl; }; } // end namespace mlir diff --git a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp index 60f6dce..40d7fcd 100644 --- a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp +++ b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp @@ -50,12 +50,16 @@ static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) { // FrozenRewritePatternList //===----------------------------------------------------------------------===// +FrozenRewritePatternList::FrozenRewritePatternList() + : impl(std::make_shared()) {} + FrozenRewritePatternList::FrozenRewritePatternList( OwningRewritePatternList &&patterns) - : nativePatterns(std::move(patterns.getNativePatterns())) { - PDLPatternModule &pdlPatterns = patterns.getPDLPatterns(); + : impl(std::make_shared()) { + impl->nativePatterns = std::move(patterns.getNativePatterns()); // Generate the bytecode for the PDL patterns if any were provided. + PDLPatternModule &pdlPatterns = patterns.getPDLPatterns(); ModuleOp pdlModule = pdlPatterns.getModule(); if (!pdlModule) return; @@ -64,14 +68,9 @@ FrozenRewritePatternList::FrozenRewritePatternList( "failed to lower PDL pattern module to the PDL Interpreter"); // Generate the pdl bytecode. - pdlByteCode = std::make_unique( + impl->pdlByteCode = std::make_unique( pdlModule, pdlPatterns.takeConstraintFunctions(), pdlPatterns.takeCreateFunctions(), pdlPatterns.takeRewriteFunctions()); } -FrozenRewritePatternList::FrozenRewritePatternList( - FrozenRewritePatternList &&patterns) - : nativePatterns(std::move(patterns.nativePatterns)), - pdlByteCode(std::move(patterns.pdlByteCode)) {} - FrozenRewritePatternList::~FrozenRewritePatternList() {}