[mlir][TilingInterface] Add pattern to tile using TilingInterface and implement Tilin...
authorMahesh Ravishankar <ravishankarm@google.com>
Mon, 13 Jun 2022 19:56:32 +0000 (19:56 +0000)
committerMahesh Ravishankar <ravishankarm@google.com>
Mon, 13 Jun 2022 20:37:44 +0000 (20:37 +0000)
This patch adds support for tiling operations that implement the
TilingInterface.
- It separates the loop constructs that are used to iterate over tile
  from the implementation of the tiling itself. For example, the use
  of destructive updates is more related to use of scf.for for
  iterating over tiles that are tensors.
- To test the transformation, TilingInterface is implemented for
  LinalgOps. The separation of the looping constructs used from the
  implementation of tile code generation greatly simplifies the
  latter.
- The implementation of TilingInterface for LinalgOp is kept as an
  external model for now till this approach can be fully flushed out
  to replace the existing tiling + fusion approaches in Linalg.

Differential Revision: https://reviews.llvm.org/D127133

21 files changed:
mlir/include/mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h [new file with mode: 0644]
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/include/mlir/Dialect/SCF/TileUsingInterface.h [new file with mode: 0644]
mlir/include/mlir/Dialect/SCF/Utils/Utils.h
mlir/include/mlir/Interfaces/TilingInterface.td
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp [new file with mode: 0644]
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp [new file with mode: 0644]
mlir/lib/Dialect/SCF/Utils/Utils.cpp
mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir [new file with mode: 0644]
mlir/test/lib/CMakeLists.txt
mlir/test/lib/Interfaces/CMakeLists.txt [new file with mode: 0644]
mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt [new file with mode: 0644]
mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp [new file with mode: 0644]
mlir/tools/mlir-opt/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h
new file mode 100644 (file)
index 0000000..5b88f1d
--- /dev/null
@@ -0,0 +1,20 @@
+//===- TilingInterfaceImpl.h - Implementation of TilingInterface ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_TILINGINTERFACEIMPL_H
+#define MLIR_DIALECT_LINALG_TILINGINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace linalg {
+void registerTilingInterfaceExternalModels(DialectRegistry &registry);
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_TILINGINTERFACEIMPL_H
index 3e9d07209017cf976cec6fc396aeeac91f8b1122..36b143b2d2ce85c48efe7340ce09e49e1d9eb857 100644 (file)
@@ -164,11 +164,11 @@ bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer,
 SmallVector<Value> computeTileOffsets(OpBuilder &b, Location loc,
                                       ValueRange ivs, ValueRange tileSizes);
 
-/// Compute tile sizes, given a list of loop `ivs`, `tileSizes` and dimension
+/// Compute tile sizes, given a list of `tileSizes` and dimension
 /// sizes (`sizeBounds`). In case a tile size is zero (i.e., no tiling), the
 /// corresponding result size is the corresponding value from `sizeBounds`.
 /// Note: The returned tile sizes are closed intervals.
-SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc, ValueRange ivs,
+SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc,
                                     ValueRange tileSizes,
                                     ArrayRef<Value> sizeBounds);
 
diff --git a/mlir/include/mlir/Dialect/SCF/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/TileUsingInterface.h
new file mode 100644 (file)
index 0000000..25911ce
--- /dev/null
@@ -0,0 +1,87 @@
+//===- TileUsingInterface.h - Tiling ops using TilingInterface --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SCF_TILEUSINGINTERFACE_H
+#define MLIR_DIALECT_SCF_TILEUSINGINTERFACE_H
+
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/TilingInterface.h"
+
+namespace mlir {
+class Operation;
+class PatternRewriter;
+class TilingInterface;
+} // namespace mlir
+
+namespace mlir {
+namespace scf {
+
+using SCFTileSizeComputationFunction =
+    std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>;
+
+/// Options to use to control tiling.
+struct SCFTilingOptions {
+  /// Computation function that returns the tile sizes for each operation.
+  /// Delayed construction of constant tile sizes should occur to interoperate
+  /// with folding.
+  SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
+
+  SCFTilingOptions &
+  setTileSizeComputationFunction(SCFTileSizeComputationFunction fun) {
+    tileSizeComputationFunction = std::move(fun);
+    return *this;
+  }
+  /// Set the `tileSizeComputationFunction` to return the values `ts`. The
+  /// values must not fold away when tiling. Otherwise, use a more robust
+  /// `tileSizeComputationFunction`.
+  SCFTilingOptions &setTileSizes(const SmallVector<Value, 4> &ts) {
+    tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; };
+    return *this;
+  }
+  /// Convenience function to set the `tileSizeComputationFunction` to a
+  /// function that computes tile sizes at the point they are needed. Allows
+  /// proper interaction with folding.
+  SCFTilingOptions &setTileSizes(ArrayRef<int64_t> ts);
+};
+
+struct SCFTilingResult {
+  Operation *tiledOp;
+  SmallVector<scf::ForOp> loops;
+};
+
+/// Pattern to tile an op that implementas the `TilingInterface` using
+/// `scf.for` for iterating over the tiles.
+struct TileUsingSCFForOp : public OpInterfaceRewritePattern<TilingInterface> {
+  /// Construct a generic pattern applied to all TilingInterface ops.
+  TileUsingSCFForOp(MLIRContext *context, SCFTilingOptions options,
+                    PatternBenefit benefit = 1);
+
+  /// Construct a generic pattern applied to `opName`.
+  TileUsingSCFForOp(StringRef opName, MLIRContext *context,
+                    SCFTilingOptions options, PatternBenefit benefit = 1);
+
+  /// `matchAndRewrite` implementation that returns the significant transformed
+  /// pieces of IR.
+  FailureOr<SCFTilingResult>
+  returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const;
+
+  LogicalResult matchAndRewrite(TilingInterface op,
+                                PatternRewriter &rewriter) const override {
+    return returningMatchAndRewrite(op, rewriter);
+  }
+
+private:
+  /// Options to control tiling;
+  SCFTilingOptions options;
+};
+
+} // namespace scf
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SCF_TILEUSINGINTERFACE_H
index ebd055cad4ee8ac7c8c39c38b9aae38986dcdc76..3c75754d64125d337d0c09b166e4e44c4338fda6 100644 (file)
@@ -13,6 +13,7 @@
 #ifndef MLIR_DIALECT_SCF_UTILS_UTILS_H_
 #define MLIR_DIALECT_SCF_UTILS_UTILS_H_
 
+#include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
@@ -32,12 +33,6 @@ class CallOp;
 class FuncOp;
 } // namespace func
 
-namespace scf {
-class IfOp;
-class ForOp;
-class ParallelOp;
-} // namespace scf
-
 /// Replace the `loop` with `newIterOperands` added as new initialization
 /// values. `newYieldValuesFn` is a callback that can be used to specify
 /// the additional values to be yielded by the loop. The number of
@@ -57,6 +52,25 @@ scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
                                     ValueRange newIterOperands,
                                     const NewYieldValueFn &newYieldValuesFn);
 
+/// Update a perfectly nested loop nest to yield new values from the innermost
+/// loop and propagating it up through the loop nest. This function
+/// - Expects `loopNest` to be a perfectly nested loop with outer most loop
+///   first and innermost loop last.
+/// - `newIterOperands` are the initialization values to be used for the
+///    outermost loop
+/// - `newYielValueFn` is the callback that generates the new values to be
+///   yielded from within the innermost loop.
+/// - The original loops are not erased,  but are left in a "no-op" state where
+///   the body of the loop just yields the basic block arguments that correspond
+///   to the initialization values of a loop. The original loops are dead after
+///   this method.
+/// - All uses of the `newIterOperands` within the generated new loop
+///   are replaced with the corresponding `BlockArgument` in the loop body.
+SmallVector<scf::ForOp>
+replaceLoopNestWithNewYields(OpBuilder &builder, ArrayRef<scf::ForOp> loopNest,
+                             ValueRange newIterOperands,
+                             NewYieldValueFn newYieldValueFn);
+
 /// Outline a region with a single block into a new FuncOp.
 /// Assumes the FuncOp result types is the type of the yielded operands of the
 /// single block. This constraint makes it easy to determine the result.
index 6346899b3998108312890147943704052535bd4e..606901375ede8fbea5312d42dc58dc1cf3b9ff54 100644 (file)
@@ -98,6 +98,28 @@ def TilingInterface : OpInterface<"TilingInterface"> {
         /*defaultImplementation=*/[{
           return {};
         }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to return the position of the result tile computed by the tiled operation.
+
+          Specifies what tile of the result of the original tensor is computed
+          by the tiled implementation. Expects the same `offsets` and `sizes` as
+          used to obtain the tiled implementation of the operation.
+        }],
+        /*retType=*/"LogicalResult",
+        /*methodName=*/"getResultTilePosition",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$resultNumber,
+          "ArrayRef<OpFoldResult> ":$offsets,
+          "ArrayRef<OpFoldResult> ":$sizes,
+          "SmallVector<OpFoldResult> &":$resultOffsets,
+          "SmallVector<OpFoldResult> &":$resultSizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
       >
   ];
 }
index fc17fba490aa17b5e2dfa1b96da0de2f09d613ef..cf771861ff58033e4539d84555269d33b8dcea1a 100644 (file)
@@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   SparseTensorRewriting.cpp
   SplitReduction.cpp
   Tiling.cpp
+  TilingInterfaceImpl.cpp
   Transforms.cpp
   Vectorization.cpp
 
index dfc78977c560eb57bdce7efcd02656d54e73010f..bb4760588bc8e8d49eb859958b34512b67511170 100644 (file)
@@ -320,8 +320,7 @@ static LogicalResult tilePadOp(RewriterBase &builder, tensor::PadOp op,
         // Compute offsets and sizes of ExtractSliceOp.
         SmallVector<Value> offsets =
             computeTileOffsets(b, loc, localIvs, tileSizes);
-        SmallVector<Value> sizes =
-            computeTileSizes(b, loc, localIvs, tileSizes, allDims);
+        SmallVector<Value> sizes = computeTileSizes(b, loc, tileSizes, allDims);
         // Create ExtractSliceOp: Extract a tile from the tensor::PadOp.
         // Note: The tensor::PadOp is located outside of the loop nest. It is
         // later moved inside by ExtractSliceOfPadTensorSwapPattern.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
new file mode 100644 (file)
index 0000000..c67097a
--- /dev/null
@@ -0,0 +1,156 @@
+//===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/TilingInterface.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+
+/// External model implementation of TilingInterface for LinalgOps. An external
+/// model implementation is used for now till the use of `TilingInterface` is
+/// on-par with the current Linalg tiling + fusion patterns. Once it is
+/// maybe possible to move this into the op-definition (though there are
+/// advantages to leaving it as an external model)
+template <typename LinalgOpTy>
+struct LinalgOpTilingInterface
+    : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
+                                            LinalgOpTy> {
+
+  /// Return the destination operands.
+  SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
+    return llvm::cast<LinalgOp>(op).getOutputOperands();
+  }
+
+  /// Return the loop iterator type.
+  SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
+    LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
+    return llvm::to_vector(
+        llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) {
+          return strAttr.cast<StringAttr>().getValue();
+        }));
+  }
+
+  /// Return the iteration domain range.
+  SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
+    Location loc = op->getLoc();
+    LinalgOp linalgOp = cast<LinalgOp>(op);
+    auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc);
+    AffineMap map = linalgOp.getShapesToLoopsMap();
+    Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+    Value one = b.create<arith::ConstantIndexOp>(loc, 1);
+    return llvm::to_vector(llvm::map_range(
+        applyMapToValues(b, loc, map, allShapesSizes), [&](Value v) {
+          return Range{zero, v, one};
+        }));
+  }
+
+  // Instantiate the tiled implementation of the operation.
+  SmallVector<Operation *>
+  getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest,
+                         ArrayRef<OpFoldResult> offsets,
+                         ArrayRef<OpFoldResult> sizes,
+                         bool tileDestOperands) const {
+    // Leave the `sizeBounds` value empty. That is only needed when the `sizes`
+    // specified could lead to out of bounds accesses.
+    Location loc = op->getLoc();
+    LinalgOp linalgOp = cast<LinalgOp>(op);
+    SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
+    SmallVector<Value, 4> tiledOperands = makeTiledShapes(
+        b, loc, linalgOp, valuesToTile,
+        getValueOrCreateConstantIndexOp(b, loc, offsets),
+        getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true);
+
+    SmallVector<Type> resultTensorTypes = llvm::to_vector(llvm::map_range(
+        linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) {
+          return tiledOperands[opOperand->getOperandNumber()].getType();
+        }));
+
+    Operation *tiledOp =
+        linalgOp.clone(b, loc, resultTensorTypes, tiledOperands);
+
+    return {tiledOp};
+  }
+
+  // Return the details of the output tile generated by the tiled
+  // implementation.
+  LogicalResult
+  getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
+                        ArrayRef<OpFoldResult> offsets,
+                        ArrayRef<OpFoldResult> sizes,
+                        SmallVector<OpFoldResult> &resultOffsets,
+                        SmallVector<OpFoldResult> &resultSizes) const {
+    Location loc = op->getLoc();
+    LinalgOp linalgOp = cast<LinalgOp>(op);
+
+    AffineExpr d0;
+    bindDims(b.getContext(), d0);
+
+    auto fullyComposeAffineMapAndOperands = [](OpBuilder &builder, Location loc,
+                                               AffineExpr expr,
+                                               ValueRange operands) -> Value {
+      AffineMap map = AffineMap::inferFromExprList({expr}).front();
+      SmallVector<Value> normalizedOperands(operands.begin(), operands.end());
+      mlir::fullyComposeAffineMapAndOperands(&map, &normalizedOperands);
+      canonicalizeMapAndOperands(&map, &normalizedOperands);
+      return builder.createOrFold<AffineApplyOp>(loc, map, normalizedOperands);
+    };
+
+    SmallVector<Value> sizeVals =
+        getValueOrCreateConstantIndexOp(b, loc, sizes);
+    SmallVector<Value> subShapeSizes =
+        llvm::to_vector(llvm::map_range(sizeVals, [&](Value v) {
+          return fullyComposeAffineMapAndOperands(b, loc, d0 - 1, v);
+        }));
+    OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
+    Value sliceOpResult =
+        makeTiledShape(b, loc, outOperand->get(), sizeVals,
+                       linalgOp.getTiedIndexingMap(outOperand),
+                       getValueOrCreateConstantIndexOp(b, loc, offsets),
+                       /*ubs*/ {}, subShapeSizes, true);
+    auto sliceOp = sliceOpResult.getDefiningOp<tensor::ExtractSliceOp>();
+    if (!sliceOp)
+      return failure();
+    resultOffsets = sliceOp.getMixedOffsets();
+    resultSizes = sliceOp.getMixedSizes();
+    return success();
+  }
+};
+
+} // namespace
+
+template <typename OpType> static void registerOne(MLIRContext *ctx) {
+  OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
+}
+
+/// Variadic helper function.
+template <typename... OpTypes> static void registerAll(MLIRContext *ctx) {
+  // FIXME: In c++17 this can be simplified by using 'fold expressions'.
+  (void)std::initializer_list<int>{0, (registerOne<OpTypes>(ctx), 0)...};
+}
+
+#define GET_OP_LIST
+
+void mlir::linalg::registerTilingInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
+    registerOne<linalg::GenericOp>(ctx);
+    registerAll<
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
+        >(ctx);
+  });
+}
index bfd2c68cfa68a3f0a069a021271f0d612548e06b..bf684344387b31cd1a42921b54b8bf47f78b6798 100644 (file)
@@ -893,7 +893,7 @@ SmallVector<Value> computeTileOffsets(OpBuilder &b, Location loc,
   return offsets;
 }
 
-SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc, ValueRange ivs,
+SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc,
                                     ValueRange tileSizes,
                                     ArrayRef<Value> sizeBounds) {
   SmallVector<Value> sizes;
@@ -923,7 +923,7 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &b, Location loc,
   // that define tile subshapes.
   SmallVector<Value> lbs = computeTileOffsets(b, loc, ivs, tileSizes);
   SmallVector<Value> subShapeSizes =
-      computeTileSizes(b, loc, ivs, tileSizes, sizeBounds);
+      computeTileSizes(b, loc, tileSizes, sizeBounds);
 
   assert(static_cast<int64_t>(valuesToTile.size()) ==
              linalgOp.getNumInputsAndOutputs() &&
index 8f5322dc7b9da9b8ac8fd102c05028f9117674b9..c876c9071269c3a3c49346f8b529db74fe3884b0 100644 (file)
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   ParallelLoopFusion.cpp
   ParallelLoopTiling.cpp
   StructuralTypeConversions.cpp
+  TileUsingInterface.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
new file mode 100644 (file)
index 0000000..0f71d52
--- /dev/null
@@ -0,0 +1,249 @@
+//===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the tiling using TilingInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/TileUsingInterface.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "tile-using-interface"
+
+using namespace mlir;
+
+scf::SCFTilingOptions &
+scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
+  assert(!tileSizeComputationFunction && "tile sizes already set");
+  SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
+  tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
+    OpBuilder::InsertionGuard guard(b);
+    b.setInsertionPointToStart(
+        &op->getParentOfType<func::FuncOp>().getBody().front());
+    return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
+      Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
+      return v;
+    }));
+  };
+  return *this;
+}
+
+/// Generate an empty loop nest that represents the tiled loop nest shell.
+/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
+/// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
+/// - In `offsets` and `sizes` return the multi-dimensional offset and size of
+/// the
+///   tile processed within the inner most loop.
+static SmallVector<scf::ForOp>
+generateTileLoopNest(OpBuilder &builder, Location loc,
+                     ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
+                     SmallVector<OpFoldResult> &offsets,
+                     SmallVector<OpFoldResult> &sizes) {
+  assert(!loopRanges.empty() && "expected at least one loop range");
+  assert(loopRanges.size() == tileSizeVals.size() &&
+         "expected as many tile sizes as loop ranges");
+  OpBuilder::InsertionGuard guard(builder);
+  SmallVector<scf::ForOp> loops;
+  offsets.resize(loopRanges.size());
+  sizes.resize(loopRanges.size());
+
+  // The tile size to use (to avoid out of bounds access) is  minimum of
+  // `tileSize` and `ub - iv`, where `iv` is the induction variable
+  // of the tiled loop.
+  AffineExpr s0, s1, d0;
+  bindDims(builder.getContext(), d0);
+  bindSymbols(builder.getContext(), s0, s1);
+  AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext());
+
+  for (auto loopRange : llvm::enumerate(loopRanges)) {
+    // No loops if tile size is zero. Set offset and size to the loop
+    // offset and size.
+    if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) {
+      offsets[loopRange.index()] = loopRange.value().offset;
+      sizes[loopRange.index()] = loopRange.value().size;
+      continue;
+    }
+
+    auto loop = builder.create<scf::ForOp>(
+        loc, loopRange.value().offset, loopRange.value().size,
+        tileSizeVals[loopRange.index()], ValueRange{},
+        [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
+            ValueRange /*iterArgs*/) {
+          Value boundedTileSize = builder.create<AffineMinOp>(
+              bodyLoc, minMap,
+              ValueRange{iv, tileSizeVals[loopRange.index()],
+                         loopRange.value().size});
+          sizes[loopRange.index()] = boundedTileSize;
+          builder.create<scf::YieldOp>(loc);
+        });
+    offsets[loopRange.index()] = loop.getInductionVar();
+    loops.push_back(loop);
+    builder.setInsertionPoint(loop.getBody()->getTerminator());
+  }
+  return loops;
+}
+
+scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context,
+                                          scf::SCFTilingOptions options,
+                                          PatternBenefit benefit)
+    : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+      options(std::move(options)) {}
+
+scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName,
+                                          MLIRContext *context,
+                                          scf::SCFTilingOptions options,
+                                          PatternBenefit benefit)
+    : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+      options(std::move(options)) {}
+
+FailureOr<scf::SCFTilingResult>
+scf::TileUsingSCFForOp::returningMatchAndRewrite(
+    TilingInterface op, PatternRewriter &rewriter) const {
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPointAfter(op);
+
+  if (!options.tileSizeComputationFunction) {
+    return rewriter.notifyMatchFailure(
+        op, "missing tile size computation function");
+  }
+
+  // 1. Get the range of the loops that are represented by the operation.
+  SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
+  size_t numLoops = iterationDomain.size();
+  if (numLoops == 0) {
+    return rewriter.notifyMatchFailure(
+        op, "unable to tile op with no iteration domain");
+  }
+
+  // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
+  // skips tiling a particular dimension. This convention is significantly
+  // simpler to handle instead of adjusting affine maps to account for missing
+  // dimensions.
+  SmallVector<Value, 4> tileSizeVector =
+      options.tileSizeComputationFunction(rewriter, op);
+  if (tileSizeVector.size() < iterationDomain.size()) {
+    auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
+    tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
+  }
+
+  scf::SCFTilingResult tilingResult;
+  SmallVector<OpFoldResult> offsets, sizes;
+  {
+    // 3. Materialize an empty loop nest that iterates over the tiles. These
+    // loops for now do not return any values even if the original operation has
+    // results.
+    tilingResult.loops = generateTileLoopNest(
+        rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
+
+    LLVM_DEBUG({
+      if (!tilingResult.loops.empty()) {
+        llvm::errs() << "LoopNest shell :\n";
+        tilingResult.loops.front().dump();
+        llvm::errs() << "\n";
+      }
+    });
+
+    // 4. Generate the tiled implementation within the inner most loop.
+    if (!tilingResult.loops.empty())
+      rewriter.setInsertionPoint(
+          tilingResult.loops.back().getBody()->getTerminator());
+    SmallVector<Operation *> tiledImplementation = op.getTiledImplementation(
+        rewriter, op.getDestinationOperands(rewriter), offsets, sizes, true);
+    if (tiledImplementation.size() != 1) {
+      return rewriter.notifyMatchFailure(
+          op, "expected tiled implementation to return a single op");
+    }
+    tilingResult.tiledOp = tiledImplementation[0];
+
+    LLVM_DEBUG({
+      if (!tilingResult.loops.empty()) {
+        llvm::errs() << "After tiled implementation :\n";
+        tilingResult.loops.front().dump();
+        llvm::errs() << "\n";
+      }
+    });
+  }
+
+  if (op->getNumResults() == 0) {
+    rewriter.eraseOp(op);
+    return tilingResult;
+  }
+
+  // 5. If the original operations has results, modify the loop nest to yield
+  // the replacement values.
+  SmallVector<Value> replacements;
+  if (tilingResult.loops.empty()) {
+    // 5a. If there were no loops, the tiled implementation results are the
+    // replacements.
+    rewriter.replaceOp(op, tilingResult.tiledOp->getResults());
+    return tilingResult;
+  }
+
+  // 5b. `scf.for` with tensor semantics requires the loop nest to yield the
+  // replacement values using destructive updates. Use the `TilingInterface`
+  // to get the position of the result tiles and use that to generate the
+  // destructive update pattern, i.e.,
+  //
+  // ```mlir
+  // scf.for %iv0 = ... {
+  //   %0 = tiled_op
+  // }
+  // ```
+  //
+  // is transformed to
+  //
+  // ```mlir
+  // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. {
+  //   %0 = tiled_op
+  //   %1 = tensor.insert_slice %0 into %arg[..] [..] [..]
+  //   scf.yield %1
+  // }
+  // ```
+  NewYieldValueFn yieldValueFn =
+      [&](OpBuilder &b, Location loc,
+          ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
+    SmallVector<Value> yieldedValues;
+    Attribute one = b.getIndexAttr(1);
+    for (auto resultNum : llvm::seq<unsigned>(0, op->getNumResults())) {
+      SmallVector<OpFoldResult> resultTileOffsets, resultTileSizes;
+      if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes,
+                                          resultTileOffsets,
+                                          resultTileSizes))) {
+        op.emitOpError("unable to get position of result ")
+            << resultNum << " of the tiled implementation";
+        return {};
+      }
+      SmallVector<OpFoldResult> resultTileStrides(resultTileOffsets.size(),
+                                                  one);
+      Value yieldedValue = b.create<tensor::InsertSliceOp>(
+          op->getLoc(), tilingResult.tiledOp->getResult(resultNum),
+          newBBArgs[resultNum], resultTileOffsets, resultTileSizes,
+          resultTileStrides);
+      yieldedValues.push_back(yieldedValue);
+    }
+    return yieldedValues;
+  };
+  SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields(
+      rewriter, tilingResult.loops, op.getDestinationOperands(rewriter),
+      yieldValueFn);
+  for (auto loop : llvm::enumerate(tilingResult.loops)) {
+    rewriter.eraseOp(loop.value());
+    tilingResult.loops[loop.index()] = newLoops[loop.index()];
+  }
+  rewriter.replaceOp(op, tilingResult.loops.front().getResults());
+  return tilingResult;
+}
index bce73bd3c432d56289a73cdc7416cef73602eb4a..2ffe2e955d1bc6ef1de911008a9ecb62312ad9dc 100644 (file)
@@ -23,6 +23,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
 
 using namespace mlir;
 
@@ -101,6 +102,31 @@ mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
   return newLoop;
 }
 
+SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
+    OpBuilder &builder, ArrayRef<scf::ForOp> loopNest,
+    ValueRange newIterOperands, NewYieldValueFn newYieldValueFn) {
+  if (loopNest.empty())
+    return {};
+  SmallVector<scf::ForOp> newLoopNest(loopNest.size());
+
+  newLoopNest.back() = replaceLoopWithNewYields(
+      builder, loopNest.back(), newIterOperands, newYieldValueFn);
+
+  for (unsigned loopDepth :
+       llvm::reverse(llvm::seq<unsigned>(0, loopNest.size() - 1))) {
+    NewYieldValueFn fn = [&](OpBuilder &innerBuilder, Location loc,
+                             ArrayRef<BlockArgument> innerNewBBArgs) {
+      SmallVector<Value> newYields(
+          newLoopNest[loopDepth + 1]->getResults().take_back(
+              newIterOperands.size()));
+      return newYields;
+    };
+    newLoopNest[loopDepth] = replaceLoopWithNewYields(
+        builder, loopNest[loopDepth], newIterOperands, fn);
+  }
+  return newLoopNest;
+}
+
 /// Outline a region with a single block into a new FuncOp.
 /// Assumes the FuncOp result types is the type of the yielded operands of the
 /// single block. This constraint makes it easy to determine the result.
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
new file mode 100644 (file)
index 0000000..1e09432
--- /dev/null
@@ -0,0 +1,194 @@
+// RUN: mlir-opt -test-tiling-interface -split-input-file %s | FileCheck %s
+
+func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul {__internal_linalg_transform__ = "simple_gemm"}
+      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+//      CHECK: func.func @simple_matmul(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
+//  CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
+//  CHECK-DAG:   %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//  CHECK-DAG:   %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+//      CHECK:   %[[OUTER:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
+// CHECK-SAME:       iter_args(%[[INIT0:.+]] = %[[ARG2]])
+//      CHECK:     %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[M]]]
+//      CHECK:     %[[INNER:[a-zA-Z0-9]+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
+// CHECK-SAME:         iter_args(%[[INIT1:.+]] = %[[INIT0]])
+//      CHECK:       %[[TS_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[N]]]
+//  CHECK-DAG:       %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME:           [%[[IV0]], 0] [%[[TS_Y]], %[[K]]] [1, 1]
+//  CHECK-DAG:       %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]]
+// CHECK-SAME:           [0, %[[IV1]]] [%[[K]], %[[TS_X]]] [1, 1]
+//  CHECK-DAG:       %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT1]]
+// CHECK-SAME:           [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
+//      CHECK:       %[[GEMM_TILE:.+]] = linalg.matmul
+// CHECK-SAME:           ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME:           outs(%[[INIT_TILE]] :
+//      CHECK:       %[[UPDATE:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[INIT1]]
+// CHECK-SAME:           [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
+//      CHECK:       scf.yield %[[UPDATE]]
+//      CHECK:     scf.yield %[[INNER]]
+//      CHECK:   return %[[OUTER]]
+
+// -----
+
+func.func @simple_matmul_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
+    %arg2 : memref<?x?xf32>) {
+  linalg.matmul {__internal_linalg_transform__ = "simple_gemm_memref"}
+      ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%arg2 : memref<?x?xf32>)
+  return
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)>
+//      CHECK: func.func @simple_matmul_memref(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
+//  CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
+//  CHECK-DAG:   %[[C30:.+]] = arith.constant 30 : index
+//  CHECK-DAG:   %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
+//  CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
+//      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
+//      CHECK:     %[[TS_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[M]]]
+//      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
+//      CHECK:       %[[TS_N:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[N]]]
+//      CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C30]]
+//      CHECK:         %[[TS_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[C30]], %[[K]]]
+//  CHECK-DAG:         %[[LHS_TILE:.+]] = memref.subview %[[ARG0]]
+// CHECK-SAME:             [%[[IV0]], %[[IV2]]] [%[[TS_M]], %[[TS_K]]] [1, 1]
+//  CHECK-DAG:         %[[RHS_TILE:.+]] = memref.subview %[[ARG1]]
+// CHECK-SAME:             [%[[IV2]], %[[IV1]]] [%[[TS_K]], %[[TS_N]]] [1, 1]
+//  CHECK-DAG:         %[[OUT_TILE:.+]] = memref.subview %[[ARG2]]
+// CHECK-SAME:             [%[[IV0]], %[[IV1]]] [%[[TS_M]], %[[TS_N]]] [1, 1]
+//      CHECK:         linalg.matmul
+// CHECK-SAME:             ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME:             outs(%[[OUT_TILE]] :
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
+func.func @multi_result(%arg0 : tensor<128x200x300xf32>) -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>) {
+  %init0 = linalg.init_tensor [128, 300, 200] : tensor<128x300x200xf32>
+  %init1 = linalg.init_tensor [300, 128, 200] : tensor<300x128x200xf32>
+  %0:2 = linalg.generic {
+      indexing_maps = [#map0, #map1, #map2],
+      iterator_types = ["parallel", "parallel", "parallel"]}
+      {__internal_linalg_transform__ = "parallel_generic_transpose"}
+      ins(%arg0 : tensor<128x200x300xf32>)
+      outs(%init0, %init1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      linalg.yield %b0, %b0 : f32, f32
+    } -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>)
+  return %0#0, %0#1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+//      CHECK: func.func @multi_result(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>)
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
+//  CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
+//  CHECK-DAG:   %[[C128:.+]] = arith.constant 128 : index
+//  CHECK-DAG:   %[[C300:.+]] = arith.constant 300 : index
+//  CHECK-DAG:   %[[INIT0:.+]] = linalg.init_tensor [128, 300, 200]
+//  CHECK-DAG:   %[[INIT1:.+]] = linalg.init_tensor [300, 128, 200]
+//      CHECK:   %[[OUTER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C10]]
+// CHECK-SAME:       iter_args(%[[ARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
+//      CHECK:     %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[C128]]]
+//      CHECK:     %[[INNER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C20]]
+// CHECK-SAME:         iter_args(%[[ARG3:[a-zA-Z0-9]+]] = %[[ARG1]], %[[ARG4:[a-zA-Z0-9]+]] = %[[ARG2]])
+//      CHECK:       %[[TS_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[C300]]]
+//  CHECK-DAG:       %[[ARG_TILE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME:           [%[[IV0]], 0, %[[IV1]]] [%[[TS_Y]], 200, %[[TS_X]]] [1, 1, 1]
+//  CHECK-DAG:       %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ARG3]]
+// CHECK-SAME:           [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], %[[TS_X]], 200] [1, 1, 1]
+//  CHECK-DAG:       %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ARG4]]
+// CHECK-SAME:           [%[[IV1]], %[[IV0]], 0] [%[[TS_X]], %[[TS_Y]], 200] [1, 1, 1]
+//      CHECK:       %[[RESULT_TILE:.+]]:2 = linalg.generic
+// CHECK-SAME:           ins(%[[ARG_TILE]] :
+// CHECK-SAME:           outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
+//      CHECK:       %[[UPDATE0:.+]] = tensor.insert_slice %[[RESULT_TILE]]#0 into %[[ARG3]]
+// CHECK-SAME:           [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], %[[TS_X]], 200] [1, 1, 1]
+//      CHECK:       %[[UPDATE1:.+]] = tensor.insert_slice %[[RESULT_TILE]]#1 into %[[ARG4]]
+// CHECK-SAME:           [%[[IV1]], %[[IV0]], 0] [%[[TS_X]], %[[TS_Y]], 200] [1, 1, 1]
+//      CHECK:       scf.yield %[[UPDATE0]], %[[UPDATE1]]
+//      CHECK:     scf.yield %[[INNER]]#0, %[[INNER]]#1
+//      CHECK:   return %[[OUTER]]#0, %[[OUTER]]#1
+
+// -----
+
+func.func @conv2D(%arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,
+    %arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.conv_2d_nhwc_hwcf {
+      strides = dense<[2, 3]> : tensor<2xi64>,
+      dilation = dense<[4, 5]> : tensor<2xi64>,
+      __internal_linalg_transform__ = "simple_conv"}
+      ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+      outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0 * 2 - 2)>
+//  CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0)[s0] -> (d0 + s0 * 3 - 3)>
+//      CHECK: func.func @conv2D(
+// CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME:     %[[FILTER:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
+//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
+//  CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
+//  CHECK-DAG:   %[[C30:.+]] = arith.constant 30 : index
+//  CHECK-DAG:   %[[N:.+]] = tensor.dim %[[INPUT]], %[[C0]]
+//  CHECK-DAG:   %[[C:.+]] = tensor.dim %[[INPUT]], %[[C3]]
+//  CHECK-DAG:   %[[P:.+]] = tensor.dim %[[FILTER]], %[[C0]]
+//  CHECK-DAG:   %[[Q:.+]] = tensor.dim %[[FILTER]], %[[C1]]
+//  CHECK-DAG:   %[[F:.+]] = tensor.dim %[[FILTER]], %[[C3]]
+//  CHECK-DAG:   %[[R:.+]] = tensor.dim %[[INIT]], %[[C1]]
+//  CHECK-DAG:   %[[S:.+]] = tensor.dim %[[INIT]], %[[C2]]
+//      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[P]] step %[[C10]]
+// CHECK-SAME:       iter_args(%[[INIT0:.+]] = %[[INIT]])
+//      CHECK:     %[[TS_P:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[P]]]
+//      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[Q]] step %[[C20]]
+// CHECK-SAME:         iter_args(%[[INIT1:.+]] = %[[INIT0]])
+//      CHECK:       %[[TS_Q:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[Q]]]
+//      CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C30]]
+// CHECK-SAME:           iter_args(%[[INIT2:.+]] = %[[INIT1]])
+//  CHECK-DAG:         %[[TS_C:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[C30]], %[[C]]]
+//  CHECK-DAG:         %[[TS_H:.+]] = affine.apply #[[MAP3]](%[[TS_P]])[%[[R]]]
+//  CHECK-DAG:         %[[TS_W:.+]] = affine.apply #[[MAP4]](%[[TS_Q]])[%[[S]]]
+//  CHECK-DAG:         %[[INPUT_TILE:.+]] = tensor.extract_slice %[[INPUT]]
+// CHECK-SAME:             [0, %[[IV0]], %[[IV1]], %[[IV2]]] [%[[N]], %[[TS_H]], %[[TS_W]], %[[TS_C]]]
+//  CHECK-DAG:         %[[FILTER_TILE:.+]] = tensor.extract_slice %[[FILTER]]
+// CHECK-SAME:             [%[[IV0]], %[[IV1]], %[[IV2]], 0] [%[[TS_P]], %[[TS_Q]], %[[TS_C]], %[[F]]]
+//  CHECK-DAG:         %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT2]]
+// CHECK-SAME:             [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]]
+//      CHECK:         %[[CONV_TILE:.+]] = linalg.conv_2d_nhwc_hwcf
+// CHECK-SAME:             dilation = dense<[4, 5]> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>
+// CHECK-SAME:             ins(%[[INPUT_TILE]], %[[FILTER_TILE]] :
+// CHECK-SAME:             outs(%[[INIT_TILE]] :
+//      CHECK:         tensor.insert_slice %[[CONV_TILE]] into %[[INIT2]]
+// CHECK-SAME:             [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]]
index 97149dcf38d8e5157f6023caa8e298ed9a129fea..88e55e77a3fb93e4a218b748b73923298d56ccae 100644 (file)
@@ -1,6 +1,7 @@
 add_subdirectory(Analysis)
 add_subdirectory(Conversion)
 add_subdirectory(Dialect)
+add_subdirectory(Interfaces)
 add_subdirectory(IR)
 add_subdirectory(Pass)
 add_subdirectory(Reducer)
diff --git a/mlir/test/lib/Interfaces/CMakeLists.txt b/mlir/test/lib/Interfaces/CMakeLists.txt
new file mode 100644 (file)
index 0000000..4a0567a
--- /dev/null
@@ -0,0 +1 @@
+add_subdirectory(TilingInterface)
diff --git a/mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt b/mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt
new file mode 100644 (file)
index 0000000..437e39c
--- /dev/null
@@ -0,0 +1,15 @@
+add_mlir_library(MLIRTilingInterfaceTestPasses
+  TestTilingInterface.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  LINK_LIBS PUBLIC
+  MLIRAffine
+  MLIRArithmetic
+  MLIRLinalg
+  MLIRLinalgTransforms
+  MLIRMemRef
+  MLIRSCF
+  MLIRSCFTransforms
+  MLIRTensor
+  )
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
new file mode 100644 (file)
index 0000000..c3795c3
--- /dev/null
@@ -0,0 +1,126 @@
+//===- TestTilingInterface.cpp - Test tiling using `TilingInterface` -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing tiling operations using
+// `TilingInterface`.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/TileUsingInterface.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Construct a generic pattern applied to all TilingInterface ops that verify
+/// `filter`.
+struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp {
+  TestTileUsingSCFForOpWithFilter(MLIRContext *context,
+                                  scf::SCFTilingOptions options,
+                                  linalg::LinalgTransformationFilter filter =
+                                      linalg::LinalgTransformationFilter(),
+                                  PatternBenefit benefit = 1)
+      : scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {}
+
+  /// Construct a generic pattern applied to `opName`.
+  TestTileUsingSCFForOpWithFilter(StringRef opName, MLIRContext *context,
+                                  scf::SCFTilingOptions options,
+                                  linalg::LinalgTransformationFilter filter =
+                                      linalg::LinalgTransformationFilter(),
+                                  PatternBenefit benefit = 1)
+      : scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {}
+
+  LogicalResult matchAndRewrite(TilingInterface op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(filter.checkAndNotify(rewriter, op)))
+      return failure();
+
+    FailureOr<scf::SCFTilingResult> tilingResult =
+        returningMatchAndRewrite(op, rewriter);
+    if (failed(tilingResult)) {
+      return failure();
+    }
+    filter.replaceLinalgTransformationFilter(rewriter, tilingResult->tiledOp);
+    return success();
+  }
+
+private:
+  linalg::LinalgTransformationFilter filter;
+};
+
+struct TestTilingInterfacePass
+    : public PassWrapper<TestTilingInterfacePass, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTilingInterfacePass)
+
+  TestTilingInterfacePass() = default;
+  TestTilingInterfacePass(const TestTilingInterfacePass &pass)
+      : PassWrapper(pass) {}
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<AffineDialect, memref::MemRefDialect, scf::SCFDialect,
+                    tensor::TensorDialect>();
+    linalg::registerTilingInterfaceExternalModels(registry);
+  }
+  StringRef getArgument() const final { return "test-tiling-interface"; }
+  StringRef getDescription() const final {
+    return "Test tiling using TilingInterface";
+  }
+
+  void runOnOperation() override;
+};
+} // namespace
+
+static void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns) {
+  auto addPatternForTiling = [&](ArrayRef<int64_t> tileSizes,
+                                 StringRef filterName) {
+    scf::SCFTilingOptions tilingOptions;
+    tilingOptions.setTileSizes(tileSizes);
+    linalg::LinalgTransformationFilter filter(
+        StringAttr::get(context, filterName),
+        StringAttr::get(context, "tiled"));
+    patterns.add<TestTileUsingSCFForOpWithFilter>(context, tilingOptions,
+                                                  filter);
+  };
+  // 1. Tiling M and N dims of `linalg.matmul` on tensors.
+  addPatternForTiling({10, 20}, "simple_gemm");
+  // 2. Tiling M, N and K of `linalg.matmul` on buffers.
+  addPatternForTiling({10, 20, 30}, "simple_gemm_memref");
+  // 3. Tiling 3D parallel generic op which implements a transpose
+  addPatternForTiling({10, 0, 20}, "parallel_generic_transpose");
+  // 4. Tiling 2D conv op.
+  addPatternForTiling({0, 0, 0, 0, 10, 20, 30}, "simple_conv");
+}
+
+void TestTilingInterfacePass::runOnOperation() {
+  MLIRContext *context = &getContext();
+
+  RewritePatternSet tilingPatterns(context);
+  addTestPatterns(context, tilingPatterns);
+  if (failed(applyPatternsAndFoldGreedily(getOperation(),
+                                          std::move(tilingPatterns))))
+    return signalPassFailure();
+}
+
+namespace mlir {
+namespace test {
+void registerTestTilingInterface() {
+  PassRegistration<TestTilingInterfacePass>();
+}
+} // namespace test
+} // namespace mlir
index a8172b83f1a476edf1ca050655789f36b50bcc55..97b082e83e5dac5d583570bcafc2017896327a5c 100644 (file)
@@ -33,6 +33,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRTestRewrite
     MLIRTestTransformDialect
     MLIRTestTransforms
+    MLIRTilingInterfaceTestPasses
     MLIRVectorTestPasses
     )
 endif()
index aa94294b4ea8de7491ae4a88195edd680861bfe7..b50cfa964290fa270aea6c654c0a7304cb4da607 100644 (file)
@@ -111,6 +111,7 @@ void registerTestRecursiveTypesPass();
 void registerTestSCFUtilsPass();
 void registerTestSliceAnalysisPass();
 void registerTestTensorTransforms();
+void registerTestTilingInterface();
 void registerTestTransformDialectInterpreterPass();
 void registerTestVectorLowerings();
 } // namespace test
@@ -206,6 +207,7 @@ void registerTestPasses() {
   mlir::test::registerTestSCFUtilsPass();
   mlir::test::registerTestSliceAnalysisPass();
   mlir::test::registerTestTensorTransforms();
+  mlir::test::registerTestTilingInterface();
   mlir::test::registerTestTransformDialectInterpreterPass();
   mlir::test::registerTestVectorLowerings();
 }
index 5fde5e771e538ae86026cdceca7da122e90ad618..49c08f7c01535da6937a3a539b3cc98e2bc5a67e 100644 (file)
@@ -1864,6 +1864,7 @@ cc_library(
         "include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h",
         "include/mlir/Dialect/SCF/Passes.h",
         "include/mlir/Dialect/SCF/Patterns.h",
+        "include/mlir/Dialect/SCF/TileUsingInterface.h",
         "include/mlir/Dialect/SCF/Transforms.h",
     ],
     includes = ["include"],
@@ -1883,6 +1884,7 @@ cc_library(
         ":SCFUtils",
         ":Support",
         ":TensorDialect",
+        ":TilingInterface",
         ":Transforms",
         "//llvm:Support",
     ],
@@ -2645,6 +2647,7 @@ cc_library(
         exclude = [
             "include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h",
             "include/mlir/Dialect/SCF/Patterns.h",
+            "include/mlir/Dialect/SCF/TileUsingInterface.h",
             "include/mlir/Dialect/SCF/Transforms.h",
         ],
     ),
@@ -6313,6 +6316,7 @@ cc_binary(
         "//mlir/test:TestSPIRV",
         "//mlir/test:TestShapeDialect",
         "//mlir/test:TestTensor",
+        "//mlir/test:TestTilingInterface",
         "//mlir/test:TestTosaDialect",
         "//mlir/test:TestTransformDialect",
         "//mlir/test:TestTransforms",
@@ -7492,6 +7496,7 @@ cc_library(
         ":TensorTilingInterfaceImpl",
         ":TensorTransforms",
         ":TensorUtils",
+        ":TilingInterface",
         ":TransformUtils",
         ":Transforms",
         ":VectorDialect",
index 742e7b610453adab893aec5e3a45bc8cbd30d64d..fa89b9c990c0be70d15b6462c695d084cea61228 100644 (file)
@@ -293,6 +293,28 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "TestTilingInterface",
+    srcs = glob(["lib/Interfaces/TilingInterface/*.cpp"]),
+    includes = ["lib/Dialect/Test"],
+    deps = [
+        "//llvm:Support",
+        "//mlir:Affine",
+        "//mlir:ArithmeticDialect",
+        "//mlir:FuncDialect",
+        "//mlir:IR",
+        "//mlir:LinalgDialect",
+        "//mlir:LinalgTransforms",
+        "//mlir:MemRefDialect",
+        "//mlir:Pass",
+        "//mlir:SCFDialect",
+        "//mlir:SCFTransforms",
+        "//mlir:TensorDialect",
+        "//mlir:TilingInterface",
+        "//mlir:Transforms",
+    ],
+)
+
 cc_library(
     name = "TestPass",
     srcs = glob(["lib/Pass/*.cpp"]),