From f81f0cb75a2808a67d2662f044ad07628fc9d900 Mon Sep 17 00:00:00 2001 From: bixia1 Date: Wed, 16 Nov 2022 16:28:41 -0800 Subject: [PATCH] [mlir][sparse] Split SparseTensorRewrite into PreSparsificationRewrite and PostSparsificationRewrite. Reviewed By: aartbik, wrengr Differential Revision: https://reviews.llvm.org/D138153 --- .../mlir/Dialect/SparseTensor/Transforms/Passes.h | 23 +++++--- .../mlir/Dialect/SparseTensor/Transforms/Passes.td | 37 +++++++++---- .../Pipelines/SparseTensorPipelines.cpp | 3 +- .../SparseTensor/Transforms/SparseTensorPasses.cpp | 64 +++++++++++++++------- .../Transforms/SparseTensorRewriting.cpp | 16 ++++-- .../Dialect/SparseTensor/convert_dense2sparse.mlir | 2 +- .../Dialect/SparseTensor/convert_sparse2dense.mlir | 2 +- .../SparseTensor/convert_sparse2sparse.mlir | 2 +- mlir/test/Dialect/SparseTensor/rewriting.mlir | 2 +- .../SparseTensor/rewriting_for_codegen.mlir | 2 +- .../SparseTensor/sparse_concat_codegen.mlir | 2 +- .../Dialect/SparseTensor/sparse_fill_zero.mlir | 2 +- mlir/test/Dialect/SparseTensor/sparse_reshape.mlir | 2 +- mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir | 2 +- 14 files changed, 108 insertions(+), 53 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h index badc3d0..0961b5e 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -138,16 +138,25 @@ std::unique_ptr createSparseTensorCodegenPass(bool enableBufferInitialization); //===----------------------------------------------------------------------===// -// The SparseTensorRewriting pass. +// The PreSparsificationRewriting pass. //===----------------------------------------------------------------------===// -void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT, - bool enableForeach, bool enableConvert); +void populatePreSparsificationRewriting(RewritePatternSet &patterns); -std::unique_ptr createSparseTensorRewritePass(); -std::unique_ptr createSparseTensorRewritePass(bool enableRT, - bool enableForeach = true, - bool enableConvert = true); +std::unique_ptr createPreSparsificationRewritePass(); + +//===----------------------------------------------------------------------===// +// The PostSparsificationRewriting pass. +//===----------------------------------------------------------------------===// + +void populatePostSparsificationRewriting(RewritePatternSet &patterns, + bool enableRT, bool enableForeach, + bool enableConvert); + +std::unique_ptr createPostSparsificationRewritePass(); +std::unique_ptr +createPostSparsificationRewritePass(bool enableRT, bool enableForeach = true, + bool enableConvert = true); //===----------------------------------------------------------------------===// // Other rewriting rules and passes. diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td index 74784af..32bba3a 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -11,13 +11,13 @@ include "mlir/Pass/PassBase.td" -def SparseTensorRewrite : Pass<"sparse-tensor-rewrite", "ModuleOp"> { +def PreSparsificationRewrite : Pass<"pre-sparsification-rewrite", "ModuleOp"> { let summary = "Applies sparse tensor rewriting rules prior to sparsification"; let description = [{ A pass that applies rewriting rules to sparse tensor operations prior to running the actual sparsification pass. }]; - let constructor = "mlir::createSparseTensorRewritePass()"; + let constructor = "mlir::createPreSparsificationRewritePass()"; let dependentDialects = [ "arith::ArithDialect", "bufferization::BufferizationDialect", @@ -26,14 +26,6 @@ def SparseTensorRewrite : Pass<"sparse-tensor-rewrite", "ModuleOp"> { "scf::SCFDialect", "sparse_tensor::SparseTensorDialect", ]; - let options = [ - Option<"enableRuntimeLibrary", "enable-runtime-library", "bool", - "true", "Enable runtime library for manipulating sparse tensors">, - Option<"enableForeach", "enable-foreach", "bool", - "true", "Enable rewriting rules for the foreach operator">, - Option<"enableConvert", "enable-convert", "bool", - "true", "Enable rewriting rules for the convert operator">, - ]; } def SparsificationPass : Pass<"sparsification", "ModuleOp"> { @@ -109,6 +101,31 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> { ]; } +def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp"> { + let summary = "Applies sparse tensor rewriting rules after sparsification"; + let description = [{ + A pass that applies rewriting rules to sparse tensor operations after + running the actual sparsification pass. + }]; + let constructor = "mlir::createPostSparsificationRewritePass()"; + let dependentDialects = [ + "arith::ArithDialect", + "bufferization::BufferizationDialect", + "linalg::LinalgDialect", + "memref::MemRefDialect", + "scf::SCFDialect", + "sparse_tensor::SparseTensorDialect", + ]; + let options = [ + Option<"enableRuntimeLibrary", "enable-runtime-library", "bool", + "true", "Enable runtime library for manipulating sparse tensors">, + Option<"enableForeach", "enable-foreach", "bool", + "true", "Enable rewriting rules for the foreach operator">, + Option<"enableConvert", "enable-convert", "bool", + "true", "Enable rewriting rules for the convert operator">, + ]; +} + def SparseTensorConversionPass : Pass<"sparse-tensor-conversion", "ModuleOp"> { let summary = "Convert sparse tensors and primitives to library calls"; let description = [{ diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp index e48a760..b816fad 100644 --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -57,8 +57,9 @@ void mlir::sparse_tensor::buildSparseCompiler( /*analysisOnly=*/options.testBufferizationAnalysisOnly))); if (options.testBufferizationAnalysisOnly) return; - pm.addPass(createSparseTensorRewritePass(options.enableRuntimeLibrary)); + pm.addPass(createPreSparsificationRewritePass()); pm.addPass(createSparsificationPass(options.sparsificationOptions())); + pm.addPass(createPostSparsificationRewritePass(options.enableRuntimeLibrary)); if (options.enableRuntimeLibrary) { pm.addPass(createSparseTensorConversionPass( options.sparseTensorConversionOptions())); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp index da7c6ff..d1491df 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -21,8 +21,9 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { -#define GEN_PASS_DEF_SPARSETENSORREWRITE +#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE #define GEN_PASS_DEF_SPARSIFICATIONPASS +#define GEN_PASS_DEF_POSTSPARSIFICATIONREWRITE #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS #define GEN_PASS_DEF_SPARSETENSORCODEGEN #define GEN_PASS_DEF_SPARSEBUFFERREWRITE @@ -38,22 +39,17 @@ namespace { // Passes implementation. //===----------------------------------------------------------------------===// -struct SparseTensorRewritePass - : public impl::SparseTensorRewriteBase { +struct PreSparsificationRewritePass + : public impl::PreSparsificationRewriteBase { - SparseTensorRewritePass() = default; - SparseTensorRewritePass(const SparseTensorRewritePass &pass) = default; - SparseTensorRewritePass(bool enableRT, bool foreach, bool convert) { - enableRuntimeLibrary = enableRT; - enableForeach = foreach; - enableConvert = convert; - } + PreSparsificationRewritePass() = default; + PreSparsificationRewritePass(const PreSparsificationRewritePass &pass) = + default; void runOnOperation() override { auto *ctx = &getContext(); RewritePatternSet patterns(ctx); - populateSparseTensorRewriting(patterns, enableRuntimeLibrary, enableForeach, - enableConvert); + populatePreSparsificationRewriting(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; @@ -80,6 +76,28 @@ struct SparsificationPass } }; +struct PostSparsificationRewritePass + : public impl::PostSparsificationRewriteBase< + PostSparsificationRewritePass> { + + PostSparsificationRewritePass() = default; + PostSparsificationRewritePass(const PostSparsificationRewritePass &pass) = + default; + PostSparsificationRewritePass(bool enableRT, bool foreach, bool convert) { + enableRuntimeLibrary = enableRT; + enableForeach = foreach; + enableConvert = convert; + } + + void runOnOperation() override { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + populatePostSparsificationRewriting(patterns, enableRuntimeLibrary, + enableForeach, enableConvert); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + struct SparseTensorConversionPass : public impl::SparseTensorConversionPassBase { @@ -254,15 +272,8 @@ mlir::sparseToSparseConversionStrategy(int32_t flag) { // Pass creation methods. //===----------------------------------------------------------------------===// -std::unique_ptr mlir::createSparseTensorRewritePass() { - return std::make_unique(); -} - -std::unique_ptr mlir::createSparseTensorRewritePass(bool enableRT, - bool enableForeach, - bool enableConvert) { - return std::make_unique(enableRT, enableForeach, - enableConvert); +std::unique_ptr mlir::createPreSparsificationRewritePass() { + return std::make_unique(); } std::unique_ptr mlir::createSparsificationPass() { @@ -274,6 +285,17 @@ mlir::createSparsificationPass(const SparsificationOptions &options) { return std::make_unique(options); } +std::unique_ptr mlir::createPostSparsificationRewritePass() { + return std::make_unique(); +} + +std::unique_ptr +mlir::createPostSparsificationRewritePass(bool enableRT, bool enableForeach, + bool enableConvert) { + return std::make_unique( + enableRT, enableForeach, enableConvert); +} + std::unique_ptr mlir::createSparseTensorConversionPass() { return std::make_unique(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 6f38796..29430d3 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -1021,11 +1021,17 @@ struct OutRewriter : public OpRewritePattern { //===---------------------------------------------------------------------===// // Methods that add patterns described in this file to a pattern list. //===---------------------------------------------------------------------===// -void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns, - bool enableRT, bool enableForeach, - bool enableConvert) { - patterns.add, + +void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + +void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns, + bool enableRT, + bool enableForeach, + bool enableConvert) { + patterns.add, ReshapeRewriter>(patterns.getContext()); if (enableForeach) patterns.add(patterns.getContext()); diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir index 5c9d19b..8336a14 100644 --- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s -// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-foreach=false" \ +// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-foreach=false" \ // RUN: --canonicalize --cse | FileCheck %s --check-prefix=CHECK-RWT #SparseVector = #sparse_tensor.encoding<{ diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir index 2601986..2c5de95 100644 --- a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s -// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-foreach=false" \ +// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-foreach=false" \ // RUN: --canonicalize --cse | FileCheck %s --check-prefix=CHECK-RWT #SparseVector = #sparse_tensor.encoding<{ diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir index 6104750..496d594 100644 --- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir @@ -6,7 +6,7 @@ // RUN: mlir-opt %s --sparse-tensor-conversion="s2s-strategy=0" \ // RUN: --canonicalize --cse | FileCheck %s -check-prefixes=CHECK-AUTO,CHECK -// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-foreach=false" \ +// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-foreach=false" \ // RUN: --canonicalize --cse | FileCheck %s --check-prefix=CHECK-RWT #SparseVector64 = #sparse_tensor.encoding<{ diff --git a/mlir/test/Dialect/SparseTensor/rewriting.mlir b/mlir/test/Dialect/SparseTensor/rewriting.mlir index f142ecf..1744861 100755 --- a/mlir/test/Dialect/SparseTensor/rewriting.mlir +++ b/mlir/test/Dialect/SparseTensor/rewriting.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -sparse-tensor-rewrite | FileCheck %s +// RUN: mlir-opt %s -post-sparsification-rewrite | FileCheck %s #SparseVector = #sparse_tensor.encoding<{ dimLevelType = ["compressed"] diff --git a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir index 3a6cf999..94c373b 100644 --- a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -sparse-tensor-rewrite="enable-runtime-library=false enable-convert=false" |\ +// RUN: mlir-opt %s -post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" |\ // RUN: FileCheck %s #CSR = #sparse_tensor.encoding<{ diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir index 717819b..c2a1fd8 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-convert=false" \ +// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \ // RUN: --sparsification | FileCheck %s #DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir index 240a940..7571e3d 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --linalg-generalize-named-ops --sparse-tensor-rewrite --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s +// RUN: mlir-opt %s --linalg-generalize-named-ops --pre-sparsification-rewrite --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s #DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir index 94ee501..a679045 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt %s | mlir-opt | FileCheck %s --check-prefix=CHECK-ROUND // RUN: mlir-opt %s --sparse-tensor-conversion --cse --canonicalize | FileCheck %s --check-prefix=CHECK-CONV -// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-convert=false" \ +// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \ // RUN: --cse --canonicalize | FileCheck %s --check-prefix=CHECK-RWT #SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> diff --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir index d2ec5ca..b55f8cb 100755 --- a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --tensor-copy-insertion --sparse-tensor-rewrite --sparsification --cse | FileCheck %s +// RUN: mlir-opt %s --tensor-copy-insertion --pre-sparsification-rewrite --sparsification --cse | FileCheck %s #SM = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> -- 2.7.4