From 352d6fe1eb2214cae974c36ee0b1bbc2cc0f91e3 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Mon, 27 Mar 2023 03:48:21 -0700 Subject: [PATCH] [mlir][Linalg] NFC - Move transform utilities related to subcomputation inference to Linalg/Utils --- .../Linalg/TransformOps/LinalgTransformOps.h | 27 ---- .../Linalg/TransformOps/LinalgTransformOps.td | 26 ++-- mlir/include/mlir/Dialect/Linalg/Utils/Utils.h | 144 ++++++++++++++------- .../Linalg/TransformOps/LinalgTransformOps.cpp | 131 ++++--------------- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 83 ++++++++++++ .../Dialect/Linalg/transform-pack-greedily.mlir | 12 +- 6 files changed, 224 insertions(+), 199 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h index a15c3a3..6d7c802 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h @@ -53,33 +53,6 @@ class DialectRegistry; namespace transform { -/// Return the set of `linalgOp` iterator positions for which the indexing map -/// for `opOperand` is a permutation (i.e. an AffineDimExpr). -DenseSet findPermutationsIndexingOperand(linalg::LinalgOp linalgOp, - OpOperand *opOperand, - utils::IteratorType iter); - -/// Possible dimension candidates that define a gemm embedded in the indexing -/// maps of a LinalgOp. -struct GemmDimsForPacking { - DenseSet mPos, nPos, kPos; -}; - -/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form -/// a gemm subcomputation within `linalgOp`. These dimensions are such that: -/// 1. The m dimension is involved in an outer-product along LHS -/// (i.e. it is a permutation on RES and LHS and does not appear in RHS). -/// 2. The n dimension is involved in an outer-product along RHS -/// (i.e. it is a permutation on RES and RHS and does not appear in LHS). -/// 3. The k dimension appears as a permutation on LHS and RHS. -/// 4. m, n and k appear only once in any given indexing. -/// This allows detecting that some gemm is embedded within `linalgOp` with some -/// orthogonal heuristic. -FailureOr inferGemmDims(linalg::LinalgOp linalgOp); - -/// Return true if `linalgOp` contains an embedded gemm subcomputation. -bool containsMostMinorGemm(linalg::LinalgOp linalgOp); - /// Implementation of tiling operations using `scf.forall`. DiagnosedSilenceableFailure tileToForallOpImpl( RewriterBase &rewriter, transform::TransformState &state, diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index c58e955..e107911 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -590,14 +590,14 @@ def PackGreedilyOp : Op needs interface. let arguments = (ins TransformHandleTypeInterface:$target, - Variadic:$gemm_packed_sizes, + Variadic:$matmul_packed_sizes, DefaultValuedAttr - :$static_gemm_packed_sizes, + :$static_matmul_packed_sizes, DefaultValuedAttr - :$gemm_inner_dims_order); + :$matmul_inner_dims_order); let results = (outs Transform_ConcreteOpType<"linalg.generic">:$packed_op); let builders = [ OpBuilder<(ins "Value":$target, - "ArrayRef":$mixedGemmPackedSizes, - CArg<"ArrayRef", "{}">:$gemmDimsInnerDimsOrder)> + "ArrayRef":$mixedMatmulPackedSizes, + CArg<"ArrayRef", "{}">:$matmulDimsInnerDimsOrder)> ]; let assemblyFormat = [{ $target oilist( - `gemm_packed_sizes` `=` custom($gemm_packed_sizes, - $static_gemm_packed_sizes) - `gemm_inner_dims_order` `=` $gemm_inner_dims_order + `matmul_packed_sizes` `=` custom($matmul_packed_sizes, + $static_matmul_packed_sizes) + `matmul_inner_dims_order` `=` $matmul_inner_dims_order ) attr-dict `:` functional-type($target, results) @@ -652,7 +652,7 @@ def PackGreedilyOp : Op getMixedGemmPackedSizes(); + SmallVector getMixedMatmulPackedSizes(); }]; } diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 3c3fa70..4c23ceb 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -28,6 +28,44 @@ class ExtractSliceOp; namespace linalg { //===----------------------------------------------------------------------===// +// Utilities for inferring various semantics properties of Linalg ops. +//===----------------------------------------------------------------------===// + +/// Possible dimension candidates that define a matmul embedded in the indexing +/// maps of a LinalgOp. +struct EmbeddedMatmulDimsCandidates { + DenseSet mPos, nPos, kPos; +}; + +/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the +/// iterators of type `iter` that index the `opOperand` as a permutation. +/// This is useful to infer various subcomputations on a given `linalgOp`. +/// This is performed by looking up each result in the matching indexing map and +/// determining whether: +/// - It is a single AffineDimExpr. +/// - It is the only result involving this AffineDimExpr. +DenseSet findPermutationsIndexingOperand(LinalgOp linalgOp, + OpOperand *opOperand, + utils::IteratorType iter); + +/// Return true if `linalgOp` contains an embedded matmul subcomputation in its +/// most minor dimensions. +bool containsMostMinorMatmul(linalg::LinalgOp linalgOp); + +/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form +/// a matmul subcomputation within `linalgOp`. These dimensions are such that: +/// 1. The m dimension is involved in an outer-product along LHS +/// (i.e. it is a permutation on RES and LHS and does not appear in RHS). +/// 2. The n dimension is involved in an outer-product along RHS +/// (i.e. it is a permutation on RES and RHS and does not appear in LHS). +/// 3. The k dimension appears as a permutation on LHS and RHS. +/// 4. m, n and k appear only once in any given indexing. +/// This allows detecting that some matmul is embedded within `linalgOp` with +/// some orthogonal heuristic. +FailureOr +inferMatmulDims(linalg::LinalgOp linalgOp); + +//===----------------------------------------------------------------------===// // General utilities //===----------------------------------------------------------------------===// @@ -96,10 +134,10 @@ FailureOr getConstantUpperBoundForIndex(Value value); /// Create a tensor::PadOp that pads `source` to the size of the statically /// sized `type` whose static sizes are assumed to be greater than the dynamic -/// `source` size. The padding introduces trailing `pad` values until the target -/// size is met. If `source` is defined by one or more LinalgOps that have been -/// padded with the same value and sizes, return their padded result instead of -/// creating a tensor::PadOp. +/// `source` size. The padding introduces trailing `pad` values until the +/// target size is met. If `source` is defined by one or more LinalgOps that +/// have been padded with the same value and sizes, return their padded result +/// instead of creating a tensor::PadOp. /// /// Example: /// ``` @@ -116,8 +154,8 @@ FailureOr getConstantUpperBoundForIndex(Value value); Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold); -/// Returns a GenericOp that transposes `inputTensor` into `outputTensor` using -/// `transposeVector` to permute the `inputTensor` dimensions. +/// Returns a GenericOp that transposes `inputTensor` into `outputTensor` +/// using `transposeVector` to permute the `inputTensor` dimensions. GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor, Value outputTensor, ArrayRef transposeVector); @@ -127,12 +165,12 @@ GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor, /// or vectorize. GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to); -/// Get the reassociation maps to fold the result of a extract_slice (or source -/// of a insert_slice) operation with given offsets, and sizes to its +/// Get the reassociation maps to fold the result of a extract_slice (or +/// source of a insert_slice) operation with given offsets, and sizes to its /// rank-reduced version. This is only done for the cases where the size is 1 -/// and offset is 0. Strictly speaking the offset 0 is not required in general, -/// but non-zero offsets are not handled by SPIR-V backend at this point (and -/// potentially cannot be handled). +/// and offset is 0. Strictly speaking the offset 0 is not required in +/// general, but non-zero offsets are not handled by SPIR-V backend at this +/// point (and potentially cannot be handled). std::optional> getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes); @@ -151,8 +189,9 @@ enum class LinalgTilingLoopType { ParallelLoops = 2 }; -/// Computes tile offsets, given a list of loop `ivs` and `tileSizes`. In case a -/// tile size is zero (i.e., no tiling), the corresponding offset is also zero. +/// Computes tile offsets, given a list of loop `ivs` and `tileSizes`. In case +/// a tile size is zero (i.e., no tiling), the corresponding offset is also +/// zero. SmallVector computeTileOffsets(OpBuilder &b, Location loc, ArrayRef ivs, ArrayRef tileSizes); @@ -166,15 +205,16 @@ SmallVector computeTileSizes(OpBuilder &b, Location loc, ArrayRef sizeBounds); /// Returns the list of tensor output types produced when the given structured -/// operation `op` is applied to the given `operands`. Note that `operands` are -/// not necessarily the actual operands of `op`. +/// operation `op` is applied to the given `operands`. Note that `operands` +/// are not necessarily the actual operands of `op`. SmallVector getTensorOutputTypes(LinalgOp op, ValueRange operands); /// Creates `insert_slice` ops that insert `results` back into larger tensors -/// they were originally extracted from with `extract_slice` before being passed -/// as `operands` to the given structured operation `op` or its clone. Note that -/// `operands` are not necessarily the actual operands of `op`, the operation -/// serves only as metadata container for operand types and positions. +/// they were originally extracted from with `extract_slice` before being +/// passed as `operands` to the given structured operation `op` or its clone. +/// Note that `operands` are not necessarily the actual operands of `op`, the +/// operation serves only as metadata container for operand types and +/// positions. SmallVector insertSlicesBack(OpBuilder &builder, Location loc, LinalgOp op, ValueRange operands, ValueRange results); @@ -187,8 +227,8 @@ struct SliceParameters { }; /// Computes SliceParameters for a single `valueToTile` assuming that its user -/// is being tiled with the given loop bounds `lbs` and `ubs` and the tile sizes -/// `tileSizes`. +/// is being tiled with the given loop bounds `lbs` and `ubs` and the tile +/// sizes `tileSizes`. /// /// `omitPartialTileCheck` controls whether to omit the partial/boundary tile /// condition check in cases where we statically know that it is unnecessary. @@ -219,8 +259,8 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp, /// Creates an extract_slice/subview op for a single `valueToTile` with /// `builder`. This new operation extracts a tile of `valueToTile`, starting /// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck` -/// controls whether to omit the partial/boundary tile condition check in cases -/// where we statically know that it is unnecessary. +/// controls whether to omit the partial/boundary tile condition check in +/// cases where we statically know that it is unnecessary. Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, ArrayRef tileSizes, AffineMap map, ArrayRef lbs, ArrayRef ubs, @@ -232,8 +272,8 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, /// nest for tiling with the given induction variables `ivs` and tile sizes /// `tileSizes`. `sizeBounds` are the iteration space bounds for *all* the /// implicit loops in `linalgOp`. `omitPartialTileCheck` controls whether to -/// omit the partial/boundary tile condition check in cases where we statically -/// know that it is unnecessary. +/// omit the partial/boundary tile condition check in cases where we +/// statically know that it is unnecessary. /// /// Note that a constant zero in `tileSizes` means no tiling at that implicit /// loop. The number of non-zero values in `tileSizes` should be equal to the @@ -254,8 +294,9 @@ void offsetIndices(RewriterBase &b, LinalgOp linalgOp, ArrayRef offests); /// A struct containing the Linalg producer before and after fusion. -/// When operating on tensors, `fusedProducer` may feed into a `tensor.cast` op -/// before the consumer Linalg op, until enough canonicalizations have applied. +/// When operating on tensors, `fusedProducer` may feed into a `tensor.cast` +/// op before the consumer Linalg op, until enough canonicalizations have +/// applied. struct FusionInfo { LinalgOp originalProducer; LinalgOp fusedProducer; @@ -285,19 +326,23 @@ FailureOr fuseProducerOfTensor(OpBuilder &b, /// Scheme used to distribute loops to processors. enum class DistributionMethod { /// Cyclic distribution where no assumption is made about the dynamic - /// relationship between number of processors and number of iterations of the + /// relationship between number of processors and number of iterations of + /// the /// distributed loop. Distributes the following loop /// /// scf.parallel (%iv) = (%lb) to (%ub) step (%step) /// /// to /// - /// scf.parallel(%iv)= (%lb + %procId * %step) to (%ub) step (%step * %nprocs) + /// scf.parallel(%iv)= (%lb + %procId * %step) to (%ub) step (%step * + /// %nprocs) Cyclic = 0, /// Cyclic distribution where the number of processors can be assumed to be - /// more than or equal to the number of iterations of the distributed loop. In - /// such cases, a simple in-bounds check is enough (instead of materializing a + /// more than or equal to the number of iterations of the distributed loop. + /// In + /// such cases, a simple in-bounds check is enough (instead of materializing + /// a /// loop). Distributes the following loop /// /// scf.parallel (%iv) = (%lb) to (%ub) step (%step) @@ -312,7 +357,8 @@ enum class DistributionMethod { CyclicNumProcsGeNumIters = 1, /// Cyclic distribution where the number of processors can be assumed to be - /// equal to the number of iterations of the distributed loop. In such cases, + /// equal to the number of iterations of the distributed loop. In such + /// cases, /// no bounds check is needed. Distributes the following loop /// /// scf.parallel (%iv) = (%lb) to (%ub) step (%step) @@ -339,16 +385,17 @@ using ProcInfoCallBackFn = std::function( /// Options that allow distribution of loops generated in Linalg transforms to /// processors while generating the loops. struct LinalgLoopDistributionOptions { - /// Callback function that returns the Values for processor ID (`procId`), and - /// number of processors (`nprocs`) used to execute the parallel loops. The - /// number of `{procId, nprocs}` pairs returned must be equal to the number of - /// `parallelLoopRanges` passed into the callback. The `parallelLoopRanges` - /// are ranges of the outer parallel loops of the operation that - /// do have non-zero tile sizes specified. + /// Callback function that returns the Values for processor ID (`procId`), + /// and number of processors (`nprocs`) used to execute the parallel loops. + /// The number of `{procId, nprocs}` pairs returned must be equal to the + /// number of `parallelLoopRanges` passed into the callback. The + /// `parallelLoopRanges` are ranges of the outer parallel loops of the + /// operation that do have non-zero tile sizes specified. ProcInfoCallBackFn procInfo; }; -/// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`. +/// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and +/// `step`. void updateBoundsForCyclicDistribution(OpBuilder &builder, Location loc, Value procId, Value nprocs, Value &lb, Value &ub, Value &step); @@ -362,15 +409,15 @@ class TileLoopNest { public: TileLoopNest(LinalgOp rootOp) : rootOp(rootOp) {} - /// Tile the root operation using the given `tileSizes` and `tileInterchange`, - /// and `tileDistribution`. + /// Tile the root operation using the given `tileSizes` and + /// `tileInterchange`, and `tileDistribution`. LogicalResult tileRootOp(OpBuilder &b, ArrayRef tileSizes, ArrayRef tileInterchange, std::optional tileDistribution); - /// Fuse the producer of `consumerOpOperand` into the tile loop nest. Returns - /// the fused producer or fails if fusion is not possible. + /// Fuse the producer of `consumerOpOperand` into the tile loop nest. + /// Returns the fused producer or fails if fusion is not possible. FailureOr fuseProducer(OpBuilder &b, OpOperand *consumerOpOperand); /// Returns the replacement results for the original untiled root operation. @@ -426,8 +473,8 @@ struct RegionMatcher { IAdd, }; - /// Matches the given linalg op if its body is performing binary operation on - /// int or float scalar values and returns the binary op kind. + /// Matches the given linalg op if its body is performing binary operation + /// on int or float scalar values and returns the binary op kind. /// /// The linalg op's region is expected to be /// ``` @@ -445,9 +492,10 @@ struct RegionMatcher { //===----------------------------------------------------------------------===// /// Utility class used to generate nested loops with ranges described by -/// `loopRanges` and loop type described by the `iteratorTypes`. `bodyBuilderFn` -/// is used to generate the body of the innermost loop. It is passed a range -/// of loop induction variables and a range of operand values to use. +/// `loopRanges` and loop type described by the `iteratorTypes`. +/// `bodyBuilderFn` is used to generate the body of the innermost loop. It is +/// passed a range of loop induction variables and a range of operand values +/// to use. template struct GenerateLoopNest { static void doit(OpBuilder &b, Location loc, ArrayRef loopRanges, diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index e3c1429..6ee0f13 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -27,10 +27,6 @@ #include "mlir/Dialect/Transform/Utils/Utils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/TilingInterface.h" @@ -38,9 +34,6 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" -#include "llvm/ADT/SetOperations.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" @@ -1300,94 +1293,21 @@ void transform::PackOp::getEffects( //===---------------------------------------------------------------------===// LogicalResult transform::PackGreedilyOp::verify() { - if (!isPermutationVector(getGemmInnerDimsOrder())) { - return emitOpError() << getGemmInnerDimsOrderAttrName() + if (!isPermutationVector(getMatmulInnerDimsOrder())) { + return emitOpError() << getMatmulInnerDimsOrderAttrName() << " is not a valid permutation"; } - // TODO: relax to allow empty once we have another strategy than just gemm. - if (getGemmInnerDimsOrder().size() != 3 || - getMixedGemmPackedSizes().size() != 3) { - return emitOpError() << " needs 3 entries for gemm_packed_sizes and " - << getGemmInnerDimsOrderAttrName() - << " order for the gemm strategy"; + // TODO: relax to allow empty once we have another strategy than just matmul. + if (getMatmulInnerDimsOrder().size() != 3 || + getMixedMatmulPackedSizes().size() != 3) { + return emitOpError() << " needs 3 entries for matmul_packed_sizes and " + << getMatmulInnerDimsOrderAttrName() + << " order for the matmul strategy"; } return success(); } -namespace { -auto par = utils::IteratorType::parallel; -auto red = utils::IteratorType::reduction; -} // namespace - -DenseSet transform::findPermutationsIndexingOperand( - LinalgOp linalgOp, OpOperand *opOperand, utils::IteratorType iter) { - DenseSet res; - assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner"); - AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); - for (AffineExpr e : indexingMap.getResults()) { - if (auto d = e.dyn_cast()) { - if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter && - llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) { - return e.isFunctionOfDim(d.getPosition()); - }) == 1) - res.insert(d.getPosition()); - } - } - return res; -} - -FailureOr transform::inferGemmDims(LinalgOp linalgOp) { - if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2) - return failure(); - - DenseSet a = findPermutationsIndexingOperand( - linalgOp, linalgOp.getDpsInputOperand(0), par); - DenseSet b = findPermutationsIndexingOperand( - linalgOp, linalgOp.getDpsInputOperand(1), par); - DenseSet c = findPermutationsIndexingOperand( - linalgOp, linalgOp.getDpsInitOperand(0), par); - - // A & C - B are the iterators involved in an outer-product along A (the LHS). - DenseSet ac = a; - llvm::set_intersect(ac, c); - llvm::set_subtract(ac, b); - // B & C - A are the iterators involved in an outer-product along B (the RHS). - DenseSet bc = b; - llvm::set_intersect(bc, c); - llvm::set_subtract(bc, a); - - // Note: if we ever need them, A & B & C would be "batch" dimensions. - - // A & B red are the reduction dimensions. - DenseSet ra = findPermutationsIndexingOperand( - linalgOp, linalgOp.getDpsInputOperand(0), red); - DenseSet rb = findPermutationsIndexingOperand( - linalgOp, linalgOp.getDpsInputOperand(1), red); - llvm::set_intersect(ra, rb); - - if (ac.empty() || bc.empty() || ra.empty()) - return failure(); - - // Pick the first one in each set. - // TODO: Better heuristic (e.g pick dims based on packing-based metric). - return GemmDimsForPacking{ac, bc, ra}; -} - -bool transform::containsMostMinorGemm(LinalgOp linalgOp) { - FailureOr res = inferGemmDims(linalgOp); - if (failed(res)) - return false; - int64_t numLoops = linalgOp.getNumLoops(); - for (const DenseSet &s : {res->mPos, res->nPos, res->kPos}) { - if (s.contains(numLoops - 3) || s.contains(numLoops - 2) || - s.contains(numLoops - 1)) - continue; - return false; - } - return true; -} - -/// Pack a LinalgOp by greedily inferring gemm dimensions (m, n, k) where m +/// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m /// and n are proper parallel dimensions and k is a proper reduction /// dimension. Packing occurs by rewriting the op as a linalg.generic and /// calling linalg::pack by `mnkPackedSizes`. The order of the packed @@ -1396,17 +1316,17 @@ bool transform::containsMostMinorGemm(LinalgOp linalgOp) { /// dimensions of the operands are not permuted at this time, this is left for /// future work. static FailureOr -packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp, - ArrayRef mnkPackedSizes, - ArrayRef mnkOrder) { +packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, + ArrayRef mnkPackedSizes, + ArrayRef mnkOrder) { assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes"); assert(mnkOrder.size() == 3 && "unexpected mnkOrder size"); assert(isPermutationVector(mnkOrder) && "expected a permutation"); int64_t numLoops = linalgOp.getNumLoops(); if (numLoops <= 2) { - return rewriter.notifyMatchFailure(linalgOp, - "need 3+ loops to find a gemm to pack"); + return rewriter.notifyMatchFailure( + linalgOp, "need 3+ loops to find a matmul to pack"); } // Locally adjust the desired iterator position of mnk and packing sizes. @@ -1418,11 +1338,11 @@ packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp, for (int64_t i = 0, e = numPackedDims; i < e; ++i) packedSizes[mnkOrder[i]] = mnkPackedSizes[i]; - // 1. Infer dims that are important for gemm. - FailureOr res = inferGemmDims(linalgOp); + // 1. Infer dims that are important for matmul. + FailureOr res = inferMatmulDims(linalgOp); if (failed(res)) { return rewriter.notifyMatchFailure(linalgOp, - "couldn't infer gemm iterators"); + "couldn't infer matmul iterators"); } // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most @@ -1479,8 +1399,8 @@ packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp, // TODO: If we wanted to give the genericOp a name after packing, after // calling `pack` would be a good time. auto packingRes = linalg::pack(rewriter, genericOp, adjustedPackedSizes); - assert(containsMostMinorGemm(packingRes->packedLinalgOp) && - "failed to pack to a most minor gemm"); + assert(containsMostMinorMatmul(packingRes->packedLinalgOp) && + "failed to pack to a most minor matmul"); return packingRes; } @@ -1500,11 +1420,11 @@ PackGreedilyOp::apply(transform::TransformResults &transformResults, rewriter.setInsertionPointAfter(linalgOp); // Failing to pack greedily is perfectly fine. // In the future we will want to order packings according to some metric. - FailureOr packResult = packGemmGreedily( + FailureOr packResult = packMatmulGreedily( /*rewriter=*/rewriter, /*linalgOp=*/linalgOp, - /*mnkPackedSizes=*/getMixedGemmPackedSizes(), - /*mnkOrder=*/getGemmInnerDimsOrder()); + /*mnkPackedSizes=*/getMixedMatmulPackedSizes(), + /*mnkOrder=*/getMatmulInnerDimsOrder()); if (succeeded(packResult)) { results.push_back(packResult->packedLinalgOp); continue; @@ -1515,15 +1435,16 @@ PackGreedilyOp::apply(transform::TransformResults &transformResults, return DiagnosedSilenceableFailure::success(); } -SmallVector PackGreedilyOp::getMixedGemmPackedSizes() { +SmallVector PackGreedilyOp::getMixedMatmulPackedSizes() { Builder b(getContext()); - return getMixedValues(getStaticGemmPackedSizes(), getGemmPackedSizes(), b); + return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(), + b); } void transform::PackGreedilyOp::getEffects( SmallVectorImpl &effects) { transform::consumesHandle(getTarget(), effects); - transform::onlyReadsHandle(getGemmPackedSizes(), effects); + transform::onlyReadsHandle(getMatmulPackedSizes(), effects); transform::producesHandle(getPackedOp(), effects); transform::modifiesPayload(effects); } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 75f818b..572b7e4 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -33,6 +33,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Pass/Pass.h" +#include "llvm/ADT/SetOperations.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include @@ -138,6 +139,88 @@ static void unpackRanges(OpBuilder &builder, Location loc, } } +//===----------------------------------------------------------------------===// +// Utilities for inferring various semantics properties of Linalg ops. +//===----------------------------------------------------------------------===// + +DenseSet mlir::linalg::findPermutationsIndexingOperand( + LinalgOp linalgOp, OpOperand *opOperand, utils::IteratorType iter) { + DenseSet res; + assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner"); + AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); + for (AffineExpr e : indexingMap.getResults()) { + if (auto d = e.dyn_cast()) { + if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter && + llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) { + return e.isFunctionOfDim(d.getPosition()); + }) == 1) + res.insert(d.getPosition()); + } + } + return res; +} + +namespace { +auto par = utils::IteratorType::parallel; +auto red = utils::IteratorType::reduction; +} // namespace + +bool mlir::linalg::containsMostMinorMatmul(LinalgOp linalgOp) { + FailureOr res = inferMatmulDims(linalgOp); + if (failed(res)) + return false; + int64_t numLoops = linalgOp.getNumLoops(); + for (const DenseSet &s : {res->mPos, res->nPos, res->kPos}) { + if (s.contains(numLoops - 3) || s.contains(numLoops - 2) || + s.contains(numLoops - 1)) + continue; + return false; + } + return true; +} + +FailureOr +mlir::linalg::inferMatmulDims(LinalgOp linalgOp) { + if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2) + return failure(); + + DenseSet a = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInputOperand(0), par); + DenseSet b = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInputOperand(1), par); + DenseSet c = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInitOperand(0), par); + + // A & C - B are the iterators involved in an outer-product along A (the LHS). + DenseSet ac = a; + llvm::set_intersect(ac, c); + llvm::set_subtract(ac, b); + // B & C - A are the iterators involved in an outer-product along B (the RHS). + DenseSet bc = b; + llvm::set_intersect(bc, c); + llvm::set_subtract(bc, a); + + // Note: if we ever need them, A & B & C would be "batch" dimensions. + + // A & B red are the reduction dimensions. + DenseSet ra = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInputOperand(0), red); + DenseSet rb = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInputOperand(1), red); + llvm::set_intersect(ra, rb); + + if (ac.empty() || bc.empty() || ra.empty()) + return failure(); + + // Pick the first one in each set. + // TODO: Better heuristic (e.g pick dims based on packing-based metric). + return EmbeddedMatmulDimsCandidates{ac, bc, ra}; +} + +//===----------------------------------------------------------------------===// +// General utilities +//===----------------------------------------------------------------------===// + namespace mlir { namespace linalg { diff --git a/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir b/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir index 42f3a6c..544f439 100644 --- a/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir +++ b/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir @@ -25,7 +25,7 @@ transform.sequence failures(propagate) { %matmul = transform.structured.match ops{["linalg.matmul"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.matmul"> transform.structured.pack_greedily %matmul - gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0] + matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0] : (!transform.op<"linalg.matmul">) -> !transform.op<"linalg.generic"> } @@ -70,7 +70,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !pdl.operation): %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic"> transform.structured.pack_greedily %generic - gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0] + matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0] : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic"> } @@ -115,7 +115,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !pdl.operation): %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic"> transform.structured.pack_greedily %generic - gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0] + matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0] : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic"> } @@ -160,7 +160,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !pdl.operation): %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic"> transform.structured.pack_greedily %generic - gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0] + matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0] : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic"> } @@ -195,7 +195,7 @@ transform.sequence failures(propagate) { %conv = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.conv_2d_nchw_fchw"> transform.structured.pack_greedily %conv - gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0] + matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0] : (!transform.op<"linalg.conv_2d_nchw_fchw">) -> !transform.op<"linalg.generic"> } @@ -223,6 +223,6 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !pdl.operation): %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic"> transform.structured.pack_greedily %generic - gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0] + matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0] : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic"> } -- 2.7.4