[mlir][Linalg] NFC - Drop filter from LinalgGeneralizationPattern
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 12 Oct 2022 11:14:33 +0000 (04:14 -0700)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 12 Oct 2022 11:47:12 +0000 (04:47 -0700)
Differential Revision: https://reviews.llvm.org/D135761

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

index 044ce8d..e77d4e8 100644 (file)
@@ -772,35 +772,25 @@ struct DownscaleDepthwiseConv2DNhwcHwcOp final
 /// Linalg generalization pattern.
 ///
 /// Apply the `generalization` transformation as a pattern.
-/// `filter` controls LinalgTransformMarker matching and update when specified.
 /// See `generalization` for more details.
+//
+// TODO: Automatic default pattern class that just unwraps a function returning
+// FailureOr<GenericOp>.
 struct LinalgGeneralizationPattern
     : public OpInterfaceRewritePattern<LinalgOp> {
-  /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
-  LinalgGeneralizationPattern(
-      MLIRContext *context,
-      LinalgTransformationFilter f = LinalgTransformationFilter(),
-      PatternBenefit benefit = 1);
-
-  /// Construct a pattern specifically applied to `opName`.
-  LinalgGeneralizationPattern(
-      StringRef opName, MLIRContext *context,
-      LinalgTransformationFilter f = LinalgTransformationFilter(),
-      PatternBenefit benefit = 1);
+  using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
 
   /// `matchAndRewrite` implementation that returns the significant transformed
   /// pieces of IR.
   FailureOr<GenericOp>
-  returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const;
+  returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const {
+    return generalizeNamedOp(rewriter, op);
+  }
 
   LogicalResult matchAndRewrite(LinalgOp op,
                                 PatternRewriter &rewriter) const override {
     return returningMatchAndRewrite(op, rewriter);
   }
-
-private:
-  /// LinalgTransformMarker handles special attribute manipulations.
-  LinalgTransformationFilter filter;
 };
 
 ///
@@ -917,9 +907,7 @@ private:
 
 /// Populates `patterns` with patterns to convert spec-generated named ops to
 /// linalg.generic ops.
-void populateLinalgNamedOpsGeneralizationPatterns(
-    RewritePatternSet &patterns,
-    const LinalgTransformationFilter &filter = LinalgTransformationFilter());
+void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
 
 /// Linalg decompose convolutions patterns
 
index 97eee8b..d656e92 100644 (file)
@@ -86,8 +86,8 @@ void LinalgGeneralizationPass::runOnOperation() {
 }
 
 void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
-    RewritePatternSet &patterns, const LinalgTransformationFilter &marker) {
-  patterns.add<LinalgGeneralizationPattern>(patterns.getContext(), marker);
+    RewritePatternSet &patterns) {
+  patterns.add<LinalgGeneralizationPattern>(patterns.getContext());
 }
 
 std::unique_ptr<OperationPass<func::FuncOp>>
index 58923bc..63f74a3 100644 (file)
@@ -427,30 +427,6 @@ mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite(
   return paddedOp;
 }
 
-/// Linalg generalization pattern.
-mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
-    MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit)
-    : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
-      filter(std::move(f)) {}
-
-mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
-    StringRef opName, MLIRContext *context, LinalgTransformationFilter f,
-    PatternBenefit benefit)
-    : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
-      filter(f.addOpNameFilter(opName)) {}
-
-FailureOr<GenericOp>
-mlir::linalg::LinalgGeneralizationPattern::returningMatchAndRewrite(
-    LinalgOp linalgOp, PatternRewriter &rewriter) const {
-  if (failed(filter.checkAndNotify(rewriter, linalgOp)))
-    return failure();
-  FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, linalgOp);
-  if (failed(genericOp))
-    return failure();
-  filter.replaceLinalgTransformationFilter(rewriter, *genericOp);
-  return genericOp;
-}
-
 LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
     memref::CopyOp copyOp, PatternRewriter &rewriter) const {
   return vectorizeCopy(rewriter, copyOp);