[mlir][bufferization] Better handling of unranked tensors in resolveTensorOpOperandCo...
authorMatthias Springer <springerm@google.com>
Mon, 30 Jan 2023 09:19:32 +0000 (10:19 +0100)
committerMatthias Springer <springerm@google.com>
Mon, 30 Jan 2023 09:20:10 +0000 (10:20 +0100)
Unranked tensors can currently not be copied. They are forced to always bufferize in-place. There is typically some other OpOperand that can bufferize out-of-place instead if needed.

Note: There is IR that cannot be bufferized with One-Shot Bufferize at the moment (see invalid test case). But it is unclear if we need to support such cases. We do not have a use case at the moment. This restriction could be loosened in the future if needed.

This change improves error handling when bufferizing IR where an unranked tensor would be copied. It also disables an optimization where an OpResult was copied instead of an OpOperand in case the OpResult is an unranked tensor (Github #60187).

Differential Revision: https://reviews.llvm.org/D142331

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir

index 488165b..95af51a 100644 (file)
@@ -156,6 +156,10 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           sets and inplace attributes will be set up accordingly before making
           any other bufferization decisions. This method will never be called on
           OpOperands that do not have a tensor type.
+
+          Note: Unranked tensor OpOperands always bufferize in-place. This could
+          be extended in the future. Unranked tensors are used with external
+          functions only.
         }],
         /*retType=*/"bool",
         /*methodName=*/"mustBufferizeInPlace",
@@ -163,7 +167,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
                       "const ::mlir::bufferization::AnalysisState &":$state),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
-          return false;
+          return opOperand.get().getType().isa<UnrankedTensorType>();
         }]
       >,
       InterfaceMethod<
index dc43017..16eedea 100644 (file)
@@ -107,6 +107,10 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
     tensor = shapedValue;
   } else if (shapedValue.getType().isa<MemRefType>()) {
     tensor = b.create<ToTensorOp>(loc, shapedValue);
+  } else if (shapedValue.getType().isa<UnrankedTensorType>() ||
+             shapedValue.getType().isa<UnrankedMemRefType>()) {
+    return getOwnerOfValue(shapedValue)
+        ->emitError("copying of unranked tensors is not implemented");
   } else {
     llvm_unreachable("expected RankedTensorType or MemRefType");
   }
@@ -175,7 +179,7 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
     if (state.isInPlace(opOperand))
       continue;
     if (operandType.isa<UnrankedTensorType>())
-      return op->emitError("copies of unranked tensors are not supported");
+      return op->emitError("copying of unranked tensors is not implemented");
 
     SmallVector<OpResult> aliasingOpResults =
         state.getAliasingOpResult(opOperand);
@@ -189,11 +193,14 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
 
     if (aliasingOpResults.size() == 1 &&
         !state.bufferizesToMemoryWrite(opOperand) &&
-        state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) {
+        state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1 &&
+        !aliasingOpResults.front().getType().isa<UnrankedTensorType>()) {
       // The op itself does not write but may create exactly one alias. Instead
       // of copying the OpOperand, copy the OpResult. The OpResult can sometimes
       // be smaller than the OpOperand (e.g., in the case of an extract_slice,
-      // where the result is usually a smaller part of the source).
+      // where the result is usually a smaller part of the source). Do not apply
+      // this optimization if the OpResult is an unranked tensor (because those
+      // cannot be copied at the moment).
       outOfPlaceOpResults.push_back(aliasingOpResults.front());
       if (!state.canOmitTensorCopy(opOperand))
         copiedOpResults.insert(aliasingOpResults.front());
index 9069010..caff954 100644 (file)
@@ -1283,16 +1283,16 @@ func.func @write_to_same_alloc_tensor_out_of_place(
 
 // -----
 
-// CHECK-LABEL: func.func private @ext_func(tensor<*xf32> {bufferization.access = "read-write"})
-func.func private @ext_func(%t: tensor<*xf32>)
+// CHECK-LABEL: func.func private @ext_func(tensor<?xf32> {bufferization.access = "read-write"})
+func.func private @ext_func(%t: tensor<?xf32>)
 
 // CHECK: func.func @private_func_read_write(%{{.*}}: tensor<5xf32> {bufferization.access = "read"})
 func.func @private_func_read_write(%t: tensor<5xf32>) -> f32 {
   %c0 = arith.constant 0 : index
   // Bufferizes out-of-place because `ext_func` may modify the buffer.
   // CHECK: tensor.cast {{.*}} {__inplace_operands_attr__ = ["false"]}
-  %0 = tensor.cast %t : tensor<5xf32> to tensor<*xf32>
-  func.call @ext_func(%0) : (tensor<*xf32>) -> ()
+  %0 = tensor.cast %t : tensor<5xf32> to tensor<?xf32>
+  func.call @ext_func(%0) : (tensor<?xf32>) -> ()
   %1 = tensor.extract %t[%c0] : tensor<5xf32>
   return %1 : f32
 }
index da0fe74..2c0c8d7 100644 (file)
@@ -315,3 +315,16 @@ func.func @yield_alloc_dominance_test_2(%cst : f32, %idx : index,
   %r = tensor.extract %2[%idx2] : tensor<?xf32>
   return %r : f32
 }
+
+// -----
+
+func.func @copy_of_unranked_tensor(%t: tensor<*xf32>) -> tensor<*xf32> {
+  // Unranked tensor OpOperands always bufferize in-place. With this limitation,
+  // there is no way to bufferize this IR correctly.
+  // expected-error @+1 {{input IR has RaW conflict}}
+  func.call @maybe_writing_func(%t) : (tensor<*xf32>) -> ()
+  return %t : tensor<*xf32>
+}
+
+// This function may write to buffer(%ptr).
+func.func private @maybe_writing_func(%ptr : tensor<*xf32>)
index d25a374..1980991 100644 (file)
@@ -607,3 +607,21 @@ func.func @transfer_read(
 //       CHECK: return %[[RES]] : vector<4xf32>
   return %0 : vector<4xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @main(
+func.func @main() {
+  // CHECK: %[[const:.*]] = memref.get_global
+  %t = arith.constant dense<[1.0, 2.0, 3.0]> : tensor<3xf32>
+  // CHECK: %[[alloc:.*]] = memref.alloc
+  // CHECK: memref.copy %[[const]], %[[alloc]]
+  // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] : memref<3xf32> to memref<*xf32>
+  %unranked = tensor.cast %t : tensor<3xf32> to tensor<*xf32>
+  // CHECK: call @maybe_writing_func(%[[casted]])
+  func.call @maybe_writing_func(%unranked) : (tensor<*xf32>) -> ()
+  return
+}
+
+// This function may write to buffer(%ptr).
+func.func private @maybe_writing_func(%ptr : tensor<*xf32>)