[mlir][Linalg] NFC: Combine elementwise fusion test passes.
authorMahesh Ravishankar <ravishankarm@google.com>
Mon, 7 Feb 2022 17:45:28 +0000 (17:45 +0000)
committerMahesh Ravishankar <ravishankarm@google.com>
Mon, 7 Feb 2022 22:46:57 +0000 (22:46 +0000)
There are a few different test passes that check elementwise fusion in
Linalg. Consolidate them to a single pass controlled by different pass
options (in keeping with how `TestLinalgTransforms` exists).

mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir
mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
mlir/test/Dialect/Linalg/reshape_control_fusion.mlir
mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
mlir/tools/mlir-opt/mlir-opt.cpp

index d81aab6..103a04d 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops -split-input-file | FileCheck %s
 
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 #binary2Dpointwise = {
index 0c02ff8..a1d4288 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-linalg-push-reshape -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=push-expanding-reshape -split-input-file | FileCheck %s
 
 // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
index d9e440c..c4e7d55 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-linalg-control-fusion-by-expansion %s -split-input-file | FileCheck %s
+// RUN: mlir-opt -test-linalg-elementwise-fusion-patterns=control-fusion-by-expansion %s -split-input-file | FileCheck %s
 
 func @control_producer_reshape_fusion(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?xf32>) -> tensor<?x?xf32> {
   %c0 = arith.constant 0 : index
index 30bef4a..85df325 100644 (file)
@@ -58,87 +58,77 @@ struct TestLinalgElementwiseFusion
     return "Test Linalg element wise operation fusion patterns";
   }
 
-  void runOnOperation() override {
-    MLIRContext *context = &this->getContext();
-    FuncOp funcOp = this->getOperation();
-    RewritePatternSet fusionPatterns(context);
-
-    linalg::populateElementwiseOpsFusionPatterns(
-        fusionPatterns,
-        linalg::LinalgElementwiseFusionOptions()
-            .setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>));
-
-    (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
-                                       std::move(fusionPatterns));
-  }
-};
-
-struct TestLinalgControlFuseByExpansion
-    : public PassWrapper<TestLinalgControlFuseByExpansion,
-                         OperationPass<FuncOp>> {
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry
-        .insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
-  }
-  StringRef getArgument() const final {
-    return "test-linalg-control-fusion-by-expansion";
-  }
-  StringRef getDescription() const final {
-    return "Test controlling of fusion of elementwise ops with reshape by "
-           "expansion";
-  }
+  Option<bool>
+      fuseGenericOps(*this, "fuse-generic-ops",
+                     llvm::cl::desc("Test fusion of generic operations."),
+                     llvm::cl::init(false));
+
+  Option<bool> controlFuseByExpansion(
+      *this, "control-fusion-by-expansion",
+      llvm::cl::desc(
+          "Test controlling fusion of reshape with generic op by expansion"),
+      llvm::cl::init(false));
+
+  Option<bool>
+      pushExpandingReshape(*this, "push-expanding-reshape",
+                           llvm::cl::desc("Test linalg expand_shape -> generic "
+                                          "to generic -> expand_shape pattern"),
+                           llvm::cl::init(false));
 
   void runOnOperation() override {
     MLIRContext *context = &this->getContext();
     FuncOp funcOp = this->getOperation();
-    RewritePatternSet fusionPatterns(context);
-
-    linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
-        [](const OpResult &producer, OpOperand &consumer) {
-          if (auto collapseOp =
-                  producer.getDefiningOp<tensor::CollapseShapeOp>()) {
-            if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) {
-              return false;
+
+    if (fuseGenericOps) {
+      RewritePatternSet fusionPatterns(context);
+      linalg::populateElementwiseOpsFusionPatterns(
+          fusionPatterns,
+          linalg::LinalgElementwiseFusionOptions()
+              .setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>));
+
+      (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
+                                         std::move(fusionPatterns));
+      return;
+    }
+
+    if (controlFuseByExpansion) {
+      RewritePatternSet fusionPatterns(context);
+
+      linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
+          [](const OpResult &producer, OpOperand &consumer) {
+            if (auto collapseOp =
+                    producer.getDefiningOp<tensor::CollapseShapeOp>()) {
+              if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) {
+                return false;
+              }
             }
-          }
-          if (auto expandOp =
-                  dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
-            if (expandOp->hasOneUse()) {
-              OpOperand &use = *expandOp->getUses().begin();
-              auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
-              if (linalgOp && linalgOp.isOutputTensor(&use))
-                return true;
+            if (auto expandOp =
+                    dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
+              if (expandOp->hasOneUse()) {
+                OpOperand &use = *expandOp->getUses().begin();
+                auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
+                if (linalgOp && linalgOp.isOutputTensor(&use))
+                  return true;
+              }
             }
-          }
-          return linalg::skipUnitDimReshape(producer, consumer);
-        };
-
-    linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
-                                                      controlReshapeFusionFn);
-    (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
-                                       std::move(fusionPatterns));
+            return linalg::skipUnitDimReshape(producer, consumer);
+          };
+
+      linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
+                                                        controlReshapeFusionFn);
+      (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
+                                         std::move(fusionPatterns));
+      return;
+    }
+
+    if (pushExpandingReshape) {
+      RewritePatternSet patterns(context);
+      linalg::populatePushReshapeOpsPatterns(patterns);
+      (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
+    }
   }
 };
 
-struct TestPushExpandingReshape
-    : public PassWrapper<TestPushExpandingReshape, OperationPass<FuncOp>> {
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry
-        .insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
-  }
-  StringRef getArgument() const final { return "test-linalg-push-reshape"; }
-  StringRef getDescription() const final {
-    return "Test Linalg reshape push patterns";
-  }
-
-  void runOnOperation() override {
-    MLIRContext *context = &this->getContext();
-    FuncOp funcOp = this->getOperation();
-    RewritePatternSet patterns(context);
-    linalg::populatePushReshapeOpsPatterns(patterns);
-    (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
-  }
-};
 } // namespace
 
 namespace test {
index 5b09cb8..73d1b54 100644 (file)
@@ -81,10 +81,8 @@ void registerTestGenericIRVisitorsPass();
 void registerTestGenericIRVisitorsInterruptPass();
 void registerTestInterfaces();
 void registerTestLinalgCodegenStrategy();
-void registerTestLinalgControlFuseByExpansion();
 void registerTestLinalgDistribution();
 void registerTestLinalgElementwiseFusion();
-void registerTestPushExpandingReshape();
 void registerTestLinalgFusionTransforms();
 void registerTestLinalgTensorFusionTransforms();
 void registerTestLinalgTiledLoopFusionTransforms();
@@ -172,10 +170,8 @@ void registerTestPasses() {
   mlir::test::registerTestGenericIRVisitorsPass();
   mlir::test::registerTestInterfaces();
   mlir::test::registerTestLinalgCodegenStrategy();
-  mlir::test::registerTestLinalgControlFuseByExpansion();
   mlir::test::registerTestLinalgDistribution();
   mlir::test::registerTestLinalgElementwiseFusion();
-  mlir::test::registerTestPushExpandingReshape();
   mlir::test::registerTestLinalgFusionTransforms();
   mlir::test::registerTestLinalgTensorFusionTransforms();
   mlir::test::registerTestLinalgTiledLoopFusionTransforms();