[mlir][tensor] Fold rank-reducing extract_slice with inverse expand_shape
authorMatthias Springer <springerm@google.com>
Mon, 5 Dec 2022 08:15:52 +0000 (09:15 +0100)
committerMatthias Springer <springerm@google.com>
Mon, 5 Dec 2022 08:17:24 +0000 (09:17 +0100)
Differential Revision: https://reviews.llvm.org/D139220

mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp [new file with mode: 0644]
mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp

index 13ff67e..267972e 100644 (file)
@@ -36,6 +36,10 @@ FailureOr<Value> replaceExtractSliceWithTiledProducer(
 void populateMergeConsecutiveInsertExtractSlicePatterns(
     RewritePatternSet &patterns);
 
+/// Populates `patterns` with patterns that fold `tensor.expand_shape` and
+/// `tensor.collapse_shape` into other ops.
+void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns);
+
 } // namespace tensor
 } // namespace mlir
 
index 75216e7..08a0d5a 100644 (file)
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
   Bufferize.cpp
   ExtractSliceFromReshapeUtils.cpp
   MergeConsecutiveInsertExtractSlicePatterns.cpp
+  ReshapePatterns.cpp
   SplitPaddingPatterns.cpp
   SwapExtractSliceWithProducerPatterns.cpp
 
@@ -26,4 +27,4 @@ add_mlir_dialect_library(MLIRTensorTransforms
   MLIRTensorDialect
   MLIRTilingInterface
   MLIRTransforms
-  )
+)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
new file mode 100644 (file)
index 0000000..c1166c5
--- /dev/null
@@ -0,0 +1,57 @@
+//===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "mlir-tensor-split-padding"
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+namespace {
+/// Fold expand_shape(extract_slice) ops that cancel itself out.
+struct FoldExpandOfRankReducingExtract
+    : public OpRewritePattern<ExpandShapeOp> {
+  using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
+                                PatternRewriter &rewriter) const override {
+    RankedTensorType resultType = expandShapeOp.getResultType();
+    auto extractSliceOp =
+        expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
+    if (!extractSliceOp)
+      return failure();
+    RankedTensorType srcType = extractSliceOp.getSourceType();
+
+    // Only cases where the ExpandShapeOp can be folded away entirely are
+    // supported. Moreover, only simple cases where the resulting ExtractSliceOp
+    // has no rank-reduction anymore are supported at the moment.
+    RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
+        srcType, extractSliceOp.getStaticOffsets(),
+        extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
+    if (nonReducingExtractType != resultType)
+      return failure();
+
+    SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
+    SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
+    rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
+        expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
+        mixedStrides);
+    return success();
+  }
+};
+} // namespace
+
+void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<FoldExpandOfRankReducingExtract>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
new file mode 100644 (file)
index 0000000..c81e531
--- /dev/null
@@ -0,0 +1,19 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-reassociative-reshape-folding %s | FileCheck %s
+
+// CHECK-LABEL: func @expand_shape_of_rank_reducing_extract(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?x?xf32>
+//   CHECK-DAG:   %[[extract1:.*]] = tensor.extract_slice %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x1x1x5xf32>
+//   CHECK-DAG:   %[[extract2:.*]] = tensor.extract_slice %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x1x1x5xf32>
+//       CHECK:   return %[[extract1]], %[[extract2]]
+func.func @expand_shape_of_rank_reducing_extract(
+    %t: tensor<?x?x?x?xf32>, %idx: index)
+  -> (tensor<?x1x1x5xf32>, tensor<?x1x1x5xf32>)
+{
+  %0 = tensor.extract_slice %t[0, 0, 0, 0][%idx, 1, 1, 5][1, 1, 1, 1]
+      : tensor<?x?x?x?xf32> to tensor<?x1x5xf32>
+  %1 = tensor.expand_shape %0 [[0], [1, 2], [3]]
+      : tensor<?x1x5xf32> into tensor<?x1x1x5xf32>
+  %2 = tensor.expand_shape %0 [[0, 1], [2], [3]]
+      : tensor<?x1x5xf32> into tensor<?x1x1x5xf32>
+  return %1, %2 : tensor<?x1x1x5xf32>, tensor<?x1x1x5xf32>
+}
index fa3cdb7..1802387 100644 (file)
@@ -65,6 +65,11 @@ struct TestTensorTransforms
                      "with loop nest"),
       llvm::cl::init(false)};
 
+  Option<bool> testReassociativeReshapeFolding{
+      *this, "test-reassociative-reshape-folding",
+      llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
+      llvm::cl::init(false)};
+
   Option<bool> useForeach{
       *this, "use-foreach",
       llvm::cl::desc(
@@ -74,6 +79,12 @@ struct TestTensorTransforms
 };
 } // namespace
 
+static void applyReassociativeReshapeFoldingPatterns(Operation *rootOp) {
+  RewritePatternSet patterns(rootOp->getContext());
+  tensor::populateReassociativeReshapeFoldingPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
+}
+
 static void applySplitPaddingPatterns(Operation *rootOp) {
   RewritePatternSet patterns(rootOp->getContext());
   tensor::populateSplitPaddingPatterns(patterns);
@@ -251,6 +262,8 @@ void TestTensorTransforms::runOnOperation() {
     applyFoldConstantExtractSlicePatterns(rootOp);
   if (testFoldConsecutiveInsertExtractSlice)
     applyFoldConsecutiveInsertExtractSlicePatterns(rootOp);
+  if (testReassociativeReshapeFolding)
+    applyReassociativeReshapeFoldingPatterns(rootOp);
   if (testRewriteExtractSliceWithTiledCollapseShape) {
     if (failed(
             applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))