[mlir][TilingInterface] Add pattern to tile using TilingInterface and implement Tilin...
authorMahesh Ravishankar <ravishankarm@google.com>
Mon, 13 Jun 2022 19:56:32 +0000 (19:56 +0000)
committerMahesh Ravishankar <ravishankarm@google.com>
Mon, 13 Jun 2022 20:37:44 +0000 (20:37 +0000)
This patch adds support for tiling operations that implement the
TilingInterface.
- It separates the loop constructs that are used to iterate over tile
  from the implementation of the tiling itself. For example, the use
  of destructive updates is more related to use of scf.for for
  iterating over tiles that are tensors.
- To test the transformation, TilingInterface is implemented for
  LinalgOps. The separation of the looping constructs used from the
  implementation of tile code generation greatly simplifies the
  latter.
- The implementation of TilingInterface for LinalgOp is kept as an
  external model for now till this approach can be fully flushed out
  to replace the existing tiling + fusion approaches in Linalg.

Differential Revision: https://reviews.llvm.org/D127133

21 files changed:
mlir/include/mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h [new file with mode: 0644]
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/include/mlir/Dialect/SCF/TileUsingInterface.h [new file with mode: 0644]
mlir/include/mlir/Dialect/SCF/Utils/Utils.h
mlir/include/mlir/Interfaces/TilingInterface.td
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp [new file with mode: 0644]
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp [new file with mode: 0644]
mlir/lib/Dialect/SCF/Utils/Utils.cpp
mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir [new file with mode: 0644]
mlir/test/lib/CMakeLists.txt
mlir/test/lib/Interfaces/CMakeLists.txt [new file with mode: 0644]
mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt [new file with mode: 0644]
mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp [new file with mode: 0644]
mlir/tools/mlir-opt/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h
new file mode 100644 (file)
index 0000000..5b88f1d
--- /dev/null
@@ -0,0 +1,20 @@
+//===- TilingInterfaceImpl.h - Implementation of TilingInterface ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_TILINGINTERFACEIMPL_H
+#define MLIR_DIALECT_LINALG_TILINGINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace linalg {
+void registerTilingInterfaceExternalModels(DialectRegistry &registry);
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_TILINGINTERFACEIMPL_H
index 3e9d072..36b143b 100644 (file)
@@ -164,11 +164,11 @@ bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer,
 SmallVector<Value> computeTileOffsets(OpBuilder &b, Location loc,
                                       ValueRange ivs, ValueRange tileSizes);
 
-/// Compute tile sizes, given a list of loop `ivs`, `tileSizes` and dimension
+/// Compute tile sizes, given a list of `tileSizes` and dimension
 /// sizes (`sizeBounds`). In case a tile size is zero (i.e., no tiling), the
 /// corresponding result size is the corresponding value from `sizeBounds`.
 /// Note: The returned tile sizes are closed intervals.
-SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc, ValueRange ivs,
+SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc,
                                     ValueRange tileSizes,
                                     ArrayRef<Value> sizeBounds);
 
diff --git a/mlir/include/mlir/Dialect/SCF/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/TileUsingInterface.h
new file mode 100644 (file)
index 0000000..25911ce
--- /dev/null
@@ -0,0 +1,87 @@
+//===- TileUsingInterface.h - Tiling ops using TilingInterface --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SCF_TILEUSINGINTERFACE_H
+#define MLIR_DIALECT_SCF_TILEUSINGINTERFACE_H
+
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/TilingInterface.h"
+
+namespace mlir {
+class Operation;
+class PatternRewriter;
+class TilingInterface;
+} // namespace mlir
+
+namespace mlir {
+namespace scf {
+
+using SCFTileSizeComputationFunction =
+    std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>;
+
+/// Options to use to control tiling.
+struct SCFTilingOptions {
+  /// Computation function that returns the tile sizes for each operation.
+  /// Delayed construction of constant tile sizes should occur to interoperate
+  /// with folding.
+  SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
+
+  SCFTilingOptions &
+  setTileSizeComputationFunction(SCFTileSizeComputationFunction fun) {
+    tileSizeComputationFunction = std::move(fun);
+    return *this;
+  }
+  /// Set the `tileSizeComputationFunction` to return the values `ts`. The
+  /// values must not fold away when tiling. Otherwise, use a more robust
+  /// `tileSizeComputationFunction`.
+  SCFTilingOptions &setTileSizes(const SmallVector<Value, 4> &ts) {
+    tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; };
+    return *this;
+  }
+  /// Convenience function to set the `tileSizeComputationFunction` to a
+  /// function that computes tile sizes at the point they are needed. Allows
+  /// proper interaction with folding.
+  SCFTilingOptions &setTileSizes(ArrayRef<int64_t> ts);
+};
+
+struct SCFTilingResult {
+  Operation *tiledOp;
+  SmallVector<scf::ForOp> loops;
+};
+
+/// Pattern to tile an op that implementas the `TilingInterface` using
+/// `scf.for` for iterating over the tiles.
+struct TileUsingSCFForOp : public OpInterfaceRewritePattern<TilingInterface> {
+  /// Construct a generic pattern applied to all TilingInterface ops.
+  TileUsingSCFForOp(MLIRContext *context, SCFTilingOptions options,
+                    PatternBenefit benefit = 1);
+
+  /// Construct a generic pattern applied to `opName`.
+  TileUsingSCFForOp(StringRef opName, MLIRContext *context,
+                    SCFTilingOptions options, PatternBenefit benefit = 1);
+
+  /// `matchAndRewrite` implementation that returns the significant transformed
+  /// pieces of IR.
+  FailureOr<SCFTilingResult>
+  returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const;
+
+  LogicalResult matchAndRewrite(TilingInterface op,
+                                PatternRewriter &rewriter) const override {
+    return returningMatchAndRewrite(op, rewriter);
+  }
+
+private:
+  /// Options to control tiling;
+  SCFTilingOptions options;
+};
+
+} // namespace scf
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SCF_TILEUSINGINTERFACE_H
index ebd055c..3c75754 100644 (file)
@@ -13,6 +13,7 @@
 #ifndef MLIR_DIALECT_SCF_UTILS_UTILS_H_
 #define MLIR_DIALECT_SCF_UTILS_UTILS_H_
 
+#include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
@@ -32,12 +33,6 @@ class CallOp;
 class FuncOp;
 } // namespace func
 
-namespace scf {
-class IfOp;
-class ForOp;
-class ParallelOp;
-} // namespace scf
-
 /// Replace the `loop` with `newIterOperands` added as new initialization
 /// values. `newYieldValuesFn` is a callback that can be used to specify
 /// the additional values to be yielded by the loop. The number of
@@ -57,6 +52,25 @@ scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
                                     ValueRange newIterOperands,
                                     const NewYieldValueFn &newYieldValuesFn);
 
+/// Update a perfectly nested loop nest to yield new values from the innermost
+/// loop and propagating it up through the loop nest. This function
+/// - Expects `loopNest` to be a perfectly nested loop with outer most loop
+///   first and innermost loop last.
+/// - `newIterOperands` are the initialization values to be used for the
+///    outermost loop
+/// - `newYielValueFn` is the callback that generates the new values to be
+///   yielded from within the innermost loop.
+/// - The original loops are not erased,  but are left in a "no-op" state where
+///   the body of the loop just yields the basic block arguments that correspond
+///   to the initialization values of a loop. The original loops are dead after
+///   this method.
+/// - All uses of the `newIterOperands` within the generated new loop
+///   are replaced with the corresponding `BlockArgument` in the loop body.
+SmallVector<scf::ForOp>
+replaceLoopNestWithNewYields(OpBuilder &builder, ArrayRef<scf::ForOp> loopNest,
+                             ValueRange newIterOperands,
+                             NewYieldValueFn newYieldValueFn);
+
 /// Outline a region with a single block into a new FuncOp.
 /// Assumes the FuncOp result types is the type of the yielded operands of the
 /// single block. This constraint makes it easy to determine the result.
index 6346899..6069013 100644 (file)
@@ -98,6 +98,28 @@ def TilingInterface : OpInterface<"TilingInterface"> {
         /*defaultImplementation=*/[{
           return {};
         }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to return the position of the result tile computed by the tiled operation.
+
+          Specifies what tile of the result of the original tensor is computed
+          by the tiled implementation. Expects the same `offsets` and `sizes` as
+          used to obtain the tiled implementation of the operation.
+        }],
+        /*retType=*/"LogicalResult",
+        /*methodName=*/"getResultTilePosition",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$resultNumber,
+          "ArrayRef<OpFoldResult> ":$offsets,
+          "ArrayRef<OpFoldResult> ":$sizes,
+          "SmallVector<OpFoldResult> &":$resultOffsets,
+          "SmallVector<OpFoldResult> &":$resultSizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
       >
   ];
 }
index fc17fba..cf77186 100644 (file)
@@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   SparseTensorRewriting.cpp
   SplitReduction.cpp
   Tiling.cpp
+  TilingInterfaceImpl.cpp
   Transforms.cpp
   Vectorization.cpp
 
index dfc7897..bb47605 100644 (file)
@@ -320,8 +320,7 @@ static LogicalResult tilePadOp(RewriterBase &builder, tensor::PadOp op,
         // Compute offsets and sizes of ExtractSliceOp.
         SmallVector<Value> offsets =
             computeTileOffsets(b, loc, localIvs, tileSizes);
-        SmallVector<Value> sizes =
-            computeTileSizes(b, loc, localIvs, tileSizes, allDims);
+        SmallVector<Value> sizes = computeTileSizes(b, loc, tileSizes, allDims);
         // Create ExtractSliceOp: Extract a tile from the tensor::PadOp.
         // Note: The tensor::PadOp is located outside of the loop nest. It is
         // later moved inside by ExtractSliceOfPadTensorSwapPattern.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
new file mode 100644 (file)
index 0000000..c67097a
--- /dev/null
@@ -0,0 +1,156 @@
+//===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/TilingInterface.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+
+/// External model implementation of TilingInterface for LinalgOps. An external
+/// model implementation is used for now till the use of `TilingInterface` is
+/// on-par with the current Linalg tiling + fusion patterns. Once it is
+/// maybe possible to move this into the op-definition (though there are
+/// advantages to leaving it as an external model)
+template <typename LinalgOpTy>
+struct LinalgOpTilingInterface
+    : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
+                                            LinalgOpTy> {
+
+  /// Return the destination operands.
+  SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
+    return llvm::cast<LinalgOp>(op).getOutputOperands();
+  }
+
+  /// Return the loop iterator type.
+  SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
+    LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
+    return llvm::to_vector(
+        llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) {
+          return strAttr.cast<StringAttr>().getValue();
+        }));
+  }
+
+  /// Return the iteration domain range.
+  SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
+    Location loc = op->getLoc();
+    LinalgOp linalgOp = cast<LinalgOp>(op);
+    auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc);
+    AffineMap map = linalgOp.getShapesToLoopsMap();
+    Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+    Value one = b.create<arith::ConstantIndexOp>(loc, 1);
+    return llvm::to_vector(llvm::map_range(
+        applyMapToValues(b, loc, map, allShapesSizes), [&](Value v) {
+          return Range{zero, v, one};
+        }));
+  }
+
+  // Instantiate the tiled implementation of the operation.
+  SmallVector<Operation *>
+  getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest,
+                         ArrayRef<OpFoldResult> offsets,
+                         ArrayRef<OpFoldResult> sizes,
+                         bool tileDestOperands) const {
+    // Leave the `sizeBounds` value empty. That is only needed when the `sizes`
+    // specified could lead to out of bounds accesses.
+    Location loc = op->getLoc();
+    LinalgOp linalgOp = cast<LinalgOp>(op);
+    SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
+    SmallVector<Value, 4> tiledOperands = makeTiledShapes(
+        b, loc, linalgOp, valuesToTile,
+        getValueOrCreateConstantIndexOp(b, loc, offsets),
+        getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true);
+
+    SmallVector<Type> resultTensorTypes = llvm::to_vector(llvm::map_range(
+        linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) {
+          return tiledOperands[opOperand->getOperandNumber()].getType();
+        }));
+
+    Operation *tiledOp =
+        linalgOp.clone(b, loc, resultTensorTypes, tiledOperands);
+
+    return {tiledOp};
+  }
+
+  // Return the details of the output tile generated by the tiled
+  // implementation.
+  LogicalResult
+  getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
+                        ArrayRef<OpFoldResult> offsets,
+                        ArrayRef<OpFoldResult> sizes,
+                        SmallVector<OpFoldResult> &resultOffsets,
+                        SmallVector<OpFoldResult> &resultSizes) const {
+    Location loc = op->getLoc();
+    LinalgOp linalgOp = cast<LinalgOp>(op);
+
+    AffineExpr d0;
+    bindDims(b.getContext(), d0);
+
+    auto fullyComposeAffineMapAndOperands = [](OpBuilder &builder, Location loc,
+                                               AffineExpr expr,
+                                               ValueRange operands) -> Value {
+      AffineMap map = AffineMap::inferFromExprList({expr}).front();
+      SmallVector<Value> normalizedOperands(operands.begin(), operands.end());
+      mlir::fullyComposeAffineMapAndOperands(&map, &normalizedOperands);
+      canonicalizeMapAndOperands(&map, &normalizedOperands);
+      return builder.createOrFold<AffineApplyOp>(loc, map, normalizedOperands);
+    };
+
+    SmallVector<Value> sizeVals =
+        getValueOrCreateConstantIndexOp(b, loc, sizes);
+    SmallVector<Value> subShapeSizes =
+        llvm::to_vector(llvm::map_range(sizeVals, [&](Value v) {
+          return fullyComposeAffineMapAndOperands(b, loc, d0 - 1, v);
+        }));
+    OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
+    Value sliceOpResult =
+        makeTiledShape(b, loc, outOperand->get(), sizeVals,
+                       linalgOp.getTiedIndexingMap(outOperand),
+                       getValueOrCreateConstantIndexOp(b, loc, offsets),
+                       /*ubs*/ {}, subShapeSizes, true);
+    auto sliceOp = sliceOpResult.getDefiningOp<tensor::ExtractSliceOp>();
+    if (!sliceOp)
+      return failure();
+    resultOffsets = sliceOp.getMixedOffsets();
+    resultSizes = sliceOp.getMixedSizes();
+    return success();
+  }
+};
+
+} // namespace
+
+template <typename OpType> static void registerOne(MLIRContext *ctx) {
+  OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
+}
+
+/// Variadic helper function.
+template <typename... OpTypes> static void registerAll(MLIRContext *ctx) {
+  // FIXME: In c++17 this can be simplified by using 'fold expressions'.
+  (void)std::initializer_list<int>{0, (registerOne<OpTypes>(ctx), 0)...};
+}
+
+#define GET_OP_LIST
+
+void mlir::linalg::registerTilingInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
+    registerOne<linalg::GenericOp>(ctx);
+    registerAll<
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
+        >(ctx);
+  });
+}
index bfd2c68..bf68434 100644 (file)
@@ -893,7 +893,7 @@ SmallVector<Value> computeTileOffsets(OpBuilder &b, Location loc,
   return offsets;
 }
 
-SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc, ValueRange ivs,
+SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc,
                                     ValueRange tileSizes,
                                     ArrayRef<Value> sizeBounds) {
   SmallVector<Value> sizes;
@@ -923,7 +923,7 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &b, Location loc,
   // that define tile subshapes.
   SmallVector<Value> lbs = computeTileOffsets(b, loc, ivs, tileSizes);
   SmallVector<Value> subShapeSizes =
-      computeTileSizes(b, loc, ivs, tileSizes, sizeBounds);
+      computeTileSizes(b, loc, tileSizes, sizeBounds);
 
   assert(static_cast<int64_t>(valuesToTile.size()) ==
              linalgOp.getNumInputsAndOutputs() &&
index 8f5322d..c876c90 100644 (file)
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   ParallelLoopFusion.cpp
   ParallelLoopTiling.cpp
   StructuralTypeConversions.cpp
+  TileUsingInterface.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
new file mode 100644 (file)
index 0000000..0f71d52
--- /dev/null
@@ -0,0 +1,249 @@
+//===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the tiling using TilingInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/TileUsingInterface.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "tile-using-interface"
+
+using namespace mlir;
+
+scf::SCFTilingOptions &
+scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
+  assert(!tileSizeComputationFunction && "tile sizes already set");
+  SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
+  tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
+    OpBuilder::InsertionGuard guard(b);
+    b.setInsertionPointToStart(
+        &op->getParentOfType<func::FuncOp>().getBody().front());
+    return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
+      Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
+      return v;
+    }));
+  };
+  return *this;
+}
+
+/// Generate an empty loop nest that represents the tiled loop nest shell.
+/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
+/// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
+/// - In `offsets` and `sizes` return the multi-dimensional offset and size of
+/// the
+///   tile processed within the inner most loop.
+static SmallVector<scf::ForOp>
+generateTileLoopNest(OpBuilder &builder, Location loc,
+                     ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
+                     SmallVector<OpFoldResult> &offsets,
+                     SmallVector<OpFoldResult> &sizes) {
+  assert(!loopRanges.empty() && "expected at least one loop range");
+  assert(loopRanges.size() == tileSizeVals.size() &&
+         "expected as many tile sizes as loop ranges");
+  OpBuilder::InsertionGuard guard(builder);
+  SmallVector<scf::ForOp> loops;
+  offsets.resize(loopRanges.size());
+  sizes.resize(loopRanges.size());
+
+  // The tile size to use (to avoid out of bounds access) is  minimum of
+  // `tileSize` and `ub - iv`, where `iv` is the induction variable
+  // of the tiled loop.
+  AffineExpr s0, s1, d0;
+  bindDims(builder.getContext(), d0);
+  bindSymbols(builder.getContext(), s0, s1);
+  AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext());
+
+  for (auto loopRange : llvm::enumerate(loopRanges)) {
+    // No loops if tile size is zero. Set offset and size to the loop
+    // offset and size.
+    if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) {
+      offsets[loopRange.index()] = loopRange.value().offset;
+      sizes[loopRange.index()] = loopRange.value().size;
+      continue;
+    }
+
+    auto loop = builder.create<scf::ForOp>(
+        loc, loopRange.value().offset, loopRange.value().size,
+        tileSizeVals[loopRange.index()], ValueRange{},
+        [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
+            ValueRange /*iterArgs*/) {
+          Value boundedTileSize = builder.create<AffineMinOp>(
+              bodyLoc, minMap,
+              ValueRange{iv, tileSizeVals[loopRange.index()],
+                         loopRange.value().size});
+          sizes[loopRange.index()] = boundedTileSize;
+          builder.create<scf::YieldOp>(loc);
+        });
+    offsets[loopRange.index()] = loop.getInductionVar();
+    loops.push_back(loop);
+    builder.setInsertionPoint(loop.getBody()->getTerminator());
+  }
+  return loops;
+}
+
+scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context,
+                                          scf::SCFTilingOptions options,
+                                          PatternBenefit benefit)
+    : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+      options(std::move(options)) {}
+
+scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName,
+                                          MLIRContext *context,
+                                          scf::SCFTilingOptions options,
+                                          PatternBenefit benefit)
+    : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+      options(std::move(options)) {}
+
+FailureOr<scf::SCFTilingResult>
+scf::TileUsingSCFForOp::returningMatchAndRewrite(
+    TilingInterface op, PatternRewriter &rewriter) const {
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPointAfter(op);
+
+  if (!options.tileSizeComputationFunction) {
+    return rewriter.notifyMatchFailure(
+        op, "missing tile size computation function");
+  }
+
+  // 1. Get the range of the loops that are represented by the operation.
+  SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
+  size_t numLoops = iterationDomain.size();
+  if (numLoops == 0) {
+    return rewriter.notifyMatchFailure(
+        op, "unable to tile op with no iteration domain");
+  }
+
+  // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
+  // skips tiling a particular dimension. This convention is significantly
+  // simpler to handle instead of adjusting affine maps to account for missing
+  // dimensions.
+  SmallVector<Value, 4> tileSizeVector =
+      options.tileSizeComputationFunction(rewriter, op);
+  if (tileSizeVector.size() < iterationDomain.size()) {
+    auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
+    tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
+  }
+
+  scf::SCFTilingResult tilingResult;
+  SmallVector<OpFoldResult> offsets, sizes;
+  {
+    // 3. Materialize an empty loop nest that iterates over the tiles. These
+    // loops for now do not return any values even if the original operation has
+    // results.
+    tilingResult.loops = generateTileLoopNest(
+        rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
+
+    LLVM_DEBUG({
+      if (!tilingResult.loops.empty()) {
+        llvm::errs() << "LoopNest shell :\n";
+        tilingResult.loops.front().dump();
+        llvm::errs() << "\n";
+      }
+    });
+
+    // 4. Generate the tiled implementation within the inner most loop.
+    if (!tilingResult.loops.empty())
+      rewriter.setInsertionPoint(
+          tilingResult.loops.back().getBody()->getTerminator());
+    SmallVector<Operation *> tiledImplementation = op.getTiledImplementation(
+        rewriter, op.getDestinationOperands(rewriter), offsets, sizes, true);
+    if (tiledImplementation.size() != 1) {
+      return rewriter.notifyMatchFailure(
+          op, "expected tiled implementation to return a single op");
+    }
+    tilingResult.tiledOp = tiledImplementation[0];
+
+    LLVM_DEBUG({
+      if (!tilingResult.loops.empty()) {
+        llvm::errs() << "After tiled implementation :\n";
+        tilingResult.loops.front().dump();
+        llvm::errs() << "\n";
+      }
+    });
+  }
+
+  if (op->getNumResults() == 0) {
+    rewriter.eraseOp(op);
+    return tilingResult;
+  }
+
+  // 5. If the original operations has results, modify the loop nest to yield
+  // the replacement values.
+  SmallVector<Value> replacements;
+  if (tilingResult.loops.empty()) {
+    // 5a. If there were no loops, the tiled implementation results are the
+    // replacements.
+    rewriter.replaceOp(op, tilingResult.tiledOp->getResults());
+    return tilingResult;
+  }
+
+  // 5b. `scf.for` with tensor semantics requires the loop nest to yield the
+  // replacement values using destructive updates. Use the `TilingInterface`
+  // to get the position of the result tiles and use that to generate the
+  // destructive update pattern, i.e.,
+  //
+  // ```mlir
+  // scf.for %iv0 = ... {
+  //   %0 = tiled_op
+  // }
+  // ```
+  //
+  // is transformed to
+  //
+  // ```mlir
+  // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. {
+  //   %0 = tiled_op
+  //   %1 = tensor.insert_slice %0 into %arg[..] [..] [..]
+  //   scf.yield %1
+  // }
+  // ```
+  NewYieldValueFn yieldValueFn =
+      [&](OpBuilder &b, Location loc,
+          ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
+    SmallVector<Value> yieldedValues;
+    Attribute one = b.getIndexAttr(1);
+    for (auto resultNum : llvm::seq<unsigned>(0, op->getNumResults())) {
+      SmallVector<OpFoldResult> resultTileOffsets, resultTileSizes;
+      if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes,
+                                          resultTileOffsets,
+                                          resultTileSizes))) {
+        op.emitOpError("unable to get position of result ")
+            << resultNum << " of the tiled implementation";
+        return {};
+      }
+      SmallVector<OpFoldResult> resultTileStrides(resultTileOffsets.size(),
+                                                  one);
+      Value yieldedValue = b.create<tensor::InsertSliceOp>(
+          op->getLoc(), tilingResult.tiledOp->getResult(resultNum),
+          newBBArgs[resultNum], resultTileOffsets, resultTileSizes,
+          resultTileStrides);
+      yieldedValues.push_back(yieldedValue);
+    }
+    return yieldedValues;
+  };
+  SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields(
+      rewriter, tilingResult.loops, op.getDestinationOperands(rewriter),
+      yieldValueFn);
+  for (auto loop : llvm::enumerate(tilingResult.loops)) {
+    rewriter.eraseOp(loop.value());
+    tilingResult.loops[loop.index()] = newLoops[loop.index()];
+  }
+  rewriter.replaceOp(op, tilingResult.loops.front().getResults());
+  return tilingResult;
+}
index bce73bd..2ffe2e9 100644 (file)
@@ -23,6 +23,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
 
 using namespace mlir;
 
@@ -101,6 +102,31 @@ mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
   return newLoop;
 }
 
+SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
+    OpBuilder &builder, ArrayRef<scf::ForOp> loopNest,
+    ValueRange newIterOperands, NewYieldValueFn newYieldValueFn) {
+  if (loopNest.empty())
+    return {};
+  SmallVector<scf::ForOp> newLoopNest(loopNest.size());
+
+  newLoopNest.back() = replaceLoopWithNewYields(
+      builder, loopNest.back(), newIterOperands, newYieldValueFn);
+
+  for (unsigned loopDepth :
+       llvm::reverse(llvm::seq<unsigned>(0, loopNest.size() - 1))) {
+    NewYieldValueFn fn = [&](OpBuilder &innerBuilder, Location loc,
+                             ArrayRef<BlockArgument> innerNewBBArgs) {
+      SmallVector<Value> newYields(
+          newLoopNest[loopDepth + 1]->getResults().take_back(
+              newIterOperands.size()));
+      return newYields;
+    };
+    newLoopNest[loopDepth] = replaceLoopWithNewYields(
+        builder, loopNest[loopDepth], newIterOperands, fn);
+  }
+  return newLoopNest;
+}
+
 /// Outline a region with a single block into a new FuncOp.
 /// Assumes the FuncOp result types is the type of the yielded operands of the
 /// single block. This constraint makes it easy to determine the result.
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
new file mode 100644 (file)
index 0000000..1e09432
--- /dev/null
@@ -0,0 +1,194 @@
+// RUN: mlir-opt -test-tiling-interface -split-input-file %s | FileCheck %s
+
+func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul {__internal_linalg_transform__ = "simple_gemm"}
+      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+//      CHECK: func.func @simple_matmul(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
+//  CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
+//  CHECK-DAG:   %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//  CHECK-DAG:   %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+//      CHECK:   %[[OUTER:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
+// CHECK-SAME:       iter_args(%[[INIT0:.+]] = %[[ARG2]])
+//      CHECK:     %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[M]]]
+//      CHECK:     %[[INNER:[a-zA-Z0-9]+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
+// CHECK-SAME:         iter_args(%[[INIT1:.+]] = %[[INIT0]])
+//      CHECK:       %[[TS_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[N]]]
+//  CHECK-DAG:       %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME:           [%[[IV0]], 0] [%[[TS_Y]], %[[K]]] [1, 1]
+//  CHECK-DAG:       %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]]
+// CHECK-SAME:           [0, %[[IV1]]] [%[[K]], %[[TS_X]]] [1, 1]
+//  CHECK-DAG:       %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT1]]
+// CHECK-SAME:           [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
+//      CHECK:       %[[GEMM_TILE:.+]] = linalg.matmul
+// CHECK-SAME:           ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME:           outs(%[[INIT_TILE]] :
+//      CHECK:       %[[UPDATE:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[INIT1]]
+// CHECK-SAME:           [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
+//      CHECK:       scf.yield %[[UPDATE]]
+//      CHECK:     scf.yield %[[INNER]]
+//      CHECK:   return %[[OUTER]]
+
+// -----
+
+func.func @simple_matmul_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
+    %arg2 : memref<?x?xf32>) {
+  linalg.matmul {__internal_linalg_transform__ = "simple_gemm_memref"}
+      ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%arg2 : memref<?x?xf32>)
+  return
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)>
+//      CHECK: func.func @simple_matmul_memref(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
+//  CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
+//  CHECK-DAG:   %[[C30:.+]] = arith.constant 30 : index
+//  CHECK-DAG:   %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
+//  CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
+//      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
+//      CHECK:     %[[TS_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[M]]]
+//      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
+//      CHECK:       %[[TS_N:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[N]]]
+//      CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C30]]
+//      CHECK:         %[[TS_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[C30]], %[[K]]]
+//  CHECK-DAG:         %[[LHS_TILE:.+]] = memref.subview %[[ARG0]]
+// CHECK-SAME:             [%[[IV0]], %[[IV2]]] [%[[TS_M]], %[[TS_K]]] [1, 1]
+//  CHECK-DAG:         %[[RHS_TILE:.+]] = memref.subview %[[ARG1]]
+// CHECK-SAME:             [%[[IV2]], %[[IV1]]] [%[[TS_K]], %[[TS_N]]] [1, 1]
+//  CHECK-DAG:         %[[OUT_TILE:.+]] = memref.subview %[[ARG2]]
+// CHECK-SAME:             [%[[IV0]], %[[IV1]]] [%[[TS_M]], %[[TS_N]]] [1, 1]
+//      CHECK:         linalg.matmul
+// CHECK-SAME:             ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME:             outs(%[[OUT_TILE]] :
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
+func.func @multi_result(%arg0 : tensor<128x200x300xf32>) -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>) {
+  %init0 = linalg.init_tensor [128, 300, 200] : tensor<128x300x200xf32>
+  %init1 = linalg.init_tensor [300, 128, 200] : tensor<300x128x200xf32>
+  %0:2 = linalg.generic {
+      indexing_maps = [#map0, #map1, #map2],
+      iterator_types = ["parallel", "parallel", "parallel"]}
+      {__internal_linalg_transform__ = "parallel_generic_transpose"}
+      ins(%arg0 : tensor<128x200x300xf32>)
+      outs(%init0, %init1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      linalg.yield %b0, %b0 : f32, f32
+    } -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>)
+  return %0#0, %0#1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+//      CHECK: func.func @multi_result(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>)
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
+//  CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
+//  CHECK-DAG:   %[[C128:.+]] = arith.constant 128 : index
+//  CHECK-DAG:   %[[C300:.+]] = arith.constant 300 : index
+//  CHECK-DAG:   %[[INIT0:.+]] = linalg.init_tensor [128, 300, 200]
+//  CHECK-DAG:   %[[INIT1:.+]] = linalg.init_tensor [300, 128, 200]
+//      CHECK:   %[[OUTER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C10]]
+// CHECK-SAME:       iter_args(%[[ARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
+//      CHECK:     %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[C128]]]
+//      CHECK:     %[[INNER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C20]]
+// CHECK-SAME:         iter_args(%[[ARG3:[a-zA-Z0-9]+]] = %[[ARG1]], %[[ARG4:[a-zA-Z0-9]+]] = %[[ARG2]])
+//      CHECK:       %[[TS_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[C300]]]
+//  CHECK-DAG:       %[[ARG_TILE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME:           [%[[IV0]], 0, %[[IV1]]] [%[[TS_Y]], 200, %[[TS_X]]] [1, 1, 1]
+//  CHECK-DAG:       %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ARG3]]
+// CHECK-SAME:           [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], %[[TS_X]], 200] [1, 1, 1]
+//  CHECK-DAG:       %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ARG4]]
+// CHECK-SAME:           [%[[IV1]], %[[IV0]], 0] [%[[TS_X]], %[[TS_Y]], 200] [1, 1, 1]
+//      CHECK:       %[[RESULT_TILE:.+]]:2 = linalg.generic
+// CHECK-SAME:           ins(%[[ARG_TILE]] :
+// CHECK-SAME:           outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
+//      CHECK:       %[[UPDATE0:.+]] = tensor.insert_slice %[[RESULT_TILE]]#0 into %[[ARG3]]
+// CHECK-SAME:           [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], %[[TS_X]], 200] [1, 1, 1]
+//      CHECK:       %[[UPDATE1:.+]] = tensor.insert_slice %[[RESULT_TILE]]#1 into %[[ARG4]]
+// CHECK-SAME:           [%[[IV1]], %[[IV0]], 0] [%[[TS_X]], %[[TS_Y]], 200] [1, 1, 1]
+//      CHECK:       scf.yield %[[UPDATE0]], %[[UPDATE1]]
+//      CHECK:     scf.yield %[[INNER]]#0, %[[INNER]]#1
+//      CHECK:   return %[[OUTER]]#0, %[[OUTER]]#1
+
+// -----
+
+func.func @conv2D(%arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,
+    %arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.conv_2d_nhwc_hwcf {
+      strides = dense<[2, 3]> : tensor<2xi64>,
+      dilation = dense<[4, 5]> : tensor<2xi64>,
+      __internal_linalg_transform__ = "simple_conv"}
+      ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+      outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0 * 2 - 2)>
+//  CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0)[s0] -> (d0 + s0 * 3 - 3)>
+//      CHECK: func.func @conv2D(
+// CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME:     %[[FILTER:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
+//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
+//  CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
+//  CHECK-DAG:   %[[C30:.+]] = arith.constant 30 : index
+//  CHECK-DAG:   %[[N:.+]] = tensor.dim %[[INPUT]], %[[C0]]
+//  CHECK-DAG:   %[[C:.+]] = tensor.dim %[[INPUT]], %[[C3]]
+//  CHECK-DAG:   %[[P:.+]] = tensor.dim %[[FILTER]], %[[C0]]
+//  CHECK-DAG:   %[[Q:.+]] = tensor.dim %[[FILTER]], %[[C1]]
+//  CHECK-DAG:   %[[F:.+]] = tensor.dim %[[FILTER]], %[[C3]]
+//  CHECK-DAG:   %[[R:.+]] = tensor.dim %[[INIT]], %[[C1]]
+//  CHECK-DAG:   %[[S:.+]] = tensor.dim %[[INIT]], %[[C2]]
+//      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[P]] step %[[C10]]
+// CHECK-SAME:       iter_args(%[[INIT0:.+]] = %[[INIT]])
+//      CHECK:     %[[TS_P:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[P]]]
+//      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[Q]] step %[[C20]]
+// CHECK-SAME:         iter_args(%[[INIT1:.+]] = %[[INIT0]])
+//      CHECK:       %[[TS_Q:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[Q]]]
+//      CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C30]]
+// CHECK-SAME:           iter_args(%[[INIT2:.+]] = %[[INIT1]])
+//  CHECK-DAG:         %[[TS_C:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[C30]], %[[C]]]
+//  CHECK-DAG:         %[[TS_H:.+]] = affine.apply #[[MAP3]](%[[TS_P]])[%[[R]]]
+//  CHECK-DAG:         %[[TS_W:.+]] = affine.apply #[[MAP4]](%[[TS_Q]])[%[[S]]]
+//  CHECK-DAG:         %[[INPUT_TILE:.+]] = tensor.extract_slice %[[INPUT]]
+// CHECK-SAME:             [0, %[[IV0]], %[[IV1]], %[[IV2]]] [%[[N]], %[[TS_H]], %[[TS_W]], %[[TS_C]]]
+//  CHECK-DAG:         %[[FILTER_TILE:.+]] = tensor.extract_slice %[[FILTER]]
+// CHECK-SAME:             [%[[IV0]], %[[IV1]], %[[IV2]], 0] [%[[TS_P]], %[[TS_Q]], %[[TS_C]], %[[F]]]
+//  CHECK-DAG:         %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT2]]
+// CHECK-SAME:             [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]]
+//      CHECK:         %[[CONV_TILE:.+]] = linalg.conv_2d_nhwc_hwcf
+// CHECK-SAME:             dilation = dense<[4, 5]> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>
+// CHECK-SAME:             ins(%[[INPUT_TILE]], %[[FILTER_TILE]] :
+// CHECK-SAME:             outs(%[[INIT_TILE]] :
+//      CHECK:         tensor.insert_slice %[[CONV_TILE]] into %[[INIT2]]
+// CHECK-SAME:             [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]]
index 97149dc..88e55e7 100644 (file)
@@ -1,6 +1,7 @@
 add_subdirectory(Analysis)
 add_subdirectory(Conversion)
 add_subdirectory(Dialect)
+add_subdirectory(Interfaces)
 add_subdirectory(IR)
 add_subdirectory(Pass)
 add_subdirectory(Reducer)
diff --git a/mlir/test/lib/Interfaces/CMakeLists.txt b/mlir/test/lib/Interfaces/CMakeLists.txt
new file mode 100644 (file)
index 0000000..4a0567a
--- /dev/null
@@ -0,0 +1 @@
+add_subdirectory(TilingInterface)
diff --git a/mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt b/mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt
new file mode 100644 (file)
index 0000000..437e39c
--- /dev/null
@@ -0,0 +1,15 @@
+add_mlir_library(MLIRTilingInterfaceTestPasses
+  TestTilingInterface.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  LINK_LIBS PUBLIC
+  MLIRAffine
+  MLIRArithmetic
+  MLIRLinalg
+  MLIRLinalgTransforms
+  MLIRMemRef
+  MLIRSCF
+  MLIRSCFTransforms
+  MLIRTensor
+  )
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
new file mode 100644 (file)
index 0000000..c3795c3
--- /dev/null
@@ -0,0 +1,126 @@
+//===- TestTilingInterface.cpp - Test tiling using `TilingInterface` -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing tiling operations using
+// `TilingInterface`.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/TileUsingInterface.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Construct a generic pattern applied to all TilingInterface ops that verify
+/// `filter`.
+struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp {
+  TestTileUsingSCFForOpWithFilter(MLIRContext *context,
+                                  scf::SCFTilingOptions options,
+                                  linalg::LinalgTransformationFilter filter =
+                                      linalg::LinalgTransformationFilter(),
+                                  PatternBenefit benefit = 1)
+      : scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {}
+
+  /// Construct a generic pattern applied to `opName`.
+  TestTileUsingSCFForOpWithFilter(StringRef opName, MLIRContext *context,
+                                  scf::SCFTilingOptions options,
+                                  linalg::LinalgTransformationFilter filter =
+                                      linalg::LinalgTransformationFilter(),
+                                  PatternBenefit benefit = 1)
+      : scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {}
+
+  LogicalResult matchAndRewrite(TilingInterface op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(filter.checkAndNotify(rewriter, op)))
+      return failure();
+
+    FailureOr<scf::SCFTilingResult> tilingResult =
+        returningMatchAndRewrite(op, rewriter);
+    if (failed(tilingResult)) {
+      return failure();
+    }
+    filter.replaceLinalgTransformationFilter(rewriter, tilingResult->tiledOp);
+    return success();
+  }
+
+private:
+  linalg::LinalgTransformationFilter filter;
+};
+
+struct TestTilingInterfacePass
+    : public PassWrapper<TestTilingInterfacePass, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTilingInterfacePass)
+
+  TestTilingInterfacePass() = default;
+  TestTilingInterfacePass(const TestTilingInterfacePass &pass)
+      : PassWrapper(pass) {}
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<AffineDialect, memref::MemRefDialect, scf::SCFDialect,
+                    tensor::TensorDialect>();
+    linalg::registerTilingInterfaceExternalModels(registry);
+  }
+  StringRef getArgument() const final { return "test-tiling-interface"; }
+  StringRef getDescription() const final {
+    return "Test tiling using TilingInterface";
+  }
+
+  void runOnOperation() override;
+};
+} // namespace
+
+static void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns) {
+  auto addPatternForTiling = [&](ArrayRef<int64_t> tileSizes,
+                                 StringRef filterName) {
+    scf::SCFTilingOptions tilingOptions;
+    tilingOptions.setTileSizes(tileSizes);
+    linalg::LinalgTransformationFilter filter(
+        StringAttr::get(context, filterName),
+        StringAttr::get(context, "tiled"));
+    patterns.add<TestTileUsingSCFForOpWithFilter>(context, tilingOptions,
+                                                  filter);
+  };
+  // 1. Tiling M and N dims of `linalg.matmul` on tensors.
+  addPatternForTiling({10, 20}, "simple_gemm");
+  // 2. Tiling M, N and K of `linalg.matmul` on buffers.
+  addPatternForTiling({10, 20, 30}, "simple_gemm_memref");
+  // 3. Tiling 3D parallel generic op which implements a transpose
+  addPatternForTiling({10, 0, 20}, "parallel_generic_transpose");
+  // 4. Tiling 2D conv op.
+  addPatternForTiling({0, 0, 0, 0, 10, 20, 30}, "simple_conv");
+}
+
+void TestTilingInterfacePass::runOnOperation() {
+  MLIRContext *context = &getContext();
+
+  RewritePatternSet tilingPatterns(context);
+  addTestPatterns(context, tilingPatterns);
+  if (failed(applyPatternsAndFoldGreedily(getOperation(),
+                                          std::move(tilingPatterns))))
+    return signalPassFailure();
+}
+
+namespace mlir {
+namespace test {
+void registerTestTilingInterface() {
+  PassRegistration<TestTilingInterfacePass>();
+}
+} // namespace test
+} // namespace mlir
index a8172b8..97b082e 100644 (file)
@@ -33,6 +33,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRTestRewrite
     MLIRTestTransformDialect
     MLIRTestTransforms
+    MLIRTilingInterfaceTestPasses
     MLIRVectorTestPasses
     )
 endif()
index aa94294..b50cfa9 100644 (file)
@@ -111,6 +111,7 @@ void registerTestRecursiveTypesPass();
 void registerTestSCFUtilsPass();
 void registerTestSliceAnalysisPass();
 void registerTestTensorTransforms();
+void registerTestTilingInterface();
 void registerTestTransformDialectInterpreterPass();
 void registerTestVectorLowerings();
 } // namespace test
@@ -206,6 +207,7 @@ void registerTestPasses() {
   mlir::test::registerTestSCFUtilsPass();
   mlir::test::registerTestSliceAnalysisPass();
   mlir::test::registerTestTensorTransforms();
+  mlir::test::registerTestTilingInterface();
   mlir::test::registerTestTransformDialectInterpreterPass();
   mlir::test::registerTestVectorLowerings();
 }
index 5fde5e7..49c08f7 100644 (file)
@@ -1864,6 +1864,7 @@ cc_library(
         "include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h",
         "include/mlir/Dialect/SCF/Passes.h",
         "include/mlir/Dialect/SCF/Patterns.h",
+        "include/mlir/Dialect/SCF/TileUsingInterface.h",
         "include/mlir/Dialect/SCF/Transforms.h",
     ],
     includes = ["include"],
@@ -1883,6 +1884,7 @@ cc_library(
         ":SCFUtils",
         ":Support",
         ":TensorDialect",
+        ":TilingInterface",
         ":Transforms",
         "//llvm:Support",
     ],
@@ -2645,6 +2647,7 @@ cc_library(
         exclude = [
             "include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h",
             "include/mlir/Dialect/SCF/Patterns.h",
+            "include/mlir/Dialect/SCF/TileUsingInterface.h",
             "include/mlir/Dialect/SCF/Transforms.h",
         ],
     ),
@@ -6313,6 +6316,7 @@ cc_binary(
         "//mlir/test:TestSPIRV",
         "//mlir/test:TestShapeDialect",
         "//mlir/test:TestTensor",
+        "//mlir/test:TestTilingInterface",
         "//mlir/test:TestTosaDialect",
         "//mlir/test:TestTransformDialect",
         "//mlir/test:TestTransforms",
@@ -7492,6 +7496,7 @@ cc_library(
         ":TensorTilingInterfaceImpl",
         ":TensorTransforms",
         ":TensorUtils",
+        ":TilingInterface",
         ":TransformUtils",
         ":Transforms",
         ":VectorDialect",
index 742e7b6..fa89b9c 100644 (file)
@@ -294,6 +294,28 @@ cc_library(
 )
 
 cc_library(
+    name = "TestTilingInterface",
+    srcs = glob(["lib/Interfaces/TilingInterface/*.cpp"]),
+    includes = ["lib/Dialect/Test"],
+    deps = [
+        "//llvm:Support",
+        "//mlir:Affine",
+        "//mlir:ArithmeticDialect",
+        "//mlir:FuncDialect",
+        "//mlir:IR",
+        "//mlir:LinalgDialect",
+        "//mlir:LinalgTransforms",
+        "//mlir:MemRefDialect",
+        "//mlir:Pass",
+        "//mlir:SCFDialect",
+        "//mlir:SCFTransforms",
+        "//mlir:TensorDialect",
+        "//mlir:TilingInterface",
+        "//mlir:Transforms",
+    ],
+)
+
+cc_library(
     name = "TestPass",
     srcs = glob(["lib/Pass/*.cpp"]),
     deps = [