return "Test Linalg element wise operation fusion patterns";
}
- void runOnOperation() override {
- MLIRContext *context = &this->getContext();
- FuncOp funcOp = this->getOperation();
- RewritePatternSet fusionPatterns(context);
-
- linalg::populateElementwiseOpsFusionPatterns(
- fusionPatterns,
- linalg::LinalgElementwiseFusionOptions()
- .setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>));
-
- (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
- std::move(fusionPatterns));
- }
-};
-
-struct TestLinalgControlFuseByExpansion
- : public PassWrapper<TestLinalgControlFuseByExpansion,
- OperationPass<FuncOp>> {
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry
- .insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
- }
- StringRef getArgument() const final {
- return "test-linalg-control-fusion-by-expansion";
- }
- StringRef getDescription() const final {
- return "Test controlling of fusion of elementwise ops with reshape by "
- "expansion";
- }
+ Option<bool>
+ fuseGenericOps(*this, "fuse-generic-ops",
+ llvm::cl::desc("Test fusion of generic operations."),
+ llvm::cl::init(false));
+
+ Option<bool> controlFuseByExpansion(
+ *this, "control-fusion-by-expansion",
+ llvm::cl::desc(
+ "Test controlling fusion of reshape with generic op by expansion"),
+ llvm::cl::init(false));
+
+ Option<bool>
+ pushExpandingReshape(*this, "push-expanding-reshape",
+ llvm::cl::desc("Test linalg expand_shape -> generic "
+ "to generic -> expand_shape pattern"),
+ llvm::cl::init(false));
void runOnOperation() override {
MLIRContext *context = &this->getContext();
FuncOp funcOp = this->getOperation();
- RewritePatternSet fusionPatterns(context);
-
- linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
- [](const OpResult &producer, OpOperand &consumer) {
- if (auto collapseOp =
- producer.getDefiningOp<tensor::CollapseShapeOp>()) {
- if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) {
- return false;
+
+ if (fuseGenericOps) {
+ RewritePatternSet fusionPatterns(context);
+ linalg::populateElementwiseOpsFusionPatterns(
+ fusionPatterns,
+ linalg::LinalgElementwiseFusionOptions()
+ .setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>));
+
+ (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
+ std::move(fusionPatterns));
+ return;
+ }
+
+ if (controlFuseByExpansion) {
+ RewritePatternSet fusionPatterns(context);
+
+ linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
+ [](const OpResult &producer, OpOperand &consumer) {
+ if (auto collapseOp =
+ producer.getDefiningOp<tensor::CollapseShapeOp>()) {
+ if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) {
+ return false;
+ }
}
- }
- if (auto expandOp =
- dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
- if (expandOp->hasOneUse()) {
- OpOperand &use = *expandOp->getUses().begin();
- auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
- if (linalgOp && linalgOp.isOutputTensor(&use))
- return true;
+ if (auto expandOp =
+ dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
+ if (expandOp->hasOneUse()) {
+ OpOperand &use = *expandOp->getUses().begin();
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
+ if (linalgOp && linalgOp.isOutputTensor(&use))
+ return true;
+ }
}
- }
- return linalg::skipUnitDimReshape(producer, consumer);
- };
-
- linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
- controlReshapeFusionFn);
- (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
- std::move(fusionPatterns));
+ return linalg::skipUnitDimReshape(producer, consumer);
+ };
+
+ linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
+ controlReshapeFusionFn);
+ (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
+ std::move(fusionPatterns));
+ return;
+ }
+
+ if (pushExpandingReshape) {
+ RewritePatternSet patterns(context);
+ linalg::populatePushReshapeOpsPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
+ }
}
};
-struct TestPushExpandingReshape
- : public PassWrapper<TestPushExpandingReshape, OperationPass<FuncOp>> {
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry
- .insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
- }
- StringRef getArgument() const final { return "test-linalg-push-reshape"; }
- StringRef getDescription() const final {
- return "Test Linalg reshape push patterns";
- }
-
- void runOnOperation() override {
- MLIRContext *context = &this->getContext();
- FuncOp funcOp = this->getOperation();
- RewritePatternSet patterns(context);
- linalg::populatePushReshapeOpsPatterns(patterns);
- (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
- }
-};
} // namespace
namespace test {
void registerTestGenericIRVisitorsInterruptPass();
void registerTestInterfaces();
void registerTestLinalgCodegenStrategy();
-void registerTestLinalgControlFuseByExpansion();
void registerTestLinalgDistribution();
void registerTestLinalgElementwiseFusion();
-void registerTestPushExpandingReshape();
void registerTestLinalgFusionTransforms();
void registerTestLinalgTensorFusionTransforms();
void registerTestLinalgTiledLoopFusionTransforms();
mlir::test::registerTestGenericIRVisitorsPass();
mlir::test::registerTestInterfaces();
mlir::test::registerTestLinalgCodegenStrategy();
- mlir::test::registerTestLinalgControlFuseByExpansion();
mlir::test::registerTestLinalgDistribution();
mlir::test::registerTestLinalgElementwiseFusion();
- mlir::test::registerTestPushExpandingReshape();
mlir::test::registerTestLinalgFusionTransforms();
mlir::test::registerTestLinalgTensorFusionTransforms();
mlir::test::registerTestLinalgTiledLoopFusionTransforms();