Add support for matchAndRewrite to the DialectConversion patterns. This also drops...
authorRiver Riddle <riverriddle@google.com>
Thu, 6 Jun 2019 22:38:08 +0000 (15:38 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 9 Jun 2019 23:21:20 +0000 (16:21 -0700)
PiperOrigin-RevId: 251941625

mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp
mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp
mlir/examples/toy/Ch5/mlir/LateLowering.cpp
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/lib/Transforms/DialectConversion.cpp

index d13f7f3..f090856 100644 (file)
@@ -135,8 +135,8 @@ public:
   explicit RangeOpConversion(MLIRContext *context)
       : ConversionPattern(linalg::RangeOp::getOperationName(), 1, context) {}
 
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     auto rangeOp = cast<linalg::RangeOp>(op);
     auto rangeDescriptorType =
         linalg::convertLinalgType(rangeOp.getResult()->getType());
@@ -153,6 +153,7 @@ public:
     rangeDescriptor = insertvalue(rangeDescriptorType, rangeDescriptor,
                                   operands[2], makePositionAttr(rewriter, 2));
     rewriter.replaceOp(op, rangeDescriptor);
+    return matchSuccess();
   }
 };
 
@@ -161,8 +162,8 @@ public:
   explicit ViewOpConversion(MLIRContext *context)
       : ConversionPattern(linalg::ViewOp::getOperationName(), 1, context) {}
 
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     auto viewOp = cast<linalg::ViewOp>(op);
     auto viewDescriptorType = linalg::convertLinalgType(viewOp.getViewType());
     auto memrefType =
@@ -277,6 +278,7 @@ public:
     }
 
     rewriter.replaceOp(op, viewDescriptor);
+    return matchSuccess();
   }
 };
 
@@ -285,8 +287,8 @@ public:
   explicit SliceOpConversion(MLIRContext *context)
       : ConversionPattern(linalg::SliceOp::getOperationName(), 1, context) {}
 
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     auto sliceOp = cast<linalg::SliceOp>(op);
     auto newViewDescriptorType =
         linalg::convertLinalgType(sliceOp.getViewType());
@@ -366,6 +368,7 @@ public:
     }
 
     rewriter.replaceOp(op, newViewDescriptor);
+    return matchSuccess();
   }
 };
 
@@ -376,9 +379,10 @@ public:
   explicit DropConsumer(MLIRContext *context)
       : ConversionPattern("some_consumer", 1, context) {}
 
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     rewriter.replaceOp(op, llvm::None);
+    return matchSuccess();
   }
 };
 
index ef0d858..26d6af8 100644 (file)
@@ -95,8 +95,8 @@ public:
 // an LLVM IR load.
 class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
   using Base::Base;
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     edsc::ScopedContext edscContext(rewriter, op->getLoc());
     auto elementType = linalg::convertLinalgType(*op->result_type_begin());
     Value *viewDescriptor = operands[0];
@@ -104,6 +104,7 @@ class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
     Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
     Value *element = intrinsics::load(elementType, ptr);
     rewriter.replaceOp(op, {element});
+    return matchSuccess();
   }
 };
 
@@ -111,8 +112,8 @@ class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
 // an LLVM IR store.
 class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
   using Base::Base;
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     edsc::ScopedContext edscContext(rewriter, op->getLoc());
     Value *viewDescriptor = operands[1];
     Value *data = operands[0];
@@ -120,6 +121,7 @@ class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
     Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
     intrinsics::store(data, ptr);
     rewriter.replaceOp(op, llvm::None);
+    return matchSuccess();
   }
 };
 
index 189add0..82541f8 100644 (file)
@@ -87,8 +87,8 @@ public:
   explicit MulOpConversion(MLIRContext *context)
       : ConversionPattern(toy::MulOp::getOperationName(), 1, context) {}
 
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     using namespace edsc;
     using intrinsics::constant_index;
     using linalg::intrinsics::range;
@@ -117,6 +117,7 @@ public:
     auto resultView = view(result, {r0, r2});
     rewriter.create<linalg::MatmulOp>(loc, lhsView, rhsView, resultView);
     rewriter.replaceOp(op, {typeCast(rewriter, result, mul.getType())});
+    return matchSuccess();
   }
 };
 
index ecf6c9d..4434e1b 100644 (file)
@@ -92,8 +92,8 @@ public:
   /// the rewritten operands for `op` in the new function.
   /// The results created by the new IR with the builder are returned, and their
   /// number must match the number of result of `op`.
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     auto add = cast<toy::AddOp>(op);
     auto loc = add.getLoc();
     // Create a `toy.alloc` operation to allocate the output buffer for this op.
@@ -122,6 +122,7 @@ public:
     // Return the newly allocated buffer, with a type.cast to preserve the
     // consumers.
     rewriter.replaceOp(op, {typeCast(rewriter, result, add.getType())});
+    return matchSuccess();
   }
 };
 
@@ -132,8 +133,8 @@ public:
   explicit PrintOpConversion(MLIRContext *context)
       : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {}
 
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     // Get or create the declaration of the printf function in the module.
     Function *printfFunc = getPrintf(*op->getFunction()->getModule());
 
@@ -178,6 +179,7 @@ public:
       // clang-format on
     }
     rewriter.replaceOp(op, llvm::None);
+    return matchSuccess();
   }
 
 private:
@@ -230,8 +232,8 @@ public:
   explicit ConstantOpConversion(MLIRContext *context)
       : ConversionPattern(toy::ConstantOp::getOperationName(), 1, context) {}
 
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     toy::ConstantOp cstOp = cast<toy::ConstantOp>(op);
     auto loc = cstOp.getLoc();
     auto retTy = cstOp.getResult()->getType().cast<toy::ToyArrayType>();
@@ -264,6 +266,7 @@ public:
       }
     }
     rewriter.replaceOp(op, result);
+    return matchSuccess();
   }
 };
 
@@ -273,8 +276,8 @@ public:
   explicit TransposeOpConversion(MLIRContext *context)
       : ConversionPattern(toy::TransposeOp::getOperationName(), 1, context) {}
 
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     auto transpose = cast<toy::TransposeOp>(op);
     auto loc = transpose.getLoc();
     Value *result = memRefTypeCast(
@@ -296,6 +299,7 @@ public:
     // clang-format on
 
     rewriter.replaceOp(op, {typeCast(rewriter, result, transpose.getType())});
+    return matchSuccess();
   }
 };
 
@@ -305,13 +309,14 @@ public:
   explicit ReturnOpConversion(MLIRContext *context)
       : ConversionPattern(toy::ReturnOp::getOperationName(), 1, context) {}
 
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     // Argument is optional, handle both cases.
     if (op->getNumOperands())
       rewriter.replaceOpWithNewOp<ReturnOp>(op, operands[0]);
     else
       rewriter.replaceOpWithNewOp<ReturnOp>(op);
+    return matchSuccess();
   }
 };
 
index 8b476c0..af08a1f 100644 (file)
@@ -48,13 +48,6 @@ public:
                     MLIRContext *ctx)
       : RewritePattern(rootName, benefit, ctx) {}
 
-  /// Hook for derived classes to implement matching. Dialect conversion
-  /// generally unconditionally match the root operation, so default to success
-  /// here.
-  virtual PatternMatchResult match(Operation *op) const override {
-    return matchSuccess();
-  }
-
   /// Hook for derived classes to implement rewriting. `op` is the (first)
   /// operation matched by the pattern, `operands` is a list of rewritten values
   /// that are passed to this operation, `rewriter` can be used to emit the new
@@ -84,14 +77,33 @@ public:
     llvm_unreachable("unimplemented rewrite for terminators");
   }
 
-  /// Rewrite the IR rooted at the specified operation with the result of
-  /// this pattern. If an unexpected error is encountered (an internal compiler
-  /// error), it is emitted through the normal MLIR diagnostic hooks and the IR
-  /// is left in a valid state.
-  void rewrite(Operation *op, PatternRewriter &rewriter) const final;
+  /// Hook for derived classes to implement combined matching and rewriting.
+  virtual PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> properOperands,
+                  ArrayRef<Block *> destinations,
+                  ArrayRef<ArrayRef<Value *>> operands,
+                  PatternRewriter &rewriter) const {
+    if (!match(op))
+      return matchFailure();
+    rewrite(op, properOperands, destinations, operands, rewriter);
+    return matchSuccess();
+  }
+
+  /// Hook for derived classes to implement combined matching and rewriting.
+  virtual PatternMatchResult matchAndRewrite(Operation *op,
+                                             ArrayRef<Value *> operands,
+                                             PatternRewriter &rewriter) const {
+    if (!match(op))
+      return matchFailure();
+    rewrite(op, operands, rewriter);
+    return matchSuccess();
+  }
+
+  /// Attempt to match and rewrite the IR root at the specified operation.
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const final;
 
 private:
-  using RewritePattern::matchAndRewrite;
   using RewritePattern::rewrite;
 };
 
index 1b50320..dd91f06 100644 (file)
@@ -210,10 +210,6 @@ public:
                        lowering_),
         dialect(dialect_) {}
 
-  PatternMatchResult match(Operation *op) const override {
-    return this->matchSuccess();
-  }
-
   // Get the LLVM IR dialect.
   LLVM::LLVMDialect &getDialect() const { return dialect; }
   // Get the LLVM context.
@@ -279,8 +275,8 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
 
   // Convert the type of the result to an LLVM type, pass operands as is,
   // preserve attributes.
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     unsigned numResults = op->getNumResults();
 
     Type packedType;
@@ -296,9 +292,10 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
 
     // If the operation produced 0 or 1 result, return them immediately.
     if (numResults == 0)
-      return rewriter.replaceOp(op, llvm::None);
+      return rewriter.replaceOp(op, llvm::None), this->matchSuccess();
     if (numResults == 1)
-      return rewriter.replaceOp(op, newOp.getOperation()->getResult(0));
+      return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)),
+             this->matchSuccess();
 
     // Otherwise, it had been converted to an operation producing a structure.
     // Extract individual results from the structure and return them as list.
@@ -311,6 +308,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
           this->getIntegerArrayAttr(rewriter, i)));
     }
     rewriter.replaceOp(op, results);
+    return this->matchSuccess();
   }
 };
 
@@ -500,8 +498,8 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
 struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
   using LLVMLegalizationPattern<DeallocOp>::LLVMLegalizationPattern;
 
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     assert(operands.size() == 1 && "dealloc takes one operand");
     OperandAdaptor<DeallocOp> transformed(operands);
 
@@ -524,6 +522,7 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
         op->getLoc(), getVoidPtrType(), bufferPtr);
     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
         op, ArrayRef<Type>(), rewriter.getFunctionAttr(freeFunc), casted);
+    return matchSuccess();
   }
 };
 
@@ -759,8 +758,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
 struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
   using Base::Base;
 
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     auto loadOp = cast<LoadOp>(op);
     OperandAdaptor<LoadOp> transformed(operands);
     auto type = loadOp.getMemRefType();
@@ -771,6 +770,7 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
 
     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, elementType,
                                               ArrayRef<Value *>{dataPtr});
+    return matchSuccess();
   }
 };
 
@@ -779,8 +779,8 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
 struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
   using Base::Base;
 
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     auto type = cast<StoreOp>(op).getMemRefType();
     OperandAdaptor<StoreOp> transformed(operands);
 
@@ -788,6 +788,7 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
                                 transformed.indices(), rewriter, getModule());
     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
                                                dataPtr);
+    return matchSuccess();
   }
 };
 
@@ -798,12 +799,14 @@ struct OneToOneLLVMTerminatorLowering
   using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
   using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
 
-  void rewrite(Operation *op, ArrayRef<Value *> properOperands,
-               ArrayRef<Block *> destinations,
-               ArrayRef<ArrayRef<Value *>> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     ArrayRef<Value *> properOperands,
+                                     ArrayRef<Block *> destinations,
+                                     ArrayRef<ArrayRef<Value *>> operands,
+                                     PatternRewriter &rewriter) const override {
     rewriter.replaceOpWithNewOp<TargetOp>(op, properOperands, destinations,
                                           operands, op->getAttrs());
+    return this->matchSuccess();
   }
 };
 
@@ -816,21 +819,23 @@ struct OneToOneLLVMTerminatorLowering
 struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
   using LLVMLegalizationPattern<ReturnOp>::LLVMLegalizationPattern;
 
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     unsigned numArguments = op->getNumOperands();
 
     // If ReturnOp has 0 or 1 operand, create it and return immediately.
     if (numArguments == 0) {
-      return rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
+      rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
           op, llvm::ArrayRef<Value *>(), llvm::ArrayRef<Block *>(),
           llvm::ArrayRef<llvm::ArrayRef<Value *>>(), op->getAttrs());
+      return matchSuccess();
     }
     if (numArguments == 1) {
-      return rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
+      rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
           op, llvm::ArrayRef<Value *>(operands.front()),
           llvm::ArrayRef<Block *>(), llvm::ArrayRef<llvm::ArrayRef<Value *>>(),
           op->getAttrs());
+      return matchSuccess();
     }
 
     // Otherwise, we need to pack the arguments into an LLVM struct type before
@@ -847,6 +852,7 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
         op, llvm::makeArrayRef(packed), llvm::ArrayRef<Block *>(),
         llvm::ArrayRef<llvm::ArrayRef<Value *>>(), op->getAttrs());
+    return matchSuccess();
   }
 };
 
index b3857ac..60c16d9 100644 (file)
@@ -159,8 +159,8 @@ public:
                                    LLVMTypeConverter &lowering_)
       : LLVMOpLowering(BufferAllocOp::getOperationName(), context, lowering_) {}
 
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     auto indexType = IndexType::get(op->getContext());
     auto voidPtrTy =
         LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
@@ -204,6 +204,7 @@ public:
     desc = insertvalue(bufferDescriptorType, desc, size,
                        positionAttr(rewriter, 1));
     rewriter.replaceOp(op, desc);
+    return matchSuccess();
   }
 };
 
@@ -215,8 +216,8 @@ public:
       : LLVMOpLowering(BufferDeallocOp::getOperationName(), context,
                        lowering_) {}
 
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     auto voidPtrTy =
         LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
     // Insert the `free` declaration if it is not already present.
@@ -239,6 +240,7 @@ public:
                                                     positionAttr(rewriter, 0)));
     call(ArrayRef<Type>(), rewriter.getFunctionAttr(freeFunc), casted);
     rewriter.replaceOp(op, llvm::None);
+    return matchSuccess();
   }
 };
 
@@ -248,12 +250,13 @@ public:
   BufferSizeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
       : LLVMOpLowering(BufferSizeOp::getOperationName(), context, lowering_) {}
 
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     auto int64Ty = lowering.convertType(operands[0]->getType());
     edsc::ScopedContext context(rewriter, op->getLoc());
     rewriter.replaceOp(
         op, {extractvalue(int64Ty, operands[0], positionAttr(rewriter, 1))});
+    return matchSuccess();
   }
 };
 
@@ -263,8 +266,8 @@ public:
   explicit DimOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
       : LLVMOpLowering(linalg::DimOp::getOperationName(), context, lowering_) {}
 
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     auto dimOp = cast<linalg::DimOp>(op);
     auto indexTy = lowering.convertType(rewriter.getIndexType());
     edsc::ScopedContext context(rewriter, op->getLoc());
@@ -273,6 +276,7 @@ public:
         {extractvalue(
             indexTy, operands[0],
             positionAttr(rewriter, {2, static_cast<int>(dimOp.getIndex())}))});
+    return matchSuccess();
   }
 };
 
@@ -318,14 +322,15 @@ public:
 // an LLVM IR load.
 class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
   using Base::Base;
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     edsc::ScopedContext edscContext(rewriter, op->getLoc());
     auto elementTy = lowering.convertType(*op->result_type_begin());
     Value *viewDescriptor = operands[0];
     ArrayRef<Value *> indices = operands.drop_front();
     auto ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
     rewriter.replaceOp(op, {llvm_load(elementTy, ptr)});
+    return matchSuccess();
   }
 };
 
@@ -335,8 +340,8 @@ public:
   explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
       : LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {}
 
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     auto rangeOp = cast<RangeOp>(op);
     auto rangeDescriptorTy =
         convertLinalgType(rangeOp.getResult()->getType(), lowering);
@@ -352,6 +357,7 @@ public:
     desc = insertvalue(rangeDescriptorTy, desc, operands[2],
                        positionAttr(rewriter, 2));
     rewriter.replaceOp(op, desc);
+    return matchSuccess();
   }
 };
 
@@ -363,8 +369,8 @@ public:
       : LLVMOpLowering(RangeIntersectOp::getOperationName(), context,
                        lowering_) {}
 
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     auto rangeIntersectOp = cast<RangeIntersectOp>(op);
     auto rangeDescriptorTy =
         convertLinalgType(rangeIntersectOp.getResult()->getType(), lowering);
@@ -397,6 +403,7 @@ public:
     desc = insertvalue(rangeDescriptorTy, desc, mul(step1, step2),
                        positionAttr(rewriter, 2));
     rewriter.replaceOp(op, desc);
+    return matchSuccess();
   }
 };
 
@@ -405,8 +412,8 @@ public:
   explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
       : LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {}
 
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     auto sliceOp = cast<SliceOp>(op);
     auto viewDescriptorTy = convertLinalgType(sliceOp.getViewType(), lowering);
     auto viewType = sliceOp.getBaseViewType();
@@ -477,6 +484,7 @@ public:
     }
 
     rewriter.replaceOp(op, desc);
+    return matchSuccess();
   }
 };
 
@@ -484,8 +492,8 @@ public:
 // an LLVM IR store.
 class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
   using Base::Base;
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     edsc::ScopedContext edscContext(rewriter, op->getLoc());
     Value *data = operands[0];
     Value *viewDescriptor = operands[1];
@@ -493,6 +501,7 @@ class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
     Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
     llvm_store(data, ptr);
     rewriter.replaceOp(op, llvm::None);
+    return matchSuccess();
   }
 };
 
@@ -501,8 +510,8 @@ public:
   explicit ViewOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
       : LLVMOpLowering(ViewOp::getOperationName(), context, lowering_) {}
 
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     auto viewOp = cast<ViewOp>(op);
     auto viewDescriptorTy = convertLinalgType(viewOp.getViewType(), lowering);
     auto elementTy = getPtrToElementType(viewOp.getViewType(), lowering);
@@ -548,6 +557,7 @@ public:
     }
 
     rewriter.replaceOp(op, desc);
+    return matchSuccess();
   }
 };
 
@@ -560,20 +570,21 @@ public:
 
   static StringRef libraryFunctionName() { return "linalg_dot"; }
 
-  void rewrite(Operation *op, ArrayRef<Value *> operands,
-               PatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                     PatternRewriter &rewriter) const override {
     auto *f =
         op->getFunction()->getModule()->getNamedFunction(libraryFunctionName());
     if (!f) {
       op->emitError("Could not find function: " + libraryFunctionName() +
                     "in lowering to LLVM ");
-      return;
+      return matchFailure();
     }
 
     auto fAttr = rewriter.getFunctionAttr(f);
     auto named = rewriter.getNamedAttr("callee", fAttr);
     rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, operands,
                                               ArrayRef<NamedAttribute>{named});
+    return matchSuccess();
   }
 };
 
index 6647434..00ae605 100644 (file)
@@ -226,19 +226,17 @@ struct DialectConversionRewriter final : public PatternRewriter {
 // ConversionPattern
 //===----------------------------------------------------------------------===//
 
-/// Rewrite the IR rooted at the specified operation with the result of this
-/// pattern.  If an unexpected error is encountered (an internal compiler
-/// error), it is emitted through the normal MLIR diagnostic hooks and the IR is
-/// left in a valid state.
-void ConversionPattern::rewrite(Operation *op,
-                                PatternRewriter &rewriter) const {
+/// Attempt to match and rewrite the IR root at the specified operation.
+PatternMatchResult
+ConversionPattern::matchAndRewrite(Operation *op,
+                                   PatternRewriter &rewriter) const {
   SmallVector<Value *, 4> operands;
   auto &dialectRewriter = static_cast<DialectConversionRewriter &>(rewriter);
   dialectRewriter.remapValues(op->getOperands(), operands);
 
   // If this operation has no successors, invoke the rewrite directly.
   if (op->getNumSuccessors() == 0)
-    return rewrite(op, operands, rewriter);
+    return matchAndRewrite(op, operands, rewriter);
 
   // Otherwise, we need to remap the successors.
   SmallVector<Block *, 2> destinations;
@@ -257,10 +255,11 @@ void ConversionPattern::rewrite(Operation *op,
   }
 
   // Rewrite the operation.
-  rewrite(op,
-          llvm::makeArrayRef(operands.data(),
-                             operands.data() + firstSuccessorOperand),
-          destinations, operandsPerDestination, rewriter);
+  return matchAndRewrite(
+      op,
+      llvm::makeArrayRef(operands.data(),
+                         operands.data() + firstSuccessorOperand),
+      destinations, operandsPerDestination, rewriter);
 }
 
 //===----------------------------------------------------------------------===//