[mlir][Linalg] Quarantine usage of LinalgTransformationFilter to TestTilingInterface.
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 12 Oct 2022 12:52:30 +0000 (05:52 -0700)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 12 Oct 2022 15:36:51 +0000 (08:36 -0700)
This revision also retires code that has now become dead.

Context: https://discourse.llvm.org/t/psa-retire-linalg-filter-based-patterns/63785

Differential Revision: https://reviews.llvm.org/D135771

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp

index bbac089..af95ed7 100644 (file)
@@ -362,67 +362,6 @@ LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp);
 //===----------------------------------------------------------------------===//
 // Transformations exposed as rewrite patterns.
 //===----------------------------------------------------------------------===//
-// Marker used as attribute name in generated Linalg rewriting transformations.
-struct LinalgTransforms {
-  static const StringLiteral kLinalgTransformMarker;
-};
-
-/// Helper class to control application of linalg transformation patterns.
-/// Control comes in 2 forms:
-///   1. attribute matching and setting behavior using the attribute named
-///      `kLinalgTransformMarker`. This can be used to build a state machine
-///      using attributes and incrementally applying patterns to advance states.
-///   2. filter function, which is a simple lambda on the Operation* that
-///      returns a LogicalResult.
-struct LinalgTransformationFilter {
-  using FilterFunction = std::function<LogicalResult(Operation *)>;
-
-  explicit LinalgTransformationFilter(
-      ArrayRef<StringAttr> matchDisjunction = {},
-      Optional<StringAttr> replacement = None);
-
-  explicit LinalgTransformationFilter(
-      const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction = {},
-      Optional<StringAttr> replacement = None);
-
-  LinalgTransformationFilter(LinalgTransformationFilter &&) = default;
-  LinalgTransformationFilter(const LinalgTransformationFilter &) = default;
-  LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const;
-  void replaceLinalgTransformationFilter(PatternRewriter &rewriter,
-                                         Operation *op) const;
-  bool hasReplacementFilter(Operation *op) const;
-
-  LinalgTransformationFilter &addFilter(const FilterFunction &f) {
-    if (f)
-      filters.push_back(f);
-    return *this;
-  }
-
-  template <typename... OpTypes>
-  LinalgTransformationFilter &addOpFilter() {
-    return addFilter(
-        [](Operation *op) { return success(isa<OpTypes...>(op)); });
-  }
-
-  LinalgTransformationFilter &addOpNameFilter(StringRef opName) {
-    return addFilter([opName](Operation *op) {
-      return success(op->getName().getStringRef() == opName);
-    });
-  }
-
-  LinalgTransformationFilter &setMatchByDefault() {
-    matchByDefault = true;
-    return *this;
-  }
-
-private:
-  SmallVector<FilterFunction> filters;
-  SmallVector<StringAttr> matchDisjunction;
-  Optional<StringAttr> replacement;
-  /// When set to true, if the attribute is not set, it will be treated as
-  /// a match. Default is false.
-  bool matchByDefault;
-};
 
 using TileSizeComputationFunction =
     std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>;
@@ -793,14 +732,7 @@ struct LinalgGeneralizationPattern
   }
 };
 
-///
-/// Linalg vectorization patterns.
-///
-/// Empty for now, used for SFINAE purposes only.
-struct LinalgVectorizationOptions {};
-
-/// `filter` controls LinalgTransformMarker matching and update when specified.
-/// See `vectorizeLinalgOp` for more details.
+/// Vectorization pattern for memref::CopyOp.
 struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
   using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
 
@@ -812,34 +744,6 @@ struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
 llvm::Optional<vector::CombiningKind> getCombinerOpKind(Operation *combinerOp);
 
 //===----------------------------------------------------------------------===//
-// Transformation and lowering options exposed as auxiliary structs.
-//===----------------------------------------------------------------------===//
-/// Options to control the application of enabling transformations.
-/// Hoisting transformations are always deemed beneficial and must be disabled
-/// explicitly.
-struct LinalgEnablingOptions {
-  /// Enable LICM.
-  bool licm = true;
-  LinalgEnablingOptions &enableLICM(bool val = true) {
-    licm = val;
-    return *this;
-  }
-  /// Enable hoisting of redundant vector transfer ops.
-  bool hoistRedundantVectorTransfers = true;
-  LinalgEnablingOptions &enableHoistRedundantVectorTransfers(bool val = true) {
-    hoistRedundantVectorTransfers = val;
-    return *this;
-  }
-  /// Enable hoisting of redundant vector transfer ops on tensor.
-  bool hoistRedundantVectorTransfersOnTensor = true;
-  LinalgEnablingOptions &
-  enableHoistRedundantVectorTransfersOnTensor(bool val = true) {
-    hoistRedundantVectorTransfersOnTensor = val;
-    return *this;
-  }
-};
-
-//===----------------------------------------------------------------------===//
 // Transformations exposed as rewrite patterns.
 //===----------------------------------------------------------------------===//
 
@@ -971,24 +875,6 @@ struct LinalgCopyVTWForwardingPattern
                                 PatternRewriter &rewriter) const override;
 };
 
-//===----------------------------------------------------------------------===//
-// Support for staged pattern application.
-//===----------------------------------------------------------------------===//
-/// Helper function to allow applying rewrite patterns, interleaved with more
-/// global transformations, in a staged fashion:
-///   1. the first stage consists of a list of FrozenRewritePatternSet. Each
-///   FrozenRewritePatternSet in this list is applied once, in order.
-///   2. the second stage consists of a single RewritePattern that is applied
-///      greedily until convergence.
-///   3. the third stage consists of applying a lambda, generally used for
-///   non-local transformation effects. This allows creating custom fused
-///   transformations where patterns can be ordered and applied at a finer
-///   granularity than a sequence of traditional compiler passes.
-LogicalResult applyStagedPatterns(
-    Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
-    const FrozenRewritePatternSet &stage2Patterns,
-    function_ref<LogicalResult(Operation *)> stage3Lambda = nullptr);
-
 /// Rewrite extract_slice(tensor.pad(x)) into tensor.pad(extract_slice(x)).
 struct ExtractSliceOfPadTensorSwapPattern
     : public OpRewritePattern<tensor::ExtractSliceOp> {
@@ -1015,20 +901,6 @@ private:
   ControlFn controlFn;
 };
 
-//===----------------------------------------------------------------------===//
-// Helper classes for type list expansion.
-//===----------------------------------------------------------------------===//
-template <typename... OpTypes>
-class VectorizationPatterns;
-
-template <>
-class VectorizationPatterns<> {
-public:
-  static void insert(RewritePatternSet &patterns,
-                     const LinalgVectorizationOptions &options,
-                     const LinalgTransformationFilter &f) {}
-};
-
 /// Split Reduction options.
 struct SplitReductionOptions {
   // Ratio used to split the reduction dimension.  If the ratio is <= 1, nothing
index 34df2c8..dd04d00 100644 (file)
@@ -659,14 +659,10 @@ struct PadOpTilingPattern : public OpRewritePattern<tensor::PadOp> {
 
   LogicalResult matchAndRewrite(tensor::PadOp op,
                                 PatternRewriter &rewriter) const override {
-    if (op->hasAttr(LinalgTransforms::kLinalgTransformMarker))
-      return failure();
     tensor::PadOp newPadOp;
     LoopNest loopNest;
     if (failed(tilePadOp(rewriter, op, newPadOp, loopNest, options)))
       return failure();
-    newPadOp->setAttr(LinalgTransforms::kLinalgTransformMarker,
-                      rewriter.getUnitAttr());
     // Replace all uses of the original tensor::PadOp.
     rewriter.replaceOp(op, loopNest.getResults()[0]);
     return success();
index 63f74a3..8eb41c5 100644 (file)
@@ -47,75 +47,6 @@ using namespace mlir::linalg;
 //===----------------------------------------------------------------------===//
 // Transformations exposed as rewrite patterns.
 //===----------------------------------------------------------------------===//
-// Marker used as attribute name in generated Linalg rewriting transformations.
-const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
-    "__internal_linalg_transform__";
-
-mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
-    ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement)
-    : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
-      replacement(replacement), matchByDefault(false) {}
-
-mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
-    const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction,
-    Optional<StringAttr> replacement)
-    : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
-      replacement(replacement), matchByDefault(false) {
-  if (f)
-    filters.push_back(f);
-}
-
-LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
-    PatternRewriter &rewriter, Operation *op) const {
-  if (llvm::any_of(filters,
-                   [&](const FilterFunction &f) { return failed(f(op)); }))
-    return failure();
-
-  auto attr = op->template getAttrOfType<StringAttr>(
-      LinalgTransforms::kLinalgTransformMarker);
-
-  if (!attr) {
-    // 1. Has no filter case and matchDisjunction is empty.
-    if (matchDisjunction.empty() || matchByDefault)
-      return success();
-
-    // 2. Has no filter but was expecting a filter.
-    return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
-      diag << " does not have any filter from list: ";
-      interleaveComma(matchDisjunction, diag);
-    });
-  }
-
-  // 4. Match explicit filter.
-  for (auto filter : matchDisjunction)
-    if (attr.getValue() == filter)
-      return success();
-
-  // 5. Fail to match.
-  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
-    diag << " does not have any filter from list: ";
-    interleaveComma(matchDisjunction, diag);
-  });
-}
-
-void mlir::linalg::LinalgTransformationFilter::
-    replaceLinalgTransformationFilter(PatternRewriter &rewriter,
-                                      Operation *op) const {
-  if (replacement.has_value())
-    op->setAttr(LinalgTransforms::kLinalgTransformMarker, replacement.value());
-  else
-    op->removeAttr(
-        rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker));
-}
-
-bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter(
-    Operation *op) const {
-  if (!replacement)
-    return false;
-  auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker)
-                  .dyn_cast<StringAttr>();
-  return attr && attr == *replacement;
-}
 
 LinalgTilingOptions &
 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
@@ -432,37 +363,6 @@ LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
   return vectorizeCopy(rewriter, copyOp);
 }
 
-LogicalResult mlir::linalg::applyStagedPatterns(
-    Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
-    const FrozenRewritePatternSet &stage2Patterns,
-    function_ref<LogicalResult(Operation *)> stage3Lambda) {
-  unsigned iteration = 0;
-  (void)iteration;
-  for (const auto &patterns : stage1Patterns) {
-    LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
-                      << *op);
-    if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
-      LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
-      return failure();
-    }
-    LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
-                      << *op);
-    if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
-      LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
-      return failure();
-    }
-    LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n"
-                      << *op);
-    if (stage3Lambda) {
-      if (failed(stage3Lambda(op)))
-        return failure();
-      LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n"
-                        << *op);
-    }
-  }
-  return success();
-}
-
 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
   return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
 }
index 146a32e..7313c30 100644 (file)
@@ -127,11 +127,6 @@ static void applyPatterns(func::FuncOp funcOp) {
   patterns.add<CopyVectorizationPattern>(ctx);
 
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
-
-  // Drop the marker.
-  funcOp.walk([](LinalgOp op) {
-    op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
-  });
 }
 
 static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) {
@@ -182,13 +177,6 @@ static void applySwapExtractSliceWithFillPattern(func::FuncOp funcOp) {
 
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnOperation() {
-  auto lambda = [&](void *) {
-    getOperation().walk([](LinalgOp op) {
-      op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
-    });
-  };
-  std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
-
   if (testPatterns)
     return applyPatterns(getOperation());
   if (testVectorTransferForwardingPatterns)
index 8e3b976..31e3c1a 100644 (file)
 
 using namespace mlir;
 
+// TODO: this file should disappear and instead tests should make use of the
+// transform dialect.
 namespace {
 
+/// Marker used as attribute name in generated Linalg rewriting transformations.
+const StringLiteral kLinalgTransformMarker = "__internal_linalg_transform__";
+
+/// Helper class to control application of linalg transformation patterns.
+/// Control comes in 2 forms:
+///   1. attribute matching and setting behavior using the attribute named
+///      `kLinalgTransformMarker`. This can be used to build a state machine
+///      using attributes and incrementally applying patterns to advance states.
+///   2. filter function, which is a simple lambda on the Operation* that
+///      returns a LogicalResult.
+struct LinalgTransformationFilter {
+  using FilterFunction = std::function<LogicalResult(Operation *)>;
+
+  explicit LinalgTransformationFilter(
+      ArrayRef<StringAttr> matchDisjunction = {},
+      Optional<StringAttr> replacement = None);
+
+  explicit LinalgTransformationFilter(
+      const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction = {},
+      Optional<StringAttr> replacement = None);
+
+  LinalgTransformationFilter(LinalgTransformationFilter &&) = default;
+  LinalgTransformationFilter(const LinalgTransformationFilter &) = default;
+  LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const;
+  void replaceLinalgTransformationFilter(PatternRewriter &rewriter,
+                                         Operation *op) const;
+  bool hasReplacementFilter(Operation *op) const;
+
+  LinalgTransformationFilter &addFilter(const FilterFunction &f) {
+    if (f)
+      filters.push_back(f);
+    return *this;
+  }
+
+  template <typename... OpTypes>
+  LinalgTransformationFilter &addOpFilter() {
+    return addFilter(
+        [](Operation *op) { return success(isa<OpTypes...>(op)); });
+  }
+
+  LinalgTransformationFilter &addOpNameFilter(StringRef opName) {
+    return addFilter([opName](Operation *op) {
+      return success(op->getName().getStringRef() == opName);
+    });
+  }
+
+  LinalgTransformationFilter &setMatchByDefault() {
+    matchByDefault = true;
+    return *this;
+  }
+
+private:
+  SmallVector<FilterFunction> filters;
+  SmallVector<StringAttr> matchDisjunction;
+  Optional<StringAttr> replacement;
+  /// When set to true, if the attribute is not set, it will be treated as
+  /// a match. Default is false.
+  bool matchByDefault;
+};
+
+LinalgTransformationFilter::LinalgTransformationFilter(
+    ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement)
+    : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
+      replacement(replacement), matchByDefault(false) {}
+
+LinalgTransformationFilter::LinalgTransformationFilter(
+    const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction,
+    Optional<StringAttr> replacement)
+    : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
+      replacement(replacement), matchByDefault(false) {
+  if (f)
+    filters.push_back(f);
+}
+
+LogicalResult
+LinalgTransformationFilter::checkAndNotify(PatternRewriter &rewriter,
+                                           Operation *op) const {
+  if (llvm::any_of(filters,
+                   [&](const FilterFunction &f) { return failed(f(op)); }))
+    return failure();
+
+  auto attr = op->template getAttrOfType<StringAttr>(kLinalgTransformMarker);
+
+  if (!attr) {
+    // 1. Has no filter case and matchDisjunction is empty.
+    if (matchDisjunction.empty() || matchByDefault)
+      return success();
+
+    // 2. Has no filter but was expecting a filter.
+    return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+      diag << " does not have any filter from list: ";
+      interleaveComma(matchDisjunction, diag);
+    });
+  }
+
+  // 4. Match explicit filter.
+  for (auto filter : matchDisjunction)
+    if (attr.getValue() == filter)
+      return success();
+
+  // 5. Fail to match.
+  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+    diag << " does not have any filter from list: ";
+    interleaveComma(matchDisjunction, diag);
+  });
+}
+
+void LinalgTransformationFilter::replaceLinalgTransformationFilter(
+    PatternRewriter &rewriter, Operation *op) const {
+  if (replacement.has_value())
+    op->setAttr(kLinalgTransformMarker, replacement.value());
+  else
+    op->removeAttr(rewriter.getStringAttr(kLinalgTransformMarker));
+}
+
+bool LinalgTransformationFilter::hasReplacementFilter(Operation *op) const {
+  if (!replacement)
+    return false;
+  auto attr = op->getAttr(kLinalgTransformMarker).dyn_cast<StringAttr>();
+  return attr && attr == *replacement;
+}
+
 /// Pattern for testing `TileUsingSCFForOp` pattern (that tiles operations using
 /// the `TilingInterface` with `scf.for` ops for iterating over the tiles) while
 /// using a `filter` to avoid recursive application.
 struct TestTileUsingSCFForOp
     : public OpInterfaceRewritePattern<TilingInterface> {
-  TestTileUsingSCFForOp(MLIRContext *context, scf::SCFTilingOptions options,
-                        linalg::LinalgTransformationFilter filter =
-                            linalg::LinalgTransformationFilter(),
-                        PatternBenefit benefit = 1)
+  TestTileUsingSCFForOp(
+      MLIRContext *context, scf::SCFTilingOptions options,
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      PatternBenefit benefit = 1)
       : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
         options(std::move(options)), filter(std::move(filter)) {}
 
   /// Construct a generic pattern applied to `opName`.
-  TestTileUsingSCFForOp(StringRef opName, MLIRContext *context,
-                        scf::SCFTilingOptions options,
-                        linalg::LinalgTransformationFilter filter =
-                            linalg::LinalgTransformationFilter(),
-                        PatternBenefit benefit = 1)
+  TestTileUsingSCFForOp(
+      StringRef opName, MLIRContext *context, scf::SCFTilingOptions options,
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      PatternBenefit benefit = 1)
       : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
         options(std::move(options)), filter(std::move(filter)) {}
 
@@ -76,7 +199,7 @@ struct TestTileUsingSCFForOp
 
 private:
   scf::SCFTilingOptions options;
-  linalg::LinalgTransformationFilter filter;
+  LinalgTransformationFilter filter;
 };
 
 /// Pattern for testing `TileConsumerAndFuseProducersUsingSCFForOp` pattern
@@ -87,8 +210,7 @@ struct TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp
     : public OpInterfaceRewritePattern<TilingInterface> {
   TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp(
       MLIRContext *context, scf::SCFTileAndFuseOptions options,
-      linalg::LinalgTransformationFilter filter =
-          linalg::LinalgTransformationFilter(),
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
       PatternBenefit benefit = 1)
       : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
         options(std::move(options)), filter(std::move(filter)) {}
@@ -97,8 +219,7 @@ struct TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp
   TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp(
       StringRef opName, MLIRContext *context,
       scf::SCFTileAndFuseOptions options,
-      linalg::LinalgTransformationFilter filter =
-          linalg::LinalgTransformationFilter(),
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
       PatternBenefit benefit = 1)
       : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
         options(std::move(options)), filter(std::move(filter)) {}
@@ -129,7 +250,7 @@ struct TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp
 
 private:
   scf::SCFTileAndFuseOptions options;
-  linalg::LinalgTransformationFilter filter;
+  LinalgTransformationFilter filter;
 };
 
 /// Pattern to lower operations that implement the `TilingInterface` to
@@ -202,8 +323,8 @@ static void addPatternForTiling(MLIRContext *context,
                                 ArrayRef<int64_t> interchange = {}) {
   scf::SCFTilingOptions tilingOptions;
   tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
-  linalg::LinalgTransformationFilter filter(
-      StringAttr::get(context, filterName), StringAttr::get(context, "tiled"));
+  LinalgTransformationFilter filter(StringAttr::get(context, filterName),
+                                    StringAttr::get(context, "tiled"));
   patterns.add<TestTileUsingSCFForOp>(context, tilingOptions, filter);
 }
 
@@ -215,8 +336,8 @@ static void addPatternForTileAndFuse(MLIRContext *context,
   scf::SCFTileAndFuseOptions tileAndFuseOptions;
   tileAndFuseOptions.tilingOptions.setTileSizes(tileSizes).setInterchange(
       interchange);
-  linalg::LinalgTransformationFilter filter(
-      StringAttr::get(context, filterName), StringAttr::get(context, "tiled"));
+  LinalgTransformationFilter filter(StringAttr::get(context, filterName),
+                                    StringAttr::get(context, "tiled"));
   patterns.add<TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp>(
       context, tileAndFuseOptions, filter);
 }