From 485190df95f98c51c3f4a4ab4db96127cdc9ce78 Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Fri, 15 Jul 2022 20:11:23 +0000 Subject: [PATCH] [mlir][Linalg] Deprecate `tileAndFuseLinalgOps` method and associated patterns. The `tileAndFuseLinalgOps` is a legacy approach for tiling + fusion of Linalg operations. Since it was also intended to work on operations with buffer operands, this method had fairly complex logic to make sure tile and fuse was correct even with side-effecting linalg ops. While complex, it still wasnt robust enough. This patch deprecates this method and thereby deprecating the tiling + fusion method for ops with buffer semantics. Note that the core transformation to do fusion of a producer with a tiled consumer still exists. The deprecation here only removes methods that auto-magically tried to tile and fuse correctly in presence of side-effects. The `tileAndFuseLinalgOps` also works with operations with tensor semantics. There are at least two other ways the same functionality exists. 1) The `tileConsumerAndFuseProducers` method. This does a similar transformation, but using a slightly different logic to automatically figure out the legal tile + fuse code. Note that this is also to be deprecated soon. 2) The prefered way uses the `TilingInterface` for tile + fuse, and relies on the caller to set the tiling options correctly to ensure that the generated code is correct. As proof that (2) is equivalent to the functionality provided by `tileAndFuseLinalgOps`, relevant tests have been moved to use the interface, where the test driver sets the tile sizes appropriately to generate the expected code. Differential Revision: https://reviews.llvm.org/D129901 --- .../mlir/Dialect/Linalg/Transforms/Transforms.h | 121 ------ mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 433 --------------------- mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 111 ------ mlir/test/Dialect/Linalg/fusion-pattern.mlir | 307 --------------- mlir/test/Dialect/Linalg/fusion-sequence.mlir | 252 ------------ .../test/Dialect/Linalg/fusion-tensor-pattern.mlir | 193 --------- .../tile-and-fuse-using-interface.mlir | 171 ++++++++ .../Dialect/Linalg/TestLinalgFusionTransforms.cpp | 193 --------- .../TilingInterface/TestTilingInterface.cpp | 6 + mlir/tools/mlir-opt/mlir-opt.cpp | 8 - 10 files changed, 177 insertions(+), 1618 deletions(-) delete mode 100644 mlir/test/Dialect/Linalg/fusion-pattern.mlir delete mode 100644 mlir/test/Dialect/Linalg/fusion-sequence.mlir delete mode 100644 mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 68a41fd..c81a7ee 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -169,71 +169,6 @@ void peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res, ArrayRef peeledLoops, LinalgTilingLoopType loopType); -/// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This -/// proceeds as follows: -/// - Find outer parallel loops in these ops that can be fused. -/// - Tile fusable outer parallel loops of the last operation in the sequence. -/// - Fuse the remaining operations with the tiled operation -/// -/// For example, consider the sequence of matmul below -/// -/// linalg.matmul ins(%arg0, %arg1 : memref<256x32xf32>, memref<32x32xf32>) -/// outs(%arg2 : memref<256x32xf32>) -/// linalg.matmul ins(%arg2, %arg3 : memref<256x32xf32>, memref<32x32xf32>) -/// outs(%arg4 : memref<256x32xf32>) -/// -/// It is legal to fuse the RAW dependence (through %arg2) by only fusing the -/// matmuls row-wise. For example, the fused computation for the above is shown -/// below. The outer `scf.parallel` loop is the "fused" loop obtained by tiling -/// along the rows of the matrix. The entire rows of the first matmul operation -/// need to be computed before they can be used for the second matmul. The -/// second matmul is further tiled (similar to normal tiling). -/// -/// #map0 = affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)> -/// #map1 = affine_map<(d0, d1) -> (d0 * 32 + d1)> -/// scf.parallel (%arg5) = (%c0) to (%c256) step (%c16) { -/// %0 = subview %arg2[%arg5, 0] [16, 32] [1, 1] -/// : memref<256x32xf32> to memref<16x32xf32, #map0> -/// %1 = subview %arg4[%arg5, 0] [16, 32] [1, 1] -/// : memref<256x32xf32> to memref<16x32xf32, #map0> -/// %2 = subview %arg0[%arg5, 0] [16, 32] [1, 1] -/// : memref<256x32xf32> to memref<16x32xf32, #map0> -/// %3 = subview %arg1[0, 0] [32, 32] [1, 1] -/// : memref<32x32xf32> to memref<32x32xf32, #map1> -/// %4 = subview %arg3[0, 0] [32, 32] [1, 1] -/// : memref<32x32xf32> to memref<32x32xf32, #map1> -/// linalg.matmul -/// ins(%2, %3 : memref<16x32xf32, #map0>, memref<32x32xf32, #map1>) -/// outs(%0 : memref<16x32xf32, #map0>) -/// linalg.matmul -/// ins(%0, %4 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>) -/// outs(%1 : memref<16x8xf32, #map0>) -/// } -/// -/// `tilingOptions` are used to tile the corresponding operation in `ops` (the -/// size of the former should be same as size of the latter. Based on how -/// tile+fuse is implemented, the fused loops are generated based on the last -/// operation in the sequence. For example, the tile sizes for the fused loops -/// is obtained from `tilingOptions.back()`. The following tiling options are -/// handled differently in tile+fuse (compared to tile only) -/// - Interchange of the tiling loops is not supported right now. -/// - Only the fused loops are distributed. -struct TiledAndFusedLinalgOps { - /// Operation obtained by tiling the last operation in sequence of `ops` - /// passed to `tileAndFuseLinalgOps`. - LinalgOp op; - /// The dimension of the loops that are fused. - std::set fusedLoopDims; - /// The generated fused operations (created within the fused loops). - SmallVector fusedProducers; - /// The fused loop generated. - SmallVector fusedLoops; -}; -FailureOr -tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef ops, - const LinalgDependenceGraph &dependenceGraph, - const LinalgTilingOptions &tilingOptions); - /// Interchange the `iterator_types` and `iterator_maps` dimensions and adapts /// the index accesses of `op`. This is an in-place transformation controlled by /// `interchangeVector`. An empty vector is interpreted as the identity @@ -847,62 +782,6 @@ private: LinalgTransformationFilter filter; }; -struct LinalgFusionOptions { - /// List of operands indices to use for fusion. - llvm::SmallSet indicesToFuse = {}; - LinalgFusionOptions &setIndicesToFuse(ArrayRef operands) { - indicesToFuse.insert(operands.begin(), operands.end()); - return *this; - } -}; - -struct LinalgBaseTileAndFusePattern : public RewritePattern { - LinalgBaseTileAndFusePattern( - StringRef opName, MLIRContext *context, - const LinalgDependenceGraph &dependenceGraph, - LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, - LinalgTransformationFilter f = LinalgTransformationFilter(), - LinalgTransformationFilter fusedOpMarker = LinalgTransformationFilter(), - LinalgTransformationFilter originalOpMarker = - LinalgTransformationFilter(), - PatternBenefit benefit = 1); - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override; - -private: - /// Dependence graph needed for fusion. - const LinalgDependenceGraph &dependenceGraph; - /// Options to control tiling. - LinalgTilingOptions tilingOptions; - /// Options to control fusion. - LinalgFusionOptions fusionOptions; - /// Marker to control application of the pattern. - LinalgTransformationFilter filter; - /// Marker set on the fused op after tile and fuse. - LinalgTransformationFilter fusedOpMarker; - /// The dependenceGraph is not modifiable, i.e. if the Linalg operations used - /// to build the dependence graph changes then the dependenceGraph needs to be - /// recomputed right now. To not invalidate the dependenceGraph as - /// transformation happens, the original producer can be tagged with a filter - /// that can be later used to delete the original operations. - LinalgTransformationFilter originalOpMarker; -}; - -template -struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern { - LinalgTileAndFusePattern( - MLIRContext *context, const LinalgDependenceGraph &dependenceGraph, - LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, - LinalgTransformationFilter f = LinalgTransformationFilter(), - LinalgTransformationFilter fusedOpMarker = LinalgTransformationFilter(), - LinalgTransformationFilter originalOpMarker = - LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : LinalgBaseTileAndFusePattern( - OpTy::getOperationName(), context, dependenceGraph, tilingOptions, - fusionOptions, f, fusedOpMarker, originalOpMarker, benefit) {} -}; - /// /// Linalg tile and fuse tensor ops pattern. /// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index a707691..91089a3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -460,436 +460,3 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, consumerOpOperand.set(def); return FusionInfo{cast(producerOpResult.getOwner()), fusedProducer}; } - -/// Prune all dimensions that are of reduction iterator type from `map`. -static AffineMap pruneReductionDimsFromMap(ArrayRef iteratorTypes, - AffineMap map) { - llvm::SmallBitVector projectedDims(iteratorTypes.size()); - for (const auto &attr : llvm::enumerate(iteratorTypes)) { - if (!isParallelIterator(attr.value())) - projectedDims.set(attr.index()); - } - return getProjectedMap(map, projectedDims); -} - -/// Returns the mapping from iterations in the consumer that write to the same -/// location as the iterations in the producer. To do so use -/// - indexing map of the fused view in the consumer : consumerIndexMap -/// - indexing map of the fused view in the producer : producerIndexMap -/// consumerLoopToProducerLoop = -/// inverse(producerIndexMap).compose(consumerIndexMap) -static FailureOr getConsumerLoopToProducerLoopMap( - LinalgDependenceGraph::LinalgDependenceGraphElem dependence) { - auto producer = dyn_cast(dependence.getDependentOp()); - if (!producer) - return failure(); - - Optional producerIndexingMap = - dependence.getDependentOpViewIndexingMap(); - Optional consumerIndexingMap = - dependence.getIndexingOpViewIndexingMap(); - if (!producerIndexingMap || !consumerIndexingMap) - return failure(); - - AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap( - producer.iterator_types().getValue(), *producerIndexingMap); - if (!prunedProducerIndexingMap.isPermutation()) - return failure(); - - if (consumerIndexingMap->getNumResults() != - prunedProducerIndexingMap.getNumResults()) - return failure(); - - LLVM_DEBUG({ - llvm::dbgs() << "\t producerMap : "; - producerIndexingMap->print(llvm::dbgs()); - llvm::dbgs() << " pruned : "; - prunedProducerIndexingMap.print(llvm::dbgs()); - llvm::dbgs() << "\n"; - llvm::dbgs() << "\t consumerMap : "; - consumerIndexingMap->print(llvm::dbgs()); - llvm::dbgs() << "\n"; - }); - - AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap); - if (!invProducerIndexMap) - return failure(); - - return invProducerIndexMap.compose(*consumerIndexingMap); -} - -/// Given a projected permutation `map`, returns true if the map changes the -/// order in which the fused loop dimension appear. -static bool doesTransposeAccess(AffineMap map, - const std::set &fusableLoops) { - Optional lastFusableLoop; - for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) { - return expr.cast().getPosition(); - })) { - if (!fusableLoops.count(pos)) - continue; - if (!lastFusableLoop) { - lastFusableLoop = pos; - continue; - } - if (pos <= *lastFusableLoop) - return true; - lastFusableLoop = pos; - } - return false; -} - -/// Returns the positions of the loop in `op` that can be tiled based on the -/// operations that are to be fused with it. For example, in a -/// -/// linalg.matmul ins(%a, %b : ...) outs(%c : ...) -/// -/// if the producer of %a needs to be fused with this op, only the `i` loop of -/// the matmul can be tiled while fusing. If producer of %a, and %b are to be -/// fused, then no loops can be tiled while fusing. The conditions used are: -/// 1. Only parallel loops can be used for tile + fuse. Find the number of -/// common outer parallel loops between the op and its producers being fused. -/// 2. Of the parallel loops only some can be fused. Only those loops can be -/// fused such where the fusable loops iteration space only touches one tile -/// of the fused operation. This is because the producer (which is writing -/// the fused subview) has update semantics. -/// -/// Since an inverse computation is needed, we need to consider the projection -/// of the producerIndexMap w.r.t the parallel loops. The actual fusable loops -/// are the dimensions of the consumerLoopToProducerLoop map that correspond to -/// parallel loops and appear in the result of the map -/// -/// Example 1: -/// linalg.fill(%cst, %c) -/// linalg.matmul ins(%a, %b) outs(%c) -/// Number of parallel loops : 2 -/// producerIndexMap = affine_map<(i, j) ->(i , j)> -/// consumerIndexMap = affine_map<(i, j, k) -> (i, j)> -/// consumerLoopToProducerLoop = affine_map<(i, j, k) -> (i, j)> -/// Fused dimensions : i, j -/// -/// Example 2: -/// linalg.matmul ins(%a, %b) outs(%c) -/// linalg.generic {indexing_maps = [affine_map<(i, j) -> (j, i)>, ... -/// iterator_types = ["parallel", "parallel"]} -/// ins(%c) ... -/// -/// Number of parallel loops = 2: -/// producerIndexMap (projected to parallel loops) = -/// affine_map<(i, j) -> (i, j)> -/// consumerLoopToProducerLoop2 = affine_map<(i, j) -> (j, i)> -/// Fused dimensions : i, j -/// -/// Example 3: -/// memref.copy(%s, %b) -/// linalg.matmul ins(%a, %b) outs(%c) -/// -/// Number of parallel loops = 2 -/// produceIndexMap : affine_map<(i, j) -> (i, j)> -/// consumerLoopToProduceLoops = affine_map<(i, j, k) -> (k, j)> -/// submap with only parallel loops = affine_map<(i, j) -> (j)> -/// Fused dimensions : j -static std::set -collectFusableLoops(ArrayRef ops, - const FusableOpDependencesTy &fusableDependences) { - assert(!ops.empty()); - auto getNumOuterParallelLoops = [](LinalgOp linalgOp) { - return linalgOp.iterator_types() - .getValue() - .take_while([](Attribute attr) -> bool { - return attr.cast().getValue() == - getParallelIteratorTypeName(); - }) - .size(); - }; - - size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back()); - for (auto op : ops.drop_back()) { - numOuterParallelLoops = - std::min(numOuterParallelLoops, getNumOuterParallelLoops(op)); - } - - std::set fusableLoops; - auto range = llvm::seq(0, numOuterParallelLoops); - fusableLoops.insert(range.begin(), range.end()); - - for (auto op : reverse(ops)) { - for (auto dependence : fusableDependences.lookup(op)) { - LLVM_DEBUG({ - llvm::dbgs() << "\t fusable :"; - for (unsigned i : fusableLoops) - llvm::dbgs() << " " << i; - llvm::dbgs() << "\n"; - }); - - Optional consumerLoopToProducerLoop = - getConsumerLoopToProducerLoopMap(dependence); - if (!consumerLoopToProducerLoop) { - op.emitRemark("failed to get map from consumer loop to producer loop"); - return {}; - } - // todo: This condition is only an implementation limitation. When fusing - // the operation, if the accesses in the producer/consumer are transposes - // of each other, the loop bounds for the tiled producer can be - // manipulated accordingly. This requires some additional bookkeeping in - // the implementation of tile+fuse that is deferred to later. - if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) { - op.emitRemark("unhandled fusion when fusion requires permutation"); - return {}; - } - - std::set candidates; - for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) { - unsigned position = expr.cast().getPosition(); - if (fusableLoops.count(position)) - candidates.insert(position); - } - LLVM_DEBUG({ - llvm::dbgs() << "\t candidates :"; - for (unsigned i : candidates) - llvm::dbgs() << " " << i; - llvm::dbgs() << "\n"; - }); - if (candidates.empty()) - return {}; - std::swap(candidates, fusableLoops); - } - } - - return fusableLoops; -} - -/// Find all dependences that are fusable. -FusableOpDependencesTy mlir::linalg::findAllFusableDependences( - ArrayRef ops, const LinalgDependenceGraph &dependenceGraph) { - FusableOpDependencesTy fusableDependences; - DenseMap> fusedProducerIndexingMap; - for (LinalgOp op : reverse(ops)) { - for (OpOperand *opOperand : op.getInputAndOutputOperands()) { - Optional - fusableDependence = findFusableProducer(*opOperand, dependenceGraph); - if (!fusableDependence) - continue; - LinalgOp producerOp = - dyn_cast(fusableDependence->getDependentOp()); - if (!producerOp) - continue; - // Do not fuse dependences that are to operations not in the same basic - // block. This avoid moving fused operations across loops that might - // themselves carry dependency making the fusion illegal. - if (producerOp->getBlock() != op->getBlock()) - continue; - - // Make sure that the indexing map of the view used for fusion in the - // producer is a projected permutation. - Optional producerMap = - fusableDependence->getDependentOpViewIndexingMap(); - Optional consumerMap = - fusableDependence->getIndexingOpViewIndexingMap(); - assert( - consumerMap && - "unable to find indexing map of operand/result of indexing OpView"); - fusedProducerIndexingMap[producerOp.getOperation()].push_back( - *consumerMap); - if (!producerMap || !producerMap->isProjectedPermutation() || - !consumerMap->isProjectedPermutation()) - continue; - - fusableDependences[producerOp.getOperation()].push_back( - *fusableDependence); - } - } - // TODO: Currently fusion would not be legal if the fusable dependence is to - // the same producer but different indexing map in the consumer. Fix this, but - // in the meanwhile disallow such a fusion. - for (auto useIndexingMapsList : fusedProducerIndexingMap) { - AffineMap map1 = useIndexingMapsList.second.front(); - for (AffineMap map2 : - ArrayRef(useIndexingMapsList.second).drop_front()) { - if (map1 != map2) { - fusableDependences.erase(useIndexingMapsList.first); - break; - } - } - } - return fusableDependences; -} - -/// Tile the fused loops in the root operation, by setting the tile sizes for -/// all other loops to zero (those will be tiled later). -static FailureOr -tileRootOperation(OpBuilder &b, LinalgOp op, ArrayRef tileSizeVector, - const LinalgTilingOptions &options, - const std::set &fusedLoops) { - SmallVector tileSizes(tileSizeVector.begin(), tileSizeVector.end()); - auto zero = b.create(op.getLoc(), 0); - for (unsigned i = 0, e = tileSizes.size(); i != e; ++i) - if (!fusedLoops.count(i)) - tileSizes[i] = zero; - LinalgTilingOptions tileFusedLoopsOptions = options; - tileFusedLoopsOptions.setTileSizes(tileSizes); - // TODO: Propagate RewriterBase everywhere. - IRRewriter rewriter(b); - return tileLinalgOp(rewriter, op, tileFusedLoopsOptions); -} - -/// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected -/// to be a tiled operation such that it is valid to fuse all operations in -/// `fusionCandidates`, i.e. move the operation within the inter-tile loops of -/// `tiledOp`. -static SmallVector -fuseOperations(OpBuilder &b, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp, - ArrayRef fusionCandidates, - const FusableOpDependencesTy &fusableDependences, - const std::set &fusedLoops) { - LinalgOp tiledOp = tiledLinalgOp.op; - OpBuilder::InsertionGuard guard(b); - b.setInsertionPoint(tiledOp); - - DenseMap fusedLoopsAndRanges; - for (unsigned loop : fusedLoops) { - ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop, true); - fusedLoopsAndRanges[loop] = getRangeFromOperandShape( - b, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension); - } - - SmallVector fusedOps(fusionCandidates.size()); - DenseMap origOpToFusedOp; - origOpToFusedOp[rootOp.getOperation()] = tiledOp; - for (const auto &candidate : enumerate(llvm::reverse(fusionCandidates))) { - LinalgOp origOp = candidate.value(); - LinalgOp fusedOp = fuse(b, origOp, fusedLoopsAndRanges); - origOpToFusedOp[origOp.getOperation()] = fusedOp; - fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp; - - // Prepare the builder for the next insertion point. - auto guard = llvm::make_scope_exit([&]() { b.setInsertionPoint(fusedOp); }); - if (!origOp.hasTensorSemantics()) - continue; - - // If the producer consumer operations are linalg operations on tensors, the - // dependence is due to value produced (as a return tensor) by the producer - // and used in the consumer. The returned value of the fused op needs to be - // made the operand of the tiled/fused consumer operation. By construction - // the value returned by the producer is the value used by the consumer. - for (auto &dependence : fusableDependences.lookup(origOp.getOperation())) { - if (dependence.dependenceType != - LinalgDependenceGraph::DependenceType::RAW) - continue; - - unsigned resultIndex = dependence.getDependentOpViewResultNum().value(); - LinalgOp consumer = origOpToFusedOp.lookup(dependence.getIndexingOp()); - if (!consumer) - continue; - - Value replacementValue = fusedOp.getOperation()->getResult(resultIndex); - consumer.getOperation()->setOperand( - dependence.getIndexingOpViewOperandNum().value(), replacementValue); - } - - // At this point, all Linalg uses of the tensors produced by `origOp` have - // been replaced. However, there may still be "output tensor"-like uses - // coming from WAW dependencies. - // All these uses are iter_args of the outermost loop (TODO: add a check). - // Such iter_args uses serve 2 purposes: - // 1. give a shape to the output - // 2. encode destructive updates that may be inplaceable by bufferization. - // To keep the second type of information while letting the unfused op die - // unused, we need to forward the producer output operand. - if (auto forOp = dyn_cast(tiledLinalgOp.loops.front())) { - for (auto &operand : forOp.getIterOpOperands()) { - if (auto opResult = operand.get().dyn_cast()) { - if (opResult.getOwner() == origOp) { - Value output = - origOp.getOutputOperand(opResult.getResultNumber())->get(); - assert(output.getType().isa()); - operand.set(output); - } - } - } - } - } - return fusedOps; -} - -static FailureOr -tileAndFuseLinalgOpsImpl(OpBuilder &b, ArrayRef ops, - const LinalgDependenceGraph &dependenceGraph, - const LinalgTilingOptions &tilingOptions) { - if (ops.size() < 2) - return failure(); - LinalgOp rootOp = ops.back(); - if (!llvm::all_of( - ops, - [](LinalgOp linalgOp) { return linalgOp.hasBufferSemantics(); }) && - !llvm::all_of(ops, [](LinalgOp linalgOp) { - return linalgOp.hasTensorSemantics(); - })) { - rootOp.emitError( - "unable to fuse operations that have tensor semantics with operations " - "that have buffer semantics and viceversa."); - return failure(); - } - // TODO: Support interchange with tile + fuse. This might actually help do - // better fusion. - if (!tilingOptions.interchangeVector.empty()) { - rootOp.emitRemark("unable to handle tile and fuse with interchange"); - return failure(); - } - - OpBuilder::InsertionGuard guard(b); - b.setInsertionPoint(rootOp); - - // Find all the producers. - LLVM_DEBUG(llvm::dbgs() << "findAllFusableDependences\n"); - FusableOpDependencesTy fusableDependences = - findAllFusableDependences(ops, dependenceGraph); - if (fusableDependences.empty()) { - LLVM_DEBUG(llvm::dbgs() << "no fusable dependencies found\n"); - return failure(); - } - - TiledAndFusedLinalgOps ret; - // Find the loops that can be tiled and fused. - LLVM_DEBUG(llvm::dbgs() << "collectFusableLoops\n"); - ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences); - - // If there are no fusable dependences or there are no tile+fusable loops, - // just return. - if (ret.fusedLoopDims.empty()) { - LLVM_DEBUG(llvm::dbgs() << "no fusable loops found\n"); - return failure(); - } - - // Tile the fused loops in the last operation in the list. - SmallVector tileSizeVector = - tilingOptions.tileSizeComputationFunction(b, rootOp); - FailureOr tiledRootOp = tileRootOperation( - b, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims); - if (failed(tiledRootOp)) { - rootOp.emitRemark("failed to tile the fused loops"); - return failure(); - } - ret.op = tiledRootOp->op; - ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); - - // Fuse the other operations into the fused inter-tile loops produced above. - ret.fusedProducers = fuseOperations(b, rootOp, *tiledRootOp, ops.drop_back(), - fusableDependences, ret.fusedLoopDims); - - return ret; -} - -FailureOr -mlir::linalg::tileAndFuseLinalgOps(OpBuilder &b, ArrayRef ops, - const LinalgDependenceGraph &dependenceGraph, - const LinalgTilingOptions &tilingOptions) { - switch (tilingOptions.loopType) { - case LinalgTilingLoopType::Loops: - case LinalgTilingLoopType::ParallelLoops: - case LinalgTilingLoopType::TiledLoops: - return tileAndFuseLinalgOpsImpl(b, ops, dependenceGraph, tilingOptions); - default:; - } - return failure(); -} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index f90a7c0..582c2ce 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -350,117 +350,6 @@ void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res, } } -static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { - if (tiledOp.loops.empty()) - return tiledOp.op.getOperation()->getResults(); - return tiledOp.loops.front()->getResults(); -} - -static ValueRange -getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { - if (tiledAndFusedOp.fusedLoops.empty()) - return tiledAndFusedOp.op.getOperation()->getResults(); - return tiledAndFusedOp.fusedLoops.front()->getResults(); -} - -mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( - StringRef opName, MLIRContext *context, - const LinalgDependenceGraph &dependenceGraph, - LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, - LinalgTransformationFilter f, LinalgTransformationFilter fusedOpMarker, - LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) - : RewritePattern(opName, benefit, context, {}), - dependenceGraph(dependenceGraph), tilingOptions(std::move(tilingOptions)), - fusionOptions(std::move(fusionOptions)), filter(std::move(f)), - fusedOpMarker(std::move(fusedOpMarker)), - originalOpMarker(std::move(originalOpMarker)) {} - -LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( - Operation *op, PatternRewriter &rewriter) const { - LinalgOp linalgOp = dyn_cast(op); - // TODO: remove hasIndexSemantics check once index ops are supported. - if (!linalgOp || linalgOp.hasIndexSemantics()) - return failure(); - if (failed(filter.checkAndNotify(rewriter, linalgOp))) - return failure(); - - DenseSet producers; - producers.insert(linalgOp); - for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { - Optional operandNumber = dependence.getIndexingOpViewOperandNum(); - // When looking at dependences into, indexingOp is always OpOperand. We - // could assert, but continue if this is not the case. - if (!operandNumber) - continue; - if (!fusionOptions.indicesToFuse.count(*operandNumber)) - continue; - if (isa(dependence.getDependentOp())) - producers.insert(dependence.getDependentOp()); - } - - SmallVector fusionOps; - for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; - ++it) { - auto producerLinalgOp = dyn_cast(&(*it)); - if (producerLinalgOp && producers.count(producerLinalgOp)) - fusionOps.push_back(producerLinalgOp); - } - fusionOps.push_back(linalgOp); - - SmallVector tileSizes = - tilingOptions.tileSizeComputationFunction(rewriter, op); - LinalgTilingOptions instanceTilingOptions = tilingOptions; - instanceTilingOptions.setTileSizes(tileSizes); - Optional tiledAndFusedOps = tileAndFuseLinalgOps( - rewriter, fusionOps, dependenceGraph, instanceTilingOptions); - if (!tiledAndFusedOps) - return failure(); - - // Tile the unfused loops; - SmallVector unfusedLoopTileSizes; - Value zero = rewriter.create(op->getLoc(), 0); - for (const auto &tileSize : enumerate(tileSizes)) { - if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) - unfusedLoopTileSizes.push_back(zero); - else - unfusedLoopTileSizes.push_back(tileSize.value()); - } - // Tile the loop only if there is a non-zero tile size. - if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) - unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); - if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { - if (auto cst = val.getDefiningOp()) - return cst.value() != 0; - return true; - })) { - LinalgTilingOptions unfusedTilingOptions = tilingOptions; - unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); - FailureOr unfusedTiledOp = - tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); - if (failed(unfusedTiledOp)) - return failure(); - rewriter.replaceOp(tiledAndFusedOps->op, - getTiledOpResult(unfusedTiledOp.value())); - tiledAndFusedOps->op = unfusedTiledOp->op; - } - op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.value())); - - filter.replaceLinalgTransformationFilter(rewriter, - tiledAndFusedOps->op.getOperation()); - for (auto fusedOp : tiledAndFusedOps->fusedProducers) { - fusedOpMarker.replaceLinalgTransformationFilter(rewriter, - fusedOp.getOperation()); - } - for (auto origProducerOp : ArrayRef(fusionOps).drop_back()) { - originalOpMarker.replaceLinalgTransformationFilter( - rewriter, origProducerOp.getOperation()); - } - rewriter.updateRootInPlace(op, [&]() { - originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); - }); - return success(); -} - /// Linalg tiling pattern. mlir::linalg::LinalgTilingPattern::LinalgTilingPattern( MLIRContext *context, LinalgTilingOptions options, diff --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir deleted file mode 100644 index 787eff6..0000000 --- a/mlir/test/Dialect/Linalg/fusion-pattern.mlir +++ /dev/null @@ -1,307 +0,0 @@ -// RUN: mlir-opt %s -test-linalg-fusion-transform-patterns -canonicalize -cse -split-input-file | FileCheck %s - -module { - func.func @basic_fusion(%arg0: memref, %arg1: memref, - %arg2: memref) { - %cst = arith.constant 0.000000e+00 : f32 - linalg.fill ins(%cst : f32) outs(%arg2 : memref) - linalg.matmul {__internal_linalg_transform__ = "basic_fusion"} - ins(%arg0, %arg1 : memref, memref) - outs(%arg2 : memref) - return - } -} - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 32)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 64)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)> -// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 32)> -// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 64)> -// CHECK: func @basic_fusion -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index -// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index -// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index -// CHECK-DAG: %[[CST:.+]] = arith.constant 0.0{{.*}} : f32 -// CHECK-DAG: linalg.fill -// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_original" -// CHECK-SAME: ins(%[[CST]]{{.*}}outs(%[[ARG2]] -// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]] -// CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]] -// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = -// CHECK-SAME: to (%[[M]], %[[N]]) -// CHECK-SAME: step (%[[C32]], %[[C64]]) { -// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] -// CHECK: %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]] -// CHECK: %[[SV1:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[K]]] -// CHECK: %[[K_2:.+]] = memref.dim %[[ARG1]], %[[C0]] -// CHECK: %[[TILE_N:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N]]] -// CHECK: %[[SV2:.+]] = memref.subview %[[ARG1]][0, %[[IV1]]] -// CHECK-SAME: %[[K_2]], %[[TILE_N]] -// CHECK: %[[SV3:.+]] = memref.subview %[[ARG2]][%[[IV0]], %[[IV1]]] -// CHECK-SAME: [%[[TILE_M]], %[[TILE_N]]] -// CHECK: %[[M_2:.+]] = memref.dim %[[ARG2]], %[[C0]] -// CHECK: %[[N_2:.+]] = memref.dim %[[ARG2]], %[[C1]] -// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]])[%[[M_2]], %[[M]]] -// CHECK: %[[TILE_N_3:.+]] = affine.min #[[MAP5]](%[[IV1]])[%[[N_2]], %[[N]]] -// CHECK: %[[SV3_2:.+]] = memref.subview %[[ARG2]][%[[IV0]], %[[IV1]]] -// CHECK-SAME: [%[[TILE_M_3]], %[[TILE_N_3]]] -// CHECK: linalg.fill -// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_producer" -// CHECK-SAME: ins(%[[CST]]{{.*}}outs(%[[SV3_2]] -// CHECK: scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] { -// CHECK: %[[TILE_K:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K]]] -// CHECK: %[[SV4:.+]] = memref.subview %[[SV1]][0, %[[IV2]]] -// CHECK-SAME: [%[[TILE_M]], %[[TILE_K]]] -// CHECK: %[[SV5:.+]] = memref.subview %[[SV2]][%[[IV2]], 0] -// CHECK-SAME: [%[[TILE_K]], %[[TILE_N]]] -// CHECK: linalg.matmul -// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion" -// CHECK-SAME: ins(%[[SV4]], %[[SV5]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV3]] : memref) -// CHECK: } -// CHECK: } -// CHECK: linalg.matmul -// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_original" - -// ----- - -module { - func.func @matmul_fusion(%arg0: memref, %arg1: memref, - %arg2: memref, %arg3: memref, - %arg4: memref) { - linalg.matmul ins(%arg0, %arg1 : memref, memref) - outs(%arg2 : memref) - linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"} - ins(%arg2, %arg3 : memref, memref) - outs(%arg4 : memref) - return - } -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 32)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 64)> -// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 32)> -// CHECK: func @matmul_fusion -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: memref -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index -// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index -// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index -// CHECK: linalg.matmul -// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_original" -// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG2]], %[[C0]] -// CHECK: scf.parallel (%[[IV0:.+]]) = -// CHECK-SAME: (%[[C0]]) to (%[[M]]) step (%[[C32]]) { -// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] -// CHECK: %[[K2:.+]] = memref.dim %[[ARG2]], %[[C1]] -// CHECK: %[[SV1:.+]] = memref.subview %[[ARG2]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[K2]]] -// CHECK: %[[N:.+]] = memref.dim %[[ARG4]], %[[C1]] -// CHECK: %[[SV2:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[N]]] -// CHECK: %[[M_3:.+]] = memref.dim %[[ARG0]], %[[C0]] -// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]])[%[[M_3]], %[[M]]] -// CHECK: %[[K1:.+]] = memref.dim %[[ARG0]], %[[C1]] -// CHECK: %[[SV3:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_3]], %[[K1]]] -// CHECK: %[[SV1_2:.+]] = memref.subview %[[ARG2]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_3]], %[[K2]]] -// CHECK: linalg.matmul -// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_producer" -// CHECK-SAME: ins(%[[SV3]], %[[ARG1]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV1_2]] : memref) -// CHECK: %[[N_2:.+]] = memref.dim %[[ARG3]], %[[C1]] -// CHECK: scf.parallel (%[[IV1:.+]]) = -// CHECK-SAME: (%[[C0]]) to (%[[N_2]]) step (%[[C64]]) { -// CHECK-NEXT: scf.for %[[IV2:.+]] = %[[C0]] to %[[K2]] step %[[C16]] { -// CHECK: %[[TILE_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K2]]] -// CHECK: %[[SV6:.+]] = memref.subview %[[SV1]][0, %[[IV2]]] -// CHECK-SAME: [%[[TILE_M]], %[[TILE_K]]] -// CHECK: %[[TILE_N:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N_2]]] -// CHECK: %[[SV7:.+]] = memref.subview %[[ARG3]][%[[IV2]], %[[IV1]]] -// CHECK-SAME: [%[[TILE_K]], %[[TILE_N]]] -// CHECK: %[[SV8:.+]] = memref.subview %[[SV2]][0, %[[IV1]]] -// CHECK-SAME: [%[[TILE_M]], %[[TILE_N]]] -// CHECK: linalg.matmul -// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion" -// CHECK-SAME: ins(%[[SV6]], %[[SV7]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV8]] : memref) -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: linalg.matmul -// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_original" - -// ----- - -module { - func.func @matmul_plus_matmul(%arg0: memref, %arg1: memref, - %arg2: memref) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %0 = memref.dim %arg2, %c0 : memref - %1 = memref.dim %arg2, %c1 : memref - %2 = memref.alloc(%0, %1) : memref - linalg.matmul ins(%arg0, %arg1 : memref, memref) - outs(%2 : memref) - linalg.generic - {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"], - __internal_linalg_transform__ = "transpose_fusion"} - ins(%2, %2 : memref, memref) - outs(%arg2 : memref) { - ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) : - %3 = arith.addf %arg3, %arg4 : f32 - linalg.yield %3 : f32 - } - return - } -} -// CHECK: func @matmul_plus_matmul -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref -// CHECK: %[[T2:.+]] = memref.alloc(%{{.*}}, %{{.*}}) : memref -// CHECK: linalg.matmul -// CHECK-SAME: after_transpose_fusion_original -// CHECK: scf.parallel (%[[ARG3:[a-zA-Z0-9_]+]], %[[ARG4:.[a-zA-Z0-9_]+]]) -// CHECK: %[[T5:.+]] = memref.subview %[[T2]][%[[ARG3]], %[[ARG4]]] -// CHECK: %[[T6:.+]] = memref.subview %[[ARG2]][%[[ARG3]], %[[ARG4]]] -// CHECK: %[[T8:.+]] = memref.subview %[[ARG0]][%[[ARG3]], 0] -// CHECK: %[[T9:.+]] = memref.subview %[[ARG1]][0, %[[ARG4]]] -// CHECK: %[[T10:.+]] = memref.subview %[[T2]][%[[ARG3]], %[[ARG4]]] -// CHECK: linalg.matmul -// CHECK-SAME: after_transpose_fusion_producer -// CHECK-SAME: ins(%[[T8]], %[[T9]] -// CHECK-SAME: outs(%[[T10]] -// CHECK-NOT: linalg.matmul -// CHECK: linalg.generic -// CHECK-SAME: ins(%[[T5]], %[[T5]] -// CHECK-SAME: outs(%[[T6]] -// CHECK-SAME: after_transpose_fusion - -// ----- - -module { - func.func @matmul_plus_transpose_matmul(%arg0: memref, - %arg1: memref, - %arg2: memref) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %0 = memref.dim %arg2, %c0 : memref - %1 = memref.dim %arg2, %c1 : memref - %2 = memref.alloc(%0, %1) : memref - linalg.matmul ins(%arg0, %arg1 : memref, memref) - outs(%2 : memref) - linalg.generic - {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d1, d0)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"], - __internal_linalg_transform__ = "transpose_fusion"} - ins(%2, %2 : memref, memref) - outs(%arg2 : memref) { - ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) : - %3 = arith.addf %arg3, %arg4 : f32 - linalg.yield %3 : f32 - } - return - } -} -// CHECK-LABEL: func @matmul_plus_transpose_matmul -// CHECK-NOT: scf.parallel -// CHECK-NOT: scf.for -// CHECK: linalg.matmul -// CHECK-NOT: scf.parallel -// CHECK-NOT: scf.for -// CHECK: linalg.generic -// CHECK-NOT: scf.parallel -// CHECK-NOT: scf.for - -// ----- - -#map0 = affine_map<(d0)[s0] -> (32, -d0 + s0)> -#map1 = affine_map<(d0)[s0] -> (64, -d0 + s0)> -#map2 = affine_map<(d0)[s0] -> (16, -d0 + s0)> -#map3 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -module { - func.func @basic_no_fusion(%arg0: memref, %arg1: memref, - %arg2: memref) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c16 = arith.constant 16 : index - %cst = arith.constant 0.000000e+00 : f32 - linalg.fill ins(%cst : f32) outs(%arg2 : memref) - %0 = memref.dim %arg0, %c0 : memref - %1 = memref.dim %arg1, %c1 : memref - %2 = memref.dim %arg0, %c1 : memref - scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%0, %1) step (%c32, %c64) { - scf.for %arg5 = %c0 to %2 step %c16 { - %3 = affine.min #map0(%arg3)[%0] - %4 = affine.min #map1(%arg4)[%1] - %5 = affine.min #map2(%arg5)[%2] - %6 = memref.subview %arg0[%arg3, %arg5] [%3, %5] [1, 1] : memref to memref - %7 = memref.subview %arg1[%arg5, %arg4] [%5, %4] [1, 1] : memref to memref - %8 = memref.subview %arg2[%arg3, %arg4] [%3, %4] [1, 1] : memref to memref - linalg.matmul {__internal_linalg_transform__ = "basic_fusion"} - ins(%6, %7 : memref, memref) - outs(%8 : memref) - } - scf.yield - } - return - } -} -// CHECK-LABEL: func @basic_no_fusion -// CHECK-NOT: scf.parallel -// CHECK: linalg.fill -// CHECK: scf.parallel -// CHECK: scf.for -// CHECK-NOT: linalg.fill -// CHECK: linalg.matmul - -// ----- - -module { - func.func @basic_conv_fusion(%arg0: memref, %arg1: memref, - %arg2: memref) { - %cst = arith.constant 0.000000e+00 : f32 - linalg.fill ins(%cst : f32) outs(%arg2 : memref) - linalg.conv_2d {__internal_linalg_transform__ = "basic_fusion"} - ins(%arg1, %arg0 : memref, memref) outs(%arg2 : memref) - return - } -} -// CHECK: func @basic_conv_fusion -// CHECK: linalg.fill -// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_original" -// CHECK: scf.parallel (%{{.+}}, %{{.+}}, %{{.+}}) -// CHECK-SAME: { -// CHECK: linalg.fill -// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_producer" -// CHECK: linalg.conv_2d -// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion" -// CHECK: } -// CHECK: linalg.conv_2d -// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_original" diff --git a/mlir/test/Dialect/Linalg/fusion-sequence.mlir b/mlir/test/Dialect/Linalg/fusion-sequence.mlir deleted file mode 100644 index ffe8580..0000000 --- a/mlir/test/Dialect/Linalg/fusion-sequence.mlir +++ /dev/null @@ -1,252 +0,0 @@ -// RUN: mlir-opt -pass-pipeline="func.func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),resolve-shaped-type-result-dims,canonicalize,cse" -split-input-file %s | FileCheck %s - -module { - func.func @three_op_fusion(%arg0: memref, %arg1: memref, - %arg2: memref, %arg3 : memref) { - %cst = arith.constant 0.000000e+00 : f32 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %d0 = memref.dim %arg0, %c0 : memref - %d1 = memref.dim %arg1, %c1 : memref - %0 = memref.alloc(%d0, %d1) : memref - linalg.fill ins(%cst : f32) outs(%0 : memref) - linalg.matmul ins(%arg0, %arg1 : memref, memref) - outs(%0 : memref) - linalg.generic - {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d1)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%0, %arg2 : memref, memref) - outs(%arg3 : memref) { - ^bb0(%arg4 : f32, %arg5 : f32, %arg6 : f32) : - %5 = arith.addf %arg4, %arg5 : f32 - linalg.yield %5 : f32 - } - return - } -} - -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> -// CHECK: func @three_op_fusion -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref -// CHECK: %[[TEMP:.+]] = memref.alloc(%{{.*}}, %{{.*}}) : memref -// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = {{.*}} { -// CHECK: %[[SV_TEMP_1:.+]] = memref.subview %[[TEMP]][%[[IV0]], %[[IV1]]] -// CHECK-DAG: %[[SV_ARG2:.+]] = memref.subview %[[ARG2]][%[[IV1]]] -// CHECK-DAG: %[[SV_ARG3:.+]] = memref.subview %[[ARG3]][%[[IV0]], %[[IV1]]] -// CHECK-DAG: %[[SV_ARG0:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0] -// CHECK-DAG: %[[SV_ARG1:.+]] = memref.subview %[[ARG1]][0, %[[IV1]]] -// CHECK: %[[SV_TEMP_2:.+]] = memref.subview %[[TEMP]][%[[IV0]], %[[IV1]]] -// CHECK: linalg.fill ins(%{{.+}}{{.*}}outs(%[[SV_TEMP_1]] -// CHECK: linalg.matmul -// CHECK-SAME: ins(%[[SV_ARG0]], %[[SV_ARG1]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV_TEMP_2]] : memref) -// CHECK: linalg.generic -// CHECK-SAME: ins(%[[SV_TEMP_1]], %[[SV_ARG2]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV_ARG3]] : memref) -// CHECK: scf.yield -// CHECK: } - -// ----- - -module { - func.func @sequence_of_matmul(%arg0: memref, %arg1: memref, - %arg2: memref, %arg3: memref, - %arg4: memref) { - %cst = arith.constant 0.000000e+00 : f32 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %m = memref.dim %arg0, %c0 : memref - %n1 = memref.dim %arg1, %c1 : memref - %n2 = memref.dim %arg2, %c1 : memref - %n3 = memref.dim %arg3, %c1 : memref - %0 = memref.alloc(%m, %n1) : memref - %1 = memref.alloc(%m, %n2) : memref - linalg.fill ins(%cst : f32) outs(%0 : memref) - linalg.matmul ins(%arg0, %arg1 : memref, memref) - outs(%0 : memref) - linalg.fill ins(%cst : f32) outs(%1 : memref) - linalg.matmul ins(%0, %arg2 : memref, memref) - outs(%1 : memref) - linalg.fill ins(%cst : f32) outs(%arg4 : memref) - linalg.matmul ins(%1, %arg3 : memref, memref) - outs(%arg4 : memref) - return - } -} - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 16)> - - -// CHECK: func @sequence_of_matmul -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: memref -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index -// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]] -// CHECK-DAG: %[[N1:.+]] = memref.dim %[[ARG1]], %[[C1]] -// CHECK-DAG: %[[N2:.+]] = memref.dim %[[ARG2]], %[[C1]] -// CHECK: %[[ALLOC1:.+]] = memref.alloc(%[[M]], %[[N1]]) -// CHECK: %[[ALLOC2:.+]] = memref.alloc(%[[M]], %[[N2]]) -// CHECK: scf.parallel (%[[IV0:.+]]) = (%[[C0]]) to (%[[M]]) -// CHECK-SAME: step (%[[C16]]) { -// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] -// CHECK: %[[SV_ALLOC3:.+]] = memref.subview %[[ALLOC2]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[N2]]] -// CHECK: %[[N3:.+]] = memref.dim %[[ARG4]], %[[C1]] -// CHECK: %[[SV_ARG4:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[N3]]] -// CHECK: %[[M_2:.+]] = memref.dim %[[ARG4]], %[[C0]] -// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP2]](%[[IV0]])[%[[M_2]], %[[M]]] -// CHECK: %[[SV_ARG4_2:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_3]], %[[N3]]] -// CHECK: %[[SV_ALLOC1:.+]] = memref.subview %[[ALLOC1]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[N1]]] -// CHECK: %[[TILE_M_5:.+]] = affine.min #[[MAP2]](%[[IV0]])[%[[M]], %[[M]]] -// CHECK: %[[N0:.+]] = memref.dim %[[ARG0]], %[[C1]] -// CHECK: %[[SV_ARG0:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_5]], %[[N0]]] -// CHECK: %[[SV_ALLOC4:.+]] = memref.subview %[[ALLOC1]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_5]], %[[N1]]] -// CHECK: linalg.fill ins(%{{.+}}{{.*}}outs(%[[SV_ALLOC1]] -// CHECK: linalg.matmul ins(%[[SV_ARG0]], %[[ARG1]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV_ALLOC4]] : memref) -// CHECK: linalg.fill ins(%{{.+}}{{.*}}outs(%[[SV_ALLOC3]] -// CHECK: linalg.matmul ins(%[[SV_ALLOC1]], %[[ARG2]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV_ALLOC3]] : memref) -// CHECK: linalg.fill ins(%{{.+}}{{.*}}outs(%[[SV_ARG4_2]] -// CHECK: linalg.matmul ins(%[[SV_ALLOC3]], %[[ARG3]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV_ARG4]] : memref) -// CHECK: scf.yield -// CHECK: } - - -// ----- - -module { - func.func @tensor_op_fusion(%arg0: tensor, %arg1: tensor, - %arg2: tensor, %arg3: tensor) - -> tensor { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) - outs(%arg2 : tensor) -> tensor - %1 = tensor.dim %0, %c0 : tensor - %2 = tensor.dim %0, %c1 : tensor - %3 = linalg.init_tensor [%1, %2] : tensor - %4 = linalg.generic - {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%0, %arg3 : tensor, tensor) - outs(%3 : tensor) { - ^bb0(%arg4: f32, %arg5: f32, %arg6: f32): - %5 = arith.addf %arg4, %arg5 : f32 - linalg.yield %5 : f32 - } -> tensor - return %4 : tensor - } -} -// CHECK-LABEL: func @tensor_op_fusion -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[INIT:.+]] = linalg.init_tensor -// CHECK: %[[R0:.+]] = scf.for %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG5:.+]] = %[[INIT]]) -> (tensor) { -// CHECK: %[[R1:.+]] = scf.for %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG7:.+]] = %[[ARG5]]) -> (tensor) { -// CHECK-DAG: %[[STARG3:.+]] = tensor.extract_slice %[[ARG3]] -// CHECK-DAG: %[[STARG7:.+]] = tensor.extract_slice %[[ARG7]] -// CHECK-DAG: %[[STARG0:.+]] = tensor.extract_slice %[[ARG0]] -// CHECK-DAG: %[[STARG1:.+]] = tensor.extract_slice %[[ARG1]] -// CHECK-DAG: %[[STARG2:.+]] = tensor.extract_slice %[[ARG2]] -// CHECK: %[[T0:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[STARG0]], %[[STARG1]] : tensor, tensor) -// CHECK-SAME: outs(%[[STARG2]] : tensor) -> tensor -// CHECK: %[[T1:.+]] = linalg.generic -// CHECK-SAME: ins(%[[T0:.+]], %[[STARG3]] : tensor, tensor) -// CHECK-SAME: outs(%[[STARG7]] : tensor) -// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[T1]] into %[[ARG7]] -// CHECK: scf.yield %[[RESULT]] -// CHECK: } -// CHECK: scf.yield %[[R1]] -// CHECK: } -// CHECK: return %[[R0]] - -// ----- - -module { - func.func @tensor_matmul_fusion(%arg0: tensor, %arg1: tensor, - %arg2: tensor, %arg3: tensor, - %arg4: tensor, %arg5: tensor, - %arg6: tensor) -> tensor { - %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) - outs(%arg2 : tensor) -> tensor // [M, N0] * [N0, N1] - %1 = linalg.matmul ins(%0, %arg3 : tensor, tensor) - outs(%arg4 : tensor) -> tensor // [M, N1] * [N1, N2] - %2 = linalg.matmul ins(%1, %arg5 : tensor, tensor) - outs(%arg6 : tensor) -> tensor // [M, N2] * [N2, N3] - return %2 : tensor - } -} - -// CHECK: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)> -// CHECK: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 16)> - -// CHECK: func @tensor_matmul_fusion( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: tensor) -> tensor { -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK: %[[M:.+]] = tensor.dim %[[ARG0]], %c0 : tensor -// CHECK: %[[R0:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]] = -// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ARG6]]) -> (tensor) { -// CHECK: %[[TILE_M_1:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] -// CHECK: %[[N3:.+]] = tensor.dim %[[ARG8]], %[[C1]] -// CHECK: %[[STARG6:.+]] = tensor.extract_slice %[[ARG8]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_1]], %[[N3]]] -// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[M]], %[[M]]] -// CHECK: %[[N2:.+]] = tensor.dim %[[ARG4]], %[[C1]] -// CHECK: %[[STARG4:.+]] = tensor.extract_slice %[[ARG4]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_2]], %[[N2]]] -// CHECK: %[[N0:.+]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK: %[[STARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_2]], %[[N0]]] -// CHECK: %[[N1:.+]] = tensor.dim %[[ARG2]], %[[C1]] -// CHECK: %[[STARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_2]], %[[N1]]] -// CHECK: %[[T0:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[STARG0]], %[[ARG1]] : tensor, tensor -// CHECK-SAME: ) outs(%[[STARG2]] : tensor) -// CHECK: %[[T1:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[T0]], %arg3 : tensor, tensor -// CHECK-SAME: ) outs(%[[STARG4]] : tensor) -// CHECK: %[[T2:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[T1]], %arg5 : tensor, tensor -// CHECK-SAME: ) outs(%[[STARG6]] : tensor) -// CHECK: %[[R1:.+]] = tensor.insert_slice %[[T2]] -// CHECK-SAME: into %[[ARG8]][%[[IV0]], 0] [%[[TILE_M_1]], %[[N3]]] -// CHECK: scf.yield %[[R1]] : tensor -// CHECK: } diff --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir deleted file mode 100644 index 56f4c9d..0000000 --- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir +++ /dev/null @@ -1,193 +0,0 @@ -// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -resolve-shaped-type-result-dims -canonicalize -cse --split-input-file | FileCheck %s - -module { - func.func @matmul_fusion(%A: tensor, %B: tensor, - %AB_init: tensor, %C: tensor, - %ABC_init: tensor) -> tensor { - %AB = linalg.matmul ins(%A, %B : tensor, tensor) - outs(%AB_init : tensor) -> tensor // - %ABC = linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"} - ins(%AB, %C : tensor, tensor) - outs(%ABC_init : tensor) -> tensor // - return %ABC : tensor - } -} -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 32)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 64)> -// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 32)> - -// CHECK: func @matmul_fusion -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: tensor - -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index -// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index -// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index -// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = -// CHECK-SAME: %[[C0]] to %[[M]] step %[[C32]] -// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor) { -// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[M]]] -// CHECK: %[[N3:.+]] = tensor.dim %[[ARG6]], %[[C1]] -// CHECK: %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_2]], %[[N3]]] -// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP5]](%[[IV0]])[%[[M]], %[[M]]] -// CHECK: %[[N1:.+]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK: %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_3]], %[[N1]]] -// CHECK: %[[N2_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] -// CHECK: %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_3]], %[[N2_2]]] -// CHECK: %[[LHS:.+]] = linalg.matmul -// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_producer" -// CHECK-SAME: ins(%[[ST_ARG0]], %[[ARG1]] : tensor, tensor) -// CHECK-SAME: outs(%[[ST_ARG2]] : tensor) -// CHECK: %[[N2:.+]] = tensor.dim %[[ARG1]], %[[C1]] -// CHECK: %[[N3_2:.+]] = tensor.dim %[[ARG3]], %[[C1]] -// CHECK: %[[YIELD0:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = -// CHECK-SAME: %[[C0]] to %[[N3_2]] step %[[C64]] -// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ST_ARG6]]) -> (tensor) { -// CHECK: %[[YIELD1:.+]] = scf.for %[[IV2:[a-zA-Z0-9]+]] = -// CHECK-SAME: %[[C0]] to %[[N2]] step %[[C16]] -// CHECK-SAME: iter_args(%[[ARG10:.+]] = %[[ARG8]]) -> (tensor) { -// CHECK: %[[TILE_N2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[N2]]] -// CHECK: %[[ST_LHS:.+]] = tensor.extract_slice %[[LHS]][0, %[[IV2]]] -// CHECK-SAME: [%[[TILE_M_3]], %[[TILE_N2]]] -// CHECK: %[[TILE_N3:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N3_2]]] -// CHECK: %[[ST_ARG3:.+]] = tensor.extract_slice %[[ARG3]][%[[IV2]], %[[IV1]]] -// CHECK-SAME: [%[[TILE_N2]], %[[TILE_N3]]] -// CHECK: %[[M_4:.+]] = tensor.dim %[[ARG10]], %[[C0]] -// CHECK: %[[ST_ARG4:.+]] = tensor.extract_slice %[[ARG10]][0, %[[IV1]]] -// CHECK-SAME: [%[[M_4]], %[[TILE_N3]]] -// CHECK: %[[ST_RESULT:.+]] = linalg.matmul -// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion" -// CHECK-SAME: ins(%[[ST_LHS]], %[[ST_ARG3]] -// CHECK-SAME: : tensor, tensor) -// CHECK-SAME: outs(%[[ST_ARG4]] : tensor) -// CHECK: %[[UPDATE1:.+]] = tensor.insert_slice %[[ST_RESULT]] -// CHECK-SAME: into %[[ARG10]][0, %[[IV1]]] [%[[M_4]], %[[TILE_N3]]] -// CHECK: scf.yield %[[UPDATE1]] -// CHECK: } -// CHECK: scf.yield %[[YIELD1]] -// CHECK: } -// CHECK: %[[UPDATE0:.+]] = tensor.insert_slice %[[YIELD0]] into -// CHECK-SAME: %[[ARG6]][%[[IV0]], 0] [%[[TILE_M_2]], %[[N3]]] -// CHECK: scf.yield %[[UPDATE0]] -// CHECK: } -// CHECK: return %[[RESULT]] - -// ----- - -module { - func.func @matmul_plus_matmul(%arg0: tensor, %arg1: tensor, - %arg2: tensor) -> tensor{ - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %0 = tensor.dim %arg2, %c0 : tensor - %1 = tensor.dim %arg2, %c1 : tensor - %2 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) - outs(%arg2 : tensor) -> tensor - %3 = tensor.dim %2, %c0 : tensor - %4 = tensor.dim %2, %c1 : tensor - %5 = linalg.init_tensor [%3, %4] : tensor - %6 = linalg.generic - {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"], - __internal_linalg_transform__ = "transpose_fusion"} - ins(%2, %2 : tensor, tensor) - outs(%5 : tensor) { - ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) : - %7 = arith.addf %arg3, %arg4 : f32 - linalg.yield %7 : f32 - } -> tensor - return %6 : tensor - } -} -// CHECK: func @matmul_plus_matmul -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]] -// CHECK-SAME: iter_args(%[[ARG4:.+]] = %{{[a-zA-Z0-9_]+}}) -// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]] -// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]]) -// CHECK: %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], %[[IV1]]] -// CHECK: %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] -// CHECK: %[[ST_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] -// CHECK: %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]] -// CHECK: %[[LHS:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]] -// CHECK-SAME: : tensor, tensor) -// CHECK-SAME: outs(%[[ST_ARG2]] : tensor) -// CHECK: %[[ST_RESULT:.+]] = linalg.generic -// CHECK-SAME: ins(%[[LHS]] : tensor) -// CHECK-SAME: outs(%[[ST_ARG6]] : tensor) -// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[ST_RESULT]] -// CHECK-SAME: into %[[ARG6]][%[[IV0]], %[[IV1]]] -// CHECK: scf.yield %[[UPDATE]] -// CHECK: scf.yield %[[YIELD]] -// CHECK: return %[[RESULT]] - -// ----- - -module { - func.func @matmul_out_fusion(%arg0: tensor, %arg1: tensor, - %arg2: tensor) -> tensor { - %c0 = arith.constant 0.0 : f32 - %0 = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor) -> tensor - %1 = linalg.matmul {__internal_linalg_transform__ = "out_fusion"} - ins(%arg1, %arg2 : tensor, tensor) - outs(%0 : tensor) -> tensor - return %1 : tensor - } -} - -// CHECK-LABEL: func @matmul_out_fusion( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[C0:.*]] = arith.constant 0.0{{.*}} : f32 -// CHECK-NOT: fill -// CHECK: scf.for %[[I:.*]]{{.*}}iter_args(%{{.*}} = %[[ARG0]]) -> (tensor) { -// CHECK: scf.for %[[J:.*]] -// CHECK: %[[ST:.*]] = tensor.extract_slice %[[ARG0]] -// CHECK: %[[ST_FILL:.*]] = linalg.fill -// CHECK-SAME: {__internal_linalg_transform__ = "after_out_fusion_producer"} -// CHECK-SAME: ins(%[[C0]] : f32) outs(%[[ST]] : tensor) -> tensor -// CHECK: %[[ST_MM_RES:.*]] = scf.for %[[K:.*]]{{.*}}iter_args(%[[BB:.*]] = %[[ST_FILL]]) -> (tensor) { -// CHECK-NOT: fill -// CHECK: %[[ST_FILL_SUB:.*]] = tensor.extract_slice %[[BB]][0, 0] -// CHECK: %[[ST_MM_SUB:.*]] = linalg.matmul {__internal_linalg_transform__ = "after_out_fusion"} ins(%{{.*}}, %{{.*}} : tensor, tensor) outs(%[[ST_FILL_SUB]] : tensor) -> tensor -// CHECK: %[[ST_MM:.*]] = tensor.insert_slice %[[ST_MM_SUB]] into %[[BB]] -// CHECK: scf.yield %[[ST_MM]] : tensor -// CHECK: %[[MM:.*]] = tensor.insert_slice %[[ST_MM_RES]] into {{.*}} -// CHECK: scf.yield %[[MM]] : tensor - -// ----- - -module { - func.func @generic_plus_matmul(%arg0: tensor, %arg1: tensor, - %arg2: tensor) -> tensor { - %c0 = arith.constant 0.0 : f32 - %0 = linalg.generic { - indexing_maps = [affine_map<(m, n) -> ()>, affine_map<(m, n) -> (m, n)>], - iterator_types = ["parallel", "parallel"]} - ins(%c0 : f32) - outs(%arg0: tensor) { - ^bb(%0: f32, %1: f32) : - linalg.yield %0 : f32 - } -> tensor - %1 = linalg.matmul {__internal_linalg_transform__ = "out_fusion"} - ins(%arg1, %arg2 : tensor, tensor) - outs(%0 : tensor) -> tensor - return %1 : tensor - } -} diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir index 81e2bfb..d1ca2d2 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir @@ -230,3 +230,174 @@ func.func @interchange_matmul_fusion(%arg0 : tensor, %arg1 : tensor, %arg1: tensor, + %arg2: tensor) -> tensor{ + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = tensor.dim %arg2, %c0 : tensor + %1 = tensor.dim %arg2, %c1 : tensor + %2 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + %3 = tensor.dim %2, %c0 : tensor + %4 = tensor.dim %2, %c1 : tensor + %5 = linalg.init_tensor [%3, %4] : tensor + %6 = linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"], + __internal_linalg_transform__ = "gemm_plus_gemm_fusion"} + ins(%2, %2 : tensor, tensor) + outs(%5 : tensor) { + ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) : + %7 = arith.addf %arg3, %arg4 : f32 + linalg.yield %7 : f32 + } -> tensor + return %6 : tensor +} +// This fuses as expected but the gemm operation is inlined twice. It should be CSE-d but isnt today. + +// CHECK: func @matmul_plus_matmul +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]] +// CHECK-SAME: iter_args(%[[ARG4:.+]] = %{{[a-zA-Z0-9_]+}}) +// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]] +// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]]) +// CHECK-DAG: %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] +// CHECK-DAG: %[[ST_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] +// CHECK-DAG: %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]] +// CHECK: %[[LHS:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]] : +// CHECK-SAME: outs(%[[ST_ARG2]] : +// CHECK-DAG: %[[ST_ARG0_1:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] +// CHECK-DAG: %[[ST_ARG1_1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] +// CHECK-DAG: %[[ST_ARG2_1:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]] +// CHECK: %[[RHS:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[ST_ARG0_1]], %[[ST_ARG1_1]] : +// CHECK-SAME: outs(%[[ST_ARG2_1]] : +// CHECK: %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], %[[IV1]]] +// CHECK: %[[ST_RESULT:.+]] = linalg.generic +// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : +// CHECK-SAME: outs(%[[ST_ARG6]] : +// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[ST_RESULT]] +// CHECK-SAME: into %[[ARG6]][%[[IV0]], %[[IV1]]] +// CHECK: scf.yield %[[UPDATE]] +// CHECK: scf.yield %[[YIELD]] +// CHECK: return %[[RESULT]] + +// ----- + +func.func @matmul_plus_transpose_matmul(%arg0: tensor, %arg1: tensor, + %arg2: tensor) -> tensor{ + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = tensor.dim %arg2, %c0 : tensor + %1 = tensor.dim %arg2, %c1 : tensor + %2 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + %3 = tensor.dim %2, %c0 : tensor + %4 = tensor.dim %2, %c1 : tensor + %5 = linalg.init_tensor [%3, %4] : tensor + %6 = linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"], + __internal_linalg_transform__ = "gemm_plus_gemm_fusion"} + ins(%2, %2 : tensor, tensor) + outs(%5 : tensor) { + ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) : + %7 = arith.addf %arg3, %arg4 : f32 + linalg.yield %7 : f32 + } -> tensor + return %6 : tensor +} +// CHECK: func @matmul_plus_transpose_matmul +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]] +// CHECK-SAME: iter_args(%[[ARG4:.+]] = %{{[a-zA-Z0-9_]+}}) +// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]] +// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]]) +// CHECK-DAG: %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] +// CHECK-DAG: %[[ST_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] +// CHECK-DAG: %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]] +// CHECK: %[[LHS:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]] +// CHECK-SAME: : tensor, tensor) +// CHECK-SAME: outs(%[[ST_ARG2]] : tensor) +// CHECK-DAG: %[[STR_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] +// CHECK-DAG: %[[STR_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV0]]] +// CHECK-DAG: %[[STR_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV1]], %[[IV0]]] +// CHECK: %[[RHS:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[STR_ARG0]], %[[STR_ARG1]] : +// CHECK-SAME: outs(%[[STR_ARG2]] : +// CHECK: %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], %[[IV1]]] +// CHECK: %[[ST_RESULT:.+]] = linalg.generic +// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : +// CHECK-SAME: outs(%[[ST_ARG6]] : +// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[ST_RESULT]] +// CHECK-SAME: into %[[ARG6]][%[[IV0]], %[[IV1]]] +// CHECK: scf.yield %[[UPDATE]] +// CHECK: scf.yield %[[YIELD]] +// CHECK: return %[[RESULT]] + +// ----- + +func.func @matmul_sequence_fusion(%arg0: tensor, %arg1: tensor, + %arg2: tensor, %arg3: tensor, %arg4: tensor, + %arg5: tensor, %arg6: tensor) -> tensor { + %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor // [M, N0] * [N0, N1] + %1 = linalg.matmul ins(%0, %arg3 : tensor, tensor) + outs(%arg4 : tensor) -> tensor // [M, N1] * [N1, N2] + %2 = linalg.matmul + {__internal_linalg_transform__ = "gemm_sequence_fusion"} + ins(%1, %arg5 : tensor, tensor) + outs(%arg6 : tensor) -> tensor // [M, N2] * [N2, N3] + return %2 : tensor +} + +// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK: func @matmul_sequence_fusion( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: tensor) -> tensor { +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[N0:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[ORIG_GEMM1:.+]] = linalg.matmul ins(%[[ARG0]], %[[ARG1]] : +// CHECK-DAG: %[[N1:.+]] = tensor.dim %[[ORIG_GEMM1]], %[[C1]] +// CHECK-DAG: %[[ORIG_GEMM2:.+]] = linalg.matmul ins(%[[ORIG_GEMM1]], %[[ARG3]] : +// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C0]] +// CHECK-DAG: %[[N2:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C1]] +// CHECK-DAG: %[[N3:.+]] = tensor.dim %[[ARG5]], %[[C1]] +// CHECK: %[[R0:.+]] = scf.for %[[IV:[a-zA-Z0-9_]+]] = +// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ARG6]]) -> (tensor) { +// CHECK-DAG: %[[TILE_M:.+]] = affine.min #[[MAP]](%[[IV]])[%{{.+}}, %[[M]]] +// CHECK-DAG: %[[SLICE_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[TILE_M]], %[[N0]]] +// CHECK-DAG: %[[SLICE_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, 0] [%[[N0]], %[[N1]]] +// CHECK-DAG: %[[SLICE_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV]], 0] [%[[TILE_M]], %[[N1]]] +// CHECK-DAG: %[[TILE_GEMM1:.+]] = linalg.matmul ins(%[[SLICE_ARG0]], %[[SLICE_ARG1]] : +// CHECK-SAME: outs(%[[SLICE_ARG2]] : +// CHECK-DAG: %[[SLICE_ARG3:.+]] = tensor.extract_slice %[[ARG3]][0, 0] [%[[N1]], %[[N2]]] +// CHECK-DAG: %[[SLICE_ARG4:.+]] = tensor.extract_slice %[[ARG4]][%[[IV]], 0] [%[[TILE_M]], %[[N2]]] +// CHECK-DAG: %[[TILE_GEMM2:.+]] = linalg.matmul ins(%[[TILE_GEMM1]], %[[SLICE_ARG3]] : +// CHECK-SAME: outs(%[[SLICE_ARG4]] : +// CHECK-DAG: %[[SLICE_ARG5:.+]] = tensor.extract_slice %[[ARG5]][0, 0] [%[[N2]], %[[N3]]] +// CHECK-DAG: %[[SLICE_ARG6:.+]] = tensor.extract_slice %[[ARG8]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]] +// CHECK-DAG: %[[TILE_GEMM3:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[TILE_GEMM2]], %[[SLICE_ARG5]] : +// CHECK-SAME: outs(%[[SLICE_ARG6]] : +// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[TILE_GEMM3]] into %[[ARG8]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]] +// CHECK: scf.yield %[[UPDATE]] diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp index 67651a9..c5b27c5 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -23,130 +23,6 @@ using namespace mlir; using namespace mlir::linalg; -/// Use this to safely fill patterns for this test, since RewritePatternSet::add -/// forwards Rvalues only to the first pattern. -template -static void fillFusionPattern(MLIRContext *context, - const LinalgDependenceGraph &dependenceGraph, - RewritePatternSet &patterns, - const Twine &testCase, - ArrayRef tileSizes, - ArrayRef indicesToFuse) { - patterns.add>( - context, dependenceGraph, - LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(LoopType), - LinalgFusionOptions().setIndicesToFuse(indicesToFuse), - LinalgTransformationFilter( - StringAttr::get(context, testCase + "_fusion"), - StringAttr::get(context, "after_" + testCase + "_fusion")), - LinalgTransformationFilter( - ArrayRef(), - StringAttr::get(context, "after_" + testCase + "_fusion_producer")), - LinalgTransformationFilter( - ArrayRef(), - StringAttr::get(context, "after_" + testCase + "_fusion_original"))); -} - -template -static void fillFusionPatterns(MLIRContext *context, - const LinalgDependenceGraph &dependenceGraph, - RewritePatternSet &patterns) { - fillFusionPattern(context, dependenceGraph, patterns, - /*testCase=*/"basic", - /*tileSizes=*/{32, 64, 16}, - /*indicesToFuse=*/{2}); - - auto fillMatmulPattern = [&](const Twine &testCase, - ArrayRef indicesToFuse) { - fillFusionPattern(context, dependenceGraph, patterns, - testCase, /*tileSizes=*/{32, 64, 16}, - indicesToFuse); - }; - fillMatmulPattern(/*testCase=*/"basic", - /*indicesToFuse=*/{2}); - fillMatmulPattern(/*testCase=*/"lhs", - /*indicesToFuse=*/{0}); - fillMatmulPattern(/*testCase=*/"out", - /*indicesToFuse=*/{2}); - fillMatmulPattern(/*testCase=*/"rhs", - /*indicesToFuse=*/{1}); - fillMatmulPattern(/*testCase=*/"two_operand", - /*indicesToFuse=*/{0, 2}); - - fillFusionPattern(context, dependenceGraph, patterns, - /*testCase=*/"transpose", - /*tileSizes=*/{32, 64}, - /*indicesToFuse=*/{0, 1}); -} - -namespace { -template -struct TestLinalgFusionTransforms - : public PassWrapper, - OperationPass> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFusionTransforms) - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - TestLinalgFusionTransforms() = default; - TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {} - - void runOnOperation() override { - MLIRContext *context = &this->getContext(); - func::FuncOp funcOp = this->getOperation(); - RewritePatternSet fusionPatterns(context); - Aliases alias; - LinalgDependenceGraph dependenceGraph = - LinalgDependenceGraph::buildDependenceGraph(alias, funcOp); - fillFusionPatterns(context, dependenceGraph, fusionPatterns); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns)); - } -}; - -struct TestLinalgFusionTransformsParallelLoops - : public TestLinalgFusionTransforms { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( - TestLinalgFusionTransformsParallelLoops) - - StringRef getArgument() const final { - return "test-linalg-fusion-transform-patterns"; - } - StringRef getDescription() const final { - return "Test Linalg fusion transformation patterns by applying them " - "greedily."; - } -}; - -struct TestLinalgFusionTransformsLoops - : public TestLinalgFusionTransforms { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFusionTransformsLoops) - - StringRef getArgument() const final { - return "test-linalg-tensor-fusion-transform-patterns"; - } - StringRef getDescription() const final { - return "Test Linalg on tensor fusion transformation " - "patterns by applying them greedily."; - } -}; - -struct TestLinalgFusionTransformsTiledLoops - : public TestLinalgFusionTransforms { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( - TestLinalgFusionTransformsTiledLoops) - - StringRef getArgument() const final { - return "test-linalg-tiled-loop-fusion-transform-patterns"; - } - StringRef getDescription() const final { - return "Test Linalg on tensor fusion transformation " - "patterns by applying them greedily."; - } -}; -} // namespace - static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) { OpBuilder b(f); DenseSet eraseSet; @@ -236,82 +112,13 @@ struct TestLinalgGreedyFusion } while (succeeded(fuseLinalgOpsGreedily(getOperation()))); } }; - -/// Pass to test tile and fuse of sequence of operations. Intended only for -/// testing. -struct TestLinalgTileAndFuseSequencePass - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( - TestLinalgTileAndFuseSequencePass) - - StringRef getArgument() const final { return "test-linalg-tile-and-fuse"; } - StringRef getDescription() const final { - return "Test Linalg tiling and fusion of a sequence of Linalg operations."; - } - TestLinalgTileAndFuseSequencePass() = default; - TestLinalgTileAndFuseSequencePass( - const TestLinalgTileAndFuseSequencePass &pass) - : PassWrapper(pass){}; - - ListOption tileSizes{*this, "tile-sizes", - llvm::cl::desc("Tile sizes to use for ops")}; - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - func::FuncOp funcOp = getOperation(); - auto &blocks = funcOp.getBody().getBlocks(); - if (!llvm::hasSingleElement(blocks)) { - return; - } - SmallVector linalgOps = - llvm::to_vector<2>(blocks.front().getOps()); - Aliases aliases; - LinalgDependenceGraph dependenceGraph(aliases, linalgOps); - OpBuilder builder(funcOp.getContext()); - linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops; - if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) { - return linalgOp.hasTensorSemantics(); - })) - loopType = LinalgTilingLoopType::Loops; - Optional tileAndFuseOps = tileAndFuseLinalgOps( - builder, linalgOps, dependenceGraph, - LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType)); - if (!tileAndFuseOps) - return signalPassFailure(); - if (linalgOps.back().hasTensorSemantics()) { - linalgOps.back().getOperation()->replaceAllUsesWith( - tileAndFuseOps->fusedLoops.front()); - } - for (auto op : linalgOps) - if (op.hasBufferSemantics()) - op.erase(); - } -}; - } // namespace namespace mlir { namespace test { -void registerTestLinalgFusionTransforms() { - PassRegistration(); -} -void registerTestLinalgTensorFusionTransforms() { - PassRegistration(); -} -void registerTestLinalgTiledLoopFusionTransforms() { - PassRegistration(); -} void registerTestLinalgGreedyFusion() { PassRegistration(); } -void registerTestLinalgTileAndFuseSequencePass() { - PassRegistration(); -} } // namespace test } // namespace mlir diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp index 214a405..5c603a5 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -191,6 +191,12 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context, addPatternForTiling< TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( context, patterns, "gemm_interchange_fusion", {10, 20}, {1, 0}); + addPatternForTiling< + TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( + context, patterns, "gemm_plus_gemm_fusion", {10, 20}); + addPatternForTiling< + TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( + context, patterns, "gemm_sequence_fusion", {10}); return; } } diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index f8fa245..78e26de 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -88,12 +88,8 @@ void registerTestInterfaces(); void registerTestLastModifiedPass(); void registerTestLinalgDecomposeOps(); void registerTestLinalgElementwiseFusion(); -void registerTestLinalgFusionTransforms(); -void registerTestLinalgTensorFusionTransforms(); -void registerTestLinalgTiledLoopFusionTransforms(); void registerTestLinalgGreedyFusion(); void registerTestLinalgHoisting(); -void registerTestLinalgTileAndFuseSequencePass(); void registerTestLinalgTransforms(); void registerTestLivenessPass(); void registerTestLoopFusion(); @@ -187,12 +183,8 @@ void registerTestPasses() { mlir::test::registerTestLastModifiedPass(); mlir::test::registerTestLinalgDecomposeOps(); mlir::test::registerTestLinalgElementwiseFusion(); - mlir::test::registerTestLinalgFusionTransforms(); - mlir::test::registerTestLinalgTensorFusionTransforms(); - mlir::test::registerTestLinalgTiledLoopFusionTransforms(); mlir::test::registerTestLinalgGreedyFusion(); mlir::test::registerTestLinalgHoisting(); - mlir::test::registerTestLinalgTileAndFuseSequencePass(); mlir::test::registerTestLinalgTransforms(); mlir::test::registerTestLivenessPass(); mlir::test::registerTestLoopFusion(); -- 2.7.4