*this, "test-bufferization-analysis-only",
desc("Run only the inplacability analysis"), init(false)};
+ PassOptions::Option<bool> enableBufferInitialization{
+ *this, "enable-buffer-initialization",
+ desc("Enable zero-initialization of memory buffers"), init(false)};
+
/// Projects out the options for `createSparsificationPass`.
SparsificationOptions sparsificationOptions() const {
return SparsificationOptions(parallelization);
std::unique_ptr<Pass> createDenseBufferizationPass(
const bufferization::OneShotBufferizationOptions &options);
-void populateSparseBufferRewriting(RewritePatternSet &patterns);
-std::unique_ptr<Pass> createSparseBufferRewritePass();
+void populateSparseBufferRewriting(RewritePatternSet &patterns,
+ bool enableBufferInitialization);
+std::unique_ptr<Pass>
+createSparseBufferRewritePass(bool enableBufferInitialization = false);
//===----------------------------------------------------------------------===//
// Registration.
"scf::SCFDialect",
"sparse_tensor::SparseTensorDialect",
];
+ let options = [
+ Option<"enableBufferInitialization", "enable-buffer-initialization", "bool",
+ "false", "Enable zero-initialization of the memory buffers">,
+ ];
}
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
options.sparseTensorConversionOptions()));
else
pm.addPass(createSparseTensorCodegenPass());
- pm.addPass(createSparseBufferRewritePass());
+ pm.addPass(createSparseBufferRewritePass(options.enableBufferInitialization));
pm.addPass(createDenseBufferizationPass(
getBufferizationOptions(/*analysisOnly=*/false)));
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
struct PushBackRewriter : OpRewritePattern<PushBackOp> {
public:
using OpRewritePattern<PushBackOp>::OpRewritePattern;
+ PushBackRewriter(MLIRContext *context, bool enableInit)
+ : OpRewritePattern(context), enableBufferInitialization(enableInit) {}
LogicalResult matchAndRewrite(PushBackOp op,
PatternRewriter &rewriter) const override {
// Rewrite push_back(buffer, value, n) to:
Value newBuffer =
rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
+ if (enableBufferInitialization) {
+ Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize);
+ Value fillValue = rewriter.create<arith::ConstantOp>(
+ loc, value.getType(), rewriter.getZeroAttr(value.getType()));
+ Value subBuffer = rewriter.create<memref::SubViewOp>(
+ loc, newBuffer, /*offset=*/ValueRange{newSize},
+ /*size=*/ValueRange{fillSize},
+ /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
+ rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer);
+ }
rewriter.create<scf::YieldOp>(loc, newBuffer);
// False branch.
rewriter.replaceOp(op, buffer);
return success();
}
+
+private:
+ bool enableBufferInitialization;
};
/// Sparse rewriting rule for the sort operator.
// Methods that add patterns described in this file to a pattern list.
//===---------------------------------------------------------------------===//
-void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns) {
- patterns.add<PushBackRewriter, SortRewriter>(patterns.getContext());
+void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
+ bool enableBufferInitialization) {
+ patterns.add<PushBackRewriter>(patterns.getContext(),
+ enableBufferInitialization);
+ patterns.add<SortRewriter>(patterns.getContext());
}
SparseBufferRewritePass() = default;
SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default;
+ SparseBufferRewritePass(bool enableInit) {
+ enableBufferInitialization = enableInit;
+ }
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
- populateSparseBufferRewriting(patterns);
+ populateSparseBufferRewriting(patterns, enableBufferInitialization);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
return std::make_unique<SparseTensorCodegenPass>();
}
-std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() {
- return std::make_unique<SparseBufferRewritePass>();
+std::unique_ptr<Pass>
+mlir::createSparseBufferRewritePass(bool enableBufferInitialization) {
+ return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
}