[mlir][tensor][linalg] Add a pattern that generalizes tensor.unpack op.
authorHanhan Wang <hanchung@google.com>
Fri, 16 Dec 2022 21:03:10 +0000 (13:03 -0800)
committerHanhan Wang <hanchung@google.com>
Tue, 20 Dec 2022 01:52:10 +0000 (17:52 -0800)
The pattern generalizes a tensor::UnPackOp into a sequence of tensor +
Linalg ops, when the outer dims are all 1s. It uses the trick of
rank-reduced tensor.extract_slice to get the tile; transpose the tile;
extract sub tile for incomplete cases if needed; use tensor.insert_slice
to insert it to the destination tensor.

Reviewed By: tyb0807, chelini

Differential Revision: https://reviews.llvm.org/D140254

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

index 2919c65..cc2b63d 100644 (file)
@@ -882,6 +882,16 @@ struct GeneralizeOuterUnitDimsPackOpPattern
                                 PatternRewriter &rewriter) const override;
 };
 
+/// Rewrites a tensor::UnPackOp into a sequence of rank-reduced extract_slice op
+/// + transpose op + insert_slice op, where the tensor::UnPackOp has outer dims
+/// being all 1s.
+struct GeneralizeOuterUnitDimsUnPackOpPattern
+    : public OpRewritePattern<tensor::UnPackOp> {
+  using OpRewritePattern<tensor::UnPackOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp,
+                                PatternRewriter &rewriter) const override;
+};
+
 /// Populates `patterns` with patterns that vectorize tensor.pad.
 /// These patterns are meant to apply in a complementary fashion. Benefits
 /// are used to encode a certain ordering of pattern application. To avoid
index 77ea7af..20ee259 100644 (file)
@@ -512,6 +512,17 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
                                  /*nofold=*/false, loc, builder);
 }
 
+static SmallVector<int64_t>
+getPackUnpackNormalizedInnerPerm(int rank, ArrayRef<int64_t> innerDimsPos) {
+  constexpr int64_t kNonTiledMarker = -1;
+  SmallVector<int64_t> vec(rank, kNonTiledMarker);
+  for (auto [index, value] : llvm::enumerate(innerDimsPos))
+    vec[value] = index;
+  SmallVector<int64_t> perm = llvm::to_vector(llvm::make_filter_range(
+      vec, [&](int64_t v) { return v != kNonTiledMarker; }));
+  return perm;
+}
+
 LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
     tensor::PackOp packOp, PatternRewriter &rewriter) const {
   // TODO: support the case that outer dimensions are not all 1s A
@@ -556,14 +567,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
       loc, readType, input, readOffsets, readSizes, readStrides);
 
   // 2. Transpose the tile to match the inner tile order.
-  constexpr int64_t kNonTiledMarker = -1;
-  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
-  SmallVector<int64_t> vec(srcRank, kNonTiledMarker);
-  for (auto [index, value] : llvm::enumerate(innerDimsPos))
-    vec[value] = index;
-  SmallVector<int64_t> perm = llvm::to_vector(llvm::make_filter_range(
-      vec, [&](int64_t v) { return v != kNonTiledMarker; }));
-
+  SmallVector<int64_t> perm =
+      getPackUnpackNormalizedInnerPerm(srcRank, packOp.getInnerDimsPos());
   SmallVector<int64_t> transpShape = readShape;
   applyPermutationToVector<int64_t>(transpShape, perm);
 
@@ -587,6 +592,81 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
   return success();
 }
 
+LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
+    tensor::UnPackOp unpackOp, PatternRewriter &rewriter) const {
+  int64_t srcRank = unpackOp.getSourceRank();
+  int64_t destRank = unpackOp.getDestRank();
+  ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
+  if (llvm::any_of(srcShape.take_front(destRank),
+                   [](int64_t val) { return val != 1; })) {
+    return rewriter.notifyMatchFailure(
+        unpackOp, "require the outer dimension of the result are all 1s");
+  }
+
+  // 1. Use rank-reduced tensor.extract_slice op to extract the tile.
+  Location loc = unpackOp.getLoc();
+  Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
+  Attribute oneIdxAttr = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
+  SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
+
+  auto mixedTiles = unpackOp.getMixedTiles();
+  SmallVector<OpFoldResult> readSizes(destRank, oneIdxAttr);
+  readSizes.append(mixedTiles.begin(), mixedTiles.end());
+
+  // Explicitly create the type for extract_slice op because the inner tile
+  // size could be 1. We want to represent the whole inner tile in this case.
+  ArrayRef<int64_t> readShape = srcShape.drop_front(destRank);
+  Type elemType = unpackOp.getSourceType().getElementType();
+  auto readType = RankedTensorType::get(readShape, elemType);
+  Value innerTile = rewriter.create<tensor::ExtractSliceOp>(
+      loc, readType, unpackOp.getSource(), readOffsets, readSizes, readStrides);
+
+  // 2. Transpose the tile to match the outer corresponding tile order.
+  ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
+  SmallVector<int64_t> perm =
+      getPackUnpackNormalizedInnerPerm(srcRank, innerDimsPos);
+  SmallVector<int64_t> transpShape(readShape);
+  applyPermutationToVector<int64_t>(transpShape, perm);
+
+  Value empty = rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
+  auto transposedOp =
+      rewriter.create<linalg::TransposeOp>(loc, innerTile, empty, perm);
+
+  // 3. Handle in-complete tiles if needed. It truncates trailing data from the
+  // transposed tile.
+  int numLoops = transpShape.size();
+  SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr);
+  SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr);
+  SmallVector<OpFoldResult> tileSizes;
+  for (int dim : innerDimsPos)
+    tileSizes.push_back(getAsOpFoldResult(
+        rewriter.createOrFold<tensor::DimOp>(loc, unpackOp.getDest(), dim)));
+
+  applyPermutationToVector<OpFoldResult>(tileSizes, perm);
+  auto partialTile = rewriter.create<tensor::ExtractSliceOp>(
+      loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides);
+
+  // 4. Insert the result to the destination tensor.
+  SmallVector<OpFoldResult> writeSizes;
+  SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
+  SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
+  DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
+      unpackOp.getDimAndTileMapping();
+  for (int i = 0, idx = 0; i < destRank; ++i) {
+    if (dimAndTileMapping.count(i))
+      writeSizes.push_back(tileSizes[idx++]);
+    else
+      writeSizes.push_back(oneIdxAttr);
+  }
+  auto insert = rewriter.create<tensor::InsertSliceOp>(
+      loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes,
+      writeStrides);
+  rewriter.replaceOp(unpackOp, insert.getResult());
+
+  return success();
+}
+
 // The following are patterns for downscaling convolution ops with size-1
 // window dimensions.
 //
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
new file mode 100644 (file)
index 0000000..f8e223d
--- /dev/null
@@ -0,0 +1,168 @@
+// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-generalize-tensor-unpack"  %s | FileCheck %s
+
+func.func @simple_KCRSsr_to_KCRS(%arg0: tensor<1x1x1x1x8x32xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32> {
+  %0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x1x1x8x32xf32> -> tensor<1x1x32x8xf32>
+  return %0 : tensor<1x1x32x8xf32>
+}
+// CHECK-LABEL: func.func @simple_KCRSsr_to_KCRS
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
+// CHECK:         %[[TRANSP:.+]] =  linalg.transpose
+// CHECK-SAME:      ins(%[[TILE]] : tensor<8x32xf32>)
+// CHECK-SAME:      outs(%[[EMPTY]] : tensor<32x8xf32>)
+// CHECK-SAME:      permutation = [1, 0]
+// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-SAME:      [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
+// CHECK:         return %[[INSERT]]
+
+// -----
+
+func.func @simple_unpack_and_extract_slice(%input: tensor<1x1x8x2xf32>, %output: tensor<5x1xf32>) -> tensor<5x1xf32> {
+  %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<1x1x8x2xf32> -> tensor<5x1xf32>
+  return %0 : tensor<5x1xf32>
+}
+// CHECK-LABEL: func.func @simple_unpack_and_extract_slice
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
+// CHECK:         %[[TRANSP:.+]] =  linalg.transpose
+// CHECK-SAME:      ins(%[[TILE]] : tensor<8x2xf32>)
+// CHECK-SAME:      outs(%[[EMPTY]] : tensor<8x2xf32>)
+// CHECK-SAME:      permutation = [0, 1]
+//                They have the same type, so the insert_slice op is folded
+//                away.
+// CHECK:         %[[SLICE:.+]] = tensor.extract_slice %[[TRANSP]][0, 0] [5, 1] [1, 1]
+// CHECK:         return %[[SLICE]]
+
+// -----
+
+func.func @simple_CNnc_to_NC(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<32x8xf32>) -> tensor<32x8xf32>{
+  %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<1x1x32x8xf32> -> tensor<32x8xf32>
+  return %0 : tensor<32x8xf32>
+}
+// CHECK-LABEL: func.func @simple_CNnc_to_NC
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
+// CHECK:         %[[TRANSP:.+]] =  linalg.transpose
+// CHECK-SAME:      ins(%[[TILE]] : tensor<32x8xf32>)
+// CHECK-SAME:      outs(%[[EMPTY]] : tensor<32x8xf32>)
+// CHECK-SAME:      permutation = [0, 1]
+//                They have the same type, so the insert_slice op is folded
+//                away.
+// CHECK:         return %[[TRANSP]]
+
+// -----
+
+// RUN: mlir-opt -split-input-file --test-transform-dialect-interpreter --canonicalize --test-linalg-transform-patterns="test-generalize-tensor-unpack"  %s | FileCheck %s --check-prefix=CHECK-TRANS
+
+func.func @KCRSsr_to_KCRS(%arg0: tensor<1x1x4x8x8x32xf32>, %arg1: tensor<1x1x128x64xf32>) -> tensor<1x1x128x64xf32> {
+  %0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x4x8x8x32xf32> -> tensor<1x1x128x64xf32>
+  return %0 : tensor<1x1x128x64xf32>
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1
+    %1, %loops:4 = transform.structured.tile_to_scf_for %0 [1, 1, 32, 8]
+}
+// CHECK-TRANS-DAG:   #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)>
+// CHECK-TRANS-DAG:   #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
+// CHECK-TRANS:       func.func @KCRSsr_to_KCRS
+// CHECK-TRANS-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-TRANS-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-TRANS:         %{{.+}} = scf.for %[[R:[a-zA-Z0-9]+]] =
+// CHECK-TRANS:           %{{.+}} = scf.for %[[S:[a-zA-Z0-9]+]] =
+// CHECK-TRANS:             %[[IN_R:.+]] = affine.apply #[[MAP0]](%[[R]])
+// CHECK-TRANS:             %[[IN_S:.+]] = affine.apply #[[MAP1]](%[[S]])
+// CHECK-TRANS:             %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
+// CHECK-TRANS-SAME:          [0, 0, %[[IN_R]], %[[IN_S]], 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
+// CHECK-TRANS:             %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
+// CHECK-TRANS-SAME:          [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] : tensor<1x1x1x1x8x32xf32> to tensor<8x32xf32>
+// CHECK-TRANS:             %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
+// CHECK-TRANS:             %[[TRANSP:.+]] =  linalg.transpose
+// CHECK-TRANS-SAME:          ins(%[[TILE]]
+// CHECK-TRANS-SAME:          outs(%[[EMPTY]]
+// CHECK-TRANS-SAME:          permutation = [1, 0]
+// CHECK-TRANS:             %{{.+}} = tensor.insert_slice %[[TRANSP]] into %{{.+}}
+
+// -----
+
+func.func @unpack_and_extract_slice(%arg0: tensor<2x8x8x2xf32>, %arg1: tensor<13x15xf32>) -> tensor<13x15xf32> {
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %arg1 : tensor<2x8x8x2xf32> -> tensor<13x15xf32>
+  return %0 : tensor<13x15xf32>
+}
+// CHECK-TRANS-DAG:   #[[MAP0:.+]] = affine_map<(d0) -> (-d0 + 13, 8)>
+// CHECK-TRANS-DAG:   #[[MAP1:.+]] = affine_map<(d0) -> (-d0 + 15, 2)>
+// CHECK-TRANS-DAG:   #[[MAP2:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
+// CHECK-TRANS-DAG:   #[[MAP3:.+]] = affine_map<(d0) -> (d0 floordiv 2)>
+// CHECK-TRANS:       func.func @unpack_and_extract_slice
+// CHECK-TRANS-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-TRANS-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-TRANS:         %{{.+}} = scf.for %[[I:[a-zA-Z0-9]+]] =
+// CHECK-TRANS:           %[[OUT_I_SZ:.+]] = affine.min #[[MAP0]](%[[I]])
+// CHECK-TRANS:           %{{.+}} = scf.for %[[J:[a-zA-Z0-9]+]] =
+// CHECK-TRANS:             %[[OUT_J_SZ:.+]] = affine.min #[[MAP1]](%[[J]])
+// CHECK-TRANS:             %[[IN_I:.+]] = affine.apply #[[MAP2]](%[[I]])
+// CHECK-TRANS:             %[[IN_J:.+]] = affine.apply #[[MAP3]](%[[J]])
+// CHECK-TRANS:             %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
+// CHECK-TRANS-SAME:          [%[[IN_I]], %[[IN_J]], 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
+// CHECK-TRANS:             %[[ITER_SLICE:.+]] = tensor.extract_slice %{{[a-zA-Z0-9]+}}
+// CHECK-TRANS-SAME:          [%[[I]], %[[J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]]
+// CHECK-TRANS:             %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
+// CHECK-TRANS-SAME:          [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1] : tensor<1x1x8x2xf32> to tensor<8x2xf32>
+// CHECK-TRANS:             %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
+// CHECK-TRANS:             %[[TRANSP:.+]] =  linalg.transpose
+// CHECK-TRANS-SAME:          ins(%[[TILE]] : tensor<8x2xf32>)
+// CHECK-TRANS-SAME:          outs(%[[EMPTY]] : tensor<8x2xf32>)
+// CHECK-TRANS-SAME:          permutation = [0, 1]
+// CHECK-TRANS:             %[[UNPACK_TILE:.+]] = tensor.extract_slice %[[TRANSP]]
+// CHECK-TRANS-SAME:          [0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
+// CHECK-TRANS:             %[[INSERT1:.+]] = tensor.insert_slice %[[UNPACK_TILE]] into %[[ITER_SLICE]]
+// CHECK-TRANS-SAME:          [0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
+// CHECK-TRANS:             %[[INSERT2:.+]] = tensor.insert_slice %[[INSERT1]] into %{{[a-zA-Z0-9]+}}
+// CHECK-TRANS-SAME:          [%[[I]], %[[J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1
+    %1, %loops:2 = transform.structured.tile_to_scf_for %0 [8, 2]
+}
+
+// -----
+
+func.func @CKkc_to_KC(%arg0: tensor<32x4x32x8xf32>, %arg1: tensor<128x256xf32>) -> tensor<128x256xf32> {
+  %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<32x4x32x8xf32> -> tensor<128x256xf32>
+  return %0 : tensor<128x256xf32>
+}
+// CHECK-TRANS-DAG:   #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)>
+// CHECK-TRANS-DAG:   #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
+// CHECK-TRANS:       func.func @CKkc_to_KC
+// CHECK-TRANS-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-TRANS-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-TRANS:         %{{.+}} = scf.for %[[K:[a-zA-Z0-9]+]] =
+// CHECK-TRANS:           %{{.+}} = scf.for %[[C:[a-zA-Z0-9]+]] =
+// CHECK-TRANS:             %[[IN_K:.+]] = affine.apply #[[MAP0]](%[[K]])
+// CHECK-TRANS:             %[[IN_C:.+]] = affine.apply #[[MAP1]](%[[C]])
+// CHECK-TRANS:             %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
+// CHECK-TRANS-SAME:          [%[[IN_C]], %[[IN_K]], 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
+// CHECK-TRANS:             %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
+// CHECK-TRANS-SAME:          [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> to tensor<32x8xf32>
+// CHECK-TRANS:             %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
+// CHECK-TRANS:             %[[TRANSP:.+]] =  linalg.transpose
+// CHECK-TRANS-SAME:          ins(%[[TILE]]
+// CHECK-TRANS-SAME:          outs(%[[EMPTY]]
+// CHECK-TRANS-SAME:          permutation = [0, 1]
+// CHECK-TRANS:             %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %{{[a-zA-Z0-9]+}}
+// CHECK-TRANS-SAME:          [%[[K]], %[[C]]] [32, 8] [1, 1]
+
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1
+    %1, %loops:2 = transform.structured.tile_to_scf_for %0 [32, 8]
+}
index 3375276..5ce43ff 100644 (file)
@@ -81,9 +81,15 @@ struct TestLinalgTransforms
       llvm::cl::init(false)};
   Option<bool> testGeneralizeTensorPackOp{
       *this, "test-generalize-tensor-pack",
-      llvm::cl::desc("Test transform that generalize pack ops into a sequence "
+      llvm::cl::desc("Test transform that generalizes pack ops into a sequence "
                      "of tensor and Linalg ops"),
       llvm::cl::init(false)};
+  Option<bool> testGeneralizeTensorUnPackOp{
+      *this, "test-generalize-tensor-unpack",
+      llvm::cl::desc(
+          "Test transform that generalizes unpack ops into a sequence "
+          "of tensor and Linalg ops"),
+      llvm::cl::init(false)};
   Option<bool> testSwapSubTensorPadTensor{
       *this, "test-swap-subtensor-padtensor",
       llvm::cl::desc("Test rewrite of subtensor(tensor.pad) into "
@@ -176,6 +182,12 @@ static void applyGeneralizeTensorPackPatterns(func::FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
+static void applyGeneralizeTensorUnPackPatterns(func::FuncOp funcOp) {
+  RewritePatternSet patterns(funcOp.getContext());
+  patterns.add<GeneralizeOuterUnitDimsUnPackOpPattern>(funcOp.getContext());
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
 static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) {
   RewritePatternSet patterns(funcOp.getContext());
   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
@@ -220,6 +232,8 @@ void TestLinalgTransforms::runOnOperation() {
     return applyGeneralizePadTensorPatterns(getOperation());
   if (testGeneralizeTensorPackOp)
     return applyGeneralizeTensorPackPatterns(getOperation());
+  if (testGeneralizeTensorUnPackOp)
+    return applyGeneralizeTensorUnPackPatterns(getOperation());
   if (testSwapSubTensorPadTensor)
     return applyExtractSliceOfPadTensorSwapPattern(getOperation());
   if (testBubbleUpExtractSliceOpPattern)