NFC: Tidy up DialectConversion.cpp and rename DialectOpConversion to DialectConv...
authorRiver Riddle <riverriddle@google.com>
Mon, 20 May 2019 03:54:13 +0000 (20:54 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 20 May 2019 20:48:19 +0000 (13:48 -0700)
--

PiperOrigin-RevId: 248980810

mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h
mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp
mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp
mlir/examples/toy/Ch5/mlir/LateLowering.cpp
mlir/g3doc/Tutorials/Linalg/LLVMConversion.md
mlir/g3doc/Tutorials/Toy/Ch-5.md
mlir/include/mlir/LLVMIR/LLVMLowering.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
mlir/lib/Transforms/DialectConversion.cpp

index 2abe6ba..cf5d7ef 100644 (file)
@@ -25,7 +25,7 @@
 
 namespace mlir {
 class DialectConversion;
-class DialectOpConversion;
+class DialectConversionPattern;
 class MLIRContext;
 class Module;
 class RewritePattern;
index 0b8d6bd..a23828d 100644 (file)
@@ -142,10 +142,11 @@ static ArrayAttr makePositionAttr(FuncBuilder &builder,
 }
 
 // RangeOp creates a new range descriptor.
-class RangeOpConversion : public DialectOpConversion {
+class RangeOpConversion : public DialectConversionPattern {
 public:
   explicit RangeOpConversion(MLIRContext *context)
-      : DialectOpConversion(linalg::RangeOp::getOperationName(), 1, context) {}
+      : DialectConversionPattern(linalg::RangeOp::getOperationName(), 1,
+                                 context) {}
 
   void rewrite(Operation *op, ArrayRef<Value *> operands,
                PatternRewriter &rewriter) const override {
@@ -168,10 +169,11 @@ public:
   }
 };
 
-class ViewOpConversion : public DialectOpConversion {
+class ViewOpConversion : public DialectConversionPattern {
 public:
   explicit ViewOpConversion(MLIRContext *context)
-      : DialectOpConversion(linalg::ViewOp::getOperationName(), 1, context) {}
+      : DialectConversionPattern(linalg::ViewOp::getOperationName(), 1,
+                                 context) {}
 
   void rewrite(Operation *op, ArrayRef<Value *> operands,
                PatternRewriter &rewriter) const override {
@@ -294,10 +296,11 @@ public:
   }
 };
 
-class SliceOpConversion : public DialectOpConversion {
+class SliceOpConversion : public DialectConversionPattern {
 public:
   explicit SliceOpConversion(MLIRContext *context)
-      : DialectOpConversion(linalg::SliceOp::getOperationName(), 1, context) {}
+      : DialectConversionPattern(linalg::SliceOp::getOperationName(), 1,
+                                 context) {}
 
   void rewrite(Operation *op, ArrayRef<Value *> operands,
                PatternRewriter &rewriter) const override {
@@ -387,10 +390,10 @@ public:
 
 // When converting the "some_consumer" operation, don't emit anything and
 // effectively drop it.
-class DropConsumer : public DialectOpConversion {
+class DropConsumer : public DialectConversionPattern {
 public:
   explicit DropConsumer(MLIRContext *context)
-      : DialectOpConversion("some_consumer", 1, context) {}
+      : DialectConversionPattern("some_consumer", 1, context) {}
 
   void rewrite(Operation *op, ArrayRef<Value *> operands,
                PatternRewriter &rewriter) const override {
index b475ed2..98bd867 100644 (file)
@@ -52,10 +52,10 @@ namespace {
 // Common functionality for Linalg LoadOp and StoreOp conversion to the
 // LLVM IR Dialect.
 template <typename Op>
-class LoadStoreOpConversion : public DialectOpConversion {
+class LoadStoreOpConversion : public DialectConversionPattern {
 public:
   explicit LoadStoreOpConversion(MLIRContext *context)
-      : DialectOpConversion(Op::getOperationName(), 1, context) {}
+      : DialectConversionPattern(Op::getOperationName(), 1, context) {}
   using Base = LoadStoreOpConversion<Op>;
 
   // Compute the pointer to an element of the buffer underlying the view given
index 093a595..b6e0703 100644 (file)
@@ -77,14 +77,14 @@ Value *memRefTypeCast(FuncBuilder &builder, Value *val) {
 
 /// Lower toy.mul to Linalg `matmul`.
 ///
-/// This class inherit from `DialectOpConversion` and override `rewrite`,
+/// This class inherit from `DialectConversionPattern` and override `rewrite`,
 /// similarly to the PatternRewriter introduced in the previous chapter.
 /// It will be called by the DialectConversion framework (see `LateLowering`
 /// class below).
-class MulOpConversion : public DialectOpConversion {
+class MulOpConversion : public DialectConversionPattern {
 public:
   explicit MulOpConversion(MLIRContext *context)
-      : DialectOpConversion(toy::MulOp::getOperationName(), 1, context) {}
+      : DialectConversionPattern(toy::MulOp::getOperationName(), 1, context) {}
 
   void rewrite(Operation *op, ArrayRef<Value *> operands,
                PatternRewriter &rewriter) const override {
index 00b3a74..d504865 100644 (file)
@@ -77,14 +77,14 @@ Value *memRefTypeCast(FuncBuilder &builder, Value *val) {
 
 /// Lower a toy.add to an affine loop nest.
 ///
-/// This class inherit from `DialectOpConversion` and override `rewrite`,
+/// This class inherit from `DialectConversionPattern` and override `rewrite`,
 /// similarly to the PatternRewriter introduced in the previous chapter.
 /// It will be called by the DialectConversion framework (see `LateLowering`
 /// class below).
-class AddOpConversion : public DialectOpConversion {
+class AddOpConversion : public DialectConversionPattern {
 public:
   explicit AddOpConversion(MLIRContext *context)
-      : DialectOpConversion(toy::AddOp::getOperationName(), 1, context) {}
+      : DialectConversionPattern(toy::AddOp::getOperationName(), 1, context) {}
 
   /// Lower the `op` by generating IR using the `rewriter` builder. The builder
   /// is setup with a new function, the `operands` array has been populated with
@@ -125,10 +125,11 @@ public:
 
 /// Lowers `toy.print` to a loop nest calling `printf` on every individual
 /// elements of the array.
-class PrintOpConversion : public DialectOpConversion {
+class PrintOpConversion : public DialectConversionPattern {
 public:
   explicit PrintOpConversion(MLIRContext *context)
-      : DialectOpConversion(toy::PrintOp::getOperationName(), 1, context) {}
+      : DialectConversionPattern(toy::PrintOp::getOperationName(), 1, context) {
+  }
 
   void rewrite(Operation *op, ArrayRef<Value *> operands,
                PatternRewriter &rewriter) const override {
@@ -227,10 +228,11 @@ private:
 };
 
 /// Lowers constant to a sequence of store in a buffer.
-class ConstantOpConversion : public DialectOpConversion {
+class ConstantOpConversion : public DialectConversionPattern {
 public:
   explicit ConstantOpConversion(MLIRContext *context)
-      : DialectOpConversion(toy::ConstantOp::getOperationName(), 1, context) {}
+      : DialectConversionPattern(toy::ConstantOp::getOperationName(), 1,
+                                 context) {}
 
   void rewrite(Operation *op, ArrayRef<Value *> operands,
                PatternRewriter &rewriter) const override {
@@ -270,10 +272,11 @@ public:
 };
 
 /// Lower transpose operation to an affine loop nest.
-class TransposeOpConversion : public DialectOpConversion {
+class TransposeOpConversion : public DialectConversionPattern {
 public:
   explicit TransposeOpConversion(MLIRContext *context)
-      : DialectOpConversion(toy::TransposeOp::getOperationName(), 1, context) {}
+      : DialectConversionPattern(toy::TransposeOp::getOperationName(), 1,
+                                 context) {}
 
   void rewrite(Operation *op, ArrayRef<Value *> operands,
                PatternRewriter &rewriter) const override {
@@ -302,10 +305,11 @@ public:
 };
 
 // Lower toy.return to standard return operation.
-class ReturnOpConversion : public DialectOpConversion {
+class ReturnOpConversion : public DialectConversionPattern {
 public:
   explicit ReturnOpConversion(MLIRContext *context)
-      : DialectOpConversion(toy::ReturnOp::getOperationName(), 1, context) {}
+      : DialectConversionPattern(toy::ReturnOp::getOperationName(), 1,
+                                 context) {}
 
   void rewrite(Operation *op, ArrayRef<Value *> operands,
                PatternRewriter &rewriter) const override {
index 72ca5a7..9f5c552 100644 (file)
@@ -212,11 +212,11 @@ offset, the min/max values or the strides. Their static (constant) dimensions
 are available directly in the type signature.
 
 An operation conversion is defined as special pattern by inheriting from
-`mlir::DialectOpConversion` and by reimplementing the matching and the rewriting
-functions:
+`mlir::DialectConversionPattern` and by reimplementing the matching and the
+rewriting functions:
 
 ```c++
-class ViewOpConversion : public DialectOpConversion {
+class ViewOpConversion : public DialectConversionPattern {
 public:
   // A conversion constructor, may take arbtirary operands but must be able
   // to obtain an MLIRContext from them to call the parent constructor.
@@ -237,15 +237,15 @@ public:
 }
 ```
 
-The `DialectOpConversion` constructor takes, in addition to the context, the
-name of the main operation to be matched and the "benefit" of a match. These
+The `DialectConversionPattern` constructor takes, in addition to the context,
+the name of the main operation to be matched and the "benefit" of a match. These
 operands are intended to be useful for defining an optimization problem across
 multiple possible conversions but are currently ignored by the conversion
 framework.
 
 ```c++
 ViewOpConversion::ViewOpConversion(MLIRContext *context)
-      : DialectOpConversion(linalg::ViewOp::getOperationName(), 1, context) {}
+      : DialectConversionPattern(linalg::ViewOp::getOperationName(), 1, context) {}
 ```
 
 The matching function can be used, for example, to apply different conversions
@@ -621,7 +621,7 @@ class Lowering : public DialectConversion {
 protected:
   // Produce a set of operation conversion patterns.  This is called once per
   // conversion.
-  llvm::DenseSet<DialectOpConversion *>
+  llvm::DenseSet<DialectConversionPattern *>
   initConverter(MLIRContext *context) override {
     allocator.Reset();
     // Call a helper function provided by MLIR to build a set of operation
index 8124c79..37ac637 100644 (file)
@@ -49,7 +49,7 @@ public:
       SmallVectorImpl<NamedAttributeList> &convertedArgAttrs) { /*...*/ }
 
   // This gets called once to set up operation converters.
-  llvm::DenseSet<DialectOpConversion *>
+  llvm::DenseSet<DialectConversionPattern *>
   initConverters(MLIRContext *context) override {
     RewriteListBuilder<MulOpConversion, PrintOpConversion,
                        TransposeOpConversion>::build(allocator, context);
@@ -65,14 +65,14 @@ Individual operation converters are following this pattern:
 ```c++
 /// Lower a toy.add to an affine loop nest.
 ///
-/// This class inherit from `DialectOpConversion` and override `rewrite`,
+/// This class inherit from `DialectConversionPattern` and override `rewrite`,
 /// similarly to the PatternRewriter introduced in the previous chapter.
 /// It will be called by the DialectConversion framework (see `LateLowering`
 /// class below).
-class AddOpConversion : public DialectOpConversion {
+class AddOpConversion : public DialectConversionPattern {
 public:
   explicit AddOpConversion(MLIRContext *context)
-      : DialectOpConversion(toy::AddOp::getOperationName(), 1, context) {}
+      : DialectConversionPattern(toy::AddOp::getOperationName(), 1, context) {}
 
   /// Lower the `op` by generating IR using the `rewriter` builder. The builder
   /// is setup with a new function, the `operands` array has been populated with
index 6cec38c..c2bf040 100644 (file)
@@ -131,7 +131,7 @@ private:
 /// Base class for operation conversions targeting the LLVM IR dialect. Provides
 /// conversion patterns with an access to the containing LLVMLowering for the
 /// purpose of type conversions.
-class LLVMOpLowering : public DialectOpConversion {
+class LLVMOpLowering : public DialectConversionPattern {
 public:
   LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
                  LLVMLowering &lowering);
index 48376e5..8bbce1d 100644 (file)
@@ -36,23 +36,17 @@ class Operation;
 class Type;
 class Value;
 
-// Private implementation class.
-namespace impl {
-class FunctionConverter;
-}
-
-/// Base class for the dialect op conversion patterns.  Specific conversions
-/// must derive this class and implement least one `rewrite` method. Optionally
-/// they can also override `PatternMatch match(Operation *)` to match more
-/// specific operations than the `rootName` provided in the constructor.
+/// Base class for the dialect conversion patterns that require type changes.
+/// Specific conversions must derive this class and implement least one
+/// `rewrite` method.
 /// NOTE: These conversion patterns can only be used with the DialectConversion
 /// class.
-class DialectOpConversion : public RewritePattern {
+class DialectConversionPattern : public RewritePattern {
 public:
-  /// Construct an DialectOpConversion.  `rootName` must correspond to the
+  /// Construct an DialectConversionPattern.  `rootName` must correspond to the
   /// canonical name of the first operation matched by the pattern.
-  DialectOpConversion(StringRef rootName, PatternBenefit benefit,
-                      MLIRContext *ctx)
+  DialectConversionPattern(StringRef rootName, PatternBenefit benefit,
+                           MLIRContext *ctx)
       : RewritePattern(rootName, benefit, ctx) {}
 
   /// Hook for derived classes to implement matching. Dialect conversion
@@ -65,10 +59,10 @@ public:
   /// Hook for derived classes to implement rewriting. `op` is the (first)
   /// operation matched by the pattern, `operands` is a list of rewritten values
   /// that are passed to this operation, `rewriter` can be used to emit the new
-  /// operations. This function must be reimplemented if the DialectOpConversion
-  /// ever needs to replace an operation that does not have successors. This
-  /// function should not fail. If some specific cases of the operation are not
-  /// supported, these cases should not be matched.
+  /// operations. This function must be reimplemented if the
+  /// DialectConversionPattern ever needs to replace an operation that does not
+  /// have successors. This function should not fail. If some specific cases of
+  /// the operation are not supported, these cases should not be matched.
   virtual void rewrite(Operation *op, ArrayRef<Value *> operands,
                        PatternRewriter &rewriter) const {
     llvm_unreachable("unimplemented rewrite");
@@ -80,7 +74,7 @@ public:
   /// of (potentially rewritten) successor blocks, `operands` is a list of lists
   /// of rewritten values passed to each of the successors, co-indexed with
   /// `destinations`, `rewriter` can be used to emit the new operations. It must
-  /// be reimplemented if the DialectOpConversion ever needs to replace a
+  /// be reimplemented if the DialectConversionPattern ever needs to replace a
   /// terminator operation that has successors. This function should not fail
   /// the pass. If some specific cases of the operation are not supported,
   /// these cases should not be matched.
@@ -126,8 +120,6 @@ private:
 ///
 /// If the conversion fails, the module is not modified.
 class DialectConversion {
-  friend class impl::FunctionConverter;
-
 public:
   virtual ~DialectConversion() = default;
 
@@ -135,7 +127,6 @@ public:
   LLVM_NODISCARD
   LogicalResult convert(Module *m);
 
-protected:
   /// Derived classes must implement this hook to produce a set of conversion
   /// patterns to apply.  They may use `mlirContext` to obtain registered
   /// dialects or operations.  This will be called in the beginning of the
index f1c43ef..36267e9 100644 (file)
@@ -195,7 +195,7 @@ static Type getMemRefElementPtrType(MemRefType t, LLVMLowering &lowering) {
 
 LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
                                LLVMLowering &lowering_)
-    : DialectOpConversion(rootOpName, /*benefit=*/1, context),
+    : DialectConversionPattern(rootOpName, /*benefit=*/1, context),
       lowering(lowering_) {}
 
 namespace {
index 3b16323..9c428d4 100644 (file)
@@ -25,6 +25,9 @@
 using namespace mlir;
 using namespace mlir::impl;
 
+//===----------------------------------------------------------------------===//
+// ProducerGenerator
+//===----------------------------------------------------------------------===//
 namespace {
 /// This class provides a simple interface for generating fake producers during
 /// the conversion process. These fake producers are used when replacing the
@@ -87,8 +90,12 @@ struct ProducerGenerator {
   UnknownLoc loc;
 };
 
-/// This class implements a pattern rewriter for DialectOpConversion patterns.
-/// It automatically performs remapping of replaced operation values.
+//===----------------------------------------------------------------------===//
+// DialectConversionRewriter
+//===----------------------------------------------------------------------===//
+
+/// This class implements a pattern rewriter for DialectConversionPattern
+/// patterns. It automatically performs remapping of replaced operation values.
 struct DialectConversionRewriter final : public PatternRewriter {
   DialectConversionRewriter(Function *fn)
       : PatternRewriter(fn), tempGenerator(fn->getContext()) {}
@@ -130,13 +137,17 @@ struct DialectConversionRewriter final : public PatternRewriter {
 };
 } // end anonymous namespace
 
+//===----------------------------------------------------------------------===//
+// DialectConversionPattern
+//===----------------------------------------------------------------------===//
+
 /// Rewrite the IR rooted at the specified operation with the result of
 /// this pattern, generating any new operations with the specified
 /// builder.  If an unexpected error is encountered (an internal
 /// compiler error), it is emitted through the normal MLIR diagnostic
 /// hooks and the IR is left in a valid state.
-void DialectOpConversion::rewrite(Operation *op,
-                                  PatternRewriter &rewriter) const {
+void DialectConversionPattern::rewrite(Operation *op,
+                                       PatternRewriter &rewriter) const {
   SmallVector<Value *, 4> operands;
   auto &dialectRewriter = static_cast<DialectConversionRewriter &>(rewriter);
   dialectRewriter.lookupValues(op->getOperands(), operands);
@@ -168,8 +179,10 @@ void DialectOpConversion::rewrite(Operation *op,
           destinations, operandsPerDestination, rewriter);
 }
 
-namespace mlir {
-namespace impl {
+//===----------------------------------------------------------------------===//
+// FunctionConverter
+//===----------------------------------------------------------------------===//
+namespace {
 // Implementation detail class of the DialectConversion utility.  Performs
 // function-by-function conversions by creating new functions, filling them in
 // with converted blocks, updating the function attributes, and replacing the
@@ -211,8 +224,7 @@ public:
   /// The matcher to use when converting operations.
   RewritePatternMatcher &matcher;
 };
-} // end namespace impl
-} // end namespace mlir
+} // end anonymous namespace
 
 LogicalResult
 FunctionConverter::convertArgument(DialectConversionRewriter &rewriter,
@@ -326,6 +338,10 @@ Function *FunctionConverter::convertFunction(Function *f) {
   return newFunc;
 }
 
+//===----------------------------------------------------------------------===//
+// DialectConversion
+//===----------------------------------------------------------------------===//
+
 // Create a function type with arguments and results converted, and argument
 // attributes passed through.
 FunctionType DialectConversion::convertFunctionSignatureType(