This is a partial revert of
D128615.
to_memref(to_tensor(x)) always be folded to x. But to_tensor(to_memref(x)) cannot be folded in the general case because writes to the intermediary memref may go unnoticed.
Differential Revision: https://reviews.llvm.org/
D129354
}
namespace {
-/// Canonicalize bufferization.to_tensor + bufferization.to_memref.
-struct ToTensorToMemrefFolding : public OpRewritePattern<ToTensorOp> {
- using OpRewritePattern<ToTensorOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ToTensorOp toTensorOp,
- PatternRewriter &rewriter) const final {
- auto toMemrefOp = toTensorOp.getMemref().getDefiningOp<ToMemrefOp>();
- if (!toMemrefOp)
- return failure();
- rewriter.replaceOp(toTensorOp, toMemrefOp.getTensor());
- return success();
- }
-};
-
struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<DimOfToTensorFolder, ToTensorToMemrefFolding>(context);
+ results.add<DimOfToTensorFolder>(context);
}
//===----------------------------------------------------------------------===//
}
// CHECK-NEXT: %[[R0:.*]] = bufferization.to_tensor %[[M0]] : memref<128x128xf32>
- // CHECK-NEXT: return %[[R0]], %[[T1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
+ // CHECK-NEXT: %[[R1:.*]] = bufferization.to_tensor %[[M1]] : memref<128x128xf32>
+ // CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
}
// CHECK: scf.yield %[[VAL_84]] : f64
// CHECK: }
// CHECK: memref.store %[[VAL_86:.*]], %[[VAL_15]][] : memref<f64>
-// CHECK: return %[[VAL_0]] : tensor<f64>
+// CHECK: %[[VAL_87:.*]] = bufferization.to_tensor %[[VAL_15]] : memref<f64>
+// CHECK: return %[[VAL_87]] : tensor<f64>
// CHECK: }
func.func @sparse_matrix_sum(%argx: tensor<f64> {linalg.inplaceable = true},
%arga: tensor<64x32xf64, #SparseMatrix>,