[mlir][bufferization][sparse] put restriction on sparse tensor allocation
authorAart Bik <ajcbik@google.com>
Thu, 23 Jun 2022 19:17:01 +0000 (12:17 -0700)
committerAart Bik <ajcbik@google.com>
Fri, 24 Jun 2022 17:58:43 +0000 (10:58 -0700)
Putting some direct use restrictions on tensor allocations in the
sparse case enables the use of simplifying assumptions in the
bufferization analysis.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D128463

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
mlir/test/Dialect/Bufferization/invalid.mlir
mlir/test/Dialect/SparseTensor/conversion.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index db4ed5e..6b92904 100644 (file)
@@ -49,13 +49,17 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
     Both dense and sparse tensor types are supported. The result of a
     `bufferization.alloc_tensor` is a tensor value that can be used like any
     other tensor value. In practice, it is often used as the "out" operand of
-    another op. E.g.:
+    another op. Sparse tensor allocations should always be used in a local
+    construction operation and never escape the function boundary directly.
+
+    Example:
 
     ```mlir
     %c = bufferization.alloc_tensor [%d1, %d2] : tensor<?x?xf32, #SparseMatrix>
     %0 = linalg.matmul
       ins(%a, %b: tensor<?x?xf32, #SparseMatrix>, tensor<?x?xf32, #SparseMatrix>)
       outs(%c: tensor<?x?xf32, #SparseMatrix>) -> tensor<?x?xf32, #SparseMatrix>
+    return %0 : tensor<?x?xf32, #SparseMatrix>
     ```
   }];
 
index c76245f..37ff8e9 100644 (file)
@@ -9,8 +9,10 @@
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Matchers.h"
 
@@ -250,6 +252,16 @@ LogicalResult AllocTensorOp::verify() {
            << getType().getNumDynamicDims() << " dynamic sizes";
   if (getCopy() && getCopy().getType() != getType())
     return emitError("expected that `copy` and return type match");
+
+  // For sparse tensor allocation, we require that none of its
+  // uses escapes the function boundary directly.
+  if (sparse_tensor::getSparseTensorEncoding(getType())) {
+    for (auto &use : getOperation()->getUses())
+      if (isa<func::ReturnOp, func::CallOp, func::CallIndirectOp>(
+              use.getOwner()))
+        return emitError("sparse tensor allocation should not escape function");
+  }
+
   return success();
 }
 
index 2efaaa5..2e2f9fd 100644 (file)
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
   MLIRDialect
   MLIRFuncDialect
   MLIRIR
+  MLIRSparseTensorDialect
   MLIRTensorDialect
   MLIRMemRefDialect
   )
index 02ee6d4..28e7be0 100644 (file)
@@ -54,4 +54,28 @@ func.func @escape_attr_non_bufferizable(%m0: memref<?xf32>) {
   // expected-error @+1{{'bufferization.escape' only valid on bufferizable ops}}
   %0 = memref.cast %m0 {bufferization.escape = [true]} : memref<?xf32> to memref<10xf32>
   return
-}
\ No newline at end of file
+}
+
+// -----
+
+#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
+
+func.func @sparse_alloc_direct_return() -> tensor<20x40xf32, #DCSR> {
+  // expected-error @+1{{sparse tensor allocation should not escape function}}
+  %0 = bufferization.alloc_tensor() : tensor<20x40xf32, #DCSR>
+  return %0 : tensor<20x40xf32, #DCSR>
+}
+
+// -----
+
+#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
+
+func.func private @foo(tensor<20x40xf32, #DCSR>) -> ()
+
+func.func @sparse_alloc_call() {
+  // expected-error @+1{{sparse tensor allocation should not escape function}}
+  %0 = bufferization.alloc_tensor() : tensor<20x40xf32, #DCSR>
+  call @foo(%0) : (tensor<20x40xf32, #DCSR>) -> ()
+  return
+}
+
index d9b3ed1..950452e 100644 (file)
@@ -136,7 +136,8 @@ func.func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func.func @sparse_init(%arg0: index, %arg1: index) -> tensor<?x?xf64, #SparseMatrix> {
   %0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64, #SparseMatrix>
-  return %0 : tensor<?x?xf64, #SparseMatrix>
+  %1 = sparse_tensor.load %0 : tensor<?x?xf64, #SparseMatrix>
+  return %1 : tensor<?x?xf64, #SparseMatrix>
 }
 
 // CHECK-LABEL: func @sparse_release(
@@ -580,6 +581,7 @@ func.func @sparse_out2(%arg0: tensor<?x?x?xf32, #SparseTensor>, %arg1: !llvm.ptr
 func.func @sparse_and_dense_init(%arg0: index, %arg1: index)
            -> (tensor<?x?xf64, #SparseMatrix>, tensor<?x?xf64>) {
   %0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64, #SparseMatrix>
-  %1 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64>
-  return %0, %1 : tensor<?x?xf64, #SparseMatrix>, tensor<?x?xf64>
+  %1 = sparse_tensor.load %0 : tensor<?x?xf64, #SparseMatrix>
+  %2 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64>
+  return %1, %2 : tensor<?x?xf64, #SparseMatrix>, tensor<?x?xf64>
 }
index dc036ce..80b4b8a 100644 (file)
@@ -8958,6 +8958,7 @@ cc_library(
         ":IR",
         ":InferTypeOpInterface",
         ":MemRefDialect",
+        ":SparseTensorDialect",
         ":Support",
         ":TensorDialect",
         "//llvm:Support",