const SparseTensorConversionOptions &options =
SparseTensorConversionOptions());
-std::unique_ptr<Pass> createDenseBufferizationPass(
- const bufferization::OneShotBufferizationOptions &options);
std::unique_ptr<Pass> createSparseTensorConversionPass();
std::unique_ptr<Pass>
createSparseTensorConversionPass(const SparseTensorConversionOptions &options);
//===----------------------------------------------------------------------===//
+// Other rewriting rules and passes.
+//===----------------------------------------------------------------------===//
+
+void populateSparseTensorRewriting(RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createDenseBufferizationPass(
+ const bufferization::OneShotBufferizationOptions &options);
+
+//===----------------------------------------------------------------------===//
// Registration.
//===----------------------------------------------------------------------===//
LinalgStrategyPasses.cpp
NamedOpConversions.cpp
Promotion.cpp
- SparseTensorRewriting.cpp
Split.cpp
SplitReduction.cpp
Tiling.cpp
// Add elementwise op fusion patterns.
populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
-
populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
- // Add the sparse tensor rewriting patterns.
- populateSparseTensorRewriting(patterns);
-
// General canonicalization patterns.
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
GenericOp::getCanonicalizationPatterns(patterns, context);
OpPassManager &pm, const SparseCompilerOptions &options) {
// TODO(wrengr): ensure the original `pm` is for ModuleOp
pm.addNestedPass<func::FuncOp>(createLinalgGeneralizationPass());
- // TODO(springerm): Reactivate element-wise op fusion pass. This pass does not
- // fit well with bufferization because it replaces unused "out" operands of
- // LinalgOps with InitTensorOps. This would result in additional buffer
- // allocations during bufferization.
- // pm.addPass(createLinalgElementwiseOpFusionPass());
pm.addPass(
bufferization::createTensorCopyInsertionPass(getBufferizationOptions(
/*analysisOnly=*/options.testBufferizationAnalysisOnly)));
Sparsification.cpp
SparseTensorConversion.cpp
SparseTensorPasses.cpp
+ SparseTensorRewriting.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
void runOnOperation() override {
auto *ctx = &getContext();
- RewritePatternSet patterns(ctx);
+ // Apply pre-rewriting.
+ RewritePatternSet prePatterns(ctx);
+ populateSparseTensorRewriting(prePatterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(prePatterns));
// Translate strategy flags to strategy options.
SparsificationOptions options(
sparseParallelizationStrategy(parallelization),
sparseVectorizationStrategy(vectorization), vectorLength,
enableSIMDIndex32, enableVLAVectorization);
- // Apply rewriting.
+ // Apply sparsification and vector cleanup rewriting.
+ RewritePatternSet patterns(ctx);
populateSparsificationPatterns(patterns, options);
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
//
//===----------------------------------------------------------------------===//
//
-// This file implements linalg dialect rewriting specific to sparse tensors.
-//
-// Sparsity should be mostly transparent to the linalg dialect optimizations
-// (i.e., the dense and sparse take the same path). However, in some cases,
-// optimizations only make sense in the context of sparse tensors. This file
-// implements such sparsity specific rewriting rules.
+// This file implements rewriting rules that are specific to sparse tensors.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LLVM.h"
//===---------------------------------------------------------------------===//
namespace {
+
/// Rewriting rule that converts two kernels:
///
/// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
/// a fusion may actually reduce the asymptotic complexity of the kernel,
/// since intermediate results may be nullified.
struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
+public:
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp op,
mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
}
};
+
+/// Sparse rewriting rule for reshape operator.
+template <typename ReshapeOp>
+struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
+public:
+ using OpRewritePattern<ReshapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ReshapeOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ auto encDst = getSparseTensorEncoding(op.getResult().getType());
+ auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
+ // Since a pure dense expansion is very cheap (change of view), for
+ // a sparse2dense or dense2sparse, we can simply unfuse a sparse
+ // conversion from the reshape operation itself.
+ // All other cases are handled elsewhere.
+ if (encDst && encSrc) {
+ return failure();
+ } else if (encSrc) {
+ RankedTensorType rtp =
+ op.getSrc().getType().template cast<RankedTensorType>();
+ auto denseTp =
+ RankedTensorType::get(rtp.getShape(), rtp.getElementType());
+ auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
+ op->setOperand(0, convert);
+ return success();
+ } else if (encDst) {
+ RankedTensorType rtp =
+ op.getResult().getType().template cast<RankedTensorType>();
+ auto denseTp =
+ RankedTensorType::get(rtp.getShape(), rtp.getElementType());
+ auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
+ op.getReassociation());
+ Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
+ rewriter.replaceOp(op, convert);
+ return success();
+ }
+ return failure();
+ }
+};
+
} // namespace
//===---------------------------------------------------------------------===//
// Methods that add patterns described in this file to a pattern list.
//===---------------------------------------------------------------------===//
-void mlir::linalg::populateSparseTensorRewriting(RewritePatternSet &patterns) {
- auto *context = patterns.getContext();
- patterns.add<FuseSparseMultiplyOverAdd>(context);
+void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns) {
+ // TODO(springerm): enable FuseSparseMultiplyOverAdd
+ patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
+ ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
}
SparsificationOptions options;
};
-/// Sparse rewriting rule for reshape operator.
-template <typename ReshapeOp>
-struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
-public:
- using OpRewritePattern<ReshapeOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ReshapeOp op,
- PatternRewriter &rewriter) const override {
- Location loc = op->getLoc();
- auto encDst = getSparseTensorEncoding(op.getResult().getType());
- auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
- // Since a pure dense expansion is very cheap (change of view), for
- // a sparse2dense or dense2sparse, we can simply unfuse a sparse
- // conversion from the reshape operation itself.
- // All other cases are handled elsewhere.
- if (encDst && encSrc) {
- return failure();
- } else if (encSrc) {
- RankedTensorType rtp =
- op.getSrc().getType().template cast<RankedTensorType>();
- auto denseTp =
- RankedTensorType::get(rtp.getShape(), rtp.getElementType());
- auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
- op->setOperand(0, convert);
- return success();
- } else if (encDst) {
- RankedTensorType rtp =
- op.getResult().getType().template cast<RankedTensorType>();
- auto denseTp =
- RankedTensorType::get(rtp.getShape(), rtp.getElementType());
- auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
- op.getReassociation());
- Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
- rewriter.replaceOp(op, convert);
- return success();
- }
- return failure();
- }
-};
-
} // namespace
/// Populates the given patterns list with rewriting rules required for
void mlir::populateSparsificationPatterns(
RewritePatternSet &patterns, const SparsificationOptions &options) {
patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
- patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
- ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
}
":SparseTensorDialect",
":SparseTensorPassIncGen",
":SparseTensorUtils",
+ ":Support",
":TensorDialect",
":Transforms",
":VectorDialect",