Add a check Callback to the Pattern Paritioner (#5646)
authorMatthew Brookhart <mbrookhart@octoml.ai>
Fri, 22 May 2020 04:35:35 +0000 (21:35 -0700)
committerGitHub <noreply@github.com>
Fri, 22 May 2020 04:35:35 +0000 (13:35 +0900)
* add a check callback to the paritioner

* fix doc string

* fix unit test spelling

* add a test with types

include/tvm/relay/dataflow_matcher.h
python/tvm/relay/dataflow_pattern/__init__.py
src/relay/ir/dataflow_matcher.cc
tests/python/relay/test_dataflow_pattern.py

index 58aa640..517582b 100644 (file)
@@ -27,6 +27,7 @@
 #include <tvm/relay/dataflow_pattern.h>
 #include <tvm/relay/dataflow_pattern_functor.h>
 
+#include <string>
 #include <unordered_map>
 #include <utility>
 
@@ -87,10 +88,14 @@ Expr RewritePatterns(Array<DFPatternCallback> callbacks, Expr expr);
  *
  * \param pattern The pattern to match
  * \param expr The expression to patition
+ * \param attrs A set of parameter names and values to apply to the partitioned function
+ * \param check A callback function for checking more complicated properties of the matched
+ * expressions, returns true if the match is accepted and false otherwise
  *
  * \return Return the paritioned Expr.
  */
-Expr PartitionPattern(DFPattern pattern, Expr expr);
+Expr PartitionPattern(DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs,
+                      PackedFunc check);
 
 }  // namespace relay
 }  // namespace tvm
index 2582894..f8be3e2 100644 (file)
@@ -109,7 +109,7 @@ class DFPattern(Node):
         """
         return match(self, expr)
 
-    def partition(self, expr: Expr, attrs=None) -> Expr:
+    def partition(self, expr: Expr, attrs=None, check=lambda x: True) -> Expr:
         """
         Parition the expression into functions defined by this pattern
 
@@ -119,13 +119,16 @@ class DFPattern(Node):
             The expression to match.
         attrs : Optional[Dict[str, Object]]
             A dictionary of Attribute name/values to add to the paritioned function
+        check : Function
+            A function to perform more complicated checks on the matched expression.
+            Returns true if partitioning should proceed, false otherwise.
 
         Returns
         -------
         result : tvm.relay.Expr
             The Expression with matched subgraphs replaced by function calls to that subgraph
         """
-        return partition(self, expr, attrs)
+        return partition(self, expr, attrs, check)
 
     def dominates(self, parent, path=None):
         """
@@ -561,7 +564,7 @@ def rewrite(callbacks, expr: Expr) -> Expr:
 
     return ffi.rewrite(tmp, expr)
 
-def partition(pattern: DFPattern, expr: Expr, attrs=None) -> Expr:
+def partition(pattern: DFPattern, expr: Expr, attrs=None, check=lambda x: True) -> Expr:
     """
     Parition the expression into a series of functions that match the pattern
 
@@ -571,12 +574,15 @@ def partition(pattern: DFPattern, expr: Expr, attrs=None) -> Expr:
         The pattern to match
     expr : tvm.relay.Expr
         The expression to split into functions
-    expr : Optional[Dict[str, Object]]
+    attrs : Optional[Dict[str, Object]]
         A dict of attributes to apply to the partitioned function
+    check : Function
+        A function to perform more complicated checks on the matched expression.
+        Returns true if partitioning should proceed, false otherwise.
 
     Returns
     -------
     result : tvm.relay.Expr
         The Expression with matched subgraphs replaced by function calls to that subgraph
     """
-    return ffi.partition(pattern, expr, attrs)
+    return ffi.partition(pattern, expr, attrs, check)
index 4bb2b0b..980935c 100644 (file)
@@ -693,11 +693,12 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.rewrite").set_body_typed(RewritePatt
 class PatternPartitioner : protected MixedModeMutator {
  public:
   Expr Partition(const DFPattern& pattern, const Expr& pre,
-                 const Map<std::string, ObjectRef>& attrs) {
+                 const Map<std::string, ObjectRef>& attrs, PackedFunc check) {
     auto grouper = PatternGrouper();
     groups_ = grouper.GroupMatches(pattern, pre);
     gid_assignments_ = grouper.GetGIDAssignments();
     attrs_ = attrs;
+    check_ = check;
     return this->VisitExpr(pre);
   }
 
@@ -718,7 +719,8 @@ class PatternPartitioner : protected MixedModeMutator {
 
   Expr DispatchVisitExpr(const Expr& pre) override {
     auto post = MixedModeMutator::DispatchVisitExpr(pre);
-    if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node) {
+    if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node &&
+        static_cast<bool>(check_(pre))) {
       post = RewritePartition(groups_[gid_assignments_[pre]]);
     }
     return post;
@@ -727,16 +729,17 @@ class PatternPartitioner : protected MixedModeMutator {
   Map<std::string, ObjectRef> attrs_;
   std::vector<PatternGrouper::Group> groups_;
   std::unordered_map<Expr, int, ObjectHash, ObjectEqual> gid_assignments_;
+  PackedFunc check_;
 };
 
-Expr PartitionPattern(DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs) {
-  return PatternPartitioner().Partition(pattern, expr, attrs);
+Expr PartitionPattern(DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs,
+                      PackedFunc check) {
+  return PatternPartitioner().Partition(pattern, expr, attrs, check);
 }
 
 TVM_REGISTER_GLOBAL("relay.dataflow_pattern.partition")
-    .set_body_typed([](DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs) {
-      return PartitionPattern(pattern, expr, attrs);
-    });
+    .set_body_typed([](DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs,
+                       PackedFunc check) { return PartitionPattern(pattern, expr, attrs, check); });
 
 }  // namespace relay
 }  // namespace tvm
index 411ef0f..3a605e4 100644 (file)
@@ -17,6 +17,7 @@
 import tvm
 from tvm import relay
 from tvm.relay.dataflow_pattern import *
+from tvm.relay.testing import run_opt_pass
 import numpy as np
 
 # NB: 1 corresponds to the C++ enum that specicfies this
@@ -880,7 +881,7 @@ def test_quadruple_partition_dominator():
 def get_BN(x, var, mean, beta, gamma, eps = 1e-5):
     return gamma * (x - mean)/relay.op.sqrt(var + relay.const(eps)) + beta
 
-def test_parition_batchnorm():
+def test_partition_batchnorm():
     x = relay.var('x')
     var = relay.var('var')
     mean = relay.var('mean')
@@ -900,7 +901,7 @@ def test_parition_batchnorm():
     partitioned = BatchnormCallback().pattern.partition(BN)
     assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, beta))
 
-def test_parition_double_batchnorm():
+def test_partition_double_batchnorm():
     x = relay.var('x')
     var = relay.var('var')
     mean = relay.var('mean')
@@ -916,7 +917,7 @@ def test_parition_double_batchnorm():
     betaf = relay.var('betaf')
     gammaf = relay.var('gammaf')
     f1 = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")
-    # The paritioner doesn't replace duplicates, so we use two copies of the function
+    # The partitioner doesn't replace duplicates, so we use two copies of the function
     xf2 = relay.var('xf2')
     varf2 = relay.var('varf2')
     meanf2 = relay.var('meanf2')
@@ -928,6 +929,58 @@ def test_parition_double_batchnorm():
     reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta)
     assert tvm.ir.structural_equal(partitioned, reference)
 
+def test_partition_check():
+    pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))
+    def check(pre):
+        return pre.args[0].attrs.data_layout == "NCHW"
+
+    x = relay.var('input')
+    w = relay.var('weight')
+    conv2d = relay.op.nn.conv2d(x, w)
+    relu = relay.op.nn.relu(conv2d)
+
+    xf = relay.var('input')
+    wf = relay.var('weight')
+    conv2df = relay.op.nn.conv2d(xf, wf)
+    reluf = relay.op.nn.relu(conv2df)
+    func = relay.Function([xf, wf], reluf).with_attr("PartitionedFromPattern", "nn.conv2d_nn.relu_")
+
+    reference = func(x, w)
+    partitioned = pattern.partition(relu, check=check)
+    assert tvm.ir.structural_equal(partitioned, reference)
+
+    conv2d = relay.op.nn.conv2d(x, w, data_layout="NHWC")
+    relu = relay.op.nn.relu(conv2d)
+    assert relu == pattern.partition(relu, check=check)
+
+def test_partition_check_types():
+    pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))
+    def check(pre):
+        conv = pre.args[0]
+        return (conv.attrs.data_layout == "NCHW") and bool(conv.checked_type.shape[0] == 1)
+
+    x = relay.var('input', shape=(1, 10, 10, 10))
+    w = relay.var('weight', shape=(10, 10, 3, 3))
+    conv2d = relay.op.nn.conv2d(x, w)
+    relu = relay.op.nn.relu(conv2d)
+    relu = run_opt_pass(relu, relay.transform.InferType())
+
+    partitioned = pattern.partition(relu, check=check)
+    assert partitioned.op.attrs["PartitionedFromPattern"] == "nn.conv2d_nn.relu_"
+
+    conv2d = relay.op.nn.conv2d(x, w, data_layout="NHWC")
+    relu = relay.op.nn.relu(conv2d)
+    relu = run_opt_pass(relu, relay.transform.InferType())
+    assert relu == pattern.partition(relu, check=check)
+
+    x = relay.var('input', shape=(2, 10, 10, 10))
+    w = relay.var('weight', shape=(10, 10, 3, 3))
+    conv2d = relay.op.nn.conv2d(x, w)
+    relu = relay.op.nn.relu(conv2d)
+    relu = run_opt_pass(relu, relay.transform.InferType())
+    assert relu == pattern.partition(relu, check=check)
+
+
 if __name__ == "__main__":
     test_match_op()
     test_no_match_op()
@@ -957,6 +1010,8 @@ if __name__ == "__main__":
     test_algebraic_simplify()
     test_partition_dominator()
     test_quadruple_partition_dominator()
-    test_parition_batchnorm()
-    test_parition_double_batchnorm()
+    test_partition_batchnorm()
+    test_partition_double_batchnorm()
+    test_partition_check()
+    test_partition_check_types()