This is to help simplify FileCheck tests for sparse-tensor-rewrite.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D136093
std::unique_ptr<Pass> createSparseTensorCodegenPass();
//===----------------------------------------------------------------------===//
-// Other rewriting rules and passes.
+// The SparseTensorRewriting pass.
//===----------------------------------------------------------------------===//
-void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT);
+void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT,
+ bool enableForeach, bool enableConvert);
std::unique_ptr<Pass> createSparseTensorRewritePass();
std::unique_ptr<Pass>
-createSparseTensorRewritePass(const SparsificationOptions &options);
+createSparseTensorRewritePass(const SparsificationOptions &options,
+ bool enableForeach = true,
+ bool enableConvert = true);
+
+//===----------------------------------------------------------------------===//
+// Other rewriting rules and passes.
+//===----------------------------------------------------------------------===//
std::unique_ptr<Pass> createDenseBufferizationPass(
const bufferization::OneShotBufferizationOptions &options);
];
let options = [
Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
- "true", "Enable runtime library for manipulating sparse tensors">
+ "true", "Enable runtime library for manipulating sparse tensors">,
+ Option<"enableForeach", "enable-foreach", "bool",
+ "true", "Enable rewriting rules for the foreach operator">,
+ Option<"enableConvert", "enable-convert", "bool",
+ "true", "Enable rewriting rules for the convert operator">,
];
}
SparseTensorRewritePass() = default;
SparseTensorRewritePass(const SparseTensorRewritePass &pass) = default;
- SparseTensorRewritePass(const SparsificationOptions &options) {
+ SparseTensorRewritePass(const SparsificationOptions &options, bool foreach,
+ bool convert) {
enableRuntimeLibrary = options.enableRuntimeLibrary;
+ enableForeach = foreach;
+ enableConvert = convert;
}
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
- populateSparseTensorRewriting(patterns, enableRuntimeLibrary);
+ populateSparseTensorRewriting(patterns, enableRuntimeLibrary, enableForeach,
+ enableConvert);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
}
std::unique_ptr<Pass>
-mlir::createSparseTensorRewritePass(const SparsificationOptions &options) {
- return std::make_unique<SparseTensorRewritePass>(options);
+mlir::createSparseTensorRewritePass(const SparsificationOptions &options,
+ bool enableForeach, bool enableConvert) {
+ return std::make_unique<SparseTensorRewritePass>(options, enableForeach,
+ enableConvert);
}
std::unique_ptr<Pass> mlir::createSparsificationPass() {
// Methods that add patterns described in this file to a pattern list.
//===---------------------------------------------------------------------===//
void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns,
- bool enableRT) {
+ bool enableRT, bool enableForeach,
+ bool /*enableConvert*/) {
patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd,
ReshapeRewriter<tensor::ExpandShapeOp>,
- ReshapeRewriter<tensor::CollapseShapeOp>, ForeachRewriter>(
- patterns.getContext());
+ ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
+ if (enableForeach)
+ patterns.add<ForeachRewriter>(patterns.getContext());
+
// TODO: If RT not enabled, rewrite concatenate ops, etc here.
if (!enableRT)
patterns.add<ConcatenateRewriter, NewRewriter,