[PatternLang] Support any index matching for TupleGetItem (#5909)
authorCody Yu <comaniac0422@gmail.com>
Wed, 24 Jun 2020 18:44:23 +0000 (11:44 -0700)
committerGitHub <noreply@github.com>
Wed, 24 Jun 2020 18:44:23 +0000 (11:44 -0700)
* support any index matching

* update doc

docs/langref/relay_pattern.rst
python/tvm/relay/dataflow_pattern/__init__.py
src/relay/ir/dataflow_matcher.cc
tests/python/relay/test_dataflow_pattern.py

index b237f14..6cacff2 100644 (file)
@@ -148,7 +148,7 @@ Since there are not call nodes, we need to use specific pattern nodes to match t
         tuple_pattern = is_tuple((wildcard(), wildcard(), wildcard()))
         assert tuple_pattern.match(relay.expr.Tuple((x,y,z)))
 
-The next example is matching a pattern of batch_norm -> get(0) -> relu:
+The next example is matching a pattern of batch_norm -> get(0) -> relu. Note that you can also use `is_tuple_get_item(bn_node)` to match a `TupleGetItem` node with any index.
 
 .. code-block:: python
 
@@ -280,7 +280,7 @@ The high level design is to introduce a language of patterns for now we propose
             | is_expr(expr)
             | is_op(op_name)
             | is_tuple()
-            | is_tuple_get_item()
+            | is_tuple_get_item(pattern, index = None)
             | pattern1 `|` pattern2
             | dominates(parent_pattern, path_pattern, child_pattern)
 
index 915842c..317d28e 100644 (file)
@@ -293,7 +293,7 @@ def is_tuple(fields: tvm.ir.container.Array) -> "DFPattern":
     return TuplePattern(fields)
 
 
-def is_tuple_get_item(tuple_value: "DFPattern", index: int) -> "DFPattern":
+def is_tuple_get_item(tuple_value: "DFPattern", index: Optional[int] = None) -> "DFPattern":
     """
     Syntatic sugar for creating an ExprPattern.
 
@@ -302,8 +302,8 @@ def is_tuple_get_item(tuple_value: "DFPattern", index: int) -> "DFPattern":
     tuple_value: tvm.relay.dataflow_pattern.DFPattern
         The input tuple expression.
 
-    index: int
-        The index.
+    index: Optional[int]
+        The index to match; Default (None) to match a TupleGetItem with any index.
 
     Returns
     -------
@@ -555,12 +555,13 @@ class TupleGetItemPattern(DFPattern):
     tuple_value: tvm.relay.dataflow_pattern.DFPattern
         The input tuple expression.
 
-    index: int
-        The index.
+    index: Optional[int]
+        The index to match; Default (None) to match a TupleGetItem with any index.
     """
 
-    def __init__(self, tuple_value: "DFPattern", index: int):
-        self.__init_handle_by_constructor__(ffi.TupleGetItemPattern, tuple_value, index)
+    def __init__(self, tuple_value: "DFPattern", index: Optional[int] = None):
+        match_index = index if index is not None else -1
+        self.__init_handle_by_constructor__(ffi.TupleGetItemPattern, tuple_value, match_index)
 
 
 @register_df_node
index c9bf11e..d33891a 100644 (file)
@@ -359,7 +359,7 @@ bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& ex
 bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) {
   bool matches = false;
   if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
-    matches = (op->index == tuple_get_item_node->index) &&
+    matches = (op->index == -1 || op->index == tuple_get_item_node->index) &&
               VisitDFPattern(op->tuple, tuple_get_item_node->tuple);
   }
   return matches;
index 9727e53..4fce473 100644 (file)
@@ -253,6 +253,11 @@ def test_match_tuple():
     tuple_get_item_pattern = is_tuple_get_item(tuple_pattern, 1)
     assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 1))
 
+    tuple_get_item_pattern = is_tuple_get_item(tuple_pattern) # Match any index
+    assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 0))
+    assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 1))
+    assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 2))
+
 
 def test_no_match_tuple():
     x = relay.var('x')