[POC][PatternLang]Remove constants from partitioned functions (#5663)
authorMatthew Brookhart <mbrookhart@octoml.ai>
Tue, 26 May 2020 17:26:31 +0000 (10:26 -0700)
committerGitHub <noreply@github.com>
Tue, 26 May 2020 17:26:31 +0000 (02:26 +0900)
* remove constants from partitioned functions

* remove print statements

src/relay/ir/dataflow_matcher.cc
tests/python/relay/test_dataflow_pattern.py

index 2f25733..dd9d806 100644 (file)
@@ -557,7 +557,7 @@ class PatternGrouper {
           auto matches = node_map[node->ref_];
           for (auto match : matches) {
             if (fuzzy_matches.count(match) == 0 && match.as<OpNode>() == nullptr &&
-                match.as<FunctionNode>() == nullptr && match.as<ConstantNode>() == nullptr) {
+                match.as<FunctionNode>() == nullptr) {
               inputs[match] = Var(
                   "FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number),
                   NullValue<Type>());
@@ -577,8 +577,8 @@ class PatternGrouper {
     auto extractor = MatchExtractor(inputs);
     auto body = extractor.Mutate(expr);
 
-    // Verify the pattern still holds
-    CHECK(DFPatternMatcher(body).Match(pattern_, body));
+    // Verify the pattern still holds, no longer valid if we're not embedding constants in the
+    // graph, keep here for future debug CHECK(DFPatternMatcher(body).Match(pattern_, body));
     group.function = Function(params, body, NullValue<Type>(), Array<TypeVar>());
     group.name = extractor.GetName();
     // Check to make sure we aren't overlapping with another group
index 17c8df4..ed90873 100644 (file)
@@ -878,8 +878,8 @@ def test_quadruple_partition_dominator():
                 )
     assert tvm.ir.structural_equal(partitioned, reference)
 
-def get_BN(x, var, mean, beta, gamma, eps = 1e-5):
-    return gamma * (x - mean)/relay.op.sqrt(var + relay.const(eps)) + beta
+def get_BN(x, var, mean, beta, gamma, eps):
+    return gamma * (x - mean)/relay.op.sqrt(var + eps) + beta
 
 def test_partition_batchnorm():
     x = relay.var('x')
@@ -887,7 +887,8 @@ def test_partition_batchnorm():
     mean = relay.var('mean')
     beta = relay.var('beta')
     gamma = relay.var('gamma')
-    BN = get_BN(x, var, mean, beta, gamma)
+    eps = relay.const(1e-5)
+    BN = get_BN(x, var, mean, beta, gamma, eps)
 
     
     xf = relay.var('xf')
@@ -895,11 +896,12 @@ def test_partition_batchnorm():
     meanf = relay.var('meanf')
     betaf = relay.var('betaf')
     gammaf = relay.var('gammaf')
+    epsf = relay.var('epsf')
     # Put the arguments in toplogological order for the reference
-    f = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")
+    f = relay.Function([gammaf, xf, meanf, varf, epsf, betaf], get_BN(xf, varf, meanf, betaf, gammaf, epsf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")
 
     partitioned = BatchnormCallback().pattern.partition(BN)
-    assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, beta))
+    assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, eps, beta))
 
 def test_partition_double_batchnorm():
     x = relay.var('x')
@@ -907,26 +909,29 @@ def test_partition_double_batchnorm():
     mean = relay.var('mean')
     beta = relay.var('beta')
     gamma = relay.var('gamma')
+    eps = relay.const(1e-5)
     
-    BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta
-    BN2 = gamma * (BN - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta
+    BN = gamma * (x - mean)/relay.op.sqrt(var + eps) + beta
+    BN2 = gamma * (BN - mean)/relay.op.sqrt(var + eps) + beta
 
     xf = relay.var('xf')
     varf = relay.var('varf')
     meanf = relay.var('meanf')
     betaf = relay.var('betaf')
     gammaf = relay.var('gammaf')
-    f1 = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")
+    epsf = relay.var('epsf')
+    f1 = relay.Function([gammaf, xf, meanf, varf, epsf, betaf], get_BN(xf, varf, meanf, betaf, gammaf, epsf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")
     # The partitioner doesn't replace duplicates, so we use two copies of the function
     xf2 = relay.var('xf2')
     varf2 = relay.var('varf2')
     meanf2 = relay.var('meanf2')
     betaf2 = relay.var('betaf2')
     gammaf2 = relay.var('gammaf2')
-    f2 = relay.Function([gammaf2, xf2, meanf2, varf2, betaf2], get_BN(xf2, varf2, meanf2, betaf2, gammaf2)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")
+    epsf2 = relay.var('epsf2')
+    f2 = relay.Function([gammaf2, xf2, meanf2, varf2, epsf2, betaf2], get_BN(xf2, varf2, meanf2, betaf2, gammaf2, epsf2)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")
 
     partitioned = BatchnormCallback().pattern.partition(BN2)
-    reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta)
+    reference = f2(gamma, f1(gamma, x, mean, var, eps, beta), mean, var, eps, beta)
     assert tvm.ir.structural_equal(partitioned, reference)
 
 def test_partition_check():