[mlir][bufferization] Add restrict and writable attrs to to_tensor
authorMatthias Springer <springerm@google.com>
Wed, 15 Feb 2023 08:51:42 +0000 (09:51 +0100)
committerMatthias Springer <springerm@google.com>
Wed, 15 Feb 2023 09:04:54 +0000 (10:04 +0100)
`restrict` is similar to the C++ restrict keyword. Results of `to_tensor` that have the `restrict` attribute are guaranteed to not alias any other `to_tensor` result (after bufferization).

Note: Since `to_memref` ops are not supported by One-Shot Bufferize and all bufferizable ops follow DPS rules (i.e., the buffer of the result is the buffer of an operand or an alias thereof), the buffer of a `to_tensor` op that has the `restrict` attribute is always an entirely "new" buffer that is not aliasing with the future buffer of any tensor value in the entire program. This makes such `to_tensor` ops "safe" from a bufferization perspective; they cannot cause RaW conflicts.

Differential Revision: https://reviews.llvm.org/D144021

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
mlir/test/Dialect/Bufferization/ops.mlir

index 0982700..d4a8161 100644 (file)
@@ -272,9 +272,8 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
   ]> {
   let summary = "memref to tensor operation";
   let description = [{
-    Create a tensor from a `memref`, making an independent copy of the element
-    data. The result value is a tensor whose shape and element type match the
-    memref operand.
+    An operation that creates a tensor from a `memref`. The result value is a
+    tensor whose shape and element type match the memref operand.
 
     The opposite of this op is `to_memref`. Together, these two ops are
     useful for source/target materializations when doing type conversions
@@ -284,15 +283,39 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 
     ```mlir
     // Produces a value of tensor<4x?xf32> type.
-    %12 = bufferization.to_tensor %10 : memref<4x?xf32, #layout, memspace0>
+    %t = bufferization.to_tensor %m : memref<4x?xf32, #layout, 0>
     ```
 
-    If tensor load is used in the bufferization steps, mutating the source
-    buffer after loading leads to undefined behavior.
+    If the `writable` unit attribute is set, the produced tensor is considered
+    "writable" during bufferization. Otherwise, every OpOperand that bufferizes
+    to a write to the future buffer of the resulting tensor (or an alias
+    thereof) will bufferize out-of-place to prevent emitting any writes to
+    `memref` during bufferization.
+
+    If the given memref does not alias with any other memref passed to another
+    `to_tensor` op, the `restrict` unit attribute can be set. Only such
+    operations are supported by One-Shot Bufferize. (Otherwise, potential memref
+    aliasing relationships would have to be captured in One-Shot Bufferize.)
+
+    Example:
+
+    ```
+    %t = bufferization.to_tensor %m restrict writable : memref<4xf32>
+
+    // %t is writable, so the tensor.insert may bufferize in-place in the
+    // absence of other conflicts.
+    %r = tensor.insert %f into %t[%idx] : tensor<4xf32>
+    ```
+
+    `to_tensor` ops are not bufferized. They are expected to fold away after
+    bufferization. If there are non-bufferizable ops in the IR and
+    `allowUnknownOps` is set, they may be part of the resulting IR and not fold
+    away. However, such IR is no longer bufferizable with One-Shot Bufferize.
   }];
 
   let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
-                       "the reference to load from", [MemRead]>:$memref);
+                       "the reference to load from", [MemRead]>:$memref,
+                       UnitAttr:$restrict, UnitAttr:$writable);
   let results = (outs AnyTensor:$result);
 
   let extraClassDeclaration = [{
@@ -308,30 +331,13 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     // BufferizableOpInterface implementation
     //===------------------------------------------------------------------===//
 
-    // ToTensorOp conceptually loads a tensor from a memory location. The
-    // One-Shot analysis has no information about the memref that is loaded from
-    // by ToTensorOp. We have to assume that the loaded tensor may after
-    // bufferization potentially alias with any other bufferized tensor. Since
-    // ToTensorOp and ToMemrefOp have no aliasing OpOperand/OpResult pairs, this
-    // cannot be encoded directly in the analysis. However, declaring ToTensorOp
-    // results as not writable enforces a buffer copy and has the same effect.
-
     LogicalResult bufferize(RewriterBase &rewriter,
                             const BufferizationOptions &options) const {
-      // to_tensor cannot be bufferized. However, other ops that are using
-      // to_tensor's result will eventually be bufferized. At that point, they
-      // will start using to_tensor's memref operand. Once all users of
-      // to_tensor are bufferized, the op will not have any users anymore and
-      // DCE away. In case of partial bufferization, to_memref(to_tensor(x))
-      // constructs may be left over. These are folded by the canonicalizer or
-      // FinalizingBufferize.
+      // to_tensor/to_memref pairs fold away after bufferization.
       return success();
     }
 
-    bool isWritable(Value value, const AnalysisState &state) const {
-      // It is unknown whether the memref operand is writable or not.
-      return false;
-    }
+    bool isWritable(Value value, const AnalysisState &state);
 
     FailureOr<BaseMemRefType> getBufferType(
         Value value, const BufferizationOptions &options,
@@ -340,7 +346,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     }
   }];
 
-  let assemblyFormat = "$memref attr-dict `:` type($memref)";
+  let assemblyFormat = [{
+    $memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
+      `:` type($memref)
+  }];
 
   let hasCanonicalizer = 1;
   let hasFolder = 1;
@@ -362,19 +371,19 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
   ]> {
   let summary = "tensor to memref cast operation";
   let description = [{
-    Casts a tensor to a memref.
+    An operation that returns the future buffer of a `tensor`.
 
     ```mlir
-    // Result type is memref<4x?xf32, #layout, 42>
-    %12 = bufferization.to_memref %10 : memref<4x?xf32, #layout, 42>
+    // Result type is memref<4x?xf32, #layout, 0>
+    %m = bufferization.to_memref %t : memref<4x?xf32, #layout, 0>
     ```
 
-    Note, that mutating the result of the `to_memref` operation leads to
-    undefined behavior.
-
     This operation is a specialized variant of the built-in
-    `unrealized_conversion_cast` and is intended for use in the context of
-    gradual bufferization.
+    `unrealized_conversion_cast` and is used to make sure that the IR stays
+    valid at any point during the bufferization.
+
+    IR that contains `to_memref` ops cannot be bufferized with One-Shot
+    Bufferize.
   }];
 
   let arguments = (ins AnyTensor:$tensor);
index 232e7b6..66e1807 100644 (file)
@@ -561,6 +561,10 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
 // ToTensorOp
 //===----------------------------------------------------------------------===//
 
+bool ToTensorOp::isWritable(Value value, const AnalysisState &state) {
+  return getWritable();
+}
+
 OpFoldResult ToTensorOp::fold(FoldAdaptor) {
   if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>())
     // Approximate alias analysis by conservatively folding only when no there
index a6730cb..db7d453 100644 (file)
@@ -942,10 +942,21 @@ static LogicalResult checkAliasInfoConsistency(Operation *op,
     // Input IR may not contain any ToMemrefOps. These are not supported because
     // the analysis cannot follow the data flow through memrefs.
     if (isa<ToMemrefOp>(op.getOperation())) {
-      op->emitError("to_memref ops not supported during One-Shot Analysis");
+      op->emitError("to_memref ops are not supported by One-Shot Analysis");
       return WalkResult::interrupt();
     }
 
+    // Input IR may not contain any ToTensorOps without the "restrict"
+    // attribute. Such tensors may alias any other tensor, which is currently
+    // not handled in the analysis.
+    if (auto toTensorOp = dyn_cast<ToTensorOp>(op.getOperation())) {
+      if (!toTensorOp.getRestrict()) {
+        op->emitError("to_tensor ops without `restrict` are not supported by "
+                      "One-Shot Analysis");
+        return WalkResult::interrupt();
+      }
+    }
+
     for (OpOperand &opOperand : op->getOpOperands()) {
       if (opOperand.get().getType().isa<TensorType>()) {
         if (wouldCreateReadAfterWriteInterference(
index caff954..5284f57 100644 (file)
@@ -1057,9 +1057,9 @@ func.func @main_func(%A : tensor<?xf32> {bufferization.writable = true},
 
 // CHECK-LABEL: func @to_tensor_op_not_writable
 func.func @to_tensor_op_not_writable(%m: memref<?xf32>, %v:  vector<5xf32>,
-                                %idx1: index, %idx2: index)
+                                     %idx1: index, %idx2: index)
     -> vector<10xf32> {
-  %0 = bufferization.to_tensor %m : memref<?xf32>
+  %0 = bufferization.to_tensor %m restrict : memref<?xf32>
 
   // Write to the tensor. Cannot be inplace due to tensor_load.
   //      CHECK: vector.transfer_write
index 759f4f3..189ef6b 100644 (file)
@@ -231,14 +231,11 @@ func.func @main() -> tensor<4xi32> {
 
 // -----
 
-func.func @to_memref_op_is_writing(
+func.func @to_memref_op_unsupported(
     %t1: tensor<?xf32> {bufferization.writable = true}, %idx1: index,
     %idx2: index, %idx3: index, %v1: vector<5xf32>) -> (vector<5xf32>, vector<5xf32>) {
-  // This is a RaW conflict because to_memref is an inplace write and %t1 is
-  // read further down. This will likely have to change with partial
-  // bufferization.
 
-  // expected-error @+1 {{to_memref ops not supported during One-Shot Analysis}}
+  // expected-error @+1 {{to_memref ops are not supported by One-Shot Analysis}}
   %0 = bufferization.to_memref %t1 : memref<?xf32>
 
   // Read from both.
@@ -251,6 +248,16 @@ func.func @to_memref_op_is_writing(
 
 // -----
 
+func.func @to_tensor_op_unsupported(%m: memref<?xf32>, %idx: index) -> (f32) {
+  // expected-error @+1 {{to_tensor ops without `restrict` are not supported by One-Shot Analysis}}
+  %0 = bufferization.to_tensor %m : memref<?xf32>
+
+  %1 = tensor.extract %0[%idx] : tensor<?xf32>
+  return %1 : f32
+}
+
+// -----
+
 // expected-error @+2 {{failed to bufferize op}}
 // expected-error @+1 {{cannot bufferize bodiless function that returns a tensor}}
 func.func private @foo(%t : tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
index 4cb25a6..63fad87 100644 (file)
@@ -23,7 +23,7 @@ func.func @test_to_memref(%arg0: tensor<?xi64>, %arg1: tensor<*xi64>)
 
 // CHECK-LABEL: func @test_to_tensor
 func.func @test_to_tensor(%buf : memref<2xf32>) -> tensor<2xf32> {
-  %tensor = bufferization.to_tensor %buf : memref<2xf32>
+  %tensor = bufferization.to_tensor %buf restrict writable : memref<2xf32>
   return %tensor : tensor<2xf32>
 }