2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
18 from collections import deque
23 from mo.front.extractor import add_attrs_props
24 from mo.middle.passes.eliminate import graph_clean_up
25 from mo.utils.graph import pseudo_topological_sort
26 from mo.ops.lin_op import Mul, Add
27 from mo.middle.passes.eliminate import merge_data_nodes
28 from mo.ops.op import Op
29 from mo.graph.graph import Node, Graph
30 from mo.middle.passes.fusing.helpers import backward_bfs, forward_bfs, get_tensor_id, get_value_id
33 def _fuse_linear_sequence(graph: Graph, start_node: Node):
35 This function finds the sequence of Mul/Add operations and replaces this sequence with two ops (Mul->Add).
37 :param start_node: The first operation of the sequence
42 data_node = node.out_node()
43 if (len(data_node.out_nodes()) != 1):
45 if (data_node.out_node().op in ['Mul', 'Add']) and get_value_id(data_node.out_node()) is not None and data_node.out_node().soft_get('can_be_fused') == True:
46 fnodes.append(data_node.out_node())
50 if len(fnodes) == 1 or (len(fnodes) == 2 and fnodes[0].op == 'Mul' and fnodes[1].op == 'Add'):
53 input_shape = start_node.in_node(get_tensor_id(start_node)).shape
55 init_dims_cnt = len(input_shape) - 2 if graph.graph['layout'] == 'NCHW' else 1
57 mul = np.ones([1 for x in range(init_dims_cnt)])
58 add = np.zeros([1 for x in range(init_dims_cnt)])
63 for idx in range(len(fnodes)):
65 const_node = get_value_id(node)
67 if first_mul_name is None:
68 first_mul_name = node.name
69 mul = mul * node.in_node(const_node).value
70 add = add * node.in_node(const_node).value
71 elif node.op == 'Add':
72 if first_add_name is None:
73 first_add_name = node.name
74 add = add + node.in_node(const_node).value
76 # If mul is scalar we broadcast it to biases shape
77 if mul.shape != add.shape and len(mul.shape) == 1 and mul.shape[0] == 1:
78 mul = np.array([mul[0] for x in range(add.shape[0])])
80 assert (np.array_equal(fnodes[0].in_node(get_tensor_id(fnodes[0])).shape, fnodes[-1].out_node().shape))
82 mul_node = Mul(graph, dict(name=first_mul_name + '/Fused_Mul_' if first_mul_name is not None else ''))
83 add_node = Add(graph, dict(name=first_add_name + '/Fused_Add_' if first_add_name is not None else ''))
85 in_node = fnodes[0].in_node(get_tensor_id(fnodes[0]))
86 out_node = fnodes[-1].out_node()
88 graph.remove_edge(in_node.id, fnodes[0].id)
89 graph.remove_edge(fnodes[-1].id, out_node.id)
91 # Remove deleted subgraph
93 for tmp_node in node.in_nodes().values():
94 # Remove node only if it has one consumer (for case with shared weights)
95 if len(tmp_node.out_nodes()) == 1:
96 graph.remove_node(tmp_node.id)
97 for tmp_node in node.out_nodes().values():
98 graph.remove_node(tmp_node.id)
99 graph.remove_node(node.id)
102 Four cases considered below:
103 1. Mul and Add have valid values (mul value != 1 and add value != 0)
104 2. Only Mul has valid values, so we add only Mul node
105 3. Only Add has valid values, so we add only Add node
106 4. When Mul and Add has not valid values we just merge two data nodes
108 if any([x != 0 for x in np.nditer(add)]) and any([x != 1 for x in np.nditer(mul)]):
109 data_mul = Op.create_input_data_node(graph, "data_mul_", np.array(mul))
110 data_add = Op.create_input_data_node(graph, "data_add_", np.array(add))
111 add_node.create_node_with_data(inputs=[mul_node.create_node_with_data([in_node, data_mul]), data_add],
113 elif any([x != 1 for x in np.nditer(mul)]):
114 data_mul = Op.create_input_data_node(graph, "data_mul_", np.array(mul))
115 mul_node.create_node_with_data(inputs=[in_node, data_mul], data_nodes=out_node)
116 elif any([x != 0 for x in np.nditer(add)]):
117 data_add = Op.create_input_data_node(graph, "data_add_", np.array(add))
118 add_node.create_node_with_data(inputs=[in_node, data_add], data_nodes=out_node)
120 merge_data_nodes(graph,out_node, in_node)
121 graph.remove_node(in_node.id)
123 log.debug('Fused {} operations'.format(len(fnodes)))
127 def fuse_mul_add_sequence(graph: Graph):
129 This function finds first valid Mul/Add node and pass it to fuse_linear_sequence where full sequence will be found
133 for idx in list(pseudo_topological_sort(graph)):
135 node = Node(graph, idx)
136 if node.soft_get('op') in ['Mul','Add'] and get_value_id(node) is not None and node.soft_get('can_be_fused') == True:
137 is_fused |= _fuse_linear_sequence(graph, node)