From 43162d669d466c69bc6a64771b2fb8441f0a1c69 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 1 Jun 2020 17:14:33 -0700 Subject: [PATCH] [PatternLang] Simplify Pattern API Implementations (#5703) * Add syntatic sugar; include pattern to API docs * fix doc warnings --- docs/api/python/index.rst | 1 + docs/api/python/relay/dataflow_pattern.rst | 25 +++ docs/langref/relay_pattern.rst | 21 ++- python/tvm/relay/dataflow_pattern/__init__.py | 230 +++++++++++++++++--------- tests/python/relay/test_dataflow_pattern.py | 55 +++--- 5 files changed, 219 insertions(+), 113 deletions(-) create mode 100644 docs/api/python/relay/dataflow_pattern.rst diff --git a/docs/api/python/index.rst b/docs/api/python/index.rst index 50d7a3d..bee6e56 100644 --- a/docs/api/python/index.rst +++ b/docs/api/python/index.rst @@ -37,6 +37,7 @@ Python API relay/transform relay/analysis relay/backend + relay/dataflow_pattern relay/testing autotvm rpc diff --git a/docs/api/python/relay/dataflow_pattern.rst b/docs/api/python/relay/dataflow_pattern.rst new file mode 100644 index 0000000..fe1d4e9 --- /dev/null +++ b/docs/api/python/relay/dataflow_pattern.rst @@ -0,0 +1,25 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +tvm.relay.dataflow_pattern +-------------------------- + +.. automodule:: tvm.relay.dataflow_pattern + :members: + :imported-members: + :exclude-members: Object, Node + :autosummary: diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst index 7d5deb2..962dcc6 100644 --- a/docs/langref/relay_pattern.rst +++ b/docs/langref/relay_pattern.rst @@ -114,7 +114,7 @@ Since there are not call nodes, we need to use specific pattern nodes to match t x = relay.var('x') y = relay.var('y') z = relay.var('z') - tuple_pattern = TuplePattern((wildcard(), wildcard(), wildcard())) + tuple_pattern = is_tuple((wildcard(), wildcard(), wildcard())) assert tuple_pattern.match(relay.expr.Tuple((x,y,z))) The next example is matching a pattern of batch_norm -> get(0) -> relu: @@ -123,7 +123,7 @@ The next example is matching a pattern of batch_norm -> get(0) -> relu: def test_match_tuple_get_item(): bn_node = is_op('nn.batch_norm')(wildcard(), wildcard(), wildcard(), wildcard(), wildcard()) - tuple_get_item_node = TupleGetItemPattern(bn_node, 0) + tuple_get_item_node = is_tuple_get_item(bn_node, 0) pat = is_op('nn.relu')(tuple_get_item_node) x = relay.var('x', shape=(1, 8)) @@ -142,7 +142,7 @@ if a specific parameter in a subgraph has been bound or not. .. code-block:: python def test_match_constant(): - conv2d = is_op('nn.conv2d')(wildcard(), ConstantPattern()) + conv2d = is_op('nn.conv2d')(wildcard(), is_constant()) pattern = is_op('nn.bias_add')(conv2d, wildcard()) x = relay.var('x', shape=(1, 3, 224, 224)) @@ -162,12 +162,12 @@ if a specific parameter in a subgraph has been bound or not. assert pattern.match(mod['main'].body) On the other hand, if you need to match the constant with a specific value, you can directly -use ``ExprPattern``. This could be useful for algebraic simplify. +use ``is_expr``. This could be useful for algebraic simplify. .. code-block:: python def test_match_plus_zero(): - zero = (ExprPattern(relay.const(0)) | ExprPattern(relay.const(0.0))) + zero = (is_expr(relay.const(0)) | is_expr(relay.const(0.0))) pattern = wildcard() + zero x = relay.Var('x') @@ -193,7 +193,7 @@ The next example is matching a diamond with two inputs at the top of the diamond def test_match_diamond(): # Pattern - is_conv2d = is_op('nn.conv2d')(is_input(), is_input()) + is_conv2d = is_op('nn.conv2d')(is_var(), is_var()) path1 = is_op('nn.relu')(is_conv2d) path2 = is_op('nn.leaky_relu')(is_conv2d) diamond = is_op('add')(path1, path2) @@ -213,7 +213,7 @@ The final example is matching diamonds with a post-dominator relationship. We em def test_match_dom_diamond(): # Pattern - is_conv2d = is_op('nn.conv2d')(is_input(), is_input()) + is_conv2d = is_op('nn.conv2d')(is_var(), is_var()) reduction = is_op('add')(wildcard(), wildcard()) diamond = dominates(is_conv2d, is_elemwise, reduction) @@ -240,7 +240,12 @@ The high level design is to introduce a language of patterns for now we propose | pattern(pattern1, ... patternN) | has_type(pattern, type) | has_attr(pattern, attrs) - | is_input(name) + | is_var(name) + | is_constant() + | is_expr(expr) + | is_op(op_name) + | is_tuple() + | is_tuple_get_item() | pattern1 `|` pattern2 | dominates(parent_pattern, path_pattern, child_pattern) diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index f1d0784..e6a1a5e 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -15,12 +15,16 @@ # specific language governing permissions and limitations # under the License. """The Relay Pattern Language and tooling.""" -from tvm.relay.expr import RelayExpr as Expr +# pylint: disable=no-member +from typing import Callable, Dict, List, Optional + import tvm._ffi -from ...ir.base import Node +from tvm.relay.expr import RelayExpr as Expr + +from ... import _ffi as tvm_ffi from ...ir import make_node +from ...ir.base import Node from ...runtime import Object -from ... import _ffi as tvm_ffi from ..op import get from . import _ffi as ffi @@ -61,7 +65,7 @@ class DFPattern(Node): def __truediv__(self, other): return is_op("divide")(self, other) - def has_attr(self, attrs): + def has_attr(self, attrs: Dict[str, Object]): """ Add an attribute constraint to this pattern @@ -77,13 +81,13 @@ class DFPattern(Node): attrs = make_node("DictAttrs", **attrs) return AttrPattern(self, attrs) - def has_type(self, ttype): + def has_type(self, ttype: tvm.ir.type.Type): """ Add a type constraint to this pattern Parameters ---------- - ttype: tvm.relay.Type + ttype: tvm.ir.type.Type The type to match Returns @@ -109,7 +113,10 @@ class DFPattern(Node): """ return match(self, expr) - def partition(self, expr: Expr, attrs=None, check=lambda x: True) -> Expr: + def partition(self, + expr: Expr, + attrs: Optional[Dict[str, Object]] = None, + check: Callable[[Expr], bool] = lambda x: True) -> Expr: """ Parition the expression into functions defined by this pattern @@ -119,7 +126,7 @@ class DFPattern(Node): The expression to match. attrs : Optional[Dict[str, Object]] A dictionary of Attribute name/values to add to the paritioned function - check : Function + check : Callable[[Expr], bool] A function to perform more complicated checks on the matched expression. Returns true if partitioning should proceed, false otherwise. @@ -130,9 +137,9 @@ class DFPattern(Node): """ return partition(self, expr, attrs, check) - def dominates(self, parent, path=None): + def dominates(self, parent: "DFPattern", path: "DFPattern" = None): """ - Create a dominator for this pattern + Create a dominator for this pattern. Parameters ---------- @@ -144,15 +151,15 @@ class DFPattern(Node): Returns ------- result: tvm.relay.dataflow_pattern.DFPattern - The resulting DominatorPattern + The resulting DominatorPattern. """ if path is None: path = wildcard() return DominatorPattern(parent, path, self) - def optional(self, option_constructor): + def optional(self, option_constructor: Callable[["DFPattern"], "DFPattern"]): """ - Create a optional user of this pattern + Create a optional user of this pattern. Parameters ---------- @@ -168,26 +175,60 @@ class DFPattern(Node): return self | option_constructor(self) -def is_input(name: str = "") -> DFPattern: +def is_var(name: str = "") -> "DFPattern": """ - Syntatic sugar for creating an optionally named VarPattern + Syntatic sugar for creating an optionally named VarPattern. Parameters ---------- name: str - The name of the input pattern to match + The name of the input pattern to match. Returns ------- result: tvm.relay.dataflow_pattern.DFPattern - The resulting InputPattern + The resulting pattern. """ return VarPattern(name) -def is_op(op_name: str) -> DFPattern: +def is_constant() -> "DFPattern": + """ + Syntatic sugar for creating a ConstantPattern. + + Parameters + ---------- + name: str + The name of the input pattern to match. + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting pattern. + """ + return ConstantPattern() + + +def is_expr(expr: Expr) -> "DFPattern": + """ + Syntatic sugar for creating an ExprPattern. + + Parameters + ---------- + expr: Expr + The Relay expression to match. + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting pattern. + """ + return ExprPattern(expr) + + +def is_op(op_name: str) -> "DFPattern": """ - Syntatic sugar for creating an operator ExprPattern + Syntatic sugar for creating an operator ExprPattern. Parameters ---------- @@ -203,19 +244,56 @@ def is_op(op_name: str) -> DFPattern: return ExprPattern(op) -def wildcard() -> DFPattern: +def is_tuple(fields: tvm.ir.container.Array) -> "DFPattern": """ - Syntatic sugar for creating a WildcardPattern + Syntatic sugar for creating an ExprPattern. + + Parameters + ---------- + fields : Array[tvm.relay.dataflow_pattern.DFPattern] + The fields in the tuple. Returns ------- result: tvm.relay.dataflow_pattern.DFPattern - The resulting WildcardPattern + The resulting pattern. + """ + return TuplePattern(fields) + + +def is_tuple_get_item(tuple_value: "DFPattern", index: int) -> "DFPattern": + """ + Syntatic sugar for creating an ExprPattern. + + Parameters + ---------- + tuple_value: tvm.relay.dataflow_pattern.DFPattern + The input tuple expression. + + index: int + The index. + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting pattern. + """ + return TupleGetItemPattern(tuple_value, index) + + +def wildcard() -> "DFPattern": + """ + Syntatic sugar for creating a WildcardPattern. + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting pattern. """ return WildcardPattern() -def has_type(ttype, pattern: DFPattern = None) -> DFPattern: +def has_type(ttype, pattern: "DFPattern" = None) -> "DFPattern": """ Syntatic sugar for creating a TypePattern @@ -224,7 +302,7 @@ def has_type(ttype, pattern: DFPattern = None) -> DFPattern: pattern: tvm.relay.dataflow_pattern.DFPattern The pattern that needs type annotation - ttype: tvm.relay.Type + ttype: tvm.ir.type.Type The type to match Returns @@ -237,7 +315,7 @@ def has_type(ttype, pattern: DFPattern = None) -> DFPattern: return TypePattern(pattern, ttype) -def has_attr(attrs, pattern=None) -> DFPattern: +def has_attr(attrs, pattern=None) -> "DFPattern": """ Syntatic sugar for creating an AttrPattern @@ -259,7 +337,7 @@ def has_attr(attrs, pattern=None) -> DFPattern: return pattern.has_attr(attrs) -def dominates(parent: DFPattern, path: DFPattern, child: DFPattern) -> DFPattern: +def dominates(parent: "DFPattern", path: "DFPattern", child: "DFPattern") -> "DFPattern": """ Syntatic sugar for creating an Dominator pattern @@ -275,12 +353,12 @@ def dominates(parent: DFPattern, path: DFPattern, child: DFPattern) -> DFPattern Returns ------- result: tvm.relay.dataflow_pattern.DFPattern - The resulting DominatorPattern + The resulting DominatorPattern. """ return DominatorPattern(parent, path, child) -def match(pattern: DFPattern, expr: Expr) -> bool: +def match(pattern: "DFPattern", expr: Expr) -> bool: """ Match a pattern to an expression @@ -321,13 +399,12 @@ class VarPattern(DFPattern): The name of the variable. Optional, if not provided, the pattern will match any VarNode. - type_annotation: tvm.relay.Type, optional + type_annotation: tvm.ir.type.Type, optional The type annotation on the variable. """ - def __init__(self, name_hint="", type_annotation=None): - self.__init_handle_by_constructor__( - ffi.VarPattern, name_hint, type_annotation) + def __init__(self, name_hint: str = "", type_annotation: Optional[tvm.ir.type.Type] = None): + self.__init_handle_by_constructor__(ffi.VarPattern, name_hint, type_annotation) @register_df_node @@ -350,19 +427,22 @@ class CallPattern(DFPattern): args: List[realy.dataflow_pattern.DFPattern] The arguments to the call. - attrs: Optional[tvm.Attrs] + attrs: Optional[tvm.ir.attrs.Attrs] Attributes to the call, can be None - type_args: Optional[List[tvm.relay.Type]] + type_args: Optional[List[tvm.ir.type.Type]] The additional type arguments, this is only used in advanced usecase of template functions. """ - def __init__(self, op, args, attrs=None, type_args=None): + def __init__(self, + op: "DFPattern", + args: List["DFPattern"], + attrs: Optional[tvm.ir.attrs.Attrs] = None, + type_args: Optional[List[tvm.ir.type.Type]] = None): if not type_args: type_args = [] - self.__init_handle_by_constructor__( - ffi.CallPattern, op, args, attrs, type_args) + self.__init_handle_by_constructor__(ffi.CallPattern, op, args, attrs, type_args) @register_df_node @@ -371,14 +451,14 @@ class TuplePattern(DFPattern): Parameters ---------- - fields : List[tvm.relay.dataflow_pattern.DFPattern] + fields : Array[tvm.relay.dataflow_pattern.DFPattern] The fields in the tuple. """ - def __init__(self, fields): + def __init__(self, fields: tvm.ir.container.Array): self.__init_handle_by_constructor__(ffi.TuplePattern, fields) - def __getitem__(self, index): + def __getitem__(self, index: int): if index >= len(self): raise IndexError("TuplePattern index out of range") return self.fields[index] @@ -403,9 +483,8 @@ class TupleGetItemPattern(DFPattern): The index. """ - def __init__(self, tuple_value: DFPattern, index): - self.__init_handle_by_constructor__( - ffi.TupleGetItemPattern, tuple_value, index) + def __init__(self, tuple_value: "DFPattern", index: int): + self.__init_handle_by_constructor__(ffi.TupleGetItemPattern, tuple_value, index) @register_df_node @@ -415,14 +494,13 @@ class AltPattern(DFPattern): Parameters ---------- left: tvm.relay.dataflow_pattern.DFPattern - One possible matching Pattern + One possible matching pattern. right: tvm.relay.dataflow_pattern.DFPattern - One possible matching Pattern + One possible matching pattern. """ - def __init__(self, left: DFPattern, right: DFPattern): - self.__init_handle_by_constructor__( - ffi.AltPattern, left, right) + def __init__(self, left: "DFPattern", right: "DFPattern"): + self.__init_handle_by_constructor__(ffi.AltPattern, left, right) @register_df_node @@ -441,34 +519,32 @@ class TypePattern(DFPattern): Parameters ---------- pattern: tvm.relay.dataflow_pattern.DFPattern - The input pattern that needs type annotation + The input pattern that needs type annotation. - ttype: tvm.relay.Type - The type to match + ttype: tvm.ir.type.Type + The type to match. """ - def __init__(self, pattern: DFPattern, ttype): - self.__init_handle_by_constructor__( - ffi.TypePattern, pattern, ttype) + def __init__(self, pattern: "DFPattern", ttype: tvm.ir.type.Type): + self.__init_handle_by_constructor__(ffi.TypePattern, pattern, ttype) @register_df_node class AttrPattern(DFPattern): """Get match an expression with a certain attributes. - Currently only supports Op Attributes, not call Attributes + Currently only supports Op Attributes, not call Attributes. Parameters ---------- pattern: tvm.relay.dataflow_pattern.DFPattern The input pattern. - attrs: tvm.Attrs - The attributes to match + attrs: tvm.ir.attrs.Attrs + The attributes to match. """ - def __init__(self, pattern: DFPattern, attrs): - self.__init_handle_by_constructor__( - ffi.AttrPattern, pattern, attrs) + def __init__(self, pattern: "DFPattern", attrs: tvm.ir.attrs.Attrs): + self.__init_handle_by_constructor__(ffi.AttrPattern, pattern, attrs) @register_df_node @@ -479,22 +555,21 @@ class DominatorPattern(DFPattern): ---------- parent: tvm.relay.dataflow_pattern.DFPattern The parent, i.e., the single node which produces something, - later aggregated by the child + later aggregated by the child. path: tvm.relay.dataflow_pattern.DFPattern The fuzzy path pattern between parent and child, - typically matches elementwise ops + typically matches elementwise ops. child: tvm.relay.dataflow_pattern.DFPattern The last node in the domination which is the end user - for all nodes in the path and the parent + for all nodes in the path and the parent. """ - def __init__(self, parent: DFPattern, path: DFPattern, child: DFPattern): - self.__init_handle_by_constructor__( - ffi.DominatorPattern, parent, path, child) + def __init__(self, parent: "DFPattern", path: "DFPattern", child: "DFPattern"): + self.__init_handle_by_constructor__(ffi.DominatorPattern, parent, path, child) class DFPatternCallback: - """A Callback for Pattern Rewriting + """A Callback for Pattern Rewriting. When rewrite is called on this DFPatternCallback, the backend will find matches for the pattern, call the callback function, and replace the matched expression with whatever @@ -515,11 +590,11 @@ class DFPatternCallback: Returns ------- result : tvm.relay.Expr - The Expression with matched subgraphs rewritten by the callbacks + The Expression with matched subgraphs rewritten by the callbacks. """ return rewrite(self, expr) - def callback(self, pre, post, node_map): + def callback(self, pre: Expr, post: Expr, node_map: tvm.ir.container.Map) -> Expr: """ Callback function to use when we found a match to the pattern @@ -529,7 +604,7 @@ class DFPatternCallback: The matching expression from the original graph. post : tvm.relay.Expr The matching expression with rewritten inputs - node_map : Map(DFPattern, List(Expr)) + node_map : tvm.ir.container.Map[DFPattern, List[Expr]] The map between patterns and matched expressions Returns @@ -542,13 +617,12 @@ class DFPatternCallback: class _DFPatternCallback(Object): """C++ implemenation""" def __init__(self, pattern, callback): - self.__init_handle_by_constructor__( - ffi.DFPatternCallback, pattern, callback) + self.__init_handle_by_constructor__(ffi.DFPatternCallback, pattern, callback) def rewrite(callbacks, expr: Expr) -> Expr: """ - Rewrite expression with the given callbacks + Rewrite expression with the given callbacks. Parameters ---------- @@ -560,7 +634,7 @@ def rewrite(callbacks, expr: Expr) -> Expr: Returns ------- result : tvm.relay.Expr - The Expression with matched subgraphs rewritten by the callbacks + The Expression with matched subgraphs rewritten by the callbacks. """ if isinstance(callbacks, DFPatternCallback): tmp = [_DFPatternCallback(callbacks.pattern, callbacks.callback)] @@ -571,7 +645,11 @@ def rewrite(callbacks, expr: Expr) -> Expr: return ffi.rewrite(tmp, expr) -def partition(pattern: DFPattern, expr: Expr, attrs=None, check=lambda x: True) -> Expr: + +def partition(pattern: "DFPattern", + expr: Expr, + attrs: Optional[Dict[str, Object]] = None, + check: Callable[[Expr], bool] = lambda x: True) -> Expr: """ Parition the expression into a series of functions that match the pattern @@ -583,7 +661,7 @@ def partition(pattern: DFPattern, expr: Expr, attrs=None, check=lambda x: True) The expression to split into functions attrs : Optional[Dict[str, Object]] A dict of attributes to apply to the partitioned function - check : Function + check : Callable[[Expr], bool] A function to perform more complicated checks on the matched expression. Returns true if partitioning should proceed, false otherwise. diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 89abb2e..8d67db5 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -32,19 +32,19 @@ K_BROADCAST = 1 ## NODE TESTS def test_expr_pattern(): - ep = ExprPattern(relay.var('x', shape=(4, 1))) + ep = is_expr(relay.var('x', shape=(4, 1))) assert isinstance(ep, ExprPattern) assert isinstance(ep.expr, relay.Var) def test_var_pattern(): - v = is_input("x") + v = is_var("x") assert isinstance(v, VarPattern) assert v.name == "x" def test_constant_pattern(): - c = ConstantPattern() + c = is_constant() assert isinstance(c, ConstantPattern) @@ -65,7 +65,7 @@ def test_CallPattern(): def test_TuplePattern(): wc1 = wildcard() wc2 = wildcard() - t = TuplePattern([wc1, wc2]) + t = is_tuple([wc1, wc2]) assert isinstance(t, TuplePattern) assert isinstance(t.fields[0], WildcardPattern) assert isinstance(t.fields[1], WildcardPattern) @@ -74,8 +74,8 @@ def test_TuplePattern(): def test_TupleGetItemPattern(): wc1 = wildcard() wc2 = wildcard() - t = TuplePattern([wc1, wc2]) - tgi = TupleGetItemPattern(t, 1) + t = is_tuple([wc1, wc2]) + tgi = is_tuple_get_item(t, 1) assert isinstance(tgi, TupleGetItemPattern) assert isinstance(tgi.tuple, TuplePattern) assert isinstance(tgi.tuple.fields[0], WildcardPattern) @@ -120,10 +120,10 @@ def test_match_op_or(): def test_match_call_commutive(): x = relay.var('x') y = relay.var('y') - add_pattern = is_op('add')(is_input("x"), is_input("y")) + add_pattern = is_op('add')(is_var("x"), is_var("y")) assert add_pattern.match(x + y) assert add_pattern.match(y + x) - mul_pattern = is_op('multiply')(is_input("x"), is_input("y")) + mul_pattern = is_op('multiply')(is_var("x"), is_var("y")) assert mul_pattern.match(x * y) assert mul_pattern.match(y * x) @@ -131,10 +131,10 @@ def test_match_call_commutive(): def test_no_match_call_commutive(): x = relay.var('x') y = relay.var('y') - add_pattern = is_op('subtract')(is_input("x"), is_input("y")) + add_pattern = is_op('subtract')(is_var("x"), is_var("y")) assert add_pattern.match(x - y) assert not add_pattern.match(y - x) - add_pattern = is_op('divide')(is_input("x"), is_input("y")) + add_pattern = is_op('divide')(is_var("x"), is_var("y")) assert add_pattern.match(x / y) assert not add_pattern.match(y / x) @@ -211,7 +211,7 @@ def test_no_match_option(): def test_match_const(): - conv2d = is_op('nn.conv2d')(wildcard(), ConstantPattern()) + conv2d = is_op('nn.conv2d')(wildcard(), is_constant()) pattern = is_op('nn.bias_add')(conv2d, wildcard()) x = relay.var('x', shape=(1, 3, 224, 224)) @@ -232,11 +232,11 @@ def test_match_tuple(): x = relay.var('x') y = relay.var('y') z = relay.op.op.get("add") - tuple_pattern = TuplePattern((is_input("x"), wildcard(), is_op("add"))) + tuple_pattern = is_tuple((is_var("x"), wildcard(), is_op("add"))) assert tuple_pattern.match(relay.expr.Tuple((x, y, z))) - tuple_pattern = TuplePattern((is_input("x"), wildcard(), is_op("add"))) - tuple_get_item_pattern = TupleGetItemPattern(tuple_pattern, 1) + tuple_pattern = is_tuple((is_var("x"), wildcard(), is_op("add"))) + tuple_get_item_pattern = is_tuple_get_item(tuple_pattern, 1) assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 1)) @@ -244,11 +244,11 @@ def test_no_match_tuple(): x = relay.var('x') y = relay.var('y') z = relay.op.op.get("add") - tuple_pattern = TuplePattern((is_input('x'), wildcard(), is_op("add"), wildcard())) + tuple_pattern = is_tuple((is_var('x'), wildcard(), is_op("add"), wildcard())) assert not tuple_pattern.match(relay.expr.Tuple((x, y, z))) - tuple_pattern = TuplePattern((is_input('x'), wildcard(), is_op("add"))) - tuple_get_item_pattern = TupleGetItemPattern(tuple_pattern, 1) + tuple_pattern = is_tuple((is_var('x'), wildcard(), is_op("add"))) + tuple_get_item_pattern = is_tuple_get_item(tuple_pattern, 1) assert not tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple( (x, y, z)), 2)) @@ -596,7 +596,7 @@ class BatchnormCallback(DFPatternCallback): self.mean = wildcard() self.beta = wildcard() self.gamma = wildcard() - self.eps = ConstantPattern() + self.eps = is_constant() self.pattern = self.gamma * (self.x - self.mean) / is_op("sqrt")(self.var + self.eps) + \ self.beta @@ -760,8 +760,8 @@ def test_quadruple_rewrite_dominator(): def algebraic_simplify(expr): - zero = (ExprPattern(relay.const(0)) | ExprPattern(relay.const(0.0))) - one = (ExprPattern(relay.const(1)) | ExprPattern(relay.const(1.0))) + zero = (is_expr(relay.const(0)) | is_expr(relay.const(0.0))) + one = (is_expr(relay.const(1)) | is_expr(relay.const(1.0))) class ElwiseNullCallback(DFPatternCallback): def callback(self, pre, post, node_map): @@ -1182,35 +1182,32 @@ def test_partition_constant_embedding(): assert tvm.ir.structural_equal(lifted_func(x, wc, b), pattern.partition(reluc)) # Check lifting of input matches - pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_input()), + pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_var()), wildcard())) assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu)) assert tvm.ir.structural_equal(reluc, pattern.partition(reluc)) #Constants are not Inputs # Check embedding of constant matches - pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), - ConstantPattern()), + pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_constant()), wildcard())) assert tvm.ir.structural_equal(relu, pattern.partition(relu)) assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc)) # Check embedding of constant ExprPatterns - pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), - ExprPattern(wc)), + pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_expr(wc)), wildcard())) assert tvm.ir.structural_equal(relu, pattern.partition(relu)) assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc)) # Check lifting/embedding of Alt matches - pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_input() - | ConstantPattern()), - wildcard())) + pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')( + wildcard(), is_var() | is_constant()), wildcard())) assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu)) assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc)) # Check lifting/embedding of Alt matches with the other ordering pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')( - wildcard(), ConstantPattern() | is_input()), wildcard())) + wildcard(), is_constant() | is_var()), wildcard())) assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu)) assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc)) -- 2.7.4