ArrayRef<int64_t> peeledLoops,
LinalgTilingLoopType loopType);
-/// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This
-/// proceeds as follows:
-/// - Find outer parallel loops in these ops that can be fused.
-/// - Tile fusable outer parallel loops of the last operation in the sequence.
-/// - Fuse the remaining operations with the tiled operation
-///
-/// For example, consider the sequence of matmul below
-///
-/// linalg.matmul ins(%arg0, %arg1 : memref<256x32xf32>, memref<32x32xf32>)
-/// outs(%arg2 : memref<256x32xf32>)
-/// linalg.matmul ins(%arg2, %arg3 : memref<256x32xf32>, memref<32x32xf32>)
-/// outs(%arg4 : memref<256x32xf32>)
-///
-/// It is legal to fuse the RAW dependence (through %arg2) by only fusing the
-/// matmuls row-wise. For example, the fused computation for the above is shown
-/// below. The outer `scf.parallel` loop is the "fused" loop obtained by tiling
-/// along the rows of the matrix. The entire rows of the first matmul operation
-/// need to be computed before they can be used for the second matmul. The
-/// second matmul is further tiled (similar to normal tiling).
-///
-/// #map0 = affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>
-/// #map1 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
-/// scf.parallel (%arg5) = (%c0) to (%c256) step (%c16) {
-/// %0 = subview %arg2[%arg5, 0] [16, 32] [1, 1]
-/// : memref<256x32xf32> to memref<16x32xf32, #map0>
-/// %1 = subview %arg4[%arg5, 0] [16, 32] [1, 1]
-/// : memref<256x32xf32> to memref<16x32xf32, #map0>
-/// %2 = subview %arg0[%arg5, 0] [16, 32] [1, 1]
-/// : memref<256x32xf32> to memref<16x32xf32, #map0>
-/// %3 = subview %arg1[0, 0] [32, 32] [1, 1]
-/// : memref<32x32xf32> to memref<32x32xf32, #map1>
-/// %4 = subview %arg3[0, 0] [32, 32] [1, 1]
-/// : memref<32x32xf32> to memref<32x32xf32, #map1>
-/// linalg.matmul
-/// ins(%2, %3 : memref<16x32xf32, #map0>, memref<32x32xf32, #map1>)
-/// outs(%0 : memref<16x32xf32, #map0>)
-/// linalg.matmul
-/// ins(%0, %4 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>)
-/// outs(%1 : memref<16x8xf32, #map0>)
-/// }
-///
-/// `tilingOptions` are used to tile the corresponding operation in `ops` (the
-/// size of the former should be same as size of the latter. Based on how
-/// tile+fuse is implemented, the fused loops are generated based on the last
-/// operation in the sequence. For example, the tile sizes for the fused loops
-/// is obtained from `tilingOptions.back()`. The following tiling options are
-/// handled differently in tile+fuse (compared to tile only)
-/// - Interchange of the tiling loops is not supported right now.
-/// - Only the fused loops are distributed.
-struct TiledAndFusedLinalgOps {
- /// Operation obtained by tiling the last operation in sequence of `ops`
- /// passed to `tileAndFuseLinalgOps`.
- LinalgOp op;
- /// The dimension of the loops that are fused.
- std::set<unsigned> fusedLoopDims;
- /// The generated fused operations (created within the fused loops).
- SmallVector<LinalgOp, 1> fusedProducers;
- /// The fused loop generated.
- SmallVector<Operation *, 4> fusedLoops;
-};
-FailureOr<TiledAndFusedLinalgOps>
-tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
- const LinalgDependenceGraph &dependenceGraph,
- const LinalgTilingOptions &tilingOptions);
-
/// Interchange the `iterator_types` and `iterator_maps` dimensions and adapts
/// the index accesses of `op`. This is an in-place transformation controlled by
/// `interchangeVector`. An empty vector is interpreted as the identity
LinalgTransformationFilter filter;
};
-struct LinalgFusionOptions {
- /// List of operands indices to use for fusion.
- llvm::SmallSet<unsigned, 1> indicesToFuse = {};
- LinalgFusionOptions &setIndicesToFuse(ArrayRef<int64_t> operands) {
- indicesToFuse.insert(operands.begin(), operands.end());
- return *this;
- }
-};
-
-struct LinalgBaseTileAndFusePattern : public RewritePattern {
- LinalgBaseTileAndFusePattern(
- StringRef opName, MLIRContext *context,
- const LinalgDependenceGraph &dependenceGraph,
- LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
- LinalgTransformationFilter f = LinalgTransformationFilter(),
- LinalgTransformationFilter fusedOpMarker = LinalgTransformationFilter(),
- LinalgTransformationFilter originalOpMarker =
- LinalgTransformationFilter(),
- PatternBenefit benefit = 1);
- LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override;
-
-private:
- /// Dependence graph needed for fusion.
- const LinalgDependenceGraph &dependenceGraph;
- /// Options to control tiling.
- LinalgTilingOptions tilingOptions;
- /// Options to control fusion.
- LinalgFusionOptions fusionOptions;
- /// Marker to control application of the pattern.
- LinalgTransformationFilter filter;
- /// Marker set on the fused op after tile and fuse.
- LinalgTransformationFilter fusedOpMarker;
- /// The dependenceGraph is not modifiable, i.e. if the Linalg operations used
- /// to build the dependence graph changes then the dependenceGraph needs to be
- /// recomputed right now. To not invalidate the dependenceGraph as
- /// transformation happens, the original producer can be tagged with a filter
- /// that can be later used to delete the original operations.
- LinalgTransformationFilter originalOpMarker;
-};
-
-template <typename OpTy>
-struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern {
- LinalgTileAndFusePattern(
- MLIRContext *context, const LinalgDependenceGraph &dependenceGraph,
- LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
- LinalgTransformationFilter f = LinalgTransformationFilter(),
- LinalgTransformationFilter fusedOpMarker = LinalgTransformationFilter(),
- LinalgTransformationFilter originalOpMarker =
- LinalgTransformationFilter(),
- PatternBenefit benefit = 1)
- : LinalgBaseTileAndFusePattern(
- OpTy::getOperationName(), context, dependenceGraph, tilingOptions,
- fusionOptions, f, fusedOpMarker, originalOpMarker, benefit) {}
-};
-
///
/// Linalg tile and fuse tensor ops pattern.
///
consumerOpOperand.set(def);
return FusionInfo{cast<LinalgOp>(producerOpResult.getOwner()), fusedProducer};
}
-
-/// Prune all dimensions that are of reduction iterator type from `map`.
-static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
- AffineMap map) {
- llvm::SmallBitVector projectedDims(iteratorTypes.size());
- for (const auto &attr : llvm::enumerate(iteratorTypes)) {
- if (!isParallelIterator(attr.value()))
- projectedDims.set(attr.index());
- }
- return getProjectedMap(map, projectedDims);
-}
-
-/// Returns the mapping from iterations in the consumer that write to the same
-/// location as the iterations in the producer. To do so use
-/// - indexing map of the fused view in the consumer : consumerIndexMap
-/// - indexing map of the fused view in the producer : producerIndexMap
-/// consumerLoopToProducerLoop =
-/// inverse(producerIndexMap).compose(consumerIndexMap)
-static FailureOr<AffineMap> getConsumerLoopToProducerLoopMap(
- LinalgDependenceGraph::LinalgDependenceGraphElem dependence) {
- auto producer = dyn_cast<LinalgOp>(dependence.getDependentOp());
- if (!producer)
- return failure();
-
- Optional<AffineMap> producerIndexingMap =
- dependence.getDependentOpViewIndexingMap();
- Optional<AffineMap> consumerIndexingMap =
- dependence.getIndexingOpViewIndexingMap();
- if (!producerIndexingMap || !consumerIndexingMap)
- return failure();
-
- AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
- producer.iterator_types().getValue(), *producerIndexingMap);
- if (!prunedProducerIndexingMap.isPermutation())
- return failure();
-
- if (consumerIndexingMap->getNumResults() !=
- prunedProducerIndexingMap.getNumResults())
- return failure();
-
- LLVM_DEBUG({
- llvm::dbgs() << "\t producerMap : ";
- producerIndexingMap->print(llvm::dbgs());
- llvm::dbgs() << " pruned : ";
- prunedProducerIndexingMap.print(llvm::dbgs());
- llvm::dbgs() << "\n";
- llvm::dbgs() << "\t consumerMap : ";
- consumerIndexingMap->print(llvm::dbgs());
- llvm::dbgs() << "\n";
- });
-
- AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap);
- if (!invProducerIndexMap)
- return failure();
-
- return invProducerIndexMap.compose(*consumerIndexingMap);
-}
-
-/// Given a projected permutation `map`, returns true if the map changes the
-/// order in which the fused loop dimension appear.
-static bool doesTransposeAccess(AffineMap map,
- const std::set<unsigned> &fusableLoops) {
- Optional<unsigned> lastFusableLoop;
- for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) {
- return expr.cast<AffineDimExpr>().getPosition();
- })) {
- if (!fusableLoops.count(pos))
- continue;
- if (!lastFusableLoop) {
- lastFusableLoop = pos;
- continue;
- }
- if (pos <= *lastFusableLoop)
- return true;
- lastFusableLoop = pos;
- }
- return false;
-}
-
-/// Returns the positions of the loop in `op` that can be tiled based on the
-/// operations that are to be fused with it. For example, in a
-///
-/// linalg.matmul ins(%a, %b : ...) outs(%c : ...)
-///
-/// if the producer of %a needs to be fused with this op, only the `i` loop of
-/// the matmul can be tiled while fusing. If producer of %a, and %b are to be
-/// fused, then no loops can be tiled while fusing. The conditions used are:
-/// 1. Only parallel loops can be used for tile + fuse. Find the number of
-/// common outer parallel loops between the op and its producers being fused.
-/// 2. Of the parallel loops only some can be fused. Only those loops can be
-/// fused such where the fusable loops iteration space only touches one tile
-/// of the fused operation. This is because the producer (which is writing
-/// the fused subview) has update semantics.
-///
-/// Since an inverse computation is needed, we need to consider the projection
-/// of the producerIndexMap w.r.t the parallel loops. The actual fusable loops
-/// are the dimensions of the consumerLoopToProducerLoop map that correspond to
-/// parallel loops and appear in the result of the map
-///
-/// Example 1:
-/// linalg.fill(%cst, %c)
-/// linalg.matmul ins(%a, %b) outs(%c)
-/// Number of parallel loops : 2
-/// producerIndexMap = affine_map<(i, j) ->(i , j)>
-/// consumerIndexMap = affine_map<(i, j, k) -> (i, j)>
-/// consumerLoopToProducerLoop = affine_map<(i, j, k) -> (i, j)>
-/// Fused dimensions : i, j
-///
-/// Example 2:
-/// linalg.matmul ins(%a, %b) outs(%c)
-/// linalg.generic {indexing_maps = [affine_map<(i, j) -> (j, i)>, ...
-/// iterator_types = ["parallel", "parallel"]}
-/// ins(%c) ...
-///
-/// Number of parallel loops = 2:
-/// producerIndexMap (projected to parallel loops) =
-/// affine_map<(i, j) -> (i, j)>
-/// consumerLoopToProducerLoop2 = affine_map<(i, j) -> (j, i)>
-/// Fused dimensions : i, j
-///
-/// Example 3:
-/// memref.copy(%s, %b)
-/// linalg.matmul ins(%a, %b) outs(%c)
-///
-/// Number of parallel loops = 2
-/// produceIndexMap : affine_map<(i, j) -> (i, j)>
-/// consumerLoopToProduceLoops = affine_map<(i, j, k) -> (k, j)>
-/// submap with only parallel loops = affine_map<(i, j) -> (j)>
-/// Fused dimensions : j
-static std::set<unsigned>
-collectFusableLoops(ArrayRef<LinalgOp> ops,
- const FusableOpDependencesTy &fusableDependences) {
- assert(!ops.empty());
- auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
- return linalgOp.iterator_types()
- .getValue()
- .take_while([](Attribute attr) -> bool {
- return attr.cast<StringAttr>().getValue() ==
- getParallelIteratorTypeName();
- })
- .size();
- };
-
- size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back());
- for (auto op : ops.drop_back()) {
- numOuterParallelLoops =
- std::min(numOuterParallelLoops, getNumOuterParallelLoops(op));
- }
-
- std::set<unsigned> fusableLoops;
- auto range = llvm::seq<unsigned>(0, numOuterParallelLoops);
- fusableLoops.insert(range.begin(), range.end());
-
- for (auto op : reverse(ops)) {
- for (auto dependence : fusableDependences.lookup(op)) {
- LLVM_DEBUG({
- llvm::dbgs() << "\t fusable :";
- for (unsigned i : fusableLoops)
- llvm::dbgs() << " " << i;
- llvm::dbgs() << "\n";
- });
-
- Optional<AffineMap> consumerLoopToProducerLoop =
- getConsumerLoopToProducerLoopMap(dependence);
- if (!consumerLoopToProducerLoop) {
- op.emitRemark("failed to get map from consumer loop to producer loop");
- return {};
- }
- // todo: This condition is only an implementation limitation. When fusing
- // the operation, if the accesses in the producer/consumer are transposes
- // of each other, the loop bounds for the tiled producer can be
- // manipulated accordingly. This requires some additional bookkeeping in
- // the implementation of tile+fuse that is deferred to later.
- if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) {
- op.emitRemark("unhandled fusion when fusion requires permutation");
- return {};
- }
-
- std::set<unsigned> candidates;
- for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) {
- unsigned position = expr.cast<AffineDimExpr>().getPosition();
- if (fusableLoops.count(position))
- candidates.insert(position);
- }
- LLVM_DEBUG({
- llvm::dbgs() << "\t candidates :";
- for (unsigned i : candidates)
- llvm::dbgs() << " " << i;
- llvm::dbgs() << "\n";
- });
- if (candidates.empty())
- return {};
- std::swap(candidates, fusableLoops);
- }
- }
-
- return fusableLoops;
-}
-
-/// Find all dependences that are fusable.
-FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
- ArrayRef<LinalgOp> ops, const LinalgDependenceGraph &dependenceGraph) {
- FusableOpDependencesTy fusableDependences;
- DenseMap<Operation *, SmallVector<AffineMap, 1>> fusedProducerIndexingMap;
- for (LinalgOp op : reverse(ops)) {
- for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
- Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
- fusableDependence = findFusableProducer(*opOperand, dependenceGraph);
- if (!fusableDependence)
- continue;
- LinalgOp producerOp =
- dyn_cast<LinalgOp>(fusableDependence->getDependentOp());
- if (!producerOp)
- continue;
- // Do not fuse dependences that are to operations not in the same basic
- // block. This avoid moving fused operations across loops that might
- // themselves carry dependency making the fusion illegal.
- if (producerOp->getBlock() != op->getBlock())
- continue;
-
- // Make sure that the indexing map of the view used for fusion in the
- // producer is a projected permutation.
- Optional<AffineMap> producerMap =
- fusableDependence->getDependentOpViewIndexingMap();
- Optional<AffineMap> consumerMap =
- fusableDependence->getIndexingOpViewIndexingMap();
- assert(
- consumerMap &&
- "unable to find indexing map of operand/result of indexing OpView");
- fusedProducerIndexingMap[producerOp.getOperation()].push_back(
- *consumerMap);
- if (!producerMap || !producerMap->isProjectedPermutation() ||
- !consumerMap->isProjectedPermutation())
- continue;
-
- fusableDependences[producerOp.getOperation()].push_back(
- *fusableDependence);
- }
- }
- // TODO: Currently fusion would not be legal if the fusable dependence is to
- // the same producer but different indexing map in the consumer. Fix this, but
- // in the meanwhile disallow such a fusion.
- for (auto useIndexingMapsList : fusedProducerIndexingMap) {
- AffineMap map1 = useIndexingMapsList.second.front();
- for (AffineMap map2 :
- ArrayRef<AffineMap>(useIndexingMapsList.second).drop_front()) {
- if (map1 != map2) {
- fusableDependences.erase(useIndexingMapsList.first);
- break;
- }
- }
- }
- return fusableDependences;
-}
-
-/// Tile the fused loops in the root operation, by setting the tile sizes for
-/// all other loops to zero (those will be tiled later).
-static FailureOr<TiledLinalgOp>
-tileRootOperation(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizeVector,
- const LinalgTilingOptions &options,
- const std::set<unsigned> &fusedLoops) {
- SmallVector<Value, 4> tileSizes(tileSizeVector.begin(), tileSizeVector.end());
- auto zero = b.create<arith::ConstantIndexOp>(op.getLoc(), 0);
- for (unsigned i = 0, e = tileSizes.size(); i != e; ++i)
- if (!fusedLoops.count(i))
- tileSizes[i] = zero;
- LinalgTilingOptions tileFusedLoopsOptions = options;
- tileFusedLoopsOptions.setTileSizes(tileSizes);
- // TODO: Propagate RewriterBase everywhere.
- IRRewriter rewriter(b);
- return tileLinalgOp(rewriter, op, tileFusedLoopsOptions);
-}
-
-/// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected
-/// to be a tiled operation such that it is valid to fuse all operations in
-/// `fusionCandidates`, i.e. move the operation within the inter-tile loops of
-/// `tiledOp`.
-static SmallVector<LinalgOp, 1>
-fuseOperations(OpBuilder &b, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp,
- ArrayRef<LinalgOp> fusionCandidates,
- const FusableOpDependencesTy &fusableDependences,
- const std::set<unsigned> &fusedLoops) {
- LinalgOp tiledOp = tiledLinalgOp.op;
- OpBuilder::InsertionGuard guard(b);
- b.setInsertionPoint(tiledOp);
-
- DenseMap<unsigned, Range> fusedLoopsAndRanges;
- for (unsigned loop : fusedLoops) {
- ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop, true);
- fusedLoopsAndRanges[loop] = getRangeFromOperandShape(
- b, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension);
- }
-
- SmallVector<LinalgOp, 1> fusedOps(fusionCandidates.size());
- DenseMap<Operation *, LinalgOp> origOpToFusedOp;
- origOpToFusedOp[rootOp.getOperation()] = tiledOp;
- for (const auto &candidate : enumerate(llvm::reverse(fusionCandidates))) {
- LinalgOp origOp = candidate.value();
- LinalgOp fusedOp = fuse(b, origOp, fusedLoopsAndRanges);
- origOpToFusedOp[origOp.getOperation()] = fusedOp;
- fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
-
- // Prepare the builder for the next insertion point.
- auto guard = llvm::make_scope_exit([&]() { b.setInsertionPoint(fusedOp); });
- if (!origOp.hasTensorSemantics())
- continue;
-
- // If the producer consumer operations are linalg operations on tensors, the
- // dependence is due to value produced (as a return tensor) by the producer
- // and used in the consumer. The returned value of the fused op needs to be
- // made the operand of the tiled/fused consumer operation. By construction
- // the value returned by the producer is the value used by the consumer.
- for (auto &dependence : fusableDependences.lookup(origOp.getOperation())) {
- if (dependence.dependenceType !=
- LinalgDependenceGraph::DependenceType::RAW)
- continue;
-
- unsigned resultIndex = dependence.getDependentOpViewResultNum().value();
- LinalgOp consumer = origOpToFusedOp.lookup(dependence.getIndexingOp());
- if (!consumer)
- continue;
-
- Value replacementValue = fusedOp.getOperation()->getResult(resultIndex);
- consumer.getOperation()->setOperand(
- dependence.getIndexingOpViewOperandNum().value(), replacementValue);
- }
-
- // At this point, all Linalg uses of the tensors produced by `origOp` have
- // been replaced. However, there may still be "output tensor"-like uses
- // coming from WAW dependencies.
- // All these uses are iter_args of the outermost loop (TODO: add a check).
- // Such iter_args uses serve 2 purposes:
- // 1. give a shape to the output
- // 2. encode destructive updates that may be inplaceable by bufferization.
- // To keep the second type of information while letting the unfused op die
- // unused, we need to forward the producer output operand.
- if (auto forOp = dyn_cast<scf::ForOp>(tiledLinalgOp.loops.front())) {
- for (auto &operand : forOp.getIterOpOperands()) {
- if (auto opResult = operand.get().dyn_cast<OpResult>()) {
- if (opResult.getOwner() == origOp) {
- Value output =
- origOp.getOutputOperand(opResult.getResultNumber())->get();
- assert(output.getType().isa<RankedTensorType>());
- operand.set(output);
- }
- }
- }
- }
- }
- return fusedOps;
-}
-
-static FailureOr<TiledAndFusedLinalgOps>
-tileAndFuseLinalgOpsImpl(OpBuilder &b, ArrayRef<LinalgOp> ops,
- const LinalgDependenceGraph &dependenceGraph,
- const LinalgTilingOptions &tilingOptions) {
- if (ops.size() < 2)
- return failure();
- LinalgOp rootOp = ops.back();
- if (!llvm::all_of(
- ops,
- [](LinalgOp linalgOp) { return linalgOp.hasBufferSemantics(); }) &&
- !llvm::all_of(ops, [](LinalgOp linalgOp) {
- return linalgOp.hasTensorSemantics();
- })) {
- rootOp.emitError(
- "unable to fuse operations that have tensor semantics with operations "
- "that have buffer semantics and viceversa.");
- return failure();
- }
- // TODO: Support interchange with tile + fuse. This might actually help do
- // better fusion.
- if (!tilingOptions.interchangeVector.empty()) {
- rootOp.emitRemark("unable to handle tile and fuse with interchange");
- return failure();
- }
-
- OpBuilder::InsertionGuard guard(b);
- b.setInsertionPoint(rootOp);
-
- // Find all the producers.
- LLVM_DEBUG(llvm::dbgs() << "findAllFusableDependences\n");
- FusableOpDependencesTy fusableDependences =
- findAllFusableDependences(ops, dependenceGraph);
- if (fusableDependences.empty()) {
- LLVM_DEBUG(llvm::dbgs() << "no fusable dependencies found\n");
- return failure();
- }
-
- TiledAndFusedLinalgOps ret;
- // Find the loops that can be tiled and fused.
- LLVM_DEBUG(llvm::dbgs() << "collectFusableLoops\n");
- ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences);
-
- // If there are no fusable dependences or there are no tile+fusable loops,
- // just return.
- if (ret.fusedLoopDims.empty()) {
- LLVM_DEBUG(llvm::dbgs() << "no fusable loops found\n");
- return failure();
- }
-
- // Tile the fused loops in the last operation in the list.
- SmallVector<Value, 4> tileSizeVector =
- tilingOptions.tileSizeComputationFunction(b, rootOp);
- FailureOr<TiledLinalgOp> tiledRootOp = tileRootOperation(
- b, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims);
- if (failed(tiledRootOp)) {
- rootOp.emitRemark("failed to tile the fused loops");
- return failure();
- }
- ret.op = tiledRootOp->op;
- ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
-
- // Fuse the other operations into the fused inter-tile loops produced above.
- ret.fusedProducers = fuseOperations(b, rootOp, *tiledRootOp, ops.drop_back(),
- fusableDependences, ret.fusedLoopDims);
-
- return ret;
-}
-
-FailureOr<TiledAndFusedLinalgOps>
-mlir::linalg::tileAndFuseLinalgOps(OpBuilder &b, ArrayRef<LinalgOp> ops,
- const LinalgDependenceGraph &dependenceGraph,
- const LinalgTilingOptions &tilingOptions) {
- switch (tilingOptions.loopType) {
- case LinalgTilingLoopType::Loops:
- case LinalgTilingLoopType::ParallelLoops:
- case LinalgTilingLoopType::TiledLoops:
- return tileAndFuseLinalgOpsImpl(b, ops, dependenceGraph, tilingOptions);
- default:;
- }
- return failure();
-}
}
}
-static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
- if (tiledOp.loops.empty())
- return tiledOp.op.getOperation()->getResults();
- return tiledOp.loops.front()->getResults();
-}
-
-static ValueRange
-getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) {
- if (tiledAndFusedOp.fusedLoops.empty())
- return tiledAndFusedOp.op.getOperation()->getResults();
- return tiledAndFusedOp.fusedLoops.front()->getResults();
-}
-
-mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
- StringRef opName, MLIRContext *context,
- const LinalgDependenceGraph &dependenceGraph,
- LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
- LinalgTransformationFilter f, LinalgTransformationFilter fusedOpMarker,
- LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
- : RewritePattern(opName, benefit, context, {}),
- dependenceGraph(dependenceGraph), tilingOptions(std::move(tilingOptions)),
- fusionOptions(std::move(fusionOptions)), filter(std::move(f)),
- fusedOpMarker(std::move(fusedOpMarker)),
- originalOpMarker(std::move(originalOpMarker)) {}
-
-LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
- Operation *op, PatternRewriter &rewriter) const {
- LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
- // TODO: remove hasIndexSemantics check once index ops are supported.
- if (!linalgOp || linalgOp.hasIndexSemantics())
- return failure();
- if (failed(filter.checkAndNotify(rewriter, linalgOp)))
- return failure();
-
- DenseSet<Operation *> producers;
- producers.insert(linalgOp);
- for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) {
- Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum();
- // When looking at dependences into, indexingOp is always OpOperand. We
- // could assert, but continue if this is not the case.
- if (!operandNumber)
- continue;
- if (!fusionOptions.indicesToFuse.count(*operandNumber))
- continue;
- if (isa<LinalgOp>(dependence.getDependentOp()))
- producers.insert(dependence.getDependentOp());
- }
-
- SmallVector<LinalgOp, 1> fusionOps;
- for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
- ++it) {
- auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
- if (producerLinalgOp && producers.count(producerLinalgOp))
- fusionOps.push_back(producerLinalgOp);
- }
- fusionOps.push_back(linalgOp);
-
- SmallVector<Value, 4> tileSizes =
- tilingOptions.tileSizeComputationFunction(rewriter, op);
- LinalgTilingOptions instanceTilingOptions = tilingOptions;
- instanceTilingOptions.setTileSizes(tileSizes);
- Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
- rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
- if (!tiledAndFusedOps)
- return failure();
-
- // Tile the unfused loops;
- SmallVector<Value, 4> unfusedLoopTileSizes;
- Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0);
- for (const auto &tileSize : enumerate(tileSizes)) {
- if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
- unfusedLoopTileSizes.push_back(zero);
- else
- unfusedLoopTileSizes.push_back(tileSize.value());
- }
- // Tile the loop only if there is a non-zero tile size.
- if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
- unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
- if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
- if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>())
- return cst.value() != 0;
- return true;
- })) {
- LinalgTilingOptions unfusedTilingOptions = tilingOptions;
- unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
- FailureOr<TiledLinalgOp> unfusedTiledOp =
- tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
- if (failed(unfusedTiledOp))
- return failure();
- rewriter.replaceOp(tiledAndFusedOps->op,
- getTiledOpResult(unfusedTiledOp.value()));
- tiledAndFusedOps->op = unfusedTiledOp->op;
- }
- op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.value()));
-
- filter.replaceLinalgTransformationFilter(rewriter,
- tiledAndFusedOps->op.getOperation());
- for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
- fusedOpMarker.replaceLinalgTransformationFilter(rewriter,
- fusedOp.getOperation());
- }
- for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
- originalOpMarker.replaceLinalgTransformationFilter(
- rewriter, origProducerOp.getOperation());
- }
- rewriter.updateRootInPlace(op, [&]() {
- originalOpMarker.replaceLinalgTransformationFilter(rewriter, op);
- });
- return success();
-}
-
/// Linalg tiling pattern.
mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
MLIRContext *context, LinalgTilingOptions options,
+++ /dev/null
-// RUN: mlir-opt %s -test-linalg-fusion-transform-patterns -canonicalize -cse -split-input-file | FileCheck %s
-
-module {
- func.func @basic_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
- %arg2: memref<?x?xf32>) {
- %cst = arith.constant 0.000000e+00 : f32
- linalg.fill ins(%cst : f32) outs(%arg2 : memref<?x?xf32>)
- linalg.matmul {__internal_linalg_transform__ = "basic_fusion"}
- ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%arg2 : memref<?x?xf32>)
- return
- }
-}
-
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 32)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 64)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)>
-// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 32)>
-// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 64)>
-// CHECK: func @basic_fusion
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
-// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index
-// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
-// CHECK-DAG: %[[CST:.+]] = arith.constant 0.0{{.*}} : f32
-// CHECK-DAG: linalg.fill
-// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_original"
-// CHECK-SAME: ins(%[[CST]]{{.*}}outs(%[[ARG2]]
-// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
-// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) =
-// CHECK-SAME: to (%[[M]], %[[N]])
-// CHECK-SAME: step (%[[C32]], %[[C64]]) {
-// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
-// CHECK: %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
-// CHECK: %[[SV1:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[K]]]
-// CHECK: %[[K_2:.+]] = memref.dim %[[ARG1]], %[[C0]]
-// CHECK: %[[TILE_N:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N]]]
-// CHECK: %[[SV2:.+]] = memref.subview %[[ARG1]][0, %[[IV1]]]
-// CHECK-SAME: %[[K_2]], %[[TILE_N]]
-// CHECK: %[[SV3:.+]] = memref.subview %[[ARG2]][%[[IV0]], %[[IV1]]]
-// CHECK-SAME: [%[[TILE_M]], %[[TILE_N]]]
-// CHECK: %[[M_2:.+]] = memref.dim %[[ARG2]], %[[C0]]
-// CHECK: %[[N_2:.+]] = memref.dim %[[ARG2]], %[[C1]]
-// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]])[%[[M_2]], %[[M]]]
-// CHECK: %[[TILE_N_3:.+]] = affine.min #[[MAP5]](%[[IV1]])[%[[N_2]], %[[N]]]
-// CHECK: %[[SV3_2:.+]] = memref.subview %[[ARG2]][%[[IV0]], %[[IV1]]]
-// CHECK-SAME: [%[[TILE_M_3]], %[[TILE_N_3]]]
-// CHECK: linalg.fill
-// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_producer"
-// CHECK-SAME: ins(%[[CST]]{{.*}}outs(%[[SV3_2]]
-// CHECK: scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] {
-// CHECK: %[[TILE_K:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K]]]
-// CHECK: %[[SV4:.+]] = memref.subview %[[SV1]][0, %[[IV2]]]
-// CHECK-SAME: [%[[TILE_M]], %[[TILE_K]]]
-// CHECK: %[[SV5:.+]] = memref.subview %[[SV2]][%[[IV2]], 0]
-// CHECK-SAME: [%[[TILE_K]], %[[TILE_N]]]
-// CHECK: linalg.matmul
-// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion"
-// CHECK-SAME: ins(%[[SV4]], %[[SV5]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
-// CHECK-SAME: outs(%[[SV3]] : memref<?x?xf32, #[[MAP1]]>)
-// CHECK: }
-// CHECK: }
-// CHECK: linalg.matmul
-// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_original"
-
-// -----
-
-module {
- func.func @matmul_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
- %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>,
- %arg4: memref<?x?xf32>) {
- linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%arg2 : memref<?x?xf32>)
- linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"}
- ins(%arg2, %arg3 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%arg4 : memref<?x?xf32>)
- return
- }
-}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 32)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 64)>
-// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 32)>
-// CHECK: func @matmul_fusion
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
-// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index
-// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
-// CHECK: linalg.matmul
-// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_original"
-// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG2]], %[[C0]]
-// CHECK: scf.parallel (%[[IV0:.+]]) =
-// CHECK-SAME: (%[[C0]]) to (%[[M]]) step (%[[C32]]) {
-// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
-// CHECK: %[[K2:.+]] = memref.dim %[[ARG2]], %[[C1]]
-// CHECK: %[[SV1:.+]] = memref.subview %[[ARG2]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[K2]]]
-// CHECK: %[[N:.+]] = memref.dim %[[ARG4]], %[[C1]]
-// CHECK: %[[SV2:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[N]]]
-// CHECK: %[[M_3:.+]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]])[%[[M_3]], %[[M]]]
-// CHECK: %[[K1:.+]] = memref.dim %[[ARG0]], %[[C1]]
-// CHECK: %[[SV3:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_3]], %[[K1]]]
-// CHECK: %[[SV1_2:.+]] = memref.subview %[[ARG2]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_3]], %[[K2]]]
-// CHECK: linalg.matmul
-// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_producer"
-// CHECK-SAME: ins(%[[SV3]], %[[ARG1]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
-// CHECK-SAME: outs(%[[SV1_2]] : memref<?x?xf32, #[[MAP1]]>)
-// CHECK: %[[N_2:.+]] = memref.dim %[[ARG3]], %[[C1]]
-// CHECK: scf.parallel (%[[IV1:.+]]) =
-// CHECK-SAME: (%[[C0]]) to (%[[N_2]]) step (%[[C64]]) {
-// CHECK-NEXT: scf.for %[[IV2:.+]] = %[[C0]] to %[[K2]] step %[[C16]] {
-// CHECK: %[[TILE_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K2]]]
-// CHECK: %[[SV6:.+]] = memref.subview %[[SV1]][0, %[[IV2]]]
-// CHECK-SAME: [%[[TILE_M]], %[[TILE_K]]]
-// CHECK: %[[TILE_N:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N_2]]]
-// CHECK: %[[SV7:.+]] = memref.subview %[[ARG3]][%[[IV2]], %[[IV1]]]
-// CHECK-SAME: [%[[TILE_K]], %[[TILE_N]]]
-// CHECK: %[[SV8:.+]] = memref.subview %[[SV2]][0, %[[IV1]]]
-// CHECK-SAME: [%[[TILE_M]], %[[TILE_N]]]
-// CHECK: linalg.matmul
-// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion"
-// CHECK-SAME: ins(%[[SV6]], %[[SV7]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
-// CHECK-SAME: outs(%[[SV8]] : memref<?x?xf32, #[[MAP1]]>)
-// CHECK: }
-// CHECK: }
-// CHECK: }
-// CHECK: linalg.matmul
-// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_original"
-
-// -----
-
-module {
- func.func @matmul_plus_matmul(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
- %arg2: memref<?x?xf32>) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %0 = memref.dim %arg2, %c0 : memref<?x?xf32>
- %1 = memref.dim %arg2, %c1 : memref<?x?xf32>
- %2 = memref.alloc(%0, %1) : memref<?x?xf32>
- linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%2 : memref<?x?xf32>)
- linalg.generic
- {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"],
- __internal_linalg_transform__ = "transpose_fusion"}
- ins(%2, %2 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%arg2 : memref<?x?xf32>) {
- ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
- %3 = arith.addf %arg3, %arg4 : f32
- linalg.yield %3 : f32
- }
- return
- }
-}
-// CHECK: func @matmul_plus_matmul
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK: %[[T2:.+]] = memref.alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32>
-// CHECK: linalg.matmul
-// CHECK-SAME: after_transpose_fusion_original
-// CHECK: scf.parallel (%[[ARG3:[a-zA-Z0-9_]+]], %[[ARG4:.[a-zA-Z0-9_]+]])
-// CHECK: %[[T5:.+]] = memref.subview %[[T2]][%[[ARG3]], %[[ARG4]]]
-// CHECK: %[[T6:.+]] = memref.subview %[[ARG2]][%[[ARG3]], %[[ARG4]]]
-// CHECK: %[[T8:.+]] = memref.subview %[[ARG0]][%[[ARG3]], 0]
-// CHECK: %[[T9:.+]] = memref.subview %[[ARG1]][0, %[[ARG4]]]
-// CHECK: %[[T10:.+]] = memref.subview %[[T2]][%[[ARG3]], %[[ARG4]]]
-// CHECK: linalg.matmul
-// CHECK-SAME: after_transpose_fusion_producer
-// CHECK-SAME: ins(%[[T8]], %[[T9]]
-// CHECK-SAME: outs(%[[T10]]
-// CHECK-NOT: linalg.matmul
-// CHECK: linalg.generic
-// CHECK-SAME: ins(%[[T5]], %[[T5]]
-// CHECK-SAME: outs(%[[T6]]
-// CHECK-SAME: after_transpose_fusion
-
-// -----
-
-module {
- func.func @matmul_plus_transpose_matmul(%arg0: memref<?x?xf32>,
- %arg1: memref<?x?xf32>,
- %arg2: memref<?x?xf32>) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %0 = memref.dim %arg2, %c0 : memref<?x?xf32>
- %1 = memref.dim %arg2, %c1 : memref<?x?xf32>
- %2 = memref.alloc(%0, %1) : memref<?x?xf32>
- linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%2 : memref<?x?xf32>)
- linalg.generic
- {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d1, d0)>,
- affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"],
- __internal_linalg_transform__ = "transpose_fusion"}
- ins(%2, %2 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%arg2 : memref<?x?xf32>) {
- ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
- %3 = arith.addf %arg3, %arg4 : f32
- linalg.yield %3 : f32
- }
- return
- }
-}
-// CHECK-LABEL: func @matmul_plus_transpose_matmul
-// CHECK-NOT: scf.parallel
-// CHECK-NOT: scf.for
-// CHECK: linalg.matmul
-// CHECK-NOT: scf.parallel
-// CHECK-NOT: scf.for
-// CHECK: linalg.generic
-// CHECK-NOT: scf.parallel
-// CHECK-NOT: scf.for
-
-// -----
-
-#map0 = affine_map<(d0)[s0] -> (32, -d0 + s0)>
-#map1 = affine_map<(d0)[s0] -> (64, -d0 + s0)>
-#map2 = affine_map<(d0)[s0] -> (16, -d0 + s0)>
-#map3 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-module {
- func.func @basic_no_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
- %arg2: memref<?x?xf32>) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c2 = arith.constant 2 : index
- %c32 = arith.constant 32 : index
- %c64 = arith.constant 64 : index
- %c16 = arith.constant 16 : index
- %cst = arith.constant 0.000000e+00 : f32
- linalg.fill ins(%cst : f32) outs(%arg2 : memref<?x?xf32>)
- %0 = memref.dim %arg0, %c0 : memref<?x?xf32>
- %1 = memref.dim %arg1, %c1 : memref<?x?xf32>
- %2 = memref.dim %arg0, %c1 : memref<?x?xf32>
- scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%0, %1) step (%c32, %c64) {
- scf.for %arg5 = %c0 to %2 step %c16 {
- %3 = affine.min #map0(%arg3)[%0]
- %4 = affine.min #map1(%arg4)[%1]
- %5 = affine.min #map2(%arg5)[%2]
- %6 = memref.subview %arg0[%arg3, %arg5] [%3, %5] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map3>
- %7 = memref.subview %arg1[%arg5, %arg4] [%5, %4] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map3>
- %8 = memref.subview %arg2[%arg3, %arg4] [%3, %4] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map3>
- linalg.matmul {__internal_linalg_transform__ = "basic_fusion"}
- ins(%6, %7 : memref<?x?xf32, #map3>, memref<?x?xf32, #map3>)
- outs(%8 : memref<?x?xf32, #map3>)
- }
- scf.yield
- }
- return
- }
-}
-// CHECK-LABEL: func @basic_no_fusion
-// CHECK-NOT: scf.parallel
-// CHECK: linalg.fill
-// CHECK: scf.parallel
-// CHECK: scf.for
-// CHECK-NOT: linalg.fill
-// CHECK: linalg.matmul
-
-// -----
-
-module {
- func.func @basic_conv_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
- %arg2: memref<?x?xf32>) {
- %cst = arith.constant 0.000000e+00 : f32
- linalg.fill ins(%cst : f32) outs(%arg2 : memref<?x?xf32>)
- linalg.conv_2d {__internal_linalg_transform__ = "basic_fusion"}
- ins(%arg1, %arg0 : memref<?x?xf32>, memref<?x?xf32>) outs(%arg2 : memref<?x?xf32>)
- return
- }
-}
-// CHECK: func @basic_conv_fusion
-// CHECK: linalg.fill
-// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_original"
-// CHECK: scf.parallel (%{{.+}}, %{{.+}}, %{{.+}})
-// CHECK-SAME: {
-// CHECK: linalg.fill
-// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_producer"
-// CHECK: linalg.conv_2d
-// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion"
-// CHECK: }
-// CHECK: linalg.conv_2d
-// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_original"
+++ /dev/null
-// RUN: mlir-opt -pass-pipeline="func.func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),resolve-shaped-type-result-dims,canonicalize,cse" -split-input-file %s | FileCheck %s
-
-module {
- func.func @three_op_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
- %arg2: memref<?xf32>, %arg3 : memref<?x?xf32>) {
- %cst = arith.constant 0.000000e+00 : f32
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %d0 = memref.dim %arg0, %c0 : memref<?x?xf32>
- %d1 = memref.dim %arg1, %c1 : memref<?x?xf32>
- %0 = memref.alloc(%d0, %d1) : memref<?x?xf32>
- linalg.fill ins(%cst : f32) outs(%0 : memref<?x?xf32>)
- linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%0 : memref<?x?xf32>)
- linalg.generic
- {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d1)>,
- affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]}
- ins(%0, %arg2 : memref<?x?xf32>, memref<?xf32>)
- outs(%arg3 : memref<?x?xf32>) {
- ^bb0(%arg4 : f32, %arg5 : f32, %arg6 : f32) :
- %5 = arith.addf %arg4, %arg5 : f32
- linalg.yield %5 : f32
- }
- return
- }
-}
-
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
-// CHECK: func @three_op_fusion
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?xf32>
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK: %[[TEMP:.+]] = memref.alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32>
-// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = {{.*}} {
-// CHECK: %[[SV_TEMP_1:.+]] = memref.subview %[[TEMP]][%[[IV0]], %[[IV1]]]
-// CHECK-DAG: %[[SV_ARG2:.+]] = memref.subview %[[ARG2]][%[[IV1]]]
-// CHECK-DAG: %[[SV_ARG3:.+]] = memref.subview %[[ARG3]][%[[IV0]], %[[IV1]]]
-// CHECK-DAG: %[[SV_ARG0:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0]
-// CHECK-DAG: %[[SV_ARG1:.+]] = memref.subview %[[ARG1]][0, %[[IV1]]]
-// CHECK: %[[SV_TEMP_2:.+]] = memref.subview %[[TEMP]][%[[IV0]], %[[IV1]]]
-// CHECK: linalg.fill ins(%{{.+}}{{.*}}outs(%[[SV_TEMP_1]]
-// CHECK: linalg.matmul
-// CHECK-SAME: ins(%[[SV_ARG0]], %[[SV_ARG1]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP2]]>, memref<?x?xf32, #[[MAP2]]>)
-// CHECK-SAME: outs(%[[SV_TEMP_2]] : memref<?x?xf32, #[[MAP2]]>)
-// CHECK: linalg.generic
-// CHECK-SAME: ins(%[[SV_TEMP_1]], %[[SV_ARG2]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP2]]>, memref<?xf32, #[[MAP3]]>)
-// CHECK-SAME: outs(%[[SV_ARG3]] : memref<?x?xf32, #[[MAP2]]>)
-// CHECK: scf.yield
-// CHECK: }
-
-// -----
-
-module {
- func.func @sequence_of_matmul(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
- %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>,
- %arg4: memref<?x?xf32>) {
- %cst = arith.constant 0.000000e+00 : f32
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %m = memref.dim %arg0, %c0 : memref<?x?xf32>
- %n1 = memref.dim %arg1, %c1 : memref<?x?xf32>
- %n2 = memref.dim %arg2, %c1 : memref<?x?xf32>
- %n3 = memref.dim %arg3, %c1 : memref<?x?xf32>
- %0 = memref.alloc(%m, %n1) : memref<?x?xf32>
- %1 = memref.alloc(%m, %n2) : memref<?x?xf32>
- linalg.fill ins(%cst : f32) outs(%0 : memref<?x?xf32>)
- linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%0 : memref<?x?xf32>)
- linalg.fill ins(%cst : f32) outs(%1 : memref<?x?xf32>)
- linalg.matmul ins(%0, %arg2 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%1 : memref<?x?xf32>)
- linalg.fill ins(%cst : f32) outs(%arg4 : memref<?x?xf32>)
- linalg.matmul ins(%1, %arg3 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%arg4 : memref<?x?xf32>)
- return
- }
-}
-
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 16)>
-
-
-// CHECK: func @sequence_of_matmul
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
-// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[N1:.+]] = memref.dim %[[ARG1]], %[[C1]]
-// CHECK-DAG: %[[N2:.+]] = memref.dim %[[ARG2]], %[[C1]]
-// CHECK: %[[ALLOC1:.+]] = memref.alloc(%[[M]], %[[N1]])
-// CHECK: %[[ALLOC2:.+]] = memref.alloc(%[[M]], %[[N2]])
-// CHECK: scf.parallel (%[[IV0:.+]]) = (%[[C0]]) to (%[[M]])
-// CHECK-SAME: step (%[[C16]]) {
-// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
-// CHECK: %[[SV_ALLOC3:.+]] = memref.subview %[[ALLOC2]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[N2]]]
-// CHECK: %[[N3:.+]] = memref.dim %[[ARG4]], %[[C1]]
-// CHECK: %[[SV_ARG4:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[N3]]]
-// CHECK: %[[M_2:.+]] = memref.dim %[[ARG4]], %[[C0]]
-// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP2]](%[[IV0]])[%[[M_2]], %[[M]]]
-// CHECK: %[[SV_ARG4_2:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_3]], %[[N3]]]
-// CHECK: %[[SV_ALLOC1:.+]] = memref.subview %[[ALLOC1]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[N1]]]
-// CHECK: %[[TILE_M_5:.+]] = affine.min #[[MAP2]](%[[IV0]])[%[[M]], %[[M]]]
-// CHECK: %[[N0:.+]] = memref.dim %[[ARG0]], %[[C1]]
-// CHECK: %[[SV_ARG0:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_5]], %[[N0]]]
-// CHECK: %[[SV_ALLOC4:.+]] = memref.subview %[[ALLOC1]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_5]], %[[N1]]]
-// CHECK: linalg.fill ins(%{{.+}}{{.*}}outs(%[[SV_ALLOC1]]
-// CHECK: linalg.matmul ins(%[[SV_ARG0]], %[[ARG1]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
-// CHECK-SAME: outs(%[[SV_ALLOC4]] : memref<?x?xf32, #[[MAP1]]>)
-// CHECK: linalg.fill ins(%{{.+}}{{.*}}outs(%[[SV_ALLOC3]]
-// CHECK: linalg.matmul ins(%[[SV_ALLOC1]], %[[ARG2]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
-// CHECK-SAME: outs(%[[SV_ALLOC3]] : memref<?x?xf32, #[[MAP1]]>)
-// CHECK: linalg.fill ins(%{{.+}}{{.*}}outs(%[[SV_ARG4_2]]
-// CHECK: linalg.matmul ins(%[[SV_ALLOC3]], %[[ARG3]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
-// CHECK-SAME: outs(%[[SV_ARG4]] : memref<?x?xf32, #[[MAP1]]>)
-// CHECK: scf.yield
-// CHECK: }
-
-
-// -----
-
-module {
- func.func @tensor_op_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
- %arg2: tensor<?x?xf32>, %arg3: tensor<?xf32>)
- -> tensor<?x?xf32> {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
- %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
- %2 = tensor.dim %0, %c1 : tensor<?x?xf32>
- %3 = linalg.init_tensor [%1, %2] : tensor<?x?xf32>
- %4 = linalg.generic
- {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d0)>,
- affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]}
- ins(%0, %arg3 : tensor<?x?xf32>, tensor<?xf32>)
- outs(%3 : tensor<?x?xf32>) {
- ^bb0(%arg4: f32, %arg5: f32, %arg6: f32):
- %5 = arith.addf %arg4, %arg5 : f32
- linalg.yield %5 : f32
- } -> tensor<?x?xf32>
- return %4 : tensor<?x?xf32>
- }
-}
-// CHECK-LABEL: func @tensor_op_fusion
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?xf32>
-// CHECK: %[[INIT:.+]] = linalg.init_tensor
-// CHECK: %[[R0:.+]] = scf.for %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG5:.+]] = %[[INIT]]) -> (tensor<?x?xf32>) {
-// CHECK: %[[R1:.+]] = scf.for %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG7:.+]] = %[[ARG5]]) -> (tensor<?x?xf32>) {
-// CHECK-DAG: %[[STARG3:.+]] = tensor.extract_slice %[[ARG3]]
-// CHECK-DAG: %[[STARG7:.+]] = tensor.extract_slice %[[ARG7]]
-// CHECK-DAG: %[[STARG0:.+]] = tensor.extract_slice %[[ARG0]]
-// CHECK-DAG: %[[STARG1:.+]] = tensor.extract_slice %[[ARG1]]
-// CHECK-DAG: %[[STARG2:.+]] = tensor.extract_slice %[[ARG2]]
-// CHECK: %[[T0:.+]] = linalg.matmul
-// CHECK-SAME: ins(%[[STARG0]], %[[STARG1]] : tensor<?x?xf32>, tensor<?x?xf32>)
-// CHECK-SAME: outs(%[[STARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK: %[[T1:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[T0:.+]], %[[STARG3]] : tensor<?x?xf32>, tensor<?xf32>)
-// CHECK-SAME: outs(%[[STARG7]] : tensor<?x?xf32>)
-// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[T1]] into %[[ARG7]]
-// CHECK: scf.yield %[[RESULT]]
-// CHECK: }
-// CHECK: scf.yield %[[R1]]
-// CHECK: }
-// CHECK: return %[[R0]]
-
-// -----
-
-module {
- func.func @tensor_matmul_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
- %arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>,
- %arg4: tensor<?x?xf32>, %arg5: tensor<?x?xf32>,
- %arg6: tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N0] * [N0, N1]
- %1 = linalg.matmul ins(%0, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%arg4 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N1] * [N1, N2]
- %2 = linalg.matmul ins(%1, %arg5 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%arg6 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N2] * [N2, N3]
- return %2 : tensor<?x?xf32>
- }
-}
-
-// CHECK: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)>
-// CHECK: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 16)>
-
-// CHECK: func @tensor_matmul_fusion(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[M:.+]] = tensor.dim %[[ARG0]], %c0 : tensor<?x?xf32>
-// CHECK: %[[R0:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]] =
-// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ARG6]]) -> (tensor<?x?xf32>) {
-// CHECK: %[[TILE_M_1:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
-// CHECK: %[[N3:.+]] = tensor.dim %[[ARG8]], %[[C1]]
-// CHECK: %[[STARG6:.+]] = tensor.extract_slice %[[ARG8]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_1]], %[[N3]]]
-// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[M]], %[[M]]]
-// CHECK: %[[N2:.+]] = tensor.dim %[[ARG4]], %[[C1]]
-// CHECK: %[[STARG4:.+]] = tensor.extract_slice %[[ARG4]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_2]], %[[N2]]]
-// CHECK: %[[N0:.+]] = tensor.dim %[[ARG0]], %[[C1]]
-// CHECK: %[[STARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_2]], %[[N0]]]
-// CHECK: %[[N1:.+]] = tensor.dim %[[ARG2]], %[[C1]]
-// CHECK: %[[STARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_2]], %[[N1]]]
-// CHECK: %[[T0:.+]] = linalg.matmul
-// CHECK-SAME: ins(%[[STARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>
-// CHECK-SAME: ) outs(%[[STARG2]] : tensor<?x?xf32>)
-// CHECK: %[[T1:.+]] = linalg.matmul
-// CHECK-SAME: ins(%[[T0]], %arg3 : tensor<?x?xf32>, tensor<?x?xf32>
-// CHECK-SAME: ) outs(%[[STARG4]] : tensor<?x?xf32>)
-// CHECK: %[[T2:.+]] = linalg.matmul
-// CHECK-SAME: ins(%[[T1]], %arg5 : tensor<?x?xf32>, tensor<?x?xf32>
-// CHECK-SAME: ) outs(%[[STARG6]] : tensor<?x?xf32>)
-// CHECK: %[[R1:.+]] = tensor.insert_slice %[[T2]]
-// CHECK-SAME: into %[[ARG8]][%[[IV0]], 0] [%[[TILE_M_1]], %[[N3]]]
-// CHECK: scf.yield %[[R1]] : tensor<?x?xf32>
-// CHECK: }
+++ /dev/null
-// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -resolve-shaped-type-result-dims -canonicalize -cse --split-input-file | FileCheck %s
-
-module {
- func.func @matmul_fusion(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
- %AB_init: tensor<?x?xf32>, %C: tensor<?x?xf32>,
- %ABC_init: tensor<?x?xf32>) -> tensor<?x?xf32> {
- %AB = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%AB_init : tensor<?x?xf32>) -> tensor<?x?xf32> // <MxN1> <N1xN2>
- %ABC = linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"}
- ins(%AB, %C : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%ABC_init : tensor<?x?xf32>) -> tensor<?x?xf32> // <MxN2> <N2xN3>
- return %ABC : tensor<?x?xf32>
- }
-}
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 32)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 64)>
-// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 32)>
-
-// CHECK: func @matmul_fusion
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
-// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index
-// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
-// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
-// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] =
-// CHECK-SAME: %[[C0]] to %[[M]] step %[[C32]]
-// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor<?x?xf32>) {
-// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[M]]]
-// CHECK: %[[N3:.+]] = tensor.dim %[[ARG6]], %[[C1]]
-// CHECK: %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_2]], %[[N3]]]
-// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP5]](%[[IV0]])[%[[M]], %[[M]]]
-// CHECK: %[[N1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
-// CHECK: %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_3]], %[[N1]]]
-// CHECK: %[[N2_2:.+]] = tensor.dim %[[ARG2]], %[[C1]]
-// CHECK: %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_3]], %[[N2_2]]]
-// CHECK: %[[LHS:.+]] = linalg.matmul
-// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_producer"
-// CHECK-SAME: ins(%[[ST_ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>)
-// CHECK-SAME: outs(%[[ST_ARG2]] : tensor<?x?xf32>)
-// CHECK: %[[N2:.+]] = tensor.dim %[[ARG1]], %[[C1]]
-// CHECK: %[[N3_2:.+]] = tensor.dim %[[ARG3]], %[[C1]]
-// CHECK: %[[YIELD0:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] =
-// CHECK-SAME: %[[C0]] to %[[N3_2]] step %[[C64]]
-// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ST_ARG6]]) -> (tensor<?x?xf32>) {
-// CHECK: %[[YIELD1:.+]] = scf.for %[[IV2:[a-zA-Z0-9]+]] =
-// CHECK-SAME: %[[C0]] to %[[N2]] step %[[C16]]
-// CHECK-SAME: iter_args(%[[ARG10:.+]] = %[[ARG8]]) -> (tensor<?x?xf32>) {
-// CHECK: %[[TILE_N2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[N2]]]
-// CHECK: %[[ST_LHS:.+]] = tensor.extract_slice %[[LHS]][0, %[[IV2]]]
-// CHECK-SAME: [%[[TILE_M_3]], %[[TILE_N2]]]
-// CHECK: %[[TILE_N3:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N3_2]]]
-// CHECK: %[[ST_ARG3:.+]] = tensor.extract_slice %[[ARG3]][%[[IV2]], %[[IV1]]]
-// CHECK-SAME: [%[[TILE_N2]], %[[TILE_N3]]]
-// CHECK: %[[M_4:.+]] = tensor.dim %[[ARG10]], %[[C0]]
-// CHECK: %[[ST_ARG4:.+]] = tensor.extract_slice %[[ARG10]][0, %[[IV1]]]
-// CHECK-SAME: [%[[M_4]], %[[TILE_N3]]]
-// CHECK: %[[ST_RESULT:.+]] = linalg.matmul
-// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion"
-// CHECK-SAME: ins(%[[ST_LHS]], %[[ST_ARG3]]
-// CHECK-SAME: : tensor<?x?xf32>, tensor<?x?xf32>)
-// CHECK-SAME: outs(%[[ST_ARG4]] : tensor<?x?xf32>)
-// CHECK: %[[UPDATE1:.+]] = tensor.insert_slice %[[ST_RESULT]]
-// CHECK-SAME: into %[[ARG10]][0, %[[IV1]]] [%[[M_4]], %[[TILE_N3]]]
-// CHECK: scf.yield %[[UPDATE1]]
-// CHECK: }
-// CHECK: scf.yield %[[YIELD1]]
-// CHECK: }
-// CHECK: %[[UPDATE0:.+]] = tensor.insert_slice %[[YIELD0]] into
-// CHECK-SAME: %[[ARG6]][%[[IV0]], 0] [%[[TILE_M_2]], %[[N3]]]
-// CHECK: scf.yield %[[UPDATE0]]
-// CHECK: }
-// CHECK: return %[[RESULT]]
-
-// -----
-
-module {
- func.func @matmul_plus_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
- %arg2: tensor<?x?xf32>) -> tensor<?x?xf32>{
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %0 = tensor.dim %arg2, %c0 : tensor<?x?xf32>
- %1 = tensor.dim %arg2, %c1 : tensor<?x?xf32>
- %2 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
- %3 = tensor.dim %2, %c0 : tensor<?x?xf32>
- %4 = tensor.dim %2, %c1 : tensor<?x?xf32>
- %5 = linalg.init_tensor [%3, %4] : tensor<?x?xf32>
- %6 = linalg.generic
- {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"],
- __internal_linalg_transform__ = "transpose_fusion"}
- ins(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%5 : tensor<?x?xf32>) {
- ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
- %7 = arith.addf %arg3, %arg4 : f32
- linalg.yield %7 : f32
- } -> tensor<?x?xf32>
- return %6 : tensor<?x?xf32>
- }
-}
-// CHECK: func @matmul_plus_matmul
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]]
-// CHECK-SAME: iter_args(%[[ARG4:.+]] = %{{[a-zA-Z0-9_]+}})
-// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]]
-// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]])
-// CHECK: %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], %[[IV1]]]
-// CHECK: %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
-// CHECK: %[[ST_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
-// CHECK: %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]]
-// CHECK: %[[LHS:.+]] = linalg.matmul
-// CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]]
-// CHECK-SAME: : tensor<?x?xf32>, tensor<?x?xf32>)
-// CHECK-SAME: outs(%[[ST_ARG2]] : tensor<?x?xf32>)
-// CHECK: %[[ST_RESULT:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[LHS]] : tensor<?x?xf32>)
-// CHECK-SAME: outs(%[[ST_ARG6]] : tensor<?x?xf32>)
-// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[ST_RESULT]]
-// CHECK-SAME: into %[[ARG6]][%[[IV0]], %[[IV1]]]
-// CHECK: scf.yield %[[UPDATE]]
-// CHECK: scf.yield %[[YIELD]]
-// CHECK: return %[[RESULT]]
-
-// -----
-
-module {
- func.func @matmul_out_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
- %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
- %c0 = arith.constant 0.0 : f32
- %0 = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
- %1 = linalg.matmul {__internal_linalg_transform__ = "out_fusion"}
- ins(%arg1, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
- return %1 : tensor<?x?xf32>
- }
-}
-
-// CHECK-LABEL: func @matmul_out_fusion(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK: %[[C0:.*]] = arith.constant 0.0{{.*}} : f32
-// CHECK-NOT: fill
-// CHECK: scf.for %[[I:.*]]{{.*}}iter_args(%{{.*}} = %[[ARG0]]) -> (tensor<?x?xf32>) {
-// CHECK: scf.for %[[J:.*]]
-// CHECK: %[[ST:.*]] = tensor.extract_slice %[[ARG0]]
-// CHECK: %[[ST_FILL:.*]] = linalg.fill
-// CHECK-SAME: {__internal_linalg_transform__ = "after_out_fusion_producer"}
-// CHECK-SAME: ins(%[[C0]] : f32) outs(%[[ST]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK: %[[ST_MM_RES:.*]] = scf.for %[[K:.*]]{{.*}}iter_args(%[[BB:.*]] = %[[ST_FILL]]) -> (tensor<?x?xf32>) {
-// CHECK-NOT: fill
-// CHECK: %[[ST_FILL_SUB:.*]] = tensor.extract_slice %[[BB]][0, 0]
-// CHECK: %[[ST_MM_SUB:.*]] = linalg.matmul {__internal_linalg_transform__ = "after_out_fusion"} ins(%{{.*}}, %{{.*}} : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ST_FILL_SUB]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK: %[[ST_MM:.*]] = tensor.insert_slice %[[ST_MM_SUB]] into %[[BB]]
-// CHECK: scf.yield %[[ST_MM]] : tensor<?x?xf32>
-// CHECK: %[[MM:.*]] = tensor.insert_slice %[[ST_MM_RES]] into {{.*}}
-// CHECK: scf.yield %[[MM]] : tensor<?x?xf32>
-
-// -----
-
-module {
- func.func @generic_plus_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
- %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
- %c0 = arith.constant 0.0 : f32
- %0 = linalg.generic {
- indexing_maps = [affine_map<(m, n) -> ()>, affine_map<(m, n) -> (m, n)>],
- iterator_types = ["parallel", "parallel"]}
- ins(%c0 : f32)
- outs(%arg0: tensor<?x?xf32>) {
- ^bb(%0: f32, %1: f32) :
- linalg.yield %0 : f32
- } -> tensor<?x?xf32>
- %1 = linalg.matmul {__internal_linalg_transform__ = "out_fusion"}
- ins(%arg1, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
- return %1 : tensor<?x?xf32>
- }
-}
// CHECK-SAME: outs(%[[INIT_TILE_2]] :
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]]
// CHECK scf.yield %[[INSERT]]
+
+// -----
+
+func.func @matmul_plus_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
+ %arg2: tensor<?x?xf32>) -> tensor<?x?xf32>{
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = tensor.dim %arg2, %c0 : tensor<?x?xf32>
+ %1 = tensor.dim %arg2, %c1 : tensor<?x?xf32>
+ %2 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %3 = tensor.dim %2, %c0 : tensor<?x?xf32>
+ %4 = tensor.dim %2, %c1 : tensor<?x?xf32>
+ %5 = linalg.init_tensor [%3, %4] : tensor<?x?xf32>
+ %6 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"],
+ __internal_linalg_transform__ = "gemm_plus_gemm_fusion"}
+ ins(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%5 : tensor<?x?xf32>) {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
+ %7 = arith.addf %arg3, %arg4 : f32
+ linalg.yield %7 : f32
+ } -> tensor<?x?xf32>
+ return %6 : tensor<?x?xf32>
+}
+// This fuses as expected but the gemm operation is inlined twice. It should be CSE-d but isnt today.
+
+// CHECK: func @matmul_plus_matmul
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: iter_args(%[[ARG4:.+]] = %{{[a-zA-Z0-9_]+}})
+// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]])
+// CHECK-DAG: %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
+// CHECK-DAG: %[[ST_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
+// CHECK-DAG: %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]]
+// CHECK: %[[LHS:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]] :
+// CHECK-SAME: outs(%[[ST_ARG2]] :
+// CHECK-DAG: %[[ST_ARG0_1:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
+// CHECK-DAG: %[[ST_ARG1_1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
+// CHECK-DAG: %[[ST_ARG2_1:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]]
+// CHECK: %[[RHS:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[ST_ARG0_1]], %[[ST_ARG1_1]] :
+// CHECK-SAME: outs(%[[ST_ARG2_1]] :
+// CHECK: %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], %[[IV1]]]
+// CHECK: %[[ST_RESULT:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
+// CHECK-SAME: outs(%[[ST_ARG6]] :
+// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[ST_RESULT]]
+// CHECK-SAME: into %[[ARG6]][%[[IV0]], %[[IV1]]]
+// CHECK: scf.yield %[[UPDATE]]
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @matmul_plus_transpose_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
+ %arg2: tensor<?x?xf32>) -> tensor<?x?xf32>{
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = tensor.dim %arg2, %c0 : tensor<?x?xf32>
+ %1 = tensor.dim %arg2, %c1 : tensor<?x?xf32>
+ %2 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %3 = tensor.dim %2, %c0 : tensor<?x?xf32>
+ %4 = tensor.dim %2, %c1 : tensor<?x?xf32>
+ %5 = linalg.init_tensor [%3, %4] : tensor<?x?xf32>
+ %6 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d1, d0)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"],
+ __internal_linalg_transform__ = "gemm_plus_gemm_fusion"}
+ ins(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%5 : tensor<?x?xf32>) {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
+ %7 = arith.addf %arg3, %arg4 : f32
+ linalg.yield %7 : f32
+ } -> tensor<?x?xf32>
+ return %6 : tensor<?x?xf32>
+}
+// CHECK: func @matmul_plus_transpose_matmul
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: iter_args(%[[ARG4:.+]] = %{{[a-zA-Z0-9_]+}})
+// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]])
+// CHECK-DAG: %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
+// CHECK-DAG: %[[ST_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
+// CHECK-DAG: %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]]
+// CHECK: %[[LHS:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]]
+// CHECK-SAME: : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[ST_ARG2]] : tensor<?x?xf32>)
+// CHECK-DAG: %[[STR_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0]
+// CHECK-DAG: %[[STR_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV0]]]
+// CHECK-DAG: %[[STR_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV1]], %[[IV0]]]
+// CHECK: %[[RHS:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[STR_ARG0]], %[[STR_ARG1]] :
+// CHECK-SAME: outs(%[[STR_ARG2]] :
+// CHECK: %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], %[[IV1]]]
+// CHECK: %[[ST_RESULT:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
+// CHECK-SAME: outs(%[[ST_ARG6]] :
+// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[ST_RESULT]]
+// CHECK-SAME: into %[[ARG6]][%[[IV0]], %[[IV1]]]
+// CHECK: scf.yield %[[UPDATE]]
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @matmul_sequence_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
+ %arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>, %arg4: tensor<?x?xf32>,
+ %arg5: tensor<?x?xf32>, %arg6: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N0] * [N0, N1]
+ %1 = linalg.matmul ins(%0, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg4 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N1] * [N1, N2]
+ %2 = linalg.matmul
+ {__internal_linalg_transform__ = "gemm_sequence_fusion"}
+ ins(%1, %arg5 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg6 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N2] * [N2, N3]
+ return %2 : tensor<?x?xf32>
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK: func @matmul_sequence_fusion(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[N0:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[ORIG_GEMM1:.+]] = linalg.matmul ins(%[[ARG0]], %[[ARG1]] :
+// CHECK-DAG: %[[N1:.+]] = tensor.dim %[[ORIG_GEMM1]], %[[C1]]
+// CHECK-DAG: %[[ORIG_GEMM2:.+]] = linalg.matmul ins(%[[ORIG_GEMM1]], %[[ARG3]] :
+// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C0]]
+// CHECK-DAG: %[[N2:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C1]]
+// CHECK-DAG: %[[N3:.+]] = tensor.dim %[[ARG5]], %[[C1]]
+// CHECK: %[[R0:.+]] = scf.for %[[IV:[a-zA-Z0-9_]+]] =
+// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ARG6]]) -> (tensor<?x?xf32>) {
+// CHECK-DAG: %[[TILE_M:.+]] = affine.min #[[MAP]](%[[IV]])[%{{.+}}, %[[M]]]
+// CHECK-DAG: %[[SLICE_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[TILE_M]], %[[N0]]]
+// CHECK-DAG: %[[SLICE_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, 0] [%[[N0]], %[[N1]]]
+// CHECK-DAG: %[[SLICE_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV]], 0] [%[[TILE_M]], %[[N1]]]
+// CHECK-DAG: %[[TILE_GEMM1:.+]] = linalg.matmul ins(%[[SLICE_ARG0]], %[[SLICE_ARG1]] :
+// CHECK-SAME: outs(%[[SLICE_ARG2]] :
+// CHECK-DAG: %[[SLICE_ARG3:.+]] = tensor.extract_slice %[[ARG3]][0, 0] [%[[N1]], %[[N2]]]
+// CHECK-DAG: %[[SLICE_ARG4:.+]] = tensor.extract_slice %[[ARG4]][%[[IV]], 0] [%[[TILE_M]], %[[N2]]]
+// CHECK-DAG: %[[TILE_GEMM2:.+]] = linalg.matmul ins(%[[TILE_GEMM1]], %[[SLICE_ARG3]] :
+// CHECK-SAME: outs(%[[SLICE_ARG4]] :
+// CHECK-DAG: %[[SLICE_ARG5:.+]] = tensor.extract_slice %[[ARG5]][0, 0] [%[[N2]], %[[N3]]]
+// CHECK-DAG: %[[SLICE_ARG6:.+]] = tensor.extract_slice %[[ARG8]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]]
+// CHECK-DAG: %[[TILE_GEMM3:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[TILE_GEMM2]], %[[SLICE_ARG5]] :
+// CHECK-SAME: outs(%[[SLICE_ARG6]] :
+// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[TILE_GEMM3]] into %[[ARG8]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]]
+// CHECK: scf.yield %[[UPDATE]]
using namespace mlir;
using namespace mlir::linalg;
-/// Use this to safely fill patterns for this test, since RewritePatternSet::add
-/// forwards Rvalues only to the first pattern.
-template <typename OpTy, LinalgTilingLoopType LoopType>
-static void fillFusionPattern(MLIRContext *context,
- const LinalgDependenceGraph &dependenceGraph,
- RewritePatternSet &patterns,
- const Twine &testCase,
- ArrayRef<int64_t> tileSizes,
- ArrayRef<int64_t> indicesToFuse) {
- patterns.add<LinalgTileAndFusePattern<OpTy>>(
- context, dependenceGraph,
- LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(LoopType),
- LinalgFusionOptions().setIndicesToFuse(indicesToFuse),
- LinalgTransformationFilter(
- StringAttr::get(context, testCase + "_fusion"),
- StringAttr::get(context, "after_" + testCase + "_fusion")),
- LinalgTransformationFilter(
- ArrayRef<StringAttr>(),
- StringAttr::get(context, "after_" + testCase + "_fusion_producer")),
- LinalgTransformationFilter(
- ArrayRef<StringAttr>(),
- StringAttr::get(context, "after_" + testCase + "_fusion_original")));
-}
-
-template <LinalgTilingLoopType LoopType>
-static void fillFusionPatterns(MLIRContext *context,
- const LinalgDependenceGraph &dependenceGraph,
- RewritePatternSet &patterns) {
- fillFusionPattern<Conv2DOp, LoopType>(context, dependenceGraph, patterns,
- /*testCase=*/"basic",
- /*tileSizes=*/{32, 64, 16},
- /*indicesToFuse=*/{2});
-
- auto fillMatmulPattern = [&](const Twine &testCase,
- ArrayRef<int64_t> indicesToFuse) {
- fillFusionPattern<MatmulOp, LoopType>(context, dependenceGraph, patterns,
- testCase, /*tileSizes=*/{32, 64, 16},
- indicesToFuse);
- };
- fillMatmulPattern(/*testCase=*/"basic",
- /*indicesToFuse=*/{2});
- fillMatmulPattern(/*testCase=*/"lhs",
- /*indicesToFuse=*/{0});
- fillMatmulPattern(/*testCase=*/"out",
- /*indicesToFuse=*/{2});
- fillMatmulPattern(/*testCase=*/"rhs",
- /*indicesToFuse=*/{1});
- fillMatmulPattern(/*testCase=*/"two_operand",
- /*indicesToFuse=*/{0, 2});
-
- fillFusionPattern<GenericOp, LoopType>(context, dependenceGraph, patterns,
- /*testCase=*/"transpose",
- /*tileSizes=*/{32, 64},
- /*indicesToFuse=*/{0, 1});
-}
-
-namespace {
-template <LinalgTilingLoopType LoopType>
-struct TestLinalgFusionTransforms
- : public PassWrapper<TestLinalgFusionTransforms<LoopType>,
- OperationPass<func::FuncOp>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFusionTransforms)
-
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
- scf::SCFDialect>();
- }
- TestLinalgFusionTransforms() = default;
- TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {}
-
- void runOnOperation() override {
- MLIRContext *context = &this->getContext();
- func::FuncOp funcOp = this->getOperation();
- RewritePatternSet fusionPatterns(context);
- Aliases alias;
- LinalgDependenceGraph dependenceGraph =
- LinalgDependenceGraph::buildDependenceGraph(alias, funcOp);
- fillFusionPatterns<LoopType>(context, dependenceGraph, fusionPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns));
- }
-};
-
-struct TestLinalgFusionTransformsParallelLoops
- : public TestLinalgFusionTransforms<LinalgTilingLoopType::ParallelLoops> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
- TestLinalgFusionTransformsParallelLoops)
-
- StringRef getArgument() const final {
- return "test-linalg-fusion-transform-patterns";
- }
- StringRef getDescription() const final {
- return "Test Linalg fusion transformation patterns by applying them "
- "greedily.";
- }
-};
-
-struct TestLinalgFusionTransformsLoops
- : public TestLinalgFusionTransforms<LinalgTilingLoopType::Loops> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFusionTransformsLoops)
-
- StringRef getArgument() const final {
- return "test-linalg-tensor-fusion-transform-patterns";
- }
- StringRef getDescription() const final {
- return "Test Linalg on tensor fusion transformation "
- "patterns by applying them greedily.";
- }
-};
-
-struct TestLinalgFusionTransformsTiledLoops
- : public TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
- TestLinalgFusionTransformsTiledLoops)
-
- StringRef getArgument() const final {
- return "test-linalg-tiled-loop-fusion-transform-patterns";
- }
- StringRef getDescription() const final {
- return "Test Linalg on tensor fusion transformation "
- "patterns by applying them greedily.";
- }
-};
-} // namespace
-
static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) {
OpBuilder b(f);
DenseSet<Operation *> eraseSet;
} while (succeeded(fuseLinalgOpsGreedily(getOperation())));
}
};
-
-/// Pass to test tile and fuse of sequence of operations. Intended only for
-/// testing.
-struct TestLinalgTileAndFuseSequencePass
- : public PassWrapper<TestLinalgTileAndFuseSequencePass,
- OperationPass<func::FuncOp>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
- TestLinalgTileAndFuseSequencePass)
-
- StringRef getArgument() const final { return "test-linalg-tile-and-fuse"; }
- StringRef getDescription() const final {
- return "Test Linalg tiling and fusion of a sequence of Linalg operations.";
- }
- TestLinalgTileAndFuseSequencePass() = default;
- TestLinalgTileAndFuseSequencePass(
- const TestLinalgTileAndFuseSequencePass &pass)
- : PassWrapper(pass){};
-
- ListOption<int64_t> tileSizes{*this, "tile-sizes",
- llvm::cl::desc("Tile sizes to use for ops")};
-
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
- scf::SCFDialect>();
- }
-
- void runOnOperation() override {
- func::FuncOp funcOp = getOperation();
- auto &blocks = funcOp.getBody().getBlocks();
- if (!llvm::hasSingleElement(blocks)) {
- return;
- }
- SmallVector<LinalgOp, 2> linalgOps =
- llvm::to_vector<2>(blocks.front().getOps<LinalgOp>());
- Aliases aliases;
- LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
- OpBuilder builder(funcOp.getContext());
- linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops;
- if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) {
- return linalgOp.hasTensorSemantics();
- }))
- loopType = LinalgTilingLoopType::Loops;
- Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps(
- builder, linalgOps, dependenceGraph,
- LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType));
- if (!tileAndFuseOps)
- return signalPassFailure();
- if (linalgOps.back().hasTensorSemantics()) {
- linalgOps.back().getOperation()->replaceAllUsesWith(
- tileAndFuseOps->fusedLoops.front());
- }
- for (auto op : linalgOps)
- if (op.hasBufferSemantics())
- op.erase();
- }
-};
-
} // namespace
namespace mlir {
namespace test {
-void registerTestLinalgFusionTransforms() {
- PassRegistration<TestLinalgFusionTransformsParallelLoops>();
-}
-void registerTestLinalgTensorFusionTransforms() {
- PassRegistration<TestLinalgFusionTransformsLoops>();
-}
-void registerTestLinalgTiledLoopFusionTransforms() {
- PassRegistration<TestLinalgFusionTransformsTiledLoops>();
-}
void registerTestLinalgGreedyFusion() {
PassRegistration<TestLinalgGreedyFusion>();
}
-void registerTestLinalgTileAndFuseSequencePass() {
- PassRegistration<TestLinalgTileAndFuseSequencePass>();
-}
} // namespace test
} // namespace mlir
addPatternForTiling<
TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
context, patterns, "gemm_interchange_fusion", {10, 20}, {1, 0});
+ addPatternForTiling<
+ TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
+ context, patterns, "gemm_plus_gemm_fusion", {10, 20});
+ addPatternForTiling<
+ TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
+ context, patterns, "gemm_sequence_fusion", {10});
return;
}
}
void registerTestLastModifiedPass();
void registerTestLinalgDecomposeOps();
void registerTestLinalgElementwiseFusion();
-void registerTestLinalgFusionTransforms();
-void registerTestLinalgTensorFusionTransforms();
-void registerTestLinalgTiledLoopFusionTransforms();
void registerTestLinalgGreedyFusion();
void registerTestLinalgHoisting();
-void registerTestLinalgTileAndFuseSequencePass();
void registerTestLinalgTransforms();
void registerTestLivenessPass();
void registerTestLoopFusion();
mlir::test::registerTestLastModifiedPass();
mlir::test::registerTestLinalgDecomposeOps();
mlir::test::registerTestLinalgElementwiseFusion();
- mlir::test::registerTestLinalgFusionTransforms();
- mlir::test::registerTestLinalgTensorFusionTransforms();
- mlir::test::registerTestLinalgTiledLoopFusionTransforms();
mlir::test::registerTestLinalgGreedyFusion();
mlir::test::registerTestLinalgHoisting();
- mlir::test::registerTestLinalgTileAndFuseSequencePass();
mlir::test::registerTestLinalgTransforms();
mlir::test::registerTestLivenessPass();
mlir::test::registerTestLoopFusion();