[PatternLang]Conditionally Embedding Constants in Partitioned Functions (#5693)
authorMatthew Brookhart <mbrookhart@octoml.ai>
Sat, 30 May 2020 01:07:07 +0000 (18:07 -0700)
committerGitHub <noreply@github.com>
Sat, 30 May 2020 01:07:07 +0000 (10:07 +0900)
* Embed constants in the partition function if the pattern explicity requests constants

fix rst

fix pylint

* improve comments based on Cody's feedback

docs/langref/relay_pattern.rst
python/tvm/relay/dataflow_pattern/__init__.py
src/relay/ir/dataflow_matcher.cc
tests/python/relay/test_dataflow_pattern.py

index 7bb7bdf..7d5deb2 100644 (file)
@@ -137,7 +137,7 @@ The next example is matching a pattern of batch_norm -> get(0) -> relu:
         pat.match(out)
 
 The next example is matching a constant node regarding its values. This is useful to check
-if a specific parameter in a subgraph has been bind or not.
+if a specific parameter in a subgraph has been bound or not.
 
 .. code-block:: python
 
@@ -266,10 +266,10 @@ Attribute Pattern
 
 Check that the operator matched by the pattern has an attribute with a particular value.
 
-Input
-*****
+Variable Pattern
+****************
 
-Check that the expression is an input, i.e has no parents and is a variable.
+Check that the expression is a relay Variable, and optional provide a name to match to the Variable name.
 
 
 Alternate
index e8f73ed..f1d0784 100644 (file)
@@ -318,15 +318,14 @@ class VarPattern(DFPattern):
     Parameters
     ----------
     name_hint: str
-        The name of the variable.
-        This name only acts as a hint, and is not used
-        for equality.
+        The name of the variable. Optional, if not provided,
+        the pattern will match any VarNode.
 
     type_annotation: tvm.relay.Type, optional
         The type annotation on the variable.
     """
 
-    def __init__(self, name_hint: str, type_annotation=None):
+    def __init__(self, name_hint="", type_annotation=None):
         self.__init_handle_by_constructor__(
             ffi.VarPattern, name_hint, type_annotation)
 
index a7e4b37..70fce2f 100644 (file)
@@ -562,7 +562,7 @@ class PatternGrouper {
           auto matches = node_map[node->ref_];
           for (auto match : matches) {
             if (fuzzy_matches.count(match) == 0 && match.as<OpNode>() == nullptr &&
-                match.as<FunctionNode>() == nullptr) {
+                match.as<FunctionNode>() == nullptr && !EmbedConst(match, node->ref_)) {
               inputs[match] = Var(
                   "FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number),
                   NullValue<Type>());
@@ -582,8 +582,8 @@ class PatternGrouper {
     auto extractor = MatchExtractor(inputs);
     auto body = extractor.Mutate(expr);
 
-    // Verify the pattern still holds, no longer valid if we're not embedding constants in the
-    // graph, keep here for future debug CHECK(DFPatternMatcher(body).Match(pattern_, body));
+    // Verify the pattern still holds
+    CHECK(DFPatternMatcher(body).Match(pattern_, body));
     group.function = Function(params, body, NullValue<Type>(), Array<TypeVar>());
     group.name = extractor.GetName();
     // Check to make sure we aren't overlapping with another group
@@ -613,6 +613,36 @@ class PatternGrouper {
     CHECK_EQ(groups_[gid_].gid, gid_);
   }
 
+  /* \brief EmbedConst implements rules for embedding constants into partitioned functions or
+   * lifting them into the function arguments.
+   *
+   * The rules depend on what pattern the ConstantNode matched.
+   *
+   * The basic rules are:
+   *  If the constant matches ExprPattern(relay.const(*)) or a ConstantPattern(), embed the constant
+   * in the partitioned function. If the constant matched an AltPattern, recursively check the
+   * matched side of the pattern. For any other matching pattern (i.e, wildcard, VarPattern, etc),
+   * lift the constant into the arguments of the partitioned function.
+   */
+  bool EmbedConst(const Expr& expr, const DFPattern pattern) {
+    bool embed = false;
+    if (expr.as<ConstantNode>()) {
+      if (pattern.as<ConstantPatternNode>() != nullptr) {
+        embed = true;
+      } else if (auto expr_pat = pattern.as<ExprPatternNode>()) {
+        if (expr_pat->expr.as<ConstantNode>()) {
+          embed = true;
+        }
+      } else if (auto alt_pat = pattern.as<AltPatternNode>()) {
+        if (matcher_->Match(alt_pat->left, expr)) {
+          embed = EmbedConst(expr, alt_pat->left);
+        } else {
+          embed = EmbedConst(expr, alt_pat->right);
+        }
+      }
+    }
+    return embed;
+  }
   // Internal State
   DFPattern pattern_;
   std::vector<Group> groups_;
index 467e30b..89abb2e 100644 (file)
@@ -596,7 +596,7 @@ class BatchnormCallback(DFPatternCallback):
         self.mean = wildcard()
         self.beta = wildcard()
         self.gamma = wildcard()
-        self.eps = wildcard()
+        self.eps = ConstantPattern()
 
         self.pattern = self.gamma * (self.x - self.mean) / is_op("sqrt")(self.var + self.eps) + \
                        self.beta
@@ -765,7 +765,7 @@ def algebraic_simplify(expr):
 
     class ElwiseNullCallback(DFPatternCallback):
         def callback(self, pre, post, node_map):
-            return node_map[self.x][0] # pylint: disable=no-member
+            return node_map[self.x][0]  # pylint: disable=no-member
 
     class AddCallback(ElwiseNullCallback):
         def __init__(self):
@@ -1001,15 +1001,15 @@ def test_partition_batchnorm():
     meanf = relay.var('meanf')
     betaf = relay.var('betaf')
     gammaf = relay.var('gammaf')
-    epsf = relay.var('epsf')
     # Put the arguments in toplogological order for the reference
-    f = relay.Function([gammaf, xf, meanf, varf, epsf, betaf],
+    f = relay.Function([gammaf, xf, meanf, varf, betaf],
                        get_BN(xf, varf, meanf, betaf, gammaf,
-                              epsf)).with_attr("PartitionedFromPattern",
-                                               "subtract_multiply_add_sqrt_divide_add_")
+                              eps)).with_attr("PartitionedFromPattern",
+                                              "subtract_multiply_add_sqrt_divide_add_")
 
     partitioned = BatchnormCallback().pattern.partition(BN)
-    assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, eps, beta))
+    reference = f(gamma, x, mean, var, beta)
+    assert tvm.ir.structural_equal(partitioned, reference)
 
 
 def test_partition_double_batchnorm():
@@ -1028,25 +1028,23 @@ def test_partition_double_batchnorm():
     meanf = relay.var('meanf')
     betaf = relay.var('betaf')
     gammaf = relay.var('gammaf')
-    epsf = relay.var('epsf')
-    f1 = relay.Function([gammaf, xf, meanf, varf, epsf, betaf],
+    f1 = relay.Function([gammaf, xf, meanf, varf, betaf],
                         get_BN(xf, varf, meanf, betaf, gammaf,
-                               epsf)).with_attr("PartitionedFromPattern",
-                                                "subtract_multiply_add_sqrt_divide_add_")
+                               eps)).with_attr("PartitionedFromPattern",
+                                               "subtract_multiply_add_sqrt_divide_add_")
     # The partitioner doesn't replace duplicates, so we use two copies of the function
     xf2 = relay.var('xf2')
     varf2 = relay.var('varf2')
     meanf2 = relay.var('meanf2')
     betaf2 = relay.var('betaf2')
     gammaf2 = relay.var('gammaf2')
-    epsf2 = relay.var('epsf2')
-    f2 = relay.Function([gammaf2, xf2, meanf2, varf2, epsf2, betaf2],
+    f2 = relay.Function([gammaf2, xf2, meanf2, varf2, betaf2],
                         get_BN(xf2, varf2, meanf2, betaf2, gammaf2,
-                               epsf2)).with_attr("PartitionedFromPattern",
-                                                 "subtract_multiply_add_sqrt_divide_add_")
+                               eps)).with_attr("PartitionedFromPattern",
+                                               "subtract_multiply_add_sqrt_divide_add_")
 
     partitioned = BatchnormCallback().pattern.partition(BN2)
-    reference = f2(gamma, f1(gamma, x, mean, var, eps, beta), mean, var, eps, beta)
+    reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta)
     assert tvm.ir.structural_equal(partitioned, reference)
 
 
@@ -1106,6 +1104,13 @@ def test_partition_check_types():
     assert relu == pattern.partition(relu, check=check)
 
 
+def conv_bias_relu(x, w, b):
+    conv2d = relay.op.nn.conv2d(x, w)
+    bias_add = relay.op.nn.bias_add(conv2d, b)
+    relu = relay.op.nn.relu(bias_add)
+    return relu
+
+
 def test_partition_option():
     x = relay.var('x')
     w = relay.var('w')
@@ -1119,12 +1124,6 @@ def test_partition_option():
     bias = is_op('nn.bias_add')(conv2d, wildcard())
     pattern2 = bias.optional(lambda x: is_op('nn.relu')(x))
 
-    def conv_bias_relu(x, w, b):
-        conv2d = relay.op.nn.conv2d(x, w)
-        bias_add = relay.op.nn.bias_add(conv2d, b)
-        relu = relay.op.nn.relu(bias_add)
-        return relu
-
     relu = conv_bias_relu(x, w, b)
 
     xf = relay.var('x')
@@ -1153,6 +1152,69 @@ def test_match_match():
     out = rewrite(TestRewrite(), mod['tensor_concatenate_int64'])
     assert tvm.ir.structural_equal(mod['tensor_concatenate_int64'], out)
 
+def test_partition_constant_embedding():
+    x = relay.var('x')
+    w = relay.var('w')
+    wc = relay.const(1)
+    b = relay.var('b')
+
+    xf = relay.var('x')
+    wf = relay.var('w')
+    bf = relay.var('b')
+    embeded_func = relay.Function([xf, bf],
+                                  conv_bias_relu(xf, wc,
+                                                 bf)).with_attr("PartitionedFromPattern",
+                                                                "nn.conv2d_nn.bias_add_nn.relu_")
+    xf = relay.var('x')
+    wf = relay.var('w')
+    bf = relay.var('b')
+    lifted_func = relay.Function([xf, wf, bf],
+                                 conv_bias_relu(xf, wf,
+                                                bf)).with_attr("PartitionedFromPattern",
+                                                               "nn.conv2d_nn.bias_add_nn.relu_")
+    relu = conv_bias_relu(x, w, b)
+    reluc = conv_bias_relu(x, wc, b)
+
+    # Check lifting of wildcard matches
+    pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), wildcard()),
+                                                    wildcard()))
+    assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
+    assert tvm.ir.structural_equal(lifted_func(x, wc, b), pattern.partition(reluc))
+
+    # Check lifting of input matches
+    pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_input()),
+                                                    wildcard()))
+    assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
+    assert tvm.ir.structural_equal(reluc, pattern.partition(reluc))  #Constants are not Inputs
+
+    # Check embedding of constant matches
+    pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(),
+                                                                       ConstantPattern()),
+                                                    wildcard()))
+    assert tvm.ir.structural_equal(relu, pattern.partition(relu))
+    assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))
+
+    # Check embedding of constant ExprPatterns
+    pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(),
+                                                                       ExprPattern(wc)),
+                                                    wildcard()))
+    assert tvm.ir.structural_equal(relu, pattern.partition(relu))
+    assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))
+
+    # Check lifting/embedding of Alt matches
+    pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_input()
+                                                                       | ConstantPattern()),
+                                                    wildcard()))
+    assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
+    assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))
+
+    # Check lifting/embedding of Alt matches with the other ordering
+    pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(
+        wildcard(), ConstantPattern() | is_input()), wildcard()))
+    assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
+    assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))
+
+
 if __name__ == "__main__":
     test_expr_pattern()
     test_var_pattern()
@@ -1209,3 +1271,4 @@ if __name__ == "__main__":
     test_partition_check_types()
     test_partition_option()
     test_match_match()
+    test_partition_constant_embedding()