[mlir][linalg] Expose lowerPack and lowerUnPack utils.
authorHanhan Wang <hanchung@google.com>
Fri, 21 Apr 2023 20:25:33 +0000 (13:25 -0700)
committerHanhan Wang <hanchung@google.com>
Fri, 21 Apr 2023 22:23:16 +0000 (15:23 -0700)
Reviewed By: qcolombet

Differential Revision: https://reviews.llvm.org/D148867

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

index 52982c3..73e830d 100644 (file)
@@ -907,6 +907,27 @@ FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
     GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
     RewriterBase &rewriter);
 
+struct LowerPackResult {
+  tensor::PadOp padOp;
+  tensor::ExpandShapeOp expandShapeOp;
+  linalg::TransposeOp transposeOp;
+};
+
+/// Rewrite pack as pad + reshape + transpose.
+FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
+                                     tensor::PackOp packOp);
+
+struct LowerUnPackOpResult {
+  tensor::EmptyOp emptyOp;
+  linalg::TransposeOp transposeOp;
+  tensor::CollapseShapeOp collapseShapeOp;
+  tensor::ExtractSliceOp extractSliceOp;
+};
+
+/// Rewrite pack as empty + transpose + reshape + extract_slice.
+FailureOr<LowerUnPackOpResult> lowerUnPack(RewriterBase &rewriter,
+                                           tensor::UnPackOp unPackOp);
+
 /// Struct to hold the result of a `pack` call.
 struct PackResult {
   SmallVector<tensor::PackOp> packOps;
index dbadabd..f113d3a 100644 (file)
@@ -748,126 +748,6 @@ LogicalResult transform::InterchangeOp::verify() {
 // LowerPackOp
 //===----------------------------------------------------------------------===//
 
-struct LowerPackResult {
-  tensor::PadOp padOp;
-  tensor::ExpandShapeOp expandShapeOp;
-  linalg::TransposeOp transposeOp;
-};
-
-/// Rewrite pack as pad + reshape + transpose.
-static FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
-                                            tensor::PackOp packOp) {
-  // 1. Filter out NYI cases.
-  if (!packOp.getOuterDimsPerm().empty())
-    return rewriter.notifyMatchFailure(packOp, "outer dims perm NYI");
-
-  auto packedTensorType =
-      packOp->getResultTypes().front().cast<RankedTensorType>();
-  if (!packedTensorType.hasStaticShape()) {
-    return rewriter.notifyMatchFailure(
-        packOp,
-        "non-static shape NYI, needs a more powerful tensor.expand_shape op");
-  }
-
-  Location loc = packOp->getLoc();
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(packOp);
-
-  // 2. Compute the permutation vector to move the last `numPackedDims` into the
-  // `innerPosDims` of a shape of rank `packedRank`.
-  int64_t numPackedDims = packOp.getInnerDimsPos().size();
-  int64_t packedRank = packedTensorType.getRank();
-  auto lastDims = llvm::to_vector(
-      llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
-  PackingMetadata packingMetadata = computePackingMetadata(
-      packedTensorType.getRank(), packOp.getInnerDimsPos());
-  SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
-      packedRank, lastDims, packingMetadata.insertPositions);
-
-  // 3. Compute the stripMinedShape: this is the packed shape before any outer
-  // or inner permutations have been applied.
-  SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
-  applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm);
-
-  // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
-  RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
-      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
-      packingMetadata.reassociations);
-  Value paddingValue = packOp.getPaddingValue();
-  if (!paddingValue) {
-    paddingValue = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
-  }
-  auto padOp =
-      tensor::createPadHighOp(collapsed, packOp.getSource(), paddingValue,
-                              /*nofold=*/false, loc, rewriter);
-
-  LLVM_DEBUG(
-      DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
-                                                DBGS() << "insertPositions: ");
-      DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
-                                      DBGS() << "packedShape: ");
-      DBGSNL();
-      llvm::interleaveComma(lastDimsToInsertPositionsPerm,
-                            DBGS() << "lastDimsToInsertPositionsPerm: ");
-      DBGSNL(); llvm::interleaveComma(
-          packingMetadata.reassociations, DBGS() << "reassociations: ",
-          [&](ReassociationIndices ri) {
-            llvm::interleaveComma(ri, llvm::dbgs() << "|");
-          });
-      DBGSNL();
-      llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
-      DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
-
-  if (packOp.isLikePad()) {
-    // This pack is just a plain pad.
-    // Just insert the pad in the higher ranked tensor.
-    auto emptyOp =
-        rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
-    // Offsets.
-    SmallVector<OpFoldResult> zeros(packedRank, rewriter.getIndexAttr(0));
-    // Strides.
-    SmallVector<OpFoldResult> ones(packedRank, rewriter.getIndexAttr(1));
-    SmallVector<OpFoldResult> sizes =
-        getMixedDimensions(rewriter, loc, packOp.getDest());
-
-    auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
-        loc, /*source=*/padOp, /*dest=*/emptyOp,
-        /*offsets=*/zeros, sizes,
-        /*strides=*/ones);
-
-    LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
-
-    rewriter.replaceOp(packOp, insertSliceOp->getResults());
-
-    return LowerPackResult{padOp, /*reshapeOp=*/nullptr,
-                           /*transposeOp=*/nullptr};
-  }
-  // 5. Expand from the padded result to the stripMinedShape.
-  auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
-      loc,
-      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
-      padOp.getResult(), packingMetadata.reassociations);
-
-  // 6. Transpose stripMinedShape to packedShape.
-  SmallVector<int64_t> insertPositionsToLastDimsPerm = computePermutationVector(
-      packedRank, packingMetadata.insertPositions, lastDims);
-  auto transposeOp = rewriter.create<linalg::TransposeOp>(
-      loc, reshapeOp.getResult(), packOp.getDest(),
-      insertPositionsToLastDimsPerm);
-
-  LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
-             DBGS() << "reshape op: " << reshapeOp; DBGSNL();
-             llvm::interleaveComma(insertPositionsToLastDimsPerm,
-                                   DBGS() << "insertPositionsToLastDimsPerm: ");
-             DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
-
-  // 7. Replace packOp by transposeOp.
-  rewriter.replaceOp(packOp, transposeOp->getResults());
-
-  return LowerPackResult{padOp, reshapeOp, transposeOp};
-}
-
 DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
     tensor::PackOp target, transform::ApplyToEachResultList &transformResults,
     transform::TransformState &state) {
@@ -889,115 +769,6 @@ DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
 // LowerUnPackOp
 //===----------------------------------------------------------------------===//
 
-struct LowerUnPackOpResult {
-  tensor::EmptyOp emptyOp;
-  linalg::TransposeOp transposeOp;
-  tensor::CollapseShapeOp collapseShapeOp;
-  tensor::ExtractSliceOp extractSliceOp;
-};
-
-/// Rewrite pack as empty + transpose + reshape + extract_slice.
-static FailureOr<LowerUnPackOpResult> lowerUnPack(RewriterBase &rewriter,
-                                                  tensor::UnPackOp unPackOp) {
-  // 1. Filter out NYI cases.
-  if (!unPackOp.getOuterDimsPerm().empty())
-    return rewriter.notifyMatchFailure(unPackOp, "outer dims perm NYI");
-
-  RankedTensorType packedTensorType = unPackOp.getSourceType();
-  if (!packedTensorType.hasStaticShape()) {
-    return rewriter.notifyMatchFailure(
-        unPackOp,
-        "non-static shape NYI, needs a more powerful tensor.expand_shape op");
-  }
-
-  Location loc = unPackOp->getLoc();
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(unPackOp);
-
-  int64_t packedRank = packedTensorType.getRank();
-
-  OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
-  auto destTensorType = unPackOp.getDest().getType().cast<RankedTensorType>();
-  if (unPackOp.isLikeUnPad()) {
-    // This unpack is just a plain unpad.
-    // Just extract the slice from the higher ranked tensor.
-    ArrayRef<int64_t> destShape = destTensorType.getShape();
-    // The inner dimensions stay the same as the destination tensor, but the
-    // outer ones are additional 1s.
-    SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one);
-    sizes.append(getMixedDimensions(rewriter, loc, unPackOp.getDest()));
-
-    auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
-        loc, destTensorType, unPackOp.getSource(),
-        SmallVector<OpFoldResult>(packedRank, zero), sizes,
-        SmallVector<OpFoldResult>(packedRank, one));
-
-    rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
-
-    return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
-                               /*reshapeOp=*/nullptr, extractSliceOp};
-  }
-  // 2. Compute the permutation vector to move the last `numPackedDims` into
-  // the `innerPosDims` of a shape of rank `packedRank`.
-  int64_t numPackedDims = unPackOp.getInnerDimsPos().size();
-  auto lastDims = llvm::to_vector(
-      llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
-  PackingMetadata packingMetadata =
-      computePackingMetadata(packedRank, unPackOp.getInnerDimsPos());
-  SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
-      packedRank, lastDims, packingMetadata.insertPositions);
-
-  // 3. Compute the stripMinedShape: this is the packed shape without outer and
-  // inner permutations.
-  SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
-  applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm);
-
-  // 4. Transpose packedShape to stripMinedShape.
-  RankedTensorType stripMinedTensorType =
-      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
-  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
-      stripMinedTensorType, packingMetadata.reassociations);
-  auto emptyOp =
-      rewriter.create<tensor::EmptyOp>(loc, stripMinedTensorType, ValueRange{});
-  auto transposeOp = rewriter.create<linalg::TransposeOp>(
-      loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm);
-
-  LLVM_DEBUG(
-      DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
-                                                DBGS() << "insertPositions: ");
-      DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
-                                      DBGS() << "packedShape: ");
-      DBGSNL();
-      llvm::interleaveComma(lastDimsToInsertPositionsPerm,
-                            DBGS() << "lastDimsToInsertPositionsPerm: ");
-      DBGSNL(); llvm::interleaveComma(
-          packingMetadata.reassociations, DBGS() << "reassociations: ",
-          [&](ReassociationIndices ri) {
-            llvm::interleaveComma(ri, llvm::dbgs() << "|");
-          });
-      DBGSNL();
-      llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
-      DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL(););
-
-  // 5. Collapse from the stripMinedShape to the padded result.
-  auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>(
-      loc, collapsedType, transposeOp->getResult(0),
-      packingMetadata.reassociations);
-
-  // 6. ExtractSlice
-  int64_t destRank = destTensorType.getRank();
-  auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
-      loc, destTensorType, reshapeOp->getResult(0),
-      SmallVector<OpFoldResult>(destRank, zero),
-      tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)),
-      SmallVector<OpFoldResult>(destRank, one));
-
-  // 7. Replace unPackOp by extractSliceOp.
-  rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
-
-  return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
-}
-
 DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
     tensor::UnPackOp target, transform::ApplyToEachResultList &transformResults,
     transform::TransformState &state) {
index 166f426..4d5ef0e 100644 (file)
@@ -477,6 +477,220 @@ private:
 
 } // namespace
 
+FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
+                                             tensor::PackOp packOp) {
+  // 1. Filter out NYI cases.
+  if (!packOp.getOuterDimsPerm().empty())
+    return rewriter.notifyMatchFailure(packOp, "outer dims perm NYI");
+
+  auto packedTensorType =
+      packOp->getResultTypes().front().cast<RankedTensorType>();
+  if (!packedTensorType.hasStaticShape()) {
+    return rewriter.notifyMatchFailure(
+        packOp,
+        "non-static shape NYI, needs a more powerful tensor.expand_shape op");
+  }
+
+  Location loc = packOp->getLoc();
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(packOp);
+
+  // 2. Compute the permutation vector to move the last `numPackedDims` into the
+  // `innerPosDims` of a shape of rank `packedRank`.
+  int64_t numPackedDims = packOp.getInnerDimsPos().size();
+  int64_t packedRank = packedTensorType.getRank();
+  auto lastDims = llvm::to_vector(
+      llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
+  PackingMetadata packingMetadata = computePackingMetadata(
+      packedTensorType.getRank(), packOp.getInnerDimsPos());
+  SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
+      packedRank, lastDims, packingMetadata.insertPositions);
+
+  // 3. Compute the stripMinedShape: this is the packed shape before any outer
+  // or inner permutations have been applied.
+  SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
+  applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm);
+
+  // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
+  RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
+      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
+      packingMetadata.reassociations);
+  Value paddingValue = packOp.getPaddingValue();
+  if (!paddingValue) {
+    paddingValue = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
+  }
+  auto padOp =
+      tensor::createPadHighOp(collapsed, packOp.getSource(), paddingValue,
+                              /*nofold=*/false, loc, rewriter);
+
+  LLVM_DEBUG(
+      DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
+                                                DBGS() << "insertPositions: ");
+      DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
+                                      DBGS() << "packedShape: ");
+      DBGSNL();
+      llvm::interleaveComma(lastDimsToInsertPositionsPerm,
+                            DBGS() << "lastDimsToInsertPositionsPerm: ");
+      DBGSNL(); llvm::interleaveComma(
+          packingMetadata.reassociations, DBGS() << "reassociations: ",
+          [&](ReassociationIndices ri) {
+            llvm::interleaveComma(ri, llvm::dbgs() << "|");
+          });
+      DBGSNL();
+      llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
+      DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
+
+  if (packOp.isLikePad()) {
+    // This pack is just a plain pad.
+    // Just insert the pad in the higher ranked tensor.
+    auto emptyOp =
+        rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
+    // Offsets.
+    SmallVector<OpFoldResult> zeros(packedRank, rewriter.getIndexAttr(0));
+    // Strides.
+    SmallVector<OpFoldResult> ones(packedRank, rewriter.getIndexAttr(1));
+    SmallVector<OpFoldResult> sizes =
+        getMixedDimensions(rewriter, loc, packOp.getDest());
+
+    auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+        loc, /*source=*/padOp, /*dest=*/emptyOp,
+        /*offsets=*/zeros, sizes,
+        /*strides=*/ones);
+
+    LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
+
+    rewriter.replaceOp(packOp, insertSliceOp->getResults());
+
+    return LowerPackResult{padOp, /*reshapeOp=*/nullptr,
+                           /*transposeOp=*/nullptr};
+  }
+  // 5. Expand from the padded result to the stripMinedShape.
+  auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
+      loc,
+      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
+      padOp.getResult(), packingMetadata.reassociations);
+
+  // 6. Transpose stripMinedShape to packedShape.
+  SmallVector<int64_t> insertPositionsToLastDimsPerm = computePermutationVector(
+      packedRank, packingMetadata.insertPositions, lastDims);
+  auto transposeOp = rewriter.create<linalg::TransposeOp>(
+      loc, reshapeOp.getResult(), packOp.getDest(),
+      insertPositionsToLastDimsPerm);
+
+  LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
+             DBGS() << "reshape op: " << reshapeOp; DBGSNL();
+             llvm::interleaveComma(insertPositionsToLastDimsPerm,
+                                   DBGS() << "insertPositionsToLastDimsPerm: ");
+             DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
+
+  // 7. Replace packOp by transposeOp.
+  rewriter.replaceOp(packOp, transposeOp->getResults());
+
+  return LowerPackResult{padOp, reshapeOp, transposeOp};
+}
+
+FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
+                                                   tensor::UnPackOp unPackOp) {
+  // 1. Filter out NYI cases.
+  if (!unPackOp.getOuterDimsPerm().empty())
+    return rewriter.notifyMatchFailure(unPackOp, "outer dims perm NYI");
+
+  RankedTensorType packedTensorType = unPackOp.getSourceType();
+  if (!packedTensorType.hasStaticShape()) {
+    return rewriter.notifyMatchFailure(
+        unPackOp,
+        "non-static shape NYI, needs a more powerful tensor.expand_shape op");
+  }
+
+  Location loc = unPackOp->getLoc();
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(unPackOp);
+
+  int64_t packedRank = packedTensorType.getRank();
+
+  OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
+  auto destTensorType = unPackOp.getDest().getType().cast<RankedTensorType>();
+  if (unPackOp.isLikeUnPad()) {
+    // This unpack is just a plain unpad.
+    // Just extract the slice from the higher ranked tensor.
+    ArrayRef<int64_t> destShape = destTensorType.getShape();
+    // The inner dimensions stay the same as the destination tensor, but the
+    // outer ones are additional 1s.
+    SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one);
+    sizes.append(getMixedDimensions(rewriter, loc, unPackOp.getDest()));
+
+    auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
+        loc, destTensorType, unPackOp.getSource(),
+        SmallVector<OpFoldResult>(packedRank, zero), sizes,
+        SmallVector<OpFoldResult>(packedRank, one));
+
+    rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
+
+    return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
+                               /*reshapeOp=*/nullptr, extractSliceOp};
+  }
+  // 2. Compute the permutation vector to move the last `numPackedDims` into
+  // the `innerPosDims` of a shape of rank `packedRank`.
+  int64_t numPackedDims = unPackOp.getInnerDimsPos().size();
+  auto lastDims = llvm::to_vector(
+      llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
+  PackingMetadata packingMetadata =
+      computePackingMetadata(packedRank, unPackOp.getInnerDimsPos());
+  SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
+      packedRank, lastDims, packingMetadata.insertPositions);
+
+  // 3. Compute the stripMinedShape: this is the packed shape without outer and
+  // inner permutations.
+  SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
+  applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm);
+
+  // 4. Transpose packedShape to stripMinedShape.
+  RankedTensorType stripMinedTensorType =
+      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
+  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+      stripMinedTensorType, packingMetadata.reassociations);
+  auto emptyOp =
+      rewriter.create<tensor::EmptyOp>(loc, stripMinedTensorType, ValueRange{});
+  auto transposeOp = rewriter.create<linalg::TransposeOp>(
+      loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm);
+
+  LLVM_DEBUG(
+      DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
+                                                DBGS() << "insertPositions: ");
+      DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
+                                      DBGS() << "packedShape: ");
+      DBGSNL();
+      llvm::interleaveComma(lastDimsToInsertPositionsPerm,
+                            DBGS() << "lastDimsToInsertPositionsPerm: ");
+      DBGSNL(); llvm::interleaveComma(
+          packingMetadata.reassociations, DBGS() << "reassociations: ",
+          [&](ReassociationIndices ri) {
+            llvm::interleaveComma(ri, llvm::dbgs() << "|");
+          });
+      DBGSNL();
+      llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
+      DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL(););
+
+  // 5. Collapse from the stripMinedShape to the padded result.
+  auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>(
+      loc, collapsedType, transposeOp->getResult(0),
+      packingMetadata.reassociations);
+
+  // 6. ExtractSlice
+  int64_t destRank = destTensorType.getRank();
+  auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
+      loc, destTensorType, reshapeOp->getResult(0),
+      SmallVector<OpFoldResult>(destRank, zero),
+      tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)),
+      SmallVector<OpFoldResult>(destRank, one));
+
+  // 7. Replace unPackOp by extractSliceOp.
+  rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
+
+  return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
+}
+
 SmallVector<int64_t>
 PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
   SmallVector<int64_t> res;