2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from extensions.front.sub import Sub
20 from extensions.ops.prelu import PreluOp
21 from mo.front.common.replacement import FrontReplacementSubgraph
22 from mo.graph.graph import Graph
23 from mo.middle.pattern_match import check_node_usages_out_of_match
26 class PReLU(FrontReplacementSubgraph):
31 nodes=[('op', dict(kind='op')),
32 ('pos_relu', dict(kind='op', op='Relu')),
33 ('neg', dict(kind='op', op='Neg')),
34 ('neg_relu', dict(kind='op', op='Relu')),
35 ('neg_1', dict(kind='op', op='Neg')),
36 ('mul', dict(kind='op', op='Mul')),
37 ('add', dict(kind='op', op='Add')),
44 ('neg_relu', 'neg_1'),
50 def replace_sub_graph(self, graph: Graph, match: dict):
51 consumers = [n for n in match if n not in ['mul', 'op', 'add'] and not check_node_usages_out_of_match(match, n)]
53 log.warning('PReLU pattern was detected. Non pattern consumers of nodes: "{}" were found. Won\'t replace'
54 ''.format(', '.join([match[n].id for n in consumers])))
56 gamma = match['mul'].in_node(0) if match['mul'].in_node(1).id == match['neg_1'].id else match['mul'].in_node(1)
57 prelu_node = PreluOp(graph, {'name': '{}/PReLU'.format(match['add'].id)}).create_node([match['op'], gamma])
58 match['add'].replace_node(prelu_node)
59 log.debug('PReLU pattern starting from "{}" was collapsed to "{}"'.format(match['op'].id, prelu_node.id))
62 class PReLUWithAbs(FrontReplacementSubgraph):
70 nodes=[('op', dict(kind='op')),
71 ('relu', dict(kind='op', op='Relu')),
72 ('abs', dict(kind='op', op='Abs')),
73 ('sub', dict(kind='op', op='Sub')),
74 ('mul', dict(kind='op', op='Mul')),
75 ('mul_1', dict(kind='op', op='Mul')),
76 ('add', dict(kind='op', op='Add')),
90 def replace_sub_graph(self, graph: Graph, match: dict):
91 consumers = [n for n in match if
92 n not in ['mul', 'mul_1', 'op', 'add', 'abs', 'sub'] and not check_node_usages_out_of_match(match,
95 log.warning('PReLUWithAbs pattern was detected. Non pattern consumers of nodes: "{}" were found. Won\'t '
96 'replace '.format(', '.join([match[n].id for n in consumers])))
98 gamma = match['mul'].in_node(0) if match['mul'].in_node(1).id == match['sub'].id else match['mul'].in_node(1)
99 prelu_node = PreluOp(graph, {'name': '{}/PReLU'.format(match['add'].id)}).create_node([match['op'], gamma])
100 match['add'].replace_node(prelu_node)
101 log.debug('PReLUWithAbs pattern starting from "{}" was collapsed to "{}"'.format(match['op'].id, prelu_node.id))