2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
22 from mo.graph.graph import Node, Graph
23 from mo.middle.passes.eliminate import merge_data_nodes
24 from mo.middle.pattern_match import apply_pattern
25 from mo.ops.lin_op import Mul, Add
26 from mo.ops.op import Op
27 from mo.ops.reshape import Reshape
30 def convert_batch_norm(graph: Graph):
32 This function finds FusedBatchNorm layer (or BatchNorm for MXNet) and replaces with Mul->Add->Mul->Add sequence.
34 for n in list(graph.nodes()):
36 if node.has_valid('op') and (
37 node.op == 'FusedBatchNorm' or node.op == 'BatchNorm' or node.op == 'BatchNormalization'):
38 toutput = node.out_node()
39 tinput = node.in_node(0)
41 if any([node.in_node(i).value is None for i in range(1, len(node.in_nodes()))]):
42 log.warning('Cannot translate FusedBatchNorm {} node with non-constant weights'.format(
43 node.name if node.has_valid('name') else '<UNKNOWN>'))
46 const = node.in_node(1)
47 beta = node.in_node(2)
48 mean = node.in_node(3)
49 variance = node.in_node(4)
52 if node.has_valid('fix_gamma') and node.fix_gamma:
55 can_be_fused = False if not node.soft_get('can_be_fused') else True
57 # Remove edges from FusedBN node
58 graph.remove_edge(tinput.id, node.id)
59 graph.remove_edge(beta.id, node.id)
60 graph.remove_edge(const.id, node.id)
61 graph.remove_edge(mean.id, node.id)
62 graph.remove_edge(variance.id, node.id)
63 graph.remove_edge(node.id, toutput.id)
65 scale = 1. / np.sqrt(variance.value + eps)
66 shift = (mean.value * (-1)) * scale
68 # Expand dims for current layout
69 broadcast_dims_cnt = len(tinput.shape) - 2 if graph.graph['layout'] == 'NCHW' else 0
70 # Update values and shapes with new shape
71 Op.expand_node_shape(const, broadcast_dims_cnt)
72 Op.expand_node_shape(beta, broadcast_dims_cnt)
74 for idx in range(broadcast_dims_cnt):
75 scale = np.expand_dims(scale, axis=-1)
76 shift = np.expand_dims(shift, axis=-1)
78 _fused_batch_norm_decomposition(graph, tinput, toutput, const, beta, scale, shift, can_be_fused)
81 def _fused_batch_norm_decomposition(graph: Graph, tinput: Node, toutput: Node, gamma: Node, beta: Node,
82 mean: np.ndarray, variance: np.ndarray, can_be_fused=True):
84 This is common function for TF, Caffe and MXNet
85 It creates Mul->Add->Mul->Add subgraph
89 # Create first Mul & Add operations
90 mul1_node = Mul(graph, dict(name="Mul1_", can_be_fused=can_be_fused))
91 add1_node = Add(graph, dict(name="Add1_", can_be_fused=can_be_fused))
93 mul1_data = Op.create_input_data_node(graph, "data_mul_", np.array(mean))
94 add1_data = Op.create_input_data_node(graph, "data_add_", np.array(variance))
96 # Broadcast const from scalar
97 # We can broadcast only when const.value is scalar
98 if gamma.shape[0] != gamma.value.shape[0]:
99 gamma.value.resize(gamma.shape)
100 gamma.value.fill(gamma.value[0])
102 # Create second Mul & Add
103 mul2_node = Mul(graph, dict(name="Mul2_", can_be_fused=can_be_fused))
104 add2_node = Add(graph, dict(name="Add2_", can_be_fused=can_be_fused))
106 add2_node.create_node_with_data(
107 inputs=[mul2_node.create_node_with_data(
108 inputs=[add1_node.create_node_with_data(
109 inputs=[mul1_node.create_node_with_data(inputs=[tinput, mul1_data]),
116 def convert_scale_shift_to_mul_add(graph: Graph):
117 nodes = graph.get_op_nodes(op='ScaleShift')
119 if node.soft_get('can_be_fused') is False:
122 ports_count = len(node.in_ports())
124 input_port = node.in_port(0)
125 scale_port = node.in_port(1) if ports_count > 1 and not node.in_port(1).disconnected() else None
126 shift_port = node.in_port(2) if ports_count > 2 and not node.in_port(2).disconnected() else None
127 output_port = node.out_port(0)
132 # We don't need zero biases
133 if shift_port is None or (shift_port.data.get_value() is not None and all([x == 0 for x in shift_port.data.get_value()])):
136 # We don't need weights with ones
137 if scale_port is None or (scale_port.data.get_value() is not None and all([x == 1 for x in scale_port.data.get_value()])):
140 mul_op = Mul(graph, dict(name=node.name + "/Mul_"))
141 add_op = Add(graph, dict(name=node.name + "/Add_"))
143 # Expand dims for current layout
144 broadcast_dims_cnt = len(input_port.data.get_shape()) - 2 if graph.graph['layout'] == 'NCHW' else 0
146 # In case if we have constant weights/biases we have to broadcast them according to graph layout
147 # otherwise we insert Reshape with broadcast dim attribute.
148 def broadcast_value(port):
149 value = np.array(port.data.get_value())
150 for idx in range(broadcast_dims_cnt):
151 value = np.expand_dims(value, axis=-1)
152 port.data.set_value(value)
154 def broadcast_with_reshape(port):
155 input_shape = input_port.data.get_shape()
156 reshape_dims = np.zeros(len(input_shape), dtype=np.int64)
157 for i in range(0, node.axis):
159 data_shape = port.data.get_shape()
160 for i in range(node.axis, node.axis + len(data_shape)):
161 reshape_dims[i] = data_shape[i - node.axis]
162 for i in range(node.axis + len(data_shape), len(input_shape)):
164 reshape = Reshape(graph, dict(name=port.node.name + "/Broadcast_", dim=reshape_dims)).create_node()
165 port.get_connection().set_destination(reshape.in_port(0))
166 reshape.out_port(0).connect(port)
168 if has_weights and scale_port.data.get_value() is not None:
169 broadcast_value(scale_port)
171 broadcast_with_reshape(scale_port)
173 if has_biases and shift_port.data.get_value() is not None:
174 broadcast_value(shift_port)
176 broadcast_with_reshape(shift_port)
178 if has_biases and has_weights:
179 # Connect input->mul->out->add->out
180 add_node = add_op.create_node()
181 mul_node = mul_op.create_node()
183 # Connect Mul operation with inputs
184 input_port.get_connection().set_destination(mul_node.in_port(0))
185 scale_port.get_connection().set_destination(mul_node.in_port(1))
187 # Connect Add operation with inputs
188 mul_node.out_port(0).connect(add_node.in_port(0))
189 shift_port.get_connection().set_destination(add_node.in_port(1))
191 output_port.get_connection().set_source(add_node.out_port(0))
193 # Connect input->mul->out
194 mul_node = mul_op.create_node()
196 # Connect Mul operation with inputs
197 input_port.get_connection().set_destination(mul_node.in_port(0))
198 scale_port.get_connection().set_destination(mul_node.in_port(1))
200 output_port.get_connection().set_source(mul_node.out_port(0))
202 # Connect input->add->out
203 add_node = add_op.create_node()
205 # Connect Add operation with inputs
206 input_port.get_connection().set_destination(add_node.in_port(0))
207 shift_port.get_connection().set_destination(add_node.in_port(1))
209 output_port.get_connection().set_source(add_node.out_port(0))
212 producer_port = input_port.get_source()
213 input_port.disconnect()
214 output_port.get_connection().set_source(producer_port)
217 def _bn_to_mul_add_action(graph: Graph, match: dict):
219 tinput = match['input']
220 toutput = match['output']
222 variance = match['variance']
225 bn_node = match['batch_norm']
227 # Disconnect data nodes from
228 graph.remove_edge(tinput.node, bn_node.node)
229 graph.remove_edge(mean.node, bn_node.node)
230 graph.remove_edge(variance.node, bn_node.node)
232 graph.remove_edge(bn_node.node, toutput.node)
234 scale = 1. / np.sqrt(variance.value + bn_node.epsilon)
235 shift = (mean.value * (-1)) * scale
237 mean.value = np.array(scale)
238 variance.value = np.array(shift)
240 # Expand dims for current layout
241 broadcast_dims_cnt = len(tinput.shape) - 2 if graph.graph['layout'] == 'NCHW' else 0
242 # Update values and shapes with new shape
243 Op.expand_node_shape(mean, broadcast_dims_cnt)
244 Op.expand_node_shape(variance, broadcast_dims_cnt)
246 can_be_fused = False if not bn_node.soft_get('can_be_fused') else True
248 mul_node = Mul(graph, dict(name="Mul_", can_be_fused=can_be_fused))
249 add_node = Add(graph, dict(name="Add_", can_be_fused=can_be_fused))
251 # Connect input->mul->add
252 add_node.create_node_with_data(inputs=[mul_node.create_node_with_data(inputs=[tinput, mean]), variance],
256 def convert_bn_to_mul_add(graph: Graph):
260 ('input', dict(kind='data')),
261 ('mean', dict(kind='data')),
262 ('variance', dict(kind='data')),
263 ('output', dict(kind='data')),
264 ('batch_norm', dict(kind='op', op='BatchNormalization')),
267 ('input', 'batch_norm', {'in': 0}),
268 ('mean', 'batch_norm', {'in': 1}),
269 ('variance', 'batch_norm', {'in': 2}),
270 ('batch_norm', 'output'),
272 action=_bn_to_mul_add_action