Value shape = en.value();
SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
for (auto en2 : llvm::enumerate(map.getResults())) {
- auto dimExpr = en2.value().dyn_cast<AffineDimExpr>();
- if (!dimExpr)
- continue;
if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
<< loopDepth << "\n");
llvm_unreachable("Expect to be able to extract a shape defining loop range");
}
-/// Fuse the producer by cloning the `producer`. The `fusedLoopsAndRanges`
-/// provides the loop range information for the fused loops. The rest are
-/// obtained from the producer itself, since they are not tiled + fused.
-static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
- const DenseMap<unsigned, Range> &fusedLoopsAndRanges) {
+/// Fuses the producer of `producerIdx` into the loop immediately enclosing
+/// `consumer`. This is achieved by "recomputing" the `producer` at the time it
+/// is needed just before the `consumer.
+///
+/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
+/// 2 cases:
+/// 1. Buffer case: `producerIdx` is the index of the buffer in
+/// `producer.getOutputBuffers()`.
+/// 2. Tensor case: `producerIdx` is the index of the tensor in
+/// `producer.getResults()`.
+static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
+ LinalgOp consumer, unsigned consumerIdx) {
+ Operation *shapeProducingOp =
+ consumer.getShapedOperand(consumerIdx).getDefiningOp();
+ assert((isa<SubViewOp>(shapeProducingOp) ||
+ isa<SubTensorOp>(shapeProducingOp)) &&
+ "SubviewOp or SubTensorOp expected");
+
+ // loopToOperandRangesMaps are permutations-only by construction:
+ // we can always identify a data dimension with a (at least one) loop
+ // dimension.
+ // TODO: extend this with range inference.
+ AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
+ LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
+ << ", producer map: " << producerMap << "\n");
unsigned nPar = producer.getNumParallelLoops();
unsigned nRed = producer.getNumReductionLoops();
unsigned nWin = producer.getNumWindowLoops();
SmallVector<Range, 8> loopRanges(nPar + nRed + nWin);
- for (auto fusedLoops : fusedLoopsAndRanges)
- loopRanges[fusedLoops.first] = fusedLoops.second;
+
+ // Iterate over dimensions identified by the producer map for `producerIdx`.
+ // This defines a subset of the loop ranges that we need to complete later.
+ auto loc = consumer.getLoc();
+ for (auto en : llvm::enumerate(producerMap.getResults())) {
+ unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
+ loopRanges[posInProducerLoop] =
+ isa<SubViewOp>(shapeProducingOp)
+ ? cast<SubViewOp>(shapeProducingOp)
+ .getOrCreateRanges(b, loc)[en.index()]
+ : cast<SubTensorOp>(shapeProducingOp)
+ .getOrCreateRanges(b, loc)[en.index()];
+ }
// Iterate over all dimensions. For the dimensions not identified by the
// producer map for `producerIdx`, we need to explicitly compute the shape
}
}
- return cloneWithLoopRanges(b, producer.getLoc(), producer, loopRanges);
-}
-
-/// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is
-/// expected to be defined by a subview op or a subtensor op.
-static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
- Value shapedOperand, unsigned dim) {
- Operation *shapeProducingOp = shapedOperand.getDefiningOp();
- if (auto subViewOp = dyn_cast<SubViewOp>(shapeProducingOp))
- return subViewOp.getOrCreateRanges(b, loc)[dim];
- if (auto subTensorOp = dyn_cast<SubTensorOp>(shapeProducingOp))
- return subTensorOp.getOrCreateRanges(b, loc)[dim];
- llvm_unreachable("SubviewOp or SubTensorOp expected");
-}
-
-/// Fuses the producer of `producerIdx` into the loop immediately enclosing
-/// `consumer`. This is achieved by "recomputing" the `producer` at the time it
-/// is needed just before the `consumer.
-///
-/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
-/// 2 cases:
-/// 1. Buffer case: `producerIdx` is the index of the buffer in
-/// `producer.getOutputBuffers()`.
-/// 2. Tensor case: `producerIdx` is the index of the tensor in
-/// `producer.getResults()`.
-static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
- LinalgOp consumer, unsigned consumerIdx) {
- AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
- LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
- << ", producer map: " << producerMap << "\n");
- DenseMap<unsigned, Range> fusedLoopsAndRanges;
- Location loc = consumer.getLoc();
- Value shapedOperand = consumer.getShapedOperand(consumerIdx);
- for (auto en : llvm::enumerate(producerMap.getResults())) {
- unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
- fusedLoopsAndRanges[posInProducerLoop] =
- getRangeFromOperandShape(b, loc, shapedOperand, en.index());
- }
- return fuse(b, producer, fusedLoopsAndRanges);
+ return cloneWithLoopRanges(b, loc, producer, loopRanges);
}
// Encode structural fusion safety preconditions.
return getProjectedMap(map, projectedDims);
}
-/// Returns the mapping from iterations in the consumer that write to the same
-/// location as the iterations in the producer. To do so use
-/// - indexing map of the fused view in the consumer : consumerIndexMap
-/// - indexing map of the fused view in the producer : producerIndexMap
-/// consumerLoopToProducerLoop =
-/// inverse(producerIndexMap).compose(consumerIndexMap)
-static Optional<AffineMap> getConsumerLoopToProducerLoopMap(
- LinalgDependenceGraph::LinalgDependenceGraphElem dependence) {
- auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
- AffineMap producerIndexingMap =
- producer.getIndexingMap(dependence.dependentOpView.operandIndex);
- auto consumer = cast<LinalgOp>(dependence.indexingOpView.op);
- AffineMap consumerIndexingMap =
- consumer.getIndexingMap(dependence.indexingOpView.operandIndex);
-
- AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
- producer.iterator_types().getValue(), producerIndexingMap);
- if (!prunedProducerIndexingMap.isPermutation())
- return None;
-
- if (consumerIndexingMap.getNumResults() !=
- prunedProducerIndexingMap.getNumResults())
- return None;
-
- LLVM_DEBUG({
- llvm::dbgs() << "\t producerMap : ";
- producerIndexingMap.print(llvm::dbgs());
- llvm::dbgs() << " pruned : ";
- prunedProducerIndexingMap.print(llvm::dbgs());
- llvm::dbgs() << "\n";
- llvm::dbgs() << "\t consumerMap : ";
- consumerIndexingMap.print(llvm::dbgs());
- llvm::dbgs() << "\n";
- });
-
- AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap);
- if (!invProducerIndexMap)
- return None;
-
- return invProducerIndexMap.compose(consumerIndexingMap);
-}
-
-/// Given a projected permutation `map`, returns true if the map changes the
-/// order in which the fused loop dimension appear.
-static bool doesTransposeAccess(AffineMap map,
- const std::set<unsigned> &fusableLoops) {
- Optional<unsigned> lastFusableLoop;
- for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) {
- return expr.cast<AffineDimExpr>().getPosition();
- })) {
- if (!fusableLoops.count(pos))
- continue;
- if (!lastFusableLoop) {
- lastFusableLoop = pos;
- continue;
- }
- if (pos <= lastFusableLoop.getValue())
- return true;
- lastFusableLoop = pos;
- }
- return false;
-}
+using FusableOpDependencesTy = llvm::MapVector<
+ Operation *,
+ SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>;
/// Returns the positions of the loop in `op` that can be tiled based on the
/// operations that are to be fused with it. For example, in a
/// 2. Of the parallel loops only some can be fused. Only those loops can be
/// fused such where the fusable loops iteration space only touches one tile
/// of the fused operation. This is because the producer (which is writing
-/// the fused subview) has update semantics.
+/// the fused subview) has update semantics. To compute this,
+/// a. Find the mapping from iterations in the consumer that write to the
+/// same location as the iterations in the producer. To do so use
+/// - indexing map of the fused view in the consumer : consumerIndexMap
+/// - indexing map of the fused view in the producer : producerIndexMap
+/// consumerLoopToProducerLoop =
+/// inverse(producerIndexMap).compose(consumerIndexMap)
///
/// Since an inverse computation is needed, we need to consider the projection
/// of the producerIndexMap w.r.t the parallel loops. The actual fusable loops
/// submap with only parallel loops = affine_map<(i, j) -> (j)>
/// Fused dimensions : j
static std::set<unsigned>
-collectFusableLoops(ArrayRef<LinalgOp> ops,
- const FusableOpDependencesTy &fusableDependences) {
- assert(!ops.empty());
+collectTileAndFuseLoops(LinalgOp op,
+ const FusableOpDependencesTy &fusableDependences) {
auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
return linalgOp.iterator_types()
.getValue()
.size();
};
- size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back());
- for (auto op : ops.drop_back()) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "Op : ";
+ op.getOperation()->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n";
+ });
+
+ size_t numOuterParallelLoops = getNumOuterParallelLoops(op);
+ for (auto dependence : fusableDependences) {
+ linalg::LinalgOp producer = cast<linalg::LinalgOp>(dependence.first);
numOuterParallelLoops =
- std::min(numOuterParallelLoops, getNumOuterParallelLoops(op));
+ std::min(numOuterParallelLoops, getNumOuterParallelLoops(producer));
}
std::set<unsigned> fusableLoops;
auto range = llvm::seq<unsigned>(0, numOuterParallelLoops);
fusableLoops.insert(range.begin(), range.end());
-
- for (auto op : reverse(ops)) {
- for (auto dependence : fusableDependences.lookup(op)) {
- LLVM_DEBUG({
- llvm::dbgs() << "\t fusable :";
- for (unsigned i : fusableLoops)
- llvm::dbgs() << " " << i;
- llvm::dbgs() << "\n";
- });
-
- Optional<AffineMap> consumerLoopToProducerLoop =
- getConsumerLoopToProducerLoopMap(dependence);
- if (!consumerLoopToProducerLoop) {
- op.emitRemark("failed to get map from consumer loop to producer loop");
- return {};
- }
- // todo: This condition is only an implementation limitation. When fusing
- // the operation, if the accesses in the producer/consumer are transposes
- // of each other, the loop bounds for the tiled producer can be
- // manipulated accordingly. This requires some additional bookkeeping in
- // the implementation of tile+fuse that is defered to later.
- if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) {
- op.emitRemark("unhandled fusion when fusion requires permutation");
- return {};
- }
-
- std::set<unsigned> candidates;
- for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) {
- unsigned position = expr.cast<AffineDimExpr>().getPosition();
- if (fusableLoops.count(position))
- candidates.insert(position);
- }
- LLVM_DEBUG({
- llvm::dbgs() << "\t candidates :";
- for (unsigned i : candidates)
- llvm::dbgs() << " " << i;
- llvm::dbgs() << "\n";
- });
- if (candidates.empty())
- return {};
- std::swap(candidates, fusableLoops);
+ for (auto dependence : fusableDependences) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "\t fusable :";
+ for (unsigned i : fusableLoops)
+ llvm::dbgs() << " " << i;
+ llvm::dbgs() << "\n";
+ });
+ linalg::LinalgOp producer = cast<linalg::LinalgOp>(dependence.first);
+
+ assert(!dependence.second.empty() &&
+ "unexpected producer but not dependences");
+ AffineMap producerIndexingMap = producer.getIndexingMap(
+ dependence.second.front().dependentOpView.operandIndex);
+ AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
+ producer.iterator_types().getValue(), producerIndexingMap);
+ if (!prunedProducerIndexingMap.isPermutation())
+ return {};
+
+ AffineMap consumerIndexingMap = op.getIndexingMap(
+ dependence.second.front().indexingOpView.operandIndex);
+ if (consumerIndexingMap.getNumResults() !=
+ prunedProducerIndexingMap.getNumResults())
+ return {};
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "\t producerMap : ";
+ producerIndexingMap.print(llvm::dbgs());
+ llvm::dbgs() << " pruned : ";
+ prunedProducerIndexingMap.print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ llvm::dbgs() << "\t consumerMap : ";
+ consumerIndexingMap.print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+
+ AffineMap invProducerIndexMap =
+ inversePermutation(prunedProducerIndexingMap);
+ if (!invProducerIndexMap)
+ return {};
+
+ AffineMap consumerLoopToProducerLoop =
+ invProducerIndexMap.compose(consumerIndexingMap);
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "\t consumerLoopToProducerLoop : ";
+ consumerLoopToProducerLoop.print(llvm::dbgs());
+ });
+
+ std::set<unsigned> candidates;
+ for (AffineExpr expr : consumerLoopToProducerLoop.getResults()) {
+ AffineDimExpr dimExpr = expr.dyn_cast<AffineDimExpr>();
+ if (!dimExpr)
+ continue;
+ unsigned position = dimExpr.getPosition();
+ if (fusableLoops.count(position))
+ candidates.insert(position);
}
+ LLVM_DEBUG({
+ llvm::dbgs() << "\t candidates :";
+ for (unsigned i : candidates)
+ llvm::dbgs() << " " << i;
+ llvm::dbgs() << "\n";
+ });
+ if (candidates.empty())
+ return {};
+ std::swap(candidates, fusableLoops);
}
return fusableLoops;
}
-/// Find all dependences that are fusable.
-FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
- ArrayRef<LinalgOp> ops, const LinalgDependenceGraph &dependenceGraph) {
+/// Find all dependences that are to be fusable.
+static FusableOpDependencesTy
+findAllFusableDependences(LinalgOp op,
+ const LinalgDependenceGraph &dependenceGraph,
+ const LinalgFusionOptions &fusionOptions) {
FusableOpDependencesTy fusableDependences;
// TODO: Currently fusion would not be legal if the fusable dependence is to
// the same producer but different indexing map in the consumer. Fix this, but
// in the meanwhile disallow such a fusion.
DenseMap<Operation *, AffineMap> fusedProducerIndexingMap;
- for (LinalgOp op : reverse(ops)) {
- for (auto operandIndex :
- llvm::seq<unsigned>(0, op.getNumInputsAndOutputBuffers())) {
- Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
- fusableDependence =
- findFusableProducer(op, operandIndex, dependenceGraph);
- if (!fusableDependence)
- continue;
- LinalgOp producerOp =
- cast<LinalgOp>(fusableDependence->dependentOpView.op);
- // Do not fuse dependences that are to operations not in the same basic
- // block. This avoid moving fused operations across loops that might
- // themselves carry dependency making the fusion illegal.
- if (producerOp.getOperation()->getBlock() !=
- op.getOperation()->getBlock()) {
- op.emitRemark("unhandled fusion of ops in different basic blocks");
- return FusableOpDependencesTy{};
- }
- // Make sure that the indexing map of the view used for fusion in the
- // producer is a projected permutation.
- unsigned producerIdx = fusableDependence->dependentOpView.operandIndex;
- AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
- if (!producerMap.isProjectedPermutation()) {
- op.emitRemark(
- "unhandled non permutation indexing map for fused view in "
- "producer for operand at index ")
- << operandIndex;
- return FusableOpDependencesTy{};
- }
-
- unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex;
- AffineMap consumerMap = op.getIndexingMap(consumerIdx);
- if (!consumerMap.isProjectedPermutation()) {
- op.emitRemark(
- "unhandled case where indexing map for fused view in the consumer "
- "is "
- "not a projected permuration while fusing at index ")
- << operandIndex;
- return FusableOpDependencesTy{};
- }
+ for (auto operandIndex : fusionOptions.indicesToFuse) {
+ auto fusableDependence =
+ findFusableProducer(op, operandIndex, dependenceGraph);
+ if (!fusableDependence)
+ return FusableOpDependencesTy{};
+ LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
+ // Do not fuse dependences that are to operations not in the same basic
+ // block. This avoid moving fused operations across loops that might
+ // themselves carry dependency making the fusion illegal.
+ if (producerOp.getOperation()->getBlock() !=
+ op.getOperation()->getBlock()) {
+ op.emitRemark("unhandled fusion of ops in different basic blocks");
+ return FusableOpDependencesTy{};
+ }
+ // Make sure that the indexing map of the view used for fusion in the
+ // producer is a projected permutation.
+ unsigned producerIdx = fusableDependence->dependentOpView.operandIndex;
+ AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
+ if (!producerMap.isProjectedPermutation()) {
+ op.emitRemark("unhandled non permutation indexing map for fused view in "
+ "producer for operand at index ")
+ << operandIndex;
+ return FusableOpDependencesTy{};
+ }
- // Check if the producer is already a fusion candidate. Cannot fuse this
- // dependence if it has a different indexing map when used in the
- // consumer.
- if (fusedProducerIndexingMap.count(producerOp.getOperation()) &&
- fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) {
- op.emitRemark(
- "unhandled fusion to the same producer but with different "
- "indexing maps");
- return FusableOpDependencesTy{};
- }
- fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap;
+ unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex;
+ AffineMap consumerMap = op.getIndexingMap(consumerIdx);
+ if (!consumerMap.isProjectedPermutation()) {
+ op.emitRemark(
+ "unhandled case where indexing map for fused view in the consumer is "
+ "not a projected permutation while fusing at index ")
+ << operandIndex;
+ return FusableOpDependencesTy{};
+ }
- fusableDependences[producerOp.getOperation()].push_back(
- *fusableDependence);
+ // Check if the producer is already a fusion candidate. Cannot fuse this
+ // dependence if it has a different indexing map when used in the consumer.
+ if (fusedProducerIndexingMap.count(producerOp.getOperation()) &&
+ fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) {
+ op.emitRemark("unhandled fusion to the same producer but with different "
+ "indexing maps");
+ return FusableOpDependencesTy{};
}
+ fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap;
+
+ fusableDependences[producerOp.getOperation()].push_back(*fusableDependence);
}
return fusableDependences;
}
-/// Tile the fused loops in the root operation, by setting the tile sizes for
-/// all other loops to zero (those will be tiled later).
-static Optional<TiledLinalgOp> tileRootOperation(
- OpBuilder &builder, LinalgOp op, ArrayRef<Value> tileSizeVector,
- const LinalgTilingOptions &options, const std::set<unsigned> &fusedLoops) {
- SmallVector<Value, 4> tileSizes(tileSizeVector.begin(), tileSizeVector.end());
- auto zero = std_constant_index(0);
- for (unsigned i = 0, e = tileSizes.size(); i != e; ++i)
- if (!fusedLoops.count(i))
- tileSizes[i] = zero;
- LinalgTilingOptions tileFusedLoopsOptions = options;
- tileFusedLoopsOptions.setTileSizes(tileSizes);
- return tileLinalgOp(builder, op, tileFusedLoopsOptions);
-}
-
-/// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected
-/// to be a tiled operation such that it is valid to fuse all operations in
-/// `fusionCandidates`, i.e. move the operation within the inter-tile loops of
-/// `tiledOp`.
-static SmallVector<LinalgOp, 1>
-fuseOperations(OpBuilder &builder, LinalgOp tiledOp,
- ArrayRef<LinalgOp> fusionCandidates,
- const FusableOpDependencesTy &fusableDependences,
- const std::set<unsigned> &fusedLoops) {
- OpBuilder::InsertionGuard guard(builder);
- builder.setInsertionPoint(tiledOp);
- DenseMap<unsigned, Range> fusedLoopsAndRanges;
- for (unsigned loop : fusedLoops) {
- ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop);
- fusedLoopsAndRanges[loop] = getRangeFromOperandShape(
- builder, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension);
- }
-
- SmallVector<LinalgOp, 1> fusedOps(fusionCandidates.size());
- for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) {
- LinalgOp fusedOp = fuse(builder, candidate.value(), fusedLoopsAndRanges);
- fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
- builder.setInsertionPoint(fusedOp);
- }
- return fusedOps;
+static bool isZero(Value v) {
+ if (auto cst = v.getDefiningOp<ConstantIndexOp>())
+ return cst.getValue() == 0;
+ return false;
}
template <typename LoopType>
static Optional<TiledAndFusedLinalgOps>
-tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
+tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op,
const LinalgDependenceGraph &dependenceGraph,
- const LinalgTilingOptions &tilingOptions) {
- if (ops.empty())
- return llvm::None;
- LinalgOp rootOp = ops.back();
- for (auto op : enumerate(ops)) {
- // TODO: Nothing in the fusion of sequence of ops is specific to
- // buffers. This check can be removed after it is tested on tensors.
- LinalgOp linalgOp = op.value();
- if (!linalgOp.hasBufferSemantics()) {
- linalgOp.emitError("tile and fuse only tested for buffer operation");
- return llvm::None;
- }
- }
- // TODO: Support interchange with tile + fuse. This might actually help do
- // better fusion.
+ const LinalgTilingOptions &tilingOptions,
+ const LinalgFusionOptions &fusionOptions) {
+ assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
+ // Some of the tiling options might not be supportable with tile and fuse.
+ // TODO: Support interchange with tile + fuse.
if (!tilingOptions.interchangeVector.empty()) {
- rootOp.emitError("unable to handle tile and fuse with interchange");
+ op.emitError("unable to handle tile and fuse with interchange");
return llvm::None;
}
- OpBuilder::InsertionGuard guard(builder);
- builder.setInsertionPoint(rootOp);
- ScopedContext scope(builder, rootOp.getLoc());
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(op);
+ ScopedContext scope(rewriter, op.getLoc());
// Find all the producers.
FusableOpDependencesTy fusableDependences =
- findAllFusableDependences(ops, dependenceGraph);
+ findAllFusableDependences(op, dependenceGraph, fusionOptions);
if (fusableDependences.empty())
return llvm::None;
+ // Enforce the convention that "tiling by zero" skips tiling a particular
+ // dimension. This convention is significantly simpler to handle instead of
+ // adjusting affine maps to account for missing dimensions.
+ auto nLoops = op.getNumLoops();
+ SmallVector<Value, 4> tileSizeVector =
+ tilingOptions.tileSizeComputationFunction(rewriter, op);
+ if (tileSizeVector.size() < nLoops) {
+ auto zero = std_constant_index(0);
+ tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
+ }
+
TiledAndFusedLinalgOps ret;
+
// Find the loops that can be tiled and fused.
- ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences);
+ std::set<unsigned> tileFuseLoops =
+ collectTileAndFuseLoops(op, fusableDependences);
// If there are no fusable dependences or there are no tile+fusable loops,
// just return.
- if (ret.fusedLoopDims.empty()) {
+ if (tileFuseLoops.empty()) {
return llvm::None;
}
- // Tile the fused loops in the last operation in the list.
- SmallVector<Value, 4> tileSizeVector =
- tilingOptions.tileSizeComputationFunction(builder, rootOp);
- Optional<TiledLinalgOp> tiledRootOp = tileRootOperation(
- builder, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims);
- if (!tiledRootOp) {
- rootOp.emitError("failed to tile the fused loops");
+ // Get the tile sizes for the first and second tiling steps. For the first
+ // step the tile size are set to zero for the loops that arent
+ // fused. Similarly for the second step, the tile sizes are set to zero for
+ // the loops that are fused. For example, if for the following input
+ //
+ // ```
+ // linalg.add ins(%a, %b) outs(%c)
+ // linalg.matmul ins(%d, %c) outs(%e)
+ // ```
+ //
+ // if the tile sizes of the `{i, j, k}` loops where given as `{ti, tj, tk}`
+ // respectively, and since only `j` can be tiled and fused. The tile sizes
+ // would be `{0, t_j, 0}` for the first tiling that tiles just the fusable
+ // loops. The second tiling would be use tile sizes of `{t_i, 0, t_k}` to tile
+ // the tiled matmul generated by the first tiling step.
+ SmallVector<Value, 4> tileAndFuseSizes, tileSizes;
+ for (auto tileSize : enumerate(tileSizeVector)) {
+ auto zero = std_constant_index(0);
+ if (tileFuseLoops.count(tileSize.index())) {
+ tileAndFuseSizes.push_back(tileSize.value());
+ tileSizes.push_back(zero);
+ } else {
+ tileSizes.push_back(tileSize.value());
+ tileAndFuseSizes.push_back(zero);
+ }
+ }
+
+ // Tile for the loops that can be fused.
+ LinalgTilingOptions firstTilingOptions = tilingOptions;
+ firstTilingOptions.setTileSizes(tileAndFuseSizes);
+ Optional<TiledLinalgOp> firstTiledOp =
+ tileLinalgOp(rewriter, op, firstTilingOptions);
+ if (!firstTiledOp)
return llvm::None;
+ ret.op = firstTiledOp->op;
+ ret.fusedLoops.assign(firstTiledOp->loops.begin(), firstTiledOp->loops.end());
+
+ rewriter.setInsertionPoint(ret.op);
+ // Fuse the operands.
+ for (auto dependence : fusableDependences) {
+ LinalgOp producerOp = cast<LinalgOp>(dependence.first);
+ unsigned producerIdx =
+ dependence.second.front().dependentOpView.operandIndex;
+ unsigned consumerIdx =
+ dependence.second.front().indexingOpView.operandIndex;
+ LinalgOp fusedOp = fuse(rewriter, producerOp,
+ producerOp.getOutputIndex(producerIdx).getValue(),
+ ret.op, consumerIdx);
+ ret.fusedProducers.push_back(fusedOp);
+ ret.originalProducers.push_back(producerOp);
+ }
+
+ if (!llvm::all_of(tileSizes, isZero)) {
+ // Tile the remaining loops of the root operation.
+ LinalgTilingOptions secondTilingOptions = tilingOptions;
+ // The distribution is done only for the tile+fused loops.
+ secondTilingOptions.distribution = llvm::None;
+ secondTilingOptions.setTileSizes(tileSizes);
+ Optional<TiledLinalgOp> secondTiledOp =
+ tileLinalgOp(rewriter, ret.op, secondTilingOptions);
+ if (!secondTiledOp)
+ return llvm::None;
+ ret.unfusedLoops.assign(secondTiledOp->loops.begin(),
+ secondTiledOp->loops.end());
+ rewriter.eraseOp(ret.op);
+ ret.op = secondTiledOp->op;
}
- ret.op = tiledRootOp->op;
- ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
- // Fuse the other operations into the fused inter-tile loops produced above.
- ret.fusedProducers = fuseOperations(builder, ret.op, ops.drop_back(),
- fusableDependences, ret.fusedLoopDims);
return ret;
}
Optional<TiledAndFusedLinalgOps>
-mlir::linalg::tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
+mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op,
const LinalgDependenceGraph &dependenceGraph,
- const LinalgTilingOptions &tilingOptions) {
+ const LinalgTilingOptions &tilingOptions,
+ const LinalgFusionOptions &fusionOptions) {
switch (tilingOptions.loopType) {
case LinalgTilingLoopType::Loops:
- return tileAndFuseLinalgOpsImpl<scf::ForOp>(builder, ops, dependenceGraph,
- tilingOptions);
+ return tileAndFuseLinalgOpsImpl<scf::ForOp>(rewriter, op, dependenceGraph,
+ tilingOptions, fusionOptions);
case LinalgTilingLoopType::ParallelLoops:
return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(
- builder, ops, dependenceGraph, tilingOptions);
+ rewriter, op, dependenceGraph, tilingOptions, fusionOptions);
default:;
}
return llvm::None;
+++ /dev/null
-// RUN: mlir-opt -pass-pipeline="func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),canonicalize,cse" -split-input-file %s | FileCheck %s
-
-module {
- func @three_op_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
- %arg2: memref<?xf32>, %arg3 : memref<?x?xf32>) {
- %cst = constant 0.000000e+00 : f32
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %d0 = dim %arg0, %c0 : memref<?x?xf32>
- %d1 = dim %arg1, %c1 : memref<?x?xf32>
- %0 = alloc(%d0, %d1) : memref<?x?xf32>
- linalg.fill(%0, %cst) : memref<?x?xf32>, f32
- linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%0 : memref<?x?xf32>)
- linalg.generic
- {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d1)>,
- affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]}
- ins(%0, %arg2 : memref<?x?xf32>, memref<?xf32>)
- outs(%arg3 : memref<?x?xf32>) {
- ^bb0(%arg4 : f32, %arg5 : f32, %arg6 : f32) :
- %5 = addf %arg4, %arg5 : f32
- linalg.yield %5 : f32
- }
- return
- }
-}
-
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
-// CHECK: func @three_op_fusion
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?xf32>
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK: %[[TEMP:.+]] = alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32>
-// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = {{.*}} {
-// CHECK-DAG: %[[SV_TEMP:.+]] = subview %[[TEMP]][%[[IV0]], %[[IV1]]]
-// CHECK-DAG: %[[SV_ARG2:.+]] = subview %[[ARG2]][%[[IV1]]]
-// CHECK-DAG: %[[SV_ARG3:.+]] = subview %[[ARG3]][%[[IV0]], %[[IV1]]]
-// CHECK-DAG: %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0]
-// CHECK-DAG: %[[SV_ARG1:.+]] = subview %[[ARG1]][0, %[[IV1]]]
-// CHECK: linalg.fill(%[[SV_TEMP]], %{{.+}})
-// CHECK: linalg.matmul
-// CHECK-SAME: ins(%[[SV_ARG0]], %[[SV_ARG1]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP2]]>, memref<?x?xf32, #[[MAP2]]>)
-// CHECK-SAME: outs(%[[SV_TEMP]] : memref<?x?xf32, #[[MAP2]]>)
-// CHECK: linalg.generic
-// CHECK-SAME: ins(%[[SV_TEMP]], %[[SV_ARG2]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP2]]>, memref<?xf32, #[[MAP3]]>)
-// CHECK-SAME: outs(%[[SV_ARG3]] : memref<?x?xf32, #[[MAP2]]>)
-// CHECK: scf.yield
-// CHECK: }
-
-// -----
-
-module {
- func @sequence_of_matmul(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
- %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>,
- %arg4: memref<?x?xf32>) {
- %cst = constant 0.000000e+00 : f32
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %m = dim %arg0, %c0 : memref<?x?xf32>
- %n1 = dim %arg1, %c1 : memref<?x?xf32>
- %n2 = dim %arg2, %c1 : memref<?x?xf32>
- %n3 = dim %arg3, %c1 : memref<?x?xf32>
- %0 = alloc(%m, %n1) : memref<?x?xf32>
- %1 = alloc(%m, %n2) : memref<?x?xf32>
- linalg.fill(%0, %cst) : memref<?x?xf32>, f32
- linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%0 : memref<?x?xf32>)
- linalg.fill(%1, %cst) : memref<?x?xf32>, f32
- linalg.matmul ins(%0, %arg2 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%1 : memref<?x?xf32>)
- linalg.fill(%arg4, %cst) : memref<?x?xf32>, f32
- linalg.matmul ins(%1, %arg3 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%arg4 : memref<?x?xf32>)
- return
- }
-}
-
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-// CHECK: func @sequence_of_matmul
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[C16:.+]] = constant 16 : index
-// CHECK-DAG: %[[M:.+]] = dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[N1:.+]] = dim %[[ARG1]], %[[C1]]
-// CHECK-DAG: %[[N2:.+]] = dim %[[ARG2]], %[[C1]]
-// CHECK: %[[ALLOC1:.+]] = alloc(%[[M]], %[[N1]])
-// CHECK: %[[ALLOC2:.+]] = alloc(%[[M]], %[[N2]])
-// CHECK: scf.parallel (%[[IV0:.+]]) = (%[[C0]]) to (%[[M]])
-// CHECK-SAME: step (%[[C16]]) {
-// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
-// CHECK: %[[SV_ALLOC2:.+]] = subview %[[ALLOC2]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[N2]]]
-// CHECK: %[[M_2:.+]] = dim %[[ARG4]], %[[C0]]
-// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]]
-// CHECK: %[[N3:.+]] = dim %[[ARG4]], %[[C1]]
-// CHECK: %[[SV_ARG4:.+]] = subview %[[ARG4]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_2]], %[[N3]]]
-// CHECK: %[[SV_ARG4_2:.+]] = subview %[[ARG4]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[N3]]]
-// CHECK: %[[SV_ALLOC1:.+]] = subview %[[ALLOC1]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[N1]]]
-// CHECK: %[[SV_ARG2:.+]] = subview %[[ARG2]][0, 0] [%[[N1]], %[[N2]]]
-// CHECK: %[[N0:.+]] = dim %[[ARG0]], %[[C1]]
-// CHECK: %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M:.+]], %[[N0]]]
-// CHECK: %[[SV_ARG1:.+]] = subview %[[ARG1]][0, 0] [%[[N0]], %[[N1]]]
-// CHECK: linalg.fill(%[[SV_ALLOC1]], %{{.+}})
-// CHECK: linalg.matmul ins(%[[SV_ARG0]], %[[SV_ARG1]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
-// CHECK-SAME: outs(%[[SV_ALLOC1]] : memref<?x?xf32, #[[MAP1]]>)
-// CHECK: linalg.fill(%[[SV_ALLOC2]], %{{.+}})
-// CHECK: linalg.matmul ins(%[[SV_ALLOC1]], %[[SV_ARG2]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
-// CHECK-SAME: outs(%[[SV_ALLOC2]] : memref<?x?xf32, #[[MAP1]]>)
-// CHECK: linalg.fill(%[[SV_ARG4_2]], %{{.+}})
-// CHECK: linalg.matmul ins(%[[SV_ALLOC2]], %[[ARG3]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
-// CHECK-SAME: outs(%[[SV_ARG4]] : memref<?x?xf32, #[[MAP1]]>)
-// CHECK: scf.yield
-// CHECK: }
-