Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / middle / passes / fusing / decomposition.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 import networkx as nx
20 import numpy as np
21
22 from mo.graph.graph import Node, Graph
23 from mo.middle.passes.eliminate import merge_data_nodes
24 from mo.middle.pattern_match import apply_pattern
25 from mo.ops.lin_op import Mul, Add
26 from mo.ops.op import Op
27 from mo.ops.reshape import Reshape
28
29
30 def convert_batch_norm(graph: Graph):
31     """
32     This function finds FusedBatchNorm layer (or BatchNorm for MXNet) and replaces with Mul->Add->Mul->Add sequence.
33     """
34     for n in list(graph.nodes()):
35         node = Node(graph, n)
36         if node.has_valid('op') and (
37                 node.op == 'FusedBatchNorm' or node.op == 'BatchNorm' or node.op == 'BatchNormalization'):
38             toutput = node.out_node()
39             tinput = node.in_node(0)
40
41             if any([node.in_node(i).value is None for i in range(1, len(node.in_nodes()))]):
42                 log.warning('Cannot translate FusedBatchNorm {} node with non-constant weights'.format(
43                     node.name if node.has_valid('name') else '<UNKNOWN>'))
44                 continue
45
46             const = node.in_node(1)
47             beta = node.in_node(2)
48             mean = node.in_node(3)
49             variance = node.in_node(4)
50             eps = node.eps
51
52             if node.has_valid('fix_gamma') and node.fix_gamma:
53                 const.value.fill(1.)
54
55             can_be_fused = False if not node.soft_get('can_be_fused') else True
56
57             # Remove edges from FusedBN node
58             graph.remove_edge(tinput.id, node.id)
59             graph.remove_edge(beta.id, node.id)
60             graph.remove_edge(const.id, node.id)
61             graph.remove_edge(mean.id, node.id)
62             graph.remove_edge(variance.id, node.id)
63             graph.remove_edge(node.id, toutput.id)
64
65             scale = 1. / np.sqrt(variance.value + eps)
66             shift = (mean.value * (-1)) * scale
67
68             # Expand dims for current layout
69             broadcast_dims_cnt = len(tinput.shape) - 2 if graph.graph['layout'] == 'NCHW' else 0
70             # Update values and shapes with new shape
71             Op.expand_node_shape(const, broadcast_dims_cnt)
72             Op.expand_node_shape(beta, broadcast_dims_cnt)
73
74             for idx in range(broadcast_dims_cnt):
75                 scale = np.expand_dims(scale, axis=-1)
76                 shift = np.expand_dims(shift, axis=-1)
77
78             _fused_batch_norm_decomposition(graph, tinput, toutput, const, beta, scale, shift, can_be_fused)
79
80
81 def _fused_batch_norm_decomposition(graph: Graph, tinput: Node, toutput: Node, gamma: Node, beta: Node,
82                                     mean: np.ndarray, variance: np.ndarray, can_be_fused=True):
83     """
84     This is common function for TF, Caffe and MXNet
85     It creates Mul->Add->Mul->Add subgraph
86     """
87     shape = tinput.shape
88
89     # Create first Mul & Add operations
90     mul1_node = Mul(graph, dict(name="Mul1_", can_be_fused=can_be_fused))
91     add1_node = Add(graph, dict(name="Add1_", can_be_fused=can_be_fused))
92
93     mul1_data = Op.create_input_data_node(graph, "data_mul_", np.array(mean))
94     add1_data = Op.create_input_data_node(graph, "data_add_", np.array(variance))
95
96     # Broadcast const from scalar
97     # We can broadcast only when const.value is scalar
98     if gamma.shape[0] != gamma.value.shape[0]:
99         gamma.value.resize(gamma.shape)
100         gamma.value.fill(gamma.value[0])
101
102     # Create second Mul & Add
103     mul2_node = Mul(graph, dict(name="Mul2_", can_be_fused=can_be_fused))
104     add2_node = Add(graph, dict(name="Add2_", can_be_fused=can_be_fused))
105
106     add2_node.create_node_with_data(
107         inputs=[mul2_node.create_node_with_data(
108             inputs=[add1_node.create_node_with_data(
109                 inputs=[mul1_node.create_node_with_data(inputs=[tinput, mul1_data]),
110                         add1_data]),
111                 gamma]),
112             beta],
113         data_nodes=toutput)
114
115
116 def convert_scale_shift_to_mul_add(graph: Graph):
117     nodes = graph.get_op_nodes(op='ScaleShift')
118     for node in nodes:
119         if node.soft_get('can_be_fused') is False:
120             continue
121
122         ports_count = len(node.in_ports())
123
124         input_port = node.in_port(0)
125         scale_port = node.in_port(1) if ports_count > 1 and not node.in_port(1).disconnected() else None
126         shift_port = node.in_port(2) if ports_count > 2 and not node.in_port(2).disconnected() else None
127         output_port = node.out_port(0)
128
129         has_biases = True
130         has_weights = True
131
132         # We don't need zero biases
133         if shift_port is None or (shift_port.data.get_value() is not None and all([x == 0 for x in shift_port.data.get_value()])):
134             has_biases = False
135
136         # We don't need weights with ones
137         if scale_port is None or (scale_port.data.get_value() is not None and all([x == 1 for x in scale_port.data.get_value()])):
138             has_weights = False
139
140         mul_op = Mul(graph, dict(name=node.name + "/Mul_"))
141         add_op = Add(graph, dict(name=node.name + "/Add_"))
142
143         # Expand dims for current layout
144         broadcast_dims_cnt = len(input_port.data.get_shape()) - 2 if graph.graph['layout'] == 'NCHW' else 0
145
146         # In case if we have constant weights/biases we have to broadcast them according to graph layout
147         # otherwise we insert Reshape with broadcast dim attribute.
148         def broadcast_value(port):
149             value = np.array(port.data.get_value())
150             for idx in range(broadcast_dims_cnt):
151                 value = np.expand_dims(value, axis=-1)
152             port.data.set_value(value)
153
154         def broadcast_with_reshape(port):
155             input_shape = input_port.data.get_shape()
156             reshape_dims = np.zeros(len(input_shape), dtype=np.int64)
157             for i in range(0, node.axis):
158                 reshape_dims[i] = 1
159             data_shape = port.data.get_shape()
160             for i in range(node.axis, node.axis + len(data_shape)):
161                 reshape_dims[i] = data_shape[i - node.axis]
162             for i in range(node.axis + len(data_shape), len(input_shape)):
163                 reshape_dims[i] = 1
164             reshape = Reshape(graph, dict(name=port.node.name + "/Broadcast_", dim=reshape_dims)).create_node()
165             port.get_connection().set_destination(reshape.in_port(0))
166             reshape.out_port(0).connect(port)
167
168         if has_weights and scale_port.data.get_value() is not None:
169             broadcast_value(scale_port)
170         elif has_weights:
171             broadcast_with_reshape(scale_port)
172
173         if has_biases and shift_port.data.get_value() is not None:
174             broadcast_value(shift_port)
175         elif has_biases:
176             broadcast_with_reshape(shift_port)
177
178         if has_biases and has_weights:
179             # Connect input->mul->out->add->out
180             add_node = add_op.create_node()
181             mul_node = mul_op.create_node()
182
183             # Connect Mul operation with inputs
184             input_port.get_connection().set_destination(mul_node.in_port(0))
185             scale_port.get_connection().set_destination(mul_node.in_port(1))
186
187             # Connect Add operation with inputs
188             mul_node.out_port(0).connect(add_node.in_port(0))
189             shift_port.get_connection().set_destination(add_node.in_port(1))
190
191             output_port.get_connection().set_source(add_node.out_port(0))
192         elif has_weights:
193             # Connect input->mul->out
194             mul_node = mul_op.create_node()
195
196             # Connect Mul operation with inputs
197             input_port.get_connection().set_destination(mul_node.in_port(0))
198             scale_port.get_connection().set_destination(mul_node.in_port(1))
199
200             output_port.get_connection().set_source(mul_node.out_port(0))
201         elif has_biases:
202             # Connect input->add->out
203             add_node = add_op.create_node()
204
205             # Connect Add operation with inputs
206             input_port.get_connection().set_destination(add_node.in_port(0))
207             shift_port.get_connection().set_destination(add_node.in_port(1))
208
209             output_port.get_connection().set_source(add_node.out_port(0))
210         else:
211             # Connect input->out
212             producer_port = input_port.get_source()
213             input_port.disconnect()
214             output_port.get_connection().set_source(producer_port)
215
216
217 def _bn_to_mul_add_action(graph: Graph, match: dict):
218     # Data nodes
219     tinput = match['input']
220     toutput = match['output']
221     mean = match['mean']
222     variance = match['variance']
223
224     # Op node
225     bn_node = match['batch_norm']
226
227     # Disconnect data nodes from
228     graph.remove_edge(tinput.node, bn_node.node)
229     graph.remove_edge(mean.node, bn_node.node)
230     graph.remove_edge(variance.node, bn_node.node)
231
232     graph.remove_edge(bn_node.node, toutput.node)
233
234     scale = 1. / np.sqrt(variance.value + bn_node.epsilon)
235     shift = (mean.value * (-1)) * scale
236
237     mean.value = np.array(scale)
238     variance.value = np.array(shift)
239
240     # Expand dims for current layout
241     broadcast_dims_cnt = len(tinput.shape) - 2 if graph.graph['layout'] == 'NCHW' else 0
242     # Update values and shapes with new shape
243     Op.expand_node_shape(mean, broadcast_dims_cnt)
244     Op.expand_node_shape(variance, broadcast_dims_cnt)
245
246     can_be_fused = False if not bn_node.soft_get('can_be_fused') else True
247
248     mul_node = Mul(graph, dict(name="Mul_", can_be_fused=can_be_fused))
249     add_node = Add(graph, dict(name="Add_", can_be_fused=can_be_fused))
250
251     # Connect input->mul->add
252     add_node.create_node_with_data(inputs=[mul_node.create_node_with_data(inputs=[tinput, mean]), variance],
253                                    data_nodes=toutput)
254
255
256 def convert_bn_to_mul_add(graph: Graph):
257     apply_pattern(
258         graph,
259         nodes=[
260             ('input', dict(kind='data')),
261             ('mean', dict(kind='data')),
262             ('variance', dict(kind='data')),
263             ('output', dict(kind='data')),
264             ('batch_norm', dict(kind='op', op='BatchNormalization')),
265         ],
266         edges=[
267             ('input', 'batch_norm', {'in': 0}),
268             ('mean', 'batch_norm', {'in': 1}),
269             ('variance', 'batch_norm', {'in': 2}),
270             ('batch_norm', 'output'),
271         ],
272         action=_bn_to_mul_add_action
273     )