add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
- add_relu = add_relu.set_attribute("Primitive", tir.IntImm("int32", 1))
add_relu = add_relu.set_attribute("Composite", tir.StringImm("add_relu"))
# merged function
sub_node = relay.subtract(in_1, in_2)
mul_node = relay.multiply(add_node, sub_node)
add_sub_mul = relay.Function([in_1, in_2], mul_node)
- add_sub_mul = add_sub_mul.set_attribute("Primitive",
- tir.IntImm("int32", 1))
add_sub_mul = add_sub_mul.set_attribute("Composite",
tir.StringImm("add_sub_mul"))
sub_node_1 = relay.subtract(in_3, in_4)
mul_node_1 = relay.multiply(add_node_1, sub_node_1)
add_sub_mul_1 = relay.Function([in_3, in_4], mul_node_1)
- add_sub_mul_1 = add_sub_mul_1.set_attribute("Primitive",
- tir.IntImm("int32", 1))
add_sub_mul_1 = add_sub_mul_1.set_attribute("Composite",
tir.StringImm("add_sub_mul"))
add_node_1 = relay.add(in_1, add_node)
add_node_2 = relay.add(add_node_1, add_node)
add_add_add = relay.Function([in_1, in_2], add_node_2)
- add_add_add = add_add_add.set_attribute("Primitive",
- tir.IntImm("int32", 1))
add_add_add = add_add_add.set_attribute("Composite",
tir.StringImm("add_add_add"))
bias_node = relay.nn.bias_add(conv_node, in_3)
r = relay.nn.relu(bias_node)
conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r)
- conv_bias_add_relu = conv_bias_add_relu.set_attribute("Primitive", tir.IntImm("int32", 1))
conv_bias_add_relu = conv_bias_add_relu.set_attribute("Composite",
tir.StringImm("conv2d_bias_relu"))
add_node = relay.add(in_4, in_5)
r = relay.nn.relu(add_node)
add_relu = relay.Function([in_4, in_5], r)
- add_relu = add_relu.set_attribute("Primitive", tir.IntImm("int32", 1))
add_relu = add_relu.set_attribute("Composite", tir.StringImm("add_relu"))
# merged function
out = relay.abs(out)
out = relay.nn.relu(out)
merged_func = relay.Function([x, y], out)
- merged_func = merged_func.set_attribute('Primitive', tir.IntImm('int32', 1))
merged_func = merged_func.set_attribute('Composite',
tir.StringImm(composite_name))
ret = relay.Call(merged_func, [input_1, input_2])
y = relay.var('y')
branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y))
func_1 = relay.Function([x, y], branch_1)
- func_1 = func_1.set_attribute('Primitive', tir.IntImm('int32', 1))
func_1 = func_1.set_attribute('Composite', tir.StringImm("add_sub_mul"))
call_1 = relay.Call(func_1, [input_1, input_2])
x1 = relay.var('x1')
y1 = relay.var('y1')
branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1))
func_2 = relay.Function([x1, y1], branch_2)
- func_2 = func_2.set_attribute('Primitive', tir.IntImm('int32', 1))
func_2 = func_2.set_attribute('Composite', tir.StringImm("add_sub_mul"))
call_2 = relay.Call(func_2, [input_1, input_2])
out = relay.multiply(call_1, call_2)
add_relu_1 = relay.add(x, y)
add_relu_1 = relay.nn.relu(add_relu_1)
add_relu_1 = relay.Function([x, y], add_relu_1)
- add_relu_1 = add_relu_1.set_attribute('Primitive', tir.IntImm('int32', 1))
add_relu_1 = add_relu_1.set_attribute('Composite', tir.StringImm('add_relu'))
add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]])
x1 = relay.var('x1')
add_relu_2 = relay.add(x1, y1)
add_relu_2 = relay.nn.relu(add_relu_2)
add_relu_2 = relay.Function([x1, y1], add_relu_2)
- add_relu_2 = add_relu_2.set_attribute('Primitive', tir.IntImm('int32', 1))
add_relu_2 = add_relu_2.set_attribute('Composite', tir.StringImm('add_relu'))
add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]])
x2 = relay.var('x2')
sub = relay.subtract(x2, y2)
add_sub_mul = relay.multiply(add, sub)
add_sub_mul = relay.Function([x2, y2], add_sub_mul)
- add_sub_mul = add_sub_mul.set_attribute('Primitive', tir.IntImm('int32', 1))
add_sub_mul = add_sub_mul.set_attribute('Composite', tir.StringImm('add_sub_mul'))
add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2])
return relay.Function(inputs, add_sub_mul_call)
add_relu = relay.add(x, y)
add_relu = relay.nn.relu(add_relu)
add_relu = relay.Function([x, y], add_relu)
- add_relu = add_relu.set_attribute('Primitive', tir.IntImm('int32', 1))
add_relu = add_relu.set_attribute('Composite', tir.StringImm('add_relu'))
add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]])
add_relu_calls.append(add_relu_call)