Change pattern test to use TestDialect instead.
authorJacques Pienaar <jpienaar@google.com>
Tue, 28 May 2019 03:04:56 +0000 (20:04 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:04:42 +0000 (20:04 -0700)
    Verify pattern specification, added benefit, named pattern and location recording using TestDialect. Naming is verified via explicitly adding named pattern to TestPatternDriver pass. Refactoring test to verify the desired functionality rather than generated code.

--

PiperOrigin-RevId: 250205618

mlir/test/TestDialect/TestOps.td
mlir/test/TestDialect/TestPatterns.cpp
mlir/test/mlir-tblgen/pattern-benefit.td [deleted file]
mlir/test/mlir-tblgen/pattern.mlir [new file with mode: 0644]
mlir/test/mlir-tblgen/pattern.td [deleted file]

index 65a21b8..9a7ead7 100644 (file)
@@ -77,4 +77,40 @@ def SameOperandAndResultShapeOp : TEST_Op<"same_operand_and_result_shape",
   let results = (outs AnyVectorOrTensor:$res);
 }
 
+//===----------------------------------------------------------------------===//
+// Test Patterns
+//===----------------------------------------------------------------------===//
+
+def OpA : TEST_Op<"op_a"> {
+  let arguments = (ins I32:$operand, I32Attr:$attr);
+  let results = (outs I32:$result);
+}
+
+def OpB : TEST_Op<"op_b"> {
+  let arguments = (ins I32:$operand, I32Attr:$attr);
+  let results = (outs I32:$result);
+}
+
+// Test named pattern.
+def TestNamedPatternRule : Pat<(OpA $input, $attr), (OpB $input, $attr)>;
+
+// Test with constant attr.
+def OpC : TEST_Op<"op_c">, Arguments<(ins I32:$arg)>, Results<(outs I32:$res)>;
+def : Pat<(OpC $input), (OpB $input, ConstantAttr<I32Attr, "17">:$attr)>;
+
+// Test with fused location.
+def : Pat<(OpA (OpA $input, $attr), $bttr), (OpB $input, $bttr)>;
+
+// Test added benefit.
+def OpD : TEST_Op<"op_d">, Arguments<(ins I32:$arg)>, Results<(outs I32:$res)>;
+def OpE : TEST_Op<"op_e">, Arguments<(ins I32:$arg)>, Results<(outs I32:$res)>;
+def OpF : TEST_Op<"op_f">, Arguments<(ins I32:$arg)>, Results<(outs I32:$res)>;
+def OpG : TEST_Op<"op_g">, Arguments<(ins I32:$arg)>, Results<(outs I32:$res)>;
+// Verify that bumping benefit results in selecting different op.
+def : Pat<(OpD $input), (OpE $input)>;
+def : Pat<(OpD $input), (OpF $input), [], (addBenefit 10)>;
+// Verify that patterns with more source nodes are selected before those with fewer.
+def : Pat<(OpG $input), (OpB $input, ConstantAttr<I32Attr, "20">:$attr)>;
+def : Pat<(OpG (OpG $input)), (OpB $input, ConstantAttr<I32Attr, "34">:$attr)>;
+
 #endif // TEST_OPS
index 1d5d398..a4c265b 100644 (file)
@@ -24,13 +24,19 @@ namespace {
 #include "TestPatterns.inc"
 
 struct TestPatternDriver : public FunctionPass<TestPatternDriver> {
-  void runOnFunction() {
-    mlir::OwningRewritePatternList patterns;
-    populateWithGenerated(&getContext(), &patterns);
-    applyPatternsGreedily(getFunction(), std::move(patterns));
-  }
+  void runOnFunction() override;
 };
 } // end anonymous namespace
 
+void TestPatternDriver::runOnFunction() {
+  mlir::OwningRewritePatternList patterns;
+  populateWithGenerated(&getContext(), &patterns);
+
+  // Verify named pattern is generated with expected name.
+  RewriteListBuilder<TestNamedPatternRule>::build(patterns, &getContext());
+
+  applyPatternsGreedily(getFunction(), std::move(patterns));
+}
+
 static mlir::PassRegistration<TestPatternDriver>
     pass("test-patterns", "Run test dialect patterns");
diff --git a/mlir/test/mlir-tblgen/pattern-benefit.td b/mlir/test/mlir-tblgen/pattern-benefit.td
deleted file mode 100644 (file)
index 36bc2c7..0000000
+++ /dev/null
@@ -1,34 +0,0 @@
-// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s
-
-include "mlir/IR/OpBase.td"
-
-def IfEqual : Constraint<CPred<"<notused>">>;
-
-def Test_Dialect : Dialect {
-  let name = "x";
-}
-class NS_Op<string mnemonic, list<OpTrait> traits = []> :
-    Op<Test_Dialect, mnemonic, traits>;
-
-// Define ops to rewrite.
-def U: Type<CPred<"true">, "U">;
-def X_AddOp : NS_Op<"add"> {
-  let arguments = (ins U, U);
-}
-def Y_AddOp : NS_Op<"add"> {
-  let arguments = (ins U, U, U);
-}
-def Z_AddOp : NS_Op<"add"> {
-  let arguments = (ins U);
-}
-
-// Define rewrite patterns.
-def bena : Pat<(X_AddOp (X_AddOp $lhs, $rhs), $rhs), (Y_AddOp $lhs, $rhs, $rhs)>;
-
-// CHECK-LABEL: struct bena
-// CHECK: RewritePattern("x.add", {"x.add"}, 2, context) {}
-
-def benb : Pat<(X_AddOp $lhs, $rhs), (Z_AddOp $lhs), [(IfEqual $lhs, $rhs)], (addBenefit 100)>;
-
-// CHECK-LABEL: struct benb
-// CHECK: RewritePattern("x.add", {"x.add"}, 101, context) {}
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
new file mode 100644 (file)
index 0000000..22a9129
--- /dev/null
@@ -0,0 +1,30 @@
+// RUN: mlir-test-opt -test-patterns -mlir-print-debuginfo %s | FileCheck %s
+
+// CHECK-LABEL: verifyConstantAttr
+func @verifyConstantAttr(%arg0 : i32) -> i32 {
+  %0 = "test.op_c"(%arg0) : (i32) -> i32 loc("a")
+
+  // CHECK: "test.op_b"(%arg0) {attr: 17 : i32} : (i32) -> i32 loc("a")
+  return %0 : i32
+}
+
+// CHECK-LABEL: verifyFusedLocs
+func @verifyFusedLocs(%arg0 : i32) -> i32 {
+  %0 = "test.op_a"(%arg0) {attr: 10 : i32} : (i32) -> i32 loc("a")
+  %result = "test.op_a"(%0) {attr: 20 : i32} : (i32) -> i32 loc("b")
+
+  // CHECK: "test.op_b"(%arg0) {attr: 10 : i32} : (i32) -> i32 loc("a")
+  // CHECK: "test.op_b"(%arg0) {attr: 20 : i32} : (i32) -> i32 loc(fused["b", "a"])
+  return %result : i32
+}
+
+// CHECK-LABEL: verifyBenefit
+func @verifyBenefit(%arg0 : i32) -> i32 {
+  %0 = "test.op_d"(%arg0) : (i32) -> i32
+  %1 = "test.op_g"(%arg0) : (i32) -> i32
+  %2 = "test.op_g"(%1) : (i32) -> i32
+
+  // CHECK: "test.op_f"(%arg0)
+  // CHECK: "test.op_b"(%arg0) {attr: 34 : i32}
+  return %0 : i32
+}
\ No newline at end of file
diff --git a/mlir/test/mlir-tblgen/pattern.td b/mlir/test/mlir-tblgen/pattern.td
deleted file mode 100644 (file)
index d8e92cb..0000000
+++ /dev/null
@@ -1,60 +0,0 @@
-// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s
-
-include "mlir/IR/OpBase.td"
-
-def Test_Dialect : Dialect {
-  let name = "";
-}
-class NS_Op<string mnemonic, list<OpTrait> traits> :
-    Op<Test_Dialect, mnemonic, traits>;
-
-def OpA : NS_Op<"op_a", []> {
-  let arguments = (ins I32:$operand, I32Attr:$attr);
-  let results = (outs I32:$result);
-}
-
-def OpB : NS_Op<"op_b", []> {
-  let arguments = (ins I32:$operand, I32Attr:$attr);
-  let results = (outs I32:$result);
-}
-
-def MyRule : Pat<(OpA $input, $attr), (OpB $input, $attr)>;
-def MyRule2 : Pat<(OpA (OpA $input, $attr), $attr), (OpB $input, $attr)>;
-
-// Test rewrite rule naming
-// ---
-
-// CHECK: Generated from:
-// CHECK-NEXT: {{.*pattern.td.*}}
-// CHECK: struct MyRule : public RewritePattern
-
-// CHECK-LABEL: struct MyRule2 : public RewritePattern
-// CHECK: s.autogeneratedRewritePatternOps[0] = op0;
-// CHECK: s.autogeneratedRewritePatternOps[1] = op1;
-// CHECK: rewriter.getFusedLoc({
-// CHECK-SAME: s.autogeneratedRewritePatternOps[0]->getLoc()
-// CHECK-SAME: s.autogeneratedRewritePatternOps[1]->getLoc()
-
-def : Pat<(OpA $input, $attr), (OpB $input, $attr)>;
-
-// Test basic structure generated from Pattern
-// ---
-
-// CHECK: struct GeneratedConvert0 : public RewritePattern
-
-// CHECK: GeneratedConvert0(MLIRContext *context) : RewritePattern("op_a", {"op_b"}, 1, context) {}
-
-// CHECK: struct MatchedState : public PatternState {
-// CHECK:   Value *input;
-// CHECK:   IntegerAttr attr;
-// CHECK: };
-
-// CHECK: PatternMatchResult match(Operation *op0) const override
-
-// CHECK: void rewrite(Operation *op, std::unique_ptr<PatternState> state,
-// CHECK:              PatternRewriter &rewriter) const override
-
-
-// CHECK: void populateWithGenerated(MLIRContext *context, OwningRewritePatternList *patterns)
-// CHECK:   patterns->push_back(llvm::make_unique<MyRule>(context));
-// CHECK:   patterns->push_back(llvm::make_unique<GeneratedConvert0>(context));