2 Copyright (c) 2017-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from extensions.front.squared_difference import SquaredDifference
20 from extensions.front.sub import Sub
21 from mo.front.common.replacement import FrontReplacementSubgraph
22 from mo.graph.graph import Node, Graph
23 from extensions.front.div import Div
24 from mo.ops.op import Op
27 class MVNUnrolled(FrontReplacementSubgraph):
31 return [SquaredDifference, Div, Sub]
34 log.debug('Enabled MVN replacement')
37 ('mean', dict(kind='op', op='Mean')),
38 ('stop_grad', dict(kind='op', op='StopGradient')),
39 ('sqdiff', dict(kind='op', op='SquaredDifference')),
40 ('variance', dict(kind='op', op='Mean')),
41 ('add', dict(kind='op', op='Add')),
42 ('pow', dict(kind='op', op='Pow')),
43 ('sub', dict(kind='op', op='Sub')),
44 ('truediv', dict(kind='op', op='Div')),
47 ('mean', 'stop_grad', {'in': 0}),
48 ('stop_grad', 'sqdiff', {'in': 1}),
49 ('sqdiff', 'variance', {'in': 0}),
50 ('mean', 'sub', {'in': 1}),
52 ('add', 'pow', {'in': 0}),
53 ('pow', 'truediv', {'in': 1}),
54 ('sub', 'truediv', {'in': 0}),
58 def replace_sub_graph(graph: Graph, match: dict):
59 MVN = Op.get_op_class_by_name('MVN')
61 mvn = MVN(graph, dict(
62 name=match['truediv'].name + '/MVN_',
63 required_reduction_indices=[1, 2] if graph.graph['layout'] == 'NHWC' else [2, 3]
65 mvn.attrs['old_infer'] = mvn.attrs['infer']
66 mvn.attrs['infer'] = __class__.infer
68 mean_reduction = match['mean'].in_node(1)
69 variance_reduction = match['variance'].in_node(1)
70 pow2 = match['pow'].in_node(1)
71 eps = match['add'].in_node(0 if match['add'].in_node(0).id != match['variance'].id else 1)
73 new_subgraph = mvn.create_node([match['mean'].in_node(0), mean_reduction, variance_reduction, pow2, eps])
75 match['truediv'].replace_node(new_subgraph)
78 def infer(node: Node):
79 if not (node.in_node(1).has_valid('value') and node.in_node(2).has_valid('value')):
80 log.warning('Reduction indices for mean and variance for MVN node {} are not constants'.format(node.name))
83 if not (all(node.in_node(1).value == node.required_reduction_indices) and
84 all(node.in_node(2).value == node.required_reduction_indices)):
85 log.warning('Reduction indices for mean {} and variance {} do not match required ones {}'.format(
86 node.in_node(1).value,
87 node.in_node(2).value,
88 node.required_reduction_indices
92 if not (node.in_node(3).has_valid('value') and node.in_node(4).has_valid('value')):
93 log.warning('Power or/and epsilon values for MVN node {} are not constants'.format(node.name))
96 if node.in_node(3).value != 0.5:
97 log.warning('Power for MVN node {} ({}) is not equal to 0.5'.format(node.name, node.in_node(3).value))
100 node['eps'] = node.in_node(4).value
102 for i in range(1, 5):
103 node.graph.remove_edge(node.in_node(i).id, node.id)