Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / mvn_unrolled.py
1 """
2  Copyright (c) 2017-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 from extensions.front.squared_difference import SquaredDifference
20 from extensions.front.sub import Sub
21 from mo.front.common.replacement import FrontReplacementSubgraph
22 from mo.graph.graph import Node, Graph
23 from extensions.front.div import Div
24 from mo.ops.op import Op
25
26
27 class MVNUnrolled(FrontReplacementSubgraph):
28     enabled = True
29
30     def run_before(self):
31         return [SquaredDifference, Div, Sub]
32
33     def pattern(self):
34         log.debug('Enabled MVN replacement')
35         return dict(
36             nodes=[
37                 ('mean', dict(kind='op', op='Mean')),
38                 ('stop_grad', dict(kind='op', op='StopGradient')),
39                 ('sqdiff', dict(kind='op', op='SquaredDifference')),
40                 ('variance', dict(kind='op', op='Mean')),
41                 ('add', dict(kind='op', op='Add')),
42                 ('pow', dict(kind='op', op='Pow')),
43                 ('sub', dict(kind='op', op='Sub')),
44                 ('truediv', dict(kind='op', op='Div')),
45             ],
46             edges=[
47                 ('mean', 'stop_grad', {'in': 0}),
48                 ('stop_grad', 'sqdiff', {'in': 1}),
49                 ('sqdiff', 'variance', {'in': 0}),
50                 ('mean', 'sub', {'in': 1}),
51                 ('variance', 'add'),
52                 ('add', 'pow', {'in': 0}),
53                 ('pow', 'truediv', {'in': 1}),
54                 ('sub', 'truediv', {'in': 0}),
55             ])
56
57     @staticmethod
58     def replace_sub_graph(graph: Graph, match: dict):
59         MVN = Op.get_op_class_by_name('MVN')
60
61         mvn = MVN(graph, dict(
62             name=match['truediv'].name + '/MVN_',
63             required_reduction_indices=[1, 2] if graph.graph['layout'] == 'NHWC' else [2, 3]
64         ))
65         mvn.attrs['old_infer'] = mvn.attrs['infer']
66         mvn.attrs['infer'] = __class__.infer
67
68         mean_reduction = match['mean'].in_node(1)
69         variance_reduction = match['variance'].in_node(1)
70         pow2 = match['pow'].in_node(1)
71         eps = match['add'].in_node(0 if match['add'].in_node(0).id != match['variance'].id else 1)
72
73         new_subgraph = mvn.create_node([match['mean'].in_node(0), mean_reduction, variance_reduction, pow2, eps])
74
75         match['truediv'].replace_node(new_subgraph)
76
77     @staticmethod
78     def infer(node: Node):
79         if not (node.in_node(1).has_valid('value') and node.in_node(2).has_valid('value')):
80             log.warning('Reduction indices for mean and variance for MVN node {} are not constants'.format(node.name))
81             return
82
83         if not (all(node.in_node(1).value == node.required_reduction_indices) and
84                     all(node.in_node(2).value == node.required_reduction_indices)):
85             log.warning('Reduction indices for mean {} and variance {} do not match required ones {}'.format(
86                 node.in_node(1).value,
87                 node.in_node(2).value,
88                 node.required_reduction_indices
89             ))
90             return
91         
92         if not (node.in_node(3).has_valid('value') and node.in_node(4).has_valid('value')):
93             log.warning('Power or/and epsilon values for MVN node {} are not constants'.format(node.name))
94             return
95
96         if node.in_node(3).value != 0.5:
97             log.warning('Power for MVN node {} ({}) is not equal to 0.5'.format(node.name, node.in_node(3).value))
98             return
99
100         node['eps'] = node.in_node(4).value
101
102         for i in range(1, 5):
103             node.graph.remove_edge(node.in_node(i).id, node.id)
104         node.old_infer(node)