2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
18 from collections import deque
22 from mo.front.common.partial_infer.utils import int64_array
23 from mo.front.extractor import add_attrs_props
24 from mo.graph.graph import Node, Graph
25 from mo.middle.passes.eliminate import graph_clean_up
26 from mo.utils.graph import pseudo_topological_sort
27 from mo.ops.lin_op import Mul, Add
28 from mo.ops.op import Op
29 from mo.middle.passes.fusing.helpers import backward_bfs, forward_bfs, get_tensor_id, get_value_id
32 def _fuse_mul(graph: Graph, node: Node, fuse_nodes: list, backward: bool = True):
34 This function takes Mul node and array of convolution/fc nodes for further fusion
38 If backward is False, that means that Convolution/FC goes after Mul node
39 else means that Mul goes after Convolutions/FC
46 const_id, tensor_id = get_value_id(node), get_tensor_id(node)
48 if const_id is None or tensor_id is None:
49 log.warning('Cannot do fuse_mul for node {} because this node has wrong inputs'.format(node.id))
52 for fuse_node in fuse_nodes:
53 if fuse_node.soft_get('can_be_fused') == False:
54 log.warning('Node {} can\'t be used in fusing due to user specified attr can_be_fused = False'.format(fuse_node.id))
57 if len(fuse_node.in_nodes()) < 2:
58 log.warning('Node {} has no weights node'.format(fuse_node.id))
61 if not fuse_node.has_valid('layout'):
62 log.warning('Node {} has no layout attr'.format(fuse_node.id))
65 weights_node = fuse_node.in_node(1)
67 if not weights_node.has_valid('output_channel_dim') or not weights_node.has_valid('input_channel_dim'):
69 'Cannot do fuse_mul for node {} because there is no field ' +
70 'output_channel_dim and/or input_channel_dim in weights.'
71 .format(fuse_node.soft_get('name'))
75 inp_ch, out_ch = weights_node.input_channel_dim, weights_node.output_channel_dim
76 if max(inp_ch, out_ch) >= len(weights_node.shape):
77 log.warning('Node {} has wrong weights shape'.format(fuse_node.id))
80 for fuse_node in fuse_nodes:
81 weights_node = fuse_node.in_node(1)
82 value = np.array(node.in_node(const_id).value)
84 value = np.squeeze(value)
86 # TODO : ch_dim should be equal to node.in_node(1).value.shape
87 # We will multiply weights according output/input channel dimension
88 ch_dim = weights_node.output_channel_dim if backward else weights_node.input_channel_dim
89 shape = np.array([weights_node.shape[ch_dim]])
93 value = np.full(shape, value.item())
95 # Common broadcast for forward fusion
97 cnt = shape[-1] / value.shape[0]
98 if fuse_node.layout == 'NCHW':
101 tmp = np.concatenate((tmp, np.repeat(val, cnt)))
102 value = np.array(tmp)
104 value = np.tile(value, int(cnt))
106 # Expand dims for multiplication (ex. [38] to [38, 1, 1])
107 wdims_number = weights_node.dims_number
108 for x in range(wdims_number - ch_dim - 1):
109 shape = np.append(shape, 1)
111 mul_val = np.array(value)
112 value = np.reshape(value, shape)
114 # Weights multiplication
115 weights_node.value = weights_node.value * value
117 # If we fuse in backward direction we should multiply biases if they exists
118 if backward and len(fuse_node.in_nodes()) == 3:
119 conv_bias = fuse_node.in_node(2)
120 conv_bias.value = conv_bias.value * np.squeeze(mul_val)
121 log.debug('Fused: {} to {}'.format(node.name, fuse_node.name))
126 out_node = node.out_node()
127 op_data_node = node.in_node(tensor_id)
128 op_const_node = node.in_node(const_id)
129 op_node = op_data_node.in_node(0)
130 graph.remove_edge(node.id, out_node.id)
131 graph.remove_edge(op_node.id, op_data_node.id)
132 graph.remove_edge(op_const_node.id, node.id)
133 # Connect nodes after deleting
134 graph.add_edge(op_node.id, out_node.id, out=0)
135 for idx in reversed(range(len(op_data_node.out_nodes()))):
136 out_data = op_data_node.out_nodes()[idx]
137 edge_attrs = graph.get_edge_data(op_data_node.id, out_data.id)[0]
138 if not out_data.id is node.id:
139 graph.remove_edge(op_data_node.id, out_data.id)
140 graph.add_edges_from([(out_node.id, out_data.id, edge_attrs)])
145 def _fuse_add(graph: Graph, node: Node, fuse_nodes: list, backward: bool = True):
147 This function takes Add node and Convolution/FC nodes for further fusion and then deletes Add node
148 In case if Convolution/FC Bias absence it will be created
151 const_id, tensor_id = get_value_id(node), get_tensor_id(node)
153 if const_id is None or tensor_id is None:
154 log.warning('Cannot do fuse_add for node {} because this node has wrong inputs'.format(node.id))
157 # if len(node.in_node(const_id).shape) > 2 or any([x == 0 for x in node.in_node(const_id).shape]):
158 # log.warning('Cannot do fuse_add for node {} because this node has wrong shape'.format(node.id))
161 for fuse_node in fuse_nodes:
162 if fuse_node.soft_get('can_be_fused') == False:
163 log.warning('Node {} can\'t be used in fusing due to user specified attr can_be_fused = False'.format(fuse_node.id))
165 if not fuse_node.has_valid('layout'):
166 log.warning('Node {} has no layout attr'.format(fuse_node.id))
168 if len(fuse_node.in_nodes()) < 2:
169 log.warning('Node {} has no weights node'.format(fuse_node.id))
172 for fuse_node in fuse_nodes:
173 value = np.array(node.in_node(const_id).value)
175 # If forward, broadcast value
177 cnt = fuse_node.in_node(1).shape[-1] / node.in_node(const_id).shape[0]
178 if fuse_node.layout == 'NCHW':
181 tmp = np.concatenate((tmp, np.repeat(val, cnt)))
182 value = np.array(tmp)
184 value = np.tile(value, int(cnt))
186 value = np.squeeze(value)
188 # Create BIAS data node if not exists
189 if len(fuse_node.in_nodes()) <= 2:
190 bias_data = graph.unique_id("bias_data")
191 data_type = fuse_node.in_node(1).data_type
192 # Broadcast if scalar
194 id = fuse_node.in_node(1).output_channel_dim if backward else fuse_node.in_node(1).input_channel_dim
195 vshape = fuse_node.in_node(1).shape[id]
196 value = np.full(vshape, value.item())
199 value = np.dot(fuse_node.in_node(1).value, value)
201 shape = int64_array(value.shape)
203 graph.add_node(bias_data, **add_attrs_props(
204 dict(kind='data', precision="FP32", name=bias_data, value=value, shape=shape, data_type=data_type)))
205 graph.add_edges_from([(bias_data, fuse_node.id, {'in': 2, 'bin': 'biases'})])
206 fuse_node['bias_term'] = True
209 fuse_node.in_node(2).value += np.dot(fuse_node.in_node(1).value, value)
211 fuse_node.in_node(2).value += value
213 log.debug('Fused: {} to {}'.format(node.name, fuse_node.name))
218 out_node = node.out_node()
219 op_data_node = node.in_node(tensor_id)
220 op_const_node = node.in_node(const_id)
221 op_node = op_data_node.in_node(0)
222 graph.remove_edge(node.id, out_node.id)
223 graph.remove_edge(op_node.id, op_data_node.id)
224 graph.remove_edge(op_const_node.id, node.id)
225 # Connect nodes after deleting
226 graph.add_edge(op_node.id, out_node.id, out=0)
227 for idx in reversed(range(len(op_data_node.out_nodes()))):
228 out_data = op_data_node.out_nodes()[idx]
229 edge_attrs = graph.get_edge_data(op_data_node.id, out_data.id)[0]
230 if not out_data.id is node.id:
231 graph.remove_edge(op_data_node.id, out_data.id)
232 graph.add_edges_from([(out_node.id, out_data.id, edge_attrs)])
237 def fuse_linear_ops(graph: Graph):
239 This function makes fusing of linear operations (Mul,Add) to Convolution/FC.
243 # Fusion in backward direction
244 nodes = pseudo_topological_sort(graph)
246 node = Node(graph, idx)
249 # Fuse Mul to Convolution/FC
250 if node.soft_get('op') == 'Mul' and get_value_id(node) is not None and node.soft_get('can_be_fused') == True:
251 fuse_nodes = backward_bfs(node, [], ['Convolution', 'Deconvolution', 'FullyConnected'])
252 is_fused = _fuse_mul(graph, node, fuse_nodes)
254 # Fuse Add to Convolution/FC
255 if node.soft_get('op') == 'Add' and get_value_id(node) is not None and node.soft_get('can_be_fused') == True:
256 fuse_nodes = backward_bfs(node, [], ['Convolution', 'Deconvolution', 'FullyConnected'])
257 is_fused = _fuse_add(graph, node, fuse_nodes)
259 fuse_count += is_fused
261 # Fusion in forward direction
262 nodes = pseudo_topological_sort(graph, reverse=True)
264 node = Node(graph, idx)
267 # Fuse Mul to Convolution/FC
268 if node.soft_get('op') == 'Mul' and get_value_id(node) is not None and node.soft_get('can_be_fused') == True:
269 fuse_nodes = forward_bfs(node, [], ['Convolution', 'Deconvolution', 'FullyConnected'])
270 is_fused = _fuse_mul(graph, node, fuse_nodes, False)
272 # Fuse Add to Convolution/FC
273 if node.soft_get('op') == 'Add' and get_value_id(node) is not None and node.soft_get('can_be_fused') == True:
274 fuse_nodes = forward_bfs(node, [], ['FullyConnected'])
275 is_fused = _fuse_add(graph, node, fuse_nodes, False)
277 fuse_count += is_fused
279 log.debug("Fused {} nodes".format(fuse_count))