Reland "[mlir][tensor] Support more cases in MergeConsecutiveExtractSlice"
authorLei Zhang <antiagainst@google.com>
Thu, 22 Sep 2022 19:07:43 +0000 (15:07 -0400)
committerLei Zhang <antiagainst@google.com>
Thu, 22 Sep 2022 21:28:50 +0000 (17:28 -0400)
This relands commit 5d4603a02d0c3e0106b10d245322b1d2072c0c3d.
It cludes fixes to GCC test failures and simplification to
the implementation.

Co-authored-by: Mahesh Ravishankar <ravishankarm@google.com>
Co-authored-by: Christopher Bate <cbate@nvidia.com>
mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h
mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir

index 2ca5562..e1e6a03 100644 (file)
 namespace mlir {
 namespace tensor {
 
+/// Fills the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use
+/// when combining a producer slice **into** a consumer slice.
+///
+/// This function performs the following computation:
+/// - Combined offsets = producer_offsets * consumer_strides + consumer_offsets
+/// - Combined sizes = consumer_sizes
+/// - Combined strides = producer_strides * consumer_strides
+LogicalResult
+mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
+                            ArrayRef<OpFoldResult> producerOffsets,
+                            ArrayRef<OpFoldResult> producerSizes,
+                            ArrayRef<OpFoldResult> producerStrides,
+                            const llvm::SmallBitVector &droppedProducerDims,
+                            ArrayRef<OpFoldResult> consumerOffsets,
+                            ArrayRef<OpFoldResult> consumerSizes,
+                            ArrayRef<OpFoldResult> consumerStrides,
+                            SmallVector<OpFoldResult> &combinedOffsets,
+                            SmallVector<OpFoldResult> &combinedSizes,
+                            SmallVector<OpFoldResult> &combinedStrides);
+
+/// Fills the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use
+/// when combining a `producer` slice op **into** a `consumer` slice op.
+LogicalResult
+mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
+                            OffsetSizeAndStrideOpInterface producer,
+                            OffsetSizeAndStrideOpInterface consumer,
+                            const llvm::SmallBitVector &droppedProducerDims,
+                            SmallVector<OpFoldResult> &combinedOffsets,
+                            SmallVector<OpFoldResult> &combinedSizes,
+                            SmallVector<OpFoldResult> &combinedStrides);
+
 //===----------------------------------------------------------------------===//
 // Extract slice from `tensor.collapse_shape`
 //===----------------------------------------------------------------------===//
index 48977a9..a065ba2 100644 (file)
@@ -7,8 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
 using namespace mlir;
 using namespace mlir::tensor;
 
-/// Adds each corresponding pair of offsets in `offsets1` and `offsets2` and
-/// returns the results.
-static SmallVector<OpFoldResult> mergeOffsets(Location loc,
-                                              ArrayRef<OpFoldResult> offsets1,
-                                              ArrayRef<OpFoldResult> offsets2,
-                                              OpBuilder &builder) {
-  SmallVector<OpFoldResult> foldedOffsets;
-  assert(offsets1.size() == offsets2.size());
-  foldedOffsets.reserve(offsets1.size());
-
-  AffineExpr dim1, dim2;
-  bindDims(builder.getContext(), dim1, dim2);
-
-  for (const auto &pair : llvm::zip(offsets1, offsets2)) {
-    auto offset0 =
-        getValueOrCreateConstantIndexOp(builder, loc, std::get<0>(pair));
-    auto offset1 =
-        getValueOrCreateConstantIndexOp(builder, loc, std::get<1>(pair));
-    auto foldedOffset =
-        makeComposedAffineApply(builder, loc, dim1 + dim2, {offset0, offset1});
-    foldedOffsets.push_back(foldedOffset.getResult());
+LogicalResult tensor::mergeOffsetsSizesAndStrides(
+    OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> producerOffsets,
+    ArrayRef<OpFoldResult> producerSizes,
+    ArrayRef<OpFoldResult> producerStrides,
+    const llvm::SmallBitVector &droppedProducerDims,
+    ArrayRef<OpFoldResult> consumerOffsets,
+    ArrayRef<OpFoldResult> consumerSizes,
+    ArrayRef<OpFoldResult> consumerStrides,
+    SmallVector<OpFoldResult> &combinedOffsets,
+    SmallVector<OpFoldResult> &combinedSizes,
+    SmallVector<OpFoldResult> &combinedStrides) {
+  combinedOffsets.resize(producerOffsets.size());
+  combinedSizes.resize(producerOffsets.size());
+  combinedStrides.resize(producerOffsets.size());
+
+  AffineExpr s0, s1, s2;
+  bindSymbols(builder.getContext(), s0, s1, s2);
+
+  unsigned consumerPos = 0;
+  for (auto i : llvm::seq<unsigned>(0, producerOffsets.size())) {
+    if (droppedProducerDims.test(i)) {
+      // For dropped dims, get the values from the producer.
+      combinedOffsets[i] = producerOffsets[i];
+      combinedSizes[i] = producerSizes[i];
+      combinedStrides[i] = producerStrides[i];
+      continue;
+    }
+    SmallVector<OpFoldResult> offsetSymbols, strideSymbols;
+    // The combined offset is computed as
+    //    producer_offset + consumer_offset * producer_strides.
+    combinedOffsets[i] = makeComposedFoldedAffineApply(
+        builder, loc, s0 * s1 + s2,
+        {consumerOffsets[consumerPos], producerStrides[i], producerOffsets[i]});
+    combinedSizes[i] = consumerSizes[consumerPos];
+    // The combined stride is computed as
+    //    consumer_stride * producer_stride.
+    combinedStrides[i] = makeComposedFoldedAffineApply(
+        builder, loc, s0 * s1,
+        {consumerStrides[consumerPos], producerStrides[i]});
+
+    consumerPos++;
   }
-  return foldedOffsets;
+  return success();
+}
+
+LogicalResult tensor::mergeOffsetsSizesAndStrides(
+    OpBuilder &builder, Location loc, OffsetSizeAndStrideOpInterface producer,
+    OffsetSizeAndStrideOpInterface consumer,
+    const llvm::SmallBitVector &droppedProducerDims,
+    SmallVector<OpFoldResult> &combinedOffsets,
+    SmallVector<OpFoldResult> &combinedSizes,
+    SmallVector<OpFoldResult> &combinedStrides) {
+  SmallVector<OpFoldResult> consumerOffsets = consumer.getMixedOffsets();
+  SmallVector<OpFoldResult> consumerSizes = consumer.getMixedSizes();
+  SmallVector<OpFoldResult> consumerStrides = consumer.getMixedStrides();
+  SmallVector<OpFoldResult> producerOffsets = producer.getMixedOffsets();
+  SmallVector<OpFoldResult> producerSizes = producer.getMixedSizes();
+  SmallVector<OpFoldResult> producerStrides = producer.getMixedStrides();
+  return tensor::mergeOffsetsSizesAndStrides(
+      builder, loc, producerOffsets, producerSizes, producerStrides,
+      droppedProducerDims, consumerOffsets, consumerSizes, consumerStrides,
+      combinedOffsets, combinedSizes, combinedStrides);
 }
 
 namespace {
@@ -53,24 +92,15 @@ struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {
     if (!prevOp)
       return failure();
 
-    if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride())
+    SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
+    if (failed(mergeOffsetsSizesAndStrides(rewriter, nextOp.getLoc(), prevOp,
+                                           nextOp, prevOp.getDroppedDims(),
+                                           newOffsets, newSizes, newStrides)))
       return failure();
 
-    auto prevResultType = prevOp.getType().cast<ShapedType>();
-    if (prevOp.getSourceType().getRank() != prevResultType.getRank())
-      return rewriter.notifyMatchFailure(
-          prevOp, "rank-reducing producder case unimplemented");
-
-    Location loc = nextOp.getLoc();
-
-    SmallVector<OpFoldResult> prevOffsets = prevOp.getMixedOffsets();
-    SmallVector<OpFoldResult> nextOffsets = nextOp.getMixedOffsets();
-    SmallVector<OpFoldResult> foldedOffsets =
-        mergeOffsets(loc, prevOffsets, nextOffsets, rewriter);
-
-    rewriter.replaceOpWithNewOp<ExtractSliceOp>(
-        nextOp, nextOp.getType(), prevOp.getSource(), foldedOffsets,
-        nextOp.getMixedSizes(), nextOp.getMixedStrides());
+    rewriter.replaceOpWithNewOp<ExtractSliceOp>(nextOp, nextOp.getType(),
+                                                prevOp.getSource(), newOffsets,
+                                                newSizes, newStrides);
     return success();
   }
 };
index 45a3f37..f5d77f6 100644 (file)
@@ -9,10 +9,12 @@ func.func @extract_slice_same_rank(
 
 // CHECK-LABEL: func.func @extract_slice_same_rank
 //  CHECK-SAME: (%[[SOURCE:.+]]: tensor<?x?x?x?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index)
-//       CHECK:   %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET0]], %[[OFFSET1]]]
+//       CHECK:   %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET1]], %[[OFFSET0]]]
 //       CHECK:   %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]][7, 9, 11, %[[OFFSET]]] [8, 16, 32, %[[SIZE1]]] [1, 1, 1, 1]
 //       CHECK:   return %[[EXTRACT]] : tensor<8x16x32x?xf32>
 
+// -----
+
 func.func @extract_slice_rank_reducing_consumer(
     %src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<16x?xf32> {
   %0 = tensor.extract_slice %src[0, 1, 2, %offset0] [128, 128, 128, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x128x128x?xf32>
@@ -23,6 +25,8 @@ func.func @extract_slice_rank_reducing_consumer(
 // CHECK-LABEL: func.func @extract_slice_rank_reducing_consumer
 //       CHECK:   tensor.extract_slice %{{.+}}[7, 9, 11, %{{.+}}] [1, 16, 1, %{{.+}}] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<16x?xf32>
 
+// -----
+
 func.func @extract_slice_rank_reducing_producer(
     %src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<8x?xf32> {
   %0 = tensor.extract_slice %src[0, 1, 2, %offset0] [1, 128, 1, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x?xf32>
@@ -30,8 +34,27 @@ func.func @extract_slice_rank_reducing_producer(
   return %1: tensor<8x?xf32>
 }
 
-//   CHECK-LABEL: func.func @extract_slice_rank_reducing_producer
-// CHECK-COUNT-2:   tensor.extract_slice
+// CHECK-LABEL: func.func @extract_slice_rank_reducing_producer
+//  CHECK-SAME: (%[[SRC:.+]]: tensor<?x?x?x?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index)
+//       CHECK:   %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET1]], %[[OFFSET0]]]
+//       CHECK:   %[[EXTRACT:.+]] = tensor.extract_slice %[[SRC]][0, 8, 2, %[[OFFSET]]] [1, 8, 1, %[[SIZE1]]] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<8x?xf32>
+//       CHECK:   return %[[EXTRACT]] : tensor<8x?xf32>
+
+// -----
+
+func.func @extract_slice_non_one_stride(
+    %src: tensor<?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index, %stride0: index, %stride1: index) -> tensor<?xf32> {
+  %0 = tensor.extract_slice %src[%offset0] [%size0] [%stride0] : tensor<?xf32> to tensor<?xf32>
+  %1 = tensor.extract_slice %0[%offset1] [%size1] [%stride1] : tensor<?xf32> to tensor<?xf32>
+  return %1: tensor<?xf32>
+}
+
+// CHECK-LABEL: func.func @extract_slice_non_one_stride
+//  CHECK-SAME: (%[[SRC:.+]]: tensor<?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index, %[[STRIDE0:.+]]: index, %[[STRIDE1:.+]]: index)
+//       CHECK:   %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>()[%[[OFFSET1]], %[[STRIDE0]], %[[OFFSET0]]]
+//       CHECK:   %[[STRIDE:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%[[STRIDE1]], %[[STRIDE0]]]
+//       CHECK:   %[[EXTRACT:.+]] = tensor.extract_slice %[[SRC]][%[[OFFSET]]] [%[[SIZE1]]] [%[[STRIDE]]] : tensor<?xf32> to tensor<?xf32>
+//       CHECK:   return %[[EXTRACT]] : tensor<?xf32>
 
 // -----
 
@@ -47,6 +70,8 @@ func.func @insert_slice_rank_reducing(
 //       CHECK:  %[[INSERT:.+]] = tensor.insert_slice %[[SRC]] into %[[DST]][6, 7, 8, %[[IDX]]] [1, 1, 16, 1] [1, 1, 1, 1]
 //       CHECK:  return %[[INSERT]]
 
+// -----
+
 func.func @insert_slice_rank_reducing_dynamic_shape(
     %dst: tensor<128x128x128x128xf32>, %mid: tensor<1x?x1xf32>, %src: tensor<?xf32>, %offset: index, %size: index) -> tensor<128x128x128x128xf32> {
   %0 = tensor.insert_slice %src into %mid[0, 0, 0] [1, %size, 1] [1, 1, 1] : tensor<?xf32> into tensor<1x?x1xf32>