Add a test case for `applyPatternsAndFoldGreedily` to support the revert of 59bbc7a08
authorMehdi Amini <joker.eph@gmail.com>
Fri, 1 Apr 2022 05:52:25 +0000 (05:52 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Fri, 1 Apr 2022 06:17:07 +0000 (06:17 +0000)
This shows that pushing constant to the right in a commutative op leads
to `applyPatternsAndFoldGreedily` to converge without applying all the
patterns.

Differential Revision: https://reviews.llvm.org/D122870

mlir/test/Transforms/test-operation-folder-commutative.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/Dialect/Test/TestPatterns.cpp

diff --git a/mlir/test/Transforms/test-operation-folder-commutative.mlir b/mlir/test/Transforms/test-operation-folder-commutative.mlir
new file mode 100644 (file)
index 0000000..9640e79
--- /dev/null
@@ -0,0 +1,11 @@
+// RUN: mlir-opt --pass-pipeline="func.func(test-patterns)" %s | FileCheck %s
+
+// CHECK-LABEL: func @test_reorder_constants_and_match
+func @test_reorder_constants_and_match(%arg0 : i32) -> (i32) {
+  // CHECK: %[[CST:.+]] = arith.constant 43
+  %cst = arith.constant 43 : i32
+  // CHECK: return %[[CST]]
+  %y = "test.op_commutative2"(%cst, %arg0) : (i32, i32) -> i32
+  %x = "test.op_commutative2"(%y, %arg0) : (i32, i32) -> i32
+  return %x : i32
+}
index 5bb397353af28b41c787fdca208b9205b2db2e59..b157d5da742a57ca998b6d063dbd06985eb8646a 100644 (file)
@@ -1089,6 +1089,11 @@ def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> {
   let results = (outs I32);
 }
 
+def TestCommutative2Op : TEST_Op<"op_commutative2", [Commutative]> {
+  let arguments = (ins I32:$op1, I32:$op2);
+  let results = (outs I32);
+}
+
 def TestIdempotentTraitOp
  : TEST_Op<"op_idempotent_trait",
            [SameOperandsAndResultType, NoSideEffect, Idempotent]> {
index 28a988412f69eeda3220d88b08922bf90dc4dfda..bd80ea793e59740f650e11e9ece3cccadf2c21cc 100644 (file)
@@ -124,18 +124,40 @@ public:
   }
 };
 
+/// This pattern matches test.op_commutative2 with the first operand being
+/// another test.op_commutative2 with a constant on the right side and fold it
+/// away by propagating it as its result. This is intend to check that patterns
+/// are applied after the commutative property moves constant to the right.
+struct FolderCommutativeOp2WithConstant
+    : public OpRewritePattern<TestCommutative2Op> {
+public:
+  using OpRewritePattern<TestCommutative2Op>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TestCommutative2Op op,
+                                PatternRewriter &rewriter) const override {
+    auto operand =
+        dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp());
+    if (!operand)
+      return failure();
+    Attribute constInput;
+    if (!matchPattern(operand->getOperand(1), m_Constant(&constInput)))
+      return failure();
+    rewriter.replaceOp(op, operand->getOperand(1));
+    return success();
+  }
+};
+
 struct TestPatternDriver
     : public PassWrapper<TestPatternDriver, OperationPass<FuncOp>> {
   StringRef getArgument() const final { return "test-patterns"; }
   StringRef getDescription() const final { return "Run test dialect patterns"; }
   void runOnOperation() override {
     mlir::RewritePatternSet patterns(&getContext());
-    populateWithGenerated(patterns);
 
     // Verify named pattern is generated with expected name.
     patterns.add<FoldingPattern, TestNamedPatternRule,
-                 FolderInsertBeforePreviouslyFoldedConstantPattern>(
-        &getContext());
+                 FolderInsertBeforePreviouslyFoldedConstantPattern,
+                 FolderCommutativeOp2WithConstant>(&getContext());
 
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }