[mlir][sparse] add create-sparse-deallocs options to match the create-deallocs in...
authorPeiming Liu <peiming@google.com>
Mon, 27 Mar 2023 22:56:52 +0000 (22:56 +0000)
committerPeiming Liu <peiming@google.com>
Mon, 27 Mar 2023 23:18:32 +0000 (23:18 +0000)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D147010

mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
mlir/test/Dialect/SparseTensor/codegen_sparse_dealloc.mlir [new file with mode: 0644]

index 09758c9..f45b64d 100644 (file)
@@ -73,6 +73,16 @@ struct SparseCompilerOptions
       *this, "enable-buffer-initialization",
       desc("Enable zero-initialization of memory buffers"), init(false)};
 
+  PassOptions::Option<bool> createSparseDeallocs{
+      *this, "create-sparse-deallocs",
+      desc("Specify if the temporary sparse buffer created by the sparse "
+           "compiler should be deallocated. For compatibility with core "
+           "bufferization passes. "
+           "It only takes effect when enable-runtime-library=false, otherwise "
+           "the memory storage for sparse tensors are managed by the runtime "
+           "library. See also create-deallocs for BufferizationOption."),
+      init(true)};
+
   PassOptions::Option<int32_t> vectorLength{
       *this, "vl", desc("Set the vector length (0 disables vectorization)"),
       init(0)};
index 734ffa0..8ef6381 100644 (file)
@@ -132,11 +132,13 @@ public:
 /// Sets up sparse tensor conversion rules.
 void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
                                          RewritePatternSet &patterns,
+                                         bool createSparseDeallocs,
                                          bool enableBufferInitialization);
 
 std::unique_ptr<Pass> createSparseTensorCodegenPass();
 std::unique_ptr<Pass>
-createSparseTensorCodegenPass(bool enableBufferInitialization);
+createSparseTensorCodegenPass(bool createSparseDeallocs,
+                              bool enableBufferInitialization);
 
 //===----------------------------------------------------------------------===//
 // The PreSparsificationRewriting pass.
@@ -180,8 +182,9 @@ std::unique_ptr<Pass> createSparsificationAndBufferizationPass(
     const bufferization::OneShotBufferizationOptions &bufferizationOptions,
     const SparsificationOptions &sparsificationOptions,
     const SparseTensorConversionOptions &sparseTensorConversionOptions,
-    bool enableRuntimeLibrary, bool enableBufferInitialization,
-    unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32);
+    bool createSparseDeallocs, bool enableRuntimeLibrary,
+    bool enableBufferInitialization, unsigned vectorLength,
+    bool enableVLAVectorization, bool enableSIMDIndex32);
 
 void populateSparseBufferRewriting(RewritePatternSet &patterns,
                                    bool enableBufferInitialization);
index 0df55d9..faf50fe 100644 (file)
@@ -220,6 +220,13 @@ def SparseTensorCodegen : Pass<"sparse-tensor-codegen", "ModuleOp"> {
   let options = [
     Option<"enableBufferInitialization", "enable-buffer-initialization", "bool",
            "false", "Enable zero-initialization of the memory buffers">,
+    Option<"createSparseDeallocs", "create-sparse-deallocs", "bool",
+           "true", "Specify if the temporary sparse buffer created by the sparse "
+                   "compiler should be deallocated. For compatibility with core "
+                   "bufferization passes. "
+                   "It only takes effect when enable-runtime-library=false, otherwise "
+                   "the memory storage for sparse tensors are managed by the runtime "
+                   "library. See also create-deallocs for BufferizationOption.">,
   ];
 }
 
index 44ccf05..47c1601 100644 (file)
@@ -56,8 +56,8 @@ void mlir::sparse_tensor::buildSparseCompiler(
   pm.addPass(createSparsificationAndBufferizationPass(
       getBufferizationOptions(options.testBufferizationAnalysisOnly),
       options.sparsificationOptions(), options.sparseTensorConversionOptions(),
-      options.enableRuntimeLibrary, options.enableBufferInitialization,
-      options.vectorLength,
+      options.createSparseDeallocs, options.enableRuntimeLibrary,
+      options.enableBufferInitialization, options.vectorLength,
       /*enableVLAVectorization=*/options.armSVE,
       /*enableSIMDIndex32=*/options.force32BitVectorIndices));
   if (options.testBufferizationAnalysisOnly)
index ae6e40f..cc7524c 100644 (file)
@@ -780,6 +780,11 @@ class SparseTensorDeallocConverter
     : public OpConversionPattern<bufferization::DeallocTensorOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
+  SparseTensorDeallocConverter(TypeConverter &typeConverter,
+                               MLIRContext *context, bool createDeallocs)
+      : OpConversionPattern(typeConverter, context),
+        createDeallocs(createDeallocs) {}
+
   LogicalResult
   matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -787,16 +792,22 @@ public:
     if (!enc)
       return failure();
 
-    // Replace the sparse tensor deallocation with field deallocations.
-    Location loc = op.getLoc();
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
-    for (auto input : desc.getMemRefFields())
-      // Deallocate every buffer used to store the sparse tensor handler.
-      rewriter.create<memref::DeallocOp>(loc, input);
-
+    // If user requests not to deallocate sparse tensors, simply erase the
+    // operation.
+    if (createDeallocs) {
+      // Replace the sparse tensor deallocation with field deallocations.
+      Location loc = op.getLoc();
+      auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+      for (auto input : desc.getMemRefFields())
+        // Deallocate every buffer used to store the sparse tensor handler.
+        rewriter.create<memref::DeallocOp>(loc, input);
+    }
     rewriter.eraseOp(op);
     return success();
   }
+
+private:
+  bool createDeallocs;
 };
 
 /// Sparse codegen rule for tensor rematerialization.
@@ -1492,13 +1503,12 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
 /// the sparsification of linear algebra operations.
 void mlir::populateSparseTensorCodegenPatterns(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
-    bool enableBufferInitialization) {
+    bool createSparseDeallocs, bool enableBufferInitialization) {
   patterns.add<SparsePackOpConverter, SparseUnpackOpConverter,
                SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
-               SparseCastConverter, SparseTensorDeallocConverter,
-               SparseExtractSliceConverter, SparseTensorLoadConverter,
-               SparseExpandConverter, SparseCompressConverter,
-               SparseInsertConverter,
+               SparseCastConverter, SparseExtractSliceConverter,
+               SparseTensorLoadConverter, SparseExpandConverter,
+               SparseCompressConverter, SparseInsertConverter,
                SparseSliceGetterOpConverter<ToSliceOffsetOp,
                                             StorageSpecifierKind::DimOffset>,
                SparseSliceGetterOpConverter<ToSliceStrideOp,
@@ -1508,6 +1518,8 @@ void mlir::populateSparseTensorCodegenPatterns(
                SparseConvertConverter, SparseNewOpConverter,
                SparseNumberOfEntriesConverter>(typeConverter,
                                                patterns.getContext());
+  patterns.add<SparseTensorDeallocConverter>(
+      typeConverter, patterns.getContext(), createSparseDeallocs);
   patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
                                            enableBufferInitialization);
 }
index 2163dcb..f39ead1 100644 (file)
@@ -181,7 +181,8 @@ struct SparseTensorCodegenPass
 
   SparseTensorCodegenPass() = default;
   SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default;
-  SparseTensorCodegenPass(bool enableInit) {
+  SparseTensorCodegenPass(bool createDeallocs, bool enableInit) {
+    createSparseDeallocs = createDeallocs;
     enableBufferInitialization = enableInit;
   }
 
@@ -232,8 +233,8 @@ struct SparseTensorCodegenPass
                                                                    converter);
     scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
                                                          target);
-    populateSparseTensorCodegenPatterns(converter, patterns,
-                                        enableBufferInitialization);
+    populateSparseTensorCodegenPatterns(
+        converter, patterns, createSparseDeallocs, enableBufferInitialization);
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
       signalPassFailure();
@@ -378,8 +379,10 @@ std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
 }
 
 std::unique_ptr<Pass>
-mlir::createSparseTensorCodegenPass(bool enableBufferInitialization) {
-  return std::make_unique<SparseTensorCodegenPass>(enableBufferInitialization);
+mlir::createSparseTensorCodegenPass(bool createSparseDeallocs,
+                                    bool enableBufferInitialization) {
+  return std::make_unique<SparseTensorCodegenPass>(createSparseDeallocs,
+                                                   enableBufferInitialization);
 }
 
 std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() {
index 48ba2dd..be1846e 100644 (file)
@@ -56,12 +56,13 @@ public:
       const bufferization::OneShotBufferizationOptions &bufferizationOptions,
       const SparsificationOptions &sparsificationOptions,
       const SparseTensorConversionOptions &sparseTensorConversionOptions,
-      bool enableRuntimeLibrary, bool enableBufferInitialization,
-      unsigned vectorLength, bool enableVLAVectorization,
-      bool enableSIMDIndex32)
+      bool createSparseDeallocs, bool enableRuntimeLibrary,
+      bool enableBufferInitialization, unsigned vectorLength,
+      bool enableVLAVectorization, bool enableSIMDIndex32)
       : bufferizationOptions(bufferizationOptions),
         sparsificationOptions(sparsificationOptions),
         sparseTensorConversionOptions(sparseTensorConversionOptions),
+        createSparseDeallocs(createSparseDeallocs),
         enableRuntimeLibrary(enableRuntimeLibrary),
         enableBufferInitialization(enableBufferInitialization),
         vectorLength(vectorLength),
@@ -147,7 +148,8 @@ public:
         pm.addPass(
             createSparseTensorConversionPass(sparseTensorConversionOptions));
       } else {
-        pm.addPass(createSparseTensorCodegenPass(enableBufferInitialization));
+        pm.addPass(createSparseTensorCodegenPass(createSparseDeallocs,
+                                                 enableBufferInitialization));
         pm.addPass(createSparseBufferRewritePass(enableBufferInitialization));
         pm.addPass(createStorageSpecifierToLLVMPass());
       }
@@ -164,6 +166,7 @@ private:
   bufferization::OneShotBufferizationOptions bufferizationOptions;
   SparsificationOptions sparsificationOptions;
   SparseTensorConversionOptions sparseTensorConversionOptions;
+  bool createSparseDeallocs;
   bool enableRuntimeLibrary;
   bool enableBufferInitialization;
   unsigned vectorLength;
@@ -178,13 +181,13 @@ std::unique_ptr<Pass> mlir::createSparsificationAndBufferizationPass(
     const bufferization::OneShotBufferizationOptions &bufferizationOptions,
     const SparsificationOptions &sparsificationOptions,
     const SparseTensorConversionOptions &sparseTensorConversionOptions,
-    bool enableRuntimeLibrary, bool enableBufferInitialization,
-    unsigned vectorLength, bool enableVLAVectorization,
-    bool enableSIMDIndex32) {
+    bool createSparseDeallocs, bool enableRuntimeLibrary,
+    bool enableBufferInitialization, unsigned vectorLength,
+    bool enableVLAVectorization, bool enableSIMDIndex32) {
   return std::make_unique<
       mlir::sparse_tensor::SparsificationAndBufferizationPass>(
       bufferizationOptions, sparsificationOptions,
-      sparseTensorConversionOptions, enableRuntimeLibrary,
+      sparseTensorConversionOptions, createSparseDeallocs, enableRuntimeLibrary,
       enableBufferInitialization, vectorLength, enableVLAVectorization,
       enableSIMDIndex32);
 }
diff --git a/mlir/test/Dialect/SparseTensor/codegen_sparse_dealloc.mlir b/mlir/test/Dialect/SparseTensor/codegen_sparse_dealloc.mlir
new file mode 100644 (file)
index 0000000..65ed5dd
--- /dev/null
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false" \
+// RUN:    --sparse-tensor-codegen=create-sparse-deallocs=false \
+// RUN:    --canonicalize --cse | FileCheck %s -check-prefix=CHECK-NO-DEALLOC
+
+// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false" \
+// RUN:    --sparse-tensor-codegen=create-sparse-deallocs=true \
+// RUN:    --canonicalize --cse | FileCheck %s -check-prefix=CHECK-DEALLOC
+
+#CSR = #sparse_tensor.encoding<{ dimLevelType = ["dense", "compressed"]}>
+#CSC = #sparse_tensor.encoding<{
+  dimLevelType = ["dense", "compressed"],
+  dimOrdering = affine_map<(i,j) -> (j,i)>
+}>
+
+//
+// No memref.dealloc is user-requested so
+// CHECK-NO-DEALLOC-LABEL: @sparse_convert_permuted
+// CHECK-NO-DEALLOC-NOT: memref.dealloc
+//
+// Otherwise memref.dealloc is created to free temporary sparse buffers.
+// CHECK-DEALLOC-LABEL: @sparse_convert_permuted
+// CHECK-DEALLOC: memref.dealloc
+//
+func.func @sparse_convert_permuted(%arg0: tensor<?x?xf32, #CSR>) -> tensor<?x?xf32, #CSC> {
+  %0 = sparse_tensor.convert %arg0 : tensor<?x?xf32, #CSR> to tensor<?x?xf32, #CSC>
+  return %0 : tensor<?x?xf32, #CSC>
+}