[mlir] Vectorize linalg.pad_tensor consumed by transfer_write
authorMatthias Springer <springerm@google.com>
Mon, 14 Jun 2021 01:16:22 +0000 (10:16 +0900)
committerMatthias Springer <springerm@google.com>
Mon, 14 Jun 2021 01:17:23 +0000 (10:17 +0900)
Vectorize linalg.pad_tensor without generating a linalg.init_tensor when consumed by a transfer_write.

Differential Revision: https://reviews.llvm.org/D103137

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization.mlir

index 955ac11..bf48f0e 100644 (file)
@@ -784,6 +784,141 @@ struct PadTensorOpVectorizationWithTransferReadPattern
   }
 };
 
+/// Rewrite use of PadTensorOp result in TransferWriteOp.
+/// This pattern rewrites TransferWriteOps that write to a padded tensor value,
+/// where the same amount of padding is immediately removed again after the
+/// write. In such cases, the TransferWriteOp can write to the non-padded tensor
+/// value and apply out-of-bounds masking. E.g.:
+/// ```
+/// %0 = subtensor ...[...] [%s0, %s1] [1, 1] : tensor<...> to tensor<?x?xf32>
+/// %1 = linalg.pad_tensor %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
+/// %2 = vector.transfer_write %vec, %1[...]
+///     : vector<17x5xf32>, tensor<17x5xf32>
+/// %r = subtensor %2[0, 0] [%s0, %s1] [1, 1]
+///     : tensor<17x5xf32> to tensor<?x?xf32>
+/// ```
+/// is rewritten to:
+/// ```
+/// %0 = subtensor ...[...] [%s0, %s1] [1, 1] : tensor<...> to tensor<?x?xf32>
+/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>, tensor<?x?xf32>
+/// ```
+/// Note: It is important that the SubTensorOp %r resizes the result of the
+/// TransferWriteOp to the same size as the input of the TensorPadOp (or an even
+/// smaller size). Otherwise, %r's new (dynamic) dimensions would differ from
+/// %r's old dimensions.
+///
+/// This rewrite is possible if:
+/// - Low padding is static 0.
+/// - `xferOp` has exactly one use, which is a SubTensorOp. This SubTensorOp
+///   trims the same amount of padding that was added beforehand.
+/// - Single, scalar padding value.
+struct PadTensorOpVectorizationWithTransferWritePattern
+    : public VectorizePadTensorOpUserPattern<vector::TransferWriteOp> {
+  using VectorizePadTensorOpUserPattern<vector::TransferWriteOp>
+      ::VectorizePadTensorOpUserPattern;
+
+  LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
+                            vector::TransferWriteOp xferOp) const override {
+    // Low padding must be static 0.
+    if (!padOp.hasZeroLowPad()) return failure();
+    // Pad value must be a constant.
+    auto padValue = padOp.getConstantPaddingValue();
+    if (!padValue) return failure();
+    // TransferWriteOp result must be directly consumed by a SubTensorOp.
+    if (!xferOp->hasOneUse()) return failure();
+    auto trimPadding = dyn_cast<SubTensorOp>(*xferOp->user_begin());
+    if (!trimPadding) return failure();
+    // Only static zero offsets supported when trimming padding.
+    if (!trimPadding.hasZeroOffset()) return failure();
+    // trimPadding must remove the amount of padding that was added earlier.
+    if (!hasSameTensorSize(padOp.source(), trimPadding)) return failure();
+
+    SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
+    auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+        xferOp, padOp.source().getType(), xferOp.vector(), padOp.source(),
+        xferOp.indices(), xferOp.permutation_mapAttr(), xferOp.mask(),
+        rewriter.getBoolArrayAttr(inBounds));
+    rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
+
+    return success();
+  }
+
+  /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
+  /// i.e., same dimensions.
+  ///
+  /// Dimensions may be static, dynamic or mix of both. In case of dynamic
+  /// dimensions, this function tries to infer the (static) tensor size by
+  /// looking at the defining op and utilizing op-specific knowledge.
+  ///
+  /// This is a conservative analysis. In case equal tensor sizes cannot be
+  /// proven statically, this analysis returns `false` even though the tensor
+  /// sizes may turn out to be equal at runtime.
+  bool hasSameTensorSize(Value beforePadding, SubTensorOp afterTrimming) const {
+    // If the input to PadTensorOp is a CastOp, try with with both CastOp result
+    // and CastOp operand.
+    if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
+      if (hasSameTensorSize(castOp.source(), afterTrimming)) return true;
+
+    auto t1 = beforePadding.getType().dyn_cast<RankedTensorType>();
+    auto t2 = afterTrimming.getType().dyn_cast<RankedTensorType>();
+    // Only RankedTensorType supported.
+    if (!t1 || !t2) return false;
+    // Rank of both values must be the same.
+    if (t1.getRank() != t2.getRank()) return false;
+
+    // All static dimensions must be the same. Mixed cases (e.g., dimension
+    // static in `t1` but dynamic in `t2`) are not supported.
+    for (unsigned i = 0; i < t1.getRank(); ++i) {
+      if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
+        return false;
+      if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
+        return false;
+    }
+
+    // Nothing more to check if all dimensions are static.
+    if (t1.getNumDynamicDims() == 0) return true;
+
+    // All dynamic sizes must be the same. The only supported case at the moment
+    // is when `beforePadding` is a SubTensorOp (or a cast thereof).
+
+    // Apart from CastOp, only SubTensorOp is supported.
+    auto beforeSubtensor = beforePadding.getDefiningOp<SubTensorOp>();
+    if (!beforeSubtensor) return false;
+
+    assert(static_cast<size_t>(t1.getRank())
+           == beforeSubtensor.getMixedSizes().size());
+    assert(static_cast<size_t>(t2.getRank())
+           == afterTrimming.getMixedSizes().size());
+
+    for (unsigned i = 0; i < t1.getRank(); ++i) {
+      // Skip static dimensions.
+      if (!t1.isDynamicDim(i)) continue;
+      auto size1 = beforeSubtensor.getMixedSizes()[i];
+      auto size2 = afterTrimming.getMixedSizes()[i];
+
+      // Case 1: Same value or same constant int.
+      if (isEqualConstantIntOrValue(size1, size2)) continue;
+
+      // Other cases: Take a deeper look at defining ops of values.
+      auto v1 = size1.dyn_cast<Value>();
+      auto v2 = size2.dyn_cast<Value>();
+      if (!v1 || !v2) return false;
+
+      // Case 2: Both values are identical AffineMinOps. (Should not happen if
+      // CSE is run.)
+      auto minOp1 = v1.getDefiningOp<AffineMinOp>();
+      auto minOp2 = v2.getDefiningOp<AffineMinOp>();
+      if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap()
+          && minOp1.operands() == minOp2.operands()) continue;
+
+      // Add additional cases as needed.
+    }
+
+    // All tests passed.
+    return true;
+  }
+};
+
 /// Rewrite use of PadTensorOp result in SubtensorInsertOp. E.g.:
 /// ```
 /// %0 = linalg.pad_tensor %src ... : tensor<?x?xf32> to tensor<17x5xf32>
@@ -807,8 +942,8 @@ struct PadTensorOpVectorizationWithTransferReadPattern
 /// - Single, scalar padding value.
 struct PadTensorOpVectorizationWithSubTensorInsertPattern
     : public VectorizePadTensorOpUserPattern<SubTensorInsertOp> {
-  using VectorizePadTensorOpUserPattern<
-      SubTensorInsertOp>::VectorizePadTensorOpUserPattern;
+  using VectorizePadTensorOpUserPattern<SubTensorInsertOp>
+      ::VectorizePadTensorOpUserPattern;
 
   LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
                             SubTensorInsertOp insertOp) const override {
@@ -864,6 +999,7 @@ void mlir::linalg::populatePadTensorOpVectorizationPatterns(
       patterns.getContext(), baseBenefit);
   // Try these specialized patterns first before resorting to the generic one.
   patterns.add<PadTensorOpVectorizationWithTransferReadPattern,
+               PadTensorOpVectorizationWithTransferWritePattern,
                PadTensorOpVectorizationWithSubTensorInsertPattern>(
       patterns.getContext(), baseBenefit.getBenefit() + 1);
 }
index 04c3d84..bc5a2fe 100644 (file)
@@ -580,6 +580,54 @@ func @pad_and_transfer_read(%arg0: tensor<5x6xf32>) -> vector<7x9xf32> {
 
 // -----
 
+
+// CHECK-LABEL: func @pad_and_transfer_write_static
+//  CHECK-SAME:     %[[ARG0:.*]]: tensor<5x6xf32>, %[[ARG1:.*]]: vector<7x9xf32>
+//   CHECK-NOT:   linalg.pad_tensor
+//       CHECK:   %[[C0:.*]] = constant 0 : index
+//       CHECK:   %[[RESULT:.*]] = vector.transfer_write %[[ARG1]], %[[ARG0]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<5x6xf32>
+//       CHECK:   return %[[RESULT]]
+func @pad_and_transfer_write_static(
+    %arg0: tensor<5x6xf32>, %arg1: vector<7x9xf32>) -> tensor<5x6xf32> {
+  %c0 = constant 0 : index
+  %c5 = constant 5.0 : f32
+  %0 = linalg.pad_tensor %arg0 low[0, 0] high[5, 7] {
+    ^bb0(%arg2: index, %arg3: index):
+      linalg.yield %c5 : f32
+  } : tensor<5x6xf32> to tensor<10x13xf32>
+  %1 = vector.transfer_write %arg1, %0[%c0, %c0]
+      : vector<7x9xf32>, tensor<10x13xf32>
+  %2 = subtensor %1[0, 0] [5, 6] [1, 1] : tensor<10x13xf32> to tensor<5x6xf32>
+  return %2 : tensor<5x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @pad_and_transfer_write_dynamic_static
+//  CHECK-SAME:     %[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: vector<7x9xf32>, %[[SIZE:.*]]: index, %[[PADDING:.*]]: index
+//   CHECK-NOT:   linalg.pad_tensor
+//       CHECK:   %[[C0:.*]] = constant 0 : index
+//       CHECK:   %[[SUB:.*]] = subtensor %[[ARG0]][0, 0] [%[[SIZE]], 6] [1, 1] : tensor<?x?xf32> to tensor<?x6xf32>
+//       CHECK:   %[[RESULT:.*]] = vector.transfer_write %[[ARG1]], %[[SUB]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<?x6xf32>
+//       CHECK:   return %[[RESULT]]
+func @pad_and_transfer_write_dynamic_static(
+    %arg0: tensor<?x?xf32>, %arg1: vector<7x9xf32>, %size: index, %padding: index) -> tensor<?x6xf32> {
+  %c0 = constant 0 : index
+  %c5 = constant 5.0 : f32
+  %s = subtensor %arg0[0, 0] [%size, 6] [1, 1]
+      : tensor<?x?xf32> to tensor<?x6xf32>
+  %0 = linalg.pad_tensor %s low[0, 0] high[%padding, 7] {
+    ^bb0(%arg2: index, %arg3: index):
+      linalg.yield %c5 : f32
+  } : tensor<?x6xf32> to tensor<?x13xf32>
+  %1 = vector.transfer_write %arg1, %0[%c0, %c0]
+      : vector<7x9xf32>, tensor<?x13xf32>
+  %2 = subtensor %1[0, 0] [%size, 6] [1, 1] : tensor<?x13xf32> to tensor<?x6xf32>
+  return %2 : tensor<?x6xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @pad_and_subtensor_insert
 //  CHECK-SAME:     %[[ARG0:.*]]: tensor<5x6xf32>, %[[ARG1:.*]]: tensor<12x13xf32>
 //   CHECK-NOT:   linalg.pad_tensor