[mlir][bufferization] Do not canonicalize to_tensor(to_memref(x))
authorMatthias Springer <springerm@google.com>
Sat, 9 Jul 2022 07:15:36 +0000 (09:15 +0200)
committerMatthias Springer <springerm@google.com>
Sat, 9 Jul 2022 07:16:52 +0000 (09:16 +0200)
This is a partial revert of D128615.

to_memref(to_tensor(x)) always be folded to x. But to_tensor(to_memref(x)) cannot be folded in the general case because writes to the intermediary memref may go unnoticed.

Differential Revision: https://reviews.llvm.org/D129354

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir

index 35f6f1b6a97f452a2b52b7b72d2c0f02ab7a45dd..4ab904ea39309153cb5f68527046c80dbcc08318 100644 (file)
@@ -539,20 +539,6 @@ OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
 }
 
 namespace {
-/// Canonicalize bufferization.to_tensor + bufferization.to_memref.
-struct ToTensorToMemrefFolding : public OpRewritePattern<ToTensorOp> {
-  using OpRewritePattern<ToTensorOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ToTensorOp toTensorOp,
-                                PatternRewriter &rewriter) const final {
-    auto toMemrefOp = toTensorOp.getMemref().getDefiningOp<ToMemrefOp>();
-    if (!toMemrefOp)
-      return failure();
-    rewriter.replaceOp(toTensorOp, toMemrefOp.getTensor());
-    return success();
-  }
-};
-
 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
 
@@ -571,7 +557,7 @@ struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
 
 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                              MLIRContext *context) {
-  results.add<DimOfToTensorFolder, ToTensorToMemrefFolding>(context);
+  results.add<DimOfToTensorFolder>(context);
 }
 
 //===----------------------------------------------------------------------===//
index 535a00706100f1caa57545e6520676c87b88e012..8e087fc0f38a412a0c4465840d0da4b619b0c74d 100644 (file)
@@ -787,7 +787,8 @@ func.func @last_value(%t0: tensor<128x128xf32>, %t1: tensor<128x128xf32>,
   }
 
   // CHECK-NEXT: %[[R0:.*]] = bufferization.to_tensor %[[M0]] : memref<128x128xf32>
-  // CHECK-NEXT: return %[[R0]], %[[T1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
+  // CHECK-NEXT: %[[R1:.*]] = bufferization.to_tensor %[[M1]] : memref<128x128xf32>
+  // CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
   return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
 }
 
index f24048e60e07cedc0edf758aaeacd6841e10e099..df55b8373e0eef6fa4bf16f36444ea0fddce357f 100644 (file)
 // CHECK:             scf.yield %[[VAL_84]] : f64
 // CHECK:           }
 // CHECK:           memref.store %[[VAL_86:.*]], %[[VAL_15]][] : memref<f64>
-// CHECK:           return %[[VAL_0]] : tensor<f64>
+// CHECK:           %[[VAL_87:.*]] = bufferization.to_tensor %[[VAL_15]] : memref<f64>
+// CHECK:           return %[[VAL_87]] : tensor<f64>
 // CHECK:         }
 func.func @sparse_matrix_sum(%argx: tensor<f64> {linalg.inplaceable = true},
                          %arga: tensor<64x32xf64, #SparseMatrix>,