[mlir][Linalg] Fix hoist padding through scf.for iter_arg
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 4 Apr 2023 15:17:19 +0000 (08:17 -0700)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Thu, 13 Apr 2023 12:21:38 +0000 (05:21 -0700)
Previously, hoisting through an iter_arg would mistakenly yield the unpadded value and
cast it to the padded value.

This was incorrect and resulted in out-of-bounds accesses.
The correct formulation is to yield the padded value and extract a smaller dynamic slice
out of it.

Differential Revision: https://reviews.llvm.org/D148173

mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir

index 5386420..7a6c58a 100644 (file)
@@ -24,6 +24,7 @@
 #include "mlir/IR/AsmState.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/Support/Debug.h"
 
@@ -153,9 +154,9 @@ struct HoistPaddingAnalysis {
   bool isValid() { return valid.has_value() && valid.value(); }
   bool isInvalid() { return valid.has_value() && !valid.value(); }
 
-  /// Footprint of the packedTensor, computed from the packingLoops.
-  SmallVector<Value> getPackedTensorSizes(RewriterBase &rewriter,
-                                          Location loc) const;
+  /// Footprint of the hoistedPackedTensor, computed from the packingLoops.
+  SmallVector<Value> getHoistedPackedTensorSizes(RewriterBase &rewriter,
+                                                 Location loc) const;
 
   /// Performs optional hoisting to enable hoist padding to occur. This may be
   /// necessary when `sliceOp` is not defined outside of the outermost enclosing
@@ -450,8 +451,8 @@ LogicalResult HoistPaddingAnalysis::dropNonIndexDependencies() {
 }
 
 SmallVector<Value>
-HoistPaddingAnalysis::getPackedTensorSizes(RewriterBase &rewriter,
-                                           Location loc) const {
+HoistPaddingAnalysis::getHoistedPackedTensorSizes(RewriterBase &rewriter,
+                                                  Location loc) const {
   SmallVector<Value> dynamicTensorSizes;
 
   // Upper bound the packing loop lengths to size the packed tensor. Taking
@@ -525,7 +526,8 @@ static Value buildLoopIterationCount(RewriterBase &rewriter, scf::ForOp outer,
 // Build a packing loop nest by iteratively traversing the backward slice and
 // clone the operations, iteratively stepping into the loops that we encounter.
 // The implementation proceeds in a stack-like fashion:
-//   1. Iteratively clone and step into the loops, pushing the `packedTensor`
+//   1. Iteratively clone and step into the loops, pushing the
+//   `hoistedPackedTensor`
 //      deeper in the stack.
 //   2. At the innermost loop level, create a GenericOp if `transposeVector` is
 //      non-empty.
@@ -537,7 +539,7 @@ static PackingResult buildPackingLoopNestImpl(
     ArrayRef<int64_t> transposeVector, RankedTensorType transposedTensorType,
     tensor::EmptyOp emptyOp, const HoistPaddingAnalysis &analysis) {
   SmallVector<OpFoldResult> offsets, sizes, strides;
-  SmallVector<Value> clonedLoopIvs, leadingPackedTensorIndexings;
+  SmallVector<Value> clonedLoopIvs, leadingHoistedPackedTensorIndexings;
 
   scf::ForOp outerLoop = analysis.outermostEnclosingForOp;
 
@@ -558,14 +560,14 @@ static PackingResult buildPackingLoopNestImpl(
     bbArg = operand.get().dyn_cast<BlockArgument>();
   }
 
-  // Step 1. iteratively clone loops and push `packedTensor`.
-  Value packedTensor = emptyOp.getResult();
+  // Step 1. iteratively clone loops and push `hoistedPackedTensor`.
+  Value hoistedPackedTensor = emptyOp.getResult();
   OpBuilder::InsertionGuard g(rewriter);
   for (Operation *op : analysis.backwardSlice) {
-    // Specifically sit out in the extract_slice(packedTensor) case: this is
-    // the piece we seek to replace.
+    // Specifically sit out in the extract_slice(hoistedPackedTensor) case: this
+    // is the piece we seek to replace.
     if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
-      if (bvm.lookupOrDefault(sliceOp.getSource()) == packedTensor) {
+      if (bvm.lookupOrDefault(sliceOp.getSource()) == hoistedPackedTensor) {
         LLVM_DEBUG(DBGS() << "--Skip: " << sliceOp << "\n");
         continue;
       }
@@ -579,11 +581,12 @@ static PackingResult buildPackingLoopNestImpl(
       continue;
     }
 
-    // Create a packing loop that takes `packedTensor` as iteration argument.
+    // Create a packing loop that takes `hoistedPackedTensor` as iteration
+    // argument.
     auto clonedForOp = rewriter.create<scf::ForOp>(
         loc, bvm.lookupOrDefault(forOp.getLowerBound()),
         bvm.lookupOrDefault(forOp.getUpperBound()),
-        bvm.lookupOrDefault(forOp.getStep()), packedTensor);
+        bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor);
 
     // Map the induction var, region args and results to the `clonedForOp`.
     bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar());
@@ -600,16 +603,18 @@ static PackingResult buildPackingLoopNestImpl(
     // Assert the loop-independent iteration count can be computed.
     if (!loopIndependentIterationCount)
       llvm_unreachable("loop independence prerequisite not met");
-    leadingPackedTensorIndexings.push_back(loopIndependentIterationCount);
-    packedTensor = clonedForOp.getRegionIterArgs().front();
+    leadingHoistedPackedTensorIndexings.push_back(
+        loopIndependentIterationCount);
+    hoistedPackedTensor = clonedForOp.getRegionIterArgs().front();
   }
 
   // Step 2. Construct offsets, sizes and strides for the innermost level of the
   // packing loop.
   int64_t nPackedLoops = clonedLoopIvs.size();
   // offsets = [clonedLoopIvs, 0 .. 0].
-  offsets = SmallVector<OpFoldResult>{leadingPackedTensorIndexings.begin(),
-                                      leadingPackedTensorIndexings.end()};
+  offsets =
+      SmallVector<OpFoldResult>{leadingHoistedPackedTensorIndexings.begin(),
+                                leadingHoistedPackedTensorIndexings.end()};
   offsets.append(paddedRank, rewriter.getIndexAttr(0));
   // sizes = [1 .. 1, transposedShape].
   sizes = SmallVector<OpFoldResult>(nPackedLoops, rewriter.getIndexAttr(1));
@@ -627,7 +632,8 @@ static PackingResult buildPackingLoopNestImpl(
   Value paddedTensor = bvm.lookup(opToHoist.getResult());
   if (!transposeVector.empty()) {
     Value outputTensor = rewriter.create<tensor::ExtractSliceOp>(
-        loc, transposedTensorType, packedTensor, offsets, sizes, strides);
+        loc, transposedTensorType, hoistedPackedTensor, offsets, sizes,
+        strides);
     maybeTransposeOp = makeTransposeOp(rewriter, loc, paddedTensor,
                                        outputTensor, transposeVector);
     paddedTensor = maybeTransposeOp.getResult(0);
@@ -638,7 +644,7 @@ static PackingResult buildPackingLoopNestImpl(
     // Step 4. Create InsertSliceOp at the innermost loop level, inserting an
     // optionally transposed padded slice into the packed tensor.
     Value inserted = rewriter.create<tensor::InsertSliceOp>(
-        loc, paddedTensor, packedTensor, offsets, sizes, strides);
+        loc, paddedTensor, hoistedPackedTensor, offsets, sizes, strides);
 
     // Step 5. Iteratively pop the stack and propagate the yield.
     Value valueToYield = inserted;
@@ -655,7 +661,7 @@ static PackingResult buildPackingLoopNestImpl(
       sizes,
       strides,
       clonedLoopIvs,
-      leadingPackedTensorIndexings,
+      leadingHoistedPackedTensorIndexings,
       maybeTransposeOp,
       cast<tensor::PadOp>(bvm.lookup(opToHoist.getResult()).getDefiningOp())};
 }
@@ -688,7 +694,7 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl(
   SmallVector<int64_t> packedShape(nPackedLoops, ShapedType::kDynamic);
   // TODO: go grab dims when needed, atm tensor::PadOp yields a static tensor.
   llvm::append_range(packedShape, transposedTensorType->getShape());
-  auto packedTensorType = RankedTensorType::get(
+  auto hoistedPackedTensorType = RankedTensorType::get(
       packedShape, transposedTensorType->getElementType());
 
   // Set the insertion point right before the outer loop and start packing.
@@ -696,10 +702,10 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl(
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(outerLoop);
   SmallVector<Value> dynamicTensorSizes =
-      analysis.getPackedTensorSizes(rewriter, loc);
+      analysis.getHoistedPackedTensorSizes(rewriter, loc);
   auto emptyOp = rewriter.create<tensor::EmptyOp>(
-      loc, packedTensorType.getShape(), packedTensorType.getElementType(),
-      dynamicTensorSizes);
+      loc, hoistedPackedTensorType.getShape(),
+      hoistedPackedTensorType.getElementType(), dynamicTensorSizes);
 
   return buildPackingLoopNestImpl(rewriter, bvm, opToHoist, transposeVector,
                                   *transposedTensorType, emptyOp, analysis);
@@ -727,14 +733,71 @@ FailureOr<PackingResult> mlir::linalg::detail::buildPackingLoopNest(
 // hoistPaddingOnTensors Implementation.
 //===----------------------------------------------------------------------===//
 
-// If the original consumer of `sliceOp` was a `forOp` (i.e. through an iter
-// arg), propagate the `packedTensor` value through the same iter arg.
-// TODO: for multiple loops we need to track the use to the innermost loop.
-static Value padThroughLoopIterArg(RewriterBase &rewriter, Value packedTensor,
-                                   tensor::ExtractSliceOp sliceOp,
-                                   scf::ForOp forOp) {
+/// Return true if we can walk back the use-def chain from `extractSliceOp` to
+/// expectedSource going through DestinationStyleOpInterface inits only.
+/// This is a poor man's analysis that is sufficient to check the extractSliceOp
+/// the matches tensor.pad we want to hoist.
+/// In the future, it will be easier to ensure this with a matching symmetric
+/// tensor.unpad op.
+static bool tracesBackToExpectedValue(tensor::ExtractSliceOp extractSliceOp,
+                                      Value expectedSource) {
+  LLVM_DEBUG(DBGS() << "Start tracesBackToExpectedValue on: " << extractSliceOp
+                    << "\n");
+  LLVM_DEBUG(DBGS() << "--with extractSlice: " << extractSliceOp << "\n");
+  Value source = extractSliceOp.getSource();
+  LLVM_DEBUG(DBGS() << "--with starting source: " << source << "\n");
+  while (source && source != expectedSource) {
+    auto destOp =
+        dyn_cast_or_null<DestinationStyleOpInterface>(source.getDefiningOp());
+    if (!destOp)
+      break;
+    LLVM_DEBUG(DBGS() << "--step dest op: " << destOp << "\n");
+    source =
+        destOp.getDpsInitOperand(source.cast<OpResult>().getResultNumber())
+            ->get();
+  }
+  LLVM_DEBUG(DBGS() << "--final source: " << source << "\n");
+  LLVM_DEBUG(DBGS() << "--expected source: " << expectedSource << "\n");
+  return source == expectedSource;
+}
+
+/// If the original consumer of `outerSliceOp` was a `forOp` (i.e. through an
+/// iter arg), propagate the `hoistedPackedTensor` value through the same iter
+/// arg.
+/// TODO: for multiple loops we need to track the use to the innermost loop.
+///
+/// Match:
+/// ```
+///   %outerSliceOp = tensor.extract_slice ..
+///   %f = scf.for ... iter_args(%arg0 = %outerSliceOp) {
+///     %hoistedPackedTensor = tensor.pad %arg0
+///     %1 = compute %hoistedPackedTensor
+///     %2 = tensor.extract_slice %1
+///     scf.yield %2
+///   }
+/// ```
+///
+/// and rewrite as:
+/// ```
+///   %outerSliceOp = tensor.extract_slice ..
+///   %hoistedPackedTensor = tensor.pad %outerSliceOp
+///   %f = scf.for ... iter_args(%arg0 = %hoistedPackedTensor) {
+///     %1 = compute %arg0
+///     scf.yield %1
+///   }
+///   %2 = tensor.extract_slice %forOp
+/// ```
+///
+/// Return null when no rewrite happened.
+static tensor::ExtractSliceOp
+padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
+                      Value hoistedPackedTensor,
+                      tensor::ExtractSliceOp outerSliceOp, scf::ForOp forOp) {
+  LLVM_DEBUG(DBGS() << "Start padThroughLoopIterArg on: " << forOp << "\n");
+  LLVM_DEBUG(DBGS() << "--paddedValueBeforeHoisting: "
+                    << paddedValueBeforeHoisting << "\n");
   OpOperand *pUse = nullptr;
-  for (OpOperand &use : sliceOp->getUses()) {
+  for (OpOperand &use : outerSliceOp->getUses()) {
     if (use.getOwner() == forOp) {
       assert(!pUse && "Multiple slice uses in the for loop");
       pUse = &use;
@@ -742,20 +805,67 @@ static Value padThroughLoopIterArg(RewriterBase &rewriter, Value packedTensor,
   }
   assert(pUse && "No slice use in the for loop");
   OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPointAfter(packedTensor.getDefiningOp());
-  Value casted = rewriter.create<tensor::CastOp>(
-      packedTensor.getLoc(), pUse->get().getType(), packedTensor);
+  rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp());
 
-  std::optional<unsigned> operandNumber =
+  std::optional<unsigned> maybeOperandNumber =
       forOp.getIterArgNumberForOpOperand(*pUse);
-  assert(operandNumber.has_value() && "expected a proper iter arg number");
+  assert(maybeOperandNumber.has_value() && "expected a proper iter arg number");
+
+  int64_t operandNumber = maybeOperandNumber.value();
+  auto yieldOp = cast<scf::YieldOp>(forOp.getBody(0)->getTerminator());
+  auto yieldingExtractSliceOp = yieldOp->getOperand(operandNumber)
+                                    .getDefiningOp<tensor::ExtractSliceOp>();
+  if (!yieldingExtractSliceOp)
+    return tensor::ExtractSliceOp();
+
+  // Poor man's analysis sufficient to ensure extractSlice matches tensor.pad.
+  // In the future, it will be easier to ensure this with a matching symmetric
+  // tensor.unpad op.
+  if (!tracesBackToExpectedValue(yieldingExtractSliceOp,
+                                 paddedValueBeforeHoisting))
+    return tensor::ExtractSliceOp();
 
   SmallVector<Value> initArgs = forOp.getInitArgs();
-  initArgs[operandNumber.value()] = casted;
-  rewriter.startRootUpdate(forOp);
-  forOp.getInitArgsMutable().assign(initArgs);
-  rewriter.finalizeRootUpdate(forOp);
-  return forOp.getRegionIterArgForOpOperand(*pUse);
+  initArgs[operandNumber] = hoistedPackedTensor;
+  SmallVector<Value> yieldOperands = yieldOp.getOperands();
+  yieldOperands[operandNumber] = yieldingExtractSliceOp.getSource();
+
+  int64_t numOriginalForOpResults = initArgs.size();
+  LLVM_DEBUG(DBGS() << "numOriginalForOpResults: " << numOriginalForOpResults
+                    << "\n");
+  tensor::ExtractSliceOp extracted;
+  {
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPointAfter(forOp);
+    extracted = rewriter.create<tensor::ExtractSliceOp>(
+        hoistedPackedTensor.getLoc(), hoistedPackedTensor,
+        outerSliceOp.getMixedOffsets(), outerSliceOp.getMixedSizes(),
+        outerSliceOp.getMixedStrides());
+    rewriter.replaceAllUsesWith(forOp.getResult(operandNumber), extracted);
+  }
+  scf::ForOp newForOp =
+      replaceLoopWithNewYields(rewriter, forOp, initArgs, yieldOperands);
+
+  LLVM_DEBUG(DBGS() << "newForOp results: " << newForOp.getNumResults()
+                    << "\n");
+  LLVM_DEBUG(DBGS() << "replace source of: " << extracted << "\n");
+  LLVM_DEBUG(DBGS() << "with result #"
+                    << numOriginalForOpResults + operandNumber
+                    << " of forOp, giving us: " << extracted << "\n");
+  rewriter.startRootUpdate(extracted);
+  extracted.getSourceMutable().assign(
+      newForOp.getResult(numOriginalForOpResults + operandNumber));
+  rewriter.finalizeRootUpdate(extracted);
+
+  LLVM_DEBUG(DBGS() << "replace uses of: " << paddedValueBeforeHoisting
+                    << "\n");
+  LLVM_DEBUG(DBGS() << "with region iter arg #"
+                    << numOriginalForOpResults + operandNumber << "\n");
+  rewriter.replaceAllUsesWith(
+      paddedValueBeforeHoisting,
+      newForOp.getRegionIterArg(numOriginalForOpResults + operandNumber));
+
+  return extracted;
 }
 
 /// Produce a tensor extracted from the packingResult. This can be used as a
@@ -781,7 +891,7 @@ static Value replaceByPackingResult(RewriterBase &rewriter,
   scf::ForOp outerLoop = analysis.outermostEnclosingForOp;
   ArrayRef<scf::ForOp> packingLoops = analysis.packingLoops;
 
-  Value packedTensor;
+  Value hoistedPackedTensor;
   SmallVector<Value> loopIterationCounts;
   SmallVector<OpFoldResult> offsets(nPackedLoops + paddedRank,
                                     rewriter.getIndexAttr(0));
@@ -798,29 +908,29 @@ static Value replaceByPackingResult(RewriterBase &rewriter,
     // offsets = [maybe_leading_ivs = originalLoopIvs, 0 .. 0].
     std::copy(loopIterationCounts.begin(), loopIterationCounts.end(),
               offsets.begin());
-    packedTensor =
+    hoistedPackedTensor =
         scf::getForInductionVarOwner(packingResult.clonedLoopIvs.front())
             ->getResult(0);
   } else {
     // If no loops were created, this is just hoisting without packing.
-    packedTensor = bvm.lookup(opToHoist.getResult());
+    hoistedPackedTensor = bvm.lookup(opToHoist.getResult());
   }
 
-  LLVM_DEBUG(DBGS() << "packedTensor: " << packedTensor << "\n");
+  LLVM_DEBUG(DBGS() << "hoistedPackedTensor: " << hoistedPackedTensor << "\n");
 
   // If the consumer of `padOp` was a `forOp`, propagate through iter args.
   scf::ForOp forOp = analysis.padConsumingForOp;
   if (forOp) {
-    packedTensor =
-        padThroughLoopIterArg(rewriter, packedTensor, analysis.sliceOp, forOp);
+    return padThroughLoopIterArg(rewriter, opToHoist, hoistedPackedTensor,
+                                 analysis.sliceOp, forOp);
   }
 
   // offsets = [maybe_leading_ivs, 0 .. 0].
   // sizes = [1 .. 1, transposedShape] (defined above).
   // strides = [1 .. 1] (defined above)
   return rewriter.create<tensor::ExtractSliceOp>(
-      loc, transposedTensorType, packedTensor, offsets, packingResult.sizes,
-      packingResult.strides);
+      loc, transposedTensorType, hoistedPackedTensor, offsets,
+      packingResult.sizes, packingResult.strides);
 }
 
 FailureOr<Value> mlir::linalg::hoistPaddingOnTensors(
index fd0d309..871163a 100644 (file)
@@ -161,12 +161,13 @@ func.func @pad_and_hoist_init(
   //      CHECK: scf.for %{{.*}} -> (tensor<24x25xf32>) {
   //      CHECK:   %[[PADDED:.*]] = tensor.pad %{{.*}} 
   //      CHECK:     : tensor<?x25xf32> to tensor<5x25xf32>
-  //      CHECK:   scf.for %{{.*}} iter_args(%[[INNER_PADDED:[0-9a-zA-Z]*]] = %[[PADDED]]) -> (tensor<5x25xf32>)
+  //      CHECK:   %[[SCF_YIELD:.*]] = scf.for %{{.*}} iter_args(%[[INNER_PADDED:[0-9a-zA-Z]*]] = %[[PADDED]]) -> (tensor<5x25xf32>)
   //      CHECK:     %[[RES:.*]] = linalg.matmul {{.*}} outs(%[[INNER_PADDED]]
   // CHECK-SAME:       : tensor<5x25xf32>
   //      CHECK:     scf.yield %[[RES]] : tensor<5x25xf32>
-  //      CHECK:   %[[CAST:.*]] = tensor.cast %{{.*}} : tensor<5x25xf32> to tensor<?x25xf32>
-  //      CHECK:   tensor.insert_slice %[[CAST]] into %{{.*}}[%{{.*}}, 0] [%{{.*}}, 25] [1, 1]
+  //      CHECK:   %[[EXTRACTED:.*]] = tensor.extract_slice %[[SCF_YIELD]][%{{.*}}, 0] [%{{.*}}, 25] [1, 1]
+  // CHECK-SAME:     : tensor<5x25xf32> to tensor<?x25xf32>
+  //      CHECK:   tensor.insert_slice %[[EXTRACTED]] into %{{.*}}[%{{.*}}, 0] [%{{.*}}, 25] [1, 1]
   // CHECK-SAME:     : tensor<?x25xf32> into tensor<24x25xf32>
   %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
   func.return %0 : tensor<24x25xf32>