Add a templated wrapper around RewritePattern that allows for defining match...
authorRiver Riddle <riverriddle@google.com>
Sun, 26 May 2019 00:22:27 +0000 (17:22 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:03:22 +0000 (20:03 -0700)
--

PiperOrigin-RevId: 250003405

mlir/examples/Linalg/Linalg3/lib/Transforms.cpp
mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
mlir/include/mlir/IR/PatternMatch.h
mlir/lib/AffineOps/AffineOps.cpp
mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
mlir/lib/StandardOps/Ops.cpp

index 621bc26..0fe70e2 100644 (file)
@@ -248,12 +248,12 @@ namespace {
 /// mlir::StoreOp requires finding the proper indexing in the supporting MemRef.
 /// This is most easily achieved by calling emitAndReturnFullyComposedView to
 /// fold away all the SliceOp.
-template <typename LoadOrStoreOpTy> struct Rewriter : public RewritePattern {
-  explicit Rewriter(MLIRContext *context)
-      : RewritePattern(LoadOrStoreOpTy::getOperationName(), 1, context) {}
+template <typename LoadOrStoreOpTy>
+struct Rewriter : public OpRewritePattern<LoadOrStoreOpTy> {
+  using OpRewritePattern<LoadOrStoreOpTy>::OpRewritePattern;
 
   /// Performs the rewrite.
-  PatternMatchResult matchAndRewrite(Operation *op,
+  PatternMatchResult matchAndRewrite(LoadOrStoreOpTy op,
                                      PatternRewriter &rewriter) const override;
 };
 
@@ -270,9 +270,8 @@ struct LowerLinalgLoadStorePass
 
 template <>
 PatternMatchResult
-Rewriter<linalg::LoadOp>::matchAndRewrite(Operation *op,
+Rewriter<linalg::LoadOp>::matchAndRewrite(linalg::LoadOp load,
                                           PatternRewriter &rewriter) const {
-  auto load = cast<linalg::LoadOp>(op);
   SliceOp slice = dyn_cast<SliceOp>(load.getView()->getDefiningOp());
   ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult())
                       : cast<ViewOp>(load.getView()->getDefiningOp());
@@ -280,15 +279,14 @@ Rewriter<linalg::LoadOp>::matchAndRewrite(Operation *op,
   ScopedContext scope(builder, load.getLoc());
   auto *memRef = view.getSupportingMemRef();
   auto operands = emitAndReturnLoadStoreOperands(load, view);
-  rewriter.replaceOpWithNewOp<mlir::LoadOp>(op, memRef, operands);
+  rewriter.replaceOpWithNewOp<mlir::LoadOp>(load, memRef, operands);
   return matchSuccess();
 }
 
 template <>
 PatternMatchResult
-Rewriter<linalg::StoreOp>::matchAndRewrite(Operation *op,
+Rewriter<linalg::StoreOp>::matchAndRewrite(linalg::StoreOp store,
                                            PatternRewriter &rewriter) const {
-  auto store = cast<linalg::StoreOp>(op);
   SliceOp slice = dyn_cast<SliceOp>(store.getView()->getDefiningOp());
   ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult())
                       : cast<ViewOp>(store.getView()->getDefiningOp());
@@ -297,7 +295,7 @@ Rewriter<linalg::StoreOp>::matchAndRewrite(Operation *op,
   auto *valueToStore = store.getValueToStore();
   auto *memRef = view.getSupportingMemRef();
   auto operands = emitAndReturnLoadStoreOperands(store, view);
-  rewriter.replaceOpWithNewOp<mlir::StoreOp>(op, valueToStore, memRef,
+  rewriter.replaceOpWithNewOp<mlir::StoreOp>(store, valueToStore, memRef,
                                              operands);
   return matchSuccess();
 }
index 37aa47f..8baa45c 100644 (file)
@@ -33,25 +33,21 @@ namespace toy {
 namespace {
 
 /// Fold transpose(transpose(x) -> transpose(x)
-struct SimplifyRedundantTranspose : public mlir::RewritePattern {
+struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
   /// We register this pattern to match every toy.transpose in the IR.
   /// The "benefit" is used by the framework to order the patterns and process
   /// them in order of profitability.
   SimplifyRedundantTranspose(mlir::MLIRContext *context)
-      : RewritePattern(TransposeOp::getOperationName(), /* benefit = */ 1,
-                       context) {}
+      : OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
 
   /// This method is attempting to match a pattern and rewrite it. The rewriter
   /// argument is the orchestrator of the sequence of rewrites. It is expected
   /// to interact with it to perform any changes to the IR from here.
   mlir::PatternMatchResult
-  matchAndRewrite(mlir::Operation *op,
+  matchAndRewrite(TransposeOp op,
                   mlir::PatternRewriter &rewriter) const override {
-    // We can directly cast the current operation as this will only get invoked
-    // on TransposeOp.
-    TransposeOp transpose = llvm::cast<TransposeOp>(op);
     // Look through the input of the current transpose.
-    mlir::Value *transposeInput = transpose.getOperand();
+    mlir::Value *transposeInput = op.getOperand();
     TransposeOp transposeInputOp =
         llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp());
     // If the input is defined by another Transpose, bingo!
@@ -65,15 +61,12 @@ struct SimplifyRedundantTranspose : public mlir::RewritePattern {
 };
 
 /// Fold reshape(constant(x)) -> constant(x'), with x' being reshaped in place.
-struct SimplifyReshapeConstant : public mlir::RewritePattern {
-  SimplifyReshapeConstant(mlir::MLIRContext *context)
-      : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
-                       context) {}
+struct SimplifyReshapeConstant : public mlir::OpRewritePattern<ReshapeOp> {
+  using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
 
   mlir::PatternMatchResult
-  matchAndRewrite(mlir::Operation *op,
+  matchAndRewrite(ReshapeOp reshape,
                   mlir::PatternRewriter &rewriter) const override {
-    ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
     // Look through the input of the current reshape.
     ConstantOp constantOp = llvm::dyn_cast_or_null<ConstantOp>(
         reshape.getOperand()->getDefiningOp());
@@ -81,7 +74,7 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern {
     if (!constantOp)
       return matchFailure();
 
-    auto reshapeType = op->getResult(0)->getType().cast<ToyArrayType>();
+    auto reshapeType = reshape.getType().cast<ToyArrayType>();
     if (auto valueAttr =
             constantOp.getAttrOfType<mlir::DenseElementsAttr>("value")) {
       // FIXME Check matching of element count!
@@ -90,7 +83,7 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern {
           reshapeType.getShape(), valueAttr.getType().getElementType());
       auto newAttr =
           mlir::DenseElementsAttr::get(newType, valueAttr.getRawData());
-      rewriter.replaceOpWithNewOp<ConstantOp>(op, reshapeType.getShape(),
+      rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
                                               newAttr);
     } else if (auto valueAttr =
                    constantOp.getAttrOfType<mlir::FloatAttr>("value")) {
@@ -102,7 +95,7 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern {
       auto tensorTy = rewriter.getTensorType(reshapeType.getShape(),
                                              reshapeType.getElementType());
       auto newAttr = mlir::DenseElementsAttr::get(tensorTy, data);
-      rewriter.replaceOpWithNewOp<ConstantOp>(op, reshapeType.getShape(),
+      rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
                                               newAttr);
     } else {
       llvm_unreachable("Unsupported Constant format");
@@ -112,17 +105,15 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern {
 };
 
 /// Fold reshape(reshape(x)) -> reshape(x)
-struct SimplifyReshapeReshape : public mlir::RewritePattern {
-  SimplifyReshapeReshape(mlir::MLIRContext *context)
-      : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
-                       context) {}
+struct SimplifyReshapeReshape : public mlir::OpRewritePattern<ReshapeOp> {
+  using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
 
   mlir::PatternMatchResult
-  matchAndRewrite(mlir::Operation *op,
+  matchAndRewrite(ReshapeOp op,
                   mlir::PatternRewriter &rewriter) const override {
-    ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
     // Look through the input of the current reshape.
-    mlir::Value *reshapeInput = reshape.getOperand();
+    mlir::Value *reshapeInput = op.getOperand();
+
     // If the input is defined by another reshape, bingo!
     if (!matchPattern(reshapeInput, mlir::m_Op<ReshapeOp>()))
       return matchFailure();
@@ -134,18 +125,15 @@ struct SimplifyReshapeReshape : public mlir::RewritePattern {
 };
 
 /// Fold reshape(x)) -> x, when input type matches output type
-struct SimplifyNullReshape : public mlir::RewritePattern {
-  SimplifyNullReshape(mlir::MLIRContext *context)
-      : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
-                       context) {}
+struct SimplifyNullReshape : public mlir::OpRewritePattern<ReshapeOp> {
+  using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
 
   mlir::PatternMatchResult
-  matchAndRewrite(mlir::Operation *op,
+  matchAndRewrite(ReshapeOp op,
                   mlir::PatternRewriter &rewriter) const override {
-    ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
-    if (reshape.getOperand()->getType() != reshape.getResult()->getType())
+    if (op.getOperand()->getType() != op.getType())
       return matchFailure();
-    rewriter.replaceOp(reshape, {reshape.getOperand()});
+    rewriter.replaceOp(op, {op.getOperand()});
     return matchSuccess();
   }
 };
index 6e05eaf..64bd2c9 100644 (file)
@@ -22,7 +22,7 @@
 
 #include "toy/Dialect.h"
 
-#include "mlir/IR/Operation.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/StandardTypes.h"
 
@@ -32,30 +32,26 @@ namespace toy {
 
 namespace {
 
-/// Fold transpose(transpose(x)) -> transpose(x)
-struct SimplifyRedundantTranspose : public mlir::RewritePattern {
+/// Fold transpose(transpose(x) -> transpose(x)
+struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
   /// We register this pattern to match every toy.transpose in the IR.
   /// The "benefit" is used by the framework to order the patterns and process
   /// them in order of profitability.
   SimplifyRedundantTranspose(mlir::MLIRContext *context)
-      : RewritePattern(TransposeOp::getOperationName(), /* benefit = */ 1,
-                       context) {}
+      : OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
 
   /// This method is attempting to match a pattern and rewrite it. The rewriter
   /// argument is the orchestrator of the sequence of rewrites. It is expected
   /// to interact with it to perform any changes to the IR from here.
   mlir::PatternMatchResult
-  matchAndRewrite(mlir::Operation *op,
+  matchAndRewrite(TransposeOp op,
                   mlir::PatternRewriter &rewriter) const override {
-    // We can directly cast the current operation as this will only get invoked
-    // on TransposeOp.
-    TransposeOp transpose = llvm::cast<TransposeOp>(op);
-    // look through the input to the current transpose
-    mlir::Value *transposeInput = transpose.getOperand();
-    mlir::Operation *transposeInputInst = transposeInput->getDefiningOp();
-    // If the input is defined by another Transpose, bingo!
+    // Look through the input of the current transpose.
+    mlir::Value *transposeInput = op.getOperand();
     TransposeOp transposeInputOp =
-        mlir::dyn_cast_or_null<TransposeOp>(transposeInputInst);
+        llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp());
+
+    // If the input is defined by another Transpose, bingo!
     if (!transposeInputOp)
       return matchFailure();
 
@@ -66,25 +62,21 @@ struct SimplifyRedundantTranspose : public mlir::RewritePattern {
 };
 
 /// Fold reshape(constant(x)) -> constant(x'), with x' being reshaped in place.
-struct SimplifyReshapeConstant : public mlir::RewritePattern {
-  SimplifyReshapeConstant(mlir::MLIRContext *context)
-      : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
-                       context) {}
+struct SimplifyReshapeConstant : public mlir::OpRewritePattern<ReshapeOp> {
+  using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
 
   mlir::PatternMatchResult
-  matchAndRewrite(mlir::Operation *op,
+  matchAndRewrite(ReshapeOp reshape,
                   mlir::PatternRewriter &rewriter) const override {
-    ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
-    // look through the input to the current reshape
-    mlir::Value *reshapeInput = reshape.getOperand();
-    mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp();
-    // If the input is defined by another reshape, bingo!
-    ConstantOp constantOp =
-        mlir::dyn_cast_or_null<ConstantOp>(reshapeInputInst);
+    // Look through the input of the current reshape.
+    ConstantOp constantOp = llvm::dyn_cast_or_null<ConstantOp>(
+        reshape.getOperand()->getDefiningOp());
+
+    // If the input is defined by another constant, bingo!
     if (!constantOp)
       return matchFailure();
 
-    auto reshapeType = op->getResult(0)->getType().cast<ToyArrayType>();
+    auto reshapeType = reshape.getType().cast<ToyArrayType>();
     if (auto valueAttr =
             constantOp.getAttrOfType<mlir::DenseElementsAttr>("value")) {
       // FIXME Check matching of element count!
@@ -93,9 +85,8 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern {
           reshapeType.getShape(), valueAttr.getType().getElementType());
       auto newAttr =
           mlir::DenseElementsAttr::get(newType, valueAttr.getRawData());
-      auto newConstant = rewriter.create<ConstantOp>(
-          constantOp.getLoc(), reshapeType.getShape(), newAttr);
-      rewriter.replaceOp(op, {newConstant});
+      rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
+                                              newAttr);
     } else if (auto valueAttr =
                    constantOp.getAttrOfType<mlir::FloatAttr>("value")) {
       // Broadcast
@@ -106,9 +97,8 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern {
       auto tensorTy = rewriter.getTensorType(reshapeType.getShape(),
                                              reshapeType.getElementType());
       auto newAttr = mlir::DenseElementsAttr::get(tensorTy, data);
-      auto newConstant = rewriter.create<ConstantOp>(
-          constantOp.getLoc(), reshapeType.getShape(), newAttr);
-      rewriter.replaceOp(op, {newConstant});
+      rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
+                                              newAttr);
     } else {
       llvm_unreachable("Unsupported Constant format");
     }
@@ -117,43 +107,35 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern {
 };
 
 /// Fold reshape(reshape(x)) -> reshape(x)
-struct SimplifyReshapeReshape : public mlir::RewritePattern {
-  SimplifyReshapeReshape(mlir::MLIRContext *context)
-      : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
-                       context) {}
+struct SimplifyReshapeReshape : public mlir::OpRewritePattern<ReshapeOp> {
+  using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
 
   mlir::PatternMatchResult
-  matchAndRewrite(mlir::Operation *op,
+  matchAndRewrite(ReshapeOp op,
                   mlir::PatternRewriter &rewriter) const override {
-    ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
-    // look through the input to the current reshape
-    mlir::Value *reshapeInput = reshape.getOperand();
-    mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp();
+    // Look through the input of the current reshape.
+    mlir::Value *reshapeInput = op.getOperand();
+
     // If the input is defined by another reshape, bingo!
-    ReshapeOp reshapeInputOp =
-        mlir::dyn_cast_or_null<ReshapeOp>(reshapeInputInst);
-    if (!reshapeInputOp)
+    if (!matchPattern(reshapeInput, mlir::m_Op<ReshapeOp>()))
       return matchFailure();
 
     // Use the rewriter to perform the replacement
-    rewriter.replaceOp(op, {reshapeInputOp});
+    rewriter.replaceOp(op, {reshapeInput});
     return matchSuccess();
   }
 };
 
 /// Fold reshape(x)) -> x, when input type matches output type
-struct SimplifyNullReshape : public mlir::RewritePattern {
-  SimplifyNullReshape(mlir::MLIRContext *context)
-      : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
-                       context) {}
+struct SimplifyNullReshape : public mlir::OpRewritePattern<ReshapeOp> {
+  using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
 
   mlir::PatternMatchResult
-  matchAndRewrite(mlir::Operation *op,
+  matchAndRewrite(ReshapeOp op,
                   mlir::PatternRewriter &rewriter) const override {
-    ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
-    if (reshape.getOperand()->getType() != reshape.getResult()->getType())
+    if (op.getOperand()->getType() != op.getType())
       return matchFailure();
-    rewriter.replaceOp(reshape, {reshape.getOperand()});
+    rewriter.replaceOp(op, {op.getOperand()});
     return matchSuccess();
   }
 };
@@ -176,17 +158,14 @@ void ReshapeOp::getCanonicalizationPatterns(
 namespace {
 
 /// Fold type.cast(x) -> x, when input type matches output type
-struct SimplifyIdentityTypeCast : public mlir::RewritePattern {
-  SimplifyIdentityTypeCast(mlir::MLIRContext *context)
-      : RewritePattern(TypeCastOp::getOperationName(), /* benefit = */ 1,
-                       context) {}
+struct SimplifyIdentityTypeCast : public mlir::OpRewritePattern<TypeCastOp> {
+  using mlir::OpRewritePattern<TypeCastOp>::OpRewritePattern;
 
   mlir::PatternMatchResult
-  matchAndRewrite(mlir::Operation *op,
+  matchAndRewrite(TypeCastOp typeCast,
                   mlir::PatternRewriter &rewriter) const override {
-    TypeCastOp typeCast = llvm::cast<TypeCastOp>(op);
-    auto resTy = typeCast.getResult()->getType();
-    auto *candidateOp = op;
+    auto resTy = typeCast.getType();
+    auto *candidateOp = typeCast.getOperation();
     while (llvm::isa_and_nonnull<TypeCastOp>(candidateOp)) {
       if (resTy == candidateOp->getOperand(0)->getType()) {
         rewriter.replaceOp(typeCast, {candidateOp->getOperand(0)});
index 60c8255..bbca58b 100644 (file)
@@ -205,6 +205,53 @@ protected:
   llvm::SmallVector<OperationName, 2> generatedOps;
 };
 
+/// OpRewritePattern is a wrapper around RewritePattern that allows for
+/// matching and rewriting against an instance of a derived operation class as
+/// opposed to a raw Operation.
+template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
+  /// Patterns must specify the root operation name they match against, and can
+  /// also specify the benefit of the pattern matching.
+  OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
+      : RewritePattern(SourceOp::getOperationName(), benefit, context) {}
+
+  /// Wrappers around the RewritePattern methods that pass the derived op type.
+  void rewrite(Operation *op, std::unique_ptr<PatternState> state,
+               PatternRewriter &rewriter) const final {
+    rewrite(llvm::cast<SourceOp>(op), std::move(state), rewriter);
+  }
+  void rewrite(Operation *op, PatternRewriter &rewriter) const final {
+    rewrite(llvm::cast<SourceOp>(op), rewriter);
+  }
+  PatternMatchResult match(Operation *op) const final {
+    return match(llvm::cast<SourceOp>(op));
+  }
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const final {
+    return matchAndRewrite(llvm::cast<SourceOp>(op), rewriter);
+  }
+
+  /// Rewrite and Match methods that operate on the SourceOp type. These must be
+  /// overridden by the derived pattern class.
+  virtual void rewrite(SourceOp op, std::unique_ptr<PatternState> state,
+                       PatternRewriter &rewriter) const {
+    rewrite(op, rewriter);
+  }
+  virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
+    llvm_unreachable("must override matchAndRewrite or a rewrite method");
+  }
+  virtual PatternMatchResult match(SourceOp op) const {
+    llvm_unreachable("must override match or matchAndRewrite");
+  }
+  virtual PatternMatchResult matchAndRewrite(SourceOp op,
+                                             PatternRewriter &rewriter) const {
+    if (auto matchResult = match(op)) {
+      rewrite(op, std::move(*matchResult), rewriter);
+      return matchSuccess();
+    }
+    return matchFailure();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // PatternRewriter class
 //===----------------------------------------------------------------------===//
index 130cb15..1a34c71 100644 (file)
@@ -654,24 +654,21 @@ void mlir::canonicalizeMapAndOperands(
 namespace {
 /// Simplify AffineApply operations.
 ///
-struct SimplifyAffineApply : public RewritePattern {
-  SimplifyAffineApply(MLIRContext *context)
-      : RewritePattern(AffineApplyOp::getOperationName(), 1, context) {}
+struct SimplifyAffineApply : public OpRewritePattern<AffineApplyOp> {
+  using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
 
-  PatternMatchResult matchAndRewrite(Operation *op,
+  PatternMatchResult matchAndRewrite(AffineApplyOp apply,
                                      PatternRewriter &rewriter) const override {
-    auto apply = cast<AffineApplyOp>(op);
     auto map = apply.getAffineMap();
 
     AffineMap oldMap = map;
     SmallVector<Value *, 8> resultOperands(apply.getOperands());
     composeAffineMapAndOperands(&map, &resultOperands);
-    if (map != oldMap) {
-      rewriter.replaceOpWithNewOp<AffineApplyOp>(op, map, resultOperands);
-      return matchSuccess();
-    }
+    if (map == oldMap)
+      return matchFailure();
 
-    return matchFailure();
+    rewriter.replaceOpWithNewOp<AffineApplyOp>(apply, map, resultOperands);
+    return matchSuccess();
   }
 };
 } // end anonymous namespace.
@@ -1002,14 +999,11 @@ void AffineForOp::print(OpAsmPrinter *p) {
 
 namespace {
 /// This is a pattern to fold constant loop bounds.
-struct AffineForLoopBoundFolder : public RewritePattern {
-  /// The rootOpName is the name of the root operation to match against.
-  AffineForLoopBoundFolder(MLIRContext *context)
-      : RewritePattern(AffineForOp::getOperationName(), 1, context) {}
+struct AffineForLoopBoundFolder : public OpRewritePattern<AffineForOp> {
+  using OpRewritePattern<AffineForOp>::OpRewritePattern;
 
-  PatternMatchResult matchAndRewrite(Operation *op,
+  PatternMatchResult matchAndRewrite(AffineForOp forOp,
                                      PatternRewriter &rewriter) const override {
-    auto forOp = cast<AffineForOp>(op);
     auto foldLowerOrUpperBound = [&forOp](bool lower) {
       // Check to see if each of the operands is the result of a constant.  If
       // so, get the value.  If not, ignore it.
@@ -1056,7 +1050,7 @@ struct AffineForLoopBoundFolder : public RewritePattern {
     // If any of the bounds were folded we return success.
     if (!folded)
       return matchFailure();
-    rewriter.updatedRootInPlace(op);
+    rewriter.updatedRootInPlace(forOp);
     return matchSuccess();
   }
 };
index 32d8de3..2a752c2 100644 (file)
@@ -118,15 +118,13 @@ static Value *emitDequantize(Location loc, Value *input,
 
 namespace {
 
-struct UniformDequantizePattern : public RewritePattern {
-  UniformDequantizePattern(MLIRContext *context)
-      : RewritePattern(DequantizeCastOp::getOperationName(), 1, context) {}
+struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> {
+  using OpRewritePattern<DequantizeCastOp>::OpRewritePattern;
 
-  PatternMatchResult matchAndRewrite(Operation *op,
+  PatternMatchResult matchAndRewrite(DequantizeCastOp op,
                                      PatternRewriter &rewriter) const {
-    auto dcastOp = cast<DequantizeCastOp>(op);
-    Type inputType = dcastOp.arg()->getType();
-    Type outputType = dcastOp.getResult()->getType();
+    Type inputType = op.arg()->getType();
+    Type outputType = op.getResult()->getType();
 
     QuantizedType inputElementType =
         QuantizedType::getQuantizedElementType(inputType);
@@ -136,8 +134,7 @@ struct UniformDequantizePattern : public RewritePattern {
       return matchFailure();
     }
 
-    Value *dequantizedValue =
-        emitDequantize(dcastOp.getLoc(), dcastOp.arg(), rewriter);
+    Value *dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter);
     if (!dequantizedValue) {
       return matchFailure();
     }
@@ -322,15 +319,13 @@ tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info,
 
 namespace {
 
-struct UniformRealAddEwPattern : public RewritePattern {
-  UniformRealAddEwPattern(MLIRContext *context)
-      : RewritePattern(RealAddEwOp::getOperationName(), 1, context) {}
+struct UniformRealAddEwPattern : public OpRewritePattern<RealAddEwOp> {
+  using OpRewritePattern<RealAddEwOp>::OpRewritePattern;
 
-  PatternMatchResult matchAndRewrite(Operation *op,
+  PatternMatchResult matchAndRewrite(RealAddEwOp op,
                                      PatternRewriter &rewriter) const {
-    auto addOp = cast<RealAddEwOp>(op);
-    const UniformBinaryOpInfo info(op, addOp.lhs(), addOp.rhs(),
-                                   addOp.clamp_min(), addOp.clamp_max());
+    const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
+                                   op.clamp_max());
     if (!info.isValid()) {
       return matchFailure();
     }
@@ -344,15 +339,13 @@ struct UniformRealAddEwPattern : public RewritePattern {
   }
 };
 
-struct UniformRealMulEwPattern : public RewritePattern {
-  UniformRealMulEwPattern(MLIRContext *context)
-      : RewritePattern(RealMulEwOp::getOperationName(), 1, context) {}
+struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> {
+  using OpRewritePattern<RealMulEwOp>::OpRewritePattern;
 
-  PatternMatchResult matchAndRewrite(Operation *op,
+  PatternMatchResult matchAndRewrite(RealMulEwOp op,
                                      PatternRewriter &rewriter) const {
-    auto mulOp = cast<RealMulEwOp>(op);
-    const UniformBinaryOpInfo info(op, mulOp.lhs(), mulOp.rhs(),
-                                   mulOp.clamp_min(), mulOp.clamp_max());
+    const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
+                                   op.clamp_max());
     if (!info.isValid()) {
       return matchFailure();
     }
index fb5b2e1..e237e8b 100644 (file)
@@ -38,26 +38,21 @@ namespace {
 
 /// Matches x -> [scast -> scast] -> y, replacing the second scast with the
 /// value of x if the casts invert each other.
-class RemoveRedundantStorageCastsRewrite : public RewritePattern {
+class RemoveRedundantStorageCastsRewrite
+    : public OpRewritePattern<StorageCastOp> {
 public:
-  RemoveRedundantStorageCastsRewrite(MLIRContext *context)
-      : RewritePattern(StorageCastOp::getOperationName(), 1, context) {}
+  using OpRewritePattern<StorageCastOp>::OpRewritePattern;
 
-  PatternMatchResult match(Operation *op) const override {
-    auto scastOp = cast<StorageCastOp>(op);
-    if (matchPattern(scastOp.arg(), m_Op<StorageCastOp>())) {
-      auto srcScastOp = cast<StorageCastOp>(scastOp.arg()->getDefiningOp());
-      if (srcScastOp.arg()->getType() == scastOp.getResult()->getType()) {
-        return matchSuccess();
-      }
-    }
-    return matchFailure();
-  }
+  PatternMatchResult matchAndRewrite(StorageCastOp op,
+                                     PatternRewriter &rewriter) const override {
+    if (!matchPattern(op.arg(), m_Op<StorageCastOp>()))
+      return matchFailure();
+    auto srcScastOp = cast<StorageCastOp>(op.arg()->getDefiningOp());
+    if (srcScastOp.arg()->getType() != op.getType())
+      return matchFailure();
 
-  void rewrite(Operation *op, PatternRewriter &rewriter) const override {
-    auto scastOp = cast<StorageCastOp>(op);
-    auto srcScastOp = cast<StorageCastOp>(scastOp.arg()->getDefiningOp());
     rewriter.replaceOp(op, srcScastOp.arg());
+    return matchSuccess();
   }
 };
 
index 44b1156..0c8ba31 100644 (file)
@@ -36,40 +36,35 @@ public:
   void runOnFunction() override;
 };
 
-class QuantizedConstRewrite : public RewritePattern {
-public:
-  struct State : PatternState {
-    QuantizedType quantizedElementType;
-    Attribute value;
-  };
-
-  QuantizedConstRewrite(MLIRContext *context)
-      : RewritePattern(QuantizeCastOp::getOperationName(), 1, context) {}
+struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> {
+  using OpRewritePattern<QuantizeCastOp>::OpRewritePattern;
 
-  PatternMatchResult match(Operation *op) const override;
-  void rewrite(Operation *op, std::unique_ptr<PatternState> baseState,
-               PatternRewriter &rewriter) const override;
+  PatternMatchResult matchAndRewrite(QuantizeCastOp qbarrier,
+                                     PatternRewriter &rewriter) const override;
 };
 
 } // end anonymous namespace
 
 /// Matches a [constant] -> [qbarrier] where the qbarrier results type is
 /// quantized and the operand type is quantizable.
-PatternMatchResult QuantizedConstRewrite::match(Operation *op) const {
-  State state;
+
+PatternMatchResult
+QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
+                                       PatternRewriter &rewriter) const {
+  Attribute value;
 
   // Is the operand a constant?
-  auto qbarrier = cast<QuantizeCastOp>(op);
-  if (!matchPattern(qbarrier.arg(), m_Constant(&state.value))) {
+  if (!matchPattern(qbarrier.arg(), m_Constant(&value))) {
     return matchFailure();
   }
+
   // Does the qbarrier convert to a quantized type. This will not be true
   // if a quantized type has not yet been chosen or if the cast to an equivalent
   // storage type is not supported.
   Type qbarrierResultType = qbarrier.getResult()->getType();
-  state.quantizedElementType =
+  QuantizedType quantizedElementType =
       QuantizedType::getQuantizedElementType(qbarrierResultType);
-  if (!state.quantizedElementType) {
+  if (!quantizedElementType) {
     return matchFailure();
   }
   if (!QuantizedType::castToStorageType(qbarrierResultType)) {
@@ -79,43 +74,34 @@ PatternMatchResult QuantizedConstRewrite::match(Operation *op) const {
   // Is the operand type compatible with the expressed type of the quantized
   // type? This will not be true if the qbarrier is superfluous (converts
   // from and to a quantized type).
-  if (!state.quantizedElementType.isCompatibleExpressedType(
+  if (!quantizedElementType.isCompatibleExpressedType(
           qbarrier.arg()->getType())) {
     return matchFailure();
   }
 
   // Is the constant value a type expressed in a way that we support?
-  if (!state.value.isa<FloatAttr>() && !state.value.isa<SplatElementsAttr>() &&
-      !state.value.isa<DenseElementsAttr>() &&
-      !state.value.isa<SparseElementsAttr>()) {
+  if (!value.isa<FloatAttr>() && !value.isa<SplatElementsAttr>() &&
+      !value.isa<DenseElementsAttr>() && !value.isa<SparseElementsAttr>()) {
     return matchFailure();
   }
 
-  return matchSuccess(llvm::make_unique<State>(std::move(state)));
-}
-
-void QuantizedConstRewrite::rewrite(Operation *op,
-                                    std::unique_ptr<PatternState> baseState,
-                                    PatternRewriter &rewriter) const {
-  auto state = static_cast<State *>(baseState.get());
-
   Type newConstValueType;
-  Attribute newConstValue = quantizeAttr(
-      state->value, state->quantizedElementType, newConstValueType);
+  auto newConstValue =
+      quantizeAttr(value, quantizedElementType, newConstValueType);
   if (!newConstValue) {
-    return;
+    return matchFailure();
   }
 
-  auto *origConstOp = op->getOperand(0);
   // When creating the new const op, use a fused location that combines the
   // original const and the qbarrier that led to the quantization.
-  auto fusedLoc =
-      FusedLoc::get({origConstOp->getDefiningOp()->getLoc(), op->getLoc()},
-                    rewriter.getContext());
+  auto fusedLoc = FusedLoc::get(
+      {qbarrier.arg()->getDefiningOp()->getLoc(), qbarrier.getLoc()},
+      rewriter.getContext());
   auto newConstOp =
       rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue);
-  rewriter.replaceOpWithNewOp<StorageCastOp>(
-      {origConstOp}, op, *op->result_type_begin(), newConstOp);
+  rewriter.replaceOpWithNewOp<StorageCastOp>({qbarrier.arg()}, qbarrier,
+                                             qbarrier.getType(), newConstOp);
+  return matchSuccess();
 }
 
 void ConvertConstPass::runOnFunction() {
index 508ebfe..dd67546 100644 (file)
@@ -291,24 +291,19 @@ static LogicalResult verify(AllocOp op) {
 
 namespace {
 /// Fold constant dimensions into an alloc operation.
-struct SimplifyAllocConst : public RewritePattern {
-  SimplifyAllocConst(MLIRContext *context)
-      : RewritePattern(AllocOp::getOperationName(), 1, context) {}
-
-  PatternMatchResult match(Operation *op) const override {
-    auto alloc = cast<AllocOp>(op);
+struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
+  using OpRewritePattern<AllocOp>::OpRewritePattern;
 
+  PatternMatchResult matchAndRewrite(AllocOp alloc,
+                                     PatternRewriter &rewriter) const override {
     // Check to see if any dimensions operands are constants.  If so, we can
     // substitute and drop them.
-    for (auto *operand : alloc.getOperands())
-      if (matchPattern(operand, m_ConstantIndex()))
-        return matchSuccess();
-    return matchFailure();
-  }
+    if (llvm::none_of(alloc.getOperands(), [](Value *operand) {
+          return matchPattern(operand, m_ConstantIndex());
+        }))
+      return matchFailure();
 
-  void rewrite(Operation *op, PatternRewriter &rewriter) const override {
-    auto allocOp = cast<AllocOp>(op);
-    auto memrefType = allocOp.getType();
+    auto memrefType = alloc.getType();
 
     // Ok, we have one or more constant operands.  Collect the non-constant ones
     // and keep track of the resultant memref type to build.
@@ -325,7 +320,7 @@ struct SimplifyAllocConst : public RewritePattern {
         newShapeConstants.push_back(dimSize);
         continue;
       }
-      auto *defOp = allocOp.getOperand(dynamicDimPos)->getDefiningOp();
+      auto *defOp = alloc.getOperand(dynamicDimPos)->getDefiningOp();
       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
         // Dynamic shape dimension will be folded.
         newShapeConstants.push_back(constantIndexOp.getValue());
@@ -334,7 +329,7 @@ struct SimplifyAllocConst : public RewritePattern {
       } else {
         // Dynamic shape dimension not folded; copy operand from old memref.
         newShapeConstants.push_back(-1);
-        newOperands.push_back(allocOp.getOperand(dynamicDimPos));
+        newOperands.push_back(alloc.getOperand(dynamicDimPos));
       }
       dynamicDimPos++;
     }
@@ -347,30 +342,29 @@ struct SimplifyAllocConst : public RewritePattern {
 
     // Create and insert the alloc op for the new memref.
     auto newAlloc =
-        rewriter.create<AllocOp>(allocOp.getLoc(), newMemRefType, newOperands);
+        rewriter.create<AllocOp>(alloc.getLoc(), newMemRefType, newOperands);
     // Insert a cast so we have the same type as the old alloc.
-    auto resultCast = rewriter.create<MemRefCastOp>(allocOp.getLoc(), newAlloc,
-                                                    allocOp.getType());
+    auto resultCast = rewriter.create<MemRefCastOp>(alloc.getLoc(), newAlloc,
+                                                    alloc.getType());
 
-    rewriter.replaceOp(op, {resultCast}, droppedOperands);
+    rewriter.replaceOp(alloc, {resultCast}, droppedOperands);
+    return matchSuccess();
   }
 };
 
 /// Fold alloc operations with no uses. Alloc has side effects on the heap,
 /// but can still be deleted if it has zero uses.
-struct SimplifyDeadAlloc : public RewritePattern {
-  SimplifyDeadAlloc(MLIRContext *context)
-      : RewritePattern(AllocOp::getOperationName(), 1, context) {}
+struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
+  using OpRewritePattern<AllocOp>::OpRewritePattern;
 
-  PatternMatchResult matchAndRewrite(Operation *op,
+  PatternMatchResult matchAndRewrite(AllocOp alloc,
                                      PatternRewriter &rewriter) const override {
     // Check if the alloc'ed value has any uses.
-    auto alloc = cast<AllocOp>(op);
     if (!alloc.use_empty())
       return matchFailure();
 
     // If it doesn't, we can eliminate it.
-    op->erase();
+    alloc.erase();
     return matchSuccess();
   }
 };
@@ -484,24 +478,22 @@ FunctionType CallOp::getCalleeType() {
 //===----------------------------------------------------------------------===//
 namespace {
 /// Fold indirect calls that have a constant function as the callee operand.
-struct SimplifyIndirectCallWithKnownCallee : public RewritePattern {
-  SimplifyIndirectCallWithKnownCallee(MLIRContext *context)
-      : RewritePattern(CallIndirectOp::getOperationName(), 1, context) {}
+struct SimplifyIndirectCallWithKnownCallee
+    : public OpRewritePattern<CallIndirectOp> {
+  using OpRewritePattern<CallIndirectOp>::OpRewritePattern;
 
-  PatternMatchResult matchAndRewrite(Operation *op,
+  PatternMatchResult matchAndRewrite(CallIndirectOp indirectCall,
                                      PatternRewriter &rewriter) const override {
-    auto indirectCall = cast<CallIndirectOp>(op);
-
     // Check that the callee is a constant callee.
     FunctionAttr calledFn;
     if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
       return matchFailure();
 
     // Replace with a direct call.
-    SmallVector<Type, 8> callResults(op->getResultTypes());
+    SmallVector<Type, 8> callResults(indirectCall.getResultTypes());
     SmallVector<Value *, 8> callOperands(indirectCall.getArgOperands());
-    rewriter.replaceOpWithNewOp<CallOp>(op, calledFn.getValue(), callResults,
-                                        callOperands);
+    rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn.getValue(),
+                                        callResults, callOperands);
     return matchSuccess();
   }
 };
@@ -964,14 +956,11 @@ namespace {
 /// cond_br true, ^bb1, ^bb2 -> br ^bb1
 /// cond_br false, ^bb1, ^bb2 -> br ^bb2
 ///
-struct SimplifyConstCondBranchPred : public RewritePattern {
-  SimplifyConstCondBranchPred(MLIRContext *context)
-      : RewritePattern(CondBranchOp::getOperationName(), 1, context) {}
+struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
+  using OpRewritePattern<CondBranchOp>::OpRewritePattern;
 
-  PatternMatchResult matchAndRewrite(Operation *op,
+  PatternMatchResult matchAndRewrite(CondBranchOp condbr,
                                      PatternRewriter &rewriter) const override {
-    auto condbr = cast<CondBranchOp>(op);
-
     // Check that the condition is a constant.
     if (!matchPattern(condbr.getCondition(), m_Op<ConstantOp>()))
       return matchFailure();
@@ -991,7 +980,7 @@ struct SimplifyConstCondBranchPred : public RewritePattern {
       branchArgs.assign(condbr.true_operand_begin(), condbr.true_operand_end());
     }
 
-    rewriter.replaceOpWithNewOp<BranchOp>(op, foldedDest, branchArgs);
+    rewriter.replaceOpWithNewOp<BranchOp>(condbr, foldedDest, branchArgs);
     return matchSuccess();
   }
 };
@@ -1230,18 +1219,14 @@ void ConstantIndexOp::build(Builder *builder, OperationState *result,
 namespace {
 /// Fold Dealloc operations that are deallocating an AllocOp that is only used
 /// by other Dealloc operations.
-struct SimplifyDeadDealloc : public RewritePattern {
-  SimplifyDeadDealloc(MLIRContext *context)
-      : RewritePattern(DeallocOp::getOperationName(), 1, context) {}
+struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
+  using OpRewritePattern<DeallocOp>::OpRewritePattern;
 
-  PatternMatchResult matchAndRewrite(Operation *op,
+  PatternMatchResult matchAndRewrite(DeallocOp dealloc,
                                      PatternRewriter &rewriter) const override {
-    auto dealloc = cast<DeallocOp>(op);
-
     // Check that the memref operand's defining operation is an AllocOp.
     Value *memref = dealloc.memref();
-    Operation *defOp = memref->getDefiningOp();
-    if (!isa_and_nonnull<AllocOp>(defOp))
+    if (!isa_and_nonnull<AllocOp>(memref->getDefiningOp()))
       return matchFailure();
 
     // Check that all of the uses of the AllocOp are other DeallocOps.
@@ -1250,7 +1235,7 @@ struct SimplifyDeadDealloc : public RewritePattern {
         return matchFailure();
 
     // Erase the dealloc operation.
-    op->erase();
+    rewriter.replaceOp(dealloc, llvm::None);
     return matchSuccess();
   }
 };