[mlir][Linalg] NFC - Move transform utilities related to subcomputation inference...
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Mon, 27 Mar 2023 10:48:21 +0000 (03:48 -0700)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Mon, 27 Mar 2023 10:48:51 +0000 (03:48 -0700)
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/test/Dialect/Linalg/transform-pack-greedily.mlir

index a15c3a3..6d7c802 100644 (file)
@@ -53,33 +53,6 @@ class DialectRegistry;
 
 namespace transform {
 
-/// Return the set of `linalgOp` iterator positions for which the indexing map
-/// for `opOperand` is a permutation (i.e. an AffineDimExpr).
-DenseSet<int64_t> findPermutationsIndexingOperand(linalg::LinalgOp linalgOp,
-                                                  OpOperand *opOperand,
-                                                  utils::IteratorType iter);
-
-/// Possible dimension candidates that define a gemm embedded in the indexing
-/// maps of a LinalgOp.
-struct GemmDimsForPacking {
-  DenseSet<int64_t> mPos, nPos, kPos;
-};
-
-/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
-/// a gemm subcomputation within `linalgOp`. These dimensions are such that:
-///   1. The m dimension is involved in an outer-product along LHS
-///      (i.e. it is a permutation on RES and LHS and does not appear in RHS).
-///   2. The n dimension is involved in an outer-product along RHS
-///      (i.e. it is a permutation on RES and RHS and does not appear in LHS).
-///   3. The k dimension appears as a permutation on LHS and RHS.
-///   4. m, n and k appear only once in any given indexing.
-/// This allows detecting that some gemm is embedded within `linalgOp` with some
-/// orthogonal heuristic.
-FailureOr<GemmDimsForPacking> inferGemmDims(linalg::LinalgOp linalgOp);
-
-/// Return true if `linalgOp` contains an embedded gemm subcomputation.
-bool containsMostMinorGemm(linalg::LinalgOp linalgOp);
-
 /// Implementation of tiling operations using `scf.forall`.
 DiagnosedSilenceableFailure tileToForallOpImpl(
     RewriterBase &rewriter, transform::TransformState &state,
index c58e955..e107911 100644 (file)
@@ -590,14 +590,14 @@ def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
 
     Different packing strategies are applied in order, when one applies
     successfully, the transform returns:
-      1. Gemm packing: Try to infer a gemm operation embedded in the target op.
+      1. Matmul packing: Try to infer a matmul operation embedded in the target op.
          Specifically, this looks for 2 parallel dimensions that participate in
          an outer-product and 1 reduction dimension.
-         These dimensions are referred as (m, n, k) to match canonical gemm
+         These dimensions are referred as (m, n, k) to match canonical matmul
          terminology.
-         The packed sizes for (m, n, k) are specified by `gemm_packed_sizes`.
+         The packed sizes for (m, n, k) are specified by `matmul_packed_sizes`.
          The ordering of the packed dimensions (mm, nn, kk) is specified by the
-         `gemm_inner_dims_order` attribute.
+         `matmul_inner_dims_order` attribute.
 
     Packing occurs as follows:
       1. Find the dimensions to pack according to the strategy.
@@ -624,25 +624,25 @@ def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
 
   // TODO: Transform_ConcreteOpType<linalg::LinalgOp> needs interface.
   let arguments = (ins TransformHandleTypeInterface:$target,
-                   Variadic<PDL_Operation>:$gemm_packed_sizes,
+                   Variadic<PDL_Operation>:$matmul_packed_sizes,
                    DefaultValuedAttr<DenseI64ArrayAttr, "{}">
-                     :$static_gemm_packed_sizes,
+                     :$static_matmul_packed_sizes,
                    DefaultValuedAttr<DenseI64ArrayAttr, "{}">
-                     :$gemm_inner_dims_order);
+                     :$matmul_inner_dims_order);
   let results = (outs Transform_ConcreteOpType<"linalg.generic">:$packed_op);
 
   let builders = [
     OpBuilder<(ins "Value":$target,
-                   "ArrayRef<OpFoldResult>":$mixedGemmPackedSizes,
-                   CArg<"ArrayRef<int64_t>", "{}">:$gemmDimsInnerDimsOrder)>
+                   "ArrayRef<OpFoldResult>":$mixedMatmulPackedSizes,
+                   CArg<"ArrayRef<int64_t>", "{}">:$matmulDimsInnerDimsOrder)>
   ];
 
   let assemblyFormat = [{
     $target
     oilist(
-      `gemm_packed_sizes` `=` custom<DynamicIndexList>($gemm_packed_sizes,
-                                                       $static_gemm_packed_sizes)
-      `gemm_inner_dims_order` `=` $gemm_inner_dims_order
+      `matmul_packed_sizes` `=` custom<DynamicIndexList>($matmul_packed_sizes,
+                                                       $static_matmul_packed_sizes)
+      `matmul_inner_dims_order` `=` $matmul_inner_dims_order
     )
     attr-dict
     `:` functional-type($target, results)
@@ -652,7 +652,7 @@ def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
   let extraClassDeclaration = [{
     /// Returns the list of tile sizes, which may be static (Attribute) or
     /// dynamic (Value).
-    SmallVector<OpFoldResult> getMixedGemmPackedSizes();
+    SmallVector<OpFoldResult> getMixedMatmulPackedSizes();
   }];
 }
 
index 3c3fa70..4c23ceb 100644 (file)
@@ -28,6 +28,44 @@ class ExtractSliceOp;
 namespace linalg {
 
 //===----------------------------------------------------------------------===//
+// Utilities for inferring various semantics properties of Linalg ops.
+//===----------------------------------------------------------------------===//
+
+/// Possible dimension candidates that define a matmul embedded in the indexing
+/// maps of a LinalgOp.
+struct EmbeddedMatmulDimsCandidates {
+  DenseSet<int64_t> mPos, nPos, kPos;
+};
+
+/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the
+/// iterators of type `iter` that index the `opOperand` as a permutation.
+/// This is useful to infer various subcomputations on a given `linalgOp`.
+/// This is performed by looking up each result in the matching indexing map and
+/// determining whether:
+///   - It is a single AffineDimExpr.
+///   - It is the only result involving this AffineDimExpr.
+DenseSet<int64_t> findPermutationsIndexingOperand(LinalgOp linalgOp,
+                                                  OpOperand *opOperand,
+                                                  utils::IteratorType iter);
+
+/// Return true if `linalgOp` contains an embedded matmul subcomputation in its
+/// most minor dimensions.
+bool containsMostMinorMatmul(linalg::LinalgOp linalgOp);
+
+/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
+/// a matmul subcomputation within `linalgOp`. These dimensions are such that:
+///   1. The m dimension is involved in an outer-product along LHS
+///      (i.e. it is a permutation on RES and LHS and does not appear in RHS).
+///   2. The n dimension is involved in an outer-product along RHS
+///      (i.e. it is a permutation on RES and RHS and does not appear in LHS).
+///   3. The k dimension appears as a permutation on LHS and RHS.
+///   4. m, n and k appear only once in any given indexing.
+/// This allows detecting that some matmul is embedded within `linalgOp` with
+/// some orthogonal heuristic.
+FailureOr<EmbeddedMatmulDimsCandidates>
+inferMatmulDims(linalg::LinalgOp linalgOp);
+
+//===----------------------------------------------------------------------===//
 // General utilities
 //===----------------------------------------------------------------------===//
 
@@ -96,10 +134,10 @@ FailureOr<int64_t> getConstantUpperBoundForIndex(Value value);
 
 /// Create a tensor::PadOp that pads `source` to the size of the statically
 /// sized `type` whose static sizes are assumed to be greater than the dynamic
-/// `source` size. The padding introduces trailing `pad` values until the target
-/// size is met. If `source` is defined by one or more LinalgOps that have been
-/// padded with the same value and sizes, return their padded result instead of
-/// creating a tensor::PadOp.
+/// `source` size. The padding introduces trailing `pad` values until the
+/// target size is met. If `source` is defined by one or more LinalgOps that
+/// have been padded with the same value and sizes, return their padded result
+/// instead of creating a tensor::PadOp.
 ///
 /// Example:
 /// ```
@@ -116,8 +154,8 @@ FailureOr<int64_t> getConstantUpperBoundForIndex(Value value);
 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
                             Value source, Value pad, bool nofold);
 
-/// Returns a GenericOp that transposes `inputTensor` into `outputTensor` using
-/// `transposeVector` to permute the `inputTensor` dimensions.
+/// Returns a GenericOp that transposes `inputTensor` into `outputTensor`
+/// using `transposeVector` to permute the `inputTensor` dimensions.
 GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
                           Value outputTensor,
                           ArrayRef<int64_t> transposeVector);
@@ -127,12 +165,12 @@ GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
 /// or vectorize.
 GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to);
 
-/// Get the reassociation maps to fold the result of a extract_slice (or source
-/// of a insert_slice) operation with given offsets, and sizes to its
+/// Get the reassociation maps to fold the result of a extract_slice (or
+/// source of a insert_slice) operation with given offsets, and sizes to its
 /// rank-reduced version. This is only done for the cases where the size is 1
-/// and offset is 0. Strictly speaking the offset 0 is not required in general,
-/// but non-zero offsets are not handled by SPIR-V backend at this point (and
-/// potentially cannot be handled).
+/// and offset is 0. Strictly speaking the offset 0 is not required in
+/// general, but non-zero offsets are not handled by SPIR-V backend at this
+/// point (and potentially cannot be handled).
 std::optional<SmallVector<ReassociationIndices>>
 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
 
@@ -151,8 +189,9 @@ enum class LinalgTilingLoopType {
   ParallelLoops = 2
 };
 
-/// Computes tile offsets, given a list of loop `ivs` and `tileSizes`. In case a
-/// tile size is zero (i.e., no tiling), the corresponding offset is also zero.
+/// Computes tile offsets, given a list of loop `ivs` and `tileSizes`. In case
+/// a tile size is zero (i.e., no tiling), the corresponding offset is also
+/// zero.
 SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc,
                                              ArrayRef<OpFoldResult> ivs,
                                              ArrayRef<OpFoldResult> tileSizes);
@@ -166,15 +205,16 @@ SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc,
                                            ArrayRef<OpFoldResult> sizeBounds);
 
 /// Returns the list of tensor output types produced when the given structured
-/// operation `op` is applied to the given `operands`. Note that `operands` are
-/// not necessarily the actual operands of `op`.
+/// operation `op` is applied to the given `operands`. Note that `operands`
+/// are not necessarily the actual operands of `op`.
 SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands);
 
 /// Creates `insert_slice` ops that insert `results` back into larger tensors
-/// they were originally extracted from with `extract_slice` before being passed
-/// as `operands` to the given structured operation `op` or its clone. Note that
-/// `operands` are not necessarily the actual operands of `op`, the operation
-/// serves only as metadata container for operand types and positions.
+/// they were originally extracted from with `extract_slice` before being
+/// passed as `operands` to the given structured operation `op` or its clone.
+/// Note that `operands` are not necessarily the actual operands of `op`, the
+/// operation serves only as metadata container for operand types and
+/// positions.
 SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
                                     LinalgOp op, ValueRange operands,
                                     ValueRange results);
@@ -187,8 +227,8 @@ struct SliceParameters {
 };
 
 /// Computes SliceParameters for a single `valueToTile` assuming that its user
-/// is being tiled with the given loop bounds `lbs` and `ubs` and the tile sizes
-/// `tileSizes`.
+/// is being tiled with the given loop bounds `lbs` and `ubs` and the tile
+/// sizes `tileSizes`.
 ///
 /// `omitPartialTileCheck` controls whether to omit the partial/boundary tile
 /// condition check in cases where we statically know that it is unnecessary.
@@ -219,8 +259,8 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
 /// Creates an extract_slice/subview op for a single `valueToTile` with
 /// `builder`. This new operation extracts a tile of `valueToTile`, starting
 /// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck`
-/// controls whether to omit the partial/boundary tile condition check in cases
-/// where we statically know that it is unnecessary.
+/// controls whether to omit the partial/boundary tile condition check in
+/// cases where we statically know that it is unnecessary.
 Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
                      ArrayRef<OpFoldResult> tileSizes, AffineMap map,
                      ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
@@ -232,8 +272,8 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
 /// nest for tiling with the given induction variables `ivs` and tile sizes
 /// `tileSizes`. `sizeBounds` are the iteration space bounds for *all* the
 /// implicit loops in `linalgOp`. `omitPartialTileCheck` controls whether to
-/// omit the partial/boundary tile condition check in cases where we statically
-/// know that it is unnecessary.
+/// omit the partial/boundary tile condition check in cases where we
+/// statically know that it is unnecessary.
 ///
 /// Note that a constant zero in `tileSizes` means no tiling at that implicit
 /// loop. The number of non-zero values in `tileSizes` should be equal to the
@@ -254,8 +294,9 @@ void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
                    ArrayRef<OpFoldResult> offests);
 
 /// A struct containing the Linalg producer before and after fusion.
-/// When operating on tensors, `fusedProducer` may feed into a `tensor.cast` op
-/// before the consumer Linalg op, until enough canonicalizations have applied.
+/// When operating on tensors, `fusedProducer` may feed into a `tensor.cast`
+/// op before the consumer Linalg op, until enough canonicalizations have
+/// applied.
 struct FusionInfo {
   LinalgOp originalProducer;
   LinalgOp fusedProducer;
@@ -285,19 +326,23 @@ FailureOr<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
 /// Scheme used to distribute loops to processors.
 enum class DistributionMethod {
   /// Cyclic distribution where no assumption is made about the dynamic
-  /// relationship between number of processors and number of iterations of the
+  /// relationship between number of processors and number of iterations of
+  /// the
   /// distributed loop. Distributes the following loop
   ///
   /// scf.parallel (%iv) = (%lb) to (%ub) step (%step)
   ///
   /// to
   ///
-  /// scf.parallel(%iv)= (%lb + %procId * %step) to (%ub) step (%step * %nprocs)
+  /// scf.parallel(%iv)= (%lb + %procId * %step) to (%ub) step (%step *
+  /// %nprocs)
   Cyclic = 0,
 
   /// Cyclic distribution where the number of processors can be assumed to be
-  /// more than or equal to the number of iterations of the distributed loop. In
-  /// such cases, a simple in-bounds check is enough (instead of materializing a
+  /// more than or equal to the number of iterations of the distributed loop.
+  /// In
+  /// such cases, a simple in-bounds check is enough (instead of materializing
+  /// a
   /// loop). Distributes the following loop
   ///
   /// scf.parallel (%iv) = (%lb) to (%ub) step (%step)
@@ -312,7 +357,8 @@ enum class DistributionMethod {
   CyclicNumProcsGeNumIters = 1,
 
   /// Cyclic distribution where the number of processors can be assumed to be
-  ///  equal to the number of iterations of the distributed loop. In such cases,
+  ///  equal to the number of iterations of the distributed loop. In such
+  ///  cases,
   ///  no bounds check is needed. Distributes the following loop
   ///
   /// scf.parallel (%iv) = (%lb) to (%ub) step (%step)
@@ -339,16 +385,17 @@ using ProcInfoCallBackFn = std::function<SmallVector<ProcInfo>(
 /// Options that allow distribution of loops generated in Linalg transforms to
 /// processors while generating the loops.
 struct LinalgLoopDistributionOptions {
-  /// Callback function that returns the Values for processor ID (`procId`), and
-  /// number of processors (`nprocs`) used to execute the parallel loops. The
-  /// number of `{procId, nprocs}` pairs returned must be equal to the number of
-  /// `parallelLoopRanges` passed into the callback. The `parallelLoopRanges`
-  /// are ranges of the outer parallel loops of the operation that
-  /// do have non-zero tile sizes specified.
+  /// Callback function that returns the Values for processor ID (`procId`),
+  /// and number of processors (`nprocs`) used to execute the parallel loops.
+  /// The number of `{procId, nprocs}` pairs returned must be equal to the
+  /// number of `parallelLoopRanges` passed into the callback. The
+  /// `parallelLoopRanges` are ranges of the outer parallel loops of the
+  /// operation that do have non-zero tile sizes specified.
   ProcInfoCallBackFn procInfo;
 };
 
-/// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
+/// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and
+/// `step`.
 void updateBoundsForCyclicDistribution(OpBuilder &builder, Location loc,
                                        Value procId, Value nprocs, Value &lb,
                                        Value &ub, Value &step);
@@ -362,15 +409,15 @@ class TileLoopNest {
 public:
   TileLoopNest(LinalgOp rootOp) : rootOp(rootOp) {}
 
-  /// Tile the root operation using the given `tileSizes` and `tileInterchange`,
-  /// and `tileDistribution`.
+  /// Tile the root operation using the given `tileSizes` and
+  /// `tileInterchange`, and `tileDistribution`.
   LogicalResult
   tileRootOp(OpBuilder &b, ArrayRef<int64_t> tileSizes,
              ArrayRef<int64_t> tileInterchange,
              std::optional<LinalgLoopDistributionOptions> tileDistribution);
 
-  /// Fuse the producer of `consumerOpOperand` into the tile loop nest. Returns
-  /// the fused producer or fails if fusion is not possible.
+  /// Fuse the producer of `consumerOpOperand` into the tile loop nest.
+  /// Returns the fused producer or fails if fusion is not possible.
   FailureOr<LinalgOp> fuseProducer(OpBuilder &b, OpOperand *consumerOpOperand);
 
   /// Returns the replacement results for the original untiled root operation.
@@ -426,8 +473,8 @@ struct RegionMatcher {
     IAdd,
   };
 
-  /// Matches the given linalg op if its body is performing binary operation on
-  /// int or float scalar values and returns the binary op kind.
+  /// Matches the given linalg op if its body is performing binary operation
+  /// on int or float scalar values and returns the binary op kind.
   ///
   /// The linalg op's region is expected to be
   /// ```
@@ -445,9 +492,10 @@ struct RegionMatcher {
 //===----------------------------------------------------------------------===//
 
 /// Utility class used to generate nested loops with ranges described by
-/// `loopRanges` and loop type described by the `iteratorTypes`. `bodyBuilderFn`
-/// is used to generate the body of the innermost loop. It is passed a range
-/// of loop induction variables and a range of operand values to use.
+/// `loopRanges` and loop type described by the `iteratorTypes`.
+/// `bodyBuilderFn` is used to generate the body of the innermost loop. It is
+/// passed a range of loop induction variables and a range of operand values
+/// to use.
 template <typename LoopTy>
 struct GenerateLoopNest {
   static void doit(OpBuilder &b, Location loc, ArrayRef<Range> loopRanges,
index e3c1429..6ee0f13 100644 (file)
 #include "mlir/Dialect/Transform/Utils/Utils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/TilingInterface.h"
@@ -38,9 +34,6 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
-#include "llvm/ADT/SetOperations.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 
@@ -1300,94 +1293,21 @@ void transform::PackOp::getEffects(
 //===---------------------------------------------------------------------===//
 
 LogicalResult transform::PackGreedilyOp::verify() {
-  if (!isPermutationVector(getGemmInnerDimsOrder())) {
-    return emitOpError() << getGemmInnerDimsOrderAttrName()
+  if (!isPermutationVector(getMatmulInnerDimsOrder())) {
+    return emitOpError() << getMatmulInnerDimsOrderAttrName()
                          << " is not a valid permutation";
   }
-  // TODO: relax to allow empty once we have another strategy than just gemm.
-  if (getGemmInnerDimsOrder().size() != 3 ||
-      getMixedGemmPackedSizes().size() != 3) {
-    return emitOpError() << " needs 3 entries for gemm_packed_sizes and "
-                         << getGemmInnerDimsOrderAttrName()
-                         << " order for the gemm strategy";
+  // TODO: relax to allow empty once we have another strategy than just matmul.
+  if (getMatmulInnerDimsOrder().size() != 3 ||
+      getMixedMatmulPackedSizes().size() != 3) {
+    return emitOpError() << " needs 3 entries for matmul_packed_sizes and "
+                         << getMatmulInnerDimsOrderAttrName()
+                         << " order for the matmul strategy";
   }
   return success();
 }
 
-namespace {
-auto par = utils::IteratorType::parallel;
-auto red = utils::IteratorType::reduction;
-} // namespace
-
-DenseSet<int64_t> transform::findPermutationsIndexingOperand(
-    LinalgOp linalgOp, OpOperand *opOperand, utils::IteratorType iter) {
-  DenseSet<int64_t> res;
-  assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
-  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
-  for (AffineExpr e : indexingMap.getResults()) {
-    if (auto d = e.dyn_cast<AffineDimExpr>()) {
-      if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter &&
-          llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
-            return e.isFunctionOfDim(d.getPosition());
-          }) == 1)
-        res.insert(d.getPosition());
-    }
-  }
-  return res;
-}
-
-FailureOr<GemmDimsForPacking> transform::inferGemmDims(LinalgOp linalgOp) {
-  if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
-    return failure();
-
-  DenseSet<int64_t> a = findPermutationsIndexingOperand(
-      linalgOp, linalgOp.getDpsInputOperand(0), par);
-  DenseSet<int64_t> b = findPermutationsIndexingOperand(
-      linalgOp, linalgOp.getDpsInputOperand(1), par);
-  DenseSet<int64_t> c = findPermutationsIndexingOperand(
-      linalgOp, linalgOp.getDpsInitOperand(0), par);
-
-  // A & C - B are the iterators involved in an outer-product along A (the LHS).
-  DenseSet<int64_t> ac = a;
-  llvm::set_intersect(ac, c);
-  llvm::set_subtract(ac, b);
-  // B & C - A are the iterators involved in an outer-product along B (the RHS).
-  DenseSet<int64_t> bc = b;
-  llvm::set_intersect(bc, c);
-  llvm::set_subtract(bc, a);
-
-  // Note: if we ever need them, A & B & C would be "batch" dimensions.
-
-  // A & B red are the reduction dimensions.
-  DenseSet<int64_t> ra = findPermutationsIndexingOperand(
-      linalgOp, linalgOp.getDpsInputOperand(0), red);
-  DenseSet<int64_t> rb = findPermutationsIndexingOperand(
-      linalgOp, linalgOp.getDpsInputOperand(1), red);
-  llvm::set_intersect(ra, rb);
-
-  if (ac.empty() || bc.empty() || ra.empty())
-    return failure();
-
-  // Pick the first one in each set.
-  // TODO: Better heuristic (e.g pick dims based on packing-based metric).
-  return GemmDimsForPacking{ac, bc, ra};
-}
-
-bool transform::containsMostMinorGemm(LinalgOp linalgOp) {
-  FailureOr<GemmDimsForPacking> res = inferGemmDims(linalgOp);
-  if (failed(res))
-    return false;
-  int64_t numLoops = linalgOp.getNumLoops();
-  for (const DenseSet<int64_t> &s : {res->mPos, res->nPos, res->kPos}) {
-    if (s.contains(numLoops - 3) || s.contains(numLoops - 2) ||
-        s.contains(numLoops - 1))
-      continue;
-    return false;
-  }
-  return true;
-}
-
-/// Pack a LinalgOp by greedily inferring gemm dimensions (m, n, k) where m
+/// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m
 /// and n are proper parallel dimensions and k is a proper reduction
 /// dimension. Packing occurs by rewriting the op as a linalg.generic and
 /// calling linalg::pack by `mnkPackedSizes`. The order of the packed
@@ -1396,17 +1316,17 @@ bool transform::containsMostMinorGemm(LinalgOp linalgOp) {
 /// dimensions of the operands are not permuted at this time, this is left for
 /// future work.
 static FailureOr<PackResult>
-packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
-                 ArrayRef<OpFoldResult> mnkPackedSizes,
-                 ArrayRef<int64_t> mnkOrder) {
+packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
+                   ArrayRef<OpFoldResult> mnkPackedSizes,
+                   ArrayRef<int64_t> mnkOrder) {
   assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes");
   assert(mnkOrder.size() == 3 && "unexpected mnkOrder size");
   assert(isPermutationVector(mnkOrder) && "expected a permutation");
 
   int64_t numLoops = linalgOp.getNumLoops();
   if (numLoops <= 2) {
-    return rewriter.notifyMatchFailure(linalgOp,
-                                       "need 3+ loops to find a gemm to pack");
+    return rewriter.notifyMatchFailure(
+        linalgOp, "need 3+ loops to find a matmul to pack");
   }
 
   // Locally adjust the desired iterator position of mnk and packing sizes.
@@ -1418,11 +1338,11 @@ packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
   for (int64_t i = 0, e = numPackedDims; i < e; ++i)
     packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
 
-  // 1. Infer dims that are important for gemm.
-  FailureOr<GemmDimsForPacking> res = inferGemmDims(linalgOp);
+  // 1. Infer dims that are important for matmul.
+  FailureOr<EmbeddedMatmulDimsCandidates> res = inferMatmulDims(linalgOp);
   if (failed(res)) {
     return rewriter.notifyMatchFailure(linalgOp,
-                                       "couldn't infer gemm iterators");
+                                       "couldn't infer matmul iterators");
   }
 
   // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most
@@ -1479,8 +1399,8 @@ packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
   // TODO: If we wanted to give the genericOp a name after packing, after
   // calling `pack` would be a good time.
   auto packingRes = linalg::pack(rewriter, genericOp, adjustedPackedSizes);
-  assert(containsMostMinorGemm(packingRes->packedLinalgOp) &&
-         "failed to pack to a most minor gemm");
+  assert(containsMostMinorMatmul(packingRes->packedLinalgOp) &&
+         "failed to pack to a most minor matmul");
   return packingRes;
 }
 
@@ -1500,11 +1420,11 @@ PackGreedilyOp::apply(transform::TransformResults &transformResults,
     rewriter.setInsertionPointAfter(linalgOp);
     // Failing to pack greedily is perfectly fine.
     // In the future we will want to order packings according to some metric.
-    FailureOr<PackResult> packResult = packGemmGreedily(
+    FailureOr<PackResult> packResult = packMatmulGreedily(
         /*rewriter=*/rewriter,
         /*linalgOp=*/linalgOp,
-        /*mnkPackedSizes=*/getMixedGemmPackedSizes(),
-        /*mnkOrder=*/getGemmInnerDimsOrder());
+        /*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
+        /*mnkOrder=*/getMatmulInnerDimsOrder());
     if (succeeded(packResult)) {
       results.push_back(packResult->packedLinalgOp);
       continue;
@@ -1515,15 +1435,16 @@ PackGreedilyOp::apply(transform::TransformResults &transformResults,
   return DiagnosedSilenceableFailure::success();
 }
 
-SmallVector<OpFoldResult> PackGreedilyOp::getMixedGemmPackedSizes() {
+SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() {
   Builder b(getContext());
-  return getMixedValues(getStaticGemmPackedSizes(), getGemmPackedSizes(), b);
+  return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
+                        b);
 }
 
 void transform::PackGreedilyOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   transform::consumesHandle(getTarget(), effects);
-  transform::onlyReadsHandle(getGemmPackedSizes(), effects);
+  transform::onlyReadsHandle(getMatmulPackedSizes(), effects);
   transform::producesHandle(getPackedOp(), effects);
   transform::modifiesPayload(effects);
 }
index 75f818b..572b7e4 100644 (file)
@@ -33,6 +33,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/Pass/Pass.h"
+#include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include <optional>
@@ -138,6 +139,88 @@ static void unpackRanges(OpBuilder &builder, Location loc,
   }
 }
 
+//===----------------------------------------------------------------------===//
+// Utilities for inferring various semantics properties of Linalg ops.
+//===----------------------------------------------------------------------===//
+
+DenseSet<int64_t> mlir::linalg::findPermutationsIndexingOperand(
+    LinalgOp linalgOp, OpOperand *opOperand, utils::IteratorType iter) {
+  DenseSet<int64_t> res;
+  assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
+  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
+  for (AffineExpr e : indexingMap.getResults()) {
+    if (auto d = e.dyn_cast<AffineDimExpr>()) {
+      if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter &&
+          llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
+            return e.isFunctionOfDim(d.getPosition());
+          }) == 1)
+        res.insert(d.getPosition());
+    }
+  }
+  return res;
+}
+
+namespace {
+auto par = utils::IteratorType::parallel;
+auto red = utils::IteratorType::reduction;
+} // namespace
+
+bool mlir::linalg::containsMostMinorMatmul(LinalgOp linalgOp) {
+  FailureOr<EmbeddedMatmulDimsCandidates> res = inferMatmulDims(linalgOp);
+  if (failed(res))
+    return false;
+  int64_t numLoops = linalgOp.getNumLoops();
+  for (const DenseSet<int64_t> &s : {res->mPos, res->nPos, res->kPos}) {
+    if (s.contains(numLoops - 3) || s.contains(numLoops - 2) ||
+        s.contains(numLoops - 1))
+      continue;
+    return false;
+  }
+  return true;
+}
+
+FailureOr<EmbeddedMatmulDimsCandidates>
+mlir::linalg::inferMatmulDims(LinalgOp linalgOp) {
+  if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
+    return failure();
+
+  DenseSet<int64_t> a = findPermutationsIndexingOperand(
+      linalgOp, linalgOp.getDpsInputOperand(0), par);
+  DenseSet<int64_t> b = findPermutationsIndexingOperand(
+      linalgOp, linalgOp.getDpsInputOperand(1), par);
+  DenseSet<int64_t> c = findPermutationsIndexingOperand(
+      linalgOp, linalgOp.getDpsInitOperand(0), par);
+
+  // A & C - B are the iterators involved in an outer-product along A (the LHS).
+  DenseSet<int64_t> ac = a;
+  llvm::set_intersect(ac, c);
+  llvm::set_subtract(ac, b);
+  // B & C - A are the iterators involved in an outer-product along B (the RHS).
+  DenseSet<int64_t> bc = b;
+  llvm::set_intersect(bc, c);
+  llvm::set_subtract(bc, a);
+
+  // Note: if we ever need them, A & B & C would be "batch" dimensions.
+
+  // A & B red are the reduction dimensions.
+  DenseSet<int64_t> ra = findPermutationsIndexingOperand(
+      linalgOp, linalgOp.getDpsInputOperand(0), red);
+  DenseSet<int64_t> rb = findPermutationsIndexingOperand(
+      linalgOp, linalgOp.getDpsInputOperand(1), red);
+  llvm::set_intersect(ra, rb);
+
+  if (ac.empty() || bc.empty() || ra.empty())
+    return failure();
+
+  // Pick the first one in each set.
+  // TODO: Better heuristic (e.g pick dims based on packing-based metric).
+  return EmbeddedMatmulDimsCandidates{ac, bc, ra};
+}
+
+//===----------------------------------------------------------------------===//
+// General utilities
+//===----------------------------------------------------------------------===//
+
 namespace mlir {
 namespace linalg {
 
index 42f3a6c..544f439 100644 (file)
@@ -25,7 +25,7 @@ transform.sequence failures(propagate) {
   %matmul = transform.structured.match ops{["linalg.matmul"]} in %module_op 
     : (!pdl.operation) -> !transform.op<"linalg.matmul">
   transform.structured.pack_greedily %matmul 
-      gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0]
+      matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0]
     : (!transform.op<"linalg.matmul">) -> !transform.op<"linalg.generic">
 }
 
@@ -70,7 +70,7 @@ transform.sequence failures(propagate) {
 ^bb1(%module_op: !pdl.operation):
   %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic">
   transform.structured.pack_greedily %generic
-      gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0]
+      matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0]
     : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic">
 }
 
@@ -115,7 +115,7 @@ transform.sequence failures(propagate) {
 ^bb1(%module_op: !pdl.operation):
   %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic">
   transform.structured.pack_greedily %generic
-      gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0]
+      matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0]
     : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic">
 }
 
@@ -160,7 +160,7 @@ transform.sequence failures(propagate) {
 ^bb1(%module_op: !pdl.operation):
   %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic">
   transform.structured.pack_greedily %generic
-      gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0]
+      matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0]
     : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic">
 }
 
@@ -195,7 +195,7 @@ transform.sequence failures(propagate) {
   %conv = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %module_op 
     : (!pdl.operation) -> !transform.op<"linalg.conv_2d_nchw_fchw">
   transform.structured.pack_greedily %conv
-      gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0]
+      matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0]
     : (!transform.op<"linalg.conv_2d_nchw_fchw">) -> !transform.op<"linalg.generic">
 }
 
@@ -223,6 +223,6 @@ transform.sequence failures(propagate) {
 ^bb1(%module_op: !pdl.operation):
   %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic">
   transform.structured.pack_greedily %generic
-      gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0]
+      matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0]
     : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic">
 }