[mlir][tensor] Fold rank-reducing insert_slice with inverse collapse_shape
authorMatthias Springer <springerm@google.com>
Mon, 5 Dec 2022 08:16:05 +0000 (09:16 +0100)
committerMatthias Springer <springerm@google.com>
Mon, 5 Dec 2022 08:17:29 +0000 (09:17 +0100)
Differential Revision: https://reviews.llvm.org/D139221

mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir

index c1166c5..b655df3 100644 (file)
@@ -49,9 +49,41 @@ struct FoldExpandOfRankReducingExtract
     return success();
   }
 };
+
+/// Fold insert_slice(collapse_shape) ops that cancel itself out.
+struct FoldInsertOfRankReducingInsert : public OpRewritePattern<InsertSliceOp> {
+  using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
+                                PatternRewriter &rewriter) const override {
+    auto collapseShapeOp =
+        insertSliceOp.getSource().getDefiningOp<CollapseShapeOp>();
+    if (!collapseShapeOp)
+      return failure();
+    RankedTensorType srcType = collapseShapeOp.getSrcType();
+
+    // Only cases where the CollapseShapeOp can be folded away entirely are
+    // supported. Moreover, only simple cases where the resulting InsertSliceOp
+    // has no rank-reduction anymore are supported at the moment.
+    RankedTensorType nonReducingInsertType =
+        RankedTensorType::get(insertSliceOp.getStaticSizes(),
+                              insertSliceOp.getType().getElementType());
+    if (nonReducingInsertType != srcType)
+      return failure();
+
+    SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
+    SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
+    rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
+        insertSliceOp, collapseShapeOp.getSrc(), insertSliceOp.getDest(),
+        mixedOffsets, mixedSizes, mixedStrides);
+    return success();
+  }
+};
 } // namespace
 
 void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<FoldExpandOfRankReducingExtract>(patterns.getContext());
+  patterns.add<FoldExpandOfRankReducingExtract, FoldInsertOfRankReducingInsert>(
+      patterns.getContext());
 }
index c81e531..15a00a5 100644 (file)
@@ -17,3 +17,19 @@ func.func @expand_shape_of_rank_reducing_extract(
       : tensor<?x1x5xf32> into tensor<?x1x1x5xf32>
   return %1, %2 : tensor<?x1x1x5xf32>, tensor<?x1x1x5xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @rank_reducing_insert_of_collapse_shape(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x1x1x5xf32>
+//       CHECK:   %[[insert:.*]] = tensor.insert_slice %[[t]] into %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x1x1x5xf32> into tensor<?x?x?x?xf32>
+//       CHECK:   return %[[insert]]
+func.func @rank_reducing_insert_of_collapse_shape(
+    %t: tensor<?x1x1x5xf32>, %d: tensor<?x?x?x?xf32>, %sz: index)
+  -> tensor<?x?x?x?xf32> {
+  %0 = tensor.collapse_shape %t [[0, 1], [2], [3]]
+      : tensor<?x1x1x5xf32> into tensor<?x1x5xf32>
+  %1 = tensor.insert_slice %0 into %d[0, 0, 0, 0][%sz, 1, 1, 5][1, 1, 1, 1]
+      : tensor<?x1x5xf32> into tensor<?x?x?x?xf32>
+  return %1 : tensor<?x?x?x?xf32>
+}