[mlir][bufferization] Allow to_memref ops in One-Shot Analysis
authorMatthias Springer <me@m-sp.org>
Tue, 20 Jun 2023 15:53:44 +0000 (17:53 +0200)
committerMatthias Springer <me@m-sp.org>
Wed, 21 Jun 2023 06:42:25 +0000 (08:42 +0200)
bufferization.to_memref ops are allowed in One-Shot Bufferize, but they are treated conservatively: in the absence of a memref analysis, we have to assume that the result buffer is read and written.

Note: to_memref cannot introduce any future aliases that would have to be considered during One-Shot Bufferize, because only to_tensor ops with the `restrict` attribute are supported. Such tensors are guaranteed to not alias with any other buffer after bufferization.

Differential Revision: https://reviews.llvm.org/D153365

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir

index 534bc26..726b6b5 100644 (file)
@@ -413,12 +413,6 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
       return true;
     }
 
-    bool mustBufferizeInPlace(OpOperand &opOperand,
-                              const AnalysisState &state) const {
-      // ToMemrefOps always bufferize inplace.
-      return true;
-    }
-
     AliasingOpResultList getAliasingOpResults(
         OpOperand &opOperand, const AnalysisState &state) const {
       return {};
index 34959aa..4a5052d 100644 (file)
@@ -938,13 +938,6 @@ static LogicalResult checkAliasInfoConsistency(Operation *op,
     if (!options.isOpAllowed(op.getOperation()))
       return WalkResult::advance();
 
-    // Input IR may not contain any ToMemrefOps. These are not supported because
-    // the analysis cannot follow the data flow through memrefs.
-    if (isa<ToMemrefOp>(op.getOperation())) {
-      op->emitError("to_memref ops are not supported by One-Shot Analysis");
-      return WalkResult::interrupt();
-    }
-
     // Input IR may not contain any ToTensorOps without the "restrict"
     // attribute. Such tensors may alias any other tensor, which is currently
     // not handled in the analysis.
index a2d47f0..071ec6f 100644 (file)
@@ -231,23 +231,6 @@ func.func @main() -> tensor<4xi32> {
 
 // -----
 
-func.func @to_memref_op_unsupported(
-    %t1: tensor<?xf32> {bufferization.writable = true}, %idx1: index,
-    %idx2: index, %idx3: index, %v1: vector<5xf32>) -> (vector<5xf32>, vector<5xf32>) {
-
-  // expected-error @+1 {{to_memref ops are not supported by One-Shot Analysis}}
-  %0 = bufferization.to_memref %t1 : memref<?xf32>
-
-  // Read from both.
-  %cst = arith.constant 0.0 : f32
-  %r1 = vector.transfer_read %t1[%idx3], %cst : tensor<?xf32>, vector<5xf32>
-  %r2 = vector.transfer_read %0[%idx3], %cst : memref<?xf32>, vector<5xf32>
-
-  return %r1, %r2 : vector<5xf32>, vector<5xf32>
-}
-
-// -----
-
 func.func @to_tensor_op_unsupported(%m: memref<?xf32>, %idx: index) -> (f32) {
   // expected-error @+1 {{to_tensor ops without `restrict` are not supported by One-Shot Analysis}}
   %0 = bufferization.to_tensor %m : memref<?xf32>
index 4103a4c..c2f88c6 100644 (file)
@@ -636,3 +636,26 @@ func.func @call_llvm_func() {
   llvm.call @llvm_func() : () -> ()
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @to_memref_op_unsupported(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<?xf32,
+func.func @to_memref_op_unsupported(
+    %t1: tensor<?xf32> {bufferization.writable = true}, %idx1: index,
+    %idx2: index, %idx3: index, %v1: vector<5xf32>) -> (vector<5xf32>) {
+
+  // Insert a copy because we cannot analyze what happens with the result of a
+  // to_memref op.
+  // CHECK: %[[alloc:.*]] = memref.alloc
+  // CHECK: memref.copy %[[arg0]], %[[alloc]]
+  %0 = bufferization.to_memref %t1 : memref<?xf32>
+  // CHECK: "test.foo"(%[[alloc]])
+  "test.foo"(%0) : (memref<?xf32>) -> ()
+
+  // CHECK: vector.transfer_read %[[arg0]]
+  %cst = arith.constant 0.0 : f32
+  %r1 = vector.transfer_read %t1[%idx3], %cst : tensor<?xf32>, vector<5xf32>
+
+  return %r1 : vector<5xf32>
+}