tuple_pattern = is_tuple((wildcard(), wildcard(), wildcard()))
assert tuple_pattern.match(relay.expr.Tuple((x,y,z)))
-The next example is matching a pattern of batch_norm -> get(0) -> relu:
+The next example is matching a pattern of batch_norm -> get(0) -> relu. Note that you can also use `is_tuple_get_item(bn_node)` to match a `TupleGetItem` node with any index.
.. code-block:: python
| is_expr(expr)
| is_op(op_name)
| is_tuple()
- | is_tuple_get_item()
+ | is_tuple_get_item(pattern, index = None)
| pattern1 `|` pattern2
| dominates(parent_pattern, path_pattern, child_pattern)
return TuplePattern(fields)
-def is_tuple_get_item(tuple_value: "DFPattern", index: int) -> "DFPattern":
+def is_tuple_get_item(tuple_value: "DFPattern", index: Optional[int] = None) -> "DFPattern":
"""
Syntatic sugar for creating an ExprPattern.
tuple_value: tvm.relay.dataflow_pattern.DFPattern
The input tuple expression.
- index: int
- The index.
+ index: Optional[int]
+ The index to match; Default (None) to match a TupleGetItem with any index.
Returns
-------
tuple_value: tvm.relay.dataflow_pattern.DFPattern
The input tuple expression.
- index: int
- The index.
+ index: Optional[int]
+ The index to match; Default (None) to match a TupleGetItem with any index.
"""
- def __init__(self, tuple_value: "DFPattern", index: int):
- self.__init_handle_by_constructor__(ffi.TupleGetItemPattern, tuple_value, index)
+ def __init__(self, tuple_value: "DFPattern", index: Optional[int] = None):
+ match_index = index if index is not None else -1
+ self.__init_handle_by_constructor__(ffi.TupleGetItemPattern, tuple_value, match_index)
@register_df_node
bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) {
bool matches = false;
if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
- matches = (op->index == tuple_get_item_node->index) &&
+ matches = (op->index == -1 || op->index == tuple_get_item_node->index) &&
VisitDFPattern(op->tuple, tuple_get_item_node->tuple);
}
return matches;
tuple_get_item_pattern = is_tuple_get_item(tuple_pattern, 1)
assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 1))
+ tuple_get_item_pattern = is_tuple_get_item(tuple_pattern) # Match any index
+ assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 0))
+ assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 1))
+ assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 2))
+
def test_no_match_tuple():
x = relay.var('x')