From 9a3d60e0d30a3a659f4040e3c424d82115c4219e Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Thu, 23 Jun 2022 12:17:01 -0700 Subject: [PATCH] [mlir][bufferization][sparse] put restriction on sparse tensor allocation Putting some direct use restrictions on tensor allocations in the sparse case enables the use of simplifying assumptions in the bufferization analysis. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D128463 --- .../Dialect/Bufferization/IR/BufferizationOps.td | 6 ++++- .../Dialect/Bufferization/IR/BufferizationOps.cpp | 12 ++++++++++ mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt | 1 + mlir/test/Dialect/Bufferization/invalid.mlir | 26 +++++++++++++++++++++- mlir/test/Dialect/SparseTensor/conversion.mlir | 8 ++++--- utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 1 + 6 files changed, 49 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td index db4ed5e..6b92904 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -49,13 +49,17 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor", Both dense and sparse tensor types are supported. The result of a `bufferization.alloc_tensor` is a tensor value that can be used like any other tensor value. In practice, it is often used as the "out" operand of - another op. E.g.: + another op. Sparse tensor allocations should always be used in a local + construction operation and never escape the function boundary directly. + + Example: ```mlir %c = bufferization.alloc_tensor [%d1, %d2] : tensor %0 = linalg.matmul ins(%a, %b: tensor, tensor) outs(%c: tensor) -> tensor + return %0 : tensor ``` }]; diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index c76245f..37ff8e9 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -9,8 +9,10 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" @@ -250,6 +252,16 @@ LogicalResult AllocTensorOp::verify() { << getType().getNumDynamicDims() << " dynamic sizes"; if (getCopy() && getCopy().getType() != getType()) return emitError("expected that `copy` and return type match"); + + // For sparse tensor allocation, we require that none of its + // uses escapes the function boundary directly. + if (sparse_tensor::getSparseTensorEncoding(getType())) { + for (auto &use : getOperation()->getUses()) + if (isa( + use.getOwner())) + return emitError("sparse tensor allocation should not escape function"); + } + return success(); } diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt index 2efaaa5..2e2f9fd 100644 --- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt @@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect MLIRDialect MLIRFuncDialect MLIRIR + MLIRSparseTensorDialect MLIRTensorDialect MLIRMemRefDialect ) diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir index 02ee6d4..28e7be0 100644 --- a/mlir/test/Dialect/Bufferization/invalid.mlir +++ b/mlir/test/Dialect/Bufferization/invalid.mlir @@ -54,4 +54,28 @@ func.func @escape_attr_non_bufferizable(%m0: memref) { // expected-error @+1{{'bufferization.escape' only valid on bufferizable ops}} %0 = memref.cast %m0 {bufferization.escape = [true]} : memref to memref<10xf32> return -} \ No newline at end of file +} + +// ----- + +#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> + +func.func @sparse_alloc_direct_return() -> tensor<20x40xf32, #DCSR> { + // expected-error @+1{{sparse tensor allocation should not escape function}} + %0 = bufferization.alloc_tensor() : tensor<20x40xf32, #DCSR> + return %0 : tensor<20x40xf32, #DCSR> +} + +// ----- + +#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> + +func.func private @foo(tensor<20x40xf32, #DCSR>) -> () + +func.func @sparse_alloc_call() { + // expected-error @+1{{sparse tensor allocation should not escape function}} + %0 = bufferization.alloc_tensor() : tensor<20x40xf32, #DCSR> + call @foo(%0) : (tensor<20x40xf32, #DCSR>) -> () + return +} + diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir index d9b3ed1..950452eb 100644 --- a/mlir/test/Dialect/SparseTensor/conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion.mlir @@ -136,7 +136,8 @@ func.func @sparse_new3d(%arg0: !llvm.ptr) -> tensor func.func @sparse_init(%arg0: index, %arg1: index) -> tensor { %0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor - return %0 : tensor + %1 = sparse_tensor.load %0 : tensor + return %1 : tensor } // CHECK-LABEL: func @sparse_release( @@ -580,6 +581,7 @@ func.func @sparse_out2(%arg0: tensor, %arg1: !llvm.ptr func.func @sparse_and_dense_init(%arg0: index, %arg1: index) -> (tensor, tensor) { %0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor - %1 = bufferization.alloc_tensor(%arg0, %arg1) : tensor - return %0, %1 : tensor, tensor + %1 = sparse_tensor.load %0 : tensor + %2 = bufferization.alloc_tensor(%arg0, %arg1) : tensor + return %1, %2 : tensor, tensor } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index dc036ce..80b4b8a 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -8958,6 +8958,7 @@ cc_library( ":IR", ":InferTypeOpInterface", ":MemRefDialect", + ":SparseTensorDialect", ":Support", ":TensorDialect", "//llvm:Support", -- 2.7.4