[mlir][tensor] Merge consecutive insert_slice/extract_slice ops
authorLei Zhang <antiagainst@google.com>
Tue, 20 Sep 2022 23:52:19 +0000 (19:52 -0400)
committerLei Zhang <antiagainst@google.com>
Tue, 20 Sep 2022 23:52:56 +0000 (19:52 -0400)
Consecutive tensor.insert_slice/tensor.extract_slice can be
created for the case like tiling convolution and then downsizing
2-D convolutions into 1-D ones. It hinders further transformations.
So adding these patterns to clean it up.

Given that bufferization is sensitive and have requirements over
the IR structure (see https://reviews.llvm.org/D132666),
these patterns are put in Transforms/ with separate entry points
for explicit collection.

Reviewed By: ThomasRaoux, mravishankar

Differential Revision: https://reviews.llvm.org/D133871

mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp [new file with mode: 0644]
mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp

index 28c22ae..13ff67e 100644 (file)
@@ -29,6 +29,13 @@ void populateSplitPaddingPatterns(RewritePatternSet &patterns,
 FailureOr<Value> replaceExtractSliceWithTiledProducer(
     OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
 
+/// Collects patterns to merge consecutive tensor.insert_slice/extract_slice
+/// into one. These patterns are in in this separate entry point because the
+/// bufferization is sensitive over IR structure, particularly those
+/// tensor.extract_slice and tensor.insert_slice ops for creating the slices.
+void populateMergeConsecutiveInsertExtractSlicePatterns(
+    RewritePatternSet &patterns);
+
 } // namespace tensor
 } // namespace mlir
 
index 0b200e0..73bab56 100644 (file)
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
   ExtractSliceFromReshape.cpp
+  MergeConsecutiveInsertExtractSlicePatterns.cpp
   SplitPadding.cpp
   SwapExtractSliceWithProducer.cpp
 
diff --git a/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
new file mode 100644 (file)
index 0000000..48977a9
--- /dev/null
@@ -0,0 +1,117 @@
+//===- MergeConsecutiveInsertExtractSlicePatterns.cpp ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+/// Adds each corresponding pair of offsets in `offsets1` and `offsets2` and
+/// returns the results.
+static SmallVector<OpFoldResult> mergeOffsets(Location loc,
+                                              ArrayRef<OpFoldResult> offsets1,
+                                              ArrayRef<OpFoldResult> offsets2,
+                                              OpBuilder &builder) {
+  SmallVector<OpFoldResult> foldedOffsets;
+  assert(offsets1.size() == offsets2.size());
+  foldedOffsets.reserve(offsets1.size());
+
+  AffineExpr dim1, dim2;
+  bindDims(builder.getContext(), dim1, dim2);
+
+  for (const auto &pair : llvm::zip(offsets1, offsets2)) {
+    auto offset0 =
+        getValueOrCreateConstantIndexOp(builder, loc, std::get<0>(pair));
+    auto offset1 =
+        getValueOrCreateConstantIndexOp(builder, loc, std::get<1>(pair));
+    auto foldedOffset =
+        makeComposedAffineApply(builder, loc, dim1 + dim2, {offset0, offset1});
+    foldedOffsets.push_back(foldedOffset.getResult());
+  }
+  return foldedOffsets;
+}
+
+namespace {
+/// Merges consecutive tensor.extract_slice ops into one.
+struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractSliceOp nextOp,
+                                PatternRewriter &rewriter) const override {
+    auto prevOp = nextOp.getSource().getDefiningOp<ExtractSliceOp>();
+    if (!prevOp)
+      return failure();
+
+    if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride())
+      return failure();
+
+    auto prevResultType = prevOp.getType().cast<ShapedType>();
+    if (prevOp.getSourceType().getRank() != prevResultType.getRank())
+      return rewriter.notifyMatchFailure(
+          prevOp, "rank-reducing producder case unimplemented");
+
+    Location loc = nextOp.getLoc();
+
+    SmallVector<OpFoldResult> prevOffsets = prevOp.getMixedOffsets();
+    SmallVector<OpFoldResult> nextOffsets = nextOp.getMixedOffsets();
+    SmallVector<OpFoldResult> foldedOffsets =
+        mergeOffsets(loc, prevOffsets, nextOffsets, rewriter);
+
+    rewriter.replaceOpWithNewOp<ExtractSliceOp>(
+        nextOp, nextOp.getType(), prevOp.getSource(), foldedOffsets,
+        nextOp.getMixedSizes(), nextOp.getMixedStrides());
+    return success();
+  }
+};
+
+/// Merges consecutive tensor.insert_slice ops into one.
+struct MergeConsecutiveInsertSlice : public OpRewritePattern<InsertSliceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(InsertSliceOp nextOp,
+                                PatternRewriter &rewriter) const override {
+    auto prevOp = nextOp.getSource().getDefiningOp<InsertSliceOp>();
+    if (!prevOp)
+      return failure();
+
+    if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride())
+      return failure();
+
+    // The first insert_slice op should be rank reducing to make sure we cover
+    // the full source tensor to be inserted in the second insert_slice op.
+    SliceVerificationResult result =
+        isRankReducedType(prevOp.getDestType(), prevOp.getSourceType());
+    if (result != SliceVerificationResult::Success)
+      return failure();
+
+    // Dynamic dimensions can pass rank reducing check in the above, e.g,
+    // inserting <?xf32> into <1x?x1xf32>. For such cases we cannot be certain
+    // the dynamic size covers the full tensor.
+    if (!prevOp.getSourceType().hasStaticShape() ||
+        !prevOp.getDestType().hasStaticShape())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<InsertSliceOp>(
+        nextOp, prevOp.getSource(), nextOp.getDest(), nextOp.getMixedOffsets(),
+        nextOp.getMixedSizes(), nextOp.getMixedStrides());
+    return success();
+  }
+};
+} // namespace
+
+void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<MergeConsecutiveExtractSlice, MergeConsecutiveInsertSlice>(
+      patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir b/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir
new file mode 100644 (file)
index 0000000..45a3f37
--- /dev/null
@@ -0,0 +1,58 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-consecutive-insert-extract-slice -canonicalize -mlir-print-local-scope %s | FileCheck %s
+
+func.func @extract_slice_same_rank(
+    %src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<8x16x32x?xf32> {
+  %0 = tensor.extract_slice %src[0, 1, 2, %offset0] [128, 128, 128, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x128x128x?xf32>
+  %1 = tensor.extract_slice %0[7, 8, 9, %offset1] [8, 16, 32, %size1] [1, 1, 1, 1] : tensor<128x128x128x?xf32> to tensor<8x16x32x?xf32>
+  return %1: tensor<8x16x32x?xf32>
+}
+
+// CHECK-LABEL: func.func @extract_slice_same_rank
+//  CHECK-SAME: (%[[SOURCE:.+]]: tensor<?x?x?x?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index)
+//       CHECK:   %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET0]], %[[OFFSET1]]]
+//       CHECK:   %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]][7, 9, 11, %[[OFFSET]]] [8, 16, 32, %[[SIZE1]]] [1, 1, 1, 1]
+//       CHECK:   return %[[EXTRACT]] : tensor<8x16x32x?xf32>
+
+func.func @extract_slice_rank_reducing_consumer(
+    %src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<16x?xf32> {
+  %0 = tensor.extract_slice %src[0, 1, 2, %offset0] [128, 128, 128, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x128x128x?xf32>
+  %1 = tensor.extract_slice %0[7, 8, 9, %offset1] [1, 16, 1, %size1] [1, 1, 1, 1] : tensor<128x128x128x?xf32> to tensor<16x?xf32>
+  return %1: tensor<16x?xf32>
+}
+
+// CHECK-LABEL: func.func @extract_slice_rank_reducing_consumer
+//       CHECK:   tensor.extract_slice %{{.+}}[7, 9, 11, %{{.+}}] [1, 16, 1, %{{.+}}] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<16x?xf32>
+
+func.func @extract_slice_rank_reducing_producer(
+    %src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<8x?xf32> {
+  %0 = tensor.extract_slice %src[0, 1, 2, %offset0] [1, 128, 1, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x?xf32>
+  %1 = tensor.extract_slice %0[7, %offset1] [8, %size1] [1, 1] : tensor<128x?xf32> to tensor<8x?xf32>
+  return %1: tensor<8x?xf32>
+}
+
+//   CHECK-LABEL: func.func @extract_slice_rank_reducing_producer
+// CHECK-COUNT-2:   tensor.extract_slice
+
+// -----
+
+func.func @insert_slice_rank_reducing(
+    %dst: tensor<128x128x128x128xf32>, %mid: tensor<1x16x1xf32>, %src: tensor<16xf32>, %offset: index) -> tensor<128x128x128x128xf32> {
+  %0 = tensor.insert_slice %src into %mid[0, 0, 0] [1, 16, 1] [1, 1, 1] : tensor<16xf32> into tensor<1x16x1xf32>
+  %1 = tensor.insert_slice %0 into %dst[6, 7, 8, %offset] [1, 1, 16, 1] [1, 1, 1, 1] : tensor<1x16x1xf32> into tensor<128x128x128x128xf32>
+  return %1: tensor<128x128x128x128xf32>
+}
+
+// CHECK-LABEL: func.func @insert_slice_rank_reducing
+//  CHECK-SAME: (%[[DST:.+]]: tensor<128x128x128x128xf32>, %{{.+}}: tensor<1x16x1xf32>, %[[SRC:.+]]: tensor<16xf32>, %[[IDX:.+]]: index)
+//       CHECK:  %[[INSERT:.+]] = tensor.insert_slice %[[SRC]] into %[[DST]][6, 7, 8, %[[IDX]]] [1, 1, 16, 1] [1, 1, 1, 1]
+//       CHECK:  return %[[INSERT]]
+
+func.func @insert_slice_rank_reducing_dynamic_shape(
+    %dst: tensor<128x128x128x128xf32>, %mid: tensor<1x?x1xf32>, %src: tensor<?xf32>, %offset: index, %size: index) -> tensor<128x128x128x128xf32> {
+  %0 = tensor.insert_slice %src into %mid[0, 0, 0] [1, %size, 1] [1, 1, 1] : tensor<?xf32> into tensor<1x?x1xf32>
+  %1 = tensor.insert_slice %0 into %dst[6, 7, 8, %offset] [1, 1, %size, 1] [1, 1, 1, 1] : tensor<1x?x1xf32> into tensor<128x128x128x128xf32>
+  return %1: tensor<128x128x128x128xf32>
+}
+
+//   CHECK-LABEL: func.func @insert_slice_rank_reducing_dynamic_shape
+// CHECK-COUNT-2:   tensor.insert_slice
index 5dd5d76..e06607c 100644 (file)
@@ -53,6 +53,12 @@ struct TestTensorTransforms
       llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"),
       llvm::cl::init(false)};
 
+  Option<bool> testFoldConsecutiveInsertExtractSlice{
+      *this, "test-fold-consecutive-insert-extract-slice",
+      llvm::cl::desc(
+          "Test folding consecutive tensor.insert_slice/tensor.extract_slice"),
+      llvm::cl::init(false)};
+
   Option<bool> testRewriteExtractSliceWithTiledCollapseShape{
       *this, "test-rewrite-extract-slice-from-collapse-shape",
       llvm::cl::desc("Test swapping tensor.extract_slice of a collapse_shape "
@@ -90,6 +96,12 @@ static void applyFoldConstantExtractSlicePatterns(Operation *rootOp) {
   (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
 }
 
+static void applyFoldConsecutiveInsertExtractSlicePatterns(Operation *rootOp) {
+  RewritePatternSet patterns(rootOp->getContext());
+  tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
+}
+
 namespace {
 /// Base pattern to rewrite  a `tensor.collapse_shape -> tensor.extract_slice`.
 /// The `tensor.extract_slice` is replaced by a loop or gather operation that
@@ -233,6 +245,8 @@ void TestTensorTransforms::runOnOperation() {
     applySplitPaddingPatterns(rootOp);
   if (testFoldConstantExtractSlice)
     applyFoldConstantExtractSlicePatterns(rootOp);
+  if (testFoldConsecutiveInsertExtractSlice)
+    applyFoldConsecutiveInsertExtractSlicePatterns(rootOp);
   if (testRewriteExtractSliceWithTiledCollapseShape) {
     if (failed(
             applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))