2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
17 from mo.front.common.layout import indices_mapping
18 from mo.graph.graph import Node, Graph
19 from mo.middle.replacement import MiddleReplacementPattern
20 from mo.ops.op import Op, PermuteAttrs
21 from mo.ops.permute import Permute
24 class ConvertLayoutDependentOperations(MiddleReplacementPattern):
26 This pass finds all convolutions and in case if layout of convolution differs from graph layout
27 we insert permutes before and after convolution and convert convolution attributes
33 from extensions.middle.pass_separator import MiddleStart
36 def find_and_replace_pattern(self, graph: Graph):
37 for node in list(graph.nodes()):
38 node = Node(graph, node)
39 # Check that node layout mismatch with graph layout
40 # For example: NHWC and NCHW or NCDHW and NDHWC
41 if node.kind == 'op' and node.has_valid('layout') and node.layout != indices_mapping[len(node.layout)][
42 graph.graph['layout']]:
43 input = node.in_node()
44 output = node.out_node()
46 # Calculate permutation for further Permute operations
47 if graph.graph['layout'] == 'NCHW':
48 # if Node has NCHW and graph has NHWC layout
49 permutation = PermuteAttrs.get_nhwc_to_nchw_permutation(len(node.layout))
51 # if Node has NHWC and graph has NCHW layout
52 permutation = PermuteAttrs.get_nchw_to_nhwc_permutation(len(node.layout))
54 # Schematic representation of transformation below
57 # NHWC -- \ | permutation permutation |
58 # data-->Convolution(example)-->data -- / | | NCHW | |
59 # / data->Permute->data->Convolution->data->Permute->data
61 # 1. Insert input Permute
62 # This Permute will permute input from original input layout to operation layout
63 edge_attrs = graph.get_edge_data(input.id, node.id)[0]
64 graph.remove_edge(input.id, node.id)
66 input_permute_op = Permute(graph, {'order': permutation.perm})
67 input_permute_data_node = input_permute_op.create_node_with_data([input],
68 dict(name=node.name + '/Permute_'))
70 graph.add_edge(input_permute_data_node.id, node.id, **edge_attrs)
72 # 2. Insert output Permute
73 # This Permute will permute output from operation layout to original input layout
74 edge_attrs = graph.get_edge_data(node.id, output.id)[0]
75 graph.remove_edge(node.id, output.id)
77 input_data_node = Op.create_data_node(graph, node, {'shape': output.shape[permutation.perm]},
80 output_permute_op = Permute(graph, {'order': permutation.inv})
81 output_permute_op.create_node_with_data([input_data_node], dict(name=node.name + '/Permute_'),
84 # 3. Add permutations for Node
85 # Here we use permutation mechanism where data nodes takes permutation attribute.
86 # And then we call permute_attrs method that permutes node attributes according to permutations on
88 node.in_node()['permutation'] = permutation
89 node.out_node()['permutation'] = permutation
90 node.permute_attrs.permute_attrs(node)
92 node.in_node()['permutation'] = None
93 node.out_node()['permutation'] = None