[mlir][linalg] Improve codegen of ExtractSliceOfPadTensorSwapPattern
authorMatthias Springer <springerm@google.com>
Thu, 15 Jul 2021 02:05:12 +0000 (11:05 +0900)
committerMatthias Springer <springerm@google.com>
Thu, 15 Jul 2021 02:05:55 +0000 (11:05 +0900)
Generate simpler code in case low/high padding of the PadTensorOp is statically zero.

Differential Revision: https://reviews.llvm.org/D105529

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir

index 4c3bb41..8d3ee8f 100644 (file)
@@ -866,6 +866,9 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
   int64_t rank = padOp.getSourceType().getRank();
   for (unsigned dim = 0; dim < rank; ++dim) {
     auto low = asValue(rewriter, loc, padOp.getMixedLowPad()[dim]);
+    bool hasLowPad = getConstantIntValue(low) != static_cast<int64_t>(0);
+    auto high = asValue(rewriter, loc, padOp.getMixedHighPad()[dim]);
+    bool hasHighPad = getConstantIntValue(high) != static_cast<int64_t>(0);
     auto offset = asValue(rewriter, loc, sliceOp.getMixedOffsets()[dim]);
     auto length = asValue(rewriter, loc, sliceOp.getMixedSizes()[dim]);
     auto srcSize =
@@ -874,7 +877,9 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
     // The new amount of low padding is `low - offset`. Except for the case
     // where none of the low padding is read. In that case, the new amount of
     // low padding is zero.
-    Value newLow = max(zero, sub(low, offset));
+    //
+    // Optimization: If low = 0, then newLow = 0.
+    Value newLow = hasLowPad ? max(zero, sub(low, offset)) : zero;
     appendIndex(newLow, newLows, staticNewLows);
 
     // Start reading the data from position `offset - low`. Since the original
@@ -887,7 +892,10 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
     // In that case, set the offset to the end of source tensor. The new
     // ExtractSliceOp length will be zero in that case. (Effectively reading no
     // data from the source.)
-    Value newOffset = min(max(sub(offset, low), zero), srcSize);
+    //
+    // Optimization: If low = 0, then the formula can be simplified.
+    Value newOffset = hasLowPad ? min(max(sub(offset, low), zero), srcSize)
+                                : min(offset, srcSize);
     newOffsets.push_back(getAsOpFoldResult(newOffset));
 
     // The original ExtractSliceOp was reading until position `offset + length`.
@@ -906,7 +914,11 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
     // endLoc = min(max(offset - low + length, 0), srcSize)
     //
     // The new ExtractSliceOp length is `endLoc - newOffset`.
-    Value endLoc = min(max(add(sub(offset, low), length), zero), srcSize);
+    //
+    // Optimization: If low = 0, then the formula can be simplified.
+    Value endLoc = hasLowPad
+                       ? min(max(add(sub(offset, low), length), zero), srcSize)
+                       : min(add(offset, length), srcSize);
     Value newLength = sub(endLoc, newOffset);
     newLengths.push_back(getAsOpFoldResult(newLength));
 
@@ -925,7 +937,9 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
 
     // The amount of high padding is simply the number of elements remaining,
     // so that the result has the same length as the original ExtractSliceOp.
-    Value newHigh = sub(sub(length, newLength), newLow);
+    // As an optimization, if the original high padding is zero, then the new
+    // high padding must also be zero.
+    Value newHigh = hasHighPad ? sub(sub(length, newLength), newLow) : zero;
     appendIndex(newHigh, newHighs, staticNewHighs);
 
     // Only unit stride supported.
index 362de8e..13f12d8 100644 (file)
@@ -177,3 +177,43 @@ func @dynamic_extract_size(%arg0 : tensor<?x5xf32>, %s1: index, %pad : f32) -> t
   %1 = tensor.extract_slice %0[2, 4] [%s1, 4] [1, 1] : tensor<?x13xf32> to tensor<?x4xf32>
   return %1 : tensor<?x4xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @dynamic_zero_low_padding
+//       CHECK:   scf.if
+//       CHECK:     tensor.generate
+//       CHECK:   else
+//       CHECK:     %[[SLICE:.*]] = tensor.extract_slice
+//       CHECK:     linalg.pad_tensor %[[SLICE]] low[0, 0]
+func @dynamic_zero_low_padding(%arg0 : tensor<?x?xf32>, %pad : f32,
+                               %o1 : index, %o2 : index,
+                               %s1 : index, %s2 : index)
+    -> tensor<?x?xf32> {
+  %0 = linalg.pad_tensor %arg0 low[0, 0] high[7, 8] {
+    ^bb0(%arg1: index, %arg2: index):
+      linalg.yield %pad : f32
+    } : tensor<?x?xf32> to tensor<?x?xf32>
+  %1 = tensor.extract_slice %0[%o1, %o2] [%s1, %s2] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @dynamic_zero_high_padding
+//       CHECK:   scf.if
+//       CHECK:     tensor.generate
+//       CHECK:   else
+//       CHECK:     %[[SLICE:.*]] = tensor.extract_slice
+//       CHECK:     linalg.pad_tensor %[[SLICE]] low[%{{.*}}, %{{.*}}] high[0, 0]
+func @dynamic_zero_high_padding(%arg0 : tensor<?x?xf32>, %pad : f32,
+                                %o1 : index, %o2 : index,
+                                %s1 : index, %s2 : index)
+    -> tensor<?x?xf32> {
+  %0 = linalg.pad_tensor %arg0 low[7, 8] high[0, 0] {
+    ^bb0(%arg1: index, %arg2: index):
+      linalg.yield %pad : f32
+    } : tensor<?x?xf32> to tensor<?x?xf32>
+  %1 = tensor.extract_slice %0[%o1, %o2] [%s1, %s2] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}