}
};
+/// Propagate a tensor.unpack operation through a tensor.pad. The idea is to
+/// add as many zero padding dimensions in `high` and `low` based on the number
+/// of point loops.
+struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
+ using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::PadOp padOp,
+ PatternRewriter &rewriter) const override {
+ tensor::UnPackOp unpackOp =
+ padOp.getSource().getDefiningOp<tensor::UnPackOp>();
+ if (!unpackOp)
+ return failure();
+
+ Location loc = padOp.getLoc();
+ // Bail out if one of the padded dimension is a tiled one.
+ llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
+ ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
+ llvm::SmallBitVector innerDims(paddedDims.size());
+ for (int64_t dim : innerDimsPos)
+ innerDims.flip(dim);
+ if (paddedDims.anyCommon(innerDims))
+ return failure();
+
+ Value paddingVal = padOp.getConstantPaddingValue();
+ if (!paddingVal)
+ return failure();
+
+ // If we have `outer_dims_perms` we need to adjust the padded dimensions.
+ ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
+ SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
+ SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
+ if (!outerDimsPerm.empty()) {
+ applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
+ applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
+ }
+ // Add zero padding for the point loops.
+ size_t pointLoopsSize = innerDimsPos.size();
+ lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
+ highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
+
+ auto newPadOp = rewriter.create<tensor::PadOp>(
+ loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad,
+ paddingVal, padOp.getNofold());
+
+ // Inject the tensor.unpack right after the packed padOp.
+ Value outputUnPack = rewriter.create<tensor::EmptyOp>(
+ loc, padOp.getResultType().getShape(),
+ padOp.getResultType().getElementType());
+
+ Value replacement = rewriter.create<tensor::UnPackOp>(
+ loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
+ unpackOp.getMixedTiles(), outerDimsPerm);
+ rewriter.replaceOp(padOp, replacement);
+ return success();
+ }
+};
+
} // namespace
void mlir::linalg::populateDataLayoutPropagationPatterns(
RewritePatternSet &patterns) {
- patterns.insert<BubbleUpPackOpThroughElemGenericOpPattern,
- PushDownUnPackOpThroughElemGenericOp>(patterns.getContext());
+ patterns
+ .insert<BubbleUpPackOpThroughElemGenericOpPattern,
+ PushDownUnPackOpThroughElemGenericOp, PushDownUnPackThroughPadOp>(
+ patterns.getContext());
}
// CHECK-SAME: outs(%[[DEST]]
// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
+// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
+
+// -----
+
+func.func @pad_valid_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x58x58x64xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<1x56x56x64xf32>
+ %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
+ %padded = tensor.pad %1 low[0, 1, 1, 0] high[0, 1, 1, 0] {
+ ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+ tensor.yield %cst : f32
+ } : tensor<1x56x56x64xf32> to tensor<1x58x58x64xf32>
+ return %padded : tensor<1x58x58x64xf32>
+}
+
+// CHECK: func.func @pad_valid_propagation(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x58x58x64xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME: into %[[EMPTY]] : tensor<1x2x58x58x32xf32> -> tensor<1x58x58x64xf32>
+
+// -----
+
+func.func @pad_valid_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<2x58x58x64xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<1x56x56x64xf32>
+ %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
+ %padded = tensor.pad %1 low[1, 1, 1, 0] high[0, 1, 1, 0] {
+ ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+ tensor.yield %cst : f32
+ } : tensor<1x56x56x64xf32> to tensor<2x58x58x64xf32>
+ return %padded : tensor<2x58x58x64xf32>
+}
+
+// CHECK: func.func @pad_valid_propagation(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[1, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x58x58x64xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME: into %[[EMPTY]] : tensor<2x2x58x58x32xf32> -> tensor<2x58x58x64xf32>
+
+// -----
+
+func.func @pad_along_unpacked_dim(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x58x58x66xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<1x56x56x64xf32>
+ %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
+ %padded = tensor.pad %1 low[0, 1, 1, 1] high[0, 1, 1, 1] {
+ ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+ tensor.yield %cst : f32
+ } : tensor<1x56x56x64xf32> to tensor<1x58x58x66xf32>
+ return %padded : tensor<1x58x58x66xf32>
+}
+
+// CHECK: func.func @pad_along_unpacked_dim(
+// CHECK: %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x56x56x64xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME: into %[[EMPTY]] : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
+// CHECK: %[[PADDED:.+]] = tensor.pad %[[UNPACK]] low[0, 1, 1, 1] high[0, 1, 1, 1]