2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
16 from collections import deque
20 from extensions.front.kaldi.add_reshape_around_convolution import ReplaceConvolutionReshape
21 from extensions.middle.TensorIteratorMerge import op_type
22 from mo.front.common.replacement import FrontReplacementSubgraph
23 from mo.graph.graph import Node, Graph
24 from mo.ops.permute import Permute
27 class ReplaceConvolutionPermute(FrontReplacementSubgraph):
29 This pass adds Permute around a Convolution layer if after there is sequence Pooling or Activation afterConvolution
30 **IMPORTANT**: This pass must run after inserting Reshapes around Poolings and Convolutions
32 Let's suppose we have next graph:
34 Convolution -> [Pooling | Activation -> Pooling | Pooling -> Activation | Activation]* -> ... -> (ScaleShift | FullyConnected)
36 **NOTE**: Please, remember about Reshapes around Poolings and Convolutions.
37 In this example we do not print them for simplicity.
38 **NOTE**: After Convolution, it is not necessary to have a sequence [Pooling | Activation -> Pooling | Pooling -> Activation | Activation]*
40 So this pass will convert this graph to the next one:
42 Convolution -> * -> Permute (order 0, 3, 2, 1 )-> Next_Layer -> ... -> (ScaleShift|FullyConnected)
50 ('target_node', dict(op=lambda x: x in ['ScaleShift', 'FullyConnected']))
55 def replace_sub_graph(self, graph: Graph, match: dict):
56 target_node = match['target_node']
57 nodes_with_weights = self.dfs(graph, target_node.name, ('Convolution', 'FullyConnected', 'ScaleShift'), True)
58 convolution_nodes = [node for node in nodes_with_weights if Node(graph, node).op == 'Convolution']
59 for convolution_node in convolution_nodes:
60 target_node = self.search_target_node(Node(graph, convolution_node))
61 permute_op = Permute(graph, {'order': np.array([0, 3, 2, 1])})
62 permute_node = permute_op.add_node({'name': '{}/Permute'.format(target_node.name)})
63 target_node.insert_node_after( permute_node, 0)
66 from extensions.front.kaldi.add_reshape_around_pooling import ReplacePoolingReshape
67 return [ReplaceConvolutionReshape, ReplacePoolingReshape]
70 def search_target_node(node: Node):
71 target_node = ReplaceConvolutionPermute.skip_reshapes(node)
72 sequence_layers = ['Pooling', 'Activation']
73 if target_node.op not in sequence_layers:
75 if target_node.op == 'Activation':
76 sequence_layers.reverse()
77 if target_node.op == sequence_layers[0]:
78 next_node = ReplaceConvolutionPermute.skip_reshapes(target_node)
79 if next_node.op == sequence_layers[1]:
80 target_node = next_node
85 def skip_reshapes(node: Node):
86 next_node = node.out_node()
87 while next_node.op == 'Reshape':
88 next_node = next_node.out_node()
92 def dfs(graph: Graph, node_name: str, stop_nodes: tuple, reverse: bool = False) -> list:
96 visited.add(node_name)
97 d.appendleft(node_name)
99 cur_node = d.popleft()
101 nodes = graph.in_edges(cur_node)
103 nodes = graph.out_edges(cur_node)
104 for in_node_name, _ in nodes:
105 if in_node_name not in visited:
106 if op_type(graph, in_node_name) not in stop_nodes:
107 visited.add(in_node_name)
108 d.append(in_node_name)
110 res.append(in_node_name)