NFC: Implement OwningRewritePatternList as a class instead of a using directive.
authorRiver Riddle <riverriddle@google.com>
Tue, 6 Aug 2019 01:37:56 +0000 (18:37 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 6 Aug 2019 01:38:22 +0000 (18:38 -0700)
This allows for proper forward declaration, as opposed to leaking the internal implementation via a using directive. This also allows for all pattern building to go through 'insert' methods on the OwningRewritePatternList, replacing uses of 'push_back' and 'RewriteListBuilder'.

PiperOrigin-RevId: 261816316

34 files changed:
mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h
mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp
mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
mlir/examples/Linalg/Linalg3/lib/Transforms.cpp
mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp
mlir/examples/toy/Ch5/mlir/LateLowering.cpp
mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/IR/PatternMatch.h
mlir/include/mlir/Transforms/LowerAffine.h
mlir/lib/AffineOps/AffineOps.cpp
mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp
mlir/lib/IR/PatternMatch.cpp
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp
mlir/lib/StandardOps/Ops.cpp
mlir/lib/Transforms/DialectConversion.cpp
mlir/lib/Transforms/LowerAffine.cpp
mlir/lib/Transforms/LowerVectorTransfers.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
mlir/test/lib/TestDialect/TestPatterns.cpp
mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp
mlir/tools/mlir-tblgen/RewriterGen.cpp

index 2d4a4a2..8a5eddd 100644 (file)
@@ -31,7 +31,7 @@ class MLIRContext;
 class ModuleOp;
 class RewritePattern;
 class Type;
-using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
+class OwningRewritePatternList;
 namespace LLVM {
 class LLVMType;
 } // end namespace LLVM
index 411a7af..58e6159 100644 (file)
@@ -395,8 +395,8 @@ public:
 
 void linalg::populateLinalg1ToLLVMConversionPatterns(
     mlir::OwningRewritePatternList &patterns, mlir::MLIRContext *context) {
-  RewriteListBuilder<DropConsumer, RangeOpConversion, SliceOpConversion,
-                     ViewOpConversion>::build(patterns, context);
+  patterns.insert<DropConsumer, RangeOpConversion, SliceOpConversion,
+                  ViewOpConversion>(context);
 }
 
 namespace {
index 8c77737..e4a401e 100644 (file)
@@ -145,8 +145,7 @@ struct LinalgTypeConverter : public LLVMTypeConverter {
 // coverters to the list.
 static void populateLinalg3ToLLVMConversionPatterns(
     mlir::OwningRewritePatternList &patterns, mlir::MLIRContext *context) {
-  RewriteListBuilder<LoadOpConversion, StoreOpConversion>::build(patterns,
-                                                                 context);
+  patterns.insert<LoadOpConversion, StoreOpConversion>(context);
 }
 
 LogicalResult linalg::convertLinalg3ToLLVM(ModuleOp module) {
index d81eec0..8f97f43 100644 (file)
@@ -261,8 +261,8 @@ struct LowerLinalgLoadStorePass
   void runOnFunction() {
     OwningRewritePatternList patterns;
     auto *context = &getContext();
-    patterns.push_back(llvm::make_unique<Rewriter<linalg::LoadOp>>(context));
-    patterns.push_back(llvm::make_unique<Rewriter<linalg::StoreOp>>(context));
+    patterns.insert<Rewriter<linalg::LoadOp>, Rewriter<linalg::StoreOp>>(
+        context);
     applyPatternsGreedily(getFunction(), std::move(patterns));
   }
 };
index 92e80d2..b89cb85 100644 (file)
@@ -142,14 +142,14 @@ struct SimplifyNullReshape : public mlir::OpRewritePattern<ReshapeOp> {
 // Register our patterns for rewrite by the Canonicalization framework.
 void TransposeOp::getCanonicalizationPatterns(
     mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
-  results.push_back(llvm::make_unique<SimplifyRedundantTranspose>(context));
+  results.insert<SimplifyRedundantTranspose>(context);
 }
 
 // Register our patterns for rewrite by the Canonicalization framework.
 void ReshapeOp::getCanonicalizationPatterns(
     mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
-  mlir::RewriteListBuilder<SimplifyReshapeConstant, SimplifyReshapeReshape,
-                           SimplifyNullReshape>::build(results, context);
+  results.insert<SimplifyReshapeConstant, SimplifyReshapeReshape,
+                 SimplifyNullReshape>(context);
 }
 
 } // namespace toy
index f3463ba..72bc289 100644 (file)
@@ -132,7 +132,7 @@ struct EarlyLoweringPass : public FunctionPass<EarlyLoweringPass> {
     target.addLegalOp<toy::AllocOp, toy::TypeCastOp>();
 
     OwningRewritePatternList patterns;
-    RewriteListBuilder<MulOpConversion>::build(patterns, &getContext());
+    patterns.insert<MulOpConversion>(&getContext());
     if (failed(applyPartialConversion(getFunction(), target,
                                       std::move(patterns)))) {
       emitError(mlir::UnknownLoc::get(&getContext()), "Error lowering Toy\n");
index 5a01122..8b2cc21 100644 (file)
@@ -352,9 +352,9 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> {
   void runOnModule() override {
     ToyTypeConverter typeConverter;
     OwningRewritePatternList toyPatterns;
-    RewriteListBuilder<AddOpConversion, PrintOpConversion, ConstantOpConversion,
-                       TransposeOpConversion,
-                       ReturnOpConversion>::build(toyPatterns, &getContext());
+    toyPatterns.insert<AddOpConversion, PrintOpConversion, ConstantOpConversion,
+                       TransposeOpConversion, ReturnOpConversion>(
+        &getContext());
     mlir::populateFuncOpTypeConversionPattern(toyPatterns, &getContext(),
                                               typeConverter);
 
index 8e9e8eb..4798ad1 100644 (file)
@@ -144,14 +144,14 @@ struct SimplifyNullReshape : public mlir::OpRewritePattern<ReshapeOp> {
 // Register our patterns for rewrite by the Canonicalization framework.
 void TransposeOp::getCanonicalizationPatterns(
     mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
-  results.push_back(llvm::make_unique<SimplifyRedundantTranspose>(context));
+  results.insert<SimplifyRedundantTranspose>(context);
 }
 
 // Register our patterns for rewrite by the Canonicalization framework.
 void ReshapeOp::getCanonicalizationPatterns(
     mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
-  mlir::RewriteListBuilder<SimplifyReshapeConstant, SimplifyReshapeReshape,
-                           SimplifyNullReshape>::build(results, context);
+  results.insert<SimplifyReshapeConstant, SimplifyReshapeReshape,
+                 SimplifyNullReshape>(context);
 }
 
 namespace {
@@ -180,7 +180,7 @@ struct SimplifyIdentityTypeCast : public mlir::OpRewritePattern<TypeCastOp> {
 
 void TypeCastOp::getCanonicalizationPatterns(
     mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
-  results.push_back(llvm::make_unique<SimplifyIdentityTypeCast>(context));
+  results.insert<SimplifyIdentityTypeCast>(context);
 }
 
 } // namespace toy
index e8ab273..78e4356 100644 (file)
@@ -29,7 +29,7 @@ class MLIRContext;
 class RewritePattern;
 
 // Owning list of rewriting patterns.
-using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
+class OwningRewritePatternList;
 
 /// Collect a set of patterns to lower from loop.for, loop.if, and
 /// loop.terminator to CFG operations within the Standard dialect, in particular
index 361294a..941e382 100644 (file)
@@ -38,7 +38,7 @@ class RewritePattern;
 class Type;
 
 // Owning list of rewriting patterns.
-using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
+class OwningRewritePatternList;
 
 /// Type for a callback constructing the owning list of patterns for the
 /// conversion to the LLVMIR dialect.  The callback is expected to append
index c76f1d6..204da29 100644 (file)
@@ -57,9 +57,7 @@ class Value;
 /// either OpTy or OperandAdaptor<OpTy> seamlessly.
 template <typename OpTy> using OperandAdaptor = typename OpTy::OperandAdaptor;
 
-/// This is a vector that owns the patterns inside of it.
-using OwningPatternList = std::vector<std::unique_ptr<Pattern>>;
-using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
+class OwningRewritePatternList;
 
 enum class OperationProperty {
   /// This bit is set for an operation if it is a commutative operation: that
index d739a80..e3897b1 100644 (file)
@@ -394,8 +394,39 @@ private:
 // Pattern-driven rewriters
 //===----------------------------------------------------------------------===//
 
-/// This is a vector that owns the patterns inside of it.
-using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
+class OwningRewritePatternList {
+  using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
+
+public:
+  PatternListT::iterator begin() { return patterns.begin(); }
+  PatternListT::iterator end() { return patterns.end(); }
+  PatternListT::const_iterator begin() const { return patterns.begin(); }
+  PatternListT::const_iterator end() const { return patterns.end(); }
+
+  //===--------------------------------------------------------------------===//
+  // Pattern Insertion
+  //===--------------------------------------------------------------------===//
+
+  void insert(RewritePattern *pattern) { patterns.emplace_back(pattern); }
+
+  /// Add an instance of each of the pattern types 'Ts' to the pattern list with
+  /// the given arguments.
+  // Note: ConstructorArg is necessary here to separate the two variadic lists.
+  template <typename... Ts, typename ConstructorArg,
+            typename... ConstructorArgs>
+  void insert(ConstructorArg &&arg, ConstructorArgs &&... args) {
+    // The following expands a call to emplace_back for each of the pattern
+    // types 'Ts'. This magic is necessary due to a limitation in the places
+    // that a parameter pack can be expanded in c++11.
+    // FIXME: In c++17 this can be simplified by using 'fold expressions'.
+    using dummy = int[];
+    (void)dummy{
+        0, (patterns.emplace_back(llvm::make_unique<Ts>(arg, args...)), 0)...};
+  }
+
+private:
+  PatternListT patterns;
+};
 
 /// This class manages optimization and execution of a group of rewrite
 /// patterns, providing an API for finding and applying, the best match against
@@ -404,7 +435,7 @@ using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
 class RewritePatternMatcher {
 public:
   /// Create a RewritePatternMatcher with the specified set of patterns.
-  explicit RewritePatternMatcher(OwningRewritePatternList &&patterns);
+  explicit RewritePatternMatcher(OwningRewritePatternList &patterns);
 
   /// Try to match the given operation to a pattern and rewrite it. Return
   /// true if any pattern matches.
@@ -416,7 +447,7 @@ private:
 
   /// The group of patterns that are matched for optimization through this
   /// matcher.
-  OwningRewritePatternList patterns;
+  std::vector<RewritePattern *> patterns;
 };
 
 /// Rewrite the regions of the specified operation, which must be isolated from
@@ -427,29 +458,6 @@ private:
 ///
 bool applyPatternsGreedily(Operation *op, OwningRewritePatternList &&patterns);
 
-/// Helper class to create a list of rewrite patterns given a list of their
-/// types and a list of attributes perfect-forwarded to each of the conversion
-/// constructors.
-template <typename Arg, typename... Args> struct RewriteListBuilder {
-  template <typename... ConstructorArgs>
-  static void build(OwningRewritePatternList &patterns,
-                    ConstructorArgs &&... constructorArgs) {
-    RewriteListBuilder<Args...>::build(
-        patterns, std::forward<ConstructorArgs>(constructorArgs)...);
-    RewriteListBuilder<Arg>::build(
-        patterns, std::forward<ConstructorArgs>(constructorArgs)...);
-  }
-};
-
-// Template specialization to stop recursion.
-template <typename Arg> struct RewriteListBuilder<Arg> {
-  template <typename... ConstructorArgs>
-  static void build(OwningRewritePatternList &patterns,
-                    ConstructorArgs &&... constructorArgs) {
-    patterns.emplace_back(llvm::make_unique<Arg>(
-        std::forward<ConstructorArgs>(constructorArgs)...));
-  }
-};
 } // end namespace mlir
 
 #endif // MLIR_PATTERN_MATCH_H
index 9ad3f66..5fae476 100644 (file)
@@ -32,7 +32,7 @@ class RewritePattern;
 class Value;
 
 // Owning list of rewriting patterns.
-using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
+class OwningRewritePatternList;
 
 /// Emit code that computes the given affine expression using standard
 /// arithmetic operations applied to the provided dimension and symbol values.
index 9a02623..767c2e3 100644 (file)
@@ -708,7 +708,7 @@ struct SimplifyAffineApply : public OpRewritePattern<AffineApplyOp> {
 
 void AffineApplyOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
-  results.push_back(llvm::make_unique<SimplifyAffineApply>(context));
+  results.insert<SimplifyAffineApply>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -912,8 +912,7 @@ LogicalResult AffineDmaStartOp::verify() {
 void AffineDmaStartOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
   /// dma_start(memrefcast) -> dma_start
-  results.push_back(
-      llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
+  results.insert<MemRefCastFolder>(getOperationName(), context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -989,8 +988,7 @@ LogicalResult AffineDmaWaitOp::verify() {
 void AffineDmaWaitOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
   /// dma_wait(memrefcast) -> dma_wait
-  results.push_back(
-      llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
+  results.insert<MemRefCastFolder>(getOperationName(), context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1333,7 +1331,7 @@ struct AffineForLoopBoundFolder : public OpRewritePattern<AffineForOp> {
 
 void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                               MLIRContext *context) {
-  results.push_back(llvm::make_unique<AffineForLoopBoundFolder>(context));
+  results.insert<AffineForLoopBoundFolder>(context);
 }
 
 AffineBound AffineForOp::getLowerBound() {
@@ -1659,8 +1657,7 @@ LogicalResult AffineLoadOp::verify() {
 void AffineLoadOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
   /// load(memrefcast) -> load
-  results.push_back(
-      llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
+  results.insert<MemRefCastFolder>(getOperationName(), context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1752,8 +1749,7 @@ LogicalResult AffineStoreOp::verify() {
 void AffineStoreOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
   /// load(memrefcast) -> load
-  results.push_back(
-      llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
+  results.insert<MemRefCastFolder>(getOperationName(), context);
 }
 
 #define GET_OP_CLASSES
index c37decf..034aa22 100644 (file)
@@ -258,8 +258,7 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
 
 void mlir::populateLoopToStdConversionPatterns(
     OwningRewritePatternList &patterns, MLIRContext *ctx) {
-  RewriteListBuilder<ForLowering, IfLowering, TerminatorLowering>::build(
-      patterns, ctx);
+  patterns.insert<ForLowering, IfLowering, TerminatorLowering>(ctx);
 }
 
 void ControlFlowToCFGPass::runOnFunction() {
index 4eadb87..58f01fc 100644 (file)
@@ -104,8 +104,7 @@ void GPUToSPIRVPass::runOnModule() {
   SPIRVTypeConverter typeConverter(context);
   SPIRVEntryFnTypeConverter entryFnConverter(context);
   OwningRewritePatternList patterns;
-  RewriteListBuilder<KernelFnConversion>::build(
-      patterns, context, typeConverter, entryFnConverter);
+  patterns.insert<KernelFnConversion>(context, typeConverter, entryFnConverter);
   populateStandardToSPIRVPatterns(context, patterns);
 
   ConversionTarget target(*context);
index af8812c..09ddcd1 100644 (file)
@@ -1023,7 +1023,7 @@ void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) {
 void mlir::populateStdToLLVMConversionPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
   // FIXME: this should be tablegen'ed
-  RewriteListBuilder<
+  patterns.insert<
       AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering,
       BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering,
       CondBranchOpLowering, ConstLLVMOpLowering, DeallocOpLowering,
@@ -1032,8 +1032,7 @@ void mlir::populateStdToLLVMConversionPatterns(
       MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering,
       RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering,
       SelectOpLowering, SIToFPLowering, StoreOpLowering, SubFOpLowering,
-      SubIOpLowering, XOrOpLowering>::build(patterns, *converter.getDialect(),
-                                            converter);
+      SubIOpLowering, XOrOpLowering>(*converter.getDialect(), converter);
 }
 
 // Convert types using the stored LLVM IR module.
index d32d866..067f2ae 100644 (file)
@@ -201,6 +201,6 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
                                      OwningRewritePatternList &patterns) {
   populateWithGenerated(context, &patterns);
   // Add the return op conversion.
-  RewriteListBuilder<ReturnToSPIRVConversion>::build(patterns, context);
+  patterns.insert<ReturnToSPIRVConversion>(context);
 }
 } // namespace mlir
index dafc8e7..d2f3881 100644 (file)
@@ -368,8 +368,7 @@ void LowerUniformRealMathPass::runOnFunction() {
   auto fn = getFunction();
   OwningRewritePatternList patterns;
   auto *context = &getContext();
-  patterns.push_back(llvm::make_unique<UniformRealAddEwPattern>(context));
-  patterns.push_back(llvm::make_unique<UniformRealMulEwPattern>(context));
+  patterns.insert<UniformRealAddEwPattern, UniformRealMulEwPattern>(context);
   applyPatternsGreedily(fn, std::move(patterns));
 }
 
@@ -389,7 +388,7 @@ void LowerUniformCastsPass::runOnFunction() {
   auto fn = getFunction();
   OwningRewritePatternList patterns;
   auto *context = &getContext();
-  patterns.push_back(llvm::make_unique<UniformDequantizePattern>(context));
+  patterns.insert<UniformDequantizePattern>(context);
   applyPatternsGreedily(fn, std::move(patterns));
 }
 
index bda5979..2fbaa49 100644 (file)
@@ -372,7 +372,7 @@ class PropagateConstantBounds : public OpRewritePattern<LaunchOp> {
 
 void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                            MLIRContext *context) {
-  RewriteListBuilder<PropagateConstantBounds>::build(results, context);
+  results.insert<PropagateConstantBounds>(context);
 }
 
 //===----------------------------------------------------------------------===//
index e237e8b..3bd49d4 100644 (file)
@@ -60,8 +60,7 @@ public:
 
 void StorageCastOp::getCanonicalizationPatterns(
     OwningRewritePatternList &patterns, MLIRContext *context) {
-  patterns.push_back(
-      llvm::make_unique<RemoveRedundantStorageCastsRewrite>(context));
+  patterns.insert<RemoveRedundantStorageCastsRewrite>(context);
 }
 
 QuantizationDialect::QuantizationDialect(MLIRContext *context)
index 8469fa2..2276fbd 100644 (file)
@@ -108,7 +108,7 @@ void ConvertConstPass::runOnFunction() {
   OwningRewritePatternList patterns;
   auto func = getFunction();
   auto *context = &getContext();
-  patterns.push_back(llvm::make_unique<QuantizedConstRewrite>(context));
+  patterns.insert<QuantizedConstRewrite>(context);
   applyPatternsGreedily(func, std::move(patterns));
 }
 
index 32d8c8a..8f5d1b3 100644 (file)
@@ -97,8 +97,7 @@ void ConvertSimulatedQuantPass::runOnFunction() {
   OwningRewritePatternList patterns;
   auto func = getFunction();
   auto *context = &getContext();
-  patterns.push_back(
-      llvm::make_unique<ConstFakeQuantRewrite>(context, &hadFailure));
+  patterns.insert<ConstFakeQuantRewrite>(context, &hadFailure);
   applyPatternsGreedily(func, std::move(patterns));
   if (hadFailure)
     signalPassFailure();
index 5010b84..94fa7ab 100644 (file)
@@ -149,12 +149,13 @@ void PatternRewriter::updatedRootInPlace(
 //===----------------------------------------------------------------------===//
 
 RewritePatternMatcher::RewritePatternMatcher(
-    OwningRewritePatternList &&patterns)
-    : patterns(std::move(patterns)) {
+    OwningRewritePatternList &patterns) {
+  for (auto &pattern : patterns)
+    this->patterns.push_back(pattern.get());
+
   // Sort the patterns by benefit to simplify the matching logic.
   std::stable_sort(this->patterns.begin(), this->patterns.end(),
-                   [](const std::unique_ptr<RewritePattern> &l,
-                      const std::unique_ptr<RewritePattern> &r) {
+                   [](RewritePattern *l, RewritePattern *r) {
                      return r->getBenefit() < l->getBenefit();
                    });
 }
@@ -162,7 +163,7 @@ RewritePatternMatcher::RewritePatternMatcher(
 /// Try to match the given operation to a pattern and rewrite it.
 bool RewritePatternMatcher::matchAndRewrite(Operation *op,
                                             PatternRewriter &rewriter) {
-  for (auto &pattern : patterns) {
+  for (auto *pattern : patterns) {
     // Ignore patterns that are for the wrong root or are impossible to match.
     if (pattern->getRootKind() != op->getName() ||
         pattern->getBenefit().isImpossibleToMatch())
index 6b62a8e..7c2ea59 100644 (file)
@@ -678,12 +678,11 @@ static void
 populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter,
                                        OwningRewritePatternList &patterns,
                                        MLIRContext *ctx) {
-  RewriteListBuilder<BufferAllocOpConversion, BufferDeallocOpConversion,
-                     BufferSizeOpConversion, DimOpConversion,
-                     LinalgOpConversion<DotOp>, LinalgOpConversion<MatmulOp>,
-                     LoadOpConversion, RangeOpConversion, SliceOpConversion,
-                     StoreOpConversion, ViewOpConversion>::build(patterns, ctx,
-                                                                 converter);
+  patterns.insert<BufferAllocOpConversion, BufferDeallocOpConversion,
+                  BufferSizeOpConversion, DimOpConversion,
+                  LinalgOpConversion<DotOp>, LinalgOpConversion<MatmulOp>,
+                  LoadOpConversion, RangeOpConversion, SliceOpConversion,
+                  StoreOpConversion, ViewOpConversion>(ctx, converter);
 }
 
 namespace {
index 6b376db..3de8913 100644 (file)
@@ -60,12 +60,9 @@ void RemoveInstrumentationPass::runOnFunction() {
   OwningRewritePatternList patterns;
   auto func = getFunction();
   auto *context = &getContext();
-  patterns.push_back(
-      llvm::make_unique<RemoveIdentityOpRewrite<StatisticsOp>>(context));
-  patterns.push_back(
-      llvm::make_unique<RemoveIdentityOpRewrite<StatisticsRefOp>>(context));
-  patterns.push_back(
-      llvm::make_unique<RemoveIdentityOpRewrite<CoupledRefOp>>(context));
+  patterns.insert<RemoveIdentityOpRewrite<StatisticsOp>,
+                  RemoveIdentityOpRewrite<StatisticsRefOp>,
+                  RemoveIdentityOpRewrite<CoupledRefOp>>(context);
   applyPatternsGreedily(func, std::move(patterns));
 }
 
index df99f00..9ecd99a 100644 (file)
@@ -365,8 +365,7 @@ struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
 
 void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                           MLIRContext *context) {
-  RewriteListBuilder<SimplifyAllocConst, SimplifyDeadAlloc>::build(results,
-                                                                   context);
+  results.insert<SimplifyAllocConst, SimplifyDeadAlloc>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -544,8 +543,7 @@ static LogicalResult verify(CallIndirectOp op) {
 
 void CallIndirectOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
-  results.push_back(
-      llvm::make_unique<SimplifyIndirectCallWithKnownCallee>(context));
+  results.insert<SimplifyIndirectCallWithKnownCallee>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1015,7 +1013,7 @@ static void print(OpAsmPrinter *p, CondBranchOp op) {
 
 void CondBranchOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
-  results.push_back(llvm::make_unique<SimplifyConstCondBranchPred>(context));
+  results.insert<SimplifyConstCondBranchPred>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1231,9 +1229,8 @@ static LogicalResult verify(DeallocOp op) {
 void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                             MLIRContext *context) {
   /// dealloc(memrefcast) -> dealloc
-  results.push_back(
-      llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
-  results.push_back(llvm::make_unique<SimplifyDeadDealloc>(context));
+  results.insert<MemRefCastFolder>(getOperationName(), context);
+  results.insert<SimplifyDeadDealloc>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1497,8 +1494,7 @@ LogicalResult DmaStartOp::verify() {
 void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                              MLIRContext *context) {
   /// dma_start(memrefcast) -> dma_start
-  results.push_back(
-      llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
+  results.insert<MemRefCastFolder>(getOperationName(), context);
 }
 
 // ---------------------------------------------------------------------------
@@ -1561,8 +1557,7 @@ ParseResult DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
 void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                             MLIRContext *context) {
   /// dma_wait(memrefcast) -> dma_wait
-  results.push_back(
-      llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
+  results.insert<MemRefCastFolder>(getOperationName(), context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1695,8 +1690,7 @@ static LogicalResult verify(LoadOp op) {
 void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                          MLIRContext *context) {
   /// load(memrefcast) -> load
-  results.push_back(
-      llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
+  results.insert<MemRefCastFolder>(getOperationName(), context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -2007,8 +2001,7 @@ static LogicalResult verify(StoreOp op) {
 void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                           MLIRContext *context) {
   /// store(memrefcast) -> store
-  results.push_back(
-      llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
+  results.insert<MemRefCastFolder>(getOperationName(), context);
 }
 
 //===----------------------------------------------------------------------===//
index 50c636f..6f264b0 100644 (file)
@@ -1243,8 +1243,7 @@ struct FuncOpSignatureConversion : public ConversionPattern {
 void mlir::populateFuncOpTypeConversionPattern(
     OwningRewritePatternList &patterns, MLIRContext *ctx,
     TypeConverter &converter) {
-  RewriteListBuilder<FuncOpSignatureConversion>::build(patterns, ctx,
-                                                       converter);
+  patterns.insert<FuncOpSignatureConversion>(ctx, converter);
 }
 
 /// This function converts the type signature of the given block, by invoking
index f35f963..1c558ef 100644 (file)
@@ -507,10 +507,11 @@ public:
 
 void mlir::populateAffineToStdConversionPatterns(
     OwningRewritePatternList &patterns, MLIRContext *ctx) {
-  RewriteListBuilder<AffineApplyLowering, AffineDmaStartLowering,
-                     AffineDmaWaitLowering, AffineLoadLowering,
-                     AffineStoreLowering, AffineForLowering, AffineIfLowering,
-                     AffineTerminatorLowering>::build(patterns, ctx);
+  patterns
+      .insert<AffineApplyLowering, AffineDmaStartLowering,
+              AffineDmaWaitLowering, AffineLoadLowering, AffineStoreLowering,
+              AffineForLowering, AffineIfLowering, AffineTerminatorLowering>(
+          ctx);
 }
 
 namespace {
index 3585e2b..ef67488 100644 (file)
@@ -365,12 +365,8 @@ struct LowerVectorTransfersPass
   void runOnFunction() {
     OwningRewritePatternList patterns;
     auto *context = &getContext();
-    patterns.push_back(
-        llvm::make_unique<VectorTransferRewriter<VectorTransferReadOp>>(
-            context));
-    patterns.push_back(
-        llvm::make_unique<VectorTransferRewriter<VectorTransferWriteOp>>(
-            context));
+    patterns.insert<VectorTransferRewriter<VectorTransferReadOp>,
+                    VectorTransferRewriter<VectorTransferWriteOp>>(context);
     applyPatternsGreedily(getFunction(), std::move(patterns));
   }
 };
index 5295217..1df4cee 100644 (file)
@@ -44,8 +44,8 @@ namespace {
 class GreedyPatternRewriteDriver : public PatternRewriter {
 public:
   explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
-                                      OwningRewritePatternList &&patterns)
-      : PatternRewriter(ctx), matcher(std::move(patterns)) {
+                                      OwningRewritePatternList &patterns)
+      : PatternRewriter(ctx), matcher(patterns) {
     worklist.reserve(64);
   }
 
@@ -224,7 +224,7 @@ bool mlir::applyPatternsGreedily(Operation *op,
   if (!op->isKnownIsolatedFromAbove())
     return false;
 
-  GreedyPatternRewriteDriver driver(op->getContext(), std::move(patterns));
+  GreedyPatternRewriteDriver driver(op->getContext(), patterns);
   bool converged = driver.simplify(op, maxPatternMatchIterations);
   LLVM_DEBUG(if (!converged) {
     llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
index 201dfc3..ed94eed 100644 (file)
@@ -41,7 +41,7 @@ struct TestPatternDriver : public FunctionPass<TestPatternDriver> {
     populateWithGenerated(&getContext(), &patterns);
 
     // Verify named pattern is generated with expected name.
-    RewriteListBuilder<TestNamedPatternRule>::build(patterns, &getContext());
+    patterns.insert<TestNamedPatternRule>(&getContext());
 
     applyPatternsGreedily(getFunction(), std::move(patterns));
   }
@@ -193,9 +193,9 @@ struct TestLegalizePatternDriver
     TestTypeConverter converter;
     mlir::OwningRewritePatternList patterns;
     populateWithGenerated(&getContext(), &patterns);
-    RewriteListBuilder<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
-                       TestDropOp, TestPassthroughInvalidOp,
-                       TestSplitReturnType>::build(patterns, &getContext());
+    patterns.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
+                    TestDropOp, TestPassthroughInvalidOp, TestSplitReturnType>(
+        &getContext());
     mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
                                               converter);
 
index edf6aea..f75413f 100644 (file)
@@ -133,7 +133,7 @@ static LogicalResult runMLIRPasses(ModuleOp m) {
   pm.addPass(createConvertToLLVMIRPass([](LLVMTypeConverter &converter,
                                           OwningRewritePatternList &patterns) {
     populateStdToLLVMConversionPatterns(converter, patterns);
-    patterns.push_back(llvm::make_unique<GPULaunchFuncOpLowering>(converter));
+    patterns.insert<GPULaunchFuncOpLowering>(converter);
   }));
   pm.addPass(createLowerGpuOpsToNVVMOpsPass());
   pm.addPass(createConvertGPUKernelToCubinPass(&compilePtxToCubin));
index d408ecf..24eeaf5 100644 (file)
@@ -935,8 +935,7 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
   os << "void populateWithGenerated(MLIRContext *context, "
      << "OwningRewritePatternList *patterns) {\n";
   for (const auto &name : rewriterNames) {
-    os << "  patterns->push_back(llvm::make_unique<" << name
-       << ">(context));\n";
+    os << "  patterns->insert<" << name << ">(context);\n";
   }
   os << "}\n";
 }