[mlir][linalg] Add support for lowering pack with outer_dims_perm.
authorHanhan Wang <hanchung@google.com>
Fri, 21 Apr 2023 20:32:54 +0000 (13:32 -0700)
committerHanhan Wang <hanchung@google.com>
Mon, 24 Apr 2023 17:39:37 +0000 (10:39 -0700)
Reviewed By: chelini, qcolombet

Differential Revision: https://reviews.llvm.org/D148845

mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
mlir/test/Dialect/Linalg/transform-lower-pack.mlir

index 3d1ae9c7121b6ecba1d011764bc3ebbcc87b7370..43a260427cf5e4cce5d7b0112ef1ac3cd6ad6b62 100644 (file)
@@ -522,6 +522,7 @@ getSimplifyCollapseShapeWithRankReducingSliceInfo(
 
 struct PackingMetadata {
   SmallVector<int64_t> insertPositions;
+  SmallVector<int64_t> outerPositions;
   SmallVector<ReassociationIndices> reassociations;
 };
 
index 4d5ef0edc9f8ac49f81d3e3fbc6f962daf09fbca..4f3f2dc0c734b76f1347e248b79cf5fae7d672c7 100644 (file)
@@ -480,9 +480,6 @@ private:
 FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
                                              tensor::PackOp packOp) {
   // 1. Filter out NYI cases.
-  if (!packOp.getOuterDimsPerm().empty())
-    return rewriter.notifyMatchFailure(packOp, "outer dims perm NYI");
-
   auto packedTensorType =
       packOp->getResultTypes().front().cast<RankedTensorType>();
   if (!packedTensorType.hasStaticShape()) {
@@ -495,21 +492,37 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(packOp);
 
-  // 2. Compute the permutation vector to move the last `numPackedDims` into the
-  // `innerPosDims` of a shape of rank `packedRank`.
+  // 2. Compute the permutation vector to shuffle packed shape into the shape
+  // before any outer or inner permutations have been applied. The permutation
+  // can be obtained from two permutations:
+  //   a) Compute the permutation vector to move the last `numPackedDims` into
+  //      the `innerPosDims` of a shape of rank `packedRank`.
+  //   b) Compute the permutation vector to move outer dims if the pack op
+  //      has outer_dims_perm.
+  // Apply (b) permutation on (a) permutation to get the final permutation.
   int64_t numPackedDims = packOp.getInnerDimsPos().size();
   int64_t packedRank = packedTensorType.getRank();
   auto lastDims = llvm::to_vector(
       llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
   PackingMetadata packingMetadata = computePackingMetadata(
       packedTensorType.getRank(), packOp.getInnerDimsPos());
-  SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
+  SmallVector<int64_t> innerPositionsPerm = computePermutationVector(
       packedRank, lastDims, packingMetadata.insertPositions);
 
+  SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
+  ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
+  if (!outerPerm.empty())
+    applyPermutationToVector(outerPos, outerPerm);
+  SmallVector<int64_t> outerPositionPerm = computePermutationVector(
+      packedRank, packingMetadata.outerPositions, outerPos);
+
+  SmallVector<int64_t> packedToStripMinedShapePerm = innerPositionsPerm;
+  applyPermutationToVector(packedToStripMinedShapePerm, outerPositionPerm);
+
   // 3. Compute the stripMinedShape: this is the packed shape before any outer
   // or inner permutations have been applied.
   SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
-  applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm);
+  applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
 
   // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
   RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
@@ -527,11 +540,17 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
   LLVM_DEBUG(
       DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
                                                 DBGS() << "insertPositions: ");
+      DBGSNL(); llvm::interleaveComma(packingMetadata.outerPositions,
+                                      DBGS() << "outerPositions: ");
       DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
                                       DBGS() << "packedShape: ");
       DBGSNL();
-      llvm::interleaveComma(lastDimsToInsertPositionsPerm,
-                            DBGS() << "lastDimsToInsertPositionsPerm: ");
+      llvm::interleaveComma(outerPositionPerm, DBGS() << "outerPositionPerm: ");
+      DBGSNL(); llvm::interleaveComma(innerPositionsPerm,
+                                      DBGS() << "innerPositionsPerm: ");
+      DBGSNL();
+      llvm::interleaveComma(packedToStripMinedShapePerm,
+                            DBGS() << "packedToStripMinedShapePerm: ");
       DBGSNL(); llvm::interleaveComma(
           packingMetadata.reassociations, DBGS() << "reassociations: ",
           [&](ReassociationIndices ri) {
@@ -572,16 +591,14 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
       padOp.getResult(), packingMetadata.reassociations);
 
   // 6. Transpose stripMinedShape to packedShape.
-  SmallVector<int64_t> insertPositionsToLastDimsPerm = computePermutationVector(
-      packedRank, packingMetadata.insertPositions, lastDims);
+  SmallVector<int64_t> transpPerm =
+      invertPermutationVector(packedToStripMinedShapePerm);
   auto transposeOp = rewriter.create<linalg::TransposeOp>(
-      loc, reshapeOp.getResult(), packOp.getDest(),
-      insertPositionsToLastDimsPerm);
+      loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
 
   LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
              DBGS() << "reshape op: " << reshapeOp; DBGSNL();
-             llvm::interleaveComma(insertPositionsToLastDimsPerm,
-                                   DBGS() << "insertPositionsToLastDimsPerm: ");
+             llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: ");
              DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
 
   // 7. Replace packOp by transposeOp.
index 383c77f3b7340023afea729bb718ae3f49d85c42..18646f598bceeb3e7f3f3f2bf3d65f9e026ec531 100644 (file)
@@ -480,6 +480,7 @@ PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
                            res.insertPositions.end());
   res.reassociations.reserve(packedRank);
   for (int64_t i = 1; i <= packedRank; ++i) {
+    res.outerPositions.push_back(i - 1);
     if (!posSet.contains(i)) {
       res.reassociations.push_back(ReassociationIndices{i - 1});
       continue;
index 83141ec75aba1cf290f23598427c374ac2abddbc..40f9f3e0761e75593b6751eb639186db87be9466 100644 (file)
@@ -183,3 +183,104 @@ transform.sequence failures(propagate) {
         !transform.op<"tensor.collapse_shape">,
         !transform.op<"tensor.extract_slice">)
 }
+
+// -----
+
+// CHECK-LABEL: func.func @pack_with_outer_dims_perm(
+func.func @pack_with_outer_dims_perm(%src: tensor<100x200x128x256xi32>,
+                                     %dest: tensor<200x4x16x100x16x32xi32>)
+    -> tensor<200x4x16x100x16x32xi32> {
+  //      CHECK: %[[C0:.*]] = arith.constant 0 : index
+  //      CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+  //      CHECK:   : tensor<100x200x128x256xi32> to tensor<100x200x128x256xi32>
+  //      CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0], [1], [2, 3], [4, 5]]
+  // CHECK-SAME:   : tensor<100x200x128x256xi32> into tensor<100x200x4x32x16x16xi32>
+  //      CHECK: linalg.transpose
+  // CHECK-SAME:   ins(%{{.*}} : tensor<100x200x4x32x16x16xi32>)
+  // CHECK-SAME:   outs(%{{.*}} : tensor<200x4x16x100x16x32xi32>)
+  // CHECK-SAME:   permutation = [1, 2, 4, 0, 5, 3]
+  %0 = tensor.pack %src
+    outer_dims_perm = [1, 2, 3, 0]
+    inner_dims_pos = [3, 2]
+    inner_tiles = [16, 32]
+    into %dest : tensor<100x200x128x256xi32> -> tensor<200x4x16x100x16x32xi32>
+  return %0 : tensor<200x4x16x100x16x32xi32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+  %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+    : (!pdl.operation) -> !transform.op<"tensor.pack">
+  transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+    -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_with_pad_and_outer_dims_perm(
+func.func @pack_with_pad_and_outer_dims_perm(%src: tensor<100x200x127x255xi32>,
+                                             %dest: tensor<200x4x16x100x16x32xi32>)
+    -> tensor<200x4x16x100x16x32xi32> {
+  //      CHECK: %[[C0:.*]] = arith.constant 0 : index
+  //      CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+  //      CHECK:   : tensor<100x200x127x255xi32> to tensor<100x200x128x256xi32>
+  //      CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0], [1], [2, 3], [4, 5]]
+  // CHECK-SAME:   : tensor<100x200x128x256xi32> into tensor<100x200x4x32x16x16xi32>
+  //      CHECK: linalg.transpose
+  // CHECK-SAME:   ins(%{{.*}} : tensor<100x200x4x32x16x16xi32>)
+  // CHECK-SAME:   outs(%{{.*}} : tensor<200x4x16x100x16x32xi32>)
+  // CHECK-SAME:   permutation = [1, 2, 4, 0, 5, 3]
+  %cst_0 = arith.constant 0 : i32
+  %0 = tensor.pack %src
+    padding_value(%cst_0 : i32)
+    outer_dims_perm = [1, 2, 3, 0]
+    inner_dims_pos = [3, 2]
+    inner_tiles = [16, 32]
+    into %dest : tensor<100x200x127x255xi32> -> tensor<200x4x16x100x16x32xi32>
+  return %0 : tensor<200x4x16x100x16x32xi32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+  %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+    : (!pdl.operation) -> !transform.op<"tensor.pack">
+  transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+    -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_as_pad_with_outer_dims_perm(
+func.func @pack_as_pad_with_outer_dims_perm(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
+  %cst_0 = arith.constant 0.0 : f32
+
+  // tensor.pack is lowered to tensor.pad + tensor.insert_slice
+  //      CHECK: %[[C0:.*]] = arith.constant 0 : index
+  //      CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+  //      CHECK:   : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
+  //      CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
+  //      CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]
+  // offsets.
+  // CHECK-SAME:   [0, 0, 0, 0, 0, 0, 0, 0]
+  // sizes.
+  // CHECK-SAME:   [1, 1, 1, 1, 136, 64, 16, 16]
+  // strides multipliers.
+  // CHECK-SAME:   [1, 1, 1, 1, 1, 1, 1, 1]
+  // CHECK-SAME:   : tensor<136x64x16x16xf32> into tensor<1x1x1x1x136x64x16x16xf32>
+  //      CHECK: return %[[RES]]
+  %pack = tensor.pack %arg0
+    padding_value(%cst_0 : f32)
+    outer_dims_perm = [1, 2, 3, 0]
+    inner_dims_pos = [0, 1, 2, 3]
+    inner_tiles = [136, 64, 16, 16]
+    into %arg1 : tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32>
+  return %pack :  tensor<1x1x1x1x136x64x16x16xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+  %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+    : (!pdl.operation) -> !transform.op<"tensor.pack">
+  transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+    -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+}