[MLIR][Tensor] Introduce a pattern to propagate through `tensor.pad`
authorLorenzo Chelini <l.chelini@icloud.com>
Mon, 13 Feb 2023 13:44:03 +0000 (14:44 +0100)
committerLorenzo Chelini <l.chelini@icloud.com>
Wed, 15 Feb 2023 07:48:55 +0000 (08:48 +0100)
Introduce a pattern to 'push down' a `tensor.unpack` through a
`tensor.pad`. The propagation happens if the unpack does not touch the
padded dimensions.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D143907

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
mlir/test/Dialect/Linalg/data-layout-propagation.mlir

index 1b6d1d2..bf5e64b 100644 (file)
@@ -465,10 +465,69 @@ struct PushDownUnPackOpThroughElemGenericOp
   }
 };
 
+/// Propagate a tensor.unpack operation through a tensor.pad. The idea is to
+/// add as many zero padding dimensions in `high` and `low` based on the number
+/// of point loops.
+struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
+  using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::PadOp padOp,
+                                PatternRewriter &rewriter) const override {
+    tensor::UnPackOp unpackOp =
+        padOp.getSource().getDefiningOp<tensor::UnPackOp>();
+    if (!unpackOp)
+      return failure();
+
+    Location loc = padOp.getLoc();
+    // Bail out if one of the padded dimension is a tiled one.
+    llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
+    ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
+    llvm::SmallBitVector innerDims(paddedDims.size());
+    for (int64_t dim : innerDimsPos)
+      innerDims.flip(dim);
+    if (paddedDims.anyCommon(innerDims))
+      return failure();
+
+    Value paddingVal = padOp.getConstantPaddingValue();
+    if (!paddingVal)
+      return failure();
+
+    // If we have `outer_dims_perms` we need to adjust the padded dimensions.
+    ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
+    SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
+    SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
+    if (!outerDimsPerm.empty()) {
+      applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
+      applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
+    }
+    // Add zero padding for the point loops.
+    size_t pointLoopsSize = innerDimsPos.size();
+    lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
+    highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
+
+    auto newPadOp = rewriter.create<tensor::PadOp>(
+        loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad,
+        paddingVal, padOp.getNofold());
+
+    // Inject the tensor.unpack right after the packed padOp.
+    Value outputUnPack = rewriter.create<tensor::EmptyOp>(
+        loc, padOp.getResultType().getShape(),
+        padOp.getResultType().getElementType());
+
+    Value replacement = rewriter.create<tensor::UnPackOp>(
+        loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
+        unpackOp.getMixedTiles(), outerDimsPerm);
+    rewriter.replaceOp(padOp, replacement);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::linalg::populateDataLayoutPropagationPatterns(
     RewritePatternSet &patterns) {
-  patterns.insert<BubbleUpPackOpThroughElemGenericOpPattern,
-                  PushDownUnPackOpThroughElemGenericOp>(patterns.getContext());
+  patterns
+      .insert<BubbleUpPackOpThroughElemGenericOpPattern,
+              PushDownUnPackOpThroughElemGenericOp, PushDownUnPackThroughPadOp>(
+          patterns.getContext());
 }
index b699b3d..32190ca 100644 (file)
@@ -471,4 +471,70 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
 // CHECK-SAME:  outs(%[[DEST]]
 // CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
 // CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
-// CHECK-SAME:  into %[[ARG0_UNPACK_EMPTY]] 
+// CHECK-SAME:  into %[[ARG0_UNPACK_EMPTY]]
+
+// -----
+
+func.func @pad_valid_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x58x58x64xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<1x56x56x64xf32>
+  %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
+  %padded = tensor.pad %1 low[0, 1, 1, 0] high[0, 1, 1, 0] {
+    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+    tensor.yield %cst : f32
+  } : tensor<1x56x56x64xf32> to tensor<1x58x58x64xf32>
+  return %padded : tensor<1x58x58x64xf32>
+}
+
+// CHECK: func.func @pad_valid_propagation(
+// CHECK-SAME:  %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x58x58x64xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]] 
+// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:  into %[[EMPTY]] : tensor<1x2x58x58x32xf32> -> tensor<1x58x58x64xf32>
+
+// -----
+
+func.func @pad_valid_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<2x58x58x64xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<1x56x56x64xf32>
+  %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
+  %padded = tensor.pad %1 low[1, 1, 1, 0] high[0, 1, 1, 0] {
+    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+    tensor.yield %cst : f32
+  } : tensor<1x56x56x64xf32> to tensor<2x58x58x64xf32>
+  return %padded : tensor<2x58x58x64xf32>
+}
+
+// CHECK: func.func @pad_valid_propagation(
+// CHECK-SAME:  %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[1, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x58x58x64xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]]
+// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME:  into %[[EMPTY]] : tensor<2x2x58x58x32xf32> -> tensor<2x58x58x64xf32>
+
+// -----
+
+func.func @pad_along_unpacked_dim(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x58x58x66xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<1x56x56x64xf32>
+  %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
+  %padded = tensor.pad %1 low[0, 1, 1, 1] high[0, 1, 1, 1] {
+    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+    tensor.yield %cst : f32
+  } : tensor<1x56x56x64xf32> to tensor<1x58x58x66xf32>
+  return %padded : tensor<1x58x58x66xf32>
+}
+
+// CHECK: func.func @pad_along_unpacked_dim(
+// CHECK: %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x56x56x64xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] 
+// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:  into %[[EMPTY]] : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
+// CHECK: %[[PADDED:.+]] = tensor.pad %[[UNPACK]] low[0, 1, 1, 1] high[0, 1, 1, 1]