[mlir][Linalg] NFC - Modernize padding pattern
authorNicolas Vasilache <ntv@google.com>
Thu, 6 Jan 2022 13:30:49 +0000 (08:30 -0500)
committerNicolas Vasilache <ntv@google.com>
Thu, 6 Jan 2022 13:59:35 +0000 (08:59 -0500)
Differential Revision: https://reviews.llvm.org/D116739

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

index c1185c0..7592094 100644 (file)
@@ -688,7 +688,7 @@ struct LinalgGenericTilingPattern : public LinalgBaseTilingPattern {
 /// Apply the `padding` transformation as a pattern.
 /// `filter` controls LinalgTransformMarker matching and update when specified.
 /// See `padding` for more details.
-struct LinalgPaddingPattern : public RewritePattern {
+struct LinalgPaddingPattern : public OpInterfaceRewritePattern<LinalgOp> {
   // Entry point to match any LinalgOp OpInterface.
   LinalgPaddingPattern(
       MLIRContext *context,
@@ -701,7 +701,7 @@ struct LinalgPaddingPattern : public RewritePattern {
       LinalgPaddingOptions options = LinalgPaddingOptions(),
       LinalgTransformationFilter filter = LinalgTransformationFilter(),
       PatternBenefit benefit = 1);
-  LogicalResult matchAndRewrite(Operation *op,
+  LogicalResult matchAndRewrite(LinalgOp,
                                 PatternRewriter &rewriter) const override;
 
 private:
index 8b9c7bd..177a2ab 100644 (file)
@@ -489,23 +489,24 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
     MLIRContext *context, LinalgPaddingOptions options,
     LinalgTransformationFilter filter, PatternBenefit benefit)
-    : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
+    : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
       filter(std::move(filter)), options(std::move(options)) {}
 
 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
     StringRef opName, MLIRContext *context, LinalgPaddingOptions options,
     LinalgTransformationFilter filter, PatternBenefit benefit)
-    : RewritePattern(opName, benefit, context, {}), filter(std::move(filter)),
-      options(std::move(options)) {}
+    : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
+      filter(std::move(filter)), options(std::move(options)) {
+  this->filter.addFilter([opName](Operation *op) {
+    return success(op->getName().getStringRef() == opName);
+  });
+}
 
 LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
-  LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
-  if (!linalgOp)
-    return failure();
+    LinalgOp linalgOp, PatternRewriter &rewriter) const {
   if (!linalgOp.hasTensorSemantics())
     return failure();
-  if (failed(filter.checkAndNotify(rewriter, op)))
+  if (failed(filter.checkAndNotify(rewriter, linalgOp)))
     return failure();
 
   // Pad the operation.
@@ -538,7 +539,7 @@ LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite(
   }
 
   // Replace the original operation to pad.
-  rewriter.replaceOp(op, newResults.getValue());
+  rewriter.replaceOp(linalgOp, newResults.getValue());
   filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
   return success();
 }