[DRR] Introduce `$_` to ignore op argument match
authorLei Zhang <antiagainst@google.com>
Mon, 2 Dec 2019 15:54:23 +0000 (07:54 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 2 Dec 2019 15:54:50 +0000 (07:54 -0800)
Right now op argument matching in DRR is position-based, meaning we need to
specify N arguments for an op with N ODS-declared argument. This can be annoying
when we don't want to capture all the arguments. `$_` is to remedy the situation.

PiperOrigin-RevId: 283339992

mlir/g3doc/DeclarativeRewrites.md
mlir/lib/TableGen/Pattern.cpp
mlir/test/lib/TestDialect/TestOps.td
mlir/test/mlir-tblgen/pattern.mlir
mlir/tools/mlir-tblgen/RewriterGen.cpp

index 2d9fb5b..5117ebd 100644 (file)
@@ -144,6 +144,13 @@ Also note that we only need to add `TypeConstraint` or `AttributeConstraint`
 when we need to further limit the match criteria. If all valid cases to the op
 are acceptable, then we can leave the constraint unspecified.
 
+`$_` is a special symbol to mean ignore capturing an argument. For example,
+`def : Pat<(AOp $_, $b), ...>` means only `$b` is interesting to capture and
+will be referenced later in result patterns. It's still possible to place
+additional constraints even if the symbol is not to be captured; for such case,
+you can simply use just the `TypeConstraint` or `AttributeConstraint` without a
+bound symbol, for example, `def : Pat<(AOp $a, F32Attr), ...>`.
+
 #### Matching DAG of operations
 
 To match an DAG of ops, use nested `dag` objects:
index d3c1ddd..098dba3 100644 (file)
@@ -564,7 +564,8 @@ void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
       // We can only bind symbols to op arguments in source pattern. Those
       // symbols are referenced in result patterns.
       auto treeArgName = tree.getArgName(i);
-      if (!treeArgName.empty()) {
+      // `$_` is a special symbol meaning ignore the current argument.
+      if (!treeArgName.empty() && treeArgName != "_") {
         LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
                                 << treeArgName << '\n');
         if (!infoMap.bindOpArgument(treeArgName, op, i)) {
index 6bb0cbc..6952eaa 100644 (file)
@@ -479,6 +479,18 @@ def OpJ : TEST_Op<"op_j">, Arguments<(ins)>, Results<(outs I32)>;
 def OpK : TEST_Op<"op_k">, Arguments<(ins)>, Results<(outs I32)>;
 def : Pat<(OpJ), (OpK)>;
 
+// Test `$_` for ignoring op argument match.
+def TestIgnoreArgMatchSrcOp : TEST_Op<"ignore_arg_match_src"> {
+  let arguments = (ins
+    AnyType:$a, AnyType:$b, AnyType:$c,
+    AnyAttr:$d, AnyAttr:$e, AnyAttr:$f);
+}
+def TestIgnoreArgMatchDstOp : TEST_Op<"ignore_arg_match_dst"> {
+  let arguments = (ins AnyType:$b, AnyAttr:$f);
+}
+def : Pat<(TestIgnoreArgMatchSrcOp $_, $b, I32, I64Attr:$_, $_, $f),
+          (TestIgnoreArgMatchDstOp $b, $f)>;
+
 def OpInterleavedOperandAttribute1 : TEST_Op<"interleaved_operand_attr1"> {
   let arguments = (ins
     I32:$input1,
index 1f6da05..67ea2fd 100644 (file)
@@ -24,6 +24,22 @@ func @verifyZeroArg() -> i32 {
   return %0 : i32
 }
 
+// CHECK-LABEL: testIgnoreArgMatch
+// CHECK-SAME: (%{{[a-z0-9]*}}: i32, %[[ARG1:[a-z0-9]*]]: i32
+func @testIgnoreArgMatch(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: f32) {
+  // CHECK: "test.ignore_arg_match_dst"(%[[ARG1]]) {f = 15 : i64}
+  "test.ignore_arg_match_src"(%arg0, %arg1, %arg2) {d = 42, e = 24, f = 15} : (i32, i32, i32) -> ()
+
+  // CHECK: test.ignore_arg_match_src
+  // Not match because wrong type for $c.
+  "test.ignore_arg_match_src"(%arg0, %arg1, %arg3) {d = 42, e = 24, f = 15} : (i32, i32, f32) -> ()
+
+  // CHECK: test.ignore_arg_match_src
+  // Not match because wrong type for $f.
+  "test.ignore_arg_match_src"(%arg0, %arg1, %arg2) {d = 42 : i32, e = 24, f = 15} : (i32, i32, i32) -> ()
+  return
+}
+
 // CHECK-LABEL: verifyInterleavedOperandAttribute
 // CHECK-SAME:    %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
 func @verifyInterleavedOperandAttribute(%arg0: i32, %arg1: i32) {
index d2776e0..a2cace7 100644 (file)
@@ -315,7 +315,8 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth,
 
   // Capture the value
   auto name = tree.getArgName(argIndex);
-  if (!name.empty()) {
+  // `$_` is a special symbol to ignore op argument matching.
+  if (!name.empty() && name != "_") {
     // We need to subtract the number of attributes before this operand to get
     // the index in the operand list.
     auto numPrevAttrs = std::count_if(
@@ -329,6 +330,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth,
 
 void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
                                         int indent) {
+
   Operator &op = tree.getDialectOp(opMap);
   auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
   const auto &attr = namedAttr->attr;
@@ -371,7 +373,8 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
 
   // Capture the value
   auto name = tree.getArgName(argIndex);
-  if (!name.empty()) {
+  // `$_` is a special symbol to ignore op argument matching.
+  if (!name.empty() && name != "_") {
     os.indent(indent) << formatv("{0} = tblgen_attr;\n", name);
   }