[mlir][Linalg] Retire LinalgStrategyTileAndFusePass and filter-based pattern.
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Mon, 10 Oct 2022 06:36:10 +0000 (23:36 -0700)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Mon, 10 Oct 2022 14:04:01 +0000 (07:04 -0700)
Context: https://discourse.llvm.org/t/psa-retire-linalg-filter-based-patterns/63785

In the process, also retire `tileConsumerAndFuseProducers` that is now replaced by `tileConsumerAndFuseProducerGreedilyUsingSCFForOp`.

Context: https://discourse.llvm.org/t/psa-retire-tileandfuselinalgops-method/63850

When performing this replacement, a change of behavior appeared: the older `tileConsumerAndFuseProducers` would split the parallel
and non-parallel dimensions automatically and perform a first level of tile-and-fuse on parallel dimensions only and then introduce a
second level of tiling-only on the reduction dimensions. The newer `tileConsumerAndFuseProducerGreedilyUsingSCFForOp` on the other hand
does not perform this breakdown. As a consequence, the transform specification is evolved to produce the same output.

Additionally, replace some uses of `unsigned` by `int64_t` where possible without pulling in larger interface changes (left for a future PR).

Context: https://www.youtube.com/watch?v=Puio5dly9N8

Lastly, tests that were performing tile and fuse and distribute on tensors are retired: the generated IR mixing scf.for, tensors and
distributed processor ids was racy at best ..

Differential Revision: https://reviews.llvm.org/D135559

15 files changed:
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir [deleted file]
mlir/test/Dialect/Linalg/transform-op-fuse.mlir
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp

index 719a792..6e41f05 100644 (file)
@@ -76,13 +76,6 @@ std::unique_ptr<Pass> createLinalgDetensorizePass();
 //===----------------------------------------------------------------------===//
 /// Linalg strategy passes.
 //===----------------------------------------------------------------------===//
-/// Create a LinalgStrategyTileAndFusePass.
-std::unique_ptr<OperationPass<func::FuncOp>>
-createLinalgStrategyTileAndFusePass(
-    StringRef opName = "", const linalg::LinalgTilingAndFusionOptions &opt = {},
-    const linalg::LinalgTransformationFilter &filter =
-        linalg::LinalgTransformationFilter());
-
 /// Create a LinalgStrategyTilePass.
 std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyTilePass(
     StringRef opName = "",
index 1889d1e..40a2f11 100644 (file)
@@ -162,18 +162,6 @@ def LinalgDetensorize : Pass<"linalg-detensorize", ""> {
   ];
 }
 
-def LinalgStrategyTileAndFusePass
-    : Pass<"linalg-strategy-tile-and-fuse-pass", "func::FuncOp"> {
-  let summary = "Configurable pass to apply pattern-based tiling and fusion.";
-  let constructor = "mlir::createLinalgStrategyTileAndFusePass()";
-  let options = [
-    Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
-      "Which func op is the anchor to latch on.">,
-    Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"",
-      "Which linalg op within the func is the anchor to latch on.">,
-  ];
-}
-
 def LinalgStrategyTilePass
     : Pass<"linalg-strategy-tile-pass", "func::FuncOp"> {
   let summary = "Configurable pass to apply pattern-based linalg tiling.";
index 6f56702..d7c0d22 100644 (file)
@@ -30,23 +30,6 @@ struct Transformation {
   LinalgTransformationFilter::FilterFunction filter = nullptr;
 };
 
-/// Represent one application of LinalgStrategyTileAndFusePass.
-struct TileAndFuse : public Transformation {
-  TileAndFuse(StringRef name, linalg::LinalgTilingAndFusionOptions options,
-              LinalgTransformationFilter::FilterFunction f = nullptr)
-      : Transformation(std::move(f)), opName(name),
-        options(std::move(options)) {}
-
-  void addToPassPipeline(OpPassManager &pm,
-                         LinalgTransformationFilter m) const override {
-    pm.addPass(createLinalgStrategyTileAndFusePass(opName, options, m));
-  }
-
-private:
-  std::string opName;
-  linalg::LinalgTilingAndFusionOptions options;
-};
-
 /// Represent one application of LinalgStrategyTilePass.
 struct Tile : public Transformation {
   Tile(StringRef name, linalg::LinalgTilingOptions options,
@@ -66,22 +49,6 @@ private:
 
 /// Codegen strategy controls how a Linalg op is progressively lowered.
 struct CodegenStrategy {
-  /// Append a pattern to tile the Op `opName` and fuse its producers with
-  /// tiling and fusion `options`.
-  CodegenStrategy &
-  tileAndFuse(StringRef opName, const LinalgTilingAndFusionOptions &options,
-              const LinalgTransformationFilter::FilterFunction &f = nullptr) {
-    transformationSequence.emplace_back(
-        std::make_unique<TileAndFuse>(opName, options, f));
-    return *this;
-  }
-  /// Conditionally append a pattern to tile the Op `opName` and fuse its
-  /// producers with tiling and fusion `options`.
-  CodegenStrategy &
-  tileAndFuseIf(bool b, StringRef opName, LinalgTilingAndFusionOptions options,
-                LinalgTransformationFilter::FilterFunction f = nullptr) {
-    return b ? tileAndFuse(opName, std::move(options), std::move(f)) : *this;
-  }
   /// Append a pattern to add a level of tiling for Op `opName` with tiling
   /// `options`.
   CodegenStrategy &
index b7f99ab..62dcc8e 100644 (file)
@@ -788,42 +788,6 @@ struct DownscaleDepthwiseConv2DNhwcHwcOp final
 };
 
 ///
-/// Linalg tile and fuse tensor ops pattern.
-///
-/// Apply tiling and fusion as a pattern.
-/// `filter` controls LinalgTransformMarker matching and update when specified.
-/// See `tileConsumerAndFuseProducers` for more details.
-struct LinalgTileAndFuseTensorOpsPattern : public RewritePattern {
-  // Entry point to match any LinalgOp.
-  LinalgTileAndFuseTensorOpsPattern(
-      MLIRContext *context, LinalgTilingAndFusionOptions options,
-      LinalgTransformationFilter f = LinalgTransformationFilter(),
-      PatternBenefit benefit = 1);
-  // Entry point to match a specific LinalgOp.
-  LinalgTileAndFuseTensorOpsPattern(
-      StringRef opName, MLIRContext *context,
-      LinalgTilingAndFusionOptions options,
-      LinalgTransformationFilter f = LinalgTransformationFilter(),
-      PatternBenefit benefit = 1);
-
-  /// `matchAndRewrite` implementation that returns the significant transformed
-  /// pieces of IR.
-  FailureOr<TileLoopNest>
-  returningMatchAndRewrite(Operation *op, PatternRewriter &rewriter) const;
-
-  LogicalResult matchAndRewrite(Operation *op,
-                                PatternRewriter &rewriter) const override {
-    return returningMatchAndRewrite(op, rewriter);
-  }
-
-private:
-  /// LinalgTransformMarker handles special attribute manipulations.
-  LinalgTransformationFilter filter;
-  /// Tile sizes and interchange used to tile the root operation.
-  LinalgTilingAndFusionOptions options;
-};
-
-///
 /// Linalg generalization pattern.
 ///
 /// Apply the `generalization` transformation as a pattern.
index 3ec6fc7..305b859 100644 (file)
@@ -445,14 +445,6 @@ private:
   DenseMap<Operation *, SmallVector<int64_t>> tiledRootAndFusedOpsLoops;
 };
 
-/// Tiles `consumerOp` and fuses its dependencies if possible. Uses the
-/// `tileSizes`, `tileInterchange`, and `tileDistribution` parameters to control
-/// the tiling.
-FailureOr<TileLoopNest> tileConsumerAndFuseProducers(
-    OpBuilder &b, LinalgOp consumerOp, ArrayRef<int64_t> tileSizes,
-    ArrayRef<int64_t> tileInterchange,
-    const Optional<LinalgLoopDistributionOptions> &tileDistribution);
-
 //===----------------------------------------------------------------------===//
 // Generic op region utilities
 //===----------------------------------------------------------------------===//
index 1c374d6..174b39c 100644 (file)
@@ -53,8 +53,8 @@ struct SCFTilingOptions {
   SCFTilingOptions &setTileSizes(ArrayRef<int64_t> ts);
 
   /// The interchange vector to reorder the tiled loops.
-  SmallVector<unsigned> interchangeVector = {};
-  SCFTilingOptions &setInterchange(ArrayRef<unsigned> interchange) {
+  SmallVector<int64_t> interchangeVector = {};
+  SCFTilingOptions &setInterchange(ArrayRef<int64_t> interchange) {
     interchangeVector = llvm::to_vector(interchange);
     return *this;
   }
index 99f93ed..5b82520 100644 (file)
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/PDL/IR/PDL.h"
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
+#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Interfaces/TilingInterface.h"
@@ -99,45 +100,63 @@ transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
   results.assign(1, nullptr);
   return emitDefaultSilenceableFailure(target);
 }
-
 //===----------------------------------------------------------------------===//
 // FuseOp
 //===----------------------------------------------------------------------===//
 
 /// Apply a tiling transformation to all payload ops and store both the
 /// tiled operation as well as the created tile loops.
-static LogicalResult
-applyTilingToAll(Operation *transformOp, ArrayRef<Operation *> payloadOps,
-                 unsigned numLoops,
-                 transform::TransformResults &transformResults,
-                 function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
+static LogicalResult applyTilingToAll(
+    Operation *transformOp, ArrayRef<Operation *> payloadOps, unsigned numLoops,
+    transform::TransformResults &transformResults,
+    function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
+        applyFn) {
   SmallVector<Operation *> tiledLinalgOps;
   SmallVector<SmallVector<Operation *>> loopOps(numLoops);
   for (unsigned int i = 0; i < numLoops; ++i)
     loopOps[i].reserve(payloadOps.size());
 
   for (Operation *target : payloadOps) {
-    auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
-    if (!linalgOp)
-      return transformOp->emitError("only LinalgOps are supported");
-
-    FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
-    if (failed(tiled))
+    auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
+    if (!tilingInterfaceOp)
+      return transformOp->emitError("only TilingInterface ops are supported");
+
+    SimpleRewriter rewriter(target->getContext());
+    rewriter.setInsertionPoint(target);
+    FailureOr<scf::SCFTileAndFuseResult> tiledResults =
+        applyFn(tilingInterfaceOp);
+    if (failed(tiledResults))
       return failure();
 
-    tiledLinalgOps.push_back(tiled->op);
-    if (tiled->loops.size() != numLoops)
-      // Not enough loops were generated. This usually means that the input size
-      // was smaller than the tiling size.
-      // TODO: LinalgTilingPattern should return failure().
-      return failure();
+    // Perform the replacement of tiled and fused values.
+    SmallVector<Operation *> opsToReplace{target};
+    llvm::append_range(opsToReplace, tiledResults->fusedProducers);
+    for (Operation *toReplace : opsToReplace) {
+      SmallVector<Value> replacements;
+      replacements.reserve(toReplace->getNumResults());
+      for (OpResult res : toReplace->getResults()) {
+        auto it = tiledResults->replacements.find(res);
+        if (it == tiledResults->replacements.end())
+          replacements.push_back(res);
+        else
+          replacements.push_back(it->getSecond());
+      }
+      rewriter.replaceOp(toReplace, replacements);
+    }
+
+    // Report back the relevant handles to the transform op.
+    tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
+    assert(tiledResults->loops.size() == numLoops &&
+           "Mismatched number of loops, tile and fuse transform should have "
+           "failed");
     for (unsigned int i = 0; i < numLoops; ++i)
-      loopOps[i].push_back(tiled->loops[i]);
+      loopOps[i].push_back(tiledResults->loops[i]);
   }
 
   transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
   for (unsigned int i = 0; i < numLoops; ++i)
     transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
+
   return success();
 }
 
@@ -172,27 +191,23 @@ static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result,
 DiagnosedSilenceableFailure
 transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
                          mlir::transform::TransformState &state) {
-  LinalgTilingAndFusionOptions fusionOptions;
-  fusionOptions.tileSizes = extractFromI64ArrayAttr(getTileSizes());
-  fusionOptions.tileInterchange = extractFromI64ArrayAttr(getTileInterchange());
+  SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getTileSizes());
+  SmallVector<int64_t> tileInterchange =
+      extractFromI64ArrayAttr(getTileInterchange());
 
+  scf::SCFTilingOptions tilingOptions;
+  tilingOptions.interchangeVector = tileInterchange;
+  tilingOptions = tilingOptions.setTileSizes(tileSizes);
+  scf::SCFTileAndFuseOptions tileAndFuseOptions;
+  tileAndFuseOptions.tilingOptions = tilingOptions;
   LogicalResult result = applyTilingToAll(
       getOperation(), state.getPayloadOps(getTarget()),
-      fusionOptions.tileSizes.size() - llvm::count(fusionOptions.tileSizes, 0),
-      transformResults, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
-        LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions);
+      tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
+      [&](TilingInterface tilingInterfaceOp)
+          -> FailureOr<scf::SCFTileAndFuseResult> {
         SimpleRewriter rewriter(getContext());
-        rewriter.setInsertionPoint(linalgOp);
-        FailureOr<TileLoopNest> tileLoopNest =
-            pattern.returningMatchAndRewrite(linalgOp, rewriter);
-        if (failed(tileLoopNest))
-          return failure();
-
-        TiledLinalgOp tiledLinalgOp;
-        tiledLinalgOp.op = tileLoopNest->getRootOp();
-        tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(),
-                               tileLoopNest->getLoopOps().end()};
-        return tiledLinalgOp;
+        return tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
+            rewriter, tilingInterfaceOp, tileAndFuseOptions);
       });
   return DiagnosedSilenceableFailure(result);
 }
index d29b767..2451c79 100644 (file)
@@ -414,68 +414,3 @@ SmallVector<LinalgOp> TileLoopNest::getAllTiledAndFusedOps() {
   }
   return result;
 }
-
-//===----------------------------------------------------------------------===//
-// Tile and fuse entry-points.
-//===----------------------------------------------------------------------===//
-
-FailureOr<TileLoopNest> mlir::linalg::tileConsumerAndFuseProducers(
-    OpBuilder &b, LinalgOp consumerOp, ArrayRef<int64_t> tileSizes,
-    ArrayRef<int64_t> tileInterchange,
-    const Optional<LinalgLoopDistributionOptions> &tileDistribution) {
-  assert(tileSizes.size() == tileInterchange.size() &&
-         "expect the number of tile sizes and interchange dims to match");
-  assert(isPermutation(tileInterchange) &&
-         "expect tile interchange is a permutation");
-
-  // Create an empty tile loop nest.
-  TileLoopNest tileLoopNest(consumerOp);
-
-  // Search the number of outer parallel loops to separate them from possible
-  // inner reduction dimensions.
-  SmallVector<StringRef> iterTypes = consumerOp.getIteratorTypesArray();
-  applyPermutationToVector(iterTypes, tileInterchange);
-  auto *it = find_if_not(iterTypes, isParallelIterator);
-  int64_t split = std::distance(iterTypes.begin(), it);
-
-  // Helper to fuse the producers greedily using a queue of fusion candidates.
-  auto fuseProducersGreedily = [&](ArrayRef<OpOperand *> operands) {
-    SmallVector<OpOperand *> candidates(operands.begin(), operands.end());
-    while (!candidates.empty()) {
-      FailureOr<LinalgOp> fusedProducer =
-          tileLoopNest.fuseProducer(b, candidates.pop_back_val());
-      if (failed(fusedProducer))
-        continue;
-      candidates.append(fusedProducer->getInputAndOutputOperands());
-    }
-  };
-
-  // Perform tiling and fusion in two steps. We need to respect the loop
-  // interchange here; filter parellel dimensions based on their order *after*
-  // permutation but pass in the original configuration *before* permuation,
-  // given the tiling and interchange happen together.
-  SmallVector<int64_t> outerTileSizes(tileSizes.size(), 0);
-  SmallVector<int64_t> innerTileSizes(tileSizes.size(), 0);
-  for (int64_t i : tileInterchange.take_front(split))
-    outerTileSizes[i] = tileSizes[i];
-  for (int64_t i : tileInterchange.drop_front(split))
-    innerTileSizes[i] = tileSizes[i];
-
-  // Tile the outer parallel loops and fuse the output operands.
-  if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange,
-                                     tileDistribution)))
-    return failure();
-  fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands());
-
-  // Tile the remaining loops and fuse the input operands.
-  if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange,
-                                     tileDistribution)))
-    return failure();
-  fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands());
-
-  // Exit if the tile loop nest is empty since all tile sizes are zero.
-  if (tileLoopNest.isEmpty())
-    return failure();
-
-  return tileLoopNest;
-}
index 3faf45e..162e74f 100644 (file)
@@ -51,44 +51,6 @@ using namespace linalg;
 
 namespace {
 
-/// Configurable pass to apply pattern-based tiling and fusion.
-struct LinalgStrategyTileAndFusePass
-    : public impl::LinalgStrategyTileAndFusePassBase<
-          LinalgStrategyTileAndFusePass> {
-
-  LinalgStrategyTileAndFusePass() = default;
-
-  LinalgStrategyTileAndFusePass(StringRef opName,
-                                LinalgTilingAndFusionOptions opt,
-                                LinalgTransformationFilter filt)
-      : options(std::move(opt)), filter(std::move(filt)) {
-    this->anchorOpName.setValue(opName.str());
-  }
-
-  void runOnOperation() override {
-    auto funcOp = getOperation();
-    if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
-      return;
-
-    RewritePatternSet tilingAndFusionPattern(funcOp.getContext());
-    if (!anchorOpName.empty()) {
-      tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
-          anchorOpName, funcOp.getContext(), options, filter);
-    } else {
-      tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
-          funcOp.getContext(), options, filter);
-    }
-    // Search the root operation using bottom up traversal.
-    GreedyRewriteConfig config;
-    config.useTopDownTraversal = false;
-    (void)applyPatternsAndFoldGreedily(
-        funcOp, std::move(tilingAndFusionPattern), config);
-  }
-
-  LinalgTilingAndFusionOptions options;
-  LinalgTransformationFilter filter;
-};
-
 /// Configurable pass to apply pattern-based linalg tiling.
 struct LinalgStrategyTilePass
     : public impl::LinalgStrategyTilePassBase<LinalgStrategyTilePass> {
@@ -139,15 +101,6 @@ struct LinalgStrategyRemoveMarkersPass
 };
 } // namespace
 
-/// Create a LinalgStrategyTileAndFusePass.
-std::unique_ptr<OperationPass<func::FuncOp>>
-mlir::createLinalgStrategyTileAndFusePass(
-    StringRef opName, const LinalgTilingAndFusionOptions &options,
-    const LinalgTransformationFilter &filter) {
-  return std::make_unique<LinalgStrategyTileAndFusePass>(opName, options,
-                                                         filter);
-}
-
 /// Create a LinalgStrategyTilePass.
 std::unique_ptr<OperationPass<func::FuncOp>>
 mlir::createLinalgStrategyTilePass(StringRef opName,
index b3062f5..938b9e7 100644 (file)
@@ -447,82 +447,6 @@ mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite(
   return paddedOp;
 }
 
-/// Linalg tile and fuse tensor ops pattern.
-mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
-    LinalgTileAndFuseTensorOpsPattern(MLIRContext *context,
-                                      LinalgTilingAndFusionOptions options,
-                                      LinalgTransformationFilter f,
-                                      PatternBenefit benefit)
-    : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
-      filter(std::move(f)), options(std::move(options)) {}
-
-mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
-    LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context,
-                                      LinalgTilingAndFusionOptions options,
-                                      LinalgTransformationFilter f,
-                                      PatternBenefit benefit)
-    : RewritePattern(opName, benefit, context), filter(std::move(f)),
-      options(std::move(options)) {}
-
-FailureOr<mlir::linalg::TileLoopNest>
-mlir::linalg::LinalgTileAndFuseTensorOpsPattern::returningMatchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
-  LinalgOp rootOp = dyn_cast<LinalgOp>(op);
-  if (!rootOp)
-    return failure();
-  if (failed(filter.checkAndNotify(rewriter, op)))
-    return failure();
-
-  // Check `tileSizes` contains a tile size for every `rootOp` loop dimension.
-  if (options.tileSizes.size() < rootOp.getNumLoops())
-    return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops");
-
-  // Check `tileInterchange` contains no entries or as many as `tileSizes`.
-  if (!options.tileInterchange.empty() &&
-      options.tileInterchange.size() != options.tileSizes.size())
-    return rewriter.notifyMatchFailure(
-        op, "expect the number of tile sizes and interchange dims to match");
-
-  // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`.
-  SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(),
-                                     options.tileSizes.begin() +
-                                         rootOp.getNumLoops());
-  SmallVector<int64_t> rootInterchange =
-      options.tileInterchange.empty()
-          ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops()))
-          : SmallVector<int64_t>(options.tileInterchange.begin(),
-                                 options.tileInterchange.begin() +
-                                     rootOp.getNumLoops());
-
-  // Check `rootTileSizes` contains non-zero tile sizes.
-  if (llvm::count(rootTileSizes, 0) == static_cast<long>(rootTileSizes.size()))
-    return rewriter.notifyMatchFailure(
-        op, "expect at least one non-zero tile size");
-
-  // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions.
-  // It has to be a permutation since the tiling cannot tile the same loop
-  // dimension multiple times.
-  if (!isPermutation(rootInterchange))
-    return rewriter.notifyMatchFailure(
-        op, "expect the tile interchange permutes the root loops");
-
-  // Tile `rootOp` and fuse its producers.
-  FailureOr<TileLoopNest> tileLoopNest =
-      tileConsumerAndFuseProducers(rewriter, rootOp, rootTileSizes,
-                                   rootInterchange, options.tileDistribution);
-  if (failed(tileLoopNest))
-    return rewriter.notifyMatchFailure(
-        op, "tileConsumerAndFuseProducers failed unexpectedly");
-
-  // Replace all uses of the tiled loop operation.
-  rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
-
-  // Apply the filter if specified.
-  for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps())
-    filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
-  return tileLoopNest;
-}
-
 /// Linalg generalization pattern.
 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
     MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit)
index 65bc941..2630da3 100644 (file)
@@ -45,12 +45,12 @@ scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
 
 /// Helper method to adjust the interchange vector to match the iteration
 /// domain.
-static SmallVector<unsigned>
-fillInterchangeVector(ArrayRef<unsigned> interchangeVector,
+static SmallVector<int64_t>
+fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
                       size_t iterationDomainSize) {
-  SmallVector<unsigned> filledVector = llvm::to_vector(interchangeVector);
+  SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
   if (filledVector.size() < iterationDomainSize) {
-    auto range = llvm::seq<unsigned>(filledVector.size(), iterationDomainSize);
+    auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
     filledVector.append(range.begin(), range.end());
   }
   if (filledVector.size() > iterationDomainSize)
@@ -61,23 +61,23 @@ fillInterchangeVector(ArrayRef<unsigned> interchangeVector,
 /// Helper method to apply permutation to a vector
 template <typename T>
 static SmallVector<T> applyPermutationToVector(const SmallVector<T> &vector,
-                                               ArrayRef<unsigned> interchange) {
+                                               ArrayRef<int64_t> interchange) {
   assert(interchange.size() == vector.size());
   return llvm::to_vector(
-      llvm::map_range(interchange, [&](unsigned val) { return vector[val]; }));
+      llvm::map_range(interchange, [&](int64_t val) { return vector[val]; }));
 }
 /// Helper method to apply to invert a permutation.
-static SmallVector<unsigned>
-invertPermutationVector(ArrayRef<unsigned> interchange) {
-  SmallVector<unsigned> inversion(interchange.size());
+static SmallVector<int64_t>
+invertPermutationVector(ArrayRef<int64_t> interchange) {
+  SmallVector<int64_t> inversion(interchange.size());
   for (const auto &pos : llvm::enumerate(interchange)) {
     inversion[pos.value()] = pos.index();
   }
   return inversion;
 }
 /// Method to check if an interchange vector is a permutation.
-static bool isPermutation(ArrayRef<unsigned> interchange) {
-  llvm::SmallDenseSet<unsigned, 4> seenVals;
+static bool isPermutation(ArrayRef<int64_t> interchange) {
+  llvm::SmallDenseSet<int64_t, 4> seenVals;
   for (auto val : interchange) {
     if (seenVals.count(val))
       return false;
@@ -298,7 +298,7 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
   {
     // If there is an interchange specified, permute the iteration domain and
     // the tile sizes.
-    SmallVector<unsigned> interchangeVector;
+    SmallVector<int64_t> interchangeVector;
     if (!options.interchangeVector.empty()) {
       interchangeVector = fillInterchangeVector(options.interchangeVector,
                                                 iterationDomain.size());
@@ -365,7 +365,7 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
   // 5. Yield all the results of the tiled operation. The surrounding loop
   //    nest is modified to insert a destructive update pattern to yield
   //    from the loop nest values to replace the untiled op with.
-  unsigned numResults = op->getNumResults();
+  int64_t numResults = op->getNumResults();
   SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults),
       resultSizesList(numResults);
   for (auto result : llvm::enumerate(op->getResults())) {
@@ -443,7 +443,7 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
 
   // 1. First tile the consumer.
   scf::SCFTileAndFuseResult tileAndFuseResult;
-  llvm::SmallDenseMap<Value, unsigned> yieldedValueToResultNumber;
+  llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
   {
     FailureOr<scf::SCFTilingResult> tilingResult =
         tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
@@ -566,7 +566,7 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
           *destinationIterArg.value());
     }
     if (iterArgNumber) {
-      unsigned resultNumber = fusableProducer.getResultNumber();
+      int64_t resultNumber = fusableProducer.getResultNumber();
       if (auto producerOp =
               dyn_cast<TilingInterface>(fusableProducer.getOwner())) {
         SmallVector<Value> destination =
diff --git a/mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir b/mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir
deleted file mode 100644 (file)
index 01f2195..0000000
+++ /dev/null
@@ -1,55 +0,0 @@
-// RUN: mlir-opt %s -test-linalg-transform-patterns=test-tile-fuse-and-distribute-options -split-input-file | FileCheck %s
-
-//      CHECK: #[[MULMAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
-//      CHECK: #[[ADDMAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
-//      CHECK: func @fill_matmul_tensors(
-// CHECK-SAME:    %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
-// CHECK-SAME:    %[[TB:[0-9a-z]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
-func.func @fill_matmul_tensors(
-  %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>)
-    -> tensor<?x?xf32> {
-//  CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-//  CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-//  CHECK-DAG: %[[BIDY:.*]] = gpu.block_id y
-//  CHECK-DAG: %[[NBLOCKSY:.*]] = gpu.grid_dim y
-//  CHECK-DAG: %[[BIDX:.*]] = gpu.block_id x
-//  CHECK-DAG: %[[NBLOCKSX:.*]] = gpu.grid_dim x
-//  CHECK-DAG: %[[INIT:.+]] = tensor.empty
-//      CHECK: %[[MUL:.+]] = affine.apply #[[MULMAP]]()[%[[BIDY]], %[[C8]]]
-//      CHECK: %[[LBY:.+]] = affine.apply #[[ADDMAP]]()[%[[MUL]], %[[C0]]]
-//      CHECK: %[[STEPY:.+]] = affine.apply #[[MULMAP]]()[%[[NBLOCKSY]], %[[C8]]]
-//      CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[INIT]]) -> (tensor<?x?xf32>) {
-//      CHECK: %[[MUL:.+]] = affine.apply #[[MULMAP]]()[%[[BIDX]], %[[C8]]]
-//      CHECK: %[[LBX:.+]] = affine.apply #[[ADDMAP]]()[%[[MUL]], %[[C0]]]
-//      CHECK: %[[STEPX:.+]] = affine.apply #[[MULMAP]]()[%[[NBLOCKSX]], %[[C8]]]
-//      CHECK:   %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xf32>) {
-//      CHECK:     %[[OUTSLICEA:.+]] = tensor.extract_slice %{{.*}}[%{{.*}}, 0] [%{{.*}}, %{{.*}}] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-//      CHECK:     %[[OUTSLICEB:.+]] = tensor.extract_slice %{{.*}}[0, %{{.*}}] [%{{.*}}, %{{.*}}] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-//      CHECK:     %[[SLICE:.+]] = tensor.extract_slice %[[TC1]]
-//      CHECK:     %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[SLICE]]
-//      CHECK:     %[[sTD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[FILL]]) -> (tensor<?x?xf32>) {
-//      CHECK:       %[[sTA:.*]] = tensor.extract_slice %[[OUTSLICEA]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
-//      CHECK:       %[[sTB:.*]] = tensor.extract_slice %[[OUTSLICEB]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
-//      CHECK:       %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
-//      CHECK:       %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor<?x?xf32>, tensor<?x?xf32>)
-// CHECK-SAME:                                  outs(%[[sTC]] : tensor<?x?xf32>)  -> tensor<?x?xf32>
-//      CHECK:       %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}]  : tensor<?x?xf32> into tensor<?x?xf32>
-//      CHECK:       scf.yield %[[TD]] : tensor<?x?xf32>
-//      CHECK:     %[[TD2:.*]] = tensor.insert_slice %[[sTD2]] into %[[TC1]][{{.*}}]  : tensor<?x?xf32> into tensor<?x?xf32>
-//      CHECK:     scf.yield %[[TD2]] : tensor<?x?xf32>
-//      CHECK:   scf.yield %[[TD1]] : tensor<?x?xf32>
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %cst = arith.constant 0.0 : f32
-  %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
-  %1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
-  %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
-  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
-  %4 = linalg.matmul {__internal_linalg_transform__ = "tensors_fuse_distribute1"}
-       ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
-      outs(%3: tensor<?x?xf32>)
-    -> tensor<?x?xf32>
-
-//      CHECK: return %[[TD0]] : tensor<?x?xf32>
-  return %4 : tensor<?x?xf32>
-}
index f26462b..e9801d8 100644 (file)
@@ -3,10 +3,11 @@
 // CHECK-LABEL: func.func @fuse_unary
 func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
 
-  //     CHECK:   scf.for
-  //     CHECK:     scf.for
+  //     CHECK: %[[RES:.*]] = scf.for
+  //     CHECK:    scf.for
   //     CHECK:       linalg.elemwise_unary
   //     CHECK:       linalg.elemwise_binary
+  //     CHECK: return %[[RES]]
   %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
                              outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
   %1 = linalg.elemwise_binary ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
@@ -28,14 +29,15 @@ transform.with_pdl_patterns {
 // CHECK-LABEL: func.func @fuse_unary
 func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
 
-  //     CHECK:   scf.for
+  //     CHECK: %[[PARTIAL_RES:.*]] = scf.for
   //     CHECK:     scf.for
   //     CHECK:       linalg.elemwise_unary
   //     CHECK:       linalg.elemwise_binary
-  //     CHECK:   scf.for
+  //     CHECK: %[[RES:.*]] = scf.for {{.*}}%[[PARTIAL_RES]]
   //     CHECK:     scf.for
   //     CHECK:       linalg.elemwise_unary
   //     CHECK:       linalg.elemwise_binary
+  //     CHECK: return %[[RES]]
   %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
                              outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
   %1 = linalg.elemwise_binary ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
@@ -61,19 +63,23 @@ func.func @interchange_reduction(%input: tensor<12x7x25xf32>) -> tensor<12x25xf3
   %five = arith.constant 5.0 : f32
   %init = tensor.empty() : tensor<12x25xf32>
 
-//       CHECK: %[[INIT:.+]] = tensor.empty()
+//   CHECK-DAG: %[[INIT:.+]] = tensor.empty()
 //   CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
 //   CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
-//       CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %{{.+}} step %[[C5]] iter_args(%[[FOR_ARG0:.+]] = %[[INIT]])
+//       CHECK: %[[RES:.*]] = scf.for %[[IV0:.+]] = %{{.+}} to %{{.+}} step %[[C5]] iter_args(%[[FOR_ARG0:.+]] = %[[INIT]])
 //       CHECK:   scf.for %[[IV1:.+]] = %{{.+}} to %{{.+}} step %[[C7]] iter_args(%[[FOR_ARG1:.+]] = %[[FOR_ARG0]])
 //       CHECK:     %[[OUT_SLICE0:.+]] = tensor.extract_slice %[[INPUT]][%[[IV0]], 0, %[[IV1]]]
 //       CHECK:     %[[OUT_SLICE1:.+]] = tensor.extract_slice %[[FOR_ARG1]][%[[IV0]], %[[IV1]]]
 //       CHECK:     %[[FILL:.+]] = linalg.fill {{.+}} outs(%[[OUT_SLICE1]] : tensor<?x?xf32>)
+//
+// Extra 4 constant is introduced, discard it.
+//       CHECK:     arith.constant 4 : index
 //       CHECK:     %[[C4:.+]] = arith.constant 4 : index
 //       CHECK:     scf.for %[[IV2:.+]] = %{{.+}} to %{{.+}} step %[[C4]] iter_args(%[[FOR_ARG2:.+]] = %[[FILL]])
 //       CHECK:       %[[IN_SLICE:.+]] = tensor.extract_slice %[[OUT_SLICE0]]
 //       CHECK:       %[[OUT_SLICE2:.+]] = tensor.extract_slice %[[FOR_ARG2]][0, 0]
 //       CHECK:       linalg.generic {{.+}} ins(%[[IN_SLICE]] : tensor<?x?x?xf32>) outs(%[[OUT_SLICE2]] : tensor<?x?xf32>)
+//       CHECK: return %[[RES]]
 
   %fill = linalg.fill ins(%five : f32) outs(%init : tensor<12x25xf32>) -> tensor<12x25xf32>
   %0 = linalg.generic {
@@ -92,6 +98,7 @@ transform.with_pdl_patterns {
   transform.sequence %arg0 failures(propagate) {
   ^bb1(%arg1: !pdl.operation):
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
-    %1, %loops:3 = transform.structured.fuse %0 {tile_sizes = [5, 4, 7], tile_interchange = [0, 2, 1]}
+    %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [5, 0, 7], tile_interchange = [0, 2, 1]}
+    %2, %loops_2 = transform.structured.tile %1 [0, 4]
   }
 }
index 01ea7e1..781936f 100644 (file)
@@ -65,10 +65,6 @@ struct TestLinalgTransforms
       *this, "test-tile-and-distribute-options",
       llvm::cl::desc("Test tile and distribute options"),
       llvm::cl::init(false)};
-  Option<bool> testTileFuseAndDistributionOptions{
-      *this, "test-tile-fuse-and-distribute-options",
-      llvm::cl::desc("Test tile, fuse and distribute options"),
-      llvm::cl::init(false)};
   Option<bool> testVectorTransferForwardingPatterns{
       *this, "test-vector-transfer-forwarding-patterns",
       llvm::cl::desc(
@@ -415,27 +411,6 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
   }
 }
 
-static void fillTileFuseAndDistributePatterns(MLIRContext *context,
-                                              RewritePatternSet &patterns) {
-  LinalgLoopDistributionOptions cyclicNprocsEqNiters;
-  SmallVector<linalg::DistributionMethod> distributionMethod = {
-      DistributionMethod::Cyclic, DistributionMethod::Cyclic};
-  cyclicNprocsEqNiters.procInfo =
-      [distributionMethod](OpBuilder &b, Location loc,
-                           ArrayRef<Range> parallelLoopRanges) {
-        return getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>(
-            b, loc, parallelLoopRanges, distributionMethod);
-      };
-  patterns.add<LinalgTileAndFuseTensorOpsPattern>(
-      MatmulOp::getOperationName(), context,
-      LinalgTilingAndFusionOptions()
-          .setTileSizes({8, 8, 4})
-          .setDistributionOptions(cyclicNprocsEqNiters),
-      LinalgTransformationFilter(
-          StringAttr::get(context, "tensors_fuse_distribute1"),
-          StringAttr::get(context, "tensors_after_fuse_distribute1")));
-}
-
 static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) {
   RewritePatternSet forwardPattern(funcOp.getContext());
   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
@@ -552,12 +527,6 @@ void TestLinalgTransforms::runOnOperation() {
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
     return;
   }
-  if (testTileFuseAndDistributionOptions) {
-    RewritePatternSet patterns(&getContext());
-    fillTileFuseAndDistributePatterns(&getContext(), patterns);
-    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
-    return;
-  }
   if (testPatterns)
     return applyPatterns(getOperation());
   if (testVectorTransferForwardingPatterns)
index 977a054..8e3b976 100644 (file)
@@ -199,7 +199,7 @@ static void addPatternForTiling(MLIRContext *context,
                                 RewritePatternSet &patterns,
                                 StringRef filterName,
                                 ArrayRef<int64_t> tileSizes,
-                                ArrayRef<unsigned> interchange = {}) {
+                                ArrayRef<int64_t> interchange = {}) {
   scf::SCFTilingOptions tilingOptions;
   tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
   linalg::LinalgTransformationFilter filter(
@@ -211,7 +211,7 @@ static void addPatternForTileAndFuse(MLIRContext *context,
                                      RewritePatternSet &patterns,
                                      StringRef filterName,
                                      ArrayRef<int64_t> tileSizes,
-                                     ArrayRef<unsigned> interchange = {}) {
+                                     ArrayRef<int64_t> interchange = {}) {
   scf::SCFTileAndFuseOptions tileAndFuseOptions;
   tileAndFuseOptions.tilingOptions.setTileSizes(tileSizes).setInterchange(
       interchange);