[mlir][sparse] Add option enable-buffer-initialization to initialize the memory buffe...
authorbixia1 <bixia@google.com>
Mon, 7 Nov 2022 23:49:06 +0000 (15:49 -0800)
committerbixia1 <bixia@google.com>
Tue, 8 Nov 2022 17:54:33 +0000 (09:54 -0800)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D137592

mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

index 97030f5..5d51c5b 100644 (file)
@@ -63,6 +63,10 @@ struct SparseCompilerOptions
       *this, "test-bufferization-analysis-only",
       desc("Run only the inplacability analysis"), init(false)};
 
+  PassOptions::Option<bool> enableBufferInitialization{
+      *this, "enable-buffer-initialization",
+      desc("Enable zero-initialization of memory buffers"), init(false)};
+
   /// Projects out the options for `createSparsificationPass`.
   SparsificationOptions sparsificationOptions() const {
     return SparsificationOptions(parallelization);
index 5e301c4..8d704dc 100644 (file)
@@ -153,8 +153,10 @@ std::unique_ptr<Pass> createSparseTensorRewritePass(bool enableRT,
 std::unique_ptr<Pass> createDenseBufferizationPass(
     const bufferization::OneShotBufferizationOptions &options);
 
-void populateSparseBufferRewriting(RewritePatternSet &patterns);
-std::unique_ptr<Pass> createSparseBufferRewritePass();
+void populateSparseBufferRewriting(RewritePatternSet &patterns,
+                                   bool enableBufferInitialization);
+std::unique_ptr<Pass>
+createSparseBufferRewritePass(bool enableBufferInitialization = false);
 
 //===----------------------------------------------------------------------===//
 // Registration.
index 421706e..b7c4baa 100644 (file)
@@ -198,6 +198,10 @@ def SparseBufferRewrite : Pass<"sparse-buffer-rewrite", "ModuleOp"> {
     "scf::SCFDialect",
     "sparse_tensor::SparseTensorDialect",
   ];
+  let options = [
+    Option<"enableBufferInitialization", "enable-buffer-initialization", "bool",
+           "false", "Enable zero-initialization of the memory buffers">,
+  ];
 }
 
 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
index a16340d..82d060e 100644 (file)
@@ -65,7 +65,7 @@ void mlir::sparse_tensor::buildSparseCompiler(
         options.sparseTensorConversionOptions()));
   else
     pm.addPass(createSparseTensorCodegenPass());
-  pm.addPass(createSparseBufferRewritePass());
+  pm.addPass(createSparseBufferRewritePass(options.enableBufferInitialization));
   pm.addPass(createDenseBufferizationPass(
       getBufferizationOptions(/*analysisOnly=*/false)));
   pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
index 0af92a6..d0564ca 100644 (file)
@@ -635,6 +635,8 @@ namespace {
 struct PushBackRewriter : OpRewritePattern<PushBackOp> {
 public:
   using OpRewritePattern<PushBackOp>::OpRewritePattern;
+  PushBackRewriter(MLIRContext *context, bool enableInit)
+      : OpRewritePattern(context), enableBufferInitialization(enableInit) {}
   LogicalResult matchAndRewrite(PushBackOp op,
                                 PatternRewriter &rewriter) const override {
     // Rewrite push_back(buffer, value, n) to:
@@ -705,6 +707,16 @@ public:
 
       Value newBuffer =
           rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
+      if (enableBufferInitialization) {
+        Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize);
+        Value fillValue = rewriter.create<arith::ConstantOp>(
+            loc, value.getType(), rewriter.getZeroAttr(value.getType()));
+        Value subBuffer = rewriter.create<memref::SubViewOp>(
+            loc, newBuffer, /*offset=*/ValueRange{newSize},
+            /*size=*/ValueRange{fillSize},
+            /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
+        rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer);
+      }
       rewriter.create<scf::YieldOp>(loc, newBuffer);
 
       // False branch.
@@ -731,6 +743,9 @@ public:
     rewriter.replaceOp(op, buffer);
     return success();
   }
+
+private:
+  bool enableBufferInitialization;
 };
 
 /// Sparse rewriting rule for the sort operator.
@@ -777,6 +792,9 @@ public:
 // Methods that add patterns described in this file to a pattern list.
 //===---------------------------------------------------------------------===//
 
-void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns) {
-  patterns.add<PushBackRewriter, SortRewriter>(patterns.getContext());
+void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
+                                         bool enableBufferInitialization) {
+  patterns.add<PushBackRewriter>(patterns.getContext(),
+                                 enableBufferInitialization);
+  patterns.add<SortRewriter>(patterns.getContext());
 }
index 4a35a7f..8bc132f 100644 (file)
@@ -215,11 +215,14 @@ struct SparseBufferRewritePass
 
   SparseBufferRewritePass() = default;
   SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default;
+  SparseBufferRewritePass(bool enableInit) {
+    enableBufferInitialization = enableInit;
+  }
 
   void runOnOperation() override {
     auto *ctx = &getContext();
     RewritePatternSet patterns(ctx);
-    populateSparseBufferRewriting(patterns);
+    populateSparseBufferRewriting(patterns, enableBufferInitialization);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
@@ -279,6 +282,7 @@ std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
   return std::make_unique<SparseTensorCodegenPass>();
 }
 
-std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() {
-  return std::make_unique<SparseBufferRewritePass>();
+std::unique_ptr<Pass>
+mlir::createSparseBufferRewritePass(bool enableBufferInitialization) {
+  return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
 }