--- /dev/null
+//===- TilingInterfaceImpl.h - Implementation of TilingInterface ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_TILINGINTERFACEIMPL_H
+#define MLIR_DIALECT_LINALG_TILINGINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace linalg {
+void registerTilingInterfaceExternalModels(DialectRegistry ®istry);
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_TILINGINTERFACEIMPL_H
SmallVector<Value> computeTileOffsets(OpBuilder &b, Location loc,
ValueRange ivs, ValueRange tileSizes);
-/// Compute tile sizes, given a list of loop `ivs`, `tileSizes` and dimension
+/// Compute tile sizes, given a list of `tileSizes` and dimension
/// sizes (`sizeBounds`). In case a tile size is zero (i.e., no tiling), the
/// corresponding result size is the corresponding value from `sizeBounds`.
/// Note: The returned tile sizes are closed intervals.
-SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc, ValueRange ivs,
+SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc,
ValueRange tileSizes,
ArrayRef<Value> sizeBounds);
--- /dev/null
+//===- TileUsingInterface.h - Tiling ops using TilingInterface --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SCF_TILEUSINGINTERFACE_H
+#define MLIR_DIALECT_SCF_TILEUSINGINTERFACE_H
+
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/TilingInterface.h"
+
+namespace mlir {
+class Operation;
+class PatternRewriter;
+class TilingInterface;
+} // namespace mlir
+
+namespace mlir {
+namespace scf {
+
+using SCFTileSizeComputationFunction =
+ std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>;
+
+/// Options to use to control tiling.
+struct SCFTilingOptions {
+ /// Computation function that returns the tile sizes for each operation.
+ /// Delayed construction of constant tile sizes should occur to interoperate
+ /// with folding.
+ SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
+
+ SCFTilingOptions &
+ setTileSizeComputationFunction(SCFTileSizeComputationFunction fun) {
+ tileSizeComputationFunction = std::move(fun);
+ return *this;
+ }
+ /// Set the `tileSizeComputationFunction` to return the values `ts`. The
+ /// values must not fold away when tiling. Otherwise, use a more robust
+ /// `tileSizeComputationFunction`.
+ SCFTilingOptions &setTileSizes(const SmallVector<Value, 4> &ts) {
+ tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; };
+ return *this;
+ }
+ /// Convenience function to set the `tileSizeComputationFunction` to a
+ /// function that computes tile sizes at the point they are needed. Allows
+ /// proper interaction with folding.
+ SCFTilingOptions &setTileSizes(ArrayRef<int64_t> ts);
+};
+
+struct SCFTilingResult {
+ Operation *tiledOp;
+ SmallVector<scf::ForOp> loops;
+};
+
+/// Pattern to tile an op that implementas the `TilingInterface` using
+/// `scf.for` for iterating over the tiles.
+struct TileUsingSCFForOp : public OpInterfaceRewritePattern<TilingInterface> {
+ /// Construct a generic pattern applied to all TilingInterface ops.
+ TileUsingSCFForOp(MLIRContext *context, SCFTilingOptions options,
+ PatternBenefit benefit = 1);
+
+ /// Construct a generic pattern applied to `opName`.
+ TileUsingSCFForOp(StringRef opName, MLIRContext *context,
+ SCFTilingOptions options, PatternBenefit benefit = 1);
+
+ /// `matchAndRewrite` implementation that returns the significant transformed
+ /// pieces of IR.
+ FailureOr<SCFTilingResult>
+ returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const;
+
+ LogicalResult matchAndRewrite(TilingInterface op,
+ PatternRewriter &rewriter) const override {
+ return returningMatchAndRewrite(op, rewriter);
+ }
+
+private:
+ /// Options to control tiling;
+ SCFTilingOptions options;
+};
+
+} // namespace scf
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SCF_TILEUSINGINTERFACE_H
#ifndef MLIR_DIALECT_SCF_UTILS_UTILS_H_
#define MLIR_DIALECT_SCF_UTILS_UTILS_H_
+#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
class FuncOp;
} // namespace func
-namespace scf {
-class IfOp;
-class ForOp;
-class ParallelOp;
-} // namespace scf
-
/// Replace the `loop` with `newIterOperands` added as new initialization
/// values. `newYieldValuesFn` is a callback that can be used to specify
/// the additional values to be yielded by the loop. The number of
ValueRange newIterOperands,
const NewYieldValueFn &newYieldValuesFn);
+/// Update a perfectly nested loop nest to yield new values from the innermost
+/// loop and propagating it up through the loop nest. This function
+/// - Expects `loopNest` to be a perfectly nested loop with outer most loop
+/// first and innermost loop last.
+/// - `newIterOperands` are the initialization values to be used for the
+/// outermost loop
+/// - `newYielValueFn` is the callback that generates the new values to be
+/// yielded from within the innermost loop.
+/// - The original loops are not erased, but are left in a "no-op" state where
+/// the body of the loop just yields the basic block arguments that correspond
+/// to the initialization values of a loop. The original loops are dead after
+/// this method.
+/// - All uses of the `newIterOperands` within the generated new loop
+/// are replaced with the corresponding `BlockArgument` in the loop body.
+SmallVector<scf::ForOp>
+replaceLoopNestWithNewYields(OpBuilder &builder, ArrayRef<scf::ForOp> loopNest,
+ ValueRange newIterOperands,
+ NewYieldValueFn newYieldValueFn);
+
/// Outline a region with a single block into a new FuncOp.
/// Assumes the FuncOp result types is the type of the yielded operands of the
/// single block. This constraint makes it easy to determine the result.
/*defaultImplementation=*/[{
return {};
}]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to return the position of the result tile computed by the tiled operation.
+
+ Specifies what tile of the result of the original tensor is computed
+ by the tiled implementation. Expects the same `offsets` and `sizes` as
+ used to obtain the tiled implementation of the operation.
+ }],
+ /*retType=*/"LogicalResult",
+ /*methodName=*/"getResultTilePosition",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "unsigned":$resultNumber,
+ "ArrayRef<OpFoldResult> ":$offsets,
+ "ArrayRef<OpFoldResult> ":$sizes,
+ "SmallVector<OpFoldResult> &":$resultOffsets,
+ "SmallVector<OpFoldResult> &":$resultSizes),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
>
];
}
SparseTensorRewriting.cpp
SplitReduction.cpp
Tiling.cpp
+ TilingInterfaceImpl.cpp
Transforms.cpp
Vectorization.cpp
// Compute offsets and sizes of ExtractSliceOp.
SmallVector<Value> offsets =
computeTileOffsets(b, loc, localIvs, tileSizes);
- SmallVector<Value> sizes =
- computeTileSizes(b, loc, localIvs, tileSizes, allDims);
+ SmallVector<Value> sizes = computeTileSizes(b, loc, tileSizes, allDims);
// Create ExtractSliceOp: Extract a tile from the tensor::PadOp.
// Note: The tensor::PadOp is located outside of the loop nest. It is
// later moved inside by ExtractSliceOfPadTensorSwapPattern.
--- /dev/null
+//===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/TilingInterface.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+
+/// External model implementation of TilingInterface for LinalgOps. An external
+/// model implementation is used for now till the use of `TilingInterface` is
+/// on-par with the current Linalg tiling + fusion patterns. Once it is
+/// maybe possible to move this into the op-definition (though there are
+/// advantages to leaving it as an external model)
+template <typename LinalgOpTy>
+struct LinalgOpTilingInterface
+ : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
+ LinalgOpTy> {
+
+ /// Return the destination operands.
+ SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
+ return llvm::cast<LinalgOp>(op).getOutputOperands();
+ }
+
+ /// Return the loop iterator type.
+ SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
+ LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
+ return llvm::to_vector(
+ llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) {
+ return strAttr.cast<StringAttr>().getValue();
+ }));
+ }
+
+ /// Return the iteration domain range.
+ SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
+ Location loc = op->getLoc();
+ LinalgOp linalgOp = cast<LinalgOp>(op);
+ auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc);
+ AffineMap map = linalgOp.getShapesToLoopsMap();
+ Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = b.create<arith::ConstantIndexOp>(loc, 1);
+ return llvm::to_vector(llvm::map_range(
+ applyMapToValues(b, loc, map, allShapesSizes), [&](Value v) {
+ return Range{zero, v, one};
+ }));
+ }
+
+ // Instantiate the tiled implementation of the operation.
+ SmallVector<Operation *>
+ getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ bool tileDestOperands) const {
+ // Leave the `sizeBounds` value empty. That is only needed when the `sizes`
+ // specified could lead to out of bounds accesses.
+ Location loc = op->getLoc();
+ LinalgOp linalgOp = cast<LinalgOp>(op);
+ SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
+ SmallVector<Value, 4> tiledOperands = makeTiledShapes(
+ b, loc, linalgOp, valuesToTile,
+ getValueOrCreateConstantIndexOp(b, loc, offsets),
+ getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true);
+
+ SmallVector<Type> resultTensorTypes = llvm::to_vector(llvm::map_range(
+ linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) {
+ return tiledOperands[opOperand->getOperandNumber()].getType();
+ }));
+
+ Operation *tiledOp =
+ linalgOp.clone(b, loc, resultTensorTypes, tiledOperands);
+
+ return {tiledOp};
+ }
+
+ // Return the details of the output tile generated by the tiled
+ // implementation.
+ LogicalResult
+ getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) const {
+ Location loc = op->getLoc();
+ LinalgOp linalgOp = cast<LinalgOp>(op);
+
+ AffineExpr d0;
+ bindDims(b.getContext(), d0);
+
+ auto fullyComposeAffineMapAndOperands = [](OpBuilder &builder, Location loc,
+ AffineExpr expr,
+ ValueRange operands) -> Value {
+ AffineMap map = AffineMap::inferFromExprList({expr}).front();
+ SmallVector<Value> normalizedOperands(operands.begin(), operands.end());
+ mlir::fullyComposeAffineMapAndOperands(&map, &normalizedOperands);
+ canonicalizeMapAndOperands(&map, &normalizedOperands);
+ return builder.createOrFold<AffineApplyOp>(loc, map, normalizedOperands);
+ };
+
+ SmallVector<Value> sizeVals =
+ getValueOrCreateConstantIndexOp(b, loc, sizes);
+ SmallVector<Value> subShapeSizes =
+ llvm::to_vector(llvm::map_range(sizeVals, [&](Value v) {
+ return fullyComposeAffineMapAndOperands(b, loc, d0 - 1, v);
+ }));
+ OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
+ Value sliceOpResult =
+ makeTiledShape(b, loc, outOperand->get(), sizeVals,
+ linalgOp.getTiedIndexingMap(outOperand),
+ getValueOrCreateConstantIndexOp(b, loc, offsets),
+ /*ubs*/ {}, subShapeSizes, true);
+ auto sliceOp = sliceOpResult.getDefiningOp<tensor::ExtractSliceOp>();
+ if (!sliceOp)
+ return failure();
+ resultOffsets = sliceOp.getMixedOffsets();
+ resultSizes = sliceOp.getMixedSizes();
+ return success();
+ }
+};
+
+} // namespace
+
+template <typename OpType> static void registerOne(MLIRContext *ctx) {
+ OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
+}
+
+/// Variadic helper function.
+template <typename... OpTypes> static void registerAll(MLIRContext *ctx) {
+ // FIXME: In c++17 this can be simplified by using 'fold expressions'.
+ (void)std::initializer_list<int>{0, (registerOne<OpTypes>(ctx), 0)...};
+}
+
+#define GET_OP_LIST
+
+void mlir::linalg::registerTilingInterfaceExternalModels(
+ DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
+ registerOne<linalg::GenericOp>(ctx);
+ registerAll<
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
+ >(ctx);
+ });
+}
return offsets;
}
-SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc, ValueRange ivs,
+SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc,
ValueRange tileSizes,
ArrayRef<Value> sizeBounds) {
SmallVector<Value> sizes;
// that define tile subshapes.
SmallVector<Value> lbs = computeTileOffsets(b, loc, ivs, tileSizes);
SmallVector<Value> subShapeSizes =
- computeTileSizes(b, loc, ivs, tileSizes, sizeBounds);
+ computeTileSizes(b, loc, tileSizes, sizeBounds);
assert(static_cast<int64_t>(valuesToTile.size()) ==
linalgOp.getNumInputsAndOutputs() &&
ParallelLoopFusion.cpp
ParallelLoopTiling.cpp
StructuralTypeConversions.cpp
+ TileUsingInterface.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF
--- /dev/null
+//===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the tiling using TilingInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/TileUsingInterface.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "tile-using-interface"
+
+using namespace mlir;
+
+scf::SCFTilingOptions &
+scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
+ assert(!tileSizeComputationFunction && "tile sizes already set");
+ SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
+ tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPointToStart(
+ &op->getParentOfType<func::FuncOp>().getBody().front());
+ return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
+ Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
+ return v;
+ }));
+ };
+ return *this;
+}
+
+/// Generate an empty loop nest that represents the tiled loop nest shell.
+/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
+/// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
+/// - In `offsets` and `sizes` return the multi-dimensional offset and size of
+/// the
+/// tile processed within the inner most loop.
+static SmallVector<scf::ForOp>
+generateTileLoopNest(OpBuilder &builder, Location loc,
+ ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
+ SmallVector<OpFoldResult> &offsets,
+ SmallVector<OpFoldResult> &sizes) {
+ assert(!loopRanges.empty() && "expected at least one loop range");
+ assert(loopRanges.size() == tileSizeVals.size() &&
+ "expected as many tile sizes as loop ranges");
+ OpBuilder::InsertionGuard guard(builder);
+ SmallVector<scf::ForOp> loops;
+ offsets.resize(loopRanges.size());
+ sizes.resize(loopRanges.size());
+
+ // The tile size to use (to avoid out of bounds access) is minimum of
+ // `tileSize` and `ub - iv`, where `iv` is the induction variable
+ // of the tiled loop.
+ AffineExpr s0, s1, d0;
+ bindDims(builder.getContext(), d0);
+ bindSymbols(builder.getContext(), s0, s1);
+ AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext());
+
+ for (auto loopRange : llvm::enumerate(loopRanges)) {
+ // No loops if tile size is zero. Set offset and size to the loop
+ // offset and size.
+ if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) {
+ offsets[loopRange.index()] = loopRange.value().offset;
+ sizes[loopRange.index()] = loopRange.value().size;
+ continue;
+ }
+
+ auto loop = builder.create<scf::ForOp>(
+ loc, loopRange.value().offset, loopRange.value().size,
+ tileSizeVals[loopRange.index()], ValueRange{},
+ [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
+ ValueRange /*iterArgs*/) {
+ Value boundedTileSize = builder.create<AffineMinOp>(
+ bodyLoc, minMap,
+ ValueRange{iv, tileSizeVals[loopRange.index()],
+ loopRange.value().size});
+ sizes[loopRange.index()] = boundedTileSize;
+ builder.create<scf::YieldOp>(loc);
+ });
+ offsets[loopRange.index()] = loop.getInductionVar();
+ loops.push_back(loop);
+ builder.setInsertionPoint(loop.getBody()->getTerminator());
+ }
+ return loops;
+}
+
+scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context,
+ scf::SCFTilingOptions options,
+ PatternBenefit benefit)
+ : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+ options(std::move(options)) {}
+
+scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName,
+ MLIRContext *context,
+ scf::SCFTilingOptions options,
+ PatternBenefit benefit)
+ : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+ options(std::move(options)) {}
+
+FailureOr<scf::SCFTilingResult>
+scf::TileUsingSCFForOp::returningMatchAndRewrite(
+ TilingInterface op, PatternRewriter &rewriter) const {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointAfter(op);
+
+ if (!options.tileSizeComputationFunction) {
+ return rewriter.notifyMatchFailure(
+ op, "missing tile size computation function");
+ }
+
+ // 1. Get the range of the loops that are represented by the operation.
+ SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
+ size_t numLoops = iterationDomain.size();
+ if (numLoops == 0) {
+ return rewriter.notifyMatchFailure(
+ op, "unable to tile op with no iteration domain");
+ }
+
+ // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
+ // skips tiling a particular dimension. This convention is significantly
+ // simpler to handle instead of adjusting affine maps to account for missing
+ // dimensions.
+ SmallVector<Value, 4> tileSizeVector =
+ options.tileSizeComputationFunction(rewriter, op);
+ if (tileSizeVector.size() < iterationDomain.size()) {
+ auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
+ tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
+ }
+
+ scf::SCFTilingResult tilingResult;
+ SmallVector<OpFoldResult> offsets, sizes;
+ {
+ // 3. Materialize an empty loop nest that iterates over the tiles. These
+ // loops for now do not return any values even if the original operation has
+ // results.
+ tilingResult.loops = generateTileLoopNest(
+ rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
+
+ LLVM_DEBUG({
+ if (!tilingResult.loops.empty()) {
+ llvm::errs() << "LoopNest shell :\n";
+ tilingResult.loops.front().dump();
+ llvm::errs() << "\n";
+ }
+ });
+
+ // 4. Generate the tiled implementation within the inner most loop.
+ if (!tilingResult.loops.empty())
+ rewriter.setInsertionPoint(
+ tilingResult.loops.back().getBody()->getTerminator());
+ SmallVector<Operation *> tiledImplementation = op.getTiledImplementation(
+ rewriter, op.getDestinationOperands(rewriter), offsets, sizes, true);
+ if (tiledImplementation.size() != 1) {
+ return rewriter.notifyMatchFailure(
+ op, "expected tiled implementation to return a single op");
+ }
+ tilingResult.tiledOp = tiledImplementation[0];
+
+ LLVM_DEBUG({
+ if (!tilingResult.loops.empty()) {
+ llvm::errs() << "After tiled implementation :\n";
+ tilingResult.loops.front().dump();
+ llvm::errs() << "\n";
+ }
+ });
+ }
+
+ if (op->getNumResults() == 0) {
+ rewriter.eraseOp(op);
+ return tilingResult;
+ }
+
+ // 5. If the original operations has results, modify the loop nest to yield
+ // the replacement values.
+ SmallVector<Value> replacements;
+ if (tilingResult.loops.empty()) {
+ // 5a. If there were no loops, the tiled implementation results are the
+ // replacements.
+ rewriter.replaceOp(op, tilingResult.tiledOp->getResults());
+ return tilingResult;
+ }
+
+ // 5b. `scf.for` with tensor semantics requires the loop nest to yield the
+ // replacement values using destructive updates. Use the `TilingInterface`
+ // to get the position of the result tiles and use that to generate the
+ // destructive update pattern, i.e.,
+ //
+ // ```mlir
+ // scf.for %iv0 = ... {
+ // %0 = tiled_op
+ // }
+ // ```
+ //
+ // is transformed to
+ //
+ // ```mlir
+ // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. {
+ // %0 = tiled_op
+ // %1 = tensor.insert_slice %0 into %arg[..] [..] [..]
+ // scf.yield %1
+ // }
+ // ```
+ NewYieldValueFn yieldValueFn =
+ [&](OpBuilder &b, Location loc,
+ ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
+ SmallVector<Value> yieldedValues;
+ Attribute one = b.getIndexAttr(1);
+ for (auto resultNum : llvm::seq<unsigned>(0, op->getNumResults())) {
+ SmallVector<OpFoldResult> resultTileOffsets, resultTileSizes;
+ if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes,
+ resultTileOffsets,
+ resultTileSizes))) {
+ op.emitOpError("unable to get position of result ")
+ << resultNum << " of the tiled implementation";
+ return {};
+ }
+ SmallVector<OpFoldResult> resultTileStrides(resultTileOffsets.size(),
+ one);
+ Value yieldedValue = b.create<tensor::InsertSliceOp>(
+ op->getLoc(), tilingResult.tiledOp->getResult(resultNum),
+ newBBArgs[resultNum], resultTileOffsets, resultTileSizes,
+ resultTileStrides);
+ yieldedValues.push_back(yieldedValue);
+ }
+ return yieldedValues;
+ };
+ SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields(
+ rewriter, tilingResult.loops, op.getDestinationOperands(rewriter),
+ yieldValueFn);
+ for (auto loop : llvm::enumerate(tilingResult.loops)) {
+ rewriter.eraseOp(loop.value());
+ tilingResult.loops[loop.index()] = newLoops[loop.index()];
+ }
+ rewriter.replaceOp(op, tilingResult.loops.front().getResults());
+ return tilingResult;
+}
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
using namespace mlir;
return newLoop;
}
+SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
+ OpBuilder &builder, ArrayRef<scf::ForOp> loopNest,
+ ValueRange newIterOperands, NewYieldValueFn newYieldValueFn) {
+ if (loopNest.empty())
+ return {};
+ SmallVector<scf::ForOp> newLoopNest(loopNest.size());
+
+ newLoopNest.back() = replaceLoopWithNewYields(
+ builder, loopNest.back(), newIterOperands, newYieldValueFn);
+
+ for (unsigned loopDepth :
+ llvm::reverse(llvm::seq<unsigned>(0, loopNest.size() - 1))) {
+ NewYieldValueFn fn = [&](OpBuilder &innerBuilder, Location loc,
+ ArrayRef<BlockArgument> innerNewBBArgs) {
+ SmallVector<Value> newYields(
+ newLoopNest[loopDepth + 1]->getResults().take_back(
+ newIterOperands.size()));
+ return newYields;
+ };
+ newLoopNest[loopDepth] = replaceLoopWithNewYields(
+ builder, loopNest[loopDepth], newIterOperands, fn);
+ }
+ return newLoopNest;
+}
+
/// Outline a region with a single block into a new FuncOp.
/// Assumes the FuncOp result types is the type of the yielded operands of the
/// single block. This constraint makes it easy to determine the result.
--- /dev/null
+// RUN: mlir-opt -test-tiling-interface -split-input-file %s | FileCheck %s
+
+func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+ %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul {__internal_linalg_transform__ = "simple_gemm"}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK: func.func @simple_matmul(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+// CHECK: %[[OUTER:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[INIT0:.+]] = %[[ARG2]])
+// CHECK: %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[M]]]
+// CHECK: %[[INNER:[a-zA-Z0-9]+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[INIT0]])
+// CHECK: %[[TS_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[N]]]
+// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[IV0]], 0] [%[[TS_Y]], %[[K]]] [1, 1]
+// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]]
+// CHECK-SAME: [0, %[[IV1]]] [%[[K]], %[[TS_X]]] [1, 1]
+// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT1]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
+// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME: outs(%[[INIT_TILE]] :
+// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[INIT1]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
+// CHECK: scf.yield %[[UPDATE]]
+// CHECK: scf.yield %[[INNER]]
+// CHECK: return %[[OUTER]]
+
+// -----
+
+func.func @simple_matmul_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
+ %arg2 : memref<?x?xf32>) {
+ linalg.matmul {__internal_linalg_transform__ = "simple_gemm_memref"}
+ ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%arg2 : memref<?x?xf32>)
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)>
+// CHECK: func.func @simple_matmul_memref(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index
+// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
+// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
+// CHECK: %[[TS_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[M]]]
+// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
+// CHECK: %[[TS_N:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[N]]]
+// CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C30]]
+// CHECK: %[[TS_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[C30]], %[[K]]]
+// CHECK-DAG: %[[LHS_TILE:.+]] = memref.subview %[[ARG0]]
+// CHECK-SAME: [%[[IV0]], %[[IV2]]] [%[[TS_M]], %[[TS_K]]] [1, 1]
+// CHECK-DAG: %[[RHS_TILE:.+]] = memref.subview %[[ARG1]]
+// CHECK-SAME: [%[[IV2]], %[[IV1]]] [%[[TS_K]], %[[TS_N]]] [1, 1]
+// CHECK-DAG: %[[OUT_TILE:.+]] = memref.subview %[[ARG2]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_M]], %[[TS_N]]] [1, 1]
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME: outs(%[[OUT_TILE]] :
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
+func.func @multi_result(%arg0 : tensor<128x200x300xf32>) -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>) {
+ %init0 = linalg.init_tensor [128, 300, 200] : tensor<128x300x200xf32>
+ %init1 = linalg.init_tensor [300, 128, 200] : tensor<300x128x200xf32>
+ %0:2 = linalg.generic {
+ indexing_maps = [#map0, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ {__internal_linalg_transform__ = "parallel_generic_transpose"}
+ ins(%arg0 : tensor<128x200x300xf32>)
+ outs(%init0, %init1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+ linalg.yield %b0, %b0 : f32, f32
+ } -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>)
+ return %0#0, %0#1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK: func.func @multi_result(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
+// CHECK-DAG: %[[C300:.+]] = arith.constant 300 : index
+// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [128, 300, 200]
+// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [300, 128, 200]
+// CHECK: %[[OUTER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[ARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
+// CHECK: %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[C128]]]
+// CHECK: %[[INNER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[ARG3:[a-zA-Z0-9]+]] = %[[ARG1]], %[[ARG4:[a-zA-Z0-9]+]] = %[[ARG2]])
+// CHECK: %[[TS_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[C300]]]
+// CHECK-DAG: %[[ARG_TILE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[IV0]], 0, %[[IV1]]] [%[[TS_Y]], 200, %[[TS_X]]] [1, 1, 1]
+// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ARG3]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], %[[TS_X]], 200] [1, 1, 1]
+// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ARG4]]
+// CHECK-SAME: [%[[IV1]], %[[IV0]], 0] [%[[TS_X]], %[[TS_Y]], 200] [1, 1, 1]
+// CHECK: %[[RESULT_TILE:.+]]:2 = linalg.generic
+// CHECK-SAME: ins(%[[ARG_TILE]] :
+// CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
+// CHECK: %[[UPDATE0:.+]] = tensor.insert_slice %[[RESULT_TILE]]#0 into %[[ARG3]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], %[[TS_X]], 200] [1, 1, 1]
+// CHECK: %[[UPDATE1:.+]] = tensor.insert_slice %[[RESULT_TILE]]#1 into %[[ARG4]]
+// CHECK-SAME: [%[[IV1]], %[[IV0]], 0] [%[[TS_X]], %[[TS_Y]], 200] [1, 1, 1]
+// CHECK: scf.yield %[[UPDATE0]], %[[UPDATE1]]
+// CHECK: scf.yield %[[INNER]]#0, %[[INNER]]#1
+// CHECK: return %[[OUTER]]#0, %[[OUTER]]#1
+
+// -----
+
+func.func @conv2D(%arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,
+ %arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.conv_2d_nhwc_hwcf {
+ strides = dense<[2, 3]> : tensor<2xi64>,
+ dilation = dense<[4, 5]> : tensor<2xi64>,
+ __internal_linalg_transform__ = "simple_conv"}
+ ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+ outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0 * 2 - 2)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0)[s0] -> (d0 + s0 * 3 - 3)>
+// CHECK: func.func @conv2D(
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[FILTER:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index
+// CHECK-DAG: %[[N:.+]] = tensor.dim %[[INPUT]], %[[C0]]
+// CHECK-DAG: %[[C:.+]] = tensor.dim %[[INPUT]], %[[C3]]
+// CHECK-DAG: %[[P:.+]] = tensor.dim %[[FILTER]], %[[C0]]
+// CHECK-DAG: %[[Q:.+]] = tensor.dim %[[FILTER]], %[[C1]]
+// CHECK-DAG: %[[F:.+]] = tensor.dim %[[FILTER]], %[[C3]]
+// CHECK-DAG: %[[R:.+]] = tensor.dim %[[INIT]], %[[C1]]
+// CHECK-DAG: %[[S:.+]] = tensor.dim %[[INIT]], %[[C2]]
+// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[P]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[INIT0:.+]] = %[[INIT]])
+// CHECK: %[[TS_P:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[P]]]
+// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[Q]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[INIT0]])
+// CHECK: %[[TS_Q:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[Q]]]
+// CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C30]]
+// CHECK-SAME: iter_args(%[[INIT2:.+]] = %[[INIT1]])
+// CHECK-DAG: %[[TS_C:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[C30]], %[[C]]]
+// CHECK-DAG: %[[TS_H:.+]] = affine.apply #[[MAP3]](%[[TS_P]])[%[[R]]]
+// CHECK-DAG: %[[TS_W:.+]] = affine.apply #[[MAP4]](%[[TS_Q]])[%[[S]]]
+// CHECK-DAG: %[[INPUT_TILE:.+]] = tensor.extract_slice %[[INPUT]]
+// CHECK-SAME: [0, %[[IV0]], %[[IV1]], %[[IV2]]] [%[[N]], %[[TS_H]], %[[TS_W]], %[[TS_C]]]
+// CHECK-DAG: %[[FILTER_TILE:.+]] = tensor.extract_slice %[[FILTER]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]], %[[IV2]], 0] [%[[TS_P]], %[[TS_Q]], %[[TS_C]], %[[F]]]
+// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT2]]
+// CHECK-SAME: [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]]
+// CHECK: %[[CONV_TILE:.+]] = linalg.conv_2d_nhwc_hwcf
+// CHECK-SAME: dilation = dense<[4, 5]> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>
+// CHECK-SAME: ins(%[[INPUT_TILE]], %[[FILTER_TILE]] :
+// CHECK-SAME: outs(%[[INIT_TILE]] :
+// CHECK: tensor.insert_slice %[[CONV_TILE]] into %[[INIT2]]
+// CHECK-SAME: [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]]
add_subdirectory(Analysis)
add_subdirectory(Conversion)
add_subdirectory(Dialect)
+add_subdirectory(Interfaces)
add_subdirectory(IR)
add_subdirectory(Pass)
add_subdirectory(Reducer)
--- /dev/null
+add_subdirectory(TilingInterface)
--- /dev/null
+add_mlir_library(MLIRTilingInterfaceTestPasses
+ TestTilingInterface.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+
+ LINK_LIBS PUBLIC
+ MLIRAffine
+ MLIRArithmetic
+ MLIRLinalg
+ MLIRLinalgTransforms
+ MLIRMemRef
+ MLIRSCF
+ MLIRSCFTransforms
+ MLIRTensor
+ )
--- /dev/null
+//===- TestTilingInterface.cpp - Test tiling using `TilingInterface` -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing tiling operations using
+// `TilingInterface`.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/TileUsingInterface.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Construct a generic pattern applied to all TilingInterface ops that verify
+/// `filter`.
+struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp {
+ TestTileUsingSCFForOpWithFilter(MLIRContext *context,
+ scf::SCFTilingOptions options,
+ linalg::LinalgTransformationFilter filter =
+ linalg::LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
+ : scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {}
+
+ /// Construct a generic pattern applied to `opName`.
+ TestTileUsingSCFForOpWithFilter(StringRef opName, MLIRContext *context,
+ scf::SCFTilingOptions options,
+ linalg::LinalgTransformationFilter filter =
+ linalg::LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
+ : scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {}
+
+ LogicalResult matchAndRewrite(TilingInterface op,
+ PatternRewriter &rewriter) const override {
+ if (failed(filter.checkAndNotify(rewriter, op)))
+ return failure();
+
+ FailureOr<scf::SCFTilingResult> tilingResult =
+ returningMatchAndRewrite(op, rewriter);
+ if (failed(tilingResult)) {
+ return failure();
+ }
+ filter.replaceLinalgTransformationFilter(rewriter, tilingResult->tiledOp);
+ return success();
+ }
+
+private:
+ linalg::LinalgTransformationFilter filter;
+};
+
+struct TestTilingInterfacePass
+ : public PassWrapper<TestTilingInterfacePass, OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTilingInterfacePass)
+
+ TestTilingInterfacePass() = default;
+ TestTilingInterfacePass(const TestTilingInterfacePass &pass)
+ : PassWrapper(pass) {}
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<AffineDialect, memref::MemRefDialect, scf::SCFDialect,
+ tensor::TensorDialect>();
+ linalg::registerTilingInterfaceExternalModels(registry);
+ }
+ StringRef getArgument() const final { return "test-tiling-interface"; }
+ StringRef getDescription() const final {
+ return "Test tiling using TilingInterface";
+ }
+
+ void runOnOperation() override;
+};
+} // namespace
+
+static void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns) {
+ auto addPatternForTiling = [&](ArrayRef<int64_t> tileSizes,
+ StringRef filterName) {
+ scf::SCFTilingOptions tilingOptions;
+ tilingOptions.setTileSizes(tileSizes);
+ linalg::LinalgTransformationFilter filter(
+ StringAttr::get(context, filterName),
+ StringAttr::get(context, "tiled"));
+ patterns.add<TestTileUsingSCFForOpWithFilter>(context, tilingOptions,
+ filter);
+ };
+ // 1. Tiling M and N dims of `linalg.matmul` on tensors.
+ addPatternForTiling({10, 20}, "simple_gemm");
+ // 2. Tiling M, N and K of `linalg.matmul` on buffers.
+ addPatternForTiling({10, 20, 30}, "simple_gemm_memref");
+ // 3. Tiling 3D parallel generic op which implements a transpose
+ addPatternForTiling({10, 0, 20}, "parallel_generic_transpose");
+ // 4. Tiling 2D conv op.
+ addPatternForTiling({0, 0, 0, 0, 10, 20, 30}, "simple_conv");
+}
+
+void TestTilingInterfacePass::runOnOperation() {
+ MLIRContext *context = &getContext();
+
+ RewritePatternSet tilingPatterns(context);
+ addTestPatterns(context, tilingPatterns);
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(tilingPatterns))))
+ return signalPassFailure();
+}
+
+namespace mlir {
+namespace test {
+void registerTestTilingInterface() {
+ PassRegistration<TestTilingInterfacePass>();
+}
+} // namespace test
+} // namespace mlir
MLIRTestRewrite
MLIRTestTransformDialect
MLIRTestTransforms
+ MLIRTilingInterfaceTestPasses
MLIRVectorTestPasses
)
endif()
void registerTestSCFUtilsPass();
void registerTestSliceAnalysisPass();
void registerTestTensorTransforms();
+void registerTestTilingInterface();
void registerTestTransformDialectInterpreterPass();
void registerTestVectorLowerings();
} // namespace test
mlir::test::registerTestSCFUtilsPass();
mlir::test::registerTestSliceAnalysisPass();
mlir::test::registerTestTensorTransforms();
+ mlir::test::registerTestTilingInterface();
mlir::test::registerTestTransformDialectInterpreterPass();
mlir::test::registerTestVectorLowerings();
}
"include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h",
"include/mlir/Dialect/SCF/Passes.h",
"include/mlir/Dialect/SCF/Patterns.h",
+ "include/mlir/Dialect/SCF/TileUsingInterface.h",
"include/mlir/Dialect/SCF/Transforms.h",
],
includes = ["include"],
":SCFUtils",
":Support",
":TensorDialect",
+ ":TilingInterface",
":Transforms",
"//llvm:Support",
],
exclude = [
"include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h",
"include/mlir/Dialect/SCF/Patterns.h",
+ "include/mlir/Dialect/SCF/TileUsingInterface.h",
"include/mlir/Dialect/SCF/Transforms.h",
],
),
"//mlir/test:TestSPIRV",
"//mlir/test:TestShapeDialect",
"//mlir/test:TestTensor",
+ "//mlir/test:TestTilingInterface",
"//mlir/test:TestTosaDialect",
"//mlir/test:TestTransformDialect",
"//mlir/test:TestTransforms",
":TensorTilingInterfaceImpl",
":TensorTransforms",
":TensorUtils",
+ ":TilingInterface",
":TransformUtils",
":Transforms",
":VectorDialect",
],
)
+cc_library(
+ name = "TestTilingInterface",
+ srcs = glob(["lib/Interfaces/TilingInterface/*.cpp"]),
+ includes = ["lib/Dialect/Test"],
+ deps = [
+ "//llvm:Support",
+ "//mlir:Affine",
+ "//mlir:ArithmeticDialect",
+ "//mlir:FuncDialect",
+ "//mlir:IR",
+ "//mlir:LinalgDialect",
+ "//mlir:LinalgTransforms",
+ "//mlir:MemRefDialect",
+ "//mlir:Pass",
+ "//mlir:SCFDialect",
+ "//mlir:SCFTransforms",
+ "//mlir:TensorDialect",
+ "//mlir:TilingInterface",
+ "//mlir:Transforms",
+ ],
+)
+
cc_library(
name = "TestPass",
srcs = glob(["lib/Pass/*.cpp"]),