From: bixia1 Date: Mon, 17 Oct 2022 17:02:17 +0000 (-0700) Subject: [mlir][sparse] Add options to sparse-tensor-rewrite to disable rewriting rules for... X-Git-Tag: upstream/17.0.6~30219 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c1864ab9534080ee77a016bca24dd9a318bc6d7e;p=platform%2Fupstream%2Fllvm.git [mlir][sparse] Add options to sparse-tensor-rewrite to disable rewriting rules for operators foreach and convert. This is to help simplify FileCheck tests for sparse-tensor-rewrite. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D136093 --- diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h index fd99e4f..2230f43 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -158,14 +158,21 @@ void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, std::unique_ptr createSparseTensorCodegenPass(); //===----------------------------------------------------------------------===// -// Other rewriting rules and passes. +// The SparseTensorRewriting pass. //===----------------------------------------------------------------------===// -void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT); +void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT, + bool enableForeach, bool enableConvert); std::unique_ptr createSparseTensorRewritePass(); std::unique_ptr -createSparseTensorRewritePass(const SparsificationOptions &options); +createSparseTensorRewritePass(const SparsificationOptions &options, + bool enableForeach = true, + bool enableConvert = true); + +//===----------------------------------------------------------------------===// +// Other rewriting rules and passes. +//===----------------------------------------------------------------------===// std::unique_ptr createDenseBufferizationPass( const bufferization::OneShotBufferizationOptions &options); diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td index 26c78ae..eee33b0 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -28,7 +28,11 @@ def SparseTensorRewrite : Pass<"sparse-tensor-rewrite", "ModuleOp"> { ]; let options = [ Option<"enableRuntimeLibrary", "enable-runtime-library", "bool", - "true", "Enable runtime library for manipulating sparse tensors"> + "true", "Enable runtime library for manipulating sparse tensors">, + Option<"enableForeach", "enable-foreach", "bool", + "true", "Enable rewriting rules for the foreach operator">, + Option<"enableConvert", "enable-convert", "bool", + "true", "Enable rewriting rules for the convert operator">, ]; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp index 5ea55c8..b524ac1 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -43,14 +43,18 @@ struct SparseTensorRewritePass SparseTensorRewritePass() = default; SparseTensorRewritePass(const SparseTensorRewritePass &pass) = default; - SparseTensorRewritePass(const SparsificationOptions &options) { + SparseTensorRewritePass(const SparsificationOptions &options, bool foreach, + bool convert) { enableRuntimeLibrary = options.enableRuntimeLibrary; + enableForeach = foreach; + enableConvert = convert; } void runOnOperation() override { auto *ctx = &getContext(); RewritePatternSet patterns(ctx); - populateSparseTensorRewriting(patterns, enableRuntimeLibrary); + populateSparseTensorRewriting(patterns, enableRuntimeLibrary, enableForeach, + enableConvert); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; @@ -255,8 +259,10 @@ std::unique_ptr mlir::createSparseTensorRewritePass() { } std::unique_ptr -mlir::createSparseTensorRewritePass(const SparsificationOptions &options) { - return std::make_unique(options); +mlir::createSparseTensorRewritePass(const SparsificationOptions &options, + bool enableForeach, bool enableConvert) { + return std::make_unique(options, enableForeach, + enableConvert); } std::unique_ptr mlir::createSparsificationPass() { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 2654887..36a564b 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -612,11 +612,14 @@ struct NewRewriter : public OpRewritePattern { // Methods that add patterns described in this file to a pattern list. //===---------------------------------------------------------------------===// void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns, - bool enableRT) { + bool enableRT, bool enableForeach, + bool /*enableConvert*/) { patterns.add, - ReshapeRewriter, ForeachRewriter>( - patterns.getContext()); + ReshapeRewriter>(patterns.getContext()); + if (enableForeach) + patterns.add(patterns.getContext()); + // TODO: If RT not enabled, rewrite concatenate ops, etc here. if (!enableRT) patterns.add