[mlir][sparse] migrate sparse rewriting to sparse transformations pass
authorAart Bik <ajcbik@google.com>
Fri, 15 Jul 2022 23:41:02 +0000 (16:41 -0700)
committerAart Bik <ajcbik@google.com>
Mon, 18 Jul 2022 16:29:22 +0000 (09:29 -0700)
The rules in the linalg file were very specific to sparse tensors so will
find a better home under sparse tensor dialect than linalg dialect. Also
moved some rewriting from sparsification into this new "pre-rewriting" file.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D129910

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp [moved from mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp with 79% similarity]
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index b7300dd..71afddc 100644 (file)
@@ -134,13 +134,20 @@ void populateSparseTensorConversionPatterns(
     const SparseTensorConversionOptions &options =
         SparseTensorConversionOptions());
 
-std::unique_ptr<Pass> createDenseBufferizationPass(
-    const bufferization::OneShotBufferizationOptions &options);
 std::unique_ptr<Pass> createSparseTensorConversionPass();
 std::unique_ptr<Pass>
 createSparseTensorConversionPass(const SparseTensorConversionOptions &options);
 
 //===----------------------------------------------------------------------===//
+// Other rewriting rules and passes.
+//===----------------------------------------------------------------------===//
+
+void populateSparseTensorRewriting(RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createDenseBufferizationPass(
+    const bufferization::OneShotBufferizationOptions &options);
+
+//===----------------------------------------------------------------------===//
 // Registration.
 //===----------------------------------------------------------------------===//
 
index 5bc2740..a8112db 100644 (file)
@@ -22,7 +22,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   LinalgStrategyPasses.cpp
   NamedOpConversions.cpp
   Promotion.cpp
-  SparseTensorRewriting.cpp
   Split.cpp
   SplitReduction.cpp
   Tiling.cpp
index b958046..51f0164 100644 (file)
@@ -1717,12 +1717,8 @@ struct LinalgElementwiseOpFusionPass
 
     // Add elementwise op fusion patterns.
     populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
-
     populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
 
-    // Add the sparse tensor rewriting patterns.
-    populateSparseTensorRewriting(patterns);
-
     // General canonicalization patterns.
     AffineApplyOp::getCanonicalizationPatterns(patterns, context);
     GenericOp::getCanonicalizationPatterns(patterns, context);
index 59e6220..8c7639b 100644 (file)
@@ -52,11 +52,6 @@ void mlir::sparse_tensor::buildSparseCompiler(
     OpPassManager &pm, const SparseCompilerOptions &options) {
   // TODO(wrengr): ensure the original `pm` is for ModuleOp
   pm.addNestedPass<func::FuncOp>(createLinalgGeneralizationPass());
-  // TODO(springerm): Reactivate element-wise op fusion pass. This pass does not
-  // fit well with bufferization because it replaces unused "out" operands of
-  // LinalgOps with InitTensorOps. This would result in additional buffer
-  // allocations during bufferization.
-  // pm.addPass(createLinalgElementwiseOpFusionPass());
   pm.addPass(
       bufferization::createTensorCopyInsertionPass(getBufferizationOptions(
           /*analysisOnly=*/options.testBufferizationAnalysisOnly)));
index 76bd316..9d99d2f 100644 (file)
@@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   Sparsification.cpp
   SparseTensorConversion.cpp
   SparseTensorPasses.cpp
+  SparseTensorRewriting.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
index 9d94e5b..5fcb44a 100644 (file)
@@ -49,13 +49,17 @@ struct SparsificationPass : public SparsificationBase<SparsificationPass> {
 
   void runOnOperation() override {
     auto *ctx = &getContext();
-    RewritePatternSet patterns(ctx);
+    // Apply pre-rewriting.
+    RewritePatternSet prePatterns(ctx);
+    populateSparseTensorRewriting(prePatterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(prePatterns));
     // Translate strategy flags to strategy options.
     SparsificationOptions options(
         sparseParallelizationStrategy(parallelization),
         sparseVectorizationStrategy(vectorization), vectorLength,
         enableSIMDIndex32, enableVLAVectorization);
-    // Apply rewriting.
+    // Apply sparsification and vector cleanup rewriting.
+    RewritePatternSet patterns(ctx);
     populateSparsificationPatterns(patterns, options);
     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
@@ -6,20 +6,16 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file implements linalg dialect rewriting specific to sparse tensors.
-//
-// Sparsity should be mostly transparent to the linalg dialect optimizations
-// (i.e., the dense and sparse take the same path). However, in some cases,
-// optimizations only make sense in the context of sparse tensors. This file
-// implements such sparsity specific rewriting rules.
+// This file implements rewriting rules that are specific to sparse tensors.
 //
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/Support/LLVM.h"
@@ -98,6 +94,7 @@ static bool isSumOfMul(GenericOp op) {
 //===---------------------------------------------------------------------===//
 
 namespace {
+
 /// Rewriting rule that converts two kernels:
 ///
 ///      T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
@@ -114,6 +111,7 @@ namespace {
 /// a fusion may actually reduce the asymptotic complexity of the kernel,
 /// since intermediate results may be nullified.
 struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
+public:
   using OpRewritePattern<GenericOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(GenericOp op,
@@ -194,13 +192,55 @@ private:
     mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
   }
 };
+
+/// Sparse rewriting rule for reshape operator.
+template <typename ReshapeOp>
+struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
+public:
+  using OpRewritePattern<ReshapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ReshapeOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    auto encDst = getSparseTensorEncoding(op.getResult().getType());
+    auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
+    // Since a pure dense expansion is very cheap (change of view), for
+    // a sparse2dense or dense2sparse, we can simply unfuse a sparse
+    // conversion from the reshape operation itself.
+    // All other cases are handled elsewhere.
+    if (encDst && encSrc) {
+      return failure();
+    } else if (encSrc) {
+      RankedTensorType rtp =
+          op.getSrc().getType().template cast<RankedTensorType>();
+      auto denseTp =
+          RankedTensorType::get(rtp.getShape(), rtp.getElementType());
+      auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
+      op->setOperand(0, convert);
+      return success();
+    } else if (encDst) {
+      RankedTensorType rtp =
+          op.getResult().getType().template cast<RankedTensorType>();
+      auto denseTp =
+          RankedTensorType::get(rtp.getShape(), rtp.getElementType());
+      auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
+                                                op.getReassociation());
+      Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
+      rewriter.replaceOp(op, convert);
+      return success();
+    }
+    return failure();
+  }
+};
+
 } // namespace
 
 //===---------------------------------------------------------------------===//
 // Methods that add patterns described in this file to a pattern list.
 //===---------------------------------------------------------------------===//
 
-void mlir::linalg::populateSparseTensorRewriting(RewritePatternSet &patterns) {
-  auto *context = patterns.getContext();
-  patterns.add<FuseSparseMultiplyOverAdd>(context);
+void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns) {
+  // TODO(springerm): enable FuseSparseMultiplyOverAdd
+  patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
+               ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
 }
index 5318224..7121438 100644 (file)
@@ -1802,46 +1802,6 @@ private:
   SparsificationOptions options;
 };
 
-/// Sparse rewriting rule for reshape operator.
-template <typename ReshapeOp>
-struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
-public:
-  using OpRewritePattern<ReshapeOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ReshapeOp op,
-                                PatternRewriter &rewriter) const override {
-    Location loc = op->getLoc();
-    auto encDst = getSparseTensorEncoding(op.getResult().getType());
-    auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
-    // Since a pure dense expansion is very cheap (change of view), for
-    // a sparse2dense or dense2sparse, we can simply unfuse a sparse
-    // conversion from the reshape operation itself.
-    // All other cases are handled elsewhere.
-    if (encDst && encSrc) {
-      return failure();
-    } else if (encSrc) {
-      RankedTensorType rtp =
-          op.getSrc().getType().template cast<RankedTensorType>();
-      auto denseTp =
-          RankedTensorType::get(rtp.getShape(), rtp.getElementType());
-      auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
-      op->setOperand(0, convert);
-      return success();
-    } else if (encDst) {
-      RankedTensorType rtp =
-          op.getResult().getType().template cast<RankedTensorType>();
-      auto denseTp =
-          RankedTensorType::get(rtp.getShape(), rtp.getElementType());
-      auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
-                                                op.getReassociation());
-      Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
-      rewriter.replaceOp(op, convert);
-      return success();
-    }
-    return failure();
-  }
-};
-
 } // namespace
 
 /// Populates the given patterns list with rewriting rules required for
@@ -1849,6 +1809,4 @@ public:
 void mlir::populateSparsificationPatterns(
     RewritePatternSet &patterns, const SparsificationOptions &options) {
   patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
-  patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
-               ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
 }
index 5058e45..73a17d1 100644 (file)
@@ -2115,6 +2115,7 @@ cc_library(
         ":SparseTensorDialect",
         ":SparseTensorPassIncGen",
         ":SparseTensorUtils",
+        ":Support",
         ":TensorDialect",
         ":Transforms",
         ":VectorDialect",