2 Copyright (c) 2017-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from extensions.front.squared_difference import SquaredDifference
22 from mo.front.common.replacement import FrontReplacementSubgraph
23 from mo.graph.graph import Node, Graph
24 from mo.ops.eltwise import Eltwise
25 from mo.ops.op import Op
28 class MVN(FrontReplacementSubgraph):
32 return [SquaredDifference]
35 log.debug('Enabled MVN replacement')
38 ('mean', dict(op='Mean')),
39 ('stop_grad', dict(op='StopGradient')),
40 ('sqdiff', dict(op='SquaredDifference')),
41 ('variance', dict(op='Mean')),
42 ('squeeze_mean', dict(op='Squeeze')),
43 ('squeeze_variance', dict(op='Squeeze')),
44 ('fbn', dict(op='FusedBatchNorm')),
47 ('mean', 'stop_grad', {'in': 0}),
48 ('stop_grad', 'sqdiff', {'in': 1}),
49 ('sqdiff', 'variance', {'in': 0}),
50 ('mean', 'squeeze_mean', {'in': 0}),
51 ('variance', 'squeeze_variance', {'in': 0}),
52 ('squeeze_mean', 'fbn', {'in': 3}),
53 ('squeeze_variance', 'fbn', {'in': 4}),
56 def replace_sub_graph(self, graph: Graph, match: dict):
58 input = fbn.in_node(0)
59 log.debug('Found potential MVN pattern after {} with name {}'.format(input.op, input.name))
60 if input.id != match['mean'].in_node(0).id or input.id != match['sqdiff'].in_node(0).id:
63 log.debug('Confirmed MVN pattern after {} with name {}'.format(input.op, input.name))
64 MVN = Op.get_op_class_by_name('MVN')
66 mvn = MVN(graph, dict(
67 name=fbn.name + '/MVN_',
69 required_reduction_indices=[1, 2] if fbn.data_format == b'NHWC' else [2, 3]
71 mvn.attrs['old_infer'] = mvn.attrs['infer']
72 mvn.attrs['infer'] = __class__.infer
74 mul = Eltwise(graph, dict(operation='mul', name=fbn.name + '/Mul_'))
75 add = Eltwise(graph, dict(operation='sum', name=fbn.name + '/Add_'))
77 input_gamma = fbn.in_node(1)
78 input_beta = fbn.in_node(2)
80 mean_reduction = match['mean'].in_node(1)
81 variance_reduction = match['variance'].in_node(1)
83 new_subgraph = add.create_node([
85 mvn.create_node([input, mean_reduction, variance_reduction]),
90 fbn.replace_node(new_subgraph)
93 def infer(node: Node):
94 if not (node.in_node(1).has_valid('value') and node.in_node(2).has_valid('value')):
95 log.warning('Reduction indices for mean and variance for MVN node {} are not constants'.format(node.name))
98 if not (all(node.in_node(1).value == node.required_reduction_indices) and
99 all(node.in_node(2).value == node.required_reduction_indices)):
100 log.warning('Reduction indices for mean {} and variance {} do not match required ones {}'.format(
101 node.in_node(1).value,
102 node.in_node(2).value,
103 node.required_reduction_indices
107 node.graph.remove_edge(node.in_node(2).id, node.id)
108 node.graph.remove_edge(node.in_node(1).id, node.id)