Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / kaldi / add_permute_after_convolution.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 from collections import deque
17
18 import numpy as np
19
20 from extensions.front.kaldi.add_reshape_around_convolution import ReplaceConvolutionReshape
21 from extensions.middle.TensorIteratorMerge import op_type
22 from mo.front.common.replacement import FrontReplacementSubgraph
23 from mo.graph.graph import Node, Graph
24 from mo.ops.permute import Permute
25
26
27 class ReplaceConvolutionPermute(FrontReplacementSubgraph):
28     """
29     This pass adds Permute around a Convolution layer if after there is sequence Pooling or Activation afterConvolution
30     **IMPORTANT**: This pass must run after inserting Reshapes around Poolings and Convolutions
31        For example:
32            Let's suppose we have next graph:
33
34            Convolution -> [Pooling | Activation -> Pooling | Pooling -> Activation | Activation]* -> ... -> (ScaleShift | FullyConnected)
35
36            **NOTE**: Please, remember about Reshapes around Poolings and Convolutions.
37                      In this example we do not print them for simplicity.
38            **NOTE**: After Convolution, it is not necessary to have a sequence [Pooling | Activation -> Pooling | Pooling -> Activation | Activation]*
39
40            So this pass will convert this graph to the next one:
41
42            Convolution -> * -> Permute (order 0, 3, 2, 1 )-> Next_Layer -> ... -> (ScaleShift|FullyConnected)
43
44     """
45     enabled = True
46
47     def pattern(self):
48         return dict(
49             nodes=[
50                 ('target_node', dict(op=lambda x: x in ['ScaleShift', 'FullyConnected']))
51             ],
52             edges=[]
53         )
54
55     def replace_sub_graph(self, graph: Graph, match: dict):
56         target_node = match['target_node']
57         nodes_with_weights = self.dfs(graph, target_node.name, ('Convolution', 'FullyConnected', 'ScaleShift'), True)
58         convolution_nodes = [node for node in nodes_with_weights if Node(graph, node).op == 'Convolution']
59         for convolution_node in convolution_nodes:
60             target_node = self.search_target_node(Node(graph, convolution_node))
61             permute_op = Permute(graph, {'order': np.array([0, 3, 2, 1])})
62             permute_node = permute_op.add_node({'name': '{}/Permute'.format(target_node.name)})
63             target_node.insert_node_after( permute_node, 0)
64
65     def run_after(self):
66         from extensions.front.kaldi.add_reshape_around_pooling import ReplacePoolingReshape
67         return [ReplaceConvolutionReshape, ReplacePoolingReshape]
68
69     @staticmethod
70     def search_target_node(node: Node):
71         target_node = ReplaceConvolutionPermute.skip_reshapes(node)
72         sequence_layers = ['Pooling', 'Activation']
73         if target_node.op not in sequence_layers:
74             return node
75         if target_node.op == 'Activation':
76             sequence_layers.reverse()
77         if target_node.op == sequence_layers[0]:
78             next_node = ReplaceConvolutionPermute.skip_reshapes(target_node)
79             if next_node.op == sequence_layers[1]:
80                 target_node = next_node
81
82         return target_node
83
84     @staticmethod
85     def skip_reshapes(node: Node):
86         next_node = node.out_node()
87         while next_node.op == 'Reshape':
88             next_node = next_node.out_node()
89         return next_node
90
91     @staticmethod
92     def dfs(graph: Graph, node_name: str, stop_nodes: tuple, reverse: bool = False) -> list:
93         d = deque()
94         res = []
95         visited = set()
96         visited.add(node_name)
97         d.appendleft(node_name)
98         while len(d) != 0:
99             cur_node = d.popleft()
100             if reverse:
101                 nodes = graph.in_edges(cur_node)
102             else:
103                 nodes = graph.out_edges(cur_node)
104             for in_node_name, _ in nodes:
105                 if in_node_name not in visited:
106                     if op_type(graph, in_node_name) not in stop_nodes:
107                         visited.add(in_node_name)
108                         d.append(in_node_name)
109                     else:
110                         res.append(in_node_name)
111         return res