# specific language governing permissions and limitations
# under the License.
"""The Relay Pattern Language and tooling."""
-from tvm.relay.expr import RelayExpr as Expr
+# pylint: disable=no-member
+from typing import Callable, Dict, List, Optional
+
import tvm._ffi
-from ...ir.base import Node
+from tvm.relay.expr import RelayExpr as Expr
+
+from ... import _ffi as tvm_ffi
from ...ir import make_node
+from ...ir.base import Node
from ...runtime import Object
-from ... import _ffi as tvm_ffi
from ..op import get
from . import _ffi as ffi
def __truediv__(self, other):
return is_op("divide")(self, other)
- def has_attr(self, attrs):
+ def has_attr(self, attrs: Dict[str, Object]):
"""
Add an attribute constraint to this pattern
attrs = make_node("DictAttrs", **attrs)
return AttrPattern(self, attrs)
- def has_type(self, ttype):
+ def has_type(self, ttype: tvm.ir.type.Type):
"""
Add a type constraint to this pattern
Parameters
----------
- ttype: tvm.relay.Type
+ ttype: tvm.ir.type.Type
The type to match
Returns
"""
return match(self, expr)
- def partition(self, expr: Expr, attrs=None, check=lambda x: True) -> Expr:
+ def partition(self,
+ expr: Expr,
+ attrs: Optional[Dict[str, Object]] = None,
+ check: Callable[[Expr], bool] = lambda x: True) -> Expr:
"""
Parition the expression into functions defined by this pattern
The expression to match.
attrs : Optional[Dict[str, Object]]
A dictionary of Attribute name/values to add to the paritioned function
- check : Function
+ check : Callable[[Expr], bool]
A function to perform more complicated checks on the matched expression.
Returns true if partitioning should proceed, false otherwise.
"""
return partition(self, expr, attrs, check)
- def dominates(self, parent, path=None):
+ def dominates(self, parent: "DFPattern", path: "DFPattern" = None):
"""
- Create a dominator for this pattern
+ Create a dominator for this pattern.
Parameters
----------
Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
- The resulting DominatorPattern
+ The resulting DominatorPattern.
"""
if path is None:
path = wildcard()
return DominatorPattern(parent, path, self)
- def optional(self, option_constructor):
+ def optional(self, option_constructor: Callable[["DFPattern"], "DFPattern"]):
"""
- Create a optional user of this pattern
+ Create a optional user of this pattern.
Parameters
----------
return self | option_constructor(self)
-def is_input(name: str = "") -> DFPattern:
+def is_var(name: str = "") -> "DFPattern":
"""
- Syntatic sugar for creating an optionally named VarPattern
+ Syntatic sugar for creating an optionally named VarPattern.
Parameters
----------
name: str
- The name of the input pattern to match
+ The name of the input pattern to match.
Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
- The resulting InputPattern
+ The resulting pattern.
"""
return VarPattern(name)
-def is_op(op_name: str) -> DFPattern:
+def is_constant() -> "DFPattern":
+ """
+ Syntatic sugar for creating a ConstantPattern.
+
+ Parameters
+ ----------
+ name: str
+ The name of the input pattern to match.
+
+ Returns
+ -------
+ result: tvm.relay.dataflow_pattern.DFPattern
+ The resulting pattern.
+ """
+ return ConstantPattern()
+
+
+def is_expr(expr: Expr) -> "DFPattern":
+ """
+ Syntatic sugar for creating an ExprPattern.
+
+ Parameters
+ ----------
+ expr: Expr
+ The Relay expression to match.
+
+ Returns
+ -------
+ result: tvm.relay.dataflow_pattern.DFPattern
+ The resulting pattern.
+ """
+ return ExprPattern(expr)
+
+
+def is_op(op_name: str) -> "DFPattern":
"""
- Syntatic sugar for creating an operator ExprPattern
+ Syntatic sugar for creating an operator ExprPattern.
Parameters
----------
return ExprPattern(op)
-def wildcard() -> DFPattern:
+def is_tuple(fields: tvm.ir.container.Array) -> "DFPattern":
"""
- Syntatic sugar for creating a WildcardPattern
+ Syntatic sugar for creating an ExprPattern.
+
+ Parameters
+ ----------
+ fields : Array[tvm.relay.dataflow_pattern.DFPattern]
+ The fields in the tuple.
Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
- The resulting WildcardPattern
+ The resulting pattern.
+ """
+ return TuplePattern(fields)
+
+
+def is_tuple_get_item(tuple_value: "DFPattern", index: int) -> "DFPattern":
+ """
+ Syntatic sugar for creating an ExprPattern.
+
+ Parameters
+ ----------
+ tuple_value: tvm.relay.dataflow_pattern.DFPattern
+ The input tuple expression.
+
+ index: int
+ The index.
+
+ Returns
+ -------
+ result: tvm.relay.dataflow_pattern.DFPattern
+ The resulting pattern.
+ """
+ return TupleGetItemPattern(tuple_value, index)
+
+
+def wildcard() -> "DFPattern":
+ """
+ Syntatic sugar for creating a WildcardPattern.
+
+ Returns
+ -------
+ result: tvm.relay.dataflow_pattern.DFPattern
+ The resulting pattern.
"""
return WildcardPattern()
-def has_type(ttype, pattern: DFPattern = None) -> DFPattern:
+def has_type(ttype, pattern: "DFPattern" = None) -> "DFPattern":
"""
Syntatic sugar for creating a TypePattern
pattern: tvm.relay.dataflow_pattern.DFPattern
The pattern that needs type annotation
- ttype: tvm.relay.Type
+ ttype: tvm.ir.type.Type
The type to match
Returns
return TypePattern(pattern, ttype)
-def has_attr(attrs, pattern=None) -> DFPattern:
+def has_attr(attrs, pattern=None) -> "DFPattern":
"""
Syntatic sugar for creating an AttrPattern
return pattern.has_attr(attrs)
-def dominates(parent: DFPattern, path: DFPattern, child: DFPattern) -> DFPattern:
+def dominates(parent: "DFPattern", path: "DFPattern", child: "DFPattern") -> "DFPattern":
"""
Syntatic sugar for creating an Dominator pattern
Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
- The resulting DominatorPattern
+ The resulting DominatorPattern.
"""
return DominatorPattern(parent, path, child)
-def match(pattern: DFPattern, expr: Expr) -> bool:
+def match(pattern: "DFPattern", expr: Expr) -> bool:
"""
Match a pattern to an expression
The name of the variable. Optional, if not provided,
the pattern will match any VarNode.
- type_annotation: tvm.relay.Type, optional
+ type_annotation: tvm.ir.type.Type, optional
The type annotation on the variable.
"""
- def __init__(self, name_hint="", type_annotation=None):
- self.__init_handle_by_constructor__(
- ffi.VarPattern, name_hint, type_annotation)
+ def __init__(self, name_hint: str = "", type_annotation: Optional[tvm.ir.type.Type] = None):
+ self.__init_handle_by_constructor__(ffi.VarPattern, name_hint, type_annotation)
@register_df_node
args: List[realy.dataflow_pattern.DFPattern]
The arguments to the call.
- attrs: Optional[tvm.Attrs]
+ attrs: Optional[tvm.ir.attrs.Attrs]
Attributes to the call, can be None
- type_args: Optional[List[tvm.relay.Type]]
+ type_args: Optional[List[tvm.ir.type.Type]]
The additional type arguments, this is only
used in advanced usecase of template functions.
"""
- def __init__(self, op, args, attrs=None, type_args=None):
+ def __init__(self,
+ op: "DFPattern",
+ args: List["DFPattern"],
+ attrs: Optional[tvm.ir.attrs.Attrs] = None,
+ type_args: Optional[List[tvm.ir.type.Type]] = None):
if not type_args:
type_args = []
- self.__init_handle_by_constructor__(
- ffi.CallPattern, op, args, attrs, type_args)
+ self.__init_handle_by_constructor__(ffi.CallPattern, op, args, attrs, type_args)
@register_df_node
Parameters
----------
- fields : List[tvm.relay.dataflow_pattern.DFPattern]
+ fields : Array[tvm.relay.dataflow_pattern.DFPattern]
The fields in the tuple.
"""
- def __init__(self, fields):
+ def __init__(self, fields: tvm.ir.container.Array):
self.__init_handle_by_constructor__(ffi.TuplePattern, fields)
- def __getitem__(self, index):
+ def __getitem__(self, index: int):
if index >= len(self):
raise IndexError("TuplePattern index out of range")
return self.fields[index]
The index.
"""
- def __init__(self, tuple_value: DFPattern, index):
- self.__init_handle_by_constructor__(
- ffi.TupleGetItemPattern, tuple_value, index)
+ def __init__(self, tuple_value: "DFPattern", index: int):
+ self.__init_handle_by_constructor__(ffi.TupleGetItemPattern, tuple_value, index)
@register_df_node
Parameters
----------
left: tvm.relay.dataflow_pattern.DFPattern
- One possible matching Pattern
+ One possible matching pattern.
right: tvm.relay.dataflow_pattern.DFPattern
- One possible matching Pattern
+ One possible matching pattern.
"""
- def __init__(self, left: DFPattern, right: DFPattern):
- self.__init_handle_by_constructor__(
- ffi.AltPattern, left, right)
+ def __init__(self, left: "DFPattern", right: "DFPattern"):
+ self.__init_handle_by_constructor__(ffi.AltPattern, left, right)
@register_df_node
Parameters
----------
pattern: tvm.relay.dataflow_pattern.DFPattern
- The input pattern that needs type annotation
+ The input pattern that needs type annotation.
- ttype: tvm.relay.Type
- The type to match
+ ttype: tvm.ir.type.Type
+ The type to match.
"""
- def __init__(self, pattern: DFPattern, ttype):
- self.__init_handle_by_constructor__(
- ffi.TypePattern, pattern, ttype)
+ def __init__(self, pattern: "DFPattern", ttype: tvm.ir.type.Type):
+ self.__init_handle_by_constructor__(ffi.TypePattern, pattern, ttype)
@register_df_node
class AttrPattern(DFPattern):
"""Get match an expression with a certain attributes.
- Currently only supports Op Attributes, not call Attributes
+ Currently only supports Op Attributes, not call Attributes.
Parameters
----------
pattern: tvm.relay.dataflow_pattern.DFPattern
The input pattern.
- attrs: tvm.Attrs
- The attributes to match
+ attrs: tvm.ir.attrs.Attrs
+ The attributes to match.
"""
- def __init__(self, pattern: DFPattern, attrs):
- self.__init_handle_by_constructor__(
- ffi.AttrPattern, pattern, attrs)
+ def __init__(self, pattern: "DFPattern", attrs: tvm.ir.attrs.Attrs):
+ self.__init_handle_by_constructor__(ffi.AttrPattern, pattern, attrs)
@register_df_node
----------
parent: tvm.relay.dataflow_pattern.DFPattern
The parent, i.e., the single node which produces something,
- later aggregated by the child
+ later aggregated by the child.
path: tvm.relay.dataflow_pattern.DFPattern
The fuzzy path pattern between parent and child,
- typically matches elementwise ops
+ typically matches elementwise ops.
child: tvm.relay.dataflow_pattern.DFPattern
The last node in the domination which is the end user
- for all nodes in the path and the parent
+ for all nodes in the path and the parent.
"""
- def __init__(self, parent: DFPattern, path: DFPattern, child: DFPattern):
- self.__init_handle_by_constructor__(
- ffi.DominatorPattern, parent, path, child)
+ def __init__(self, parent: "DFPattern", path: "DFPattern", child: "DFPattern"):
+ self.__init_handle_by_constructor__(ffi.DominatorPattern, parent, path, child)
class DFPatternCallback:
- """A Callback for Pattern Rewriting
+ """A Callback for Pattern Rewriting.
When rewrite is called on this DFPatternCallback, the backend will find matches for the
pattern, call the callback function, and replace the matched expression with whatever
Returns
-------
result : tvm.relay.Expr
- The Expression with matched subgraphs rewritten by the callbacks
+ The Expression with matched subgraphs rewritten by the callbacks.
"""
return rewrite(self, expr)
- def callback(self, pre, post, node_map):
+ def callback(self, pre: Expr, post: Expr, node_map: tvm.ir.container.Map) -> Expr:
"""
Callback function to use when we found a match to the pattern
The matching expression from the original graph.
post : tvm.relay.Expr
The matching expression with rewritten inputs
- node_map : Map(DFPattern, List(Expr))
+ node_map : tvm.ir.container.Map[DFPattern, List[Expr]]
The map between patterns and matched expressions
Returns
class _DFPatternCallback(Object):
"""C++ implemenation"""
def __init__(self, pattern, callback):
- self.__init_handle_by_constructor__(
- ffi.DFPatternCallback, pattern, callback)
+ self.__init_handle_by_constructor__(ffi.DFPatternCallback, pattern, callback)
def rewrite(callbacks, expr: Expr) -> Expr:
"""
- Rewrite expression with the given callbacks
+ Rewrite expression with the given callbacks.
Parameters
----------
Returns
-------
result : tvm.relay.Expr
- The Expression with matched subgraphs rewritten by the callbacks
+ The Expression with matched subgraphs rewritten by the callbacks.
"""
if isinstance(callbacks, DFPatternCallback):
tmp = [_DFPatternCallback(callbacks.pattern, callbacks.callback)]
return ffi.rewrite(tmp, expr)
-def partition(pattern: DFPattern, expr: Expr, attrs=None, check=lambda x: True) -> Expr:
+
+def partition(pattern: "DFPattern",
+ expr: Expr,
+ attrs: Optional[Dict[str, Object]] = None,
+ check: Callable[[Expr], bool] = lambda x: True) -> Expr:
"""
Parition the expression into a series of functions that match the pattern
The expression to split into functions
attrs : Optional[Dict[str, Object]]
A dict of attributes to apply to the partitioned function
- check : Function
+ check : Callable[[Expr], bool]
A function to perform more complicated checks on the matched expression.
Returns true if partitioning should proceed, false otherwise.
## NODE TESTS
def test_expr_pattern():
- ep = ExprPattern(relay.var('x', shape=(4, 1)))
+ ep = is_expr(relay.var('x', shape=(4, 1)))
assert isinstance(ep, ExprPattern)
assert isinstance(ep.expr, relay.Var)
def test_var_pattern():
- v = is_input("x")
+ v = is_var("x")
assert isinstance(v, VarPattern)
assert v.name == "x"
def test_constant_pattern():
- c = ConstantPattern()
+ c = is_constant()
assert isinstance(c, ConstantPattern)
def test_TuplePattern():
wc1 = wildcard()
wc2 = wildcard()
- t = TuplePattern([wc1, wc2])
+ t = is_tuple([wc1, wc2])
assert isinstance(t, TuplePattern)
assert isinstance(t.fields[0], WildcardPattern)
assert isinstance(t.fields[1], WildcardPattern)
def test_TupleGetItemPattern():
wc1 = wildcard()
wc2 = wildcard()
- t = TuplePattern([wc1, wc2])
- tgi = TupleGetItemPattern(t, 1)
+ t = is_tuple([wc1, wc2])
+ tgi = is_tuple_get_item(t, 1)
assert isinstance(tgi, TupleGetItemPattern)
assert isinstance(tgi.tuple, TuplePattern)
assert isinstance(tgi.tuple.fields[0], WildcardPattern)
def test_match_call_commutive():
x = relay.var('x')
y = relay.var('y')
- add_pattern = is_op('add')(is_input("x"), is_input("y"))
+ add_pattern = is_op('add')(is_var("x"), is_var("y"))
assert add_pattern.match(x + y)
assert add_pattern.match(y + x)
- mul_pattern = is_op('multiply')(is_input("x"), is_input("y"))
+ mul_pattern = is_op('multiply')(is_var("x"), is_var("y"))
assert mul_pattern.match(x * y)
assert mul_pattern.match(y * x)
def test_no_match_call_commutive():
x = relay.var('x')
y = relay.var('y')
- add_pattern = is_op('subtract')(is_input("x"), is_input("y"))
+ add_pattern = is_op('subtract')(is_var("x"), is_var("y"))
assert add_pattern.match(x - y)
assert not add_pattern.match(y - x)
- add_pattern = is_op('divide')(is_input("x"), is_input("y"))
+ add_pattern = is_op('divide')(is_var("x"), is_var("y"))
assert add_pattern.match(x / y)
assert not add_pattern.match(y / x)
def test_match_const():
- conv2d = is_op('nn.conv2d')(wildcard(), ConstantPattern())
+ conv2d = is_op('nn.conv2d')(wildcard(), is_constant())
pattern = is_op('nn.bias_add')(conv2d, wildcard())
x = relay.var('x', shape=(1, 3, 224, 224))
x = relay.var('x')
y = relay.var('y')
z = relay.op.op.get("add")
- tuple_pattern = TuplePattern((is_input("x"), wildcard(), is_op("add")))
+ tuple_pattern = is_tuple((is_var("x"), wildcard(), is_op("add")))
assert tuple_pattern.match(relay.expr.Tuple((x, y, z)))
- tuple_pattern = TuplePattern((is_input("x"), wildcard(), is_op("add")))
- tuple_get_item_pattern = TupleGetItemPattern(tuple_pattern, 1)
+ tuple_pattern = is_tuple((is_var("x"), wildcard(), is_op("add")))
+ tuple_get_item_pattern = is_tuple_get_item(tuple_pattern, 1)
assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 1))
x = relay.var('x')
y = relay.var('y')
z = relay.op.op.get("add")
- tuple_pattern = TuplePattern((is_input('x'), wildcard(), is_op("add"), wildcard()))
+ tuple_pattern = is_tuple((is_var('x'), wildcard(), is_op("add"), wildcard()))
assert not tuple_pattern.match(relay.expr.Tuple((x, y, z)))
- tuple_pattern = TuplePattern((is_input('x'), wildcard(), is_op("add")))
- tuple_get_item_pattern = TupleGetItemPattern(tuple_pattern, 1)
+ tuple_pattern = is_tuple((is_var('x'), wildcard(), is_op("add")))
+ tuple_get_item_pattern = is_tuple_get_item(tuple_pattern, 1)
assert not tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple(
(x, y, z)), 2))
self.mean = wildcard()
self.beta = wildcard()
self.gamma = wildcard()
- self.eps = ConstantPattern()
+ self.eps = is_constant()
self.pattern = self.gamma * (self.x - self.mean) / is_op("sqrt")(self.var + self.eps) + \
self.beta
def algebraic_simplify(expr):
- zero = (ExprPattern(relay.const(0)) | ExprPattern(relay.const(0.0)))
- one = (ExprPattern(relay.const(1)) | ExprPattern(relay.const(1.0)))
+ zero = (is_expr(relay.const(0)) | is_expr(relay.const(0.0)))
+ one = (is_expr(relay.const(1)) | is_expr(relay.const(1.0)))
class ElwiseNullCallback(DFPatternCallback):
def callback(self, pre, post, node_map):
assert tvm.ir.structural_equal(lifted_func(x, wc, b), pattern.partition(reluc))
# Check lifting of input matches
- pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_input()),
+ pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_var()),
wildcard()))
assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
assert tvm.ir.structural_equal(reluc, pattern.partition(reluc)) #Constants are not Inputs
# Check embedding of constant matches
- pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(),
- ConstantPattern()),
+ pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_constant()),
wildcard()))
assert tvm.ir.structural_equal(relu, pattern.partition(relu))
assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))
# Check embedding of constant ExprPatterns
- pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(),
- ExprPattern(wc)),
+ pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_expr(wc)),
wildcard()))
assert tvm.ir.structural_equal(relu, pattern.partition(relu))
assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))
# Check lifting/embedding of Alt matches
- pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_input()
- | ConstantPattern()),
- wildcard()))
+ pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(
+ wildcard(), is_var() | is_constant()), wildcard()))
assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))
# Check lifting/embedding of Alt matches with the other ordering
pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(
- wildcard(), ConstantPattern() | is_input()), wildcard()))
+ wildcard(), is_constant() | is_var()), wildcard()))
assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))