[mlir][utils] Fix invalid reshapes in ComposeCollapseOfExpandOp
authorMatthias Springer <springerm@google.com>
Wed, 23 Nov 2022 10:56:07 +0000 (11:56 +0100)
committerMatthias Springer <springerm@google.com>
Wed, 23 Nov 2022 12:52:00 +0000 (13:52 +0100)
Do not generate CollapseShapeOps/ExpandShapeOps that have the same source and result shape. Generate casts instead. Such reshapes became invalid with D138498.

Differential Revision: https://reviews.llvm.org/D138557

mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
mlir/test/Dialect/Tensor/canonicalize.mlir

index dba055d..760a5aa 100644 (file)
@@ -225,7 +225,7 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
 //
 /// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
 /// `reassociation_2` and produce `expand_shape`.
-template <typename CollapseOpTy, typename ExpandOpTy>
+template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy>
 struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
   using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
@@ -250,8 +250,7 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
     SmallVector<ReassociationIndices, 4> higherRankReassociation,
         lowerRankReassociation;
 
-    bool isResultCollapsed = srcRank > resultRank;
-    if (isResultCollapsed) {
+    if (srcRank > resultRank) {
       higherRankReassociation = expandOp.getReassociationIndices();
       lowerRankReassociation = collapseOp.getReassociationIndices();
     } else {
@@ -274,12 +273,20 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
       }
       composedReassociation.push_back(composedIndices);
     }
-    if (isResultCollapsed)
+    if (srcRank > resultRank) {
       rewriter.replaceOpWithNewOp<CollapseOpTy>(
           collapseOp, resultType, expandOp.getSrc(), composedReassociation);
-    else
+    } else if (srcRank < resultRank) {
       rewriter.replaceOpWithNewOp<ExpandOpTy>(
           collapseOp, resultType, expandOp.getSrc(), composedReassociation);
+    } else {
+      // Collapses/expansions that do not change the rank are not allowed. Use
+      // a cast instead.
+      assert(llvm::equal(srcType.getShape(), resultType.getShape()) &&
+             "expected same shape");
+      rewriter.replaceOpWithNewOp<CastOpTy>(collapseOp, resultType,
+                                            expandOp.getSrc());
+    }
     return success();
   }
 };
index 2bbb57e..503c8ae 100644 (file)
@@ -2447,7 +2447,7 @@ public:
 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
   results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
-              ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
+              ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp>,
               CollapseShapeOpMemRefCastFolder>(context);
 }
 
index 36e3aad..23af46c 100644 (file)
@@ -1586,7 +1586,7 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
   results
       .add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
-           ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
+           ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp>,
            FoldReshapeWithConstant<CollapseShapeOp>,
            FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
           context);
index d34d0d9..d9710b7 100644 (file)
@@ -859,3 +859,19 @@ func.func @memref_realloc_dead(%src : memref<2xf32>, %v : f32) -> memref<2xf32>{
   memref.store %v, %0[%i2] : memref<4xf32>
   return %src : memref<2xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @collapse_expand_fold_to_cast(
+//  CHECK-SAME:     %[[m:.*]]: memref<?xf32, strided<[1]>, 3>
+//       CHECK:   %[[casted:.*]] = memref.cast %[[m]] : memref<?xf32, strided<[1]>, 3> to memref<?xf32, 3
+//       CHECK:   return %[[casted]]
+func.func @collapse_expand_fold_to_cast(%m: memref<?xf32, strided<[1]>, 3>)
+    -> (memref<?xf32, 3>)
+{
+  %0 = memref.expand_shape %m [[0, 1]]
+      : memref<?xf32, strided<[1]>, 3> into memref<1x?xf32, 3>
+  %1 = memref.collapse_shape %0 [[0, 1]]
+      : memref<1x?xf32, 3> into memref<?xf32, 3>
+  return %1 : memref<?xf32, 3>
+}
index c9e662f..92e329d 100644 (file)
@@ -1666,3 +1666,15 @@ func.func @dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
   %1 = tensor.dim %0, %c1 : tensor<?x?xf32>
   return %1 : index
 }
+
+// -----
+
+// CHECK-LABEL: func @collapse_expand_fold_to_cast(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?xf32>
+//       CHECK:   return %[[t]]
+func.func @collapse_expand_fold_to_cast(%t: tensor<?xf32>) -> (tensor<?xf32>)
+{
+  %0 = tensor.expand_shape %t [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
+  %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+  return %1 : tensor<?xf32>
+}