From: Quinn Dawkins Date: Mon, 20 Feb 2023 19:49:38 +0000 (-0500) Subject: [mlir][tensor] Fix transpose permutation in tensor.pack generalization pattern X-Git-Tag: upstream/17.0.6~16748 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=bbf1d80d67db5076e4cd02caa754c0688239fe76;p=platform%2Fupstream%2Fllvm.git [mlir][tensor] Fix transpose permutation in tensor.pack generalization pattern The generalization pattern for tensor.pack was inverting the innerDimsPos permutation when normalizing. Thus, the transpose op produced by the generalization would be incorrect. Differential Revision: https://reviews.llvm.org/D144425 --- diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 3b2cd0d..96c29e4 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -578,6 +578,13 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite( // 2. Transpose the tile to match the inner tile order. SmallVector perm = getPackUnpackNormalizedInnerPerm(srcRank, packOp.getInnerDimsPos()); + // The permutation is inverted when normalizing so invert back to match the + // ordering in the pack op. + perm = invertPermutationVector(perm); + + LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"; + llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL();); + SmallVector transpShape = readShape; applyPermutationToVector(transpShape, perm); diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir index 3feb165..8e9b77e 100644 --- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir +++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir @@ -58,3 +58,21 @@ func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32 // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] // CHECK-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] // CHECK: return %[[INSERT]] + +// ----- + +func.func @simple_CHW_to_CHWhwc(%arg0: tensor<3x5x7xf32>, %arg1: tensor<1x1x1x5x7x3xf32>) -> tensor<1x1x1x5x7x3xf32> { + %0 = tensor.pack %arg0 inner_dims_pos = [1, 2, 0] inner_tiles = [5, 7, 3] into %arg1 : tensor<3x5x7xf32> -> tensor<1x1x1x5x7x3xf32> + return %0 : tensor<1x1x1x5x7x3xf32> +} +// CHECK-LABEL: func.func @simple_CHW_to_CHWhwc +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<5x7x3xf32> +// CHECK: %[[TRANSP:.+]] = linalg.transpose +// CHECK-SAME: ins(%[[SRC]] : tensor<3x5x7xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<5x7x3xf32>) +// CHECK-SAME: permutation = [1, 2, 0] +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] +// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 5, 7, 3] [1, 1, 1, 1, 1, 1] +// CHECK: return %[[INSERT]]