Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / prelu.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 from extensions.front.sub import Sub
20 from extensions.ops.prelu import PreluOp
21 from mo.front.common.replacement import FrontReplacementSubgraph
22 from mo.graph.graph import Graph
23 from mo.middle.pattern_match import check_node_usages_out_of_match
24
25
26 class PReLU(FrontReplacementSubgraph):
27     enabled = True
28
29     def pattern(self):
30         return dict(
31             nodes=[('op', dict(kind='op')),
32                    ('pos_relu', dict(kind='op', op='Relu')),
33                    ('neg', dict(kind='op', op='Neg')),
34                    ('neg_relu', dict(kind='op', op='Relu')),
35                    ('neg_1', dict(kind='op', op='Neg')),
36                    ('mul', dict(kind='op', op='Mul')),
37                    ('add', dict(kind='op', op='Add')),
38                    ],
39             edges=[
40                 ('op', 'pos_relu'),
41                 ('op', 'neg'),
42                 ('pos_relu', 'add'),
43                 ('neg', 'neg_relu'),
44                 ('neg_relu', 'neg_1'),
45                 ('neg_1', 'mul'),
46                 ('mul', 'add')
47             ]
48         )
49
50     def replace_sub_graph(self, graph: Graph, match: dict):
51         consumers = [n for n in match if n not in ['mul', 'op', 'add'] and not check_node_usages_out_of_match(match, n)]
52         if consumers:
53             log.warning('PReLU pattern was detected. Non pattern consumers of nodes: "{}" were found. Won\'t replace'
54                         ''.format(', '.join([match[n].id for n in consumers])))
55             return
56         gamma = match['mul'].in_node(0) if match['mul'].in_node(1).id == match['neg_1'].id else match['mul'].in_node(1)
57         prelu_node = PreluOp(graph, {'name': '{}/PReLU'.format(match['add'].id)}).create_node([match['op'], gamma])
58         match['add'].replace_node(prelu_node)
59         log.debug('PReLU pattern starting from "{}" was collapsed to "{}"'.format(match['op'].id, prelu_node.id))
60
61
62 class PReLUWithAbs(FrontReplacementSubgraph):
63     enabled = True
64
65     def run_before(self):
66         return [Sub]
67
68     def pattern(self):
69         return dict(
70             nodes=[('op', dict(kind='op')),
71                    ('relu', dict(kind='op', op='Relu')),
72                    ('abs', dict(kind='op', op='Abs')),
73                    ('sub', dict(kind='op', op='Sub')),
74                    ('mul', dict(kind='op', op='Mul')),
75                    ('mul_1', dict(kind='op', op='Mul')),
76                    ('add', dict(kind='op', op='Add')),
77                    ],
78             edges=[
79                 ('op', 'relu'),
80                 ('op', 'abs'),
81                 ('op', 'sub'),
82                 ('abs', 'sub'),
83                 ('sub', 'mul'),
84                 ('mul', 'mul_1'),
85                 ('relu', 'add'),
86                 ('mul_1', 'add'),
87             ]
88         )
89
90     def replace_sub_graph(self, graph: Graph, match: dict):
91         consumers = [n for n in match if
92                      n not in ['mul', 'mul_1', 'op', 'add', 'abs', 'sub'] and not check_node_usages_out_of_match(match,
93                                                                                                                  n)]
94         if consumers:
95             log.warning('PReLUWithAbs pattern was detected. Non pattern consumers of nodes: "{}" were found. Won\'t '
96                         'replace '.format(', '.join([match[n].id for n in consumers])))
97             return
98         gamma = match['mul'].in_node(0) if match['mul'].in_node(1).id == match['sub'].id else match['mul'].in_node(1)
99         prelu_node = PreluOp(graph, {'name': '{}/PReLU'.format(match['add'].id)}).create_node([match['op'], gamma])
100         match['add'].replace_node(prelu_node)
101         log.debug('PReLUWithAbs pattern starting from "{}" was collapsed to "{}"'.format(match['op'].id, prelu_node.id))