Move the ConversionListBuilder utility to PatternMatch.h and rename it to Rewrite...
authorRiver Riddle <riverriddle@google.com>
Sat, 18 May 2019 20:23:38 +0000 (13:23 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 20 May 2019 20:47:28 +0000 (13:47 -0700)
--

PiperOrigin-RevId: 248884466

13 files changed:
mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp
mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp
mlir/examples/toy/Ch5/mlir/LateLowering.cpp
mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
mlir/g3doc/Tutorials/Linalg/LLVMConversion.md
mlir/g3doc/Tutorials/Toy/Ch-5.md
mlir/include/mlir/IR/PatternMatch.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/lib/StandardOps/Ops.cpp

index 3027d92..0b64029 100644 (file)
@@ -397,8 +397,8 @@ public:
 
 void linalg::getDescriptorConverters(mlir::OwningRewritePatternList &patterns,
                                      mlir::MLIRContext *context) {
-  ConversionListBuilder<DropConsumer, RangeOpConversion, SliceOpConversion,
-                        ViewOpConversion>::build(patterns, context);
+  RewriteListBuilder<DropConsumer, RangeOpConversion, SliceOpConversion,
+                     ViewOpConversion>::build(patterns, context);
 }
 
 namespace {
index 059ef32..8d945f3 100644 (file)
@@ -138,8 +138,8 @@ class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
 static void getConversions(mlir::OwningRewritePatternList &patterns,
                            mlir::MLIRContext *context) {
   linalg::getDescriptorConverters(patterns, context);
-  ConversionListBuilder<LoadOpConversion, StoreOpConversion>::build(patterns,
-                                                                    context);
+  RewriteListBuilder<LoadOpConversion, StoreOpConversion>::build(patterns,
+                                                                 context);
 }
 
 void linalg::convertLinalg3ToLLVM(Module &module) {
index 4175fc2..37aa47f 100644 (file)
@@ -161,9 +161,8 @@ void TransposeOp::getCanonicalizationPatterns(
 // Register our patterns for rewrite by the Canonicalization framework.
 void ReshapeOp::getCanonicalizationPatterns(
     mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
-  results.push_back(llvm::make_unique<SimplifyReshapeConstant>(context));
-  results.push_back(llvm::make_unique<SimplifyReshapeReshape>(context));
-  results.push_back(llvm::make_unique<SimplifyNullReshape>(context));
+  mlir::RewriteListBuilder<SimplifyReshapeConstant, SimplifyReshapeReshape,
+                           SimplifyNullReshape>::build(results, context);
 }
 
 } // namespace toy
index 3509f8b..093a595 100644 (file)
@@ -125,7 +125,7 @@ protected:
   // Initialize the list of converters.
   void initConverters(OwningRewritePatternList &patterns,
                       MLIRContext *context) override {
-    ConversionListBuilder<MulOpConversion>::build(patterns, context);
+    RewriteListBuilder<MulOpConversion>::build(patterns, context);
   }
 };
 
index b54c8dd..4cb34d1 100644 (file)
@@ -326,9 +326,9 @@ protected:
   /// Initialize the list of converters.
   void initConverters(OwningRewritePatternList &patterns,
                       MLIRContext *context) override {
-    ConversionListBuilder<AddOpConversion, PrintOpConversion,
-                          ConstantOpConversion, TransposeOpConversion,
-                          ReturnOpConversion>::build(patterns, context);
+    RewriteListBuilder<AddOpConversion, PrintOpConversion, ConstantOpConversion,
+                       TransposeOpConversion,
+                       ReturnOpConversion>::build(patterns, context);
   }
 
   /// Convert a Toy type, this gets called for block and region arguments, and
index 2330447..6e05eaf 100644 (file)
@@ -169,9 +169,8 @@ void TransposeOp::getCanonicalizationPatterns(
 // Register our patterns for rewrite by the Canonicalization framework.
 void ReshapeOp::getCanonicalizationPatterns(
     mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
-  results.push_back(llvm::make_unique<SimplifyReshapeConstant>(context));
-  results.push_back(llvm::make_unique<SimplifyReshapeReshape>(context));
-  results.push_back(llvm::make_unique<SimplifyNullReshape>(context));
+  mlir::RewriteListBuilder<SimplifyReshapeConstant, SimplifyReshapeReshape,
+                           SimplifyNullReshape>::build(results, context);
 }
 
 namespace {
index 22797e4..72ca5a7 100644 (file)
@@ -628,7 +628,7 @@ protected:
     // conversion instances given a list of classes as template parameters.
     // These instances will be allocated within `allocator` and their lifetime
     // is managed by the Lowering class.
-    return ConversionListBuilder<
+    return RewriteListBuilder<
         LoadOpConversion, SliceOpConversion, StoreOpConversion,
         ViewOpConversion>::build(allocator, context);
   }
index 5b69bdd..8124c79 100644 (file)
@@ -51,9 +51,8 @@ public:
   // This gets called once to set up operation converters.
   llvm::DenseSet<DialectOpConversion *>
   initConverters(MLIRContext *context) override {
-    return ConversionListBuilder<MulOpConversion,
-                                 PrintOpConversion,
-                                 TransposeOpConversion>::build(allocator, context);
+    RewriteListBuilder<MulOpConversion, PrintOpConversion,
+                       TransposeOpConversion>::build(allocator, context);
   }
 
 private:
index 4ce0eb0..784674b 100644 (file)
@@ -348,6 +348,29 @@ private:
 ///
 bool applyPatternsGreedily(Function &fn, OwningRewritePatternList &&patterns);
 
+/// Helper class to create a list of rewrite patterns given a list of their
+/// types and a list of attributes perfect-forwarded to each of the conversion
+/// constructors.
+template <typename Arg, typename... Args> struct RewriteListBuilder {
+  template <typename... ConstructorArgs>
+  static void build(OwningRewritePatternList &patterns,
+                    ConstructorArgs &&... constructorArgs) {
+    RewriteListBuilder<Args...>::build(
+        patterns, std::forward<ConstructorArgs>(constructorArgs)...);
+    RewriteListBuilder<Arg>::build(
+        patterns, std::forward<ConstructorArgs>(constructorArgs)...);
+  }
+};
+
+// Template specialization to stop recursion.
+template <typename Arg> struct RewriteListBuilder<Arg> {
+  template <typename... ConstructorArgs>
+  static void build(OwningRewritePatternList &patterns,
+                    ConstructorArgs &&... constructorArgs) {
+    patterns.emplace_back(llvm::make_unique<Arg>(
+        std::forward<ConstructorArgs>(constructorArgs)...));
+  }
+};
 } // end namespace mlir
 
 #endif // MLIR_PATTERN_MATCH_H
index 107db47..0790e45 100644 (file)
@@ -103,30 +103,6 @@ private:
   using RewritePattern::rewrite;
 };
 
-// Helper class to create a list of dialect conversion patterns given a list of
-// their types and a list of attributes perfect-forwarded to each of the
-// conversion constructors.
-template <typename Arg, typename... Args> struct ConversionListBuilder {
-  template <typename... ConstructorArgs>
-  static void build(OwningRewritePatternList &patterns,
-                    ConstructorArgs &&... constructorArgs) {
-    ConversionListBuilder<Args...>::build(
-        patterns, std::forward<ConstructorArgs>(constructorArgs)...);
-    ConversionListBuilder<Arg>::build(
-        patterns, std::forward<ConstructorArgs>(constructorArgs)...);
-  }
-};
-
-// Template specialization to stop recursion.
-template <typename Arg> struct ConversionListBuilder<Arg> {
-  template <typename... ConstructorArgs>
-  static void build(OwningRewritePatternList &patterns,
-                    ConstructorArgs &&... constructorArgs) {
-    patterns.emplace_back(llvm::make_unique<Arg>(
-        std::forward<ConstructorArgs>(constructorArgs)...));
-  }
-};
-
 /// Base class for dialect conversion interface.  Specific converters must
 /// derive this class and implement the pure virtual functions.
 ///
index 9758f98..6de9a15 100644 (file)
@@ -941,7 +941,7 @@ void LLVMLowering::initConverters(OwningRewritePatternList &patterns,
   module = &llvmDialect->getLLVMModule();
 
   // FIXME: this should be tablegen'ed
-  ConversionListBuilder<
+  RewriteListBuilder<
       AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering,
       BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering,
       CondBranchOpLowering, ConstLLVMOpLowering, DeallocOpLowering,
index 4a07b29..2ecea9c 100644 (file)
@@ -588,12 +588,14 @@ namespace {
 class Lowering : public LLVMLowering {
 protected:
   void initAdditionalConverters(OwningRewritePatternList &patterns) override {
-    return ConversionListBuilder<
-        BufferAllocOpConversion, BufferDeallocOpConversion,
-        BufferSizeOpConversion, DimOpConversion, DotOpConversion,
-        LoadOpConversion, RangeOpConversion, RangeIntersectOpConversion,
-        SliceOpConversion, StoreOpConversion,
-        ViewOpConversion>::build(patterns, llvmDialect->getContext(), *this);
+    RewriteListBuilder<BufferAllocOpConversion, BufferDeallocOpConversion,
+                       BufferSizeOpConversion, DimOpConversion, DotOpConversion,
+                       LoadOpConversion, RangeOpConversion,
+                       RangeIntersectOpConversion, SliceOpConversion,
+                       StoreOpConversion,
+                       ViewOpConversion>::build(patterns,
+                                                llvmDialect->getContext(),
+                                                *this);
   }
 
   Type convertAdditionalType(Type t) override {
index aabaed5..e2989d1 100644 (file)
@@ -378,8 +378,8 @@ struct SimplifyDeadAlloc : public RewritePattern {
 
 void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                           MLIRContext *context) {
-  results.push_back(llvm::make_unique<SimplifyAllocConst>(context));
-  results.push_back(llvm::make_unique<SimplifyDeadAlloc>(context));
+  RewriteListBuilder<SimplifyAllocConst, SimplifyDeadAlloc>::build(results,
+                                                                   context);
 }
 
 //===----------------------------------------------------------------------===//