[mlir][Linalg] Refactor isaContractionOpInterface and surrounding utils
authorNicolas Vasilache <nicolasvasilache@users.noreply.github.com>
Tue, 27 Jun 2023 19:27:47 +0000 (19:27 +0000)
committerNicolas Vasilache <nicolasvasilache@users.noreply.github.com>
Wed, 28 Jun 2023 10:19:00 +0000 (10:19 +0000)
This is almost NFC except for the fact that:
- when multiple candidates are available we now return them in sorted order vs undetermined order previously
- the type of the transform return is relaxed an a test is added for the case where the transform does not apply

Differential Revision: https://reviews.llvm.org/D153941

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
mlir/test/Dialect/Linalg/transform-pack-greedily.mlir

index cb93e8a..0562f37 100644 (file)
@@ -36,13 +36,61 @@ bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp,
                                 ArrayRef<OpOperand *> droppedOperands);
 } // namespace detail
 
+/// Positions of a Linalg op loops that correspond to different kinds of a
+/// contraction dimension.
+struct ContractionDimensions {
+  SmallVector<unsigned, 2> batch;
+  SmallVector<unsigned, 2> m;
+  SmallVector<unsigned, 2> n;
+  SmallVector<unsigned, 2> k;
+};
+
+/// Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates
+/// that form a matmul subcomputation within `linalgOp`.
+/// These dimensions are such that:
+///   1. The m dimension is involved in an outer-product along LHS
+///      (i.e. it is a permutation on RES and LHS and does not appear in RHS).
+///   2. The n dimension is involved in an outer-product along RHS
+///      (i.e. it is a permutation on RES and RHS and does not appear in LHS).
+///   3. The k dimension appears as a permutation on LHS and RHS.
+///   4. m, n and k appear only once in any given indexing.
+///   5. Optional batch dimensions that appear in all operands are captured.
+/// This allows e.g. detecting that some contraction is embedded within
+/// `linalgOp` with some orthogonal heuristic.
+/// When multiple dimension occurrences exist that match `batch`, `m`, `n`, or
+/// `k`, indices are returned in sorted order.
+/// Returns a failure if any of `m`, `n` or `k` is empty.
+FailureOr<ContractionDimensions> inferContractionDims(LinalgOp linalgOp);
+
 /// Checks whether `linalgOp` conforms to ContractionOpInterface.
 // TODO: embed within `isa<ContractionOpInterface>` if possible / natural.
 bool isaContractionOpInterface(LinalgOp linalgOp);
 
+/// Checks whether `linalgOp` conforms to ConvolutionOpInterface.
+// TODO: embed within `isa<ConvolutionOpInterface>` if possible / natural.
+bool isaConvolutionOpInterface(LinalgOp linalgOp);
+
 namespace detail {
 
 /// Result of matching a Linalg generic against the predicates of it being a
+/// contractiom.
+enum class MatchContractionResult;
+
+/// Checks whether `op` conforms to ContractionOpInterface and populates
+/// `dimensions` with indexes of the different kinds of dimensions when
+/// present.
+// TODO: Extract a standalone `inferConvolutionDims` that can also detect
+// whether a conv pattern exists within a bigger linalg op (see
+// inferContractionDims).
+MatchContractionResult
+isContractionInterfaceImpl(Operation *op,
+                           ContractionDimensions *dimensions = nullptr);
+
+/// Returns the error message corresponding to the contraction checking return
+/// code.
+StringRef getMatchContractionMessage(MatchContractionResult res);
+
+/// Result of matching a Linalg generic against the predicates of it being a
 /// convolution.
 enum class MatchConvolutionResult;
 
@@ -58,7 +106,8 @@ struct ConvolutionDimensions {
 };
 
 /// Checks whether `op` conforms to ConvolutionOpInterface and populates
-/// `dimensions` with indexes of the different kinds of dimensions when present.
+/// `dimensions` with indexes of the different kinds of dimensions when
+/// present.
 MatchConvolutionResult
 isConvolutionInterfaceImpl(Operation *op,
                            ConvolutionDimensions *dimensions = nullptr);
index 3943e62..bccdeaa 100644 (file)
@@ -722,7 +722,7 @@ def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
                                  :$matmul_padded_sizes_next_multiple_of,
                    ConfinedAttr<DefaultValuedAttr<DenseI64ArrayAttr, "{}">,
                                  [DenseArrayCount<3>]>:$matmul_inner_dims_order);
-  let results = (outs Transform_ConcreteOpType<"linalg.generic">:$packed_op);
+  let results = (outs TransformHandleTypeInterface:$packed_op);
 
   let builders = [
     OpBuilder<(ins "Value":$target,
index ccd650c..a4c8bae 100644 (file)
@@ -35,41 +35,6 @@ namespace linalg {
 // Utilities for inferring various semantics properties of Linalg ops.
 //===----------------------------------------------------------------------===//
 
-/// Possible dimension candidates that define a contraction embedded in the
-/// indexing maps of a LinalgOp.
-struct EmbeddedContractionDimsCandidates {
-  DenseSet<int64_t> batchPos, mPos, nPos, kPos;
-};
-
-/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the
-/// iterators of type `iter` that index the `opOperand` as a permutation.
-/// This is useful to infer various subcomputations on a given `linalgOp`.
-/// This is performed by looking up each result in the matching indexing map and
-/// determining whether:
-///   - It is a single AffineDimExpr.
-///   - It is the only result involving this AffineDimExpr.
-DenseSet<int64_t> findPermutationsIndexingOperand(LinalgOp linalgOp,
-                                                  OpOperand *opOperand,
-                                                  utils::IteratorType iter);
-
-/// Return true if `linalgOp` contains an embedded matmul subcomputation in its
-/// most minor dimensions.
-bool containsMostMinorMatmul(linalg::LinalgOp linalgOp);
-
-/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
-/// a matmul subcomputation within `linalgOp`. These dimensions are such that:
-///   1. The m dimension is involved in an outer-product along LHS
-///      (i.e. it is a permutation on RES and LHS and does not appear in RHS).
-///   2. The n dimension is involved in an outer-product along RHS
-///      (i.e. it is a permutation on RES and RHS and does not appear in LHS).
-///   3. The k dimension appears as a permutation on LHS and RHS.
-///   4. m, n and k appear only once in any given indexing.
-///   5. Optional batch dimensions that appear in all operands are captured.
-/// This allows e.g. detecting that some contraction is embedded within
-/// `linalgOp` with some orthogonal heuristic.
-FailureOr<EmbeddedContractionDimsCandidates>
-inferContractionDims(linalg::LinalgOp linalgOp);
-
 //===----------------------------------------------------------------------===//
 // General utilities
 //===----------------------------------------------------------------------===//
index 1372043..e928cc7 100644 (file)
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include <algorithm>
 
 using namespace mlir;
 using namespace mlir::linalg;
@@ -112,6 +115,96 @@ static bool isAddMul(Block &block) {
   return success;
 }
 
+/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the
+/// iterators of type `iter` that index the `opOperand` as a permutation.
+/// This is useful to infer various subcomputations on a given `linalgOp`.
+/// This is performed by looking up each result in the matching indexing map and
+/// determining whether:
+///   - It is a single AffineDimExpr.
+///   - It is the only result involving this AffineDimExpr.
+static DenseSet<int64_t>
+findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
+                                utils::IteratorType iter) {
+  DenseSet<int64_t> res;
+  assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
+  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
+  for (AffineExpr e : indexingMap.getResults()) {
+    if (auto d = e.dyn_cast<AffineDimExpr>()) {
+      if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter &&
+          llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
+            return e.isFunctionOfDim(d.getPosition());
+          }) == 1)
+        res.insert(d.getPosition());
+    }
+  }
+  return res;
+}
+
+namespace {
+auto par = utils::IteratorType::parallel;
+auto red = utils::IteratorType::reduction;
+} // namespace
+
+/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
+/// a matmul subcomputation within `linalgOp`. These dimensions are such that:
+///   1. The m dimension is involved in an outer-product along LHS
+///      (i.e. it is a permutation on RES and LHS and does not appear in RHS).
+///   2. The n dimension is involved in an outer-product along RHS
+///      (i.e. it is a permutation on RES and RHS and does not appear in LHS).
+///   3. The k dimension appears as a permutation on LHS and RHS.
+///   4. m, n and k appear only once in any given indexing.
+///   5. Optional batch dimensions that appear in all operands are captured.
+/// This allows e.g. detecting that some contraction is embedded within
+/// `linalgOp` with some orthogonal heuristic.
+FailureOr<ContractionDimensions>
+mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
+  if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
+    return failure();
+
+  DenseSet<int64_t> a = findPermutationsIndexingOperand(
+      linalgOp, linalgOp.getDpsInputOperand(0), par);
+  DenseSet<int64_t> b = findPermutationsIndexingOperand(
+      linalgOp, linalgOp.getDpsInputOperand(1), par);
+  DenseSet<int64_t> c = findPermutationsIndexingOperand(
+      linalgOp, linalgOp.getDpsInitOperand(0), par);
+
+  // A & C - B are the iterators involved in an outer-product along A (the LHS).
+  DenseSet<int64_t> ac = a;
+  llvm::set_intersect(ac, c);
+  llvm::set_subtract(ac, b);
+  // B & C - A are the iterators involved in an outer-product along B (the RHS).
+  DenseSet<int64_t> bc = b;
+  llvm::set_intersect(bc, c);
+  llvm::set_subtract(bc, a);
+  // A & B & C are the "batch" dimensions.
+  DenseSet<int64_t> batches = a;
+  llvm::set_intersect(batches, b);
+  llvm::set_intersect(batches, c);
+
+  // A & B red are the reduction dimensions.
+  DenseSet<int64_t> ra = findPermutationsIndexingOperand(
+      linalgOp, linalgOp.getDpsInputOperand(0), red);
+  DenseSet<int64_t> rb = findPermutationsIndexingOperand(
+      linalgOp, linalgOp.getDpsInputOperand(1), red);
+  llvm::set_intersect(ra, rb);
+
+  if (ac.empty() || bc.empty() || ra.empty())
+    return failure();
+
+  // Return each set in sorted order.
+  ContractionDimensions dimensions{
+      SmallVector<unsigned, 2>(batches.begin(), batches.end()),
+      SmallVector<unsigned, 2>(ac.begin(), ac.end()),
+      SmallVector<unsigned, 2>(bc.begin(), bc.end()),
+      SmallVector<unsigned, 2>(ra.begin(), ra.end())};
+  std::sort(dimensions.batch.begin(), dimensions.batch.end());
+  std::sort(dimensions.m.begin(), dimensions.m.end());
+  std::sort(dimensions.n.begin(), dimensions.n.end());
+  std::sort(dimensions.k.begin(), dimensions.k.end());
+  return dimensions;
+}
+
+namespace mlir::linalg::detail {
 enum class MatchContractionResult {
   Success = 0,
   NotLinalgOp,
@@ -120,7 +213,11 @@ enum class MatchContractionResult {
   NotProjectedPermutations,
   NotAddMul
 };
-static MatchContractionResult isContractionInterfaceImpl(Operation *op) {
+} // namespace mlir::linalg::detail
+
+mlir::linalg::detail::MatchContractionResult
+mlir::linalg::detail::isContractionInterfaceImpl(
+    Operation *op, mlir::linalg::ContractionDimensions *dimensions) {
   auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
   if (!linalgOp)
     return MatchContractionResult::NotLinalgOp;
@@ -139,15 +236,41 @@ static MatchContractionResult isContractionInterfaceImpl(Operation *op) {
           linalgOp->getRegion(0).front()) &&
       !isAddMul<arith::OrIOp, arith::AndIOp>(linalgOp->getRegion(0).front()))
     return MatchContractionResult::NotAddMul;
+
+  if (dimensions) {
+    FailureOr<ContractionDimensions> res = inferContractionDims(linalgOp);
+    assert(succeeded(res) && "unexpected failure to infer contraction dims");
+    *dimensions = *res;
+  }
   return MatchContractionResult::Success;
 }
 
+StringRef
+mlir::linalg::detail::getMatchContractionMessage(MatchContractionResult res) {
+  switch (res) {
+  case MatchContractionResult::NotLinalgOp:
+    return "expected a LinalgOp";
+  case MatchContractionResult::WrongNumOperands:
+    return "expected op with 2 inputs and 1 output";
+  case MatchContractionResult::NoReduction:
+    return "expected at least 1 reduction";
+  case MatchContractionResult::NotProjectedPermutations:
+    return "expected indexing maps to be projected permutations";
+  case MatchContractionResult::NotAddMul:
+    return "expected add/mul op in the body";
+  case MatchContractionResult::Success:
+    return "";
+  }
+  llvm_unreachable("unhandled MatchContractionResult case");
+}
+
 bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) {
   if (!linalgOp)
     return false;
   Operation *op = linalgOp.getOperation();
   return isa<ContractionOpInterface>(op) ||
-         (isContractionInterfaceImpl(op) == MatchContractionResult::Success);
+         (mlir::linalg::detail::isContractionInterfaceImpl(op) ==
+          mlir::linalg::detail::MatchContractionResult::Success);
 }
 
 /// Verify that a LinalgOp `op` is a contraction.
@@ -165,16 +288,8 @@ bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) {
 /// constant operations that do not involve the reduction dimension(s).
 LogicalResult mlir::linalg::detail::verifyContractionInterface(Operation *op) {
   auto res = isContractionInterfaceImpl(op);
-  if (res == MatchContractionResult::NotLinalgOp)
-    return op->emitError("expected a LinalgOp");
-  if (res == MatchContractionResult::WrongNumOperands)
-    return op->emitError("expected op with 2 inputs and 1 outputs");
-  if (res == MatchContractionResult::NoReduction)
-    return op->emitError("expected at least a reduction loop");
-  if (res == MatchContractionResult::NotProjectedPermutations)
-    return op->emitError("expected all indexings to be projected permutations");
-  if (res == MatchContractionResult::NotAddMul)
-    return op->emitError("(add, mul) operations not found");
+  if (res != MatchContractionResult::Success)
+    return op->emitError(getMatchContractionMessage(res));
   return success();
 }
 
@@ -454,6 +569,11 @@ mlir::linalg::detail::getMatchConvolutionMessage(MatchConvolutionResult res) {
   llvm_unreachable("unhandled MatchConvolutionResult case");
 }
 
+bool mlir::linalg::isaConvolutionOpInterface(LinalgOp linalgOp) {
+  return linalg::detail::isConvolutionInterfaceImpl(linalgOp.getOperation()) ==
+         linalg::detail::MatchConvolutionResult::Success;
+}
+
 LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
   MatchConvolutionResult res = isConvolutionInterfaceImpl(op);
   if (res != MatchConvolutionResult::Success)
index 49ed546..d702e6d 100644 (file)
@@ -49,6 +49,7 @@ using namespace mlir::transform;
 #define DEBUG_TYPE "linalg-transforms"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
 #define DBGSNL() (llvm::dbgs() << "\n")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
 /// Attempts to apply the pattern specified as template argument to the given
 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
@@ -1227,6 +1228,8 @@ packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
 
   int64_t numLoops = linalgOp.getNumLoops();
   if (numLoops <= 2) {
+    LDBG("need 3+ loops to find a matmul to pack, got "
+         << numLoops << "\nin: " << linalgOp << "\n");
     return rewriter.notifyMatchFailure(
         linalgOp, "need 3+ loops to find a matmul to pack");
   }
@@ -1247,17 +1250,21 @@ packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
   }
 
   // 1. Infer dims that are important for matmul.
-  FailureOr<EmbeddedContractionDimsCandidates> res = inferContractionDims(linalgOp);
-  if (failed(res)) {
+  FailureOr<ContractionDimensions> maybeDimensions =
+      inferContractionDims(linalgOp);
+  if (failed(maybeDimensions)) {
+    LDBG("couldn't infer matmul iterators in: " << linalgOp << "\n");
     return rewriter.notifyMatchFailure(linalgOp,
                                        "couldn't infer matmul iterators");
   }
 
   // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most
-  // minor iterators. If we wanted a different normalization order, this is
-  // where it would have to plug a heuristic.
-  int64_t mPos = *(res->mPos.begin()), nPos = *(res->nPos.begin()),
-          kPos = *(res->kPos.begin());
+  // minor iterators. In cases with multiple options for m, n, k bias towards
+  // the most minor embedding.
+  // If we wanted a different normalization order, this is where it would have
+  // to plug a heuristic.
+  int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
+          kPos = maybeDimensions->k.back();
   LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
              DBGS() << "Start packing generic op greedily with (m@" << mPos
                     << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp
@@ -2655,71 +2662,71 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
     ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
     linalg::ForallTilingResult &tilingResult) {
   // Transform all targets one by one.
-    auto tileableOp = dyn_cast<TilingInterface>(target);
-    if (!tileableOp) {
-      DiagnosedSilenceableFailure diag =
-          transformOp.emitSilenceableError()
-          << "only TilingInterface ops are supported";
-      diag.attachNote(target->getLoc()) << "target op";
-      return diag;
-    }
-    rewriter.setInsertionPoint(tileableOp);
-    FailureOr<linalg::ForallTilingResult> maybeTilingResult = failure();
-    if (!mixedNumThreads.empty()) {
-      maybeTilingResult = linalg::tileToForallOp(rewriter, tileableOp,
-                                                 mixedNumThreads, mapping);
-    } else {
-      maybeTilingResult = linalg::tileToForallOpUsingTileSizes(
-          rewriter, tileableOp, mixedTileSizes, mapping);
-    }
+  auto tileableOp = dyn_cast<TilingInterface>(target);
+  if (!tileableOp) {
+    DiagnosedSilenceableFailure diag =
+        transformOp.emitSilenceableError()
+        << "only TilingInterface ops are supported";
+    diag.attachNote(target->getLoc()) << "target op";
+    return diag;
+  }
+  rewriter.setInsertionPoint(tileableOp);
+  FailureOr<linalg::ForallTilingResult> maybeTilingResult = failure();
+  if (!mixedNumThreads.empty()) {
+    maybeTilingResult =
+        linalg::tileToForallOp(rewriter, tileableOp, mixedNumThreads, mapping);
+  } else {
+    maybeTilingResult = linalg::tileToForallOpUsingTileSizes(
+        rewriter, tileableOp, mixedTileSizes, mapping);
+  }
 
-    if (failed(maybeTilingResult))
-      return transformOp.emitDefaultSilenceableFailure(tileableOp);
-    rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
+  if (failed(maybeTilingResult))
+    return transformOp.emitDefaultSilenceableFailure(tileableOp);
+  rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
 
-    tilingResult = *maybeTilingResult;
-    return DiagnosedSilenceableFailure::success();
+  tilingResult = *maybeTilingResult;
+  return DiagnosedSilenceableFailure::success();
 }
 
 DiagnosedSilenceableFailure
 transform::TileToForallOp::apply(transform::TransformRewriter &rewriter,
                                  transform::TransformResults &transformResults,
                                  transform::TransformState &state) {
-    auto transformOp = cast<TransformOpInterface>(getOperation());
-
-    // Result payload ops.
-    SmallVector<Operation *> tileOps;
-    SmallVector<Operation *> tiledOps;
-
-    // Unpack handles.
-    SmallVector<OpFoldResult> mixedNumThreads;
-    DiagnosedSilenceableFailure status =
-        getPackedNumThreads()
-            ? unpackSingleIndexResultPayloadOperations(
-                  state, transformOp, mixedNumThreads, getPackedNumThreads())
-            : unpackSingleIndexResultPayloadOperations(
-                  state, transformOp, mixedNumThreads, getMixedNumThreads());
-    if (!status.succeeded())
-      return status;
-    SmallVector<OpFoldResult> mixedTileSizes;
-    status = getPackedTileSizes()
-                 ? unpackSingleIndexResultPayloadOperations(
-                       state, transformOp, mixedTileSizes, getPackedTileSizes())
-                 : unpackSingleIndexResultPayloadOperations(
-                       state, transformOp, mixedTileSizes, getMixedTileSizes());
-    if (!status.succeeded())
-      return status;
-
-    for (Operation *target : state.getPayloadOps(getTarget())) {
-      linalg::ForallTilingResult tilingResult;
-      DiagnosedSilenceableFailure diag = tileToForallOpImpl(
-          rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
-          getMapping(), tilingResult);
-      if (!diag.succeeded())
+  auto transformOp = cast<TransformOpInterface>(getOperation());
+
+  // Result payload ops.
+  SmallVector<Operation *> tileOps;
+  SmallVector<Operation *> tiledOps;
+
+  // Unpack handles.
+  SmallVector<OpFoldResult> mixedNumThreads;
+  DiagnosedSilenceableFailure status =
+      getPackedNumThreads()
+          ? unpackSingleIndexResultPayloadOperations(
+                state, transformOp, mixedNumThreads, getPackedNumThreads())
+          : unpackSingleIndexResultPayloadOperations(
+                state, transformOp, mixedNumThreads, getMixedNumThreads());
+  if (!status.succeeded())
+    return status;
+  SmallVector<OpFoldResult> mixedTileSizes;
+  status = getPackedTileSizes()
+               ? unpackSingleIndexResultPayloadOperations(
+                     state, transformOp, mixedTileSizes, getPackedTileSizes())
+               : unpackSingleIndexResultPayloadOperations(
+                     state, transformOp, mixedTileSizes, getMixedTileSizes());
+  if (!status.succeeded())
+    return status;
+
+  for (Operation *target : state.getPayloadOps(getTarget())) {
+    linalg::ForallTilingResult tilingResult;
+    DiagnosedSilenceableFailure diag = tileToForallOpImpl(
+        rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
+        getMapping(), tilingResult);
+    if (!diag.succeeded())
       return diag;
     tileOps.push_back(tilingResult.tileOp);
     tiledOps.push_back(tilingResult.tiledOp);
-    }
+  }
 
   transformResults.set(cast<OpResult>(getForallOp()), tileOps);
   transformResults.set(cast<OpResult>(getTiledOp()), tiledOps);
index e3a9569..55da709 100644 (file)
@@ -33,7 +33,6 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/Pass/Pass.h"
-#include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include <optional>
@@ -141,86 +140,6 @@ static void unpackRanges(OpBuilder &builder, Location loc,
 }
 
 //===----------------------------------------------------------------------===//
-// Utilities for inferring various semantics properties of Linalg ops.
-//===----------------------------------------------------------------------===//
-
-DenseSet<int64_t> mlir::linalg::findPermutationsIndexingOperand(
-    LinalgOp linalgOp, OpOperand *opOperand, utils::IteratorType iter) {
-  DenseSet<int64_t> res;
-  assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
-  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
-  for (AffineExpr e : indexingMap.getResults()) {
-    if (auto d = e.dyn_cast<AffineDimExpr>()) {
-      if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter &&
-          llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
-            return e.isFunctionOfDim(d.getPosition());
-          }) == 1)
-        res.insert(d.getPosition());
-    }
-  }
-  return res;
-}
-
-namespace {
-auto par = utils::IteratorType::parallel;
-auto red = utils::IteratorType::reduction;
-} // namespace
-
-bool mlir::linalg::containsMostMinorMatmul(LinalgOp linalgOp) {
-  FailureOr<EmbeddedContractionDimsCandidates> res = inferContractionDims(linalgOp);
-  if (failed(res))
-    return false;
-  int64_t numLoops = linalgOp.getNumLoops();
-  for (const DenseSet<int64_t> &s : {res->mPos, res->nPos, res->kPos}) {
-    if (s.contains(numLoops - 3) || s.contains(numLoops - 2) ||
-        s.contains(numLoops - 1))
-      continue;
-    return false;
-  }
-  return true;
-}
-
-FailureOr<EmbeddedContractionDimsCandidates>
-mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
-  if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
-    return failure();
-
-  DenseSet<int64_t> a = findPermutationsIndexingOperand(
-      linalgOp, linalgOp.getDpsInputOperand(0), par);
-  DenseSet<int64_t> b = findPermutationsIndexingOperand(
-      linalgOp, linalgOp.getDpsInputOperand(1), par);
-  DenseSet<int64_t> c = findPermutationsIndexingOperand(
-      linalgOp, linalgOp.getDpsInitOperand(0), par);
-
-  // A & C - B are the iterators involved in an outer-product along A (the LHS).
-  DenseSet<int64_t> ac = a;
-  llvm::set_intersect(ac, c);
-  llvm::set_subtract(ac, b);
-  // B & C - A are the iterators involved in an outer-product along B (the RHS).
-  DenseSet<int64_t> bc = b;
-  llvm::set_intersect(bc, c);
-  llvm::set_subtract(bc, a);
-  // A & B & C are the "batch" dimensions.
-  DenseSet<int64_t> batches = a;
-  llvm::set_intersect(batches, b);
-  llvm::set_intersect(batches, c);
-
-  // A & B red are the reduction dimensions.
-  DenseSet<int64_t> ra = findPermutationsIndexingOperand(
-      linalgOp, linalgOp.getDpsInputOperand(0), red);
-  DenseSet<int64_t> rb = findPermutationsIndexingOperand(
-      linalgOp, linalgOp.getDpsInputOperand(1), red);
-  llvm::set_intersect(ra, rb);
-
-  if (ac.empty() || bc.empty() || ra.empty())
-    return failure();
-
-  // Pick the first one in each set.
-  // TODO: Better heuristic (e.g pick dims based on packing-based metric).
-  return EmbeddedContractionDimsCandidates{batches, ac, bc, ra};
-}
-
-//===----------------------------------------------------------------------===//
 // General utilities
 //===----------------------------------------------------------------------===//
 
index d67016e..8fb6ed1 100644 (file)
@@ -59,7 +59,9 @@ transform::OperationType::checkPayload(Location loc,
   for (Operation *op : payload) {
     if (opName != op->getName()) {
       DiagnosedSilenceableFailure diag =
-          emitSilenceableError(loc) << "incompatible payload operation name";
+          emitSilenceableError(loc)
+          << "incompatible payload operation name expected " << opName << " vs "
+          << op->getName() << " -> " << *op;
       diag.attachNote(op->getLoc()) << "payload operation";
       return diag;
     }
index 374c1d2..68f0706 100644 (file)
@@ -326,3 +326,25 @@ transform.sequence failures(propagate) {
       matmul_inner_dims_order = [1, 2, 0]
     : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic">
 }
+
+// -----
+
+!A = tensor<1023x255xf32>
+!X = tensor<255xf32>
+!Y = tensor<1023xf32>
+
+// CHECK-LABEL: @matvec_fail(
+func.func @matvec_fail(%A : !A, %x : !X, %y : !Y) -> !Y {
+  //      CHECK: linalg.matvec
+  %0 = linalg.matvec ins(%A, %x : !A, !X) outs(%y : !Y) -> !Y
+  return %0 : !Y
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+  %matmul = transform.structured.match ops{["linalg.matvec"]} in %module_op 
+    : (!transform.any_op) -> !transform.op<"linalg.matvec">
+  transform.structured.pack_greedily %matmul 
+      matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0]
+    : (!transform.op<"linalg.matvec">) -> !transform.any_op
+}