[RELAY] Remove primitive attribute from composite function (#5014)
authorlhutton1 <35535092+lhutton1@users.noreply.github.com>
Tue, 10 Mar 2020 08:10:07 +0000 (08:10 +0000)
committerGitHub <noreply@github.com>
Tue, 10 Mar 2020 08:10:07 +0000 (17:10 +0900)
* A composite function should not be primitive since we still may need to perform passes on it.

Change-Id: If62d06d265234861a6ec0df7749dc1c339c1055c

src/relay/pass/merge_composite.cc
tests/python/relay/test_pass_merge_composite.py

index 4e1094b..162bf3a 100644 (file)
@@ -168,7 +168,6 @@ class MergeCompositeWrapper : public ExprMutator {
       // make the composite function
       auto f = FunctionNode::make(free_vars, extract, call->checked_type_, {}, Attrs());
       f = FunctionSetAttr(f, attr::kComposite, tir::StringImmNode::make(pattern_name_));
-      f = FunctionSetAttr(f, attr::kPrimitive, tvm::Integer(1));
       // find the expressions associated with the free vars using the args_map
       // this tells us which expressions should be given as inputs to the composite function
       Array<Expr> args;
index b96a89b..bcf61a0 100644 (file)
@@ -164,7 +164,6 @@ def test_simple_merge():
         add_node = relay.add(in_1, in_2)
         relu_node = relay.nn.relu(add_node)
         add_relu = relay.Function([in_1, in_2], relu_node)
-        add_relu = add_relu.set_attribute("Primitive", tir.IntImm("int32", 1))
         add_relu = add_relu.set_attribute("Composite", tir.StringImm("add_relu"))
 
         # merged function
@@ -230,8 +229,6 @@ def test_branch_merge():
         sub_node = relay.subtract(in_1, in_2)
         mul_node = relay.multiply(add_node, sub_node)
         add_sub_mul = relay.Function([in_1, in_2], mul_node)
-        add_sub_mul = add_sub_mul.set_attribute("Primitive",
-                                                tir.IntImm("int32", 1))
         add_sub_mul = add_sub_mul.set_attribute("Composite",
                                                 tir.StringImm("add_sub_mul"))
 
@@ -242,8 +239,6 @@ def test_branch_merge():
         sub_node_1 = relay.subtract(in_3, in_4)
         mul_node_1 = relay.multiply(add_node_1, sub_node_1)
         add_sub_mul_1 = relay.Function([in_3, in_4], mul_node_1)
-        add_sub_mul_1 = add_sub_mul_1.set_attribute("Primitive",
-                                                    tir.IntImm("int32", 1))
         add_sub_mul_1 = add_sub_mul_1.set_attribute("Composite",
                                                     tir.StringImm("add_sub_mul"))
 
@@ -304,8 +299,6 @@ def test_reuse_call_merge():
         add_node_1 = relay.add(in_1, add_node)
         add_node_2 = relay.add(add_node_1, add_node)
         add_add_add = relay.Function([in_1, in_2], add_node_2)
-        add_add_add = add_add_add.set_attribute("Primitive",
-                                                tir.IntImm("int32", 1))
         add_add_add = add_add_add.set_attribute("Composite",
                                                 tir.StringImm("add_add_add"))
 
@@ -390,7 +383,6 @@ def test_multiple_patterns():
         bias_node = relay.nn.bias_add(conv_node, in_3)
         r = relay.nn.relu(bias_node)
         conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r)
-        conv_bias_add_relu = conv_bias_add_relu.set_attribute("Primitive", tir.IntImm("int32", 1))
         conv_bias_add_relu = conv_bias_add_relu.set_attribute("Composite",
                                                               tir.StringImm("conv2d_bias_relu"))
 
@@ -400,7 +392,6 @@ def test_multiple_patterns():
         add_node = relay.add(in_4, in_5)
         r = relay.nn.relu(add_node)
         add_relu = relay.Function([in_4, in_5], r)
-        add_relu = add_relu.set_attribute("Primitive", tir.IntImm("int32", 1))
         add_relu = add_relu.set_attribute("Composite", tir.StringImm("add_relu"))
 
         # merged function
@@ -470,7 +461,6 @@ def test_merge_order():
         out = relay.abs(out)
         out = relay.nn.relu(out)
         merged_func = relay.Function([x, y], out)
-        merged_func = merged_func.set_attribute('Primitive', tir.IntImm('int32', 1))
         merged_func = merged_func.set_attribute('Composite',
                                                 tir.StringImm(composite_name))
         ret = relay.Call(merged_func, [input_1, input_2])
@@ -537,14 +527,12 @@ def test_parallel_merge():
         y = relay.var('y')
         branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y))
         func_1 = relay.Function([x, y], branch_1)
-        func_1 = func_1.set_attribute('Primitive', tir.IntImm('int32', 1))
         func_1 = func_1.set_attribute('Composite', tir.StringImm("add_sub_mul"))
         call_1 = relay.Call(func_1, [input_1, input_2])
         x1 = relay.var('x1')
         y1 = relay.var('y1')
         branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1))
         func_2 = relay.Function([x1, y1], branch_2)
-        func_2 = func_2.set_attribute('Primitive', tir.IntImm('int32', 1))
         func_2 = func_2.set_attribute('Composite', tir.StringImm("add_sub_mul"))
         call_2 = relay.Call(func_2, [input_1, input_2])
         out = relay.multiply(call_1, call_2)
@@ -624,7 +612,6 @@ def test_multiple_input_subgraphs():
         add_relu_1 = relay.add(x, y)
         add_relu_1 = relay.nn.relu(add_relu_1)
         add_relu_1 = relay.Function([x, y], add_relu_1)
-        add_relu_1 = add_relu_1.set_attribute('Primitive', tir.IntImm('int32', 1))
         add_relu_1 = add_relu_1.set_attribute('Composite', tir.StringImm('add_relu'))
         add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]])
         x1 = relay.var('x1')
@@ -632,7 +619,6 @@ def test_multiple_input_subgraphs():
         add_relu_2 = relay.add(x1, y1)
         add_relu_2 = relay.nn.relu(add_relu_2)
         add_relu_2 = relay.Function([x1, y1], add_relu_2)
-        add_relu_2 = add_relu_2.set_attribute('Primitive', tir.IntImm('int32', 1))
         add_relu_2 = add_relu_2.set_attribute('Composite', tir.StringImm('add_relu'))
         add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]])
         x2 = relay.var('x2')
@@ -641,7 +627,6 @@ def test_multiple_input_subgraphs():
         sub = relay.subtract(x2, y2)
         add_sub_mul = relay.multiply(add, sub)
         add_sub_mul = relay.Function([x2, y2], add_sub_mul)
-        add_sub_mul = add_sub_mul.set_attribute('Primitive', tir.IntImm('int32', 1))
         add_sub_mul = add_sub_mul.set_attribute('Composite', tir.StringImm('add_sub_mul'))
         add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2])
         return relay.Function(inputs, add_sub_mul_call)
@@ -655,7 +640,6 @@ def test_multiple_input_subgraphs():
             add_relu = relay.add(x, y)
             add_relu = relay.nn.relu(add_relu)
             add_relu = relay.Function([x, y], add_relu)
-            add_relu = add_relu.set_attribute('Primitive', tir.IntImm('int32', 1))
             add_relu = add_relu.set_attribute('Composite', tir.StringImm('add_relu'))
             add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]])
             add_relu_calls.append(add_relu_call)