Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / ConvertLayoutDependentOperations.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 from mo.front.common.layout import indices_mapping
18 from mo.graph.graph import Node, Graph
19 from mo.middle.replacement import MiddleReplacementPattern
20 from mo.ops.op import Op, PermuteAttrs
21 from mo.ops.permute import Permute
22
23
24 class ConvertLayoutDependentOperations(MiddleReplacementPattern):
25     """
26          This pass finds all convolutions and in case if layout of convolution differs from graph layout
27          we insert permutes before and after convolution and convert convolution attributes
28     """
29
30     enabled = True
31
32     def run_after(self):
33         from extensions.middle.pass_separator import MiddleStart
34         return [MiddleStart]
35
36     def find_and_replace_pattern(self, graph: Graph):
37         for node in list(graph.nodes()):
38             node = Node(graph, node)
39             # Check that node layout mismatch with graph layout
40             # For example: NHWC and NCHW or NCDHW and NDHWC
41             if node.kind == 'op' and node.has_valid('layout') and node.layout != indices_mapping[len(node.layout)][
42                 graph.graph['layout']]:
43                 input = node.in_node()
44                 output = node.out_node()
45
46                 # Calculate permutation for further Permute operations
47                 if graph.graph['layout'] == 'NCHW':
48                     # if Node has NCHW and graph has NHWC layout
49                     permutation = PermuteAttrs.get_nhwc_to_nchw_permutation(len(node.layout))
50                 else:
51                     # if Node has NHWC and graph has NCHW layout
52                     permutation = PermuteAttrs.get_nchw_to_nhwc_permutation(len(node.layout))
53
54                 # Schematic representation of transformation below
55                 #
56                 #                                           \            NCHW                              NCHW
57                 #            NHWC                        --  \            |  permutation       permutation  |
58                 #   data-->Convolution(example)-->data   --  /            |      |       NCHW      |        |
59                 #                                           /   data->Permute->data->Convolution->data->Permute->data
60
61                 # 1. Insert input Permute
62                 #    This Permute will permute input from original input layout to operation layout
63                 edge_attrs = graph.get_edge_data(input.id, node.id)[0]
64                 graph.remove_edge(input.id, node.id)
65
66                 input_permute_op = Permute(graph, {'order': permutation.perm})
67                 input_permute_data_node = input_permute_op.create_node_with_data([input],
68                                                                                  dict(name=node.name + '/Permute_'))
69
70                 graph.add_edge(input_permute_data_node.id, node.id, **edge_attrs)
71
72                 # 2. Insert output Permute
73                 #    This Permute will permute output from operation layout to original input layout
74                 edge_attrs = graph.get_edge_data(node.id, output.id)[0]
75                 graph.remove_edge(node.id, output.id)
76
77                 input_data_node = Op.create_data_node(graph, node, {'shape': output.shape[permutation.perm]},
78                                                       edge_attrs)
79
80                 output_permute_op = Permute(graph, {'order': permutation.inv})
81                 output_permute_op.create_node_with_data([input_data_node], dict(name=node.name + '/Permute_'),
82                                                         data_nodes=output)
83
84                 # 3. Add permutations for Node
85                 #    Here we use permutation mechanism where data nodes takes permutation attribute.
86                 #    And then we call permute_attrs method that permutes node attributes according to permutations on
87                 #    data nodes.
88                 node.in_node()['permutation'] = permutation
89                 node.out_node()['permutation'] = permutation
90                 node.permute_attrs.permute_attrs(node)
91
92                 node.in_node()['permutation'] = None
93                 node.out_node()['permutation'] = None