[mlir][Linalg] Adding a greedy packing transform dialect op.
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Thu, 26 Jan 2023 20:30:54 +0000 (12:30 -0800)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Thu, 26 Jan 2023 22:07:17 +0000 (14:07 -0800)
This PR adds a `pack_greedily` transform operation that infers the packing for gemm
subcomputations embedded within in any LinalgOp and packs accordingly.
A normalization step guarantees that we get the innermost op dimensions in one of `8`
possible `(m, n, k)` orders, specified as a parameter, from which we can emit all
packed forms.

The current implementation takes an arbitrary LinalgOp and tries to pack it along
the specified dimensions with specified sizes and inner dim permutation.

This achieves a new level of normalization and generalization for any `n-D`
LinalgOp that contains a gemm embedded within it:
we will always see a predictable packed form for any of these ops.

Differential Revision: https://reviews.llvm.org/D142661

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-pack-greedily.mlir [new file with mode: 0644]

index 30506fa..bd2fe3b 100644 (file)
@@ -370,7 +370,7 @@ def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
 
 def PackOp : Op<Transform_Dialect, "structured.pack", [
                 DeclareOpInterfaceMethods<TransformOpInterface>,
-                DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,]> {
+                DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
   let description = [{
     Pack a LinalgOp by applying a data tiling transformation on the op and
     packing the operands according to the `packed_sizes` specification.
@@ -454,6 +454,84 @@ def PackOp : Op<Transform_Dialect, "structured.pack", [
 }
 
 //===----------------------------------------------------------------------===//
+// PackGreedilyOp
+//===----------------------------------------------------------------------===//
+def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
+                        DeclareOpInterfaceMethods<TransformOpInterface>,
+                        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let description = [{
+    Target a Linalg op and rewrite it into packed LinalgOp form by trying to
+    infer whether a known suboperation is embedded
+
+    Different packing strategies are applied in order, when one applies
+    successfully, the transform returns:
+      1. Gemm packing: Try to infer a gemm operation embedded in the target op.
+         Specifically, this looks for 2 parallel dimensions that participate in
+         an outer-product and 1 reduction dimension.
+         These dimensions are referred as (m, n, k) to match canonical gemm
+         terminology.
+         The packed sizes for (m, n, k) are specified by `gemm_packed_sizes`.
+         The ordering of the packed dimensions (mm, nn, kk) is specified by the
+         `gemm_inner_dims_order` attribute.
+
+    Packing occurs as follows:
+      1. Find the dimensions to pack according to the strategy.
+      2. The target is converted to linalg.generic form.
+      3. An interchange transform is applied to isolate the dimensions to pack as
+         the most minor indexing dimensions of the linalg.generic. The most minor
+         dimensions are themselves ordered according to `inner_dims_order`.
+      4. Packing is performed by `packed_sizes` and following `inner_dims_order`.
+
+    By normalizing the most minor dimensions to `inner_dims_order`, the transform
+    guarantees that packing immediates generates inner dimensions in a desirable
+    layout.
+
+    Outer dimension layout permutations are not controlled by this transform op
+    at the moment and can be obtained by composing with the pack_transpose
+    transformation.
+
+    #### Return modes
+
+    This operation ignores non-Linalg ops and drops them in the return.
+    It returns the list of packed Linalg ops or the original op when all available
+    packing strategies failed to apply.
+  }];
+
+  // TODO: Transform_ConcreteOpType<linalg::LinalgOp> needs interface.
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                   Variadic<PDL_Operation>:$gemm_packed_sizes,
+                   DefaultValuedAttr<DenseI64ArrayAttr, "{}">
+                     :$static_gemm_packed_sizes,
+                   DefaultValuedAttr<DenseI64ArrayAttr, "{}">
+                     :$gemm_inner_dims_order);
+  let results = (outs Transform_ConcreteOpType<"linalg.generic">:$packed_op);
+
+  let builders = [
+    OpBuilder<(ins "Value":$target,
+                   "ArrayRef<OpFoldResult>":$mixedGemmPackedSizes,
+                   CArg<"ArrayRef<int64_t>", "{}">:$gemmDimsInnerDimsOrder)>
+  ];
+
+  let assemblyFormat = [{
+    $target
+    oilist(
+      `gemm_packed_sizes` `=` custom<DynamicIndexList>($gemm_packed_sizes,
+                                                       $static_gemm_packed_sizes)
+      `gemm_inner_dims_order` `=` $gemm_inner_dims_order
+    )
+    attr-dict
+    `:` functional-type($target, results)
+  }];
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    /// Returns the list of tile sizes, which may be static (Attribute) or
+    /// dynamic (Value).
+    SmallVector<OpFoldResult> getMixedGemmPackedSizes();
+  }];
+}
+
+//===----------------------------------------------------------------------===//
 // PackTransposeOp
 //===----------------------------------------------------------------------===//
 def PackTransposeOp : Op<Transform_Dialect, "structured.pack_transpose", [
index 554328c..2060dd7 100644 (file)
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/Support/Debug.h"
 
@@ -134,7 +136,7 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPDLOperations(
 //===----------------------------------------------------------------------===//
 
 DiagnosedSilenceableFailure
-transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
+transform::DecomposeOp::applyToOne(LinalgOp target,
                                    transform::ApplyToEachResultList &results,
                                    transform::TransformState &state) {
 #define DOWNSCALE(trans)                                                       \
@@ -642,7 +644,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
 //===----------------------------------------------------------------------===//
 
 DiagnosedSilenceableFailure
-transform::GeneralizeOp::applyToOne(linalg::LinalgOp target,
+transform::GeneralizeOp::applyToOne(LinalgOp target,
                                     transform::ApplyToEachResultList &results,
                                     transform::TransformState &state) {
   // Exit early if no transformation is needed.
@@ -663,7 +665,7 @@ transform::GeneralizeOp::applyToOne(linalg::LinalgOp target,
 //===----------------------------------------------------------------------===//
 
 DiagnosedSilenceableFailure
-transform::InterchangeOp::applyToOne(linalg::GenericOp target,
+transform::InterchangeOp::applyToOne(GenericOp target,
                                      transform::ApplyToEachResultList &results,
                                      transform::TransformState &state) {
   ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
@@ -730,7 +732,7 @@ transform::MatchOp::apply(transform::TransformResults &results,
     if (getInterface().has_value()) {
       auto iface = getInterface().value();
       if (iface == transform::MatchInterfaceEnum::LinalgOp &&
-          !isa<linalg::LinalgOp>(op))
+          !isa<LinalgOp>(op))
         return;
       if (iface == transform::MatchInterfaceEnum::TilingInterface &&
           isa<TilingInterface>(op))
@@ -885,7 +887,7 @@ void transform::PackOp::build(OpBuilder &builder, OperationState &result,
   // attributes for multiple variadic operands. In the absence of this, horrible
   // bugs ensue.
   Type linalgOpHType = transform::OperationType::get(
-      builder.getContext(), linalg::GenericOp::getOperationName());
+      builder.getContext(), GenericOp::getOperationName());
   build(builder, result,
         /*resultType=*/linalgOpHType,
         /*target=*/target,
@@ -908,7 +910,7 @@ transform::PackOp::apply(transform::TransformResults &transformResults,
     return DiagnosedSilenceableFailure::success();
   }
   // Fail on multi-op handles.
-  auto linalgOp = dyn_cast<linalg::LinalgOp>(targetOps.front());
+  auto linalgOp = dyn_cast<LinalgOp>(targetOps.front());
   if (targetOps.size() != 1 || !linalgOp) {
     return emitSilenceableError()
            << "requires target to map to exactly 1 LinalgOp (got "
@@ -947,6 +949,268 @@ void transform::PackOp::getEffects(
 }
 
 //===---------------------------------------------------------------------===//
+// PackGreedilyOp.
+//===---------------------------------------------------------------------===//
+
+LogicalResult transform::PackGreedilyOp::verify() {
+  if (!isPermutationVector(getGemmInnerDimsOrder())) {
+    return emitOpError() << getGemmInnerDimsOrderAttrName()
+                         << " is not a valid permutation";
+  }
+  // TODO: relax to allow empty once we have another strategy than just gemm.
+  if (getGemmInnerDimsOrder().size() != 3 ||
+      getMixedGemmPackedSizes().size() != 3) {
+    return emitOpError() << " needs 3 entries for gemm_packed_sizes and "
+                         << getGemmInnerDimsOrderAttrName()
+                         << " order for the gemm strategy";
+  }
+  return success();
+}
+
+namespace {
+auto par = utils::IteratorType::parallel;
+auto red = utils::IteratorType::reduction;
+} // namespace
+
+/// Return the set of AffineDimExpr
+static DenseSet<int64_t>
+findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
+                                utils::IteratorType iter) {
+  DenseSet<int64_t> res;
+  assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
+  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
+  for (AffineExpr e : indexingMap.getResults()) {
+    if (auto d = e.dyn_cast<AffineDimExpr>()) {
+      if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter &&
+          llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
+            return e.isFunctionOfDim(d.getPosition());
+          }) == 1)
+        res.insert(d.getPosition());
+    }
+  }
+  return res;
+}
+
+struct GemmDimsForPacking {
+  int64_t mPos, nPos, kPos;
+};
+/// Greedily look for 2 parallel (m and n) and 1 reduction (k) dimension that
+/// form a gemm. Such dimensions are such that:
+///   1. The m dimension is involved in an outer-product along LHS
+///      (i.e. it is a permutation on RES and LHS and does not appear in RHS).
+///   2. The n dimension is involved in an outer-product along RHS
+///      (i.e. it is a permutation on RES and RHS and does not appear in LHS).
+///   3. The k dimension appears as a permutation on LHS and RHS.
+///   4. m, n and k appear only once in any given indexing.
+///
+/// This allows detecting that some gemm is embedded within `linalgOp`.
+///
+/// When multiple possibilities for selecting m, n and k appear, we just pick
+/// an arbitrary one (i.e. the first in a DenseSet).
+// TODO: Better heuristic (e.g pick dims based on packing-based metric).
+static FailureOr<GemmDimsForPacking> getGemmDims(LinalgOp linalgOp) {
+  assert(linalgOp.getNumDpsInits() == 1 && "wrong number of dps inits");
+  assert(linalgOp.getNumDpsInputs() == 2 && "wrong number of dps inputs");
+
+  DenseSet<int64_t> a = findPermutationsIndexingOperand(
+      linalgOp, linalgOp.getDpsInputOperand(0), par);
+  DenseSet<int64_t> b = findPermutationsIndexingOperand(
+      linalgOp, linalgOp.getDpsInputOperand(1), par);
+  DenseSet<int64_t> c = findPermutationsIndexingOperand(
+      linalgOp, linalgOp.getDpsInitOperand(0), par);
+
+  // A & C - B are the iterators involved in an outer-product along A (the LHS).
+  DenseSet<int64_t> ac = a;
+  llvm::set_intersect(ac, c);
+  llvm::set_subtract(ac, b);
+  // B & C - A are the iterators involved in an outer-product along B (the RHS).
+  DenseSet<int64_t> bc = b;
+  llvm::set_intersect(bc, c);
+  llvm::set_subtract(bc, a);
+
+  // Note: if we ever need them, A & B & C would be "batch" dimensions.
+
+  // A & B red are the reduction dimensions.
+  DenseSet<int64_t> ra = findPermutationsIndexingOperand(
+      linalgOp, linalgOp.getDpsInputOperand(0), red);
+  DenseSet<int64_t> rb = findPermutationsIndexingOperand(
+      linalgOp, linalgOp.getDpsInputOperand(1), red);
+  llvm::set_intersect(ra, rb);
+
+  if (ac.empty() || bc.empty() || ra.empty())
+    return failure();
+
+  // Pick the first one in each set.
+  // TODO: Better heuristic (e.g pick dims based on packing-based metric).
+  return GemmDimsForPacking{*ac.begin(), *bc.begin(), *ra.begin()};
+}
+
+/// Return a permutation vector of size permSize that would result in moving
+/// positions into desiredPositions.
+///
+/// For example, permSize == 5, positions = {2, 4}, desiredPositions = {1, 0}
+/// would result in a {4, 2, 0, 1, 3} permutation vector.
+static SmallVector<int64_t>
+computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
+                         ArrayRef<int64_t> desiredPositions) {
+  SmallVector<int64_t> res(permSize, -1);
+  DenseSet<int64_t> seen;
+  for (auto [pos, desiredPos] : llvm::zip(positions, desiredPositions)) {
+    res[desiredPos] = pos;
+    seen.insert(pos);
+  }
+  int64_t nextPos = 0;
+  for (int64_t &entry : res) {
+    if (entry != -1)
+      continue;
+    while (seen.contains(nextPos))
+      ++nextPos;
+    entry = nextPos;
+    ++nextPos;
+  }
+  return res;
+}
+
+/// Pack a LinalgOp by greedily inferring gemm dimensions (m, n, k)
+/// where m and n are proper parallel dimensions and k is a proper reduction
+/// dimension.
+/// Packing occurs by rewriting the op as a linalg.generic and calling
+/// linalg::pack by `mnkPackedSizes`.
+/// The order of the packed dimensions is customizable: the `mnkOrder` is a
+/// permutation of {0, 1, 2} to reorder {m, n, k} into one of the 8 possible
+/// forms.
+/// The outer dimensions of the operands are not permuted at this time, this is
+/// left for future work.
+static FailureOr<LinalgOp>
+packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
+                 ArrayRef<OpFoldResult> mnkPackedSizes,
+                 ArrayRef<int64_t> mnkOrder) {
+  assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes");
+  assert(mnkOrder.size() == 3 && "unexpected mnkOrder size");
+  assert(isPermutationVector(mnkOrder) && "expected a permutation");
+
+  int64_t numLoops = linalgOp.getNumLoops();
+  if (numLoops <= 2) {
+    return rewriter.notifyMatchFailure(linalgOp,
+                                       "need 3+ loops to find a gemm to pack");
+  }
+
+  // Locally adjust the desired iterator position of mnk and packing sizes.
+  int64_t numPackedDims = mnkPackedSizes.size();
+  SmallVector<int64_t> mmnnkkPos(numPackedDims);
+  for (int64_t i = 0, e = numPackedDims; i < e; ++i)
+    mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
+  SmallVector<OpFoldResult> packedSizes(mnkPackedSizes.size());
+  for (int64_t i = 0, e = numPackedDims; i < e; ++i)
+    packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
+
+  // 1. Infer dims that are important for gemm.
+  FailureOr<GemmDimsForPacking> res = getGemmDims(linalgOp);
+  if (failed(res)) {
+    return rewriter.notifyMatchFailure(linalgOp,
+                                       "couldn't infer gemm iterators");
+  }
+
+  // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most
+  // minor iterators. If we wanted a different normalization order, this is
+  // where it would have to start.
+  int64_t mPos = res->mPos, nPos = res->nPos, kPos = res->kPos;
+  LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
+             DBGS() << "Start packing generic op greedily with (m@" << mPos
+                    << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp
+                    << "\n";);
+
+  // 2.a. Rewrite as a generic.
+  auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
+  if (!genericOp) {
+    FailureOr<GenericOp> generalizeResult =
+        generalizeNamedOp(rewriter, linalgOp);
+    assert(succeeded(generalizeResult) && "unexpected failure generalizing op");
+    genericOp = *generalizeResult;
+  }
+
+  // 2.b. Interchange to move the dimensions (k, m, n) as most-minor iterators.
+  // Note that this only normalized the iteration order and does not change the
+  // indexings of any operand.
+  SmallVector<int64_t> permutation =
+      computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos);
+  LLVM_DEBUG(llvm::interleaveComma(permutation, DBGS() << "perm: "); DBGSNL(););
+  // Sign .. unsigned pollution.
+  SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end());
+  FailureOr<GenericOp> interchangeResult =
+      interchangeGenericOp(rewriter, genericOp, unsignedPerm);
+  assert(succeeded(interchangeResult) && "unexpected failure interchanging op");
+  genericOp = *interchangeResult;
+  LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n";);
+
+  // At this point, the op iterators are normalized to {leading, k, m, n}.
+  // The layouts induced by packing will always be:
+  //   - LHS{leading_lhs, kk, mm}
+  //   - RHS{leading_rhs, kk, nn}
+  //   - RES{leading_res, mm, nn}
+  // If we wanted to change the packed order, we would reorder (k, m, n) to
+  // something else above.
+  //
+  // Additional permutations of the outer dims of the operands (i.e.
+  // leading_lhs, leading_rhs and leading_res) could follow by computing the
+  // desired outerPerm for each operand.
+  // This is left for future work.
+
+  // Add leading zeros to match numLoops.
+  SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(),
+                                                rewriter.getIndexAttr(0));
+  llvm::append_range(adjustedPackedSizes, packedSizes);
+
+  // TODO: If we wanted to give the genericOp a name after packing, after
+  // calling `pack` would be a good time.
+  return linalg::pack(rewriter, genericOp, adjustedPackedSizes);
+}
+
+DiagnosedSilenceableFailure
+PackGreedilyOp::apply(transform::TransformResults &transformResults,
+                      transform::TransformState &state) {
+  ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
+
+  SmallVector<Operation *> results;
+  IRRewriter rewriter(getContext());
+  for (Operation *op : targetOps) {
+    auto linalgOp = dyn_cast<LinalgOp>(op);
+    if (!linalgOp)
+      continue;
+    // linalgOp will be replaced and the insertion point may be invalidated if
+    // we set it before -> set it after.
+    rewriter.setInsertionPointAfter(linalgOp);
+    // Failing to pack greedily is perfectly fine.
+    // In the future we will want to order packings according to some metric.
+    FailureOr<LinalgOp> gemm = packGemmGreedily(
+        /*rewriter=*/rewriter,
+        /*linalgOp=*/linalgOp,
+        /*mnkPackedSizes=*/getMixedGemmPackedSizes(),
+        /*mnkOrder=*/getGemmInnerDimsOrder());
+    if (succeeded(gemm)) {
+      results.push_back(*gemm);
+      continue;
+    }
+    results.push_back(linalgOp);
+  }
+  transformResults.set(getPackedOp().cast<OpResult>(), results);
+  return DiagnosedSilenceableFailure::success();
+}
+
+SmallVector<OpFoldResult> PackGreedilyOp::getMixedGemmPackedSizes() {
+  Builder b(getContext());
+  return getMixedValues(getStaticGemmPackedSizes(), getGemmPackedSizes(), b);
+}
+
+void transform::PackGreedilyOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::consumesHandle(getTarget(), effects);
+  transform::onlyReadsHandle(getGemmPackedSizes(), effects);
+  transform::producesHandle(getPackedOp(), effects);
+  transform::modifiesPayload(effects);
+}
+
+//===---------------------------------------------------------------------===//
 // PackTransposeOp
 //===---------------------------------------------------------------------===//
 
@@ -1030,7 +1294,7 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
     return emitSilenceableError() << "requires target to map to a "
                                      "tensor.pack or tensor.unpack";
   }
-  LinalgOp linalgOpTarget = dyn_cast<linalg::LinalgOp>(linalgOps.front());
+  LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(linalgOps.front());
   if (!linalgOpTarget)
     return emitSilenceableError() << "requires a LinalgOp target";
 
@@ -1102,7 +1366,7 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
 //===---------------------------------------------------------------------===//
 
 DiagnosedSilenceableFailure
-transform::PadOp::applyToOne(linalg::LinalgOp target,
+transform::PadOp::applyToOne(LinalgOp target,
                              transform::ApplyToEachResultList &results,
                              transform::TransformState &state) {
   // Convert the integer packing flags to booleans.
@@ -1214,7 +1478,7 @@ LogicalResult transform::PadOp::verify() {
 //===----------------------------------------------------------------------===//
 
 DiagnosedSilenceableFailure
-transform::PromoteOp::applyToOne(linalg::LinalgOp target,
+transform::PromoteOp::applyToOne(LinalgOp target,
                                  transform::ApplyToEachResultList &results,
                                  transform::TransformState &state) {
   LinalgPromotionOptions promotionOptions;
@@ -1308,7 +1572,7 @@ LogicalResult transform::ReplaceOp::verify() {
 //===----------------------------------------------------------------------===//
 
 DiagnosedSilenceableFailure
-transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
+transform::ScalarizeOp::applyToOne(LinalgOp target,
                                    transform::ApplyToEachResultList &results,
                                    transform::TransformState &state) {
   scf::SCFTilingOptions tilingOptions;
@@ -1560,7 +1824,7 @@ void transform::SplitReductionOp::build(
 }
 
 DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
-    linalg::LinalgOp target, transform::ApplyToEachResultList &results,
+    LinalgOp target, transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   ControlSplitReductionFn splitFn = [&](LinalgOp) {
     return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
@@ -1605,7 +1869,7 @@ void transform::TileReductionUsingScfOp::build(
 }
 
 DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne(
-    linalg::LinalgOp target, transform::ApplyToEachResultList &results,
+    LinalgOp target, transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   TrivialPatternRewriter rewriter(getContext());
   rewriter.setInsertionPoint(target);
@@ -1649,7 +1913,7 @@ void transform::TileReductionUsingForeachThreadOp::build(
 
 DiagnosedSilenceableFailure
 transform::TileReductionUsingForeachThreadOp::applyToOne(
-    linalg::LinalgOp target, transform::ApplyToEachResultList &results,
+    LinalgOp target, transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   TrivialPatternRewriter rewriter(getContext());
   rewriter.setInsertionPoint(target);
diff --git a/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir b/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir
new file mode 100644 (file)
index 0000000..42f3a6c
--- /dev/null
@@ -0,0 +1,228 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter --split-input-file | FileCheck %s
+
+!A_mk = tensor<1023x255xf32>
+!B_kn = tensor<255x127xf32>
+!C_mn = tensor<1023x127xf32>
+
+// Normalized dims are:                     ( k,  m,  n)(kk, mm, nn)
+// CHECK-DAG: #[[$mk_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>
+// CHECK-DAG: #[[$kn_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// CHECK-DAG: #[[$mn_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
+
+// CHECK-LABEL: @matmul_mk_kn_mn(
+func.func @matmul_mk_kn_mn(%A : !A_mk, %B : !B_kn, %C : !C_mn) -> !C_mn {
+  //      CHECK: linalg.generic
+  // CHECK-SAME: indexing_maps = [#[[$mk_kkmm]], #[[$kn_kknn]], #[[$mn_mmnn]]]
+  // CHECK-SAME:   ["reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]} 
+  // CHECK-SAME:   ins(%{{.*}} : tensor<128x8x32x8xf32>, tensor<8x8x32x16xf32>)
+  // CHECK-SAME:  outs(%{{.*}} : tensor<128x8x8x16xf32>)
+  %0 = linalg.matmul ins(%A, %B : !A_mk, !B_kn) outs(%C : !C_mn) -> !C_mn
+  return %0 : !C_mn
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+  %matmul = transform.structured.match ops{["linalg.matmul"]} in %module_op 
+    : (!pdl.operation) -> !transform.op<"linalg.matmul">
+  transform.structured.pack_greedily %matmul 
+      gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0]
+    : (!transform.op<"linalg.matmul">) -> !transform.op<"linalg.generic">
+}
+
+// -----
+
+!A_mk = tensor<1023x255xf32>
+!B_nk = tensor<127x255xf32>
+!C_nm = tensor<127x1023xf32>
+
+#mkn_accesses = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (n, k)>,
+  affine_map<(m, n, k) -> (n, m)>
+]
+#mkn_trait = {
+  indexing_maps = #mkn_accesses,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// Normalized dims are:                     ( k,  m,  n)(kk, mm, nn)
+// CHECK-DAG: #[[$km_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>
+// CHECK-DAG: #[[$kn_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)>
+// CHECK-DAG: #[[$mn_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)>
+
+// CHECK-LABEL: @matmul_mk_nk_nm(
+func.func @matmul_mk_nk_nm(%A : !A_mk, %B : !B_nk, %C : !C_nm) -> !C_nm {
+  //      CHECK: linalg.generic
+  // CHECK-SAME: indexing_maps = [#[[$mk_kkmm]], #[[$kn_kknn]], #[[$mn_mmnn]]]
+  // CHECK-SAME:   ["reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]} 
+  // CHECK-SAME:   ins(%{{.*}} : tensor<128x8x32x8xf32>, tensor<8x8x32x16xf32>)
+  // CHECK-SAME:  outs(%{{.*}} : tensor<8x128x8x16xf32>)
+  %0 = linalg.generic #mkn_trait ins(%A, %B : !A_mk, !B_nk) outs(%C : !C_nm) {
+    ^bb0(%a: f32, %b: f32, %c: f32):
+      %d = arith.mulf %a, %b : f32
+      %e = arith.addf %c, %d : f32
+      linalg.yield %e : f32
+  } -> !C_nm
+  return %0 : !C_nm
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+  %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic">
+  transform.structured.pack_greedily %generic
+      gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0]
+    : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic">
+}
+
+// -----
+
+!A_mk = tensor<1023x255xf32>
+!B_nk = tensor<127x255xf32>
+!C_nm = tensor<127x1023xf32>
+
+#mkn_accesses = [
+  affine_map<(k, m, n) -> (m, k)>,
+  affine_map<(k, m, n) -> (n, k)>,
+  affine_map<(k, m, n) -> (n, m)>
+]
+#mkn_trait = {
+  indexing_maps = #mkn_accesses,
+  iterator_types = ["reduction", "parallel", "parallel"]
+}
+
+// Normalized dims are:                     ( k,  m,  n)(kk, mm, nn)
+// CHECK-DAG: #[[$mk_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>
+// CHECK-DAG: #[[$kn_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)>
+// CHECK-DAG: #[[$mn_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)>
+
+// CHECK-LABEL: @matmul_mk_nk_nm_transposed(
+func.func @matmul_mk_nk_nm_transposed(%A : !A_mk, %B : !B_nk, %C : !C_nm) -> !C_nm {
+  //      CHECK: linalg.generic
+  // CHECK-SAME: indexing_maps = [#[[$mk_kkmm]], #[[$kn_kknn]], #[[$mn_mmnn]]]
+  // CHECK-SAME:   ["reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]} 
+  // CHECK-SAME:   ins(%{{.*}} : tensor<128x8x32x8xf32>, tensor<8x8x32x16xf32>)
+  // CHECK-SAME:  outs(%{{.*}} : tensor<8x128x8x16xf32>)
+  %0 = linalg.generic #mkn_trait ins(%A, %B : !A_mk, !B_nk) outs(%C : !C_nm) {
+    ^bb0(%a: f32, %b: f32, %c: f32):
+      %d = arith.mulf %a, %b : f32
+      %e = arith.addf %c, %d : f32
+      linalg.yield %e : f32
+  } -> !C_nm
+  return %0 : !C_nm
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+  %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic">
+  transform.structured.pack_greedily %generic
+      gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0]
+    : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic">
+}
+
+// -----
+
+!A_bmkm2 = tensor<42x1023x255x33xf32>
+!B_nkb = tensor<127x255x42xf32>
+!C_nbm = tensor<127x42x1023xf32>
+
+#mkn_accesses = [
+  affine_map<(k, m, n, b, m2) -> (b, m, k, m2)>,
+  affine_map<(k, m, n, b, m2) -> (n, k, b)>,
+  affine_map<(k, m, n, b, m2) -> (n, b, m)>
+]
+#mkn_trait = {
+  indexing_maps = #mkn_accesses,
+  iterator_types = ["reduction", "parallel", "parallel", "parallel", "parallel"]
+}
+
+// Normalized dims are:                        ( ?,  ?,  k,  m,  n)(kk, mm, nn)
+// CHECK-DAG: #[[$bmkm2_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d3, d2, d1, d5, d6)>
+// CHECK-DAG:   #[[$nkb_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d2, d0, d5, d7)>
+// CHECK-DAG:   #[[$nbm_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d0, d3, d6, d7)>
+
+// CHECK-LABEL: @contraction_bmkm2_nkb_nbm(
+func.func @contraction_bmkm2_nkb_nbm(%A : !A_bmkm2, %B : !B_nkb, %C : !C_nbm) -> !C_nbm {
+  //      CHECK: linalg.generic
+  // CHECK-SAME: indexing_maps = [#[[$bmkm2_kkmm]], #[[$nkb_kknn]], #[[$nbm_mmnn]]]
+  // CHECK-SAME:   ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]} 
+  // CHECK-SAME:   ins(%{{.*}} : tensor<42x128x8x33x32x8xf32>, tensor<8x8x42x32x16xf32>)
+  // CHECK-SAME:  outs(%{{.*}} : tensor<8x42x128x8x16xf32>)
+  %0 = linalg.generic #mkn_trait ins(%A, %B : !A_bmkm2, !B_nkb) outs(%C : !C_nbm) {
+    ^bb0(%a: f32, %b: f32, %c: f32):
+      %d = arith.mulf %a, %b : f32
+      %e = arith.addf %c, %d : f32
+      linalg.yield %e : f32
+  } -> !C_nbm
+  return %0 : !C_nbm
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+  %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic">
+  transform.structured.pack_greedily %generic
+      gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0]
+    : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic">
+}
+
+// -----
+
+// Conv linguo:                          h   w  kh  kw   c   n   f  cc  nn  ff
+// Normalized dims are:                ( ?,  ?,  ?,  ?,  k,  m,  n)(kk, mm, nn)
+//                                                                                   n   c   h + kh   w + kw  cc  nn
+// CHECK-DAG: #[[$M1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (d5, d4, d0 + d2, d1 + d3, d7, d8)>
+//                                                                                   f   c  kh  kw  cc  ff
+// CHECK-DAG: #[[$M2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (d6, d4, d2, d3, d7, d9)>
+//                                                                                   n   f   h   w  nn  ff
+// CHECK-DAG: #[[$M3:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (d5, d6, d0, d1, d8, d9)>
+
+// CHECK-LABEL: @conv_2d_nchw_fchw
+func.func @conv_2d_nchw_fchw(%arg0: tensor<?x47x16x16xf32>, %arg2: tensor<?x16x14x14xf32>) -> tensor<?x16x14x14xf32> {
+  %c0 = arith.constant dense<0.1> : tensor<16x47x3x3xf32>
+  //      CHECK: linalg.generic
+  // CHECK-SAME: indexing_maps = [#[[$M1]], #[[$M2]], #[[$M3]]]
+  // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "reduction", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]
+  // CHECK-SAME:  ins(%{{.*}} : tensor<?x2x16x16x32x8xf32>, tensor<1x2x3x3x32x16xf32>)
+  // CHECK-SAME: outs(%{{.*}} : tensor<?x1x14x14x8x16xf32>)
+  %0 = linalg.conv_2d_nchw_fchw
+    {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+     ins(%arg0, %c0: tensor<?x47x16x16xf32>, tensor<16x47x3x3xf32>)
+    outs(%arg2: tensor<?x16x14x14xf32>) -> tensor<?x16x14x14xf32>
+  return %0 : tensor<?x16x14x14xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+  %conv = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %module_op 
+    : (!pdl.operation) -> !transform.op<"linalg.conv_2d_nchw_fchw">
+  transform.structured.pack_greedily %conv
+      gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0]
+    : (!transform.op<"linalg.conv_2d_nchw_fchw">) -> !transform.op<"linalg.generic">
+}
+
+
+// -----
+
+// These should fail to pack for now as they don't contain a contraction.
+// CHECK-LABEL: @reduce_and_map
+func.func @reduce_and_map(%arg0: tensor<10x100xf32>,
+    %arg1: tensor<10x100xf32>, %output: tensor<10xf32>) -> tensor<10xf32> {
+  %map_init = tensor.empty() : tensor<10x100xf32>
+  // CHECK: linalg.map
+  %mapped = linalg.map { arith.addf }
+              ins(%arg0, %arg1 : tensor<10x100xf32>, tensor<10x100xf32>)
+              outs(%map_init : tensor<10x100xf32>)
+  // CHECK: linalg.reduce
+  %res = linalg.reduce { arith.addf }
+           ins(%mapped: tensor<10x100xf32>)
+           outs(%output: tensor<10xf32>)
+           dimensions = [1]
+  return %res : tensor<10xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+  %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic">
+  transform.structured.pack_greedily %generic
+      gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0]
+    : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic">
+}