Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / middle / passes / fusing / fuse_linear_ops.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18 from collections import deque
19
20 import numpy as np
21
22 from mo.front.common.partial_infer.utils import int64_array
23 from mo.front.extractor import add_attrs_props
24 from mo.graph.graph import Node, Graph
25 from mo.middle.passes.eliminate import graph_clean_up
26 from mo.utils.graph import pseudo_topological_sort
27 from mo.ops.lin_op import Mul, Add
28 from mo.ops.op import Op
29 from mo.middle.passes.fusing.helpers import backward_bfs, forward_bfs, get_tensor_id, get_value_id
30
31
32 def _fuse_mul(graph: Graph, node: Node, fuse_nodes: list, backward: bool = True):
33     """
34     This function takes Mul node and array of convolution/fc nodes for further fusion
35     Parameters
36     ----------
37     x : bool
38         If backward is False, that means that Convolution/FC goes after Mul node
39         else means that Mul goes after Convolutions/FC
40         :param backward:
41         :param fuse_nodes:
42         :param node:
43         :param graph:
44     """
45     is_fused = False
46     const_id, tensor_id = get_value_id(node), get_tensor_id(node)
47
48     if const_id is None or tensor_id is None:
49         log.warning('Cannot do fuse_mul for node {} because this node has wrong inputs'.format(node.id))
50         return False
51
52     for fuse_node in fuse_nodes:
53         if fuse_node.soft_get('can_be_fused') == False:
54             log.warning('Node {} can\'t be used in fusing due to user specified attr can_be_fused = False'.format(fuse_node.id))
55             return False
56
57         if len(fuse_node.in_nodes()) < 2:
58             log.warning('Node {} has no weights node'.format(fuse_node.id))
59             return False
60
61         if not fuse_node.has_valid('layout'):
62             log.warning('Node {} has no layout attr'.format(fuse_node.id))
63             return False
64
65         weights_node = fuse_node.in_node(1)
66
67         if not weights_node.has_valid('output_channel_dim') or not weights_node.has_valid('input_channel_dim'):
68             log.warning(
69                 'Cannot do fuse_mul for node {} because there is no field ' +
70                 'output_channel_dim and/or input_channel_dim in weights.'
71                 .format(fuse_node.soft_get('name'))
72             )
73             return False
74
75         inp_ch, out_ch = weights_node.input_channel_dim, weights_node.output_channel_dim
76         if max(inp_ch, out_ch) >= len(weights_node.shape):
77             log.warning('Node {} has wrong weights shape'.format(fuse_node.id))
78             return False
79
80     for fuse_node in fuse_nodes:
81         weights_node = fuse_node.in_node(1)
82         value = np.array(node.in_node(const_id).value)
83
84         value = np.squeeze(value)
85
86         # TODO : ch_dim should be equal to node.in_node(1).value.shape
87         # We will multiply weights according output/input channel dimension
88         ch_dim = weights_node.output_channel_dim if backward else weights_node.input_channel_dim
89         shape = np.array([weights_node.shape[ch_dim]])
90
91         # Scalar broadcast
92         if value.size == 1:
93             value = np.full(shape, value.item())
94
95         # Common broadcast for forward fusion
96         if not backward:
97             cnt = shape[-1] / value.shape[0]
98             if fuse_node.layout == 'NCHW':
99                 tmp = []
100                 for val in value:
101                     tmp = np.concatenate((tmp, np.repeat(val, cnt)))
102                 value = np.array(tmp)
103             else:
104                 value = np.tile(value, int(cnt))
105
106         # Expand dims for multiplication (ex. [38] to [38, 1, 1])
107         wdims_number = weights_node.dims_number
108         for x in range(wdims_number - ch_dim - 1):
109             shape = np.append(shape, 1)
110
111         mul_val = np.array(value)
112         value = np.reshape(value, shape)
113
114         # Weights multiplication
115         weights_node.value = weights_node.value * value
116
117         # If we fuse in backward direction we should multiply biases if they exists
118         if backward and len(fuse_node.in_nodes()) == 3:
119             conv_bias = fuse_node.in_node(2)
120             conv_bias.value = conv_bias.value * np.squeeze(mul_val)
121         log.debug('Fused: {} to {}'.format(node.name, fuse_node.name))
122         is_fused = True
123
124     if is_fused:
125         # Delete Mul node
126         out_node = node.out_node()
127         op_data_node = node.in_node(tensor_id)
128         op_const_node = node.in_node(const_id)
129         op_node = op_data_node.in_node(0)
130         graph.remove_edge(node.id, out_node.id)
131         graph.remove_edge(op_node.id, op_data_node.id)
132         graph.remove_edge(op_const_node.id, node.id)
133         # Connect nodes after deleting
134         graph.add_edge(op_node.id, out_node.id, out=0)
135         for idx in reversed(range(len(op_data_node.out_nodes()))):
136             out_data = op_data_node.out_nodes()[idx]
137             edge_attrs = graph.get_edge_data(op_data_node.id, out_data.id)[0]
138             if not out_data.id is node.id:
139                 graph.remove_edge(op_data_node.id, out_data.id)
140                 graph.add_edges_from([(out_node.id, out_data.id, edge_attrs)])
141
142     return is_fused
143
144
145 def _fuse_add(graph: Graph, node: Node, fuse_nodes: list, backward: bool = True):
146     """
147     This function takes Add node and Convolution/FC nodes for further fusion and then deletes Add node
148     In case if Convolution/FC Bias absence it will be created
149     """
150     is_fused = False
151     const_id, tensor_id = get_value_id(node), get_tensor_id(node)
152
153     if const_id is None or tensor_id is None:
154         log.warning('Cannot do fuse_add for node {} because this node has wrong inputs'.format(node.id))
155         return False
156
157     # if len(node.in_node(const_id).shape) > 2 or any([x == 0 for x in node.in_node(const_id).shape]):
158     #     log.warning('Cannot do fuse_add for node {} because this node has wrong shape'.format(node.id))
159     #     return False
160
161     for fuse_node in fuse_nodes:
162         if fuse_node.soft_get('can_be_fused') == False:
163             log.warning('Node {} can\'t be used in fusing due to user specified attr can_be_fused = False'.format(fuse_node.id))
164             return False
165         if not fuse_node.has_valid('layout'):
166             log.warning('Node {} has no layout attr'.format(fuse_node.id))
167             return False
168         if len(fuse_node.in_nodes()) < 2:
169             log.warning('Node {} has no weights node'.format(fuse_node.id))
170             return False
171
172     for fuse_node in fuse_nodes:
173         value = np.array(node.in_node(const_id).value)
174
175         # If forward, broadcast value
176         if not backward:
177             cnt = fuse_node.in_node(1).shape[-1] / node.in_node(const_id).shape[0]
178             if fuse_node.layout == 'NCHW':
179                 tmp = []
180                 for val in value:
181                     tmp = np.concatenate((tmp, np.repeat(val, cnt)))
182                 value = np.array(tmp)
183             else:
184                 value = np.tile(value, int(cnt))
185
186         value = np.squeeze(value)
187
188         # Create BIAS data node if not exists
189         if len(fuse_node.in_nodes()) <= 2:
190             bias_data = graph.unique_id("bias_data")
191             data_type = fuse_node.in_node(1).data_type
192             # Broadcast if scalar
193             if value.size == 1:
194                 id = fuse_node.in_node(1).output_channel_dim if backward else fuse_node.in_node(1).input_channel_dim
195                 vshape = fuse_node.in_node(1).shape[id]
196                 value = np.full(vshape, value.item())
197
198             if not backward:
199                 value = np.dot(fuse_node.in_node(1).value, value)
200
201             shape = int64_array(value.shape)
202
203             graph.add_node(bias_data, **add_attrs_props(
204                 dict(kind='data', precision="FP32", name=bias_data, value=value, shape=shape, data_type=data_type)))
205             graph.add_edges_from([(bias_data, fuse_node.id, {'in': 2, 'bin': 'biases'})])
206             fuse_node['bias_term'] = True
207         else:
208             if not backward:
209                 fuse_node.in_node(2).value += np.dot(fuse_node.in_node(1).value, value)
210             else:
211                 fuse_node.in_node(2).value += value
212
213         log.debug('Fused: {} to {}'.format(node.name, fuse_node.name))
214         is_fused = True
215
216     if is_fused:
217         # Delete Add node
218         out_node = node.out_node()
219         op_data_node = node.in_node(tensor_id)
220         op_const_node = node.in_node(const_id)
221         op_node = op_data_node.in_node(0)
222         graph.remove_edge(node.id, out_node.id)
223         graph.remove_edge(op_node.id, op_data_node.id)
224         graph.remove_edge(op_const_node.id, node.id)
225         # Connect nodes after deleting
226         graph.add_edge(op_node.id, out_node.id, out=0)
227         for idx in reversed(range(len(op_data_node.out_nodes()))):
228             out_data = op_data_node.out_nodes()[idx]
229             edge_attrs = graph.get_edge_data(op_data_node.id, out_data.id)[0]
230             if not out_data.id is node.id:
231                 graph.remove_edge(op_data_node.id, out_data.id)
232                 graph.add_edges_from([(out_node.id, out_data.id, edge_attrs)])
233
234     return is_fused
235
236
237 def fuse_linear_ops(graph: Graph):
238     """
239     This function makes fusing of linear operations (Mul,Add) to Convolution/FC.
240     """
241     fuse_count = 0
242
243     # Fusion in backward direction
244     nodes = pseudo_topological_sort(graph)
245     for idx in nodes:
246         node = Node(graph, idx)
247         is_fused = False
248
249         # Fuse Mul to Convolution/FC
250         if node.soft_get('op') == 'Mul' and get_value_id(node) is not None and node.soft_get('can_be_fused') == True:
251             fuse_nodes = backward_bfs(node, [], ['Convolution', 'Deconvolution', 'FullyConnected'])
252             is_fused = _fuse_mul(graph, node, fuse_nodes)
253
254         # Fuse Add to Convolution/FC
255         if node.soft_get('op') == 'Add' and get_value_id(node) is not None and node.soft_get('can_be_fused') == True:
256             fuse_nodes = backward_bfs(node, [], ['Convolution', 'Deconvolution', 'FullyConnected'])
257             is_fused = _fuse_add(graph, node, fuse_nodes)
258
259         fuse_count += is_fused
260
261     # Fusion in forward direction
262     nodes = pseudo_topological_sort(graph, reverse=True)
263     for idx in nodes:
264         node = Node(graph, idx)
265         is_fused = False
266
267         # Fuse Mul to Convolution/FC
268         if node.soft_get('op') == 'Mul' and get_value_id(node) is not None and node.soft_get('can_be_fused') == True:
269             fuse_nodes = forward_bfs(node, [], ['Convolution', 'Deconvolution', 'FullyConnected'])
270             is_fused = _fuse_mul(graph, node, fuse_nodes, False)
271
272         # Fuse Add to Convolution/FC
273         if node.soft_get('op') == 'Add' and get_value_id(node) is not None and node.soft_get('can_be_fused') == True:
274             fuse_nodes = forward_bfs(node, [], ['FullyConnected'])
275             is_fused = _fuse_add(graph, node, fuse_nodes, False)
276
277         fuse_count += is_fused
278
279     log.debug("Fused {} nodes".format(fuse_count))