--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/pass/merge_composite.cc
+ * \brief Merges expressions matching patterns into functions marked
+ * as 'composite'. This is primarily intended to be used alongside the
+ * external codegen infrastructure to support the case where multiple
+ * Relay operators map to a single external operator.
+ */
+
+#include <tvm/te/operation.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+namespace merge_composite {
+
+class MergeCompositeWrapper : public ExprMutator {
+ public:
+ explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern)
+ : pattern_name_(pattern_name), pattern_(pattern) {}
+
+ Expr ExtractPattern(const Var& pattern, const Expr& root,
+ Map<std::string, Array<Expr>>* var_map) {
+ if (var_map->find(pattern->name_hint()) == var_map->end()) {
+ // if we haven't encountered this var yet, make a new free var and associate
+ // it with the value at 'root'
+ auto free_var = VarNode::make(pattern->name_hint(), Type());
+ var_map->Set(pattern->name_hint(), Array<Expr>({free_var, root}));
+ return std::move(free_var);
+ } else {
+ // if we have encountered this var already, return the free var that was created
+ auto vars = (*var_map)[pattern->name_hint()];
+ auto free_var = vars[0];
+ auto graph_expr = vars[1];
+ // make sure to first check they both map to the same node in the graph
+ if (graph_expr != root) {
+ return Expr();
+ }
+ return (*var_map)[pattern->name_hint()][0];
+ }
+ }
+
+ Expr ExtractPattern(const Constant& pattern, const Expr& root,
+ Map<std::string, Array<Expr>>* var_map) {
+ return root;
+ }
+
+ /*!
+ * \brief Try and extract a given pattern from a graph as a subgraph.
+ * \param pattern The pattern to extract.
+ * \param root The graph to extract from.
+ * \param var_map A map between free vars in the subgraph and nodes in the graph.
+ * \return The extracted subgraph.
+ *
+ * \note How does this work?
+ *
+ * A pattern consists of Relay expression containing only operator call nodes, constants
+ * and free variables. The free variables indicate where the pattern can 'attach' in your
+ * graph. This function takes the final call node of the pattern and the call node currently
+ * being traversed in the Relay graph. It traverses through the pattern in lockstep with call node
+ * from the graph (referred to as the 'root' node here) to check they're identical. If at any point
+ * they differ, an empty expression is returned to signify the extract failed. If a free var is
+ * reached in the pattern, the corresponding value in the root is associated with the name of the
+ * free var (via the var_map) so that when we construct the composite function, the inputs match
+ * up correctly with the rest of the graph. The return value of this function when successful is
+ * a new Relay expression ready to be wrapped into a composite function.
+ */
+ Expr ExtractPattern(const Call& pattern, const Call& root,
+ Map<std::string, Array<Expr>>* var_map) {
+ // check to make sure both calls are to operators (not functions)
+ if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>())
+ return Expr();
+ if (pattern->op.as<OpNode>()->name != root->op.as<OpNode>()->name)
+ return Expr();
+
+ unsigned int i = 0;
+ Array<Expr> new_args;
+ for (const auto& arg : pattern->args) {
+ Expr new_arg;
+ if (arg->IsInstance<CallNode>()) {
+ // fail if the root argument is not also a call node
+ if (!root->args[i]->IsInstance<CallNode>()) {
+ return Expr();
+ }
+ // if it's a call node, recursively call this function
+ new_arg = ExtractPattern(Downcast<Call>(arg),
+ Downcast<Call>(root->args[i]),
+ var_map);
+ } else if (arg->IsInstance<VarNode>()) {
+ // if there's a var in the pattern, it must be a free var
+ // so call the function to update the var_map
+ new_arg = ExtractPattern(Downcast<Var>(arg),
+ root->args[i],
+ var_map);
+ } else if (arg->IsInstance<ConstantNode>()) {
+ // if there's a constant, simply get the corresponding
+ // value of the constant from the root
+ new_arg = ExtractPattern(Downcast<Constant>(arg),
+ root->args[i],
+ var_map);
+ }
+ if (!new_arg.defined()) {
+ return Expr();
+ }
+ new_args.push_back(new_arg);
+ i++;
+ }
+ return CallNode::make(root->op, new_args, root->attrs);
+ }
+
+ Expr VisitExpr_(const CallNode* cn) {
+ Call call = GetRef<Call>(cn);
+ if (call->op->IsInstance<FunctionNode>()) {
+ Function func = Downcast<Function>(call->op);
+ CHECK(func.defined());
+ const auto name_node = FunctionGetAttr(func, attr::kComposite).as<tir::StringImmNode>();
+ // don't step into existing composite functions
+ if (name_node && name_node->value != "") {
+ tvm::Array<tvm::relay::Expr> new_args;
+ for (const auto& arg : call->args) {
+ auto new_e = this->Mutate(arg);
+ new_args.push_back(new_e);
+ }
+ return CallNode::make(call->op, new_args, call->attrs);
+ }
+ }
+
+ Expr expr = ExprMutator::VisitExpr_(cn);
+ call = Downcast<Call>(expr);
+ if (!call->op->IsInstance<OpNode>())
+ return std::move(call);
+
+ // only call patterns are supported
+ Call pattern = Downcast<Call>(pattern_);
+ CHECK(pattern.defined());
+ Map<std::string, Array<Expr>> args_map;
+ auto extract = ExtractPattern(pattern, call, &args_map);
+ if (extract.defined()) {
+ auto free_vars = FreeVars(extract);
+ // make the composite function
+ auto f = FunctionNode::make(free_vars, extract, call->checked_type_, {}, Attrs());
+ f = FunctionSetAttr(f, attr::kComposite, tir::StringImmNode::make(pattern_name_));
+ f = FunctionSetAttr(f, attr::kPrimitive, tvm::Integer(1));
+ // find the expressions associated with the free vars using the args_map
+ // this tells us which expressions should be given as inputs to the composite function
+ Array<Expr> args;
+ for (const auto& free_var : free_vars) {
+ args.push_back(args_map[free_var->name_hint()][1]);
+ }
+ auto new_call = CallNode::make(f, args);
+ return std::move(new_call);
+ }
+ return std::move(call);
+ }
+
+ private:
+ /*! \brief The name of the pattern to match */
+ std::string pattern_name_;
+ /*! \brief The pattern to match */
+ Expr pattern_;
+};
+
+Expr MergeComposite(const Expr& expr,
+ const Array<tir::StringImm>& pattern_names, const Array<Expr>& patterns) {
+ CHECK_EQ(pattern_names.size(), patterns.size());
+ Expr merged_expr = expr;
+ // merge the patterns one-by-one in order
+ for (size_t i = 0; i < patterns.size(); i++) {
+ std::string pattern_name = pattern_names[i]->value;
+ Expr pattern = patterns[i];
+ merged_expr = MergeCompositeWrapper(pattern_name, pattern).Mutate(merged_expr);
+ }
+ return merged_expr;
+}
+
+} // namespace merge_composite
+
+namespace transform {
+
+Pass MergeComposite(const tvm::Array<tir::StringImm>& pattern_names,
+ const tvm::Array<Expr>& patterns) {
+ runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(
+ relay::merge_composite::MergeComposite(f, pattern_names, patterns));
+ };
+ auto func_pass = CreateFunctionPass(pass_func, 0, "MergeComposite", {});
+ return func_pass;
+}
+
+TVM_REGISTER_GLOBAL("relay._transform.MergeComposite")
+.set_body_typed(MergeComposite);
+
+} // namespace transform
+
+} // namespace relay
+} // namespace tvm
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for merge composite."""
+from tvm import expr
+from tvm import relay
+from tvm.relay.testing import run_opt_pass
+
+"""
+The merge composite pass is designed to merge multiple relay operators, that
+match a given pattern, and combine them into a single relay function.
+
+For example suppose we have the graph:
+
+ conv2d
+ | (merge composite pass)
+ bias_add ====> conv2d_bias_relu
+ | (our target)
+ relu
+
+Our Relay IR before the pass:
+ fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
+ %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
+ %0 = nn.conv2d(%data, %kernel, kernel_size=[1, 1])
+ /* ty=Tensor[(1, 256, 28, 28), float32] */;
+ %1 = nn.bias_add(%0, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+ nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
+ }
+
+Our Relay IR after the pass:
+ fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
+ %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
+ %2 = fn (%x: Tensor[(1, 512, 28, 28), float32], %y: Tensor[(256, 512, 1, 1), float32],
+ %z: Tensor[(256), float32], Primitive=1, Composite="conv2d_bias_relu") ->
+ Tensor[(1, 256, 28, 28), float32] {
+ %0 = nn.conv2d(%x, %y, kernel_size=[1, 1]) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+ %1 = nn.bias_add(%0, %z) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+ nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
+ };
+ %2(%data, %kernel, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */
+ }
+
+As you can see in the second relay example, the pattern we specified has been wrapped
+in a function. The function is then called, producing the same result as the first relay
+example.
+
+One convenient use for this pass is to offload multiple operators to a single external
+codegen function.
+"""
+
+
+def make_add_sub_mul_pattern():
+ """Create a pattern to match the following graph.
+
+ add sub
+ \ /
+ \ /
+ mul
+ """
+ x = relay.var('x')
+ y = relay.var('y')
+ add_node = relay.add(x, y)
+ sub_node = relay.subtract(x, y)
+ mul_node = relay.multiply(add_node, sub_node)
+ return mul_node
+
+
+def make_add_relu_pattern():
+ """Create a pattern to match the following graph.
+
+ add
+ |
+ relu
+ """
+ x = relay.var('x')
+ y = relay.var('y')
+ add_node = relay.add(x, y)
+ r = relay.nn.relu(add_node)
+ return r
+
+
+def make_conv_bias_relu_pattern():
+ """Create a pattern to match the following graph.
+
+ conv2d
+ |
+ bias_add
+ |
+ relu
+ """
+ x = relay.var('x')
+ y = relay.var('y')
+ z = relay.var('z')
+ conv_node = relay.nn.conv2d(x, y)
+ bias_node = relay.nn.bias_add(conv_node, z)
+ r = relay.nn.relu(bias_node)
+ return r
+
+
+def test_simple_merge():
+ """Test composite function is correctly produced from simple graph.
+
+ We could expect the pattern `make_add_relu_pattern` to be merged
+ into a single op `add_relu`.
+
+ a b
+ \ / a b
+ add ====> \ /
+ | add_relu
+ relu
+
+ """
+ pattern_table = [
+ ("add_relu", make_add_relu_pattern())
+ ]
+
+ def before():
+ a = relay.var('a', shape=(10, 10))
+ b = relay.var('b', shape=(10, 10))
+ add_node = relay.add(a, b)
+ r = relay.nn.relu(add_node)
+ return relay.Function([a, b], r)
+
+ def expected():
+ a = relay.var('a', shape=(10, 10))
+ b = relay.var('b', shape=(10, 10))
+
+ # add_relu function
+ in_1 = relay.var('in_1', shape=(10, 10))
+ in_2 = relay.var('in_2', shape=(10, 10))
+ add_node = relay.add(in_1, in_2)
+ relu_node = relay.nn.relu(add_node)
+ add_relu = relay.Function([in_1, in_2], relu_node)
+
+ # merged function
+ r = relay.Call(add_relu, [a, b])
+ return relay.Function([a, b], r)
+
+ result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+ assert not relay.analysis.free_vars(result)
+ expected = run_opt_pass(expected(), relay.transform.InferType())
+ assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_branch_merge():
+ """Test composite function is correctly produced from branching graph.
+
+ We would expect the pattern `make_add_sub_mul_pattern` to be merged
+ into a single op `add_sub_mul`.
+
+ a b a b
+ \/ \/
+ add sub a b
+ \ / \/
+ \ / add_sub_mul
+ mul c |
+ / \ \ |
+ c / c | ====> add_sub_mul
+ \/ \/ |
+ add sub |
+ \ / relu
+ \ /
+ mul
+ |
+ |
+ relu
+ """
+
+ pattern_table = [
+ ("add_sub_mul", make_add_sub_mul_pattern())
+ ]
+
+ def before():
+ a = relay.var('a', shape=(10, 10))
+ b = relay.var('b', shape=(10, 10))
+ c = relay.var('c', shape=(10, 10))
+ add_node = relay.add(a, b)
+ sub_node = relay.subtract(a, b)
+ mul_node = relay.multiply(add_node, sub_node)
+ add_node_2 = relay.add(c, mul_node)
+ sub_node_2 = relay.subtract(c, mul_node)
+ mul_node_2 = relay.multiply(add_node_2, sub_node_2)
+ r = relay.nn.relu(mul_node_2)
+ return relay.Function([a, b, c], r)
+
+ def expected():
+ a = relay.var('a', shape=(10, 10))
+ b = relay.var('b', shape=(10, 10))
+ c = relay.var('c', shape=(10, 10))
+
+ # add_sub_mul function
+ in_1 = relay.var('in_1', shape=(10, 10))
+ in_2 = relay.var('in_2', shape=(10, 10))
+ add_node = relay.add(in_1, in_2)
+ sub_node = relay.subtract(in_1, in_2)
+ mul_node = relay.multiply(add_node, sub_node)
+ add_sub_mul = relay.Function([in_1, in_2], mul_node)
+
+ # merged function
+ add_sub_mul_1 = relay.Call(add_sub_mul, [a, b])
+ add_sub_mul_2 = relay.Call(add_sub_mul, [c, add_sub_mul_1])
+ r = relay.nn.relu(add_sub_mul_2)
+ return relay.Function([a, b, c], r)
+
+ result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+ assert not relay.analysis.free_vars(result)
+ expected = run_opt_pass(expected(), relay.transform.InferType())
+ assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_multiple_patterns():
+ """Test different patterns are merged correctly in the graph.
+
+ We would expect the pattern `make_conv_bias_relu_pattern` to be merged
+ into a single op `conv_bias_relu`. We would also expect `make_add_relu_pattern`
+ to be merged into a single op `add_relu`.
+
+ data kernel
+ \ /
+ \ /
+ conv2d data kernel bias
+ | \ | /
+ | bias conv2d_bias_relu
+ | / |
+ bias_add ====> | a
+ | | /
+ relu a add_relu
+ \ / |
+ add | b
+ | | /
+ relu b mul
+ | /
+ mul
+ """
+ pattern_table = [
+ ("conv2d_bias_relu", make_conv_bias_relu_pattern()),
+ ("add_relu", make_add_relu_pattern())
+ ]
+
+ def before():
+ data = relay.var('data', shape=(1, 512, 28, 28))
+ kernel = relay.var('kernel', shape=(256, 512, 1, 1))
+ bias = relay.var('bias', shape=(256,))
+ a = relay.var('a', shape=(1, 256, 28, 28))
+ b = relay.var('b', shape=(1, 256, 28, 28))
+
+ conv_node = relay.nn.conv2d(data,
+ kernel,
+ kernel_size=(1, 1),
+ padding=(0, 0),
+ strides=(1, 1))
+
+ bias_node = relay.nn.bias_add(conv_node, bias)
+ relu_node = relay.nn.relu(bias_node)
+ add_node = relay.add(relu_node, a)
+ relu_node_2 = relay.nn.relu(add_node)
+ r = relay.multiply(relu_node_2, b)
+ return relay.Function([data, kernel, bias, a, b], r)
+
+ def expected():
+ data = relay.var('data', shape=(1, 512, 28, 28))
+ kernel = relay.var('kernel', shape=(256, 512, 1, 1))
+ bias = relay.var('bias', shape=(256,))
+ a = relay.var('a', shape=(1, 256, 28, 28))
+ b = relay.var('b', shape=(1, 256, 28, 28))
+
+ # conv_bias_relu function
+ in_1 = relay.var('in_1', shape=(1, 512, 28, 28))
+ in_2 = relay.var('in_2', shape=(256, 512, 1, 1))
+ in_3 = relay.var('in_3', shape=(256,))
+
+ conv_node = relay.nn.conv2d(in_1,
+ in_2,
+ kernel_size=(1, 1),
+ padding=(0, 0),
+ strides=(1, 1))
+
+ bias_node = relay.nn.bias_add(conv_node, in_3)
+ r = relay.nn.relu(bias_node)
+ conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r)
+
+ # add_relu function
+ in_4 = relay.var('in_4', shape=(1, 256, 28, 28))
+ in_5 = relay.var('in_5', shape=(1, 256, 28, 28))
+ add_node = relay.add(in_4, in_5)
+ r = relay.nn.relu(add_node)
+ add_relu = relay.Function([in_4, in_5], r)
+
+ # merged function
+ conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias])
+ add_relu_1 = relay.Call(add_relu, [conv_bias_add_relu_1, a])
+ r = relay.multiply(add_relu_1, b)
+ return relay.Function([data, kernel, bias, a, b], r)
+
+ result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+ assert not relay.analysis.free_vars(result)
+ expected = run_opt_pass(expected(), relay.transform.InferType())
+ assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_merge_order():
+ """Test that patterns are merged in the order they exist in the pattern table.
+
+ There can be cases where one pattern is a subgraph of another, in which case
+ it is not clear which match should take priority. The priority should come
+ from the order in which the patterns are declared in the pattern table. The
+ first patterns will be merged with highest priority and the last with lowest.
+
+ A: B: C:
+ add add abs
+ | | |
+ abs abs relu
+ |
+ relu
+
+ """
+
+ def pattern_A():
+ x = relay.var('x')
+ y = relay.var('y')
+ out = relay.add(x, y)
+ out = relay.abs(out)
+ out = relay.nn.relu(out)
+ return out
+
+ def pattern_B():
+ x = relay.var('x')
+ y = relay.var('y')
+ out = relay.add(x, y)
+ out = relay.abs(out)
+ return out
+
+ def pattern_C():
+ x = relay.var('x')
+ out = relay.abs(x)
+ out = relay.nn.relu(x)
+ return out
+
+ def before():
+ input_1 = relay.var('input_1', shape=(10, 10))
+ input_2 = relay.var('input_2', shape=(10, 10))
+ out = relay.add(input_1, input_2)
+ out = relay.abs(out)
+ out = relay.nn.relu(out)
+ return relay.Function([input_1, input_2], out)
+
+ def after_A_priority():
+ input_1 = relay.var('input_1', shape=(10, 10))
+ input_2 = relay.var('input_2', shape=(10, 10))
+ x = relay.var('x')
+ y = relay.var('y')
+ out = relay.add(x, y)
+ out = relay.abs(out)
+ out = relay.nn.relu(out)
+ merged_func = relay.Function([x, y], out)
+ merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1))
+ merged_func = merged_func.set_attribute('Composite', expr.StringImm('A'))
+ ret = relay.Call(merged_func, [input_1, input_2])
+ return relay.Function([input_1, input_2], ret)
+
+ def after_B_priority():
+ input_1 = relay.var('input_1', shape=(10, 10))
+ input_2 = relay.var('input_2', shape=(10, 10))
+ x = relay.var('x')
+ y = relay.var('y')
+ out = relay.add(x, y)
+ out = relay.abs(out)
+ merged_func = relay.Function([x, y], out)
+ merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1))
+ merged_func = merged_func.set_attribute('Composite', expr.StringImm('B'))
+ merged_call = relay.Call(merged_func, [input_1, input_2])
+ ret = relay.nn.relu(merged_call)
+ return relay.Function([input_1, input_2], ret)
+
+ def after_C_priority():
+ input_1 = relay.var('input_1', shape=(10, 10))
+ input_2 = relay.var('input_2', shape=(10, 10))
+ add = relay.add(input_1, input_2)
+ x = relay.var('x')
+ out = relay.abs(x)
+ out = relay.nn.relu(out)
+ merged_func = relay.Function([x], out)
+ merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1))
+ merged_func = merged_func.set_attribute('Composite', expr.StringImm('C'))
+ ret = relay.Call(merged_func, [add])
+ return relay.Function([input_1, input_2], ret)
+
+ # check A highest priority
+ pattern_table = [
+ ("A", pattern_A()),
+ ("B", pattern_B()),
+ ("C", pattern_C()),
+ ]
+ result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+ assert not relay.analysis.free_vars(result)
+ expected = run_opt_pass(after_A_priority(), relay.transform.InferType())
+ assert relay.analysis.alpha_equal(result, expected)
+
+ # check B highest priority
+ pattern_table = [
+ ("B", pattern_A()),
+ ("C", pattern_B()),
+ ("A", pattern_C()),
+ ]
+ result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+ assert not relay.analysis.free_vars(result)
+ expected = run_opt_pass(after_A_priority(), relay.transform.InferType())
+ assert relay.analysis.alpha_equal(result, expected)
+
+ # check C highest priority
+ pattern_table = [
+ ("C", pattern_A()),
+ ("A", pattern_B()),
+ ("B", pattern_C()),
+ ]
+ result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+ assert not relay.analysis.free_vars(result)
+ expected = run_opt_pass(after_A_priority(), relay.transform.InferType())
+ assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_parallel_merge():
+ """Tests that parallel patterns relying on the same inputs are correctly merged.
+
+ The test graph is difficult to draw out as ascii art. It is essentially two parallel
+ add-sub-mul units which both consume input_1 and input_2 with their results being multiplied
+ to give the output. We expect both parallel branches should get merged and both should still
+ consume the same input variables, input_1 and input_2."""
+
+ def before():
+ input_1 = relay.var('input_1', shape=(10, 10))
+ input_2 = relay.var('input_2', shape=(10, 10))
+ branch_1_add = relay.add(input_1, input_2)
+ branch_1_sub = relay.subtract(input_1, input_2)
+ branch_1 = relay.multiply(branch_1_add, branch_1_sub)
+ branch_2_add = relay.add(input_1, input_2)
+ branch_2_sub = relay.subtract(input_1, input_2)
+ branch_2 = relay.multiply(branch_2_add, branch_2_sub)
+ out = relay.multiply(branch_1, branch_2)
+ return relay.Function([input_1, input_2], out)
+
+ def after():
+ input_1 = relay.var('input_1', shape=(10, 10))
+ input_2 = relay.var('input_2', shape=(10, 10))
+ x = relay.var('x')
+ y = relay.var('y')
+ branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y))
+ func_1 = relay.Function([x, y], branch_1)
+ call_1 = relay.Call(func_1, [input_1, input_2])
+ x1 = relay.var('x1')
+ y1 = relay.var('y1')
+ branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1))
+ func_2 = relay.Function([x1, y1], branch_2)
+ call_2 = relay.Call(func_2, [input_1, input_2])
+ out = relay.multiply(call_1, call_2)
+ return relay.Function([input_1, input_2], out)
+
+ pattern_table = [
+ ("add_sub_mul", make_add_sub_mul_pattern())
+ ]
+ result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+ assert not relay.analysis.free_vars(result)
+ expected = run_opt_pass(after(), relay.transform.InferType())
+ assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_multiple_input_subgraphs():
+ """Test the case when multiple input subgraphs feed into another subgraph.
+
+ (1) (2) (3) (4)
+ add add add add
+ | | | |
+ relu relu relu relu
+ \ / \ /
+ \ / \ /
+ add sub
+ \ /
+ \ /
+ \ /
+ mul
+
+ ----> When 1=3 and 2=4 (Case 'A')
+
+ add_relu add_relu
+ \ /
+ \ /
+ add_sub_mul
+
+ ----> When 1!=3 and 2!=4 (Case 'B')
+
+ add_relu add_relu add_relu add_relu
+ \ / \ /
+ \ / \ /
+ add sub
+ \ /
+ -------- -----
+ \ /
+ mul
+
+ The difference in behaviour comes from the fact that add_sub_mul expects that the
+ inputs to add and sub are identical (the same two relay expressions). So when you
+ have 4 independent inputs, the pattern should not be merged.
+ """
+
+ def before():
+ before_funcs = {}
+ inputs = [relay.var('input_' + str(i), shape=(10, 10)) for i in range(8)]
+ add_relu_1 = relay.add(inputs[0], inputs[1])
+ add_relu_1 = relay.nn.relu(add_relu_1)
+ add_relu_2 = relay.add(inputs[2], inputs[3])
+ add_relu_2 = relay.nn.relu(add_relu_2)
+ add_relu_3 = relay.add(inputs[4], inputs[5])
+ add_relu_3 = relay.nn.relu(add_relu_3)
+ add_relu_4 = relay.add(inputs[6], inputs[7])
+ add_relu_4 = relay.nn.relu(add_relu_4)
+ add = relay.add(add_relu_1, add_relu_2)
+ sub = relay.subtract(add_relu_3, add_relu_4)
+ out = relay.multiply(add, sub)
+ before_funcs['B'] = relay.Function(inputs, out)
+ sub = relay.subtract(add_relu_1, add_relu_2)
+ out = relay.multiply(add, sub)
+ before_funcs['A'] = relay.Function(inputs[:4], out)
+ return before_funcs
+
+ def after_A():
+ inputs = [relay.var('input_' + str(i), shape=(10, 10)) for i in range(4)]
+ x = relay.var('x')
+ y = relay.var('y')
+ add_relu_1 = relay.add(x, y)
+ add_relu_1 = relay.nn.relu(add_relu_1)
+ add_relu_1 = relay.Function([x, y], add_relu_1)
+ add_relu_1 = add_relu_1.set_attribute('Primitive', expr.IntImm('int32', 1))
+ add_relu_1 = add_relu_1.set_attribute('Composite', expr.StringImm('add_relu'))
+ add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]])
+ x1 = relay.var('x1')
+ y1 = relay.var('y1')
+ add_relu_2 = relay.add(x1, y1)
+ add_relu_2 = relay.nn.relu(add_relu_2)
+ add_relu_2 = relay.Function([x1, y1], add_relu_2)
+ add_relu_2 = add_relu_2.set_attribute('Primitive', expr.IntImm('int32', 1))
+ add_relu_2 = add_relu_2.set_attribute('Composite', expr.StringImm('add_relu'))
+ add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]])
+ x2 = relay.var('x2')
+ y2 = relay.var('y2')
+ add = relay.add(x2, y2)
+ sub = relay.subtract(x2, y2)
+ add_sub_mul = relay.multiply(add, sub)
+ add_sub_mul = relay.Function([x2, y2], add_sub_mul)
+ add_sub_mul = add_sub_mul.set_attribute('Primitive', expr.IntImm('int32', 1))
+ add_sub_mul = add_sub_mul.set_attribute('Composite', expr.StringImm('add_sub_mul'))
+ add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2])
+ return relay.Function(inputs, add_sub_mul_call)
+
+ def after_B():
+ inputs = [relay.var('input_' + str(i), shape=(10, 10)) for i in range(8)]
+ add_relu_calls = []
+ for i in range(4):
+ x = relay.var('x' + str(i))
+ y = relay.var('x' + str(i))
+ add_relu = relay.add(x, y)
+ add_relu = relay.nn.relu(add_relu)
+ add_relu = relay.Function([x, y], add_relu)
+ add_relu = add_relu.set_attribute('Primitive', expr.IntImm('int32', 1))
+ add_relu = add_relu.set_attribute('Composite', expr.StringImm('add_relu'))
+ add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]])
+ add_relu_calls.append(add_relu_call)
+
+ add = relay.add(add_relu_calls[0], add_relu_calls[1])
+ sub = relay.subtract(add_relu_calls[2], add_relu_calls[3])
+ out = relay.multiply(add, sub)
+ return relay.Function(inputs, out)
+
+ pattern_table = [
+ ("add_sub_mul", make_add_sub_mul_pattern()),
+ ("add_relu", make_add_relu_pattern())
+ ]
+ # check case 'A'
+ result = run_opt_pass(before()['A'], relay.transform.MergeComposite(pattern_table))
+ assert not relay.analysis.free_vars(result)
+ expected = run_opt_pass(after_A(), relay.transform.InferType())
+ assert relay.analysis.alpha_equal(result, expected)
+
+ # check case 'B'
+ result = run_opt_pass(before()['B'], relay.transform.MergeComposite(pattern_table))
+ assert not relay.analysis.free_vars(result)
+ expected = run_opt_pass(after_B(), relay.transform.InferType())
+ assert relay.analysis.alpha_equal(result, expected)
+
+
+if __name__ == "__main__":
+ test_simple_merge()
+ test_branch_merge()
+ test_multiple_patterns()
+ test_merge_order()
+ test_parallel_merge()
+ test_multiple_input_subgraphs()
\ No newline at end of file