[mlir][Linalg] NFC - Refactor hoist padding transform and separate packing creation...
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Mon, 3 Apr 2023 15:17:24 +0000 (08:17 -0700)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 4 Apr 2023 09:45:15 +0000 (02:45 -0700)
This revision is in prevision for upcoming improvements.

Differential Revision: https://reviews.llvm.org/D147458

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/SCF/Utils/Utils.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
mlir/lib/Dialect/Tensor/Utils/Utils.cpp
mlir/test/Dialect/Linalg/transform-op-hoist-pad-build-packing-loop-nest.mlir [new file with mode: 0644]
mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir

index 3b80712..415c799 100644 (file)
@@ -794,9 +794,51 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
 // HoistPadOp
 //===----------------------------------------------------------------------===//
 
+def HoistPadBuildPackingLoopNestOp :
+    Op<Transform_Dialect,
+       "structured.hoist_pad.build_packing_loop_nest",
+    [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let description = [{
+    Helper transform used to hoist a tensor.pad target operation. This operation
+    creates the packing loop nest required by the hoist_pad operation and makes
+    that functionality available independently.
+
+    TODO: In the future, we should consider rewriting as a tensor.pack after
+    hoisting since this abstraction is now available.
+
+    #### Return modes
+
+    This operation ignores non-tensor.pad ops and drops them in the result.
+    If any non-tensor.pad is passed, the transform emits a silenceable failure.
+
+    The return handle points to only the subset of successfully created packing
+    loop nests, which can be empty.
+  }];
+
+  // Also allow any !pdl.operation for simpler composition. Non-tensor.pad ops
+  // will be dropped from the results.
+  let arguments =
+    (ins TransformHandleTypeInterface:$target,
+         TransformHandleTypeInterface:$loop,
+         DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$transpose);
+  let results = (outs TransformHandleTypeInterface:$packing_loop);
+
+  let assemblyFormat = [{
+    $target
+    `above` $loop
+    (`,` `transpose` `by` $transpose^)?
+    attr-dict
+    `:` functional-type(operands, results)
+  }];
+  let hasVerifier = 1;
+}
+
 def HoistPadOp : Op<Transform_Dialect, "structured.hoist_pad",
-    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-     TransformOpInterface, TransformEachOpTrait]> {
+    [FunctionalStyleTransformOpTrait,
+     MemoryEffectsOpInterface,
+     TransformOpInterface,
+     TransformEachOpTrait]> {
   let description = [{
     Hoist the tensor.pad target operation by at most the given number of loops.
     Optionally apply the transpose attribute to the inner dimensions.
index 8e38de4..05f1090 100644 (file)
@@ -362,6 +362,26 @@ rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
                   ArrayRef<Attribute> paddingValues,
                   ArrayRef<bool> packPaddings, LinalgOp &paddedOp);
 
+namespace detail {
+
+/// Helper struct to hold the results of building a packing loop nest.
+struct PackingResult {
+  SmallVector<OpFoldResult> offsets, sizes, strides;
+  SmallVector<Value> clonedLoopIvs, leadingPackedTensorIndexings;
+  GenericOp maybeTransposeOp;
+  tensor::PadOp hoistedPadOp;
+};
+
+/// Build the packing loop nest required to hoist `opToHoist` above
+/// `outermostEnclosingForOp`.
+/// The loop nest is built just before `outermostEnclosingForOp`.
+FailureOr<PackingResult>
+buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist,
+                     scf::ForOp outermostEnclosingForOp,
+                     ArrayRef<int64_t> transposeVector);
+
+} // namespace detail
+
 /// Mechanically hoist padding operations on tensors by `numLoops` into a new,
 /// generally larger tensor. This achieves packing of multiple padding ops into
 /// a larger tensor. On success, `opToHoist` is replaced by the cloned version
index 4d4baa9..0b251a6 100644 (file)
@@ -54,6 +54,18 @@ scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
                                     ValueRange newIterOperands,
                                     const NewYieldValueFn &newYieldValuesFn,
                                     bool replaceIterOperandsUsesInLoop = true);
+// Simpler API if the new yields are just a list of values that can be
+// determined ahead of time.
+inline scf::ForOp
+replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
+                         ValueRange newIterOperands, ValueRange newYields,
+                         bool replaceIterOperandsUsesInLoop = true) {
+  auto fn = [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
+    return SmallVector<Value>(newYields.begin(), newYields.end());
+  };
+  return replaceLoopWithNewYields(builder, loop, newIterOperands, fn,
+                                  replaceIterOperandsUsesInLoop);
+}
 
 /// Update a perfectly nested loop nest to yield new values from the innermost
 /// loop and propagating it up through the loop nest. This function
index e818069..56972a9 100644 (file)
@@ -1757,6 +1757,60 @@ LogicalResult transform::PadOp::verify() {
 // HoistPadOp
 //===---------------------------------------------------------------------===//
 
+DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
+    transform::TransformResults &transformResults,
+    transform::TransformState &state) {
+  ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
+  ArrayRef<Operation *> loopOps = state.getPayloadOps(getLoop());
+  if (targetOps.size() != 1 || loopOps.size() != 1) {
+    return emitDefiniteFailure()
+           << "requires exactly one target and one loop handle (got "
+           << targetOps.size() << " and " << loopOps.size() << ")";
+  }
+
+  auto padOp = dyn_cast_or_null<tensor::PadOp>(targetOps.front());
+  auto loopOp = dyn_cast_or_null<scf::ForOp>(loopOps.front());
+  if (!padOp || !loopOp)
+    return emitDefiniteFailure() << "requires exactly 2 non-null handles";
+
+  IRRewriter rewriter(getContext());
+  FailureOr<linalg::detail::PackingResult> result =
+      linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp,
+                                           getTranspose());
+  if (failed(result))
+    return emitDefiniteFailure() << "could not build packing loop nest";
+
+  if (result->clonedLoopIvs.empty()) {
+    transformResults.set(getPackingLoop().cast<OpResult>(),
+                         result->hoistedPadOp.getOperation());
+    return DiagnosedSilenceableFailure::success();
+  }
+  auto outerPackedLoop =
+      scf::getForInductionVarOwner(result->clonedLoopIvs.front());
+  transformResults.set(getPackingLoop().cast<OpResult>(),
+                       outerPackedLoop.getOperation());
+  return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::HoistPadBuildPackingLoopNestOp::verify() {
+  ArrayRef<int64_t> transpose = getTranspose();
+  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
+  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
+                           transpose.end())) {
+    return emitOpError() << "expects transpose to be a permutation, found "
+                         << getTranspose();
+  }
+  return success();
+}
+
+void transform::HoistPadBuildPackingLoopNestOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getTarget(), effects);
+  transform::onlyReadsHandle(getLoop(), effects);
+  transform::producesHandle(getPackingLoop(), effects);
+  transform::modifiesPayload(effects);
+}
+
 DiagnosedSilenceableFailure
 transform::HoistPadOp::applyToOne(tensor::PadOp target,
                                   transform::ApplyToEachResultList &results,
index 614efca..d2d3f58 100644 (file)
@@ -33,6 +33,7 @@ using llvm::dbgs;
 
 using namespace mlir;
 using namespace mlir::linalg;
+using namespace mlir::linalg::detail;
 
 #ifndef NDEBUG
 static bool debugPrintLoopInShortForm(Operation *op) {
@@ -61,6 +62,73 @@ static void debugPrintBackwardSlice(SetVector<Operation *> &backwardSlice) {
              DBGS() << "\n";);
 }
 
+/// Return at most nLevels of immediately enclosing scf::ForOp loops.
+/// Stops at the first parent that is not an scf::ForOp.
+/// Multi-loops such as scf.parallel or linalg.tiled_loop are not modeled atm.
+/// Control-flow and other containing ops with regions are not modeled atm.
+static void
+getAtMostNEnclosingLoops(tensor::PadOp padOp, int nLevels,
+                         SmallVector<scf::ForOp> &reverseEnclosingLoops) {
+  scf::ForOp outermostEnclosingForOp = nullptr;
+  Operation *nextEnclosingOp = padOp->getParentOp();
+  while (nLevels-- > 0 &&
+         (outermostEnclosingForOp = dyn_cast<scf::ForOp>(nextEnclosingOp))) {
+    LLVM_DEBUG(DBGS() << "loops: ";
+               debugPrintLoopInShortForm(outermostEnclosingForOp);
+               dbgs() << "\n");
+    reverseEnclosingLoops.push_back(outermostEnclosingForOp);
+    nextEnclosingOp = outermostEnclosingForOp->getParentOp();
+  }
+}
+
+/// Return at most nLevels of immediately enclosing scf::ForOp loops.
+/// Stops at the first parent that is not an scf::ForOp.
+/// Multi-loops such as scf.parallel or linalg.tiled_loop are not modeled atm.
+/// Control-flow and other containing ops with regions are not modeled atm.
+static void
+getEnclosingLoopsUntil(tensor::PadOp padOp, scf::ForOp untilLoop,
+                       SmallVector<scf::ForOp> &reverseEnclosingLoops) {
+  scf::ForOp outermostEnclosingForOp = nullptr;
+  Operation *nextEnclosingOp = padOp->getParentOp();
+  while (outermostEnclosingForOp != untilLoop &&
+         (outermostEnclosingForOp = dyn_cast<scf::ForOp>(nextEnclosingOp))) {
+    LLVM_DEBUG(DBGS() << "loops: ";
+               debugPrintLoopInShortForm(outermostEnclosingForOp);
+               dbgs() << "\n");
+    reverseEnclosingLoops.push_back(outermostEnclosingForOp);
+    nextEnclosingOp = outermostEnclosingForOp->getParentOp();
+  }
+}
+
+// Get all the ops in the backwards slice starting from `padOp` and that
+// are dominated by the outermost enclosing loop.
+// This also requires tracking ops defining values used in the region but
+// defined above.
+static void computeBackwardSlice(tensor::PadOp padOp,
+                                 scf::ForOp outermostEnclosingForOp,
+                                 SetVector<Operation *> &backwardSlice) {
+  DominanceInfo domInfo(outermostEnclosingForOp);
+  auto filter = [&](Operation *op) {
+    return domInfo.dominates(outermostEnclosingForOp, op) &&
+           !padOp->isProperAncestor(op);
+  };
+  // First, add the ops required to compute the region to the backwardSlice.
+  SetVector<Value> valuesDefinedAbove;
+  getUsedValuesDefinedAbove(padOp.getRegion(), padOp.getRegion(),
+                            valuesDefinedAbove);
+  for (Value v : valuesDefinedAbove) {
+    getBackwardSlice(v, &backwardSlice, filter, /*inclusive=*/true);
+  }
+  // Then, add the backward slice from padOp itself.
+  getBackwardSlice(padOp.getOperation(), &backwardSlice, filter,
+                   /*inclusive=*/true);
+}
+
+//===----------------------------------------------------------------------===//
+// HoistPaddingAnalysis Implementation.
+//===----------------------------------------------------------------------===//
+
+namespace {
 /// Analysis class to support tensor::PadOp hoisting across multiple enclosing
 /// loops. The failure conditions are:
 ///   1. Pad op has a use that is not an input of a LinalgOp.
@@ -76,36 +144,48 @@ static void debugPrintBackwardSlice(SetVector<Operation *> &backwardSlice) {
 ///      the outermost enclosing scf::ForOp.
 ///   8. There is no enclosing scf::ForOp that indexes the padded data.
 /// Other cases succeed and will trigger hoisting of the pad op.
-struct HoistingAnalysis {
-  HoistingAnalysis(tensor::PadOp padOp, int numLoops);
+struct HoistPaddingAnalysis {
+  HoistPaddingAnalysis(tensor::PadOp padOp, int numLoops);
+  HoistPaddingAnalysis(tensor::PadOp padOp, scf::ForOp outermostEnclosingForOp);
 
-  bool isValid() { return valid; }
+  bool isValid() { return valid.has_value() && valid.value(); }
+  bool isInvalid() { return valid.has_value() && !valid.value(); }
 
   /// Footprint of the packedTensor, computed from the packingLoops.
-  SmallVector<Value> getPackedTensorSizes(RewriterBase &b, Location loc);
+  SmallVector<Value> getPackedTensorSizes(RewriterBase &rewriter,
+                                          Location loc) const;
 
-  /// The outermost loop, determined by `nLevels` above which `padOp` will
-  /// be hoisted.
-  scf::ForOp outermostEnclosingForOp;
+  /// Performs optional hoisting to enable hoist padding to occur. This may be
+  /// necessary when `sliceOp` is not defined outside of the outermost enclosing
+  /// loop we want to hoist above.
+  ///
+  /// Example:
+  /// ```
+  /// %source = linalg.fill(%cst, %arg0)
+  /// // %source is available for packing here!
+  /// scf.for %i
+  ///   scf.for %j
+  ///     scf.for %k
+  ///       %slice = tensor.extract_slice %source [%i, %j]
+  ///       %padded_slice = tensor.pad %slice
+  /// ```
+  void enableHoistPadding(RewriterBase &rewriter);
 
-  /// Backward slice rooted at `padOp` and nested under
-  /// `outermostEnclosingForOp`.
-  SetVector<Operation *> backwardSlice;
+  /// Common analysis builder to finalize the construction of the analysis once
+  /// optional `enableHoistPadding` has run.
+  /// `reverseEnclosingLoops.back()` is the loop to hoist above.
+  void finalizeHoistPaddingAnalysis();
 
-  /// The scf::ForOp immediately enclosing `padOp` such that:
-  ///  1. they are nested under `outermostEnclosingForOp` (inclusive)
-  ///  2. whose induction variable is used, directly or indirectly, in the
-  ///     computation of `padOp`.
-  /// The span of these loops determines the footprint of the packed tensor.
-  SmallVector<scf::ForOp> packingLoops;
+private:
+  /// Encodes whether the analysis is valid and hoisting can proceed.
+  std::optional<bool> valid;
 
-  /// The ExtractSliceOp that feeds the PadOp we want to hoist.
-  tensor::ExtractSliceOp sliceOp;
+  /// The padOp to hoist.
+  tensor::PadOp opToHoist;
 
-  /// If non-empty, this is the unique scf::ForOp that consumes the `sliceOp`.
-  scf::ForOp padConsumingForOp;
+  /// Immediately enclosing loops considered for hoisting padding.
+  SmallVector<scf::ForOp> reverseEnclosingLoops;
 
-private:
   /// Drop any non-index dependencies of `padOp` and `sliceOp` from
   /// `backwardSlice`. The method follows the use-def chains of the index
   /// operands consumed by `padOp` and `sliceOp` and drops the operations
@@ -130,96 +210,92 @@ private:
   /// ```
   /// dropNonIndexDependencies(%padded_slice, %slice)
   /// removes [scf.for %k, linalg.fill(%cst, %arg1)] from backwardSlice.
-  LogicalResult dropNonIndexDependencies(tensor::PadOp padOp);
+  LogicalResult dropNonIndexDependencies();
 
-  /// Encodes whether the analysis is valid and hoisting can proceed.
-  bool valid;
-};
+public:
+  /// The outermost loop, determined by `nLevels` above which `padOp` will
+  /// be hoisted.
+  scf::ForOp outermostEnclosingForOp;
 
-/// Return at most nLevels of immediately enclosing scf::ForOp loops.
-/// Stops at the first parent that is not an scf::ForOp.
-/// Multi-loops such as scf.parallel or linalg.tiled_loop are not modeled atm.
-/// Control-flow and other containing ops with regions are not modeled atm.
-static void
-getAtMostNEnclosingLoops(tensor::PadOp padOp, int nLevels,
-                         SmallVector<scf::ForOp> &reverseEnclosingLoops) {
-  scf::ForOp outermostEnclosingForOp = nullptr;
-  Operation *nextEnclosingOp = padOp->getParentOp();
-  while (nLevels-- > 0 &&
-         (outermostEnclosingForOp = dyn_cast<scf::ForOp>(nextEnclosingOp))) {
-    LLVM_DEBUG(DBGS() << "loops: ";
-               debugPrintLoopInShortForm(outermostEnclosingForOp);
-               dbgs() << "\n");
-    reverseEnclosingLoops.push_back(outermostEnclosingForOp);
-    nextEnclosingOp = outermostEnclosingForOp->getParentOp();
-  }
-}
+  /// Backward slice rooted at `padOp` and nested under
+  /// `outermostEnclosingForOp`.
+  SetVector<Operation *> backwardSlice;
 
-// Get all the ops in the backwards slice starting from `padOp` and that
-// are dominated by the outermost enclosing loop.
-// This also requires tracking ops defining values used in the region but
-// defined above.
-static void computeBackwardSlice(tensor::PadOp padOp,
-                                 scf::ForOp outermostEnclosingForOp,
-                                 SetVector<Operation *> &backwardSlice) {
-  DominanceInfo domInfo(outermostEnclosingForOp);
-  auto filter = [&](Operation *op) {
-    return domInfo.dominates(outermostEnclosingForOp, op) &&
-           !padOp->isProperAncestor(op);
-  };
-  // First, add the ops required to compute the region to the backwardSlice.
-  SetVector<Value> valuesDefinedAbove;
-  getUsedValuesDefinedAbove(padOp.getRegion(), padOp.getRegion(),
-                            valuesDefinedAbove);
-  for (Value v : valuesDefinedAbove) {
-    getBackwardSlice(v, &backwardSlice, filter, /*inclusive=*/true);
-  }
-  // Then, add the backward slice from padOp itself.
-  getBackwardSlice(padOp.getOperation(), &backwardSlice, filter,
-                   /*inclusive=*/true);
-}
+  /// The scf::ForOp immediately enclosing `padOp` such that:
+  ///  1. they are nested under `outermostEnclosingForOp` (inclusive)
+  ///  2. whose induction variable is used, directly or indirectly, in the
+  ///     computation of `padOp`.
+  /// The span of these loops determines the footprint of the packed tensor.
+  SmallVector<scf::ForOp> packingLoops;
+
+  /// The ExtractSliceOp that feeds the PadOp we want to hoist.
+  tensor::ExtractSliceOp sliceOp;
+
+  /// If non-empty, this is the unique scf::ForOp that consumes the `sliceOp`.
+  scf::ForOp padConsumingForOp;
+};
 
-HoistingAnalysis::HoistingAnalysis(tensor::PadOp padOp, int numLoops) {
-  valid = false;
+} // namespace
 
+HoistPaddingAnalysis::HoistPaddingAnalysis(tensor::PadOp padOp, int numLoops)
+    : valid(std::nullopt), opToHoist(padOp) {
   // Get at most `numLoops` of immediately enclosing loops.
-  SmallVector<scf::ForOp> reverseEnclosingLoops;
-  getAtMostNEnclosingLoops(padOp, numLoops, reverseEnclosingLoops);
+  getAtMostNEnclosingLoops(opToHoist, numLoops, reverseEnclosingLoops);
   if (reverseEnclosingLoops.empty()) {
     LLVM_DEBUG(DBGS() << "--No immediately enclosing loop -> Skip\n");
+    valid = false;
     return;
   }
-
   outermostEnclosingForOp = reverseEnclosingLoops.back();
+  sliceOp = opToHoist.getSource().getDefiningOp<tensor::ExtractSliceOp>();
+  if (!sliceOp) {
+    LLVM_DEBUG(DBGS() << "--Cannot find the extract slice op -> Skip\n");
+    valid = false;
+    return;
+  }
+}
 
-  // Get the `sliceOp` that defines the source tensor of `padOp` and
-  // check its source is defined outside of the outermost loop. This check
-  // ensures the padded data is available for packing before entering the
-  // outermost enclosing loop.
-  //
-  // Example:
-  // ```
-  // %source = linalg.fill(%cst, %arg0)
-  // // %source is available for packing here!
-  // scf.for %i
-  //   scf.for %j
-  //     scf.for %k
-  //       %slice = tensor.extract_slice %source [%i, %j]
-  //       %padded_slice = tensor.pad %slice
-  // ```
-  sliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
+HoistPaddingAnalysis::HoistPaddingAnalysis(tensor::PadOp padOp,
+                                           scf::ForOp outermostEnclosingForOp)
+    : valid(std::nullopt), opToHoist(padOp) {
+  // Get enclosing loops until outermostEnclosingForOp.
+  getEnclosingLoopsUntil(opToHoist, outermostEnclosingForOp,
+                         reverseEnclosingLoops);
+  if (reverseEnclosingLoops.empty()) {
+    LLVM_DEBUG(DBGS() << "--No immediately enclosing loop -> Skip\n");
+    valid = false;
+    return;
+  }
+  this->outermostEnclosingForOp = reverseEnclosingLoops.back();
+  if (this->outermostEnclosingForOp != outermostEnclosingForOp) {
+    LLVM_DEBUG(DBGS() << "--Unexpected outermost enclosing loop -> Skip\n");
+    valid = false;
+    return;
+  }
+  sliceOp = opToHoist.getSource().getDefiningOp<tensor::ExtractSliceOp>();
   if (!sliceOp) {
     LLVM_DEBUG(DBGS() << "--Cannot find the extract slice op -> Skip\n");
+    valid = false;
     return;
   }
+}
+
+void HoistPaddingAnalysis::enableHoistPadding(RewriterBase &rewriter) {
+  if (isInvalid())
+    return;
   // If the padded data is not yet available before entering the outermost
   // enclosing loop, try to apply hoisting on this outermost loop.
   // TODO: we may want finer-grained hoisting of only that particular `sliceOp`.
-  IRRewriter rewriter(outermostEnclosingForOp->getContext());
   if (!outermostEnclosingForOp.isDefinedOutsideOfLoop(sliceOp.getSource())) {
     outermostEnclosingForOp =
         hoistRedundantSubsetExtractInsert(rewriter, outermostEnclosingForOp);
   }
+}
+
+void HoistPaddingAnalysis::finalizeHoistPaddingAnalysis() {
+  if (isInvalid())
+    return;
+
   if (!outermostEnclosingForOp.isDefinedOutsideOfLoop(sliceOp.getSource())) {
     LLVM_DEBUG(DBGS() << "--outermostEnclosingForOp:\n"
                       << outermostEnclosingForOp << "\n"
@@ -227,6 +303,7 @@ HoistingAnalysis::HoistingAnalysis(tensor::PadOp padOp, int numLoops) {
                       << "--sliceOp.getSource(): " << sliceOp.getSource()
                       << "\n");
     LLVM_DEBUG(DBGS() << "----Source not defined outside of loops -> Skip\n");
+    valid = false;
     return;
   }
   if (sliceOp->hasOneUse()) {
@@ -236,23 +313,27 @@ HoistingAnalysis::HoistingAnalysis(tensor::PadOp padOp, int numLoops) {
   // Check the region of `padOp` depends on a constant only. Adding hoisting
   // support for arbitrary padding regions would require cloning all
   // dependencies captured by the padding region.
-  Value paddingValue = padOp.getConstantPaddingValue();
+  Value paddingValue = opToHoist.getConstantPaddingValue();
   if (!paddingValue ||
       !isa_and_nonnull<arith::ConstantOp>(paddingValue.getDefiningOp())) {
     LLVM_DEBUG(DBGS() << "Cannot find constant padding value -> Skip\n");
+    valid = false;
     return;
   }
 
-  computeBackwardSlice(padOp, outermostEnclosingForOp, backwardSlice);
-  if (backwardSlice.size() <= 1)
+  computeBackwardSlice(opToHoist, outermostEnclosingForOp, backwardSlice);
+  if (backwardSlice.size() <= 1) {
+    valid = false;
     return;
+  }
 
   debugPrintBackwardSlice(backwardSlice);
   // Remove all ops in the backward slice that are not used to index
   // the padded tensor. In particular, keep `padOp`, `sliceOp`, and
   // the loop and affine operations used for the index computation.
-  if (failed(dropNonIndexDependencies(padOp))) {
+  if (failed(dropNonIndexDependencies())) {
     LLVM_DEBUG(DBGS() << "--Cannot dropNonIndexDependencies -> Skip\n");
+    valid = false;
     return;
   }
   debugPrintBackwardSlice(backwardSlice);
@@ -281,7 +362,7 @@ HoistingAnalysis::HoistingAnalysis(tensor::PadOp padOp, int numLoops) {
   valid = true;
 }
 
-LogicalResult HoistingAnalysis::dropNonIndexDependencies(tensor::PadOp padOp) {
+LogicalResult HoistPaddingAnalysis::dropNonIndexDependencies() {
   // Set of all values used for index computation.
   SetVector<Value> indexEdges;
 
@@ -300,7 +381,7 @@ LogicalResult HoistingAnalysis::dropNonIndexDependencies(tensor::PadOp padOp) {
     });
   };
 
-  // Starting from `padOp` and `sliceOp` walk the use-def edges of index
+  // Starting from `opToHoist` and `sliceOp` walk the use-def edges of index
   // type in `backwardSlice`. Add the index operands of an operation to
   // `indexEdges` and remove all operations from `backwardSlice` that are not
   // part of the index computation.
@@ -322,9 +403,9 @@ LogicalResult HoistingAnalysis::dropNonIndexDependencies(tensor::PadOp padOp) {
   // backwardSlice = backwardSlice / [linalg.fill(%cst, %arg1), scf.for %k]
   SetVector<Operation *> operationsToRemove;
   for (Operation *op : llvm::reverse(backwardSlice)) {
-    // Add the index operands of `padOp` and `sliceOp` to start the
+    // Add the index operands of `opToHoist` and `sliceOp` to start the
     // exploration of the index computation.
-    if (op == padOp || op == sliceOp) {
+    if (op == opToHoist || op == sliceOp) {
       addIndexOperandsToIndexEdges(op);
       continue;
     }
@@ -358,7 +439,7 @@ LogicalResult HoistingAnalysis::dropNonIndexDependencies(tensor::PadOp padOp) {
       continue;
     }
     // Remove all other operations not used by the index computation. An
-    // exception are constant operations that may be used by `padOp`.
+    // exception are constant operations that may be used by `opToHoist`.
     if (!isa<arith::ConstantOp>(op))
       operationsToRemove.insert(op);
   }
@@ -367,7 +448,8 @@ LogicalResult HoistingAnalysis::dropNonIndexDependencies(tensor::PadOp padOp) {
 }
 
 SmallVector<Value>
-HoistingAnalysis::getPackedTensorSizes(RewriterBase &rewriter, Location loc) {
+HoistPaddingAnalysis::getPackedTensorSizes(RewriterBase &rewriter,
+                                           Location loc) const {
   SmallVector<Value> dynamicTensorSizes;
 
   // Upper bound the packing loop lengths to size the packed tensor. Taking
@@ -403,6 +485,10 @@ static bool isDefinedOutsideOrConstant(scf::ForOp outer, Value v) {
   return outer.isDefinedOutsideOfLoop(v) || matchPattern(v, m_Constant());
 }
 
+//===----------------------------------------------------------------------===//
+// buildPackingLoopNest Implementation.
+//===----------------------------------------------------------------------===//
+
 /// Return the current iteration number in the loop (iv - lb).ceilDiv(step).
 /// The returned Value is guaranteed not to depend on any loop comprised in
 /// [`outer`, `forOp`].
@@ -423,12 +509,6 @@ static Value buildLoopIterationCount(RewriterBase &rewriter, scf::ForOp outer,
       loc, (iv - lb).ceilDiv(step), ValueRange{ivVal, lbVal, stepVal});
 }
 
-struct PackingLoopNestResult {
-  SmallVector<OpFoldResult> offsets, sizes, strides;
-  SmallVector<Value> clonedLoopIvs, leadingPackedTensorIndexings;
-  GenericOp maybeTransposeOp;
-};
-
 // Build a packing loop nest by iteratively traversing the backward slice and
 // clone the operations, iteratively stepping into the loops that we encounter.
 // The implementation proceeds in a stack-like fashion:
@@ -439,10 +519,10 @@ struct PackingLoopNestResult {
 //   3. At the innermost loop level, create a InsertSliceOp.
 //   4. Iteratively pop and yield the result of the InsertSliceOp across the
 //      cloned loops.
-static PackingLoopNestResult buildPackingLoopNest(
+static PackingResult buildPackingLoopNestImpl(
     RewriterBase &rewriter, IRMapping &bvm, tensor::PadOp opToHoist,
     ArrayRef<int64_t> transposeVector, RankedTensorType transposedTensorType,
-    tensor::EmptyOp emptyOp, const HoistingAnalysis &analysis) {
+    tensor::EmptyOp emptyOp, const HoistPaddingAnalysis &analysis) {
   SmallVector<OpFoldResult> offsets, sizes, strides;
   SmallVector<Value> clonedLoopIvs, leadingPackedTensorIndexings;
 
@@ -452,8 +532,21 @@ static PackingLoopNestResult buildPackingLoopNest(
   RankedTensorType paddedTensorType = opToHoist.getResultType();
   int paddedRank = paddedTensorType.getRank();
 
-  Value packedTensor = emptyOp.getResult();
+  // Step 0. Populate bvm with opToHoist.getSource if relevant.
+  BlockArgument bbArg = opToHoist.getSource().dyn_cast<BlockArgument>();
+  while (bbArg) {
+    auto forOp = dyn_cast<scf::ForOp>(bbArg.getOwner()->getParentOp());
+    if (!forOp)
+      break;
+    if (forOp != outerLoop && !outerLoop->isAncestor(forOp))
+      break;
+    OpOperand &operand = forOp.getOpOperandForRegionIterArg(bbArg);
+    bvm.map(bbArg, operand.get());
+    bbArg = operand.get().dyn_cast<BlockArgument>();
+  }
+
   // Step 1. iteratively clone loops and push `packedTensor`.
+  Value packedTensor = emptyOp.getResult();
   OpBuilder::InsertionGuard g(rewriter);
   for (Operation *op : analysis.backwardSlice) {
     // Specifically sit out in the extract_slice(packedTensor) case: this is
@@ -544,14 +637,83 @@ static PackingLoopNestResult buildPackingLoopNest(
     }
   }
 
-  return PackingLoopNestResult{offsets,
-                               sizes,
-                               strides,
-                               clonedLoopIvs,
-                               leadingPackedTensorIndexings,
-                               maybeTransposeOp};
+  return PackingResult{
+      offsets,
+      sizes,
+      strides,
+      clonedLoopIvs,
+      leadingPackedTensorIndexings,
+      maybeTransposeOp,
+      cast<tensor::PadOp>(bvm.lookup(opToHoist.getResult()).getDefiningOp())};
 }
 
+/// Build the packing loop nest required to hoist `opToHoist` above
+/// `outermostEnclosingForOp`.
+/// The loop nest is built just before `outermostEnclosingForOp`.
+static FailureOr<PackingResult> buildPackingLoopNestImpl(
+    RewriterBase &rewriter, IRMapping &bvm, tensor::PadOp opToHoist,
+    ArrayRef<int64_t> transposeVector, const HoistPaddingAnalysis &analysis) {
+  // Update actual number of loops, which may be smaller.
+  int nPackedLoops = analysis.packingLoops.size();
+  LLVM_DEBUG(DBGS() << "\n";
+             DBGS() << "Func:\n"
+                    << *opToHoist->getParentOfType<func::FuncOp>() << "\n";
+             DBGS() << "Start hoisting above " << nPackedLoops << " loops\n");
+
+  Location loc = opToHoist->getLoc();
+  RankedTensorType paddedTensorType = opToHoist.getResultType();
+
+  // Compute the type of the transposed padded tensor.
+  FailureOr<RankedTensorType> transposedTensorType =
+      tensor::computeTransposedType(paddedTensorType, transposeVector);
+  if (failed(transposedTensorType)) {
+    LLVM_DEBUG(DBGS() << "--Could not compute transposed type -> Skip\n");
+    return failure();
+  }
+
+  // Create the packed tensor<?x?x..? x transposedShape>.
+  SmallVector<int64_t> packedShape(nPackedLoops, ShapedType::kDynamic);
+  // TODO: go grab dims when needed, atm tensor::PadOp yields a static tensor.
+  llvm::append_range(packedShape, transposedTensorType->getShape());
+  auto packedTensorType = RankedTensorType::get(
+      packedShape, transposedTensorType->getElementType());
+
+  // Set the insertion point right before the outer loop and start packing.
+  scf::ForOp outerLoop = analysis.outermostEnclosingForOp;
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(outerLoop);
+  SmallVector<Value> dynamicTensorSizes =
+      analysis.getPackedTensorSizes(rewriter, loc);
+  auto emptyOp = rewriter.create<tensor::EmptyOp>(
+      loc, packedTensorType.getShape(), packedTensorType.getElementType(),
+      dynamicTensorSizes);
+
+  return buildPackingLoopNestImpl(rewriter, bvm, opToHoist, transposeVector,
+                                  *transposedTensorType, emptyOp, analysis);
+}
+
+/// Build the packing loop nest required to hoist `opToHoist` above
+/// `outermostEnclosingForOp`.
+/// The loop nest is built just before `outermostEnclosingForOp`.
+FailureOr<PackingResult> mlir::linalg::detail::buildPackingLoopNest(
+    RewriterBase &rewriter, tensor::PadOp opToHoist,
+    scf::ForOp outermostEnclosingForOp, ArrayRef<int64_t> transposeVector) {
+  HoistPaddingAnalysis analysis(opToHoist, outermostEnclosingForOp);
+  analysis.enableHoistPadding(rewriter);
+  analysis.finalizeHoistPaddingAnalysis();
+  if (!analysis.isValid()) {
+    LLVM_DEBUG(DBGS() << "--Analysis failed -> Skip\n");
+    return failure();
+  }
+  IRMapping bvm;
+  return buildPackingLoopNestImpl(rewriter, bvm, opToHoist, transposeVector,
+                                  analysis);
+}
+
+//===----------------------------------------------------------------------===//
+// hoistPaddingOnTensors Implementation.
+//===----------------------------------------------------------------------===//
+
 // If the original consumer of `sliceOp` was a `forOp` (i.e. through an iter
 // arg), propagate the `packedTensor` value through the same iter arg.
 // TODO: for multiple loops we need to track the use to the innermost loop.
@@ -574,6 +736,7 @@ static Value padThroughLoopIterArg(RewriterBase &rewriter, Value packedTensor,
   std::optional<unsigned> operandNumber =
       forOp.getIterArgNumberForOpOperand(*pUse);
   assert(operandNumber.has_value() && "expected a proper iter arg number");
+
   SmallVector<Value> initArgs = forOp.getInitArgs();
   initArgs[operandNumber.value()] = casted;
   rewriter.startRootUpdate(forOp);
@@ -584,10 +747,12 @@ static Value padThroughLoopIterArg(RewriterBase &rewriter, Value packedTensor,
 
 /// Produce a tensor extracted from the packingResult. This can be used as a
 /// replacement for `opToHoist` in callers.
-static Value replaceByPackingLoopNestResult(
-    RewriterBase &rewriter, const IRMapping &bvm, tensor::PadOp opToHoist,
-    RankedTensorType transposedTensorType, const HoistingAnalysis &analysis,
-    const PackingLoopNestResult &packingResult) {
+static Value replaceByPackingResult(RewriterBase &rewriter,
+                                    const IRMapping &bvm,
+                                    tensor::PadOp opToHoist,
+                                    RankedTensorType transposedTensorType,
+                                    const HoistPaddingAnalysis &analysis,
+                                    const PackingResult &packingResult) {
   // The replacement occurs under a single insertion point within the original
   // loop, just before opToHoist.
   OpBuilder::InsertionGuard g(rewriter);
@@ -625,13 +790,7 @@ static Value replaceByPackingLoopNestResult(
             ->getResult(0);
   } else {
     // If no loops were created, this is just hoisting without packing.
-    auto padOp =
-        cast<tensor::PadOp>(bvm.lookup(opToHoist.getResult()).getDefiningOp());
-    tensor::ExtractSliceOp sliceOp = analysis.sliceOp;
-    rewriter.startRootUpdate(padOp);
-    padOp.getSourceMutable().assign(sliceOp.getResult());
-    rewriter.finalizeRootUpdate(padOp);
-    packedTensor = padOp;
+    packedTensor = bvm.lookup(opToHoist.getResult());
   }
 
   LLVM_DEBUG(DBGS() << "packedTensor: " << packedTensor << "\n");
@@ -657,59 +816,39 @@ FailureOr<Value> mlir::linalg::hoistPaddingOnTensors(
     SmallVectorImpl<GenericOp> &transposeOps) {
   LLVM_DEBUG(DBGS() << "\n"; DBGS() << " Try to hoist " << *(opToHoist) << "\n";
              DBGS() << " by " << numLoops << " loops\n");
-  HoistingAnalysis analysis(opToHoist, numLoops);
+
+  HoistPaddingAnalysis analysis(opToHoist, numLoops);
+  analysis.enableHoistPadding(rewriter);
+  analysis.finalizeHoistPaddingAnalysis();
   if (!analysis.isValid()) {
     LLVM_DEBUG(DBGS() << "--Analysis failed -> Skip\n");
     return failure();
   }
 
-  // Update actual number of loops, which may be smaller.
-  int nPackedLoops = analysis.packingLoops.size();
-  LLVM_DEBUG(DBGS() << "\n";
-             DBGS() << "Func:\n"
-                    << *opToHoist->getParentOfType<func::FuncOp>() << "\n";
-             DBGS() << "Start hoisting above " << nPackedLoops << " loops\n");
-
-  Location loc = opToHoist->getLoc();
-  RankedTensorType paddedTensorType = opToHoist.getResultType();
-
-  // Compute the type of the transposed padded tensor.
-  FailureOr<RankedTensorType> transposedTensorType =
-      tensor::computeTransposedType(paddedTensorType, transposeVector);
-  if (failed(transposedTensorType)) {
-    LLVM_DEBUG(DBGS() << "--Could not compute transposed type -> Skip\n");
+  /// Construct the packing loop nest.
+  IRMapping bvm;
+  FailureOr<PackingResult> packingResult = buildPackingLoopNestImpl(
+      rewriter, bvm, opToHoist, transposeVector, analysis);
+  if (failed(packingResult)) {
+    LLVM_DEBUG(DBGS() << "--buildPackingLoopNestImpl failed -> Skip\n");
     return failure();
   }
 
-  // Create the packed tensor<?x?x..? x transposedShape>.
-  SmallVector<int64_t> packedShape(nPackedLoops, ShapedType::kDynamic);
-  // TODO: go grab dims when needed, atm tensor::PadOp yields a static tensor.
-  llvm::append_range(packedShape, transposedTensorType->getShape());
-  auto packedTensorType = RankedTensorType::get(
-      packedShape, transposedTensorType->getElementType());
-
-  // Set the insertion point right before the outer loop and start packing.
-  scf::ForOp outerLoop = analysis.outermostEnclosingForOp;
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(outerLoop);
-  SmallVector<Value> dynamicTensorSizes =
-      analysis.getPackedTensorSizes(rewriter, loc);
-  auto emptyOp = rewriter.create<tensor::EmptyOp>(
-      loc, packedTensorType.getShape(), packedTensorType.getElementType(),
-      dynamicTensorSizes);
-
-  /// Construct the packing loop nest.
-  IRMapping bvm;
-  PackingLoopNestResult packingResult =
-      buildPackingLoopNest(rewriter, bvm, opToHoist, transposeVector,
-                           *transposedTensorType, emptyOp, analysis);
   if (!transposeVector.empty())
-    transposeOps.push_back(packingResult.maybeTransposeOp);
+    transposeOps.push_back(packingResult->maybeTransposeOp);
+
+  FailureOr<RankedTensorType> transposedTensorType =
+      tensor::computeTransposedType(opToHoist.getResultType(), transposeVector);
+  assert(succeeded(transposedTensorType) && "unexpected failure in type");
 
   // Now the packed tensor is ready, replace the original padding op by a
   // 1x..x1 slice [originalLoopIvs, 0 .. 0][1 .. 1, paddedShape][1 .. 1].
-  Value newResult = replaceByPackingLoopNestResult(
-      rewriter, bvm, opToHoist, *transposedTensorType, analysis, packingResult);
+  Value newResult =
+      replaceByPackingResult(rewriter, bvm, opToHoist, *transposedTensorType,
+                             analysis, *packingResult);
+
+  Location loc = opToHoist->getLoc();
+  RankedTensorType paddedTensorType = opToHoist.getResultType();
   if (!transposeVector.empty()) {
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPointAfter(newResult.getDefiningOp());
@@ -729,8 +868,7 @@ FailureOr<Value> mlir::linalg::hoistPaddingOnTensors(
              << "\n");
 
   // Make the newly cloned `opToHoist` available to the caller.
-  hoistedOp =
-      cast<tensor::PadOp>(bvm.lookup(opToHoist.getResult()).getDefiningOp());
+  hoistedOp = packingResult->hoistedPadOp;
 
   LLVM_DEBUG(DBGS() << "--SUCCESS\n");
   return newResult;
index 4c09c54..81d39a7 100644 (file)
@@ -87,6 +87,7 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
                                     ArrayRef<int64_t> transposeVector) {
   if (transposeVector.empty())
     return rankedTensorType;
+
   if (!isPermutationVector(transposeVector) ||
       transposeVector.size() != static_cast<size_t>(rankedTensorType.getRank()))
     return failure();
diff --git a/mlir/test/Dialect/Linalg/transform-op-hoist-pad-build-packing-loop-nest.mlir b/mlir/test/Dialect/Linalg/transform-op-hoist-pad-build-packing-loop-nest.mlir
new file mode 100644 (file)
index 0000000..99090ac
--- /dev/null
@@ -0,0 +1,174 @@
+// BUILD-PACKING-LOOP-NEST only checks the creation of packing code but does not connect it.
+// Do not run canonicalization as it would be DCE'd away.
+// RUN: mlir-opt --test-transform-dialect-interpreter -split-input-file --verify-diagnostics %s | FileCheck %s --check-prefix=BUILD-PACKING-LOOP-NEST
+
+func.func @pad_and_hoist_rhs(
+  %arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>)
+     -> tensor<24x25xf32>
+{
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+  func.return %0 : tensor<24x25xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+    : (!pdl.operation) -> !pdl.operation
+
+  %matmul_l1, %loops_l1 = transform.structured.tile_to_scf_for %matmul [5]
+
+  %matmul_padded = transform.structured.pad %matmul_l1 {
+    padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32],
+    padding_dimensions=[0, 1, 2]
+  }
+
+  // In this case, the pad op is actually empty: we only tile the first dimension
+  // and it does not have an impact on the RHS operand.
+  %pad = transform.get_producer_of_operand %matmul_padded[1]
+    : (!pdl.operation) -> !pdl.operation
+
+  // expected-error @below {{requires exactly 2 non-null handles}}
+  transform.structured.hoist_pad.build_packing_loop_nest %pad above %loops_l1
+     : (!pdl.operation, !pdl.operation) -> !pdl.operation
+}
+
+// -----
+
+func.func @pad_and_hoist_init(
+  %arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>)
+     -> tensor<24x25xf32>
+{
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+  func.return %0 : tensor<24x25xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+    : (!pdl.operation) -> !pdl.operation
+
+  %matmul_l1, %loops_l1 = transform.structured.tile_to_scf_for %matmul [5]
+
+  %matmul_padded = transform.structured.pad %matmul_l1 {
+    padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32],
+    padding_dimensions=[0, 1, 2]
+  }
+
+  %pad = transform.get_producer_of_operand %matmul_padded[2]
+    : (!pdl.operation) -> !pdl.operation
+
+  // We do not know yet how to hoist the init.
+  // expected-error @below {{could not build packing loop nest}}
+  transform.structured.hoist_pad.build_packing_loop_nest %pad above %loops_l1
+     : (!pdl.operation, !pdl.operation) -> !pdl.operation
+}
+
+// -----
+
+//     BUILD-PACKING-LOOP-NEST-LABEL: pad_and_hoist_lhs
+func.func @pad_and_hoist_lhs(
+  %arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>)
+     -> tensor<24x25xf32>
+{
+  //     BUILD-PACKING-LOOP-NEST: %[[PACKED:.*]] = scf.for %{{.*}} -> (tensor<?x5x12xf32>) {
+  //     BUILD-PACKING-LOOP-NEST:   tensor.pad %{{.*}} 
+  //     BUILD-PACKING-LOOP-NEST:     : tensor<?x12xf32> to tensor<5x12xf32>
+  //     BUILD-PACKING-LOOP-NEST:   tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}, 0, 0] [1, 5, 12] [1, 1, 1] 
+  // BUILD-PACKING-LOOP-NEST-SAME:   : tensor<5x12xf32> into tensor<?x5x12xf32>
+  //     BUILD-PACKING-LOOP-NEST: scf.for %{{.*}} -> (tensor<24x25xf32>)
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+  func.return %0 : tensor<24x25xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+    : (!pdl.operation) -> !pdl.operation
+
+  %matmul_l1, %loops_l1 = transform.structured.tile_to_scf_for %matmul [5]
+
+  %matmul_padded = transform.structured.pad %matmul_l1 {
+    padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32],
+    padding_dimensions=[0, 1, 2]
+  }
+
+  %pad = transform.get_producer_of_operand %matmul_padded[0]
+    : (!pdl.operation) -> !pdl.operation
+
+  transform.structured.hoist_pad.build_packing_loop_nest %pad above %loops_l1
+     : (!pdl.operation, !pdl.operation) -> !pdl.operation
+}
+
+// -----
+
+//     BUILD-PACKING-LOOP-NEST-LABEL: pad_and_hoist_lhs_transpose
+func.func @pad_and_hoist_lhs_transpose(
+  %arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>)
+     -> tensor<24x25xf32>
+{
+  //     BUILD-PACKING-LOOP-NEST: %[[PACKED:.*]] = scf.for %{{.*}} -> (tensor<?x12x5xf32>) {
+  //     BUILD-PACKING-LOOP-NEST:   tensor.pad %{{.*}} 
+  //     BUILD-PACKING-LOOP-NEST:     : tensor<?x12xf32> to tensor<5x12xf32>
+  //     BUILD-PACKING-LOOP-NEST:   linalg.generic
+  //     BUILD-PACKING-LOOP-NEST:     -> tensor<12x5xf32>
+  //     BUILD-PACKING-LOOP-NEST:   tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}, 0, 0] [1, 12, 5] [1, 1, 1] 
+  // BUILD-PACKING-LOOP-NEST-SAME:   : tensor<12x5xf32> into tensor<?x12x5xf32>
+  //     BUILD-PACKING-LOOP-NEST: scf.for %{{.*}} -> (tensor<24x25xf32>)
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+  func.return %0 : tensor<24x25xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+    : (!pdl.operation) -> !pdl.operation
+
+  %matmul_l1, %loops_l1 = transform.structured.tile_to_scf_for %matmul [5]
+
+  %matmul_padded = transform.structured.pad %matmul_l1 {
+    padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32],
+    padding_dimensions=[0, 1, 2]
+  }
+
+  %pad = transform.get_producer_of_operand %matmul_padded[0]
+    : (!pdl.operation) -> !pdl.operation
+
+  transform.structured.hoist_pad.build_packing_loop_nest %pad above %loops_l1, transpose by [1, 0]
+     : (!pdl.operation, !pdl.operation) -> !pdl.operation
+}
+
+// -----
+
+//     BUILD-PACKING-LOOP-NEST-LABEL: pad_and_hoist_init
+func.func @pad_and_hoist_init(
+  %arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>)
+     -> tensor<24x25xf32>
+{
+
+  //      BUILD-PACKING-LOOP-NEST: scf.for %{{.*}} -> (tensor<24x25xf32>) {
+  //      BUILD-PACKING-LOOP-NEST:   %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice
+  //      BUILD-PACKING-LOOP-NEST:   %[[PADDED:.*]] = tensor.pad %[[EXTRACTED_SLICE]] 
+  //      BUILD-PACKING-LOOP-NEST:     : tensor<?x25xf32> to tensor<5x25xf32>
+  //      BUILD-PACKING-LOOP-NEST:   scf.for %{{.*}} iter_args({{.*}} = %[[EXTRACTED_SLICE]]) -> (tensor<24x25xf32>, tensor<?x25xf32>) {
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+  func.return %0 : tensor<24x25xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+    : (!pdl.operation) -> !pdl.operation
+
+  %matmul_l1, %loops_l1:2 = transform.structured.tile_to_scf_for %matmul [5, 0, 7]
+
+  %matmul_padded = transform.structured.pad %matmul_l1 {
+    padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32],
+    padding_dimensions=[0, 1, 2]
+  }
+
+  %pad = transform.get_producer_of_operand %matmul_padded[2]
+    : (!pdl.operation) -> !pdl.operation
+
+  transform.structured.hoist_pad.build_packing_loop_nest %pad above %loops_l1#1
+     : (!pdl.operation, !pdl.operation) -> !pdl.operation
+}
index 02e4698..fd0d309 100644 (file)
@@ -68,7 +68,7 @@ transform.sequence failures(propagate) {
 
 // -----
 
-//     CHECK-LABEL: pad_and_hoist_lhs
+//     CHECK-LABEL: pad_and_hoist_lhs(
 func.func @pad_and_hoist_lhs(
   %arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>)
      -> tensor<24x25xf32>