[mlir][sparse] Add options to sparse-tensor-rewrite to disable rewriting rules for...
authorbixia1 <bixia@google.com>
Mon, 17 Oct 2022 17:02:17 +0000 (10:02 -0700)
committerbixia1 <bixia@google.com>
Tue, 18 Oct 2022 16:27:32 +0000 (09:27 -0700)
This is to help simplify FileCheck tests for sparse-tensor-rewrite.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D136093

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

index fd99e4f..2230f43 100644 (file)
@@ -158,14 +158,21 @@ void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
 std::unique_ptr<Pass> createSparseTensorCodegenPass();
 
 //===----------------------------------------------------------------------===//
-// Other rewriting rules and passes.
+// The SparseTensorRewriting pass.
 //===----------------------------------------------------------------------===//
 
-void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT);
+void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT,
+                                   bool enableForeach, bool enableConvert);
 
 std::unique_ptr<Pass> createSparseTensorRewritePass();
 std::unique_ptr<Pass>
-createSparseTensorRewritePass(const SparsificationOptions &options);
+createSparseTensorRewritePass(const SparsificationOptions &options,
+                              bool enableForeach = true,
+                              bool enableConvert = true);
+
+//===----------------------------------------------------------------------===//
+// Other rewriting rules and passes.
+//===----------------------------------------------------------------------===//
 
 std::unique_ptr<Pass> createDenseBufferizationPass(
     const bufferization::OneShotBufferizationOptions &options);
index 26c78ae..eee33b0 100644 (file)
@@ -28,7 +28,11 @@ def SparseTensorRewrite : Pass<"sparse-tensor-rewrite", "ModuleOp"> {
   ];
   let options = [
     Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
-           "true", "Enable runtime library for manipulating sparse tensors">
+           "true", "Enable runtime library for manipulating sparse tensors">,
+    Option<"enableForeach", "enable-foreach", "bool",
+           "true", "Enable rewriting rules for the foreach operator">,
+    Option<"enableConvert", "enable-convert", "bool",
+           "true", "Enable rewriting rules for the convert operator">,
   ];
 }
 
index 5ea55c8..b524ac1 100644 (file)
@@ -43,14 +43,18 @@ struct SparseTensorRewritePass
 
   SparseTensorRewritePass() = default;
   SparseTensorRewritePass(const SparseTensorRewritePass &pass) = default;
-  SparseTensorRewritePass(const SparsificationOptions &options) {
+  SparseTensorRewritePass(const SparsificationOptions &options, bool foreach,
+                          bool convert) {
     enableRuntimeLibrary = options.enableRuntimeLibrary;
+    enableForeach = foreach;
+    enableConvert = convert;
   }
 
   void runOnOperation() override {
     auto *ctx = &getContext();
     RewritePatternSet patterns(ctx);
-    populateSparseTensorRewriting(patterns, enableRuntimeLibrary);
+    populateSparseTensorRewriting(patterns, enableRuntimeLibrary, enableForeach,
+                                  enableConvert);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
@@ -255,8 +259,10 @@ std::unique_ptr<Pass> mlir::createSparseTensorRewritePass() {
 }
 
 std::unique_ptr<Pass>
-mlir::createSparseTensorRewritePass(const SparsificationOptions &options) {
-  return std::make_unique<SparseTensorRewritePass>(options);
+mlir::createSparseTensorRewritePass(const SparsificationOptions &options,
+                                    bool enableForeach, bool enableConvert) {
+  return std::make_unique<SparseTensorRewritePass>(options, enableForeach,
+                                                   enableConvert);
 }
 
 std::unique_ptr<Pass> mlir::createSparsificationPass() {
index 2654887..36a564b 100644 (file)
@@ -612,11 +612,14 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
 // Methods that add patterns described in this file to a pattern list.
 //===---------------------------------------------------------------------===//
 void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns,
-                                         bool enableRT) {
+                                         bool enableRT, bool enableForeach,
+                                         bool /*enableConvert*/) {
   patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd,
                ReshapeRewriter<tensor::ExpandShapeOp>,
-               ReshapeRewriter<tensor::CollapseShapeOp>, ForeachRewriter>(
-      patterns.getContext());
+               ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
+  if (enableForeach)
+    patterns.add<ForeachRewriter>(patterns.getContext());
+
   // TODO: If RT not enabled, rewrite concatenate ops, etc here.
   if (!enableRT)
     patterns.add<ConcatenateRewriter, NewRewriter,