[mlir][tensor] Fix transpose permutation in tensor.pack generalization pattern
authorQuinn Dawkins <quinn@nod-labs.com>
Mon, 20 Feb 2023 19:49:38 +0000 (14:49 -0500)
committerQuinn Dawkins <quinn@nod-labs.com>
Wed, 22 Feb 2023 19:49:49 +0000 (14:49 -0500)
The generalization pattern for tensor.pack was inverting the
innerDimsPos permutation when normalizing. Thus, the transpose op
produced by the generalization would be incorrect.

Differential Revision: https://reviews.llvm.org/D144425

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir

index 3b2cd0d..96c29e4 100644 (file)
@@ -578,6 +578,13 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
   // 2. Transpose the tile to match the inner tile order.
   SmallVector<int64_t> perm =
       getPackUnpackNormalizedInnerPerm(srcRank, packOp.getInnerDimsPos());
+  // The permutation is inverted when normalizing so invert back to match the
+  // ordering in the pack op.
+  perm = invertPermutationVector(perm);
+
+  LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
+             llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
+
   SmallVector<int64_t> transpShape = readShape;
   applyPermutationToVector<int64_t>(transpShape, perm);
 
index 3feb165..8e9b77e 100644 (file)
@@ -58,3 +58,21 @@ func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32
 // CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
 // CHECK-SAME:      [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
 // CHECK:         return %[[INSERT]]
+
+// -----
+
+func.func @simple_CHW_to_CHWhwc(%arg0: tensor<3x5x7xf32>, %arg1: tensor<1x1x1x5x7x3xf32>) -> tensor<1x1x1x5x7x3xf32> {
+  %0 = tensor.pack %arg0 inner_dims_pos = [1, 2, 0] inner_tiles = [5, 7, 3] into %arg1 : tensor<3x5x7xf32> -> tensor<1x1x1x5x7x3xf32>
+  return %0 : tensor<1x1x1x5x7x3xf32>
+}
+// CHECK-LABEL: func.func @simple_CHW_to_CHWhwc
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<5x7x3xf32>
+// CHECK:         %[[TRANSP:.+]] =  linalg.transpose
+// CHECK-SAME:      ins(%[[SRC]] : tensor<3x5x7xf32>)
+// CHECK-SAME:      outs(%[[EMPTY]] : tensor<5x7x3xf32>)
+// CHECK-SAME:      permutation = [1, 2, 0]
+// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-SAME:      [0, 0, 0, 0, 0, 0] [1, 1, 1, 5, 7, 3] [1, 1, 1, 1, 1, 1]
+// CHECK:         return %[[INSERT]]