#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include <utility>
Optional<ArrayRef<OpFoldResult>> nominalTileSizes,
SmallVector<OpFoldResult> &tiledOffsets,
SmallVector<OpFoldResult> &tiledSizes) {
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPointToStart(foreachThreadOp.getBody(0));
+
ValueRange threadIds = foreachThreadOp.getThreadIndices();
SmallVector<OpFoldResult> nonZeroNumThreads =
llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
Optional<ArrayAttr> mapping, bool omitTileOffsetBoundsCheck) {
Location loc = op->getLoc();
OpBuilder::InsertionGuard g(b);
+
SmallVector<Range> loopRanges = op.getIterationDomain(b);
if (loopRanges.empty())
return op->emitOpError("expected non-empty loop ranges");
Operation *tiledOp = nullptr;
- // Create the ForeachThreadOp. We don't use the lambda body-builder
+ // 1. Create the ForeachThreadOp. We don't use the lambda body-builder
// version because we require the use of RewriterBase in the body, so we
// manually move the insertion point to the body below.
scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
loc, dest, ValueRange(materializedNonZeroNumThreads), mapping);
- // Fill out the ForeachThreadOp body.
- b.setInsertionPointToStart(foreachThreadOp.getBody(0));
+ // 2. Fill out the ForeachThreadOp body.
SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
calculateTileOffsetsAndSizes(b, loc, foreachThreadOp, numThreads, loopRanges,
omitTileOffsetBoundsCheck, nominalTileSizes,
tiledOffsets, tiledSizes);
- // Clone the tileable op and update its destination operands to use the output
- // bbArgs of the ForeachThreadOp.
+ // 3. Clone the tileable op and update its destination operands to use the
+ // output bbArgs of the ForeachThreadOp.
ArrayRef<BlockArgument> destBbArgs =
foreachThreadOp.getOutputBlockArguments();
- Operation *clonedOp = b.clone(*op.getOperation());
- auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
- if (destinationStyleOp) {
- for (OpOperand *outOperand : destinationStyleOp.getDpsInitOperands()) {
- auto *it = llvm::find(dest, outOperand->get());
- assert(it != dest.end() && "dest operand not found in dest");
- unsigned destNum = std::distance(dest.begin(), it);
- outOperand->set(destBbArgs[destNum]);
+ {
+ // 3.a. RAII guard, inserting within foreachThreadOp, before terminator.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(foreachThreadOp.getTerminator());
+ Operation *clonedOp = b.clone(*op.getOperation());
+ auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
+ if (destinationStyleOp) {
+ for (OpOperand *outOperand : destinationStyleOp.getDpsInitOperands()) {
+ auto *it = llvm::find(dest, outOperand->get());
+ assert(it != dest.end() && "dest operand not found in dest");
+ unsigned destNum = std::distance(dest.begin(), it);
+ outOperand->set(destBbArgs[destNum]);
+ }
}
- }
- // Tile the cloned op and delete the clone.
- SmallVector<Operation *> tiledOps =
- cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
- tiledSizes);
- b.eraseOp(clonedOp);
- assert(tiledOps.size() == 1 && "expected a single produced tiled op");
- tiledOp = tiledOps.front();
+ // 4. Tile the cloned op and delete the clone.
+ SmallVector<Operation *> tiledOps =
+ cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
+ tiledSizes);
+ b.eraseOp(clonedOp);
+ assert(tiledOps.size() == 1 && "expected a single produced tiled op");
+ tiledOp = tiledOps.front();
+ }
+ // 5. Parallel insert back into the result tensor.
auto tilingInterfaceOp = dyn_cast<TilingInterface>(tiledOp);
assert(tilingInterfaceOp && "Tiled op does not implement TilingInterface");
- OpBuilder::InsertPoint insertPt = b.saveInsertionPoint();
for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())),
tilingInterfaceOp->getResults(), destBbArgs)) {
- b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint());
+ // 5.a. Partial subset information is inserted just before the terminator.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(foreachThreadOp.getTerminator());
+
SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(op.getResultTilePosition(b, std::get<0>(it), tiledOffsets,
tiledSizes, resultOffsets,
resultSizes)))
return op->emitOpError("output offsets couldn't be calculated");
SmallVector<OpFoldResult> strides(resultSizes.size(), b.getIndexAttr(1));
+
+ // 5.b. Parallel insertions are inserted at the end of the combining
+ // terminator.
b.setInsertionPointToEnd(foreachThreadOp.getTerminator().getBody());
b.create<tensor::ParallelInsertSliceOp>(loc, std::get<1>(it),
std::get<2>(it), resultOffsets,
static FailureOr<TiledLinalgOp>
tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
const LinalgTilingOptions &options) {
+ OpBuilder::InsertionGuard g(b);
+
auto nLoops = op.getNumLoops();
// Initial tile sizes may be too big, only take the first nLoops.
tileSizes = tileSizes.take_front(nLoops);
Optional<ArrayAttr> mapping) {
Location loc = op.getLoc();
OpBuilder::InsertionGuard g(b);
+
// Ops implementing PartialReductionOpInterface are expected to implement
// TilingInterface.
+ // TODO: proper core mechanism to tie interfaces together.
auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
+
+ // Ops implementing PartialReductionOpInterface are not necessarily expected
+ // to implement TilingInterface.. This cast is unsafe atm.
+ // TODO: proper core mechanism to tie interfaces together.
+ // TODO: this function requires a pair of interfaces ..
+ auto destinationStyleOp =
+ dyn_cast<DestinationStyleOpInterface>(op.getOperation());
+ if (!destinationStyleOp)
+ return b.notifyMatchFailure(op, "not a destination style op");
+
+ // Actually this only work for Linalg ops atm.
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(op.getOperation());
+ if (!linalgOp)
+ return b.notifyMatchFailure(op, "not a linalg op");
+
SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
if (op->getNumResults() != 1)
return b.notifyMatchFailure(
op, "don't support ops with multiple results for now");
+
SmallVector<utils::IteratorType> iterators =
tilingInterfaceOp.getLoopIteratorTypes();
SmallVector<unsigned> redDims;
- cast<linalg::LinalgOp>(op.getOperation()).getReductionDims(redDims);
+ linalgOp.getReductionDims(redDims);
if (redDims.size() != 1)
return b.notifyMatchFailure(
op, "only support ops with one reduction dimension.");
return b.notifyMatchFailure(op, "if tile sizes are present it must have as "
"many elements as number of threads");
int reductionDim = static_cast<int>(redDims.front());
- // 1. create the inital tensor value.
+
+ // 1. Create the inital tensor value.
FailureOr<Operation *> identityTensor =
op.generateInitialTensorForPartialReduction(b, loc, numThreads,
reductionDim);
loc, identityTensor.value()->getResults(),
ValueRange(materializedNonZeroNumThreads), mapping);
- // 3. calculate the tile offsets and sizes.
- b.setInsertionPointToStart(foreachThreadOp.getBody(0));
+ // 3. Calculate the tile offsets and sizes for the subsequent loop that will
+ // be nested under `foreachThreadOp`.
SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
calculateTileOffsetsAndSizes(
b, loc, foreachThreadOp, numThreads, iterationDomain,
// 4. Clone the tileable op and update its destination operands to use the
// output bbArgs of the ForeachThreadOp.
+ ValueRange tilingResults;
ArrayRef<BlockArgument> destBbArgs =
foreachThreadOp.getOutputBlockArguments();
- Operation *clonedOp = b.clone(*op.getOperation());
- b.setInsertionPointToStart(foreachThreadOp.getBody(0));
- auto destinationStyleOp = cast<DestinationStyleOpInterface>(clonedOp);
- for (OpOperand *initOperand : destinationStyleOp.getDpsInitOperands()) {
- auto *it = llvm::find(dest, initOperand->get());
- assert(it != dest.end() && "dest operand not found in dest");
- unsigned destNum = std::distance(dest.begin(), it);
- SmallVector<OpFoldResult> strides(numThreads.size(), b.getIndexAttr(1));
- SmallVector<OpFoldResult> outOffsets(numThreads.size(), b.getIndexAttr(0));
- SmallVector<OpFoldResult> sizes = tiledSizes;
- sizes[reductionDim] = b.getIndexAttr(1);
- outOffsets[reductionDim] = foreachThreadOp.getThreadIndices().front();
- // TODO: use SubsetExtractOpInterface once it is available.
- Value patial = b.create<tensor::ExtractSliceOp>(
- loc, initOperand->get().getType().cast<RankedTensorType>(),
- destBbArgs[destNum], outOffsets, sizes, strides);
- initOperand->set(patial);
- }
- b.setInsertionPoint(clonedOp);
+ {
+ // 4.a. RAII guard, inserting within foreachThreadOp, before terminator.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(foreachThreadOp.getTerminator());
+
+ SmallVector<Value> tiledDpsInitOperands;
+ for (OpOperand *initOperand : destinationStyleOp.getDpsInitOperands()) {
+ auto *it = llvm::find(dest, initOperand->get());
+ assert(it != dest.end() && "dest operand not found in dest");
+ unsigned destNum = std::distance(dest.begin(), it);
+ SmallVector<OpFoldResult> strides(numThreads.size(), b.getIndexAttr(1));
+ SmallVector<OpFoldResult> outOffsets(numThreads.size(),
+ b.getIndexAttr(0));
+ SmallVector<OpFoldResult> sizes = tiledSizes;
+ sizes[reductionDim] = b.getIndexAttr(1);
+ outOffsets[reductionDim] = foreachThreadOp.getThreadIndices().front();
+ // TODO: use SubsetExtractOpInterface once it is available.
+ tiledDpsInitOperands.push_back(b.create<tensor::ExtractSliceOp>(
+ loc, initOperand->get().getType().cast<RankedTensorType>(),
+ destBbArgs[destNum], outOffsets, sizes, strides));
+ }
- // 5. Tile the cloned op and delete the clone.
- if (tileSizes.empty()) {
- SmallVector<Operation *> tiledOps =
- cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
- tiledSizes);
- assert(tiledOps.size() == 1 && "expected a single produced tiled op");
- tiledOp = tiledOps.front();
- } else {
- LinalgTilingOptions options;
- auto tiled = tileLinalgOpImpl<scf::ForOp>(b, cast<LinalgOp>(clonedOp),
- tileSizes, options);
- SmallVector<Value> ids = foreachThreadOp.getThreadIndices();
- mapLoopToProcessorIds(cast<scf::ForOp>(tiled->loops.back()), ids,
- materializedNonZeroNumThreads);
- assert(tiled->loops.size() == 1 && "expected a single produced loop");
- tiledOp = tiled->loops.front();
+ // 4.b. Clone the op and update init operands.
+ // We cannot use a BlockAndValueMapping here because it can replace
+ // different OpOperands with the same value.
+ Operation *clonedOp = b.clone(*op.getOperation());
+ b.updateRootInPlace(clonedOp, [&]() {
+ for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal(
+ cast<DestinationStyleOpInterface>(clonedOp).getDpsInitOperands(),
+ tiledDpsInitOperands)) {
+ initOperandPtr->set(tiledInitValue);
+ }
+ });
+
+ // 5. Tile the cloned op and delete the clone.
+ if (tileSizes.empty()) {
+ SmallVector<Operation *> tiledOps =
+ cast<TilingInterface>(clonedOp).getTiledImplementation(
+ b, tiledOffsets, tiledSizes);
+ assert(tiledOps.size() == 1 && "expected a single produced tiled op");
+ tiledOp = tiledOps.front();
+ tilingResults = tiledOp->getResults();
+ } else {
+ LinalgTilingOptions options;
+ FailureOr<TiledLinalgOp> maybeTiled = tileLinalgOpImpl<scf::ForOp>(
+ b, cast<LinalgOp>(clonedOp), tileSizes, options);
+ if (failed(maybeTiled))
+ return b.notifyMatchFailure(op, "failed tileLinalgOpImpl");
+
+ SmallVector<Value> ids = foreachThreadOp.getThreadIndices();
+ mapLoopToProcessorIds(cast<scf::ForOp>(maybeTiled->loops.back()), ids,
+ materializedNonZeroNumThreads);
+ assert(maybeTiled->loops.size() == 1 &&
+ "expected a single produced loop");
+ tiledOp = maybeTiled->op;
+ tilingResults = maybeTiled->loops.front()->getResults();
+ }
+
+ b.eraseOp(clonedOp);
}
- b.eraseOp(clonedOp);
// 6. Insert the partial reductions back into a new tensor.
- b.setInsertionPointAfter(tiledOp);
- OpBuilder::InsertPoint insertPt = b.saveInsertionPoint();
- for (auto [index, result, bbArg] :
- llvm::zip(llvm::seq<unsigned>(0, dest.size()), tiledOp->getResults(),
- destBbArgs)) {
- b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint());
+ for (auto [index, result, bbArg] : llvm::zip(
+ llvm::seq<unsigned>(0, dest.size()), tilingResults, destBbArgs)) {
+ // 6.a. Partial subset information is inserted just before the terminator.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(foreachThreadOp.getTerminator());
+
SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(tilingInterfaceOp.getResultTilePosition(
b, index, tiledOffsets, tiledSizes, resultOffsets, resultSizes)))
resultOffsetsRank.push_back(resultOffsets[offIdx++]);
resultSizesRank.push_back(resultSizes[sizeIdx++]);
}
-
SmallVector<OpFoldResult> strides(resultSizesRank.size(),
b.getIndexAttr(1));
+
+ // 6.b. Parallel insertions are inserted at the end of the combining
+ // terminator.
b.setInsertionPointToEnd(foreachThreadOp.getTerminator().getBody());
b.create<tensor::ParallelInsertSliceOp>(
loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides);
}
+
// 7. Merge the partial reductions.
b.setInsertionPointAfter(foreachThreadOp);
Operation *mergeOp =
op.mergeReductions(b, loc, foreachThreadOp->getResults(), reductionDim);
b.replaceOp(op, mergeOp->getResults());
+
+ // 8. Return.
ForeachThreadReductionTilingResult results;
results.initialOp = identityTensor.value();
results.loops = foreachThreadOp;