Revert "[mlir][tensor] Fold rank-reducing insert_slice with inverse collapse_shape"
authorMatthias Springer <springerm@google.com>
Fri, 2 Dec 2022 20:22:04 +0000 (21:22 +0100)
committerMatthias Springer <springerm@google.com>
Fri, 2 Dec 2022 20:22:04 +0000 (21:22 +0100)
This reverts commit 1522a3b7b34b41cf0b17678e4a8687797f44a3f0.

mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir

index b655df3..c1166c5 100644 (file)
@@ -49,41 +49,9 @@ struct FoldExpandOfRankReducingExtract
     return success();
   }
 };
-
-/// Fold insert_slice(collapse_shape) ops that cancel itself out.
-struct FoldInsertOfRankReducingInsert : public OpRewritePattern<InsertSliceOp> {
-  using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
-                                PatternRewriter &rewriter) const override {
-    auto collapseShapeOp =
-        insertSliceOp.getSource().getDefiningOp<CollapseShapeOp>();
-    if (!collapseShapeOp)
-      return failure();
-    RankedTensorType srcType = collapseShapeOp.getSrcType();
-
-    // Only cases where the CollapseShapeOp can be folded away entirely are
-    // supported. Moreover, only simple cases where the resulting InsertSliceOp
-    // has no rank-reduction anymore are supported at the moment.
-    RankedTensorType nonReducingInsertType =
-        RankedTensorType::get(insertSliceOp.getStaticSizes(),
-                              insertSliceOp.getType().getElementType());
-    if (nonReducingInsertType != srcType)
-      return failure();
-
-    SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
-    SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
-    SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
-    rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
-        insertSliceOp, collapseShapeOp.getSrc(), insertSliceOp.getDest(),
-        mixedOffsets, mixedSizes, mixedStrides);
-    return success();
-  }
-};
 } // namespace
 
 void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<FoldExpandOfRankReducingExtract, FoldInsertOfRankReducingInsert>(
-      patterns.getContext());
+  patterns.add<FoldExpandOfRankReducingExtract>(patterns.getContext());
 }
index 15a00a5..c81e531 100644 (file)
@@ -17,19 +17,3 @@ func.func @expand_shape_of_rank_reducing_extract(
       : tensor<?x1x5xf32> into tensor<?x1x1x5xf32>
   return %1, %2 : tensor<?x1x1x5xf32>, tensor<?x1x1x5xf32>
 }
-
-// -----
-
-// CHECK-LABEL: func @rank_reducing_insert_of_collapse_shape(
-//  CHECK-SAME:     %[[t:.*]]: tensor<?x1x1x5xf32>
-//       CHECK:   %[[insert:.*]] = tensor.insert_slice %[[t]] into %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x1x1x5xf32> into tensor<?x?x?x?xf32>
-//       CHECK:   return %[[insert]]
-func.func @rank_reducing_insert_of_collapse_shape(
-    %t: tensor<?x1x1x5xf32>, %d: tensor<?x?x?x?xf32>, %sz: index)
-  -> tensor<?x?x?x?xf32> {
-  %0 = tensor.collapse_shape %t [[0, 1], [2], [3]]
-      : tensor<?x1x1x5xf32> into tensor<?x1x5xf32>
-  %1 = tensor.insert_slice %0 into %d[0, 0, 0, 0][%sz, 1, 1, 5][1, 1, 1, 1]
-      : tensor<?x1x5xf32> into tensor<?x?x?x?xf32>
-  return %1 : tensor<?x?x?x?xf32>
-}