[mlir][tensor][bufferize] Improve bufferization of DimOp/RankOp
authorMatthias Springer <springerm@google.com>
Wed, 14 Dec 2022 11:26:29 +0000 (12:26 +0100)
committerMatthias Springer <springerm@google.com>
Wed, 14 Dec 2022 11:47:46 +0000 (12:47 +0100)
The tensor operands do not bufferize to a memory read.

Differential Revision: https://reviews.llvm.org/D140007

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/Tensor/one-shot-bufferize.mlir

index aa5a1d8..3c634e9 100644 (file)
@@ -205,7 +205,8 @@ struct DimOpInterface
                                                     tensor::DimOp> {
   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
-    return true;
+    // The op reads the tensor's metadata but not its contents.
+    return false;
   }
 
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
@@ -927,7 +928,8 @@ struct RankOpInterface
                                                     tensor::RankOp> {
   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
-    return true;
+    // The op reads the tensor's metadata but not its contents.
+    return false;
   }
 
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
index e1473c8..59fde56 100644 (file)
@@ -330,3 +330,20 @@ func.func @insert_slice_full_overwrite(%t: tensor<10xf32>, %b: tensor<10xf32>) -
   %2 = tensor.insert_slice %b into %t[0][10][1] : tensor<10xf32> into tensor<10xf32>
   return %2 : tensor<10xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @dim_not_reading(
+//  CHECK-SAME:     %[[t:.*]]: memref<?xf32
+func.func @dim_not_reading(%t: tensor<?xf32>, %f: f32, %pos: index) 
+    -> (tensor<?xf32>, index)
+{
+  %c0 = arith.constant 0 : index
+  // CHECK-NOT: memref.alloc
+  // CHECK-NOT: memref.copy
+  //     CHECK: memref.store %{{.*}}, %[[t]]
+  %0 = tensor.insert %f into %t[%pos] : tensor<?xf32>
+  //     CHECK: memref.dim %[[t]]
+  %1 = tensor.dim %t, %c0 : tensor<?xf32>
+  return %0, %1 : tensor<?xf32>, index
+}