[PatternLang] Add ConstantPattern (#5689)
authorCody Yu <comaniac0422@gmail.com>
Thu, 28 May 2020 23:48:07 +0000 (16:48 -0700)
committerGitHub <noreply@github.com>
Thu, 28 May 2020 23:48:07 +0000 (08:48 +0900)
* Add ConstantPattern

* update doc

docs/langref/relay_pattern.rst
include/tvm/relay/dataflow_pattern.h
include/tvm/relay/dataflow_pattern_functor.h
python/tvm/relay/dataflow_pattern/__init__.py
src/relay/ir/dataflow_matcher.cc
src/relay/ir/dataflow_pattern.cc
src/relay/ir/dataflow_pattern_functor.cc
src/relay/ir/indexed_graph.cc
tests/python/relay/test_dataflow_pattern.py

index c129544..7bb7bdf 100644 (file)
@@ -37,6 +37,11 @@ for more use cases.
 
 .. _tests/python/relay/test_dataflow_pattern.py: https://github.com/apache/incubator-tvm/blob/master/tests/python/relay/test_dataflow_pattern.py
 
+.. note::
+
+    If you cannot find the corresponding pattern node to match the Relay node you want,
+    you are welcome to raise an issue or submit a PR to add it.
+
 Matching One of Two Ops
 ***********************
 
@@ -131,6 +136,44 @@ The next example is matching a pattern of batch_norm -> get(0) -> relu:
         out = relay.nn.relu(tuple_get_item_node)
         pat.match(out)
 
+The next example is matching a constant node regarding its values. This is useful to check
+if a specific parameter in a subgraph has been bind or not.
+
+.. code-block:: python
+
+    def test_match_constant():
+        conv2d = is_op('nn.conv2d')(wildcard(), ConstantPattern())
+        pattern = is_op('nn.bias_add')(conv2d, wildcard())
+
+        x = relay.var('x', shape=(1, 3, 224, 224))
+        w = relay.var('w', shape=(3, 3, 3, 3))
+        b = relay.var('b', shape=(3, ))
+        conv2d = relay.op.nn.conv2d(x, w)
+        out = relay.op.nn.bias_add(conv2d, b)
+        func = relay.Function([x, w, b], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        # Two inputs of the conv2d in the graph are VarNode by default, so no match.
+        assert not pattern.match(mod['main'].body)
+
+        # The second input (weight) has been bind with constant values so it is now a constant node.
+        mod["main"] = bind_params_by_name(mod["main"],
+                                        {'w': tvm.nd.array(np.ones(shape=(3, 3, 3, 3)))})
+        assert pattern.match(mod['main'].body)
+
+On the other hand, if you need to match the constant with a specific value, you can directly
+use ``ExprPattern``. This could be useful for algebraic simplify.
+
+.. code-block:: python
+
+    def test_match_plus_zero():
+        zero = (ExprPattern(relay.const(0)) | ExprPattern(relay.const(0.0)))
+        pattern = wildcard() + zero
+        
+        x = relay.Var('x')
+        y = x + relay.const(0)
+        assert pattern.match(y)
+
 The next example is matching function nodes with a specific attribute:
 
 .. code-block:: python
index a8db51f..80a5d6f 100644 (file)
@@ -112,6 +112,24 @@ class VarPattern : public DFPattern {
 };
 
 /*!
+ * \brief A Pattern to Match a Relay Constant
+ */
+class ConstantPattern;
+/*! \brief Container for Constant */
+class ConstantPatternNode : public DFPatternNode {
+ public:
+  void VisitAttrs(tvm::AttrVisitor* v) {}
+
+  static constexpr const char* _type_key = "relay.dataflow_pattern.ConstantPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ConstantPatternNode, DFPatternNode);
+};
+
+class ConstantPattern : public DFPattern {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(ConstantPattern, DFPattern, ConstantPatternNode);
+};
+
+/*!
  * \brief Call corresponds to operator invocation.
  *  Corresponds to the operator in computational graph terminology.
  */
index 05c2147..a1140ae 100644 (file)
@@ -91,6 +91,7 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
   virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const ConstantPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPatternDefault_(const Object* op, Args...) {
     LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
@@ -111,6 +112,7 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(ConstantPatternNode);
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);
     return vtable;
   }
@@ -134,6 +136,7 @@ class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {
   void VisitDFPattern_(const TuplePatternNode* op) override;
   void VisitDFPattern_(const TypePatternNode* op) override;
   void VisitDFPattern_(const VarPatternNode* op) override;
+  void VisitDFPattern_(const ConstantPatternNode* op) override;
   void VisitDFPattern_(const WildcardPatternNode* op) override;
 
  protected:
index ca83e06..e8f73ed 100644 (file)
@@ -332,6 +332,14 @@ class VarPattern(DFPattern):
 
 
 @register_df_node
+class ConstantPattern(DFPattern):
+    """A pattern matching a Relay Constant.
+    """
+    def __init__(self):
+        self.__init_handle_by_constructor__(ffi.ConstantPattern)
+
+
+@register_df_node
 class CallPattern(DFPattern):
     """A pattern matching a function call node in Relay.
 
index 409dbc3..a7e4b37 100644 (file)
@@ -56,6 +56,7 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Ex
   bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
   bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
   bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) override;
   bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
 
   void ClearMap(size_t watermark);
@@ -394,6 +395,10 @@ bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& exp
   return matches;
 }
 
+bool DFPatternMatcher::VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) {
+  return expr.as<ConstantNode>() != nullptr;
+}
+
 bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) {
   return true;
 }
index 826a035..2809131 100644 (file)
@@ -69,6 +69,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->stream << ")";
     });
 
+TVM_REGISTER_NODE_TYPE(ConstantPatternNode);
+
+TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ConstantPattern").set_body_typed([]() {
+  auto c = ConstantPattern(make_object<ConstantPatternNode>());
+  return c;
+});
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<ConstantPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
+      p->stream << "ConstantPattern()";
+    });
+
 CallPattern::CallPattern(DFPattern op, Array<DFPattern> args, Attrs attrs, Array<Type> type_args) {
   ObjectPtr<CallPatternNode> n = make_object<CallPatternNode>();
   n->op = std::move(op);
index c7c34c8..ee44bcb 100644 (file)
@@ -71,6 +71,8 @@ void DFPatternVisitor::VisitDFPattern_(const TypePatternNode* op) { VisitDFPatte
 
 void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {}
 
+void DFPatternVisitor::VisitDFPattern_(const ConstantPatternNode* op) {}
+
 void DFPatternVisitor::VisitDFPattern_(const WildcardPatternNode* op) {}
 
 }  // namespace relay
index 7f7a5ff..0d4b90d 100644 (file)
@@ -270,6 +270,8 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const DFPattern& pattern) {
 
     void VisitDFPattern_(const VarPatternNode* op, NodePtr parent) override {}
 
+    void VisitDFPattern_(const ConstantPatternNode* op, NodePtr parent) override {}
+
     void VisitDFPattern_(const WildcardPatternNode* op, NodePtr parent) override {}
   };
   return Annotator(Creator().CreateGraph(pattern)).Annotate();
index 6a66f60..5d91dcb 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+# pylint: disable=unused-wildcard-import
+import numpy as np
+
 import tvm
 from tvm import relay
+from tvm.relay.build_module import bind_params_by_name
 from tvm.relay.dataflow_pattern import *
 from tvm.relay.testing import run_opt_pass
-import numpy as np
 
 # NB: 1 corresponds to the C++ enum that specicfies this
 # we loose the type safety due to the Python/C++ calling
@@ -26,21 +29,30 @@ import numpy as np
 K_ELEMWISE = 0
 K_BROADCAST = 1
 
+
 ## NODE TESTS
 def test_expr_pattern():
     ep = ExprPattern(relay.var('x', shape=(4, 1)))
     assert isinstance(ep, ExprPattern)
     assert isinstance(ep.expr, relay.Var)
 
+
 def test_var_pattern():
     v = is_input("x")
     assert isinstance(v, VarPattern)
     assert v.name == "x"
 
+
+def test_constant_pattern():
+    c = ConstantPattern()
+    assert isinstance(c, ConstantPattern)
+
+
 def test_wildcard_pattern():
     wc = wildcard()
     assert isinstance(wc, WildcardPattern)
 
+
 def test_CallPattern():
     wc1 = wildcard()
     wc2 = wildcard()
@@ -49,6 +61,7 @@ def test_CallPattern():
     assert isinstance(c.args[0], WildcardPattern)
     assert isinstance(c.args[1], WildcardPattern)
 
+
 def test_TuplePattern():
     wc1 = wildcard()
     wc2 = wildcard()
@@ -57,6 +70,7 @@ def test_TuplePattern():
     assert isinstance(t.fields[0], WildcardPattern)
     assert isinstance(t.fields[1], WildcardPattern)
 
+
 def test_TupleGetItemPattern():
     wc1 = wildcard()
     wc2 = wildcard()
@@ -67,34 +81,42 @@ def test_TupleGetItemPattern():
     assert isinstance(tgi.tuple.fields[0], WildcardPattern)
     assert isinstance(tgi.tuple.fields[1], WildcardPattern)
 
+
 def test_AltPattern():
     is_add_or_sub = is_op('add') | is_op('subtract')
     assert isinstance(is_add_or_sub, AltPattern)
 
+
 def test_TypePattern():
     ttype = relay.TensorType((10, 10), "float32")
     ty_pat = has_type(ttype)
     assert isinstance(ty_pat, TypePattern)
     assert ty_pat.type == ttype
 
+
 def test_AttrPattern():
     op = is_op('add').has_attr({"TOpPattern": K_ELEMWISE})
     assert isinstance(op, AttrPattern)
     assert op.attrs["TOpPattern"] == K_ELEMWISE
 
+
 ## MATCHER TESTS
 
+
 def test_match_op():
     assert is_op('add').match(relay.op.op.get("add"))
 
+
 def test_no_match_op():
     assert not is_op('add').match(relay.op.op.get("subtract"))
 
+
 def test_match_op_or():
     is_add_or_sub = is_op('add') | is_op('subtract')
     assert is_add_or_sub.match(relay.op.op.get("add"))
     assert is_add_or_sub.match(relay.op.op.get("subtract"))
 
+
 def test_match_call_commutive():
     x = relay.var('x')
     y = relay.var('y')
@@ -105,6 +127,7 @@ def test_match_call_commutive():
     assert mul_pattern.match(x * y)
     assert mul_pattern.match(y * x)
 
+
 def test_no_match_call_commutive():
     x = relay.var('x')
     y = relay.var('y')
@@ -115,26 +138,27 @@ def test_no_match_call_commutive():
     assert add_pattern.match(x / y)
     assert not add_pattern.match(y / x)
 
+
 def test_match_call():
     x = relay.var('x')
     y = relay.var('y')
     add_pattern = is_op('add')(wildcard(), wildcard())
     assert add_pattern.match(x + y)
 
+
 def test_no_match_call():
     x = relay.var('x')
     y = relay.var('y')
     add_pattern = is_op('add')(wildcard(), wildcard())
     assert not add_pattern.match(x - y)
 
+
 def test_match_option():
     x = relay.var('x')
     w = relay.var('w')
     b = relay.var('b')
-    pattern = is_op("nn.relu")(
-            is_op("nn.conv2d")(wildcard(), wildcard()
-                             ).optional(lambda x: is_op("nn.bias_add")(x, wildcard()))
-        )
+    pattern = is_op("nn.relu")(is_op("nn.conv2d")(
+        wildcard(), wildcard()).optional(lambda x: is_op("nn.bias_add")(x, wildcard())))
 
     conv2d = relay.op.nn.conv2d(x, w)
     relu = relay.op.nn.relu(conv2d)
@@ -159,14 +183,13 @@ def test_match_option():
     assert pattern.match(tanh2)
     assert not pattern.match(relu2)
 
+
 def test_no_match_option():
     x = relay.var('x')
     w = relay.var('w')
     b = relay.var('b')
-    pattern = is_op("nn.relu")(
-            is_op("nn.conv2d")(wildcard(), wildcard()
-                             ).optional(lambda x: is_op("nn.bias_add")(x, wildcard()))
-        )
+    pattern = is_op("nn.relu")(is_op("nn.conv2d")(
+        wildcard(), wildcard()).optional(lambda x: is_op("nn.bias_add")(x, wildcard())))
 
     conv2d = relay.op.nn.conv2d(x, w)
     relu = relay.op.tanh(conv2d)
@@ -186,46 +209,62 @@ def test_no_match_option():
     relu = relay.op.nn.relu(bias_add)
     assert not pattern.match(relu)
 
-def test_match_tuple():
-    x = relay.var('x')
-    y = relay.var('y')
-    z = relay.op.op.get("add")
-    tuple_pattern = TuplePattern((is_input("x"), wildcard(), is_op("add")))
-    assert tuple_pattern.match(relay.expr.Tuple((x,y,z)))
 
-def test_no_match_tuple():
-    x = relay.var('x')
-    y = relay.var('y')
-    z = relay.op.op.get("add")
-    tuple_pattern = TuplePattern((is_input('x'), wildcard(), is_op("add"), wildcard()))
-    assert not tuple_pattern.match(relay.expr.Tuple((x,y,z)))
+def test_match_const():
+    conv2d = is_op('nn.conv2d')(wildcard(), ConstantPattern())
+    pattern = is_op('nn.bias_add')(conv2d, wildcard())
+
+    x = relay.var('x', shape=(1, 3, 224, 224))
+    w = relay.var('w', shape=(3, 3, 3, 3))
+    b = relay.var('b', shape=(3, ))
+    conv2d = relay.op.nn.conv2d(x, w)
+    out = relay.op.nn.bias_add(conv2d, b)
+    func = relay.Function([x, w, b], out)
+    mod = tvm.IRModule.from_expr(func)
+
+    assert not pattern.match(mod['main'].body)
+    mod["main"] = bind_params_by_name(mod["main"],
+                                      {'w': tvm.nd.array(np.ones(shape=(3, 3, 3, 3)))})
+    assert pattern.match(mod['main'].body)
+
 
 def test_match_tuple():
     x = relay.var('x')
     y = relay.var('y')
     z = relay.op.op.get("add")
     tuple_pattern = TuplePattern((is_input("x"), wildcard(), is_op("add")))
+    assert tuple_pattern.match(relay.expr.Tuple((x, y, z)))
+
+    tuple_pattern = TuplePattern((is_input("x"), wildcard(), is_op("add")))
     tuple_get_item_pattern = TupleGetItemPattern(tuple_pattern, 1)
-    assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x,y,z)), 1))
+    assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 1))
+
 
 def test_no_match_tuple():
     x = relay.var('x')
     y = relay.var('y')
     z = relay.op.op.get("add")
+    tuple_pattern = TuplePattern((is_input('x'), wildcard(), is_op("add"), wildcard()))
+    assert not tuple_pattern.match(relay.expr.Tuple((x, y, z)))
+
     tuple_pattern = TuplePattern((is_input('x'), wildcard(), is_op("add")))
     tuple_get_item_pattern = TupleGetItemPattern(tuple_pattern, 1)
-    assert not tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x,y,z)), 2))
+    assert not tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple(
+        (x, y, z)), 2))
+
 
 def test_match_type():
     x = relay.var('x', shape=(10, 10), dtype="float32")
     ty_pat = has_type(relay.TensorType((10, 10), "float32"))
     assert ty_pat.match(x)
 
+
 def test_no_match_type():
     x = relay.var('x', shape=(10, 10), dtype="int32")
     ty_pat = has_type(relay.TensorType((10, 10), "float32"))
     assert not ty_pat.match(x)
 
+
 def test_match_op_attr():
     op = is_op('add').has_attr({"TOpPattern": K_BROADCAST})
     op_pat = op(wildcard(), wildcard())
@@ -233,6 +272,7 @@ def test_match_op_attr():
     y = relay.var('y')
     assert op_pat.match(x + y)
 
+
 def test_no_match_op_attr():
     op = is_op('nn.dense').has_attr({"TOpPattern": K_ELEMWISE})
     op_pat = op(wildcard(), wildcard())
@@ -245,6 +285,7 @@ def test_no_match_op_attr():
     y = relay.var('y')
     assert not op_pat.match(x - y)
 
+
 def test_match_func_attr():
     pattern = wildcard().has_attr({"Composite": "add"})
     x = relay.var('x')
@@ -252,6 +293,7 @@ def test_match_func_attr():
     f = relay.Function([x, y], x + y).with_attr("Composite", "add")
     assert pattern.match(f)
 
+
 def test_no_match_func_attr():
     pattern = wildcard().has_attr({"Composite": "add"})
     x = relay.var('x')
@@ -262,12 +304,14 @@ def test_no_match_func_attr():
     f = relay.Function([x, y], x + y).with_attr("Composite", "conv_bias")
     assert not pattern.match(f)
 
+
 def test_match_call_attr():
     is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"data_layout": "NCHW"})
     x = relay.var('x')
     y = relay.var('y')
     assert is_conv2d.match(relay.op.nn.conv2d(x, y))
 
+
 def test_no_match_call_attr():
     x = relay.var('x')
     y = relay.var('y')
@@ -278,6 +322,7 @@ def test_no_match_call_attr():
     is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"RandomAttr": "NCHW"})
     assert not is_conv2d.match(relay.op.nn.conv2d(x, y))
 
+
 def test_match_diamond():
     # Pattern
     is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
@@ -296,6 +341,7 @@ def test_match_diamond():
     # Check
     assert diamond.match(out)
 
+
 def test_no_match_diamond():
     # Pattern
     is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
@@ -309,12 +355,12 @@ def test_no_match_diamond():
     conv2d = relay.op.nn.conv2d(inp, weight)
     relu = relay.op.nn.relu(conv2d)
     leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
-    out = relu + leaky_relu
 
     # Check
     assert not diamond.match(leaky_relu)
     assert not diamond.match(relu)
 
+
 def test_match_fake_diamond():
     # Pattern
     is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
@@ -380,10 +426,11 @@ def test_match_dominator():
 
     # Check
     assert diamond.match(out)
-    
+
     # Fuzzy path/nested Diamond
     is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
-    is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op('add')(wildcard(), wildcard())
+    is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(
+        wildcard()) | is_op('add')(wildcard(), wildcard())
     reduction = is_op('add')(wildcard(), wildcard())
     diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
 
@@ -398,6 +445,7 @@ def test_match_dominator():
 
     assert diamond.match(out)
 
+
 def test_not_match_dominator():
     is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
     is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())
@@ -451,36 +499,47 @@ def test_not_match_dominator():
     # Check
     assert not diamond.match(out)
 
+
 def test_rewrite():
     x = relay.var('x')
     y = relay.var('y')
     add_pattern = is_op('add')(wildcard(), wildcard())
     sub_pattern = is_op('subtract')(wildcard(), wildcard())
+
     class TestRewrite(DFPatternCallback):
         def __init__(self):
             self.pattern = add_pattern
+
         def callback(self, pre, post, node_map):
             return post.args[0] - post.args[1]
+
     out = rewrite(TestRewrite(), x + y)
     assert sub_pattern.match(out)
 
+
 def test_rewrite_func():
     x = relay.var('x')
     w = relay.var('w')
     y = relay.var('y')
     add_pattern = is_op('add')(wildcard(), wildcard())
     sub_pattern = is_op('subtract')(wildcard(), wildcard())
+
     class TestRewrite(DFPatternCallback):
         def __init__(self):
             self.pattern = add_pattern
+
         def callback(self, pre, post, node_map):
             return post.args[0] - post.args[1]
+
     inpf = relay.var("input")
     weightf = relay.var("weight")
-    func = relay.Function([inpf, weightf], relay.op.nn.relu(relay.op.nn.conv2d(inpf, weightf)), attrs=None)
-    out = rewrite(TestRewrite(), func(x,w) + y)
+    func = relay.Function([inpf, weightf],
+                          relay.op.nn.relu(relay.op.nn.conv2d(inpf, weightf)),
+                          attrs=None)
+    out = rewrite(TestRewrite(), func(x, w) + y)
     assert sub_pattern.match(out)
 
+
 def test_nested_rewrite():
     class PatternCallback(DFPatternCallback):
         def __init__(self, pattern):
@@ -510,6 +569,7 @@ def test_nested_rewrite():
 
     assert tvm.ir.structural_equal(out, new_out)
 
+
 def test_not_fuse_multi_diamond():
     # Pattern
     is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
@@ -528,6 +588,7 @@ def test_not_fuse_multi_diamond():
     # Check
     assert not diamond.match(out)
 
+
 class BatchnormCallback(DFPatternCallback):
     def __init__(self):
         self.x = wildcard()
@@ -536,8 +597,9 @@ class BatchnormCallback(DFPatternCallback):
         self.beta = wildcard()
         self.gamma = wildcard()
         self.eps = wildcard()
-        
-        self.pattern = self.gamma * (self.x - self.mean)/is_op("sqrt")(self.var + self.eps) + self.beta
+
+        self.pattern = self.gamma * (self.x - self.mean) / is_op("sqrt")(self.var + self.eps) + \
+                       self.beta
 
     def callback(self, pre, post, node_map):
         x = node_map[self.x][0]
@@ -546,7 +608,9 @@ class BatchnormCallback(DFPatternCallback):
         beta = node_map[self.beta][0]
         gamma = node_map[self.gamma][0]
         eps = node_map[self.eps][0]
-        return relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = eps.data.asnumpy().item())[0]
+        return relay.op.nn.batch_norm(x, gamma, beta, mean, var,
+                                      epsilon=eps.data.asnumpy().item())[0]
+
 
 def test_fuse_batchnorm():
     x = relay.var('x')
@@ -554,11 +618,14 @@ def test_fuse_batchnorm():
     mean = relay.var('mean')
     beta = relay.var('beta')
     gamma = relay.var('gamma')
-    
-    BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta
+
+    BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta
 
     out = rewrite(BatchnormCallback(), BN)
-    assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0])
+    assert tvm.ir.structural_equal(
+        out,
+        relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0])
+
 
 def test_no_fuse_batchnorm():
     x = relay.var('x')
@@ -566,75 +633,85 @@ def test_no_fuse_batchnorm():
     mean = relay.var('mean')
     beta = relay.var('beta')
     gamma = relay.var('gamma')
-    
-    fake_BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) - beta
+
+    fake_BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) - beta
 
     out = rewrite(BatchnormCallback(), fake_BN)
     assert tvm.ir.structural_equal(out, fake_BN)
 
+
 def test_fuse_double_batchnorm():
     x = relay.var('x')
     var = relay.var('var')
     mean = relay.var('mean')
     beta = relay.var('beta')
     gamma = relay.var('gamma')
-    
-    BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta
-    BN2 = gamma * (BN - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta
+
+    BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta
+    BN2 = gamma * (BN - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta
 
     out = rewrite(BatchnormCallback(), BN2)
 
-    bn = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]
-    bn2 = relay.op.nn.batch_norm(bn, gamma, beta, mean, var, epsilon = 1e-5)[0]
+    bn = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]
+    bn2 = relay.op.nn.batch_norm(bn, gamma, beta, mean, var, epsilon=1e-5)[0]
 
     assert tvm.ir.structural_equal(out, bn2)
 
+
 def test_partial_fuse_double_batchnorm():
     x = relay.var('x')
     var = relay.var('var')
     mean = relay.var('mean')
     beta = relay.var('beta')
     gamma = relay.var('gamma')
-    
-    BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) - beta
-    BN2 = gamma * (BN - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta
+
+    BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) - beta
+    BN2 = gamma * (BN - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta
 
     out = rewrite(BatchnormCallback(), BN2)
 
-    bn2 = relay.op.nn.batch_norm(BN, gamma, beta, mean, var, epsilon = 1e-5)[0]
+    bn2 = relay.op.nn.batch_norm(BN, gamma, beta, mean, var, epsilon=1e-5)[0]
 
     assert tvm.ir.structural_equal(out, bn2)
 
+
 def test_fuse_batchnorm_commutation():
     x = relay.var('x')
     var = relay.var('var')
     mean = relay.var('mean')
     beta = relay.var('beta')
     gamma = relay.var('gamma')
-    
+
     #commute add
-    BN = beta + gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5))
+    BN = beta + gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5))
     out = rewrite(BatchnormCallback(), BN)
-    assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0])
+    assert tvm.ir.structural_equal(
+        out,
+        relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0])
 
     # associate divide/multiply
-    BN = (gamma * (x - mean)) /relay.op.sqrt(var + relay.const(1e-5))  + beta
+    BN = (gamma * (x - mean)) / relay.op.sqrt(var + relay.const(1e-5)) + beta
     out = rewrite(BatchnormCallback(), BN)
-    assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0])
+    assert tvm.ir.structural_equal(
+        out,
+        relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0])
 
     # associate multiply/divide
-    BN = gamma * ((x - mean)/relay.op.sqrt(var + relay.const(1e-5))) + beta
+    BN = gamma * ((x - mean) / relay.op.sqrt(var + relay.const(1e-5))) + beta
     out = rewrite(BatchnormCallback(), BN)
-    assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0])
+    assert tvm.ir.structural_equal(
+        out,
+        relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0])
+
 
 def test_quadruple_rewrite_dominator():
     class DominatorRemovalCallback(DFPatternCallback):
         def __init__(self):
             self.inp = wildcard()
             self.weight = wildcard()
-            
             is_conv2d = is_op('nn.conv2d')(self.inp, self.weight)
-            is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op('add')(wildcard(), wildcard())
+            is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(
+                wildcard()) | is_op('add')(wildcard(), wildcard())
             reduction = is_op('add')(wildcard(), wildcard())
             self.pattern = dominates(is_conv2d, is_unary_elemwise, reduction)
 
@@ -674,7 +751,6 @@ def test_quadruple_rewrite_dominator():
     tanh = relay.op.tanh(relu)
     leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
     out = tanh + leaky_relu
-    
     one = relay.op.nn.conv2d(inp, weight)
     two = relay.op.nn.conv2d(one, weight)
     three = relay.op.nn.conv2d(two, weight)
@@ -682,18 +758,20 @@ def test_quadruple_rewrite_dominator():
 
     assert tvm.ir.structural_equal(DominatorRemovalCallback().rewrite(out), four)
 
+
 def algebraic_simplify(expr):
     zero = (ExprPattern(relay.const(0)) | ExprPattern(relay.const(0.0)))
     one = (ExprPattern(relay.const(1)) | ExprPattern(relay.const(1.0)))
+
     class ElwiseNullCallback(DFPatternCallback):
         def callback(self, pre, post, node_map):
-            return node_map[self.x][0]
+            return node_map[self.x][0] # pylint: disable=no-member
 
     class AddCallback(ElwiseNullCallback):
         def __init__(self):
             self.x = wildcard()
             self.pattern = self.x + zero
-    
+
     class SubCallback(ElwiseNullCallback):
         def __init__(self):
             self.x = wildcard()
@@ -719,17 +797,19 @@ def algebraic_simplify(expr):
             self.x = zero
             self.pattern = self.x / wildcard()
 
-    return rewrite([AddCallback(),
-                    SubCallback(),
-                    MulCallback(),
-                    DivCallback(),
-                    MulZeroCallback(),
-                    ZeroDivCallback()
-                    ], expr);
+    return rewrite([
+        AddCallback(),
+        SubCallback(),
+        MulCallback(),
+        DivCallback(),
+        MulZeroCallback(),
+        ZeroDivCallback()
+    ], expr)
+
 
 def test_algebraic_simplify():
     x = relay.Var('x')
-    y = relay.Var('y')  
+    y = relay.Var('y')
 
     one = relay.const(1)
     zero = relay.const(0)
@@ -740,23 +820,25 @@ def test_algebraic_simplify():
     assert algebraic_simplify(x + zerof) == x
     assert algebraic_simplify(zero + x) == x
     assert algebraic_simplify(zerof + x) == x
-    
+
     assert algebraic_simplify(x - zero) == x
     assert algebraic_simplify(x - zerof) == x
-    
+
     assert algebraic_simplify(x * one) == x
     assert algebraic_simplify(x * onef) == x
     assert algebraic_simplify(one * x) == x
     assert algebraic_simplify(onef * x) == x
     assert algebraic_simplify(x * zero) == zero
     assert algebraic_simplify(x * zerof) == zerof
-    
+
     assert algebraic_simplify(x / one) == x
     assert algebraic_simplify(x / onef) == x
     assert algebraic_simplify(zero / x) == zero
     assert algebraic_simplify(zerof / x) == zerof
 
-    assert tvm.ir.structural_equal(algebraic_simplify((x + zero * y) / one + (y * one) - zero / x), x + y)
+    assert tvm.ir.structural_equal(algebraic_simplify((x + zero * y) / one + (y * one) - zero / x),
+                                   x + y)
+
 
 def test_double_partition():
     # Pattern 1
@@ -780,19 +862,30 @@ def test_double_partition():
     for pat, label in [(relu_p, "conv_bias_relu"), (bias_add_p, "conv_bias")]:
         partitioned = pat.partition(partitioned, {"Composite": label})
 
-
     inpf = relay.var("input")
     weightf = relay.var("weight")
     biasf = relay.var("bias")
-    func0 = relay.Function([inpf, weightf, biasf], relay.op.nn.relu(relay.op.nn.bias_add(relay.op.nn.conv2d(inpf, weightf), biasf))).with_attr("Composite", "conv_bias_relu").with_attr("PartitionedFromPattern","nn.conv2d_nn.bias_add_nn.relu_")
+    func0 = relay.Function(
+        [inpf, weightf, biasf],
+        relay.op.nn.relu(relay.op.nn.bias_add(
+            relay.op.nn.conv2d(inpf, weightf),
+            biasf))).with_attr("Composite",
+                               "conv_bias_relu").with_attr("PartitionedFromPattern",
+                                                           "nn.conv2d_nn.bias_add_nn.relu_")
     inpf = relay.var("input")
     weightf = relay.var("weight")
     biasf = relay.var("bias")
-    func1 = relay.Function([inpf, weightf, biasf], relay.op.nn.bias_add(relay.op.nn.conv2d(inpf, weightf), biasf)).with_attr("Composite", "conv_bias").with_attr("PartitionedFromPattern","nn.conv2d_nn.bias_add_")
+    func1 = relay.Function([inpf, weightf, biasf],
+                           relay.op.nn.bias_add(relay.op.nn.conv2d(inpf, weightf),
+                                                biasf)).with_attr("Composite",
+                                                                  "conv_bias").with_attr(
+                                                                      "PartitionedFromPattern",
+                                                                      "nn.conv2d_nn.bias_add_")
 
     expected = func1(func0(x, w, b), w2, b2)
     assert tvm.ir.structural_equal(partitioned, expected)
 
+
 def test_partition_dominator():
     # Pattern
     is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
@@ -803,31 +896,36 @@ def test_partition_dominator():
     # Classic Diamond
     inp = relay.var('input')
     weight = relay.var('weight')
+
     def generate_diamond(inp, weight):
         conv2d = relay.op.nn.conv2d(inp, weight)
         relu = relay.op.nn.relu(conv2d)
         relu = relay.op.nn.relu(relu)
         leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
         return relu + leaky_relu
-    out = generate_diamond(inp*inp, weight*weight)
+
+    out = generate_diamond(inp * inp, weight * weight)
     # Check
     partitioned = diamond.partition(out)
-    
+
     i = relay.Var("input")
     w = relay.Var("weight")
-    f = relay.Function([i, w], generate_diamond(i, w)).with_attr("PartitionedFromPattern","nn.conv2d_nn.relu_nn.relu_nn.leaky_relu_add_")
-    assert tvm.ir.structural_equal(partitioned, f(inp*inp, weight*weight))
+    f = relay.Function([i, w], generate_diamond(i, w)).with_attr(
+        "PartitionedFromPattern", "nn.conv2d_nn.relu_nn.relu_nn.leaky_relu_add_")
+    assert tvm.ir.structural_equal(partitioned, f(inp * inp, weight * weight))
+
 
 def test_quadruple_partition_dominator():
     # Pattern
     is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
-    is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op('add')(wildcard(), wildcard())
+    is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(
+        wildcard()) | is_op('add')(wildcard(), wildcard())
     reduction = is_op('add')(wildcard(), wildcard())
     diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
 
-
     inp = relay.var('input')
     weight = relay.var('weight')
+
     # Classic Diamond
     def classic_diamond(inp, weight):
         conv2d = relay.op.nn.conv2d(inp, weight)
@@ -863,40 +961,31 @@ def test_quadruple_partition_dominator():
         return tanh + leaky_relu
 
     partitioned = diamond.partition(
-                    nested_diamond(
-                        single_branch(
-                            deeper_diamond(
-                                classic_diamond(inp, weight),
-                                weight),
-                            weight),
-                        weight
-                    )
-                )
+        nested_diamond(single_branch(deeper_diamond(classic_diamond(inp, weight), weight), weight),
+                       weight))
 
     functions = []
     partition_names = [
         "nn.conv2d_nn.relu_nn.relu_nn.leaky_relu_add_",
-        "nn.conv2d_nn.relu_nn.relu_tanh_nn.leaky_relu_add_",
-        "nn.conv2d_nn.relu_nn.relu_tanh_add_",
+        "nn.conv2d_nn.relu_nn.relu_tanh_nn.leaky_relu_add_", "nn.conv2d_nn.relu_nn.relu_tanh_add_",
         "nn.conv2d_nn.relu_add_tanh_nn.leaky_relu_add_"
     ]
     for i, f in enumerate([classic_diamond, deeper_diamond, single_branch, nested_diamond]):
         inpf = relay.var("input")
         weightf = relay.var("weight")
-        functions.append(relay.Function([inpf, weightf], f(inpf, weightf)).with_attr("PartitionedFromPattern", partition_names[i]))
-
-    reference = functions[3](
-                    functions[2](
-                        functions[1](
-                            functions[0](inp, weight),
-                            weight),
-                        weight),
-                    weight
-                )
+        functions.append(
+            relay.Function([inpf, weightf], f(inpf,
+                                              weightf)).with_attr("PartitionedFromPattern",
+                                                                  partition_names[i]))
+
+    reference = functions[3](functions[2](functions[1](functions[0](inp, weight), weight), weight),
+                             weight)
     assert tvm.ir.structural_equal(partitioned, reference)
 
+
 def get_BN(x, var, mean, beta, gamma, eps):
-    return gamma * (x - mean)/relay.op.sqrt(var + eps) + beta
+    return gamma * (x - mean) / relay.op.sqrt(var + eps) + beta
+
 
 def test_partition_batchnorm():
     x = relay.var('x')
@@ -907,7 +996,6 @@ def test_partition_batchnorm():
     eps = relay.const(1e-5)
     BN = get_BN(x, var, mean, beta, gamma, eps)
 
-    
     xf = relay.var('xf')
     varf = relay.var('varf')
     meanf = relay.var('meanf')
@@ -915,11 +1003,15 @@ def test_partition_batchnorm():
     gammaf = relay.var('gammaf')
     epsf = relay.var('epsf')
     # Put the arguments in toplogological order for the reference
-    f = relay.Function([gammaf, xf, meanf, varf, epsf, betaf], get_BN(xf, varf, meanf, betaf, gammaf, epsf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")
+    f = relay.Function([gammaf, xf, meanf, varf, epsf, betaf],
+                       get_BN(xf, varf, meanf, betaf, gammaf,
+                              epsf)).with_attr("PartitionedFromPattern",
+                                               "subtract_multiply_add_sqrt_divide_add_")
 
     partitioned = BatchnormCallback().pattern.partition(BN)
     assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, eps, beta))
 
+
 def test_partition_double_batchnorm():
     x = relay.var('x')
     var = relay.var('var')
@@ -927,9 +1019,9 @@ def test_partition_double_batchnorm():
     beta = relay.var('beta')
     gamma = relay.var('gamma')
     eps = relay.const(1e-5)
-    
-    BN = gamma * (x - mean)/relay.op.sqrt(var + eps) + beta
-    BN2 = gamma * (BN - mean)/relay.op.sqrt(var + eps) + beta
+
+    BN = gamma * (x - mean) / relay.op.sqrt(var + eps) + beta
+    BN2 = gamma * (BN - mean) / relay.op.sqrt(var + eps) + beta
 
     xf = relay.var('xf')
     varf = relay.var('varf')
@@ -937,7 +1029,10 @@ def test_partition_double_batchnorm():
     betaf = relay.var('betaf')
     gammaf = relay.var('gammaf')
     epsf = relay.var('epsf')
-    f1 = relay.Function([gammaf, xf, meanf, varf, epsf, betaf], get_BN(xf, varf, meanf, betaf, gammaf, epsf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")
+    f1 = relay.Function([gammaf, xf, meanf, varf, epsf, betaf],
+                        get_BN(xf, varf, meanf, betaf, gammaf,
+                               epsf)).with_attr("PartitionedFromPattern",
+                                                "subtract_multiply_add_sqrt_divide_add_")
     # The partitioner doesn't replace duplicates, so we use two copies of the function
     xf2 = relay.var('xf2')
     varf2 = relay.var('varf2')
@@ -945,14 +1040,19 @@ def test_partition_double_batchnorm():
     betaf2 = relay.var('betaf2')
     gammaf2 = relay.var('gammaf2')
     epsf2 = relay.var('epsf2')
-    f2 = relay.Function([gammaf2, xf2, meanf2, varf2, epsf2, betaf2], get_BN(xf2, varf2, meanf2, betaf2, gammaf2, epsf2)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")
+    f2 = relay.Function([gammaf2, xf2, meanf2, varf2, epsf2, betaf2],
+                        get_BN(xf2, varf2, meanf2, betaf2, gammaf2,
+                               epsf2)).with_attr("PartitionedFromPattern",
+                                                 "subtract_multiply_add_sqrt_divide_add_")
 
     partitioned = BatchnormCallback().pattern.partition(BN2)
     reference = f2(gamma, f1(gamma, x, mean, var, eps, beta), mean, var, eps, beta)
     assert tvm.ir.structural_equal(partitioned, reference)
 
+
 def test_partition_check():
     pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))
+
     def check(pre):
         return pre.args[0].attrs.data_layout == "NCHW"
 
@@ -965,7 +1065,8 @@ def test_partition_check():
     wf = relay.var('weight')
     conv2df = relay.op.nn.conv2d(xf, wf)
     reluf = relay.op.nn.relu(conv2df)
-    func = relay.Function([xf, wf], reluf).with_attr("PartitionedFromPattern", "nn.conv2d_nn.relu_")
+    func = relay.Function([xf, wf], reluf).with_attr("PartitionedFromPattern",
+                                                     "nn.conv2d_nn.relu_")
 
     reference = func(x, w)
     partitioned = pattern.partition(relu, check=check)
@@ -975,8 +1076,10 @@ def test_partition_check():
     relu = relay.op.nn.relu(conv2d)
     assert relu == pattern.partition(relu, check=check)
 
+
 def test_partition_check_types():
     pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))
+
     def check(pre):
         conv = pre.args[0]
         return (conv.attrs.data_layout == "NCHW") and bool(conv.checked_type.shape[0] == 1)
@@ -1002,6 +1105,7 @@ def test_partition_check_types():
     relu = run_opt_pass(relu, relay.transform.InferType())
     assert relu == pattern.partition(relu, check=check)
 
+
 def test_partition_option():
     x = relay.var('x')
     w = relay.var('w')
@@ -1020,12 +1124,15 @@ def test_partition_option():
         bias_add = relay.op.nn.bias_add(conv2d, b)
         relu = relay.op.nn.relu(bias_add)
         return relu
+
     relu = conv_bias_relu(x, w, b)
 
     xf = relay.var('x')
     wf = relay.var('w')
     bf = relay.var('b')
-    func = relay.Function([xf, wf, bf], conv_bias_relu(xf, wf, bf)).with_attr("PartitionedFromPattern","nn.conv2d_nn.bias_add_nn.relu_")
+    func = relay.Function([xf, wf, bf],
+                          conv_bias_relu(xf, wf, bf)).with_attr("PartitionedFromPattern",
+                                                                "nn.conv2d_nn.bias_add_nn.relu_")
 
     assert pattern1.match(relu)
     assert tvm.ir.structural_equal(func(x, w, b), pattern1.partition(relu))
@@ -1033,33 +1140,55 @@ def test_partition_option():
     assert pattern2.match(relu)
     assert tvm.ir.structural_equal(func(x, w, b), pattern2.partition(relu))
 
+
 if __name__ == "__main__":
+    test_expr_pattern()
+    test_var_pattern()
+    test_constant_pattern()
+    test_wildcard_pattern()
+    test_CallPattern()
+    test_TuplePattern()
+    test_TupleGetItemPattern()
+    test_AltPattern()
+    test_TypePattern()
+    test_AttrPattern()
     test_match_op()
     test_no_match_op()
     test_match_op_or()
-    test_match_call()
-    test_no_match_call()
     test_match_call_commutive()
     test_no_match_call_commutive()
+    test_match_call()
+    test_no_match_call()
+    test_match_option()
+    test_no_match_option()
+    test_match_const()
     test_match_tuple()
     test_no_match_tuple()
     test_match_type()
     test_no_match_type()
-    test_match_attr()
-    test_no_match_attr()
+    test_match_op_attr()
+    test_no_match_op_attr()
+    test_match_func_attr()
+    test_no_match_func_attr()
+    test_match_call_attr()
+    test_no_match_call_attr()
     test_match_diamond()
     test_no_match_diamond()
     test_match_fake_diamond()
+    test_match_dominator()
+    test_not_match_dominator()
     test_rewrite()
+    test_rewrite_func()
     test_nested_rewrite()
+    test_not_fuse_multi_diamond()
     test_fuse_batchnorm()
     test_no_fuse_batchnorm()
     test_fuse_double_batchnorm()
     test_partial_fuse_double_batchnorm()
     test_fuse_batchnorm_commutation()
-    test_match_dominator()
-    test_not_match_dominator()
+    test_quadruple_rewrite_dominator()
     test_algebraic_simplify()
+    test_double_partition()
     test_partition_dominator()
     test_quadruple_partition_dominator()
     test_partition_batchnorm()