[mlir][sparse] Introducing options for the SparseTensorConversion pass
authorwren romano <2998727+wrengr@users.noreply.github.com>
Sat, 19 Mar 2022 02:10:40 +0000 (19:10 -0700)
committerwren romano <2998727+wrengr@users.noreply.github.com>
Tue, 22 Mar 2022 20:11:09 +0000 (13:11 -0700)
This is work towards: https://github.com/llvm/llvm-project/issues/51652

This differential sets up the options and threads them through everywhere, but doesn't actually use them yet.  The differential that finally makes use of them is D122061, which is the final differential in the chain that fixes bug 51652.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D122054

mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

index 86b59b6..782f3b4 100644 (file)
@@ -49,6 +49,17 @@ struct SparseCompilerOptions
                                  vectorLength, enableSIMDIndex32);
   }
 
+  // These options must be kept in sync with `SparseTensorConversionBase`.
+  PassOptions::Option<int32_t> sparseToSparse{
+      *this, "s2s-strategy",
+      desc("Set the strategy for sparse-to-sparse conversion"), init(0)};
+
+  /// Projects out the options for `createSparsificationPass`.
+  SparseTensorConversionOptions sparseTensorConversionOptions() const {
+    return SparseTensorConversionOptions(
+        sparseToSparseConversionStrategy(sparseToSparse));
+  }
+
   // These options must be kept in sync with `ConvertVectorToLLVMBase`.
   // TODO(wrengr): does `indexOptimizations` differ from `enableSIMDIndex32`?
   PassOptions::Option<bool> reassociateFPReductions{
index 96f9ea1..1888b45 100644 (file)
@@ -8,6 +8,12 @@
 //
 // This header file defines prototypes of all sparse tensor passes.
 //
+// In general, this file takes the approach of keeping "mechanism" (the
+// actual steps of applying a transformation) completely separate from
+// "policy" (heuristics for when and where to apply transformations).
+// The only exception is in `SparseToSparseConversionStrategy`; for which,
+// see further discussion there.
+//
 //===----------------------------------------------------------------------===//
 
 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES_H_
@@ -21,6 +27,10 @@ namespace mlir {
 // Forward.
 class TypeConverter;
 
+//===----------------------------------------------------------------------===//
+// The Sparsification pass.
+//===----------------------------------------------------------------------===//
+
 /// Defines a parallelization strategy. Any independent loop is a candidate
 /// for parallelization. The loop is made parallel if (1) allowed by the
 /// strategy (e.g., AnyStorageOuterLoop considers either a dense or sparse
@@ -51,7 +61,7 @@ enum class SparseVectorizationStrategy {
 /// Converts command-line vectorization flag to the strategy enum.
 SparseVectorizationStrategy sparseVectorizationStrategy(int32_t flag);
 
-/// Sparsification options.
+/// Options for the Sparsification pass.
 struct SparsificationOptions {
   SparsificationOptions(SparseParallelizationStrategy p,
                         SparseVectorizationStrategy v, unsigned vl, bool e)
@@ -71,14 +81,56 @@ void populateSparsificationPatterns(
     RewritePatternSet &patterns,
     const SparsificationOptions &options = SparsificationOptions());
 
-/// Sets up sparse tensor conversion rules.
-void populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
-                                            RewritePatternSet &patterns);
-
 std::unique_ptr<Pass> createSparsificationPass();
 std::unique_ptr<Pass>
 createSparsificationPass(const SparsificationOptions &options);
+
+//===----------------------------------------------------------------------===//
+// The SparseTensorConversion pass.
+//===----------------------------------------------------------------------===//
+
+/// Defines a strategy for implementing sparse-to-sparse conversion.
+/// `kAuto` leaves it up to the compiler to automatically determine
+/// the method used.  `kViaCOO` converts the source tensor to COO and
+/// then converts the COO to the target format.  `kDirect` converts
+/// directly via the algorithm in <https://arxiv.org/abs/2001.02609>;
+/// however, beware that there are many formats not supported by this
+/// conversion method.
+///
+/// The presence of the `kAuto` option violates our usual goal of keeping
+/// policy completely separated from mechanism.  The reason it exists is
+/// because (at present) this strategy can only be specified on a per-file
+/// basis.  To see why this is a problem, note that `kDirect` cannot
+/// support certain conversions; so if there is no `kAuto` setting,
+/// then whenever a file contains a single non-`kDirect`-able conversion
+/// the user would be forced to use `kViaCOO` for all conversions in
+/// that file!  In the future, instead of using this enum as a `Pass`
+/// option, we could instead move it to being an attribute on the
+/// conversion op; at which point `kAuto` would no longer be necessary.
+enum class SparseToSparseConversionStrategy { kAuto, kViaCOO, kDirect };
+
+/// Converts command-line sparse2sparse flag to the strategy enum.
+SparseToSparseConversionStrategy sparseToSparseConversionStrategy(int32_t flag);
+
+/// SparseTensorConversion options.
+struct SparseTensorConversionOptions {
+  SparseTensorConversionOptions(SparseToSparseConversionStrategy s2s)
+      : sparseToSparseStrategy(s2s) {}
+  SparseTensorConversionOptions()
+      : SparseTensorConversionOptions(SparseToSparseConversionStrategy::kAuto) {
+  }
+  SparseToSparseConversionStrategy sparseToSparseStrategy;
+};
+
+/// Sets up sparse tensor conversion rules.
+void populateSparseTensorConversionPatterns(
+    TypeConverter &typeConverter, RewritePatternSet &patterns,
+    const SparseTensorConversionOptions &options =
+        SparseTensorConversionOptions());
+
 std::unique_ptr<Pass> createSparseTensorConversionPass();
+std::unique_ptr<Pass>
+createSparseTensorConversionPass(const SparseTensorConversionOptions &options);
 
 //===----------------------------------------------------------------------===//
 // Registration.
index 31b08af..89aacd6 100644 (file)
@@ -114,6 +114,10 @@ def SparseTensorConversion : Pass<"sparse-tensor-conversion", "ModuleOp"> {
     "sparse_tensor::SparseTensorDialect",
     "vector::VectorDialect",
   ];
+  let options = [
+    Option<"sparseToSparse", "s2s-strategy", "int32_t", "0",
+           "Set the strategy for sparse-to-sparse conversion">,
+  ];
 }
 
 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
index 1f5e266..54dac3d 100644 (file)
@@ -33,7 +33,8 @@ void mlir::sparse_tensor::buildSparseCompiler(
   pm.addNestedPass<FuncOp>(createLinalgGeneralizationPass());
   pm.addPass(createLinalgElementwiseOpFusionPass());
   pm.addPass(createSparsificationPass(options.sparsificationOptions()));
-  pm.addPass(createSparseTensorConversionPass());
+  pm.addPass(createSparseTensorConversionPass(
+      options.sparseTensorConversionOptions()));
   pm.addNestedPass<FuncOp>(createLinalgBufferizePass());
   pm.addNestedPass<FuncOp>(vector::createVectorBufferizePass());
   pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
index 17a07da..11329f6 100644 (file)
@@ -453,7 +453,18 @@ class SparseTensorInitConverter : public OpConversionPattern<InitOp> {
 
 /// Sparse conversion rule for the convert operator.
 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
+  /// Options to control sparse code generation.
+  SparseTensorConversionOptions options;
+
+public:
   using OpConversionPattern::OpConversionPattern;
+  SparseTensorConvertConverter(MLIRContext *context,
+                               SparseTensorConversionOptions o)
+      : OpConversionPattern<ConvertOp>(context), options(o) {}
+  SparseTensorConvertConverter(TypeConverter &typeConv, MLIRContext *context,
+                               SparseTensorConversionOptions o)
+      : OpConversionPattern<ConvertOp>(typeConv, context), options(o) {}
+
   LogicalResult
   matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -825,14 +836,17 @@ public:
 
 /// Populates the given patterns list with conversion rules required for
 /// the sparsification of linear algebra operations.
-void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
-                                                  RewritePatternSet &patterns) {
+void mlir::populateSparseTensorConversionPatterns(
+    TypeConverter &typeConverter, RewritePatternSet &patterns,
+    const SparseTensorConversionOptions &options) {
   patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
                SparseCastConverter, SparseTensorNewConverter,
-               SparseTensorInitConverter, SparseTensorConvertConverter,
-               SparseTensorReleaseConverter, SparseTensorToPointersConverter,
-               SparseTensorToIndicesConverter, SparseTensorToValuesConverter,
-               SparseTensorLoadConverter, SparseTensorLexInsertConverter,
-               SparseTensorExpandConverter, SparseTensorCompressConverter,
-               SparseTensorOutConverter>(typeConverter, patterns.getContext());
+               SparseTensorInitConverter, SparseTensorReleaseConverter,
+               SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
+               SparseTensorToValuesConverter, SparseTensorLoadConverter,
+               SparseTensorLexInsertConverter, SparseTensorExpandConverter,
+               SparseTensorCompressConverter, SparseTensorOutConverter>(
+      typeConverter, patterns.getContext());
+  patterns.add<SparseTensorConvertConverter>(typeConverter,
+                                             patterns.getContext(), options);
 }
index 2d8b858..2124aec 100644 (file)
@@ -73,6 +73,13 @@ public:
 
 struct SparseTensorConversionPass
     : public SparseTensorConversionBase<SparseTensorConversionPass> {
+
+  SparseTensorConversionPass() = default;
+  SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
+  SparseTensorConversionPass(const SparseTensorConversionOptions &options) {
+    sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy);
+  }
+
   void runOnOperation() override {
     auto *ctx = &getContext();
     RewritePatternSet patterns(ctx);
@@ -106,11 +113,14 @@ struct SparseTensorConversionPass
     target
         .addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect,
                          memref::MemRefDialect, scf::SCFDialect>();
+    // Translate strategy flags to strategy options.
+    SparseTensorConversionOptions options(
+        sparseToSparseConversionStrategy(sparseToSparse));
     // Populate with rules and apply rewriting rules.
     populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns,
                                                              converter);
     populateCallOpTypeConversionPattern(patterns, converter);
-    populateSparseTensorConversionPatterns(converter, patterns);
+    populateSparseTensorConversionPatterns(converter, patterns, options);
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
       signalPassFailure();
@@ -146,6 +156,18 @@ SparseVectorizationStrategy mlir::sparseVectorizationStrategy(int32_t flag) {
   }
 }
 
+SparseToSparseConversionStrategy
+mlir::sparseToSparseConversionStrategy(int32_t flag) {
+  switch (flag) {
+  default:
+    return SparseToSparseConversionStrategy::kAuto;
+  case 1:
+    return SparseToSparseConversionStrategy::kViaCOO;
+  case 2:
+    return SparseToSparseConversionStrategy::kDirect;
+  }
+}
+
 std::unique_ptr<Pass> mlir::createSparsificationPass() {
   return std::make_unique<SparsificationPass>();
 }
@@ -158,3 +180,8 @@ mlir::createSparsificationPass(const SparsificationOptions &options) {
 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
   return std::make_unique<SparseTensorConversionPass>();
 }
+
+std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
+    const SparseTensorConversionOptions &options) {
+  return std::make_unique<SparseTensorConversionPass>(options);
+}