NFC: Expose a ConversionPatternRewriter for use with ConversionPatterns.
authorRiver Riddle <riverriddle@google.com>
Thu, 18 Jul 2019 19:04:57 +0000 (12:04 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Fri, 19 Jul 2019 18:40:00 +0000 (11:40 -0700)
This specific PatternRewriter will allow for exposing hooks in the future that are only useful for the conversion framework, e.g. type conversions.

PiperOrigin-RevId: 258818122

12 files changed:
mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp
mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp
mlir/examples/toy/Ch5/mlir/LateLowering.cpp
mlir/include/mlir/IR/PatternMatch.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp
mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/lib/Transforms/DialectConversion.cpp
mlir/test/lib/TestDialect/TestPatterns.cpp
mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp

index c43a2ae..67b0ac0 100644 (file)
@@ -138,8 +138,9 @@ public:
   explicit RangeOpConversion(MLIRContext *context)
       : ConversionPattern(linalg::RangeOp::getOperationName(), 1, context) {}
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     auto rangeOp = cast<linalg::RangeOp>(op);
     auto rangeDescriptorType =
         linalg::convertLinalgType(rangeOp.getResult()->getType());
@@ -165,8 +166,9 @@ public:
   explicit ViewOpConversion(MLIRContext *context)
       : ConversionPattern(linalg::ViewOp::getOperationName(), 1, context) {}
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     auto viewOp = cast<linalg::ViewOp>(op);
     auto viewDescriptorType = linalg::convertLinalgType(viewOp.getViewType());
     auto memrefType =
@@ -290,8 +292,9 @@ public:
   explicit SliceOpConversion(MLIRContext *context)
       : ConversionPattern(linalg::SliceOp::getOperationName(), 1, context) {}
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     auto sliceOp = cast<linalg::SliceOp>(op);
     auto newViewDescriptorType =
         linalg::convertLinalgType(sliceOp.getViewType());
@@ -382,8 +385,9 @@ public:
   explicit DropConsumer(MLIRContext *context)
       : ConversionPattern("some_consumer", 1, context) {}
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     rewriter.replaceOp(op, llvm::None);
     return matchSuccess();
   }
index c86f5d7..68a48d6 100644 (file)
@@ -96,8 +96,9 @@ public:
 // an LLVM IR load.
 class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
   using Base::Base;
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     edsc::ScopedContext edscContext(rewriter, op->getLoc());
     auto elementType = linalg::convertLinalgType(*op->result_type_begin());
     Value *viewDescriptor = operands[0];
@@ -113,8 +114,9 @@ class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
 // an LLVM IR store.
 class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
   using Base::Base;
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     edsc::ScopedContext edscContext(rewriter, op->getLoc());
     Value *viewDescriptor = operands[1];
     Value *data = operands[0];
index e4df917..f3463ba 100644 (file)
@@ -57,7 +57,7 @@ namespace {
 /// time both side of the cast (producer and consumer) will be lowered to a
 /// dialect like LLVM and end up with the same LLVM representation, at which
 /// point this becomes a no-op and is eliminated.
-Value *typeCast(PatternRewriter &builder, Value *val, Type destTy) {
+Value *typeCast(ConversionPatternRewriter &builder, Value *val, Type destTy) {
   if (val->getType() == destTy)
     return val;
   return builder.create<toy::TypeCastOp>(val->getLoc(), val, destTy)
@@ -67,7 +67,7 @@ Value *typeCast(PatternRewriter &builder, Value *val, Type destTy) {
 /// Create a type cast to turn a toy.array into a memref. The Toy Array will be
 /// lowered to a memref during buffer allocation, at which point the type cast
 /// becomes useless.
-Value *memRefTypeCast(PatternRewriter &builder, Value *val) {
+Value *memRefTypeCast(ConversionPatternRewriter &builder, Value *val) {
   if (val->getType().isa<MemRefType>())
     return val;
   auto toyArrayTy = val->getType().dyn_cast<toy::ToyArrayType>();
@@ -87,8 +87,9 @@ public:
   explicit MulOpConversion(MLIRContext *context)
       : ConversionPattern(toy::MulOp::getOperationName(), 1, context) {}
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     using namespace edsc;
     using intrinsics::constant_index;
     using linalg::intrinsics::range;
index cd826fb..8b80588 100644 (file)
@@ -92,8 +92,9 @@ public:
   /// the rewritten operands for `op` in the new function.
   /// The results created by the new IR with the builder are returned, and their
   /// number must match the number of result of `op`.
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     auto add = cast<toy::AddOp>(op);
     auto loc = add.getLoc();
     // Create a `toy.alloc` operation to allocate the output buffer for this op.
@@ -133,8 +134,9 @@ public:
   explicit PrintOpConversion(MLIRContext *context)
       : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {}
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     // Get or create the declaration of the printf function in the module.
     FuncOp printfFunc = getPrintf(op->getParentOfType<ModuleOp>());
 
@@ -232,8 +234,9 @@ public:
   explicit ConstantOpConversion(MLIRContext *context)
       : ConversionPattern(toy::ConstantOp::getOperationName(), 1, context) {}
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     toy::ConstantOp cstOp = cast<toy::ConstantOp>(op);
     auto loc = cstOp.getLoc();
     auto retTy = cstOp.getResult()->getType().cast<toy::ToyArrayType>();
@@ -276,8 +279,9 @@ public:
   explicit TransposeOpConversion(MLIRContext *context)
       : ConversionPattern(toy::TransposeOp::getOperationName(), 1, context) {}
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     auto transpose = cast<toy::TransposeOp>(op);
     auto loc = transpose.getLoc();
     Value *result = memRefTypeCast(
@@ -309,8 +313,9 @@ public:
   explicit ReturnOpConversion(MLIRContext *context)
       : ConversionPattern(toy::ReturnOp::getOperationName(), 1, context) {}
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     // Argument is optional, handle both cases.
     if (op->getNumOperands())
       rewriter.replaceOpWithNewOp<ReturnOp>(op, operands[0]);
index de68e4b..d739a80 100644 (file)
@@ -321,7 +321,10 @@ public:
   /// (perhaps transitively) dead.  If any of those values are dead, this will
   /// remove them as well.
   virtual void replaceOp(Operation *op, ArrayRef<Value *> newValues,
-                         ArrayRef<Value *> valuesToRemoveIfDead = {});
+                         ArrayRef<Value *> valuesToRemoveIfDead);
+  void replaceOp(Operation *op, ArrayRef<Value *> newValues) {
+    replaceOp(op, newValues, llvm::None);
+  }
 
   /// Replaces the result op with a new op that is created without verification.
   /// The result values of the two ops must be the same types.
index bfe3674..68c6f12 100644 (file)
@@ -31,6 +31,7 @@ namespace mlir {
 
 // Forward declarations.
 class Block;
+class ConversionPatternRewriter;
 class FuncOp;
 class MLIRContext;
 class Operation;
@@ -192,7 +193,7 @@ public:
   /// have successors. This function should not fail. If some specific cases of
   /// the operation are not supported, these cases should not be matched.
   virtual void rewrite(Operation *op, ArrayRef<Value *> operands,
-                       PatternRewriter &rewriter) const {
+                       ConversionPatternRewriter &rewriter) const {
     llvm_unreachable("unimplemented rewrite");
   }
 
@@ -209,7 +210,7 @@ public:
   virtual void rewrite(Operation *op, ArrayRef<Value *> properOperands,
                        ArrayRef<Block *> destinations,
                        ArrayRef<ArrayRef<Value *>> operands,
-                       PatternRewriter &rewriter) const {
+                       ConversionPatternRewriter &rewriter) const {
     llvm_unreachable("unimplemented rewrite for terminators");
   }
 
@@ -218,7 +219,7 @@ public:
   matchAndRewrite(Operation *op, ArrayRef<Value *> properOperands,
                   ArrayRef<Block *> destinations,
                   ArrayRef<ArrayRef<Value *>> operands,
-                  PatternRewriter &rewriter) const {
+                  ConversionPatternRewriter &rewriter) const {
     if (!match(op))
       return matchFailure();
     rewrite(op, properOperands, destinations, operands, rewriter);
@@ -226,9 +227,9 @@ public:
   }
 
   /// Hook for derived classes to implement combined matching and rewriting.
-  virtual PatternMatchResult matchAndRewrite(Operation *op,
-                                             ArrayRef<Value *> operands,
-                                             PatternRewriter &rewriter) const {
+  virtual PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const {
     if (!match(op))
       return matchFailure();
     rewrite(op, operands, rewriter);
@@ -244,6 +245,50 @@ private:
 };
 
 //===----------------------------------------------------------------------===//
+// Conversion PatternRewriter
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+struct ConversionPatternRewriterImpl;
+} // end namespace detail
+
+/// This class implements a pattern rewriter for use with ConversionPatterns.
+class ConversionPatternRewriter final : public PatternRewriter {
+public:
+  ConversionPatternRewriter(MLIRContext *ctx, TypeConverter *converter);
+  ~ConversionPatternRewriter() override;
+
+  //===--------------------------------------------------------------------===//
+  // PatternRewriter Hooks
+  //===--------------------------------------------------------------------===//
+
+  /// PatternRewriter hook for replacing the results of an operation.
+  void replaceOp(Operation *op, ArrayRef<Value *> newValues,
+                 ArrayRef<Value *> valuesToRemoveIfDead) override;
+  using PatternRewriter::replaceOp;
+
+  /// PatternRewriter hook for splitting a block into two parts.
+  Block *splitBlock(Block *block, Block::iterator before) override;
+
+  /// PatternRewriter hook for moving blocks out of a region.
+  void inlineRegionBefore(Region &region, Region &parent,
+                          Region::iterator before) override;
+  using PatternRewriter::inlineRegionBefore;
+
+  /// PatternRewriter hook for creating a new operation.
+  Operation *createOperation(const OperationState &state) override;
+
+  /// PatternRewriter hook for updating the root operation in-place.
+  void notifyRootUpdated(Operation *op) override;
+
+  /// Return a reference to the internal implementation.
+  detail::ConversionPatternRewriterImpl &getImpl();
+
+private:
+  std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
+};
+
+//===----------------------------------------------------------------------===//
 // ConversionTarget
 //===----------------------------------------------------------------------===//
 
@@ -260,7 +305,7 @@ public:
     /// by the target.
     Dynamic,
 
-    /// This target explicitly does not support this operation.
+    /// The target explicitly does not support this operation.
     Illegal,
   };
 
index 9c2053d..1515d95 100644 (file)
@@ -104,8 +104,9 @@ struct ForLowering : public ConversionPattern {
   ForLowering(MLIRContext *ctx)
       : ConversionPattern(ForOp::getOperationName(), 1, ctx) {}
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override;
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override;
 };
 
 // Create a CFG subgraph for the loop.if operation (including its "then" and
@@ -154,16 +155,18 @@ struct IfLowering : public ConversionPattern {
   IfLowering(MLIRContext *ctx)
       : ConversionPattern(IfOp::getOperationName(), 1, ctx) {}
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override;
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override;
 };
 
 struct TerminatorLowering : public ConversionPattern {
   TerminatorLowering(MLIRContext *ctx)
       : ConversionPattern(TerminatorOp::getOperationName(), 1, ctx) {}
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     rewriter.replaceOp(op, {});
     return matchSuccess();
   }
@@ -172,7 +175,7 @@ struct TerminatorLowering : public ConversionPattern {
 
 PatternMatchResult
 ForLowering::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                             PatternRewriter &rewriter) const {
+                             ConversionPatternRewriter &rewriter) const {
   auto forOp = cast<ForOp>(op);
   Location loc = op->getLoc();
 
@@ -228,7 +231,7 @@ ForLowering::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
 
 PatternMatchResult
 IfLowering::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                            PatternRewriter &rewriter) const {
+                            ConversionPatternRewriter &rewriter) const {
   auto ifOp = cast<IfOp>(op);
   auto loc = op->getLoc();
 
index aa72a7b..042e768 100644 (file)
@@ -229,7 +229,7 @@ public:
   }
 
   // Create an LLVM IR pseudo-operation defining the given index constant.
-  Value *createIndexConstant(PatternRewriter &builder, Location loc,
+  Value *createIndexConstant(ConversionPatternRewriter &builder, Location loc,
                              uint64_t value) const {
     auto attr = builder.getIntegerAttr(builder.getIndexType(), value);
     return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr);
@@ -237,7 +237,7 @@ public:
 
   // Get the array attribute named "position" containing the given list of
   // integers as integer attribute elements.
-  static ArrayAttr getIntegerArrayAttr(PatternRewriter &builder,
+  static ArrayAttr getIntegerArrayAttr(ConversionPatternRewriter &builder,
                                        ArrayRef<int64_t> values) {
     SmallVector<Attribute, 4> attrs;
     attrs.reserve(values.size());
@@ -247,7 +247,8 @@ public:
   }
 
   // Extract raw data pointer value from a value representing a memref.
-  static Value *extractMemRefElementPtr(PatternRewriter &builder, Location loc,
+  static Value *extractMemRefElementPtr(ConversionPatternRewriter &builder,
+                                        Location loc,
                                         Value *convertedMemRefValue,
                                         Type elementTypePtr,
                                         bool hasStaticShape) {
@@ -274,8 +275,9 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
 
   // Convert the type of the result to an LLVM type, pass operands as is,
   // preserve attributes.
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     unsigned numResults = op->getNumResults();
 
     Type packedType;
@@ -398,7 +400,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
   }
 
   void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+               ConversionPatternRewriter &rewriter) const override {
     auto allocOp = cast<AllocOp>(op);
     MemRefType type = allocOp.getType();
 
@@ -495,8 +497,9 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
 struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
   using LLVMLegalizationPattern<DeallocOp>::LLVMLegalizationPattern;
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     assert(operands.size() == 1 && "dealloc takes one operand");
     OperandAdaptor<DeallocOp> transformed(operands);
 
@@ -538,7 +541,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
   }
 
   void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+               ConversionPatternRewriter &rewriter) const override {
     auto memRefCastOp = cast<MemRefCastOp>(op);
     OperandAdaptor<MemRefCastOp> transformed(operands);
     auto targetType = memRefCastOp.getType();
@@ -610,7 +613,7 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
   }
 
   void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+               ConversionPatternRewriter &rewriter) const override {
     auto dimOp = cast<DimOp>(op);
     OperandAdaptor<DimOp> transformed(operands);
     MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>();
@@ -660,7 +663,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
   // by accumulating the running linearized value.
   // Note that `indices` and `allocSizes` are passed in the same order as they
   // appear in load/store operations and memref type declarations.
-  Value *linearizeSubscripts(PatternRewriter &builder, Location loc,
+  Value *linearizeSubscripts(ConversionPatternRewriter &builder, Location loc,
                              ArrayRef<Value *> indices,
                              ArrayRef<Value *> allocSizes) const {
     assert(indices.size() == allocSizes.size() &&
@@ -686,7 +689,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
   Value *getElementPtr(Location loc, Type elementTypePtr,
                        ArrayRef<int64_t> shape, Value *memRefDescriptor,
                        ArrayRef<Value *> indices,
-                       PatternRewriter &rewriter) const {
+                       ConversionPatternRewriter &rewriter) const {
     // Get the list of MemRef sizes.  Static sizes are defined as constants.
     // Dynamic sizes are extracted from the MemRef descriptor, where they start
     // from the position 1 (the buffer is at position 0).
@@ -722,7 +725,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
   Value *getRawElementPtr(Location loc, Type elementTypePtr,
                           ArrayRef<int64_t> shape, Value *rawDataPtr,
                           ArrayRef<Value *> indices,
-                          PatternRewriter &rewriter) const {
+                          ConversionPatternRewriter &rewriter) const {
     if (shape.empty())
       return rawDataPtr;
 
@@ -738,7 +741,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
   }
 
   Value *getDataPtr(Location loc, MemRefType type, Value *dataPtr,
-                    ArrayRef<Value *> indices, PatternRewriter &rewriter,
+                    ArrayRef<Value *> indices,
+                    ConversionPatternRewriter &rewriter,
                     llvm::Module &module) const {
     auto ptrType = getMemRefElementPtrType(type, this->lowering);
     auto shape = type.getShape();
@@ -755,8 +759,9 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
 struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
   using Base::Base;
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     auto loadOp = cast<LoadOp>(op);
     OperandAdaptor<LoadOp> transformed(operands);
     auto type = loadOp.getMemRefType();
@@ -776,8 +781,9 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
 struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
   using Base::Base;
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     auto type = cast<StoreOp>(op).getMemRefType();
     OperandAdaptor<StoreOp> transformed(operands);
 
@@ -796,8 +802,9 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
 struct IndexCastOpLowering : public LLVMLegalizationPattern<IndexCastOp> {
   using LLVMLegalizationPattern<IndexCastOp>::LLVMLegalizationPattern;
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     IndexCastOpOperandAdaptor transformed(operands);
     auto indexCastOp = cast<IndexCastOp>(op);
 
@@ -829,8 +836,9 @@ static LLVM::ICmpPredicate convertCmpIPredicate(CmpIPredicate pred) {
 struct CmpIOpLowering : public LLVMLegalizationPattern<CmpIOp> {
   using LLVMLegalizationPattern<CmpIOp>::LLVMLegalizationPattern;
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     auto cmpiOp = cast<CmpIOp>(op);
     CmpIOpOperandAdaptor transformed(operands);
 
@@ -851,11 +859,11 @@ struct OneToOneLLVMTerminatorLowering
   using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
   using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
 
-  PatternMatchResult matchAndRewrite(Operation *op,
-                                     ArrayRef<Value *> properOperands,
-                                     ArrayRef<Block *> destinations,
-                                     ArrayRef<ArrayRef<Value *>> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> properOperands,
+                  ArrayRef<Block *> destinations,
+                  ArrayRef<ArrayRef<Value *>> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     rewriter.replaceOpWithNewOp<TargetOp>(op, properOperands, destinations,
                                           operands, op->getAttrs());
     return this->matchSuccess();
@@ -871,8 +879,9 @@ struct OneToOneLLVMTerminatorLowering
 struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
   using LLVMLegalizationPattern<ReturnOp>::LLVMLegalizationPattern;
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     unsigned numArguments = op->getNumOperands();
 
     // If ReturnOp has 0 or 1 operand, create it and return immediately.
index fb26f85..98be230 100644 (file)
@@ -164,8 +164,9 @@ public:
                                    LLVMTypeConverter &lowering_)
       : LLVMOpLowering(BufferAllocOp::getOperationName(), context, lowering_) {}
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     auto indexType = IndexType::get(op->getContext());
     auto voidPtrTy =
         LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
@@ -227,8 +228,9 @@ public:
       : LLVMOpLowering(BufferDeallocOp::getOperationName(), context,
                        lowering_) {}
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     auto voidPtrTy =
         LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
     // Insert the `free` declaration if it is not already present.
@@ -261,8 +263,9 @@ public:
   BufferSizeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
       : LLVMOpLowering(BufferSizeOp::getOperationName(), context, lowering_) {}
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
     edsc::ScopedContext context(rewriter, op->getLoc());
     rewriter.replaceOp(
@@ -277,8 +280,9 @@ public:
   explicit DimOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
       : LLVMOpLowering(linalg::DimOp::getOperationName(), context, lowering_) {}
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     auto dimOp = cast<linalg::DimOp>(op);
     auto indexTy = lowering.convertType(rewriter.getIndexType());
     edsc::ScopedContext context(rewriter, op->getLoc());
@@ -307,7 +311,7 @@ public:
   // a getelementptr. This must be called under an edsc::ScopedContext.
   Value *obtainDataPtr(Operation *op, Value *viewDescriptor,
                        ArrayRef<Value *> indices,
-                       PatternRewriter &rewriter) const {
+                       ConversionPatternRewriter &rewriter) const {
     auto loadOp = cast<Op>(op);
     auto elementTy = getPtrToElementType(loadOp.getViewType(), lowering);
     auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
@@ -333,8 +337,9 @@ public:
 // an LLVM IR load.
 class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
   using Base::Base;
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     edsc::ScopedContext edscContext(rewriter, op->getLoc());
     auto elementTy = lowering.convertType(*op->result_type_begin());
     Value *viewDescriptor = operands[0];
@@ -351,8 +356,9 @@ public:
   explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
       : LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {}
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     auto rangeOp = cast<RangeOp>(op);
     auto rangeDescriptorTy =
         convertLinalgType(rangeOp.getResult()->getType(), lowering);
@@ -380,8 +386,9 @@ public:
       : LLVMOpLowering(RangeIntersectOp::getOperationName(), context,
                        lowering_) {}
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     auto rangeIntersectOp = cast<RangeIntersectOp>(op);
     auto rangeDescriptorTy =
         convertLinalgType(rangeIntersectOp.getResult()->getType(), lowering);
@@ -423,8 +430,9 @@ public:
   explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
       : LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {}
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     auto sliceOp = cast<SliceOp>(op);
     auto viewDescriptorTy = convertLinalgType(sliceOp.getViewType(), lowering);
     auto viewType = sliceOp.getBaseViewType();
@@ -503,8 +511,9 @@ public:
 // an LLVM IR store.
 class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
   using Base::Base;
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     edsc::ScopedContext edscContext(rewriter, op->getLoc());
     Value *data = operands[0];
     Value *viewDescriptor = operands[1];
@@ -521,8 +530,9 @@ public:
   explicit ViewOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
       : LLVMOpLowering(ViewOp::getOperationName(), context, lowering_) {}
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     auto viewOp = cast<ViewOp>(op);
     auto viewDescriptorTy = convertLinalgType(viewOp.getViewType(), lowering);
     auto elementTy = getPtrToElementType(viewOp.getViewType(), lowering);
@@ -598,9 +608,9 @@ static FuncOp getLLVMLibraryCallImplDefinition(FuncOp libFn) {
 // Get function definition for the LinalgOp. If it doesn't exist, insert a
 // definition.
 template <typename LinalgOp>
-static FuncOp getLLVMLibraryCallDeclaration(Operation *op,
-                                            LLVMTypeConverter &lowering,
-                                            PatternRewriter &rewriter) {
+static FuncOp
+getLLVMLibraryCallDeclaration(Operation *op, LLVMTypeConverter &lowering,
+                              ConversionPatternRewriter &rewriter) {
   assert(isa<LinalgOp>(op));
   auto fnName = LinalgOp::getLibraryCallName();
   auto module = op->getParentOfType<ModuleOp>();
@@ -689,8 +699,9 @@ public:
                               LinalgTypeConverter &lowering_)
       : LLVMOpLowering(LinalgOp::getOperationName(), context, lowering_) {}
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     // Only emit library call declaration. Fill in the body later.
     auto f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter);
     static_cast<LinalgTypeConverter &>(lowering).addLibraryFnDeclaration(f);
index 3ef9766..ed271b6 100644 (file)
@@ -28,6 +28,7 @@
 #include "llvm/Support/raw_ostream.h"
 
 using namespace mlir;
+using namespace mlir::detail;
 
 #define DEBUG_TYPE "dialect-conversion"
 
@@ -102,6 +103,7 @@ struct ArgConverter {
   /// The pattern rewriter to use when materializing conversions.
   PatternRewriter &rewriter;
 };
+} // end anonymous namespace
 
 constexpr StringLiteral ArgConverter::kCastName;
 
@@ -283,9 +285,9 @@ Operation *ArgConverter::createCast(ArrayRef<Value *> inputs, Type outputType) {
 }
 
 //===----------------------------------------------------------------------===//
-// DialectConversionRewriter
+// ConversionPatternRewriterImpl
 //===----------------------------------------------------------------------===//
-
+namespace {
 /// This class contains a snapshot of the current conversion rewriter state.
 /// This is useful when saving and undoing a set of rewrites.
 struct RewriterState {
@@ -307,10 +309,11 @@ struct RewriterState {
   /// The current number of type conversion actions performed.
   unsigned numTypeConversions;
 };
+} // end anonymous namespace
 
-/// This class implements a pattern rewriter for ConversionPattern
-/// patterns. It automatically performs remapping of replaced operation values.
-struct DialectConversionRewriter final : public PatternRewriter {
+namespace mlir {
+namespace detail {
+struct ConversionPatternRewriterImpl {
   /// This class represents one requested operation replacement via 'replaceOp'.
   struct OpReplacement {
     OpReplacement() = default;
@@ -362,205 +365,55 @@ struct DialectConversionRewriter final : public PatternRewriter {
     NamedAttributeList originalParentAttributes;
   };
 
-  DialectConversionRewriter(MLIRContext *ctx, TypeConverter *converter)
-      : PatternRewriter(ctx), argConverter(converter, *this) {}
-  ~DialectConversionRewriter() = default;
+  ConversionPatternRewriterImpl(PatternRewriter &rewriter,
+                                TypeConverter *converter)
+      : argConverter(converter, rewriter) {}
 
   /// Return the current state of the rewriter.
-  RewriterState getCurrentState() {
-    return RewriterState(createdOps.size(), replacements.size(),
-                         blockActions.size(), typeConversions.size());
-  }
+  RewriterState getCurrentState();
 
   /// Reset the state of the rewriter to a previously saved point.
-  void resetState(RewriterState state) {
-    // Undo any type conversions or block actions.
-    undoTypeConversions(state.numTypeConversions);
-    undoBlockActions(state.numBlockActions);
-
-    // Reset any replaced operations and undo any saved mappings.
-    for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
-      for (auto *result : repl.op->getResults())
-        mapping.erase(result);
-    replacements.resize(state.numReplacements);
-
-    // Pop all of the newly created operations.
-    while (createdOps.size() != state.numCreatedOperations)
-      createdOps.pop_back_val()->erase();
-  }
+  void resetState(RewriterState state);
 
   /// Undo the block actions (motions, splits) one by one in reverse order until
   /// "numActionsToKeep" actions remains.
-  void undoBlockActions(unsigned numActionsToKeep = 0) {
-    for (auto &action :
-         llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) {
-      switch (action.kind) {
-      // Merge back the block that was split out.
-      case BlockActionKind::Split: {
-        action.originalBlock->getOperations().splice(
-            action.originalBlock->end(), action.block->getOperations());
-        action.block->erase();
-        break;
-      }
-      // Move the block back to its original position.
-      case BlockActionKind::Move: {
-        Region *originalRegion = action.originalPosition.region;
-        originalRegion->getBlocks().splice(
-            std::next(originalRegion->begin(),
-                      action.originalPosition.position),
-            action.block->getParent()->getBlocks(), action.block);
-        break;
-      }
-      }
-    }
-    blockActions.resize(numActionsToKeep);
-  }
+  void undoBlockActions(unsigned numActionsToKeep = 0);
 
   /// Undo the type conversion actions one by one, until "numActionsToKeep"
   /// actions remain.
-  void undoTypeConversions(unsigned numActionsToKeep = 0) {
-    for (auto &conversion :
-         llvm::drop_begin(typeConversions, numActionsToKeep)) {
-      if (auto *region = conversion.object.dyn_cast<Region *>())
-        region->getContainingOp()->setAttrs(
-            conversion.originalParentAttributes);
-      else
-        argConverter.discardPendingRewrites(conversion.object.get<Block *>());
-    }
-    typeConversions.resize(numActionsToKeep);
-  }
+  void undoTypeConversions(unsigned numActionsToKeep = 0);
 
   /// Cleanup and destroy any generated rewrite operations. This method is
   /// invoked when the conversion process fails.
-  void discardRewrites() {
-    undoTypeConversions();
-    undoBlockActions();
-
-    // Remove any newly created ops.
-    for (auto *op : createdOps) {
-      op->dropAllDefinedValueUses();
-      op->erase();
-    }
-  }
+  void discardRewrites();
 
   /// Apply all requested operation rewrites. This method is invoked when the
   /// conversion process succeeds.
-  void applyRewrites() {
-    // Apply all of the rewrites replacements requested during conversion.
-    for (auto &repl : replacements) {
-      for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i)
-        repl.op->getResult(i)->replaceAllUsesWith(
-            mapping.lookupOrDefault(repl.newValues[i]));
-
-      // if this operation defines any regions, drop any pending argument
-      // rewrites.
-      if (argConverter.typeConverter && repl.op->getNumRegions()) {
-        for (auto &region : repl.op->getRegions())
-          for (auto &block : region)
-            argConverter.cancelPendingRewrites(&block);
-      }
-    }
-
-    // In a second pass, erase all of the replaced operations in reverse. This
-    // allows processing nested operations before their parent region is
-    // destroyed.
-    for (auto &repl : llvm::reverse(replacements))
-      repl.op->erase();
-
-    argConverter.applyRewrites();
-  }
+  void applyRewrites();
 
   /// Return if the given block has already been converted.
-  bool hasSignatureBeenConverted(Block *block) {
-    return argConverter.hasBeenConverted(block);
-  }
+  bool hasSignatureBeenConverted(Block *block);
 
   /// Convert the signature of the given region.
-  LogicalResult convertRegionSignature(Region &region) {
-    auto parentAttrs = region.getContainingOp()->getAttrList();
-    auto result = argConverter.convertSignature(region, mapping);
-    if (succeeded(result)) {
-      typeConversions.push_back(TypeConversion{&region, parentAttrs});
-      if (!region.empty())
-        typeConversions.push_back(
-            TypeConversion{&region.front(), NamedAttributeList()});
-    }
-    return result;
-  }
+  LogicalResult convertRegionSignature(Region &region);
 
   /// Convert the signature of the given block.
-  LogicalResult convertBlockSignature(Block *block) {
-    auto result = argConverter.convertSignature(block, mapping);
-    if (succeeded(result))
-      typeConversions.push_back(TypeConversion{block, NamedAttributeList()});
-    return result;
-  }
+  LogicalResult convertBlockSignature(Block *block);
 
   /// PatternRewriter hook for replacing the results of an operation.
   void replaceOp(Operation *op, ArrayRef<Value *> newValues,
-                 ArrayRef<Value *> valuesToRemoveIfDead) override {
-    assert(newValues.size() == op->getNumResults());
-
-    // Create mappings for each of the new result values.
-    for (unsigned i = 0, e = newValues.size(); i < e; ++i) {
-      assert((newValues[i] || op->getResult(i)->use_empty()) &&
-             "result value has remaining uses that must be replaced");
-      if (newValues[i])
-        mapping.map(op->getResult(i), newValues[i]);
-    }
+                 ArrayRef<Value *> valuesToRemoveIfDead);
 
-    // Record the requested operation replacement.
-    replacements.emplace_back(op, newValues);
-  }
-
-  /// PatternRewriter hook for splitting a block into two parts.
-  Block *splitBlock(Block *block, Block::iterator before) override {
-    auto *continuation = PatternRewriter::splitBlock(block, before);
-    BlockAction action;
-    action.kind = BlockActionKind::Split;
-    action.block = continuation;
-    action.originalBlock = block;
-    blockActions.push_back(action);
-    return continuation;
-  }
-
-  /// PatternRewriter hook for moving blocks out of a region.
-  void inlineRegionBefore(Region &region, Region &parent,
-                          Region::iterator before) override {
-    for (auto &pair : llvm::enumerate(region)) {
-      Block &block = pair.value();
-      unsigned position = pair.index();
-      BlockAction action;
-      action.kind = BlockActionKind::Move;
-      action.block = &block;
-      action.originalPosition = {&region, position};
-      blockActions.push_back(action);
-    }
-    PatternRewriter::inlineRegionBefore(region, parent, before);
-  }
-
-  /// PatternRewriter hook for creating a new operation.
-  Operation *createOperation(const OperationState &state) override {
-    auto *result = OpBuilder::createOperation(state);
-    createdOps.push_back(result);
-    return result;
-  }
+  /// Notifies that a block was split.
+  void notifySplitBlock(Block *block, Block *continuation);
 
-  /// PatternRewriter hook for updating the root operation in-place.
-  void notifyRootUpdated(Operation *op) override {
-    // The rewriter caches changes to the IR to allow for operating in-place and
-    // backtracking. The rewrite is currently not capable of backtracking
-    // in-place modifications.
-    llvm_unreachable("in-place operation updates are not supported");
-  }
+  /// Notifies that the blocks of a region are about to be moved.
+  void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent,
+                                        Region::iterator before);
 
   /// Remap the given operands to those with potentially different types.
   void remapValues(Operation::operand_range operands,
-                   SmallVectorImpl<Value *> &remapped) {
-    remapped.reserve(llvm::size(operands));
-    for (Value *operand : operands)
-      remapped.push_back(mapping.lookupOrDefault(operand));
-  }
+                   SmallVectorImpl<Value *> &remapped);
 
   // Mapping between replaced values that differ in type. This happens when
   // replacing a value with one of a different type.
@@ -581,7 +434,226 @@ struct DialectConversionRewriter final : public PatternRewriter {
   /// Ordered list of type conversion actions.
   SmallVector<TypeConversion, 4> typeConversions;
 };
-} // end anonymous namespace
+} // end namespace detail
+} // end namespace mlir
+
+RewriterState ConversionPatternRewriterImpl::getCurrentState() {
+  return RewriterState(createdOps.size(), replacements.size(),
+                       blockActions.size(), typeConversions.size());
+}
+
+void ConversionPatternRewriterImpl::resetState(RewriterState state) {
+  // Undo any type conversions or block actions.
+  undoTypeConversions(state.numTypeConversions);
+  undoBlockActions(state.numBlockActions);
+
+  // Reset any replaced operations and undo any saved mappings.
+  for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
+    for (auto *result : repl.op->getResults())
+      mapping.erase(result);
+  replacements.resize(state.numReplacements);
+
+  // Pop all of the newly created operations.
+  while (createdOps.size() != state.numCreatedOperations)
+    createdOps.pop_back_val()->erase();
+}
+
+void ConversionPatternRewriterImpl::undoBlockActions(
+    unsigned numActionsToKeep) {
+  for (auto &action :
+       llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) {
+    switch (action.kind) {
+    // Merge back the block that was split out.
+    case BlockActionKind::Split: {
+      action.originalBlock->getOperations().splice(
+          action.originalBlock->end(), action.block->getOperations());
+      action.block->erase();
+      break;
+    }
+    // Move the block back to its original position.
+    case BlockActionKind::Move: {
+      Region *originalRegion = action.originalPosition.region;
+      originalRegion->getBlocks().splice(
+          std::next(originalRegion->begin(), action.originalPosition.position),
+          action.block->getParent()->getBlocks(), action.block);
+      break;
+    }
+    }
+  }
+  blockActions.resize(numActionsToKeep);
+}
+
+void ConversionPatternRewriterImpl::undoTypeConversions(
+    unsigned numActionsToKeep) {
+  for (auto &conversion : llvm::drop_begin(typeConversions, numActionsToKeep)) {
+    if (auto *region = conversion.object.dyn_cast<Region *>())
+      region->getContainingOp()->setAttrs(conversion.originalParentAttributes);
+    else
+      argConverter.discardPendingRewrites(conversion.object.get<Block *>());
+  }
+  typeConversions.resize(numActionsToKeep);
+}
+
+void ConversionPatternRewriterImpl::discardRewrites() {
+  undoTypeConversions();
+  undoBlockActions();
+
+  // Remove any newly created ops.
+  for (auto *op : createdOps) {
+    op->dropAllDefinedValueUses();
+    op->erase();
+  }
+}
+
+void ConversionPatternRewriterImpl::applyRewrites() {
+  // Apply all of the rewrites replacements requested during conversion.
+  for (auto &repl : replacements) {
+    for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i)
+      repl.op->getResult(i)->replaceAllUsesWith(
+          mapping.lookupOrDefault(repl.newValues[i]));
+
+    // If this operation defines any regions, drop any pending argument
+    // rewrites.
+    if (argConverter.typeConverter && repl.op->getNumRegions()) {
+      for (auto &region : repl.op->getRegions())
+        for (auto &block : region)
+          argConverter.cancelPendingRewrites(&block);
+    }
+  }
+
+  // In a second pass, erase all of the replaced operations in reverse. This
+  // allows processing nested operations before their parent region is
+  // destroyed.
+  for (auto &repl : llvm::reverse(replacements))
+    repl.op->erase();
+
+  argConverter.applyRewrites();
+}
+
+bool ConversionPatternRewriterImpl::hasSignatureBeenConverted(Block *block) {
+  return argConverter.hasBeenConverted(block);
+}
+
+LogicalResult
+ConversionPatternRewriterImpl::convertRegionSignature(Region &region) {
+  auto parentAttrs = region.getContainingOp()->getAttrList();
+  auto result = argConverter.convertSignature(region, mapping);
+  if (succeeded(result)) {
+    typeConversions.push_back(TypeConversion{&region, parentAttrs});
+    if (!region.empty())
+      typeConversions.push_back(
+          TypeConversion{&region.front(), NamedAttributeList()});
+  }
+  return result;
+}
+
+LogicalResult
+ConversionPatternRewriterImpl::convertBlockSignature(Block *block) {
+  auto result = argConverter.convertSignature(block, mapping);
+  if (succeeded(result))
+    typeConversions.push_back(TypeConversion{block, NamedAttributeList()});
+  return result;
+}
+
+void ConversionPatternRewriterImpl::replaceOp(
+    Operation *op, ArrayRef<Value *> newValues,
+    ArrayRef<Value *> valuesToRemoveIfDead) {
+  assert(newValues.size() == op->getNumResults());
+
+  // Create mappings for each of the new result values.
+  for (unsigned i = 0, e = newValues.size(); i < e; ++i) {
+    assert((newValues[i] || op->getResult(i)->use_empty()) &&
+           "result value has remaining uses that must be replaced");
+    if (newValues[i])
+      mapping.map(op->getResult(i), newValues[i]);
+  }
+
+  // Record the requested operation replacement.
+  replacements.emplace_back(op, newValues);
+}
+
+void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
+                                                     Block *continuation) {
+  BlockAction action;
+  action.kind = BlockActionKind::Split;
+  action.block = continuation;
+  action.originalBlock = block;
+  blockActions.push_back(action);
+}
+
+void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
+    Region &region, Region &parent, Region::iterator before) {
+  for (auto &pair : llvm::enumerate(region)) {
+    Block &block = pair.value();
+    unsigned position = pair.index();
+    BlockAction action;
+    action.kind = BlockActionKind::Move;
+    action.block = &block;
+    action.originalPosition = {&region, position};
+    blockActions.push_back(action);
+  }
+}
+
+void ConversionPatternRewriterImpl::remapValues(
+    Operation::operand_range operands, SmallVectorImpl<Value *> &remapped) {
+  remapped.reserve(llvm::size(operands));
+  for (Value *operand : operands)
+    remapped.push_back(mapping.lookupOrDefault(operand));
+}
+
+//===----------------------------------------------------------------------===//
+// ConversionPatternRewriter
+//===----------------------------------------------------------------------===//
+
+ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx,
+                                                     TypeConverter *converter)
+    : PatternRewriter(ctx),
+      impl(new detail::ConversionPatternRewriterImpl(*this, converter)) {}
+ConversionPatternRewriter::~ConversionPatternRewriter() {}
+
+/// PatternRewriter hook for replacing the results of an operation.
+void ConversionPatternRewriter::replaceOp(
+    Operation *op, ArrayRef<Value *> newValues,
+    ArrayRef<Value *> valuesToRemoveIfDead) {
+  impl->replaceOp(op, newValues, valuesToRemoveIfDead);
+}
+
+/// PatternRewriter hook for splitting a block into two parts.
+Block *ConversionPatternRewriter::splitBlock(Block *block,
+                                             Block::iterator before) {
+  auto *continuation = PatternRewriter::splitBlock(block, before);
+  impl->notifySplitBlock(block, continuation);
+  return continuation;
+}
+
+/// PatternRewriter hook for moving blocks out of a region.
+void ConversionPatternRewriter::inlineRegionBefore(Region &region,
+                                                   Region &parent,
+                                                   Region::iterator before) {
+  impl->notifyRegionIsBeingInlinedBefore(region, parent, before);
+  PatternRewriter::inlineRegionBefore(region, parent, before);
+}
+
+/// PatternRewriter hook for creating a new operation.
+Operation *
+ConversionPatternRewriter::createOperation(const OperationState &state) {
+  auto *result = OpBuilder::createOperation(state);
+  impl->createdOps.push_back(result);
+  return result;
+}
+
+/// PatternRewriter hook for updating the root operation in-place.
+void ConversionPatternRewriter::notifyRootUpdated(Operation *op) {
+  // The rewriter caches changes to the IR to allow for operating in-place and
+  // backtracking. The rewriter is currently not capable of backtracking
+  // in-place modifications.
+  llvm_unreachable("in-place operation updates are not supported");
+}
+
+/// Return a reference to the internal implementation.
+detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
+  return *impl;
+}
 
 //===----------------------------------------------------------------------===//
 // Conversion Patterns
@@ -592,12 +664,12 @@ PatternMatchResult
 ConversionPattern::matchAndRewrite(Operation *op,
                                    PatternRewriter &rewriter) const {
   SmallVector<Value *, 4> operands;
-  auto &dialectRewriter = static_cast<DialectConversionRewriter &>(rewriter);
-  dialectRewriter.remapValues(op->getOperands(), operands);
+  auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
+  dialectRewriter.getImpl().remapValues(op->getOperands(), operands);
 
   // If this operation has no successors, invoke the rewrite directly.
   if (op->getNumSuccessors() == 0)
-    return matchAndRewrite(op, operands, rewriter);
+    return matchAndRewrite(op, operands, dialectRewriter);
 
   // Otherwise, we need to remap the successors.
   SmallVector<Block *, 2> destinations;
@@ -620,7 +692,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
       op,
       llvm::makeArrayRef(operands.data(),
                          operands.data() + firstSuccessorOperand),
-      destinations, operandsPerDestination, rewriter);
+      destinations, operandsPerDestination, dialectRewriter);
 }
 
 //===----------------------------------------------------------------------===//
@@ -648,13 +720,13 @@ public:
 
   /// Attempt to legalize the given operation. Returns success if the operation
   /// was legalized, failure otherwise.
-  LogicalResult legalize(Operation *op, DialectConversionRewriter &rewriter);
+  LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
 
 private:
   /// Attempt to legalize the given operation by applying the provided pattern.
   /// Returns success if the operation was legalized, failure otherwise.
   LogicalResult legalizePattern(Operation *op, RewritePattern *pattern,
-                                DialectConversionRewriter &rewriter);
+                                ConversionPatternRewriter &rewriter);
 
   /// Build an optimistic legalization graph given the provided patterns. This
   /// function populates 'legalizerPatterns' with the operations that are not
@@ -693,15 +765,16 @@ bool OperationLegalizer::isIllegal(Operation *op) const {
 
 LogicalResult
 OperationLegalizer::legalize(Operation *op,
-                             DialectConversionRewriter &rewriter) {
+                             ConversionPatternRewriter &rewriter) {
   // Make sure that the signature of the parent block of this operation has been
   // converted.
-  if (rewriter.argConverter.typeConverter) {
+  auto &rewriterImpl = rewriter.getImpl();
+  if (rewriterImpl.argConverter.typeConverter) {
     auto *block = op->getBlock();
-    if (block && !rewriter.hasSignatureBeenConverted(block)) {
+    if (block && !rewriterImpl.hasSignatureBeenConverted(block)) {
       if (failed(block->isEntryBlock()
-                     ? rewriter.convertRegionSignature(*block->getParent())
-                     : rewriter.convertBlockSignature(block)))
+                     ? rewriterImpl.convertRegionSignature(*block->getParent())
+                     : rewriterImpl.convertBlockSignature(block)))
         return failure();
     }
   }
@@ -743,7 +816,7 @@ OperationLegalizer::legalize(Operation *op,
 
 LogicalResult
 OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
-                                    DialectConversionRewriter &rewriter) {
+                                    ConversionPatternRewriter &rewriter) {
   LLVM_DEBUG({
     llvm::dbgs() << "-* Applying rewrite pattern '" << op->getName() << " -> (";
     interleaveComma(pattern->getGeneratedOps(), llvm::dbgs());
@@ -759,10 +832,11 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
     return failure();
   }
 
-  RewriterState curState = rewriter.getCurrentState();
+  auto &rewriterImpl = rewriter.getImpl();
+  RewriterState curState = rewriterImpl.getCurrentState();
   auto cleanupFailure = [&] {
     // Reset the rewriter state and pop this pattern.
-    rewriter.resetState(curState);
+    rewriterImpl.resetState(curState);
     appliedPatterns.erase(pattern);
     return failure();
   };
@@ -776,9 +850,9 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
 
   // Recursively legalize each of the new operations.
   for (unsigned i = curState.numCreatedOperations,
-                e = rewriter.createdOps.size();
+                e = rewriterImpl.createdOps.size();
        i != e; ++i) {
-    if (failed(legalize(rewriter.createdOps[i], rewriter))) {
+    if (failed(legalize(rewriterImpl.createdOps[i], rewriter))) {
       LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated operation was illegal.\n");
       return cleanupFailure();
     }
@@ -941,7 +1015,7 @@ struct OperationConverter {
 
 private:
   /// Converts an operation with the given rewriter.
-  LogicalResult convert(DialectConversionRewriter &rewriter, Operation *op);
+  LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
 
   /// Recursively collect all of the operations, to convert from within
   /// 'region'.
@@ -991,7 +1065,7 @@ OperationConverter::computeConversionSet(Region &region,
 }
 
 /// Converts an operation with the given rewriter.
-LogicalResult OperationConverter::convert(DialectConversionRewriter &rewriter,
+LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
                                           Operation *op) {
   // Legalize the given operation.
   if (failed(opLegalizer.legalize(op, rewriter))) {
@@ -1013,16 +1087,17 @@ LogicalResult OperationConverter::convert(DialectConversionRewriter &rewriter,
   // within.
   // FIXME(riverriddle) This should be replaced by patterns when the pattern
   // rewriter exposes functionality to remap region signatures.
-  if (rewriter.argConverter.typeConverter) {
+  auto &rewriterImpl = rewriter.getImpl();
+  if (rewriterImpl.argConverter.typeConverter) {
     for (auto &region : op->getRegions())
-      if (region.empty() && failed(rewriter.convertRegionSignature(region)))
+      if (region.empty() && failed(rewriterImpl.convertRegionSignature(region)))
         return failure();
   }
 
   return success();
 }
 
-/// Converts the given top-level operation to the conversion target.
+/// Converts the given operations to the conversion target.
 LogicalResult
 OperationConverter::convertOperations(ArrayRef<Operation *> ops,
                                       TypeConverter *typeConverter) {
@@ -1039,16 +1114,16 @@ OperationConverter::convertOperations(ArrayRef<Operation *> ops,
   }
 
   // Convert each operation and discard rewrites on failure.
-  DialectConversionRewriter rewriter(ops.front()->getContext(), typeConverter);
+  ConversionPatternRewriter rewriter(ops.front()->getContext(), typeConverter);
   for (auto *op : toConvert) {
     if (failed(convert(rewriter, op))) {
-      rewriter.discardRewrites();
+      rewriter.getImpl().discardRewrites();
       return failure();
     }
   }
 
   // Otherwise the body conversion succeeded, so apply all rewrites.
-  rewriter.applyRewrites();
+  rewriter.getImpl().applyRewrites();
   return success();
 }
 
index 6b1266e..410536c 100644 (file)
@@ -62,8 +62,9 @@ struct TestRegionRewriteBlockMovement : public ConversionPattern {
   TestRegionRewriteBlockMovement(MLIRContext *ctx)
       : ConversionPattern("test.region", 1, ctx) {}
 
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const final {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const final {
     // Inline this region into the parent region.
     auto &parentRegion = *op->getContainingRegion();
     rewriter.inlineRegionBefore(op->getRegion(0), parentRegion,
@@ -101,8 +102,9 @@ struct TestRegionRewriteUndo : public RewritePattern {
 /// This pattern simply erases the given operation.
 struct TestDropOp : public ConversionPattern {
   TestDropOp(MLIRContext *ctx) : ConversionPattern("test.drop_op", 1, ctx) {}
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const final {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const final {
     rewriter.replaceOp(op, llvm::None);
     return matchSuccess();
   }
@@ -111,8 +113,9 @@ struct TestDropOp : public ConversionPattern {
 struct TestPassthroughInvalidOp : public ConversionPattern {
   TestPassthroughInvalidOp(MLIRContext *ctx)
       : ConversionPattern("test.invalid", 1, ctx) {}
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const final {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const final {
     rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
                                              llvm::None);
     return matchSuccess();
@@ -122,8 +125,9 @@ struct TestPassthroughInvalidOp : public ConversionPattern {
 struct TestSplitReturnType : public ConversionPattern {
   TestSplitReturnType(MLIRContext *ctx)
       : ConversionPattern("test.return", 1, ctx) {}
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const final {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const final {
     // Check for a return of F32.
     if (op->getNumOperands() != 1 || !op->getOperand(0)->getType().isF32())
       return matchFailure();
index b76a565..0077e95 100644 (file)
@@ -115,8 +115,9 @@ public:
                        lowering_.getDialect()->getContext(), lowering_) {}
 
   // Convert the kernel arguments to an LLVM type, preserve the rest.
-  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                                     PatternRewriter &rewriter) const override {
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
     rewriter.clone(*op)->setOperands(operands);
     return rewriter.replaceOp(op, llvm::None), matchSuccess();
   }