FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
tensor::PackOp packOp) {
// 1. Filter out NYI cases.
- if (!packOp.getOuterDimsPerm().empty())
- return rewriter.notifyMatchFailure(packOp, "outer dims perm NYI");
-
auto packedTensorType =
packOp->getResultTypes().front().cast<RankedTensorType>();
if (!packedTensorType.hasStaticShape()) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(packOp);
- // 2. Compute the permutation vector to move the last `numPackedDims` into the
- // `innerPosDims` of a shape of rank `packedRank`.
+ // 2. Compute the permutation vector to shuffle packed shape into the shape
+ // before any outer or inner permutations have been applied. The permutation
+ // can be obtained from two permutations:
+ // a) Compute the permutation vector to move the last `numPackedDims` into
+ // the `innerPosDims` of a shape of rank `packedRank`.
+ // b) Compute the permutation vector to move outer dims if the pack op
+ // has outer_dims_perm.
+ // Apply (b) permutation on (a) permutation to get the final permutation.
int64_t numPackedDims = packOp.getInnerDimsPos().size();
int64_t packedRank = packedTensorType.getRank();
auto lastDims = llvm::to_vector(
llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
PackingMetadata packingMetadata = computePackingMetadata(
packedTensorType.getRank(), packOp.getInnerDimsPos());
- SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
+ SmallVector<int64_t> innerPositionsPerm = computePermutationVector(
packedRank, lastDims, packingMetadata.insertPositions);
+ SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
+ ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
+ if (!outerPerm.empty())
+ applyPermutationToVector(outerPos, outerPerm);
+ SmallVector<int64_t> outerPositionPerm = computePermutationVector(
+ packedRank, packingMetadata.outerPositions, outerPos);
+
+ SmallVector<int64_t> packedToStripMinedShapePerm = innerPositionsPerm;
+ applyPermutationToVector(packedToStripMinedShapePerm, outerPositionPerm);
+
// 3. Compute the stripMinedShape: this is the packed shape before any outer
// or inner permutations have been applied.
SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
- applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm);
+ applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
// 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
LLVM_DEBUG(
DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
DBGS() << "insertPositions: ");
+ DBGSNL(); llvm::interleaveComma(packingMetadata.outerPositions,
+ DBGS() << "outerPositions: ");
DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
DBGS() << "packedShape: ");
DBGSNL();
- llvm::interleaveComma(lastDimsToInsertPositionsPerm,
- DBGS() << "lastDimsToInsertPositionsPerm: ");
+ llvm::interleaveComma(outerPositionPerm, DBGS() << "outerPositionPerm: ");
+ DBGSNL(); llvm::interleaveComma(innerPositionsPerm,
+ DBGS() << "innerPositionsPerm: ");
+ DBGSNL();
+ llvm::interleaveComma(packedToStripMinedShapePerm,
+ DBGS() << "packedToStripMinedShapePerm: ");
DBGSNL(); llvm::interleaveComma(
packingMetadata.reassociations, DBGS() << "reassociations: ",
[&](ReassociationIndices ri) {
padOp.getResult(), packingMetadata.reassociations);
// 6. Transpose stripMinedShape to packedShape.
- SmallVector<int64_t> insertPositionsToLastDimsPerm = computePermutationVector(
- packedRank, packingMetadata.insertPositions, lastDims);
+ SmallVector<int64_t> transpPerm =
+ invertPermutationVector(packedToStripMinedShapePerm);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
- loc, reshapeOp.getResult(), packOp.getDest(),
- insertPositionsToLastDimsPerm);
+ loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
DBGS() << "reshape op: " << reshapeOp; DBGSNL();
- llvm::interleaveComma(insertPositionsToLastDimsPerm,
- DBGS() << "insertPositionsToLastDimsPerm: ");
+ llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: ");
DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
// 7. Replace packOp by transposeOp.
!transform.op<"tensor.collapse_shape">,
!transform.op<"tensor.extract_slice">)
}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_with_outer_dims_perm(
+func.func @pack_with_outer_dims_perm(%src: tensor<100x200x128x256xi32>,
+ %dest: tensor<200x4x16x100x16x32xi32>)
+ -> tensor<200x4x16x100x16x32xi32> {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+ // CHECK: : tensor<100x200x128x256xi32> to tensor<100x200x128x256xi32>
+ // CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0], [1], [2, 3], [4, 5]]
+ // CHECK-SAME: : tensor<100x200x128x256xi32> into tensor<100x200x4x32x16x16xi32>
+ // CHECK: linalg.transpose
+ // CHECK-SAME: ins(%{{.*}} : tensor<100x200x4x32x16x16xi32>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<200x4x16x100x16x32xi32>)
+ // CHECK-SAME: permutation = [1, 2, 4, 0, 5, 3]
+ %0 = tensor.pack %src
+ outer_dims_perm = [1, 2, 3, 0]
+ inner_dims_pos = [3, 2]
+ inner_tiles = [16, 32]
+ into %dest : tensor<100x200x128x256xi32> -> tensor<200x4x16x100x16x32xi32>
+ return %0 : tensor<200x4x16x100x16x32xi32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+ %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+ : (!pdl.operation) -> !transform.op<"tensor.pack">
+ transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+ -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_with_pad_and_outer_dims_perm(
+func.func @pack_with_pad_and_outer_dims_perm(%src: tensor<100x200x127x255xi32>,
+ %dest: tensor<200x4x16x100x16x32xi32>)
+ -> tensor<200x4x16x100x16x32xi32> {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+ // CHECK: : tensor<100x200x127x255xi32> to tensor<100x200x128x256xi32>
+ // CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0], [1], [2, 3], [4, 5]]
+ // CHECK-SAME: : tensor<100x200x128x256xi32> into tensor<100x200x4x32x16x16xi32>
+ // CHECK: linalg.transpose
+ // CHECK-SAME: ins(%{{.*}} : tensor<100x200x4x32x16x16xi32>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<200x4x16x100x16x32xi32>)
+ // CHECK-SAME: permutation = [1, 2, 4, 0, 5, 3]
+ %cst_0 = arith.constant 0 : i32
+ %0 = tensor.pack %src
+ padding_value(%cst_0 : i32)
+ outer_dims_perm = [1, 2, 3, 0]
+ inner_dims_pos = [3, 2]
+ inner_tiles = [16, 32]
+ into %dest : tensor<100x200x127x255xi32> -> tensor<200x4x16x100x16x32xi32>
+ return %0 : tensor<200x4x16x100x16x32xi32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+ %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+ : (!pdl.operation) -> !transform.op<"tensor.pack">
+ transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+ -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_as_pad_with_outer_dims_perm(
+func.func @pack_as_pad_with_outer_dims_perm(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
+ %cst_0 = arith.constant 0.0 : f32
+
+ // tensor.pack is lowered to tensor.pad + tensor.insert_slice
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+ // CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
+ // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
+ // CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]
+ // offsets.
+ // CHECK-SAME: [0, 0, 0, 0, 0, 0, 0, 0]
+ // sizes.
+ // CHECK-SAME: [1, 1, 1, 1, 136, 64, 16, 16]
+ // strides multipliers.
+ // CHECK-SAME: [1, 1, 1, 1, 1, 1, 1, 1]
+ // CHECK-SAME: : tensor<136x64x16x16xf32> into tensor<1x1x1x1x136x64x16x16xf32>
+ // CHECK: return %[[RES]]
+ %pack = tensor.pack %arg0
+ padding_value(%cst_0 : f32)
+ outer_dims_perm = [1, 2, 3, 0]
+ inner_dims_pos = [0, 1, 2, 3]
+ inner_tiles = [136, 64, 16, 16]
+ into %arg1 : tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32>
+ return %pack : tensor<1x1x1x1x136x64x16x16xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+ %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+ : (!pdl.operation) -> !transform.op<"tensor.pack">
+ transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+ -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+}