vectorLength, enableSIMDIndex32);
}
+ // These options must be kept in sync with `SparseTensorConversionBase`.
+ PassOptions::Option<int32_t> sparseToSparse{
+ *this, "s2s-strategy",
+ desc("Set the strategy for sparse-to-sparse conversion"), init(0)};
+
+ /// Projects out the options for `createSparsificationPass`.
+ SparseTensorConversionOptions sparseTensorConversionOptions() const {
+ return SparseTensorConversionOptions(
+ sparseToSparseConversionStrategy(sparseToSparse));
+ }
+
// These options must be kept in sync with `ConvertVectorToLLVMBase`.
// TODO(wrengr): does `indexOptimizations` differ from `enableSIMDIndex32`?
PassOptions::Option<bool> reassociateFPReductions{
//
// This header file defines prototypes of all sparse tensor passes.
//
+// In general, this file takes the approach of keeping "mechanism" (the
+// actual steps of applying a transformation) completely separate from
+// "policy" (heuristics for when and where to apply transformations).
+// The only exception is in `SparseToSparseConversionStrategy`; for which,
+// see further discussion there.
+//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES_H_
// Forward.
class TypeConverter;
+//===----------------------------------------------------------------------===//
+// The Sparsification pass.
+//===----------------------------------------------------------------------===//
+
/// Defines a parallelization strategy. Any independent loop is a candidate
/// for parallelization. The loop is made parallel if (1) allowed by the
/// strategy (e.g., AnyStorageOuterLoop considers either a dense or sparse
/// Converts command-line vectorization flag to the strategy enum.
SparseVectorizationStrategy sparseVectorizationStrategy(int32_t flag);
-/// Sparsification options.
+/// Options for the Sparsification pass.
struct SparsificationOptions {
SparsificationOptions(SparseParallelizationStrategy p,
SparseVectorizationStrategy v, unsigned vl, bool e)
RewritePatternSet &patterns,
const SparsificationOptions &options = SparsificationOptions());
-/// Sets up sparse tensor conversion rules.
-void populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
- RewritePatternSet &patterns);
-
std::unique_ptr<Pass> createSparsificationPass();
std::unique_ptr<Pass>
createSparsificationPass(const SparsificationOptions &options);
+
+//===----------------------------------------------------------------------===//
+// The SparseTensorConversion pass.
+//===----------------------------------------------------------------------===//
+
+/// Defines a strategy for implementing sparse-to-sparse conversion.
+/// `kAuto` leaves it up to the compiler to automatically determine
+/// the method used. `kViaCOO` converts the source tensor to COO and
+/// then converts the COO to the target format. `kDirect` converts
+/// directly via the algorithm in <https://arxiv.org/abs/2001.02609>;
+/// however, beware that there are many formats not supported by this
+/// conversion method.
+///
+/// The presence of the `kAuto` option violates our usual goal of keeping
+/// policy completely separated from mechanism. The reason it exists is
+/// because (at present) this strategy can only be specified on a per-file
+/// basis. To see why this is a problem, note that `kDirect` cannot
+/// support certain conversions; so if there is no `kAuto` setting,
+/// then whenever a file contains a single non-`kDirect`-able conversion
+/// the user would be forced to use `kViaCOO` for all conversions in
+/// that file! In the future, instead of using this enum as a `Pass`
+/// option, we could instead move it to being an attribute on the
+/// conversion op; at which point `kAuto` would no longer be necessary.
+enum class SparseToSparseConversionStrategy { kAuto, kViaCOO, kDirect };
+
+/// Converts command-line sparse2sparse flag to the strategy enum.
+SparseToSparseConversionStrategy sparseToSparseConversionStrategy(int32_t flag);
+
+/// SparseTensorConversion options.
+struct SparseTensorConversionOptions {
+ SparseTensorConversionOptions(SparseToSparseConversionStrategy s2s)
+ : sparseToSparseStrategy(s2s) {}
+ SparseTensorConversionOptions()
+ : SparseTensorConversionOptions(SparseToSparseConversionStrategy::kAuto) {
+ }
+ SparseToSparseConversionStrategy sparseToSparseStrategy;
+};
+
+/// Sets up sparse tensor conversion rules.
+void populateSparseTensorConversionPatterns(
+ TypeConverter &typeConverter, RewritePatternSet &patterns,
+ const SparseTensorConversionOptions &options =
+ SparseTensorConversionOptions());
+
std::unique_ptr<Pass> createSparseTensorConversionPass();
+std::unique_ptr<Pass>
+createSparseTensorConversionPass(const SparseTensorConversionOptions &options);
//===----------------------------------------------------------------------===//
// Registration.
"sparse_tensor::SparseTensorDialect",
"vector::VectorDialect",
];
+ let options = [
+ Option<"sparseToSparse", "s2s-strategy", "int32_t", "0",
+ "Set the strategy for sparse-to-sparse conversion">,
+ ];
}
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
pm.addNestedPass<FuncOp>(createLinalgGeneralizationPass());
pm.addPass(createLinalgElementwiseOpFusionPass());
pm.addPass(createSparsificationPass(options.sparsificationOptions()));
- pm.addPass(createSparseTensorConversionPass());
+ pm.addPass(createSparseTensorConversionPass(
+ options.sparseTensorConversionOptions()));
pm.addNestedPass<FuncOp>(createLinalgBufferizePass());
pm.addNestedPass<FuncOp>(vector::createVectorBufferizePass());
pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
/// Sparse conversion rule for the convert operator.
class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
+ /// Options to control sparse code generation.
+ SparseTensorConversionOptions options;
+
+public:
using OpConversionPattern::OpConversionPattern;
+ SparseTensorConvertConverter(MLIRContext *context,
+ SparseTensorConversionOptions o)
+ : OpConversionPattern<ConvertOp>(context), options(o) {}
+ SparseTensorConvertConverter(TypeConverter &typeConv, MLIRContext *context,
+ SparseTensorConversionOptions o)
+ : OpConversionPattern<ConvertOp>(typeConv, context), options(o) {}
+
LogicalResult
matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
/// Populates the given patterns list with conversion rules required for
/// the sparsification of linear algebra operations.
-void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
- RewritePatternSet &patterns) {
+void mlir::populateSparseTensorConversionPatterns(
+ TypeConverter &typeConverter, RewritePatternSet &patterns,
+ const SparseTensorConversionOptions &options) {
patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
SparseCastConverter, SparseTensorNewConverter,
- SparseTensorInitConverter, SparseTensorConvertConverter,
- SparseTensorReleaseConverter, SparseTensorToPointersConverter,
- SparseTensorToIndicesConverter, SparseTensorToValuesConverter,
- SparseTensorLoadConverter, SparseTensorLexInsertConverter,
- SparseTensorExpandConverter, SparseTensorCompressConverter,
- SparseTensorOutConverter>(typeConverter, patterns.getContext());
+ SparseTensorInitConverter, SparseTensorReleaseConverter,
+ SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
+ SparseTensorToValuesConverter, SparseTensorLoadConverter,
+ SparseTensorLexInsertConverter, SparseTensorExpandConverter,
+ SparseTensorCompressConverter, SparseTensorOutConverter>(
+ typeConverter, patterns.getContext());
+ patterns.add<SparseTensorConvertConverter>(typeConverter,
+ patterns.getContext(), options);
}
struct SparseTensorConversionPass
: public SparseTensorConversionBase<SparseTensorConversionPass> {
+
+ SparseTensorConversionPass() = default;
+ SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
+ SparseTensorConversionPass(const SparseTensorConversionOptions &options) {
+ sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy);
+ }
+
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
target
.addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect,
memref::MemRefDialect, scf::SCFDialect>();
+ // Translate strategy flags to strategy options.
+ SparseTensorConversionOptions options(
+ sparseToSparseConversionStrategy(sparseToSparse));
// Populate with rules and apply rewriting rules.
populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns,
converter);
populateCallOpTypeConversionPattern(patterns, converter);
- populateSparseTensorConversionPatterns(converter, patterns);
+ populateSparseTensorConversionPatterns(converter, patterns, options);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
}
+SparseToSparseConversionStrategy
+mlir::sparseToSparseConversionStrategy(int32_t flag) {
+ switch (flag) {
+ default:
+ return SparseToSparseConversionStrategy::kAuto;
+ case 1:
+ return SparseToSparseConversionStrategy::kViaCOO;
+ case 2:
+ return SparseToSparseConversionStrategy::kDirect;
+ }
+}
+
std::unique_ptr<Pass> mlir::createSparsificationPass() {
return std::make_unique<SparsificationPass>();
}
std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
return std::make_unique<SparseTensorConversionPass>();
}
+
+std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
+ const SparseTensorConversionOptions &options) {
+ return std::make_unique<SparseTensorConversionPass>(options);
+}