Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / permute_tensor_iterator.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import numpy as np
18
19 from extensions.middle.FusePermutesSequence import FusePermutesSequence
20 from extensions.middle.LSTMRNNSequenceToTensorIterator import LSTMToTensorIterator
21 from extensions.middle.ONNXRNNSequenceNormalize import ONNXRNNSequenceNormalize
22 from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
23 from mo.graph.graph import dict_includes, Graph
24 from mo.middle.passes.eliminate import remove_op_node_with_data_node
25 from mo.middle.pattern_match import find_isomorphisms
26 from mo.middle.replacement import MiddleReplacementPattern
27
28
29 class PermuteTensorIteratorLSTM(MiddleReplacementPattern):
30     """ Fuses Permute(1,0,2) --> TI --> Permute(1,0,2) pattern to a single TI with changed axis.
31
32         WARNING This transformation is limited to support of very special case of TI but
33         code doesn't check all the cases.
34     """
35
36     enabled = True
37
38     def run_after(self):
39         return [TensorIteratorMerge, ONNXRNNSequenceNormalize, LSTMToTensorIterator, FusePermutesSequence]
40
41
42     def run_before(self):
43         return []
44
45     def pattern(self):
46         return dict(
47             nodes=[
48                 ('input'),
49                 ('direct_permute', dict(type='Permute')),
50                 ('input_permuted'),
51                 ('init_hidden'),
52                 ('init_cell'),
53
54                 ('ti', dict(kind='op', op='TensorIterator')),
55
56                 ('output_permuted'),
57                 ('inverse_permute', dict(type='Permute')),
58                 ('output'),
59             ],
60             edges=[
61                 ('input', 'direct_permute'),
62                 ('direct_permute', 'input_permuted'),
63
64                 ('input_permuted', 'ti', {'in': 0}),  # affected by permute
65                 ('init_hidden', 'ti', {'in': 1}),
66                 ('init_cell', 'ti', {'in': 2}),
67                 ('ti', 'output_permuted', {'out': 0}),  # affected by permute
68
69                 ('output_permuted', 'inverse_permute'),
70                 ('inverse_permute', 'output'),
71             ]
72         )
73
74     def replace_pattern(self, graph: Graph, match: dict):
75
76         # This transformation works if and only if a body of TI
77         # matches the following topology (Reshape -> LSTMCell -> Reshape)
78         nodes = [
79             ('input_unsqueezed'),
80             ('squeeze', dict(op='Reshape')),
81             ('input_squeezed'),
82             ('input_hidden'),
83             ('input_cell'),
84             ('weights'),
85             ('biases'),
86
87             ('lstm', dict(op='LSTMCell')),
88
89             ('output_hidden'),
90             ('output_cell'),
91             ('unsqueeze', dict(op='Reshape')),
92             ('output_unsqueezed'),
93
94             ('const_w', dict(op='Const')),
95             ('const_b', dict(op='Const')),
96
97             ('op_output', dict(op='OpOutput')),
98             ('op_output_1', dict(op='OpOutput')),
99             ('op_output_2', dict(op='OpOutput'))
100
101         ]
102         edges = [
103             ('input_unsqueezed', 'squeeze'),
104             ('squeeze', 'input_squeezed'),
105
106             ('input_squeezed', 'lstm', {'in': 0}),
107             ('input_hidden', 'lstm', {'in': 1}),
108             ('input_cell', 'lstm', {'in': 2}),
109             ('weights', 'lstm', {'in': 3}),
110             ('biases', 'lstm', {'in': 4}),
111
112             ('const_w', 'weights'),
113             ('const_b', 'biases'),
114
115             ('lstm', 'output_hidden', {'out': 0}),
116             ('lstm', 'output_cell', {'out': 1}),
117
118             ('output_hidden', 'unsqueeze'),
119             ('unsqueeze', 'output_unsqueezed'),
120
121             ('output_unsqueezed', 'op_output'),
122             ('output_hidden', 'op_output_1'),
123             ('output_cell', 'op_output_2'),
124
125         ]
126         ti = match['ti']
127         isomorphisms = find_isomorphisms(ti.body, nodes, edges)
128         if len(list(isomorphisms)) != 1:
129             return
130         isomorphism = isomorphisms[0]
131
132         direct_permute = match['direct_permute']
133         inverse_permute = match['inverse_permute']
134
135         permute_order = [1, 0, 2]
136
137         # Check both perumute orders exactly match expected one - [1, 0, 2]
138         if not direct_permute.has_valid('order') or not np.array_equal(direct_permute.order, permute_order):
139             return
140         if not inverse_permute.has_valid('order') or not np.array_equal(inverse_permute.order, permute_order):
141             return
142
143         def find_ports(port_map: list, attrs: dict):
144             """ Find all ports in a given port map with specified attributes """
145             result = []
146             for i, port in enumerate(port_map):
147                 if dict_includes(port, attrs):
148                     result.append(i)
149             return result
150
151         # Check TI has only single partitioned input/output port; all partitioned ports have defined axis
152         data_input_port = find_ports(ti.input_port_map, {'axis': lambda attr: attr in [0, 1]})
153         data_output_port = find_ports(ti.output_port_map, {'axis': lambda attr: attr in [0, 1]})
154         assert len(data_input_port) == 1
155         assert len(data_output_port) == 1
156         data_input_port = data_input_port[0]
157         data_output_port = data_output_port[0]
158         # Verify that they are really connected to Permute layers (guarantied by port numbers of TI, see the pattern)
159         assert ti.in_edge(0)['external_port_id'] == ti.input_port_map[data_input_port]['external_port_id']
160         assert ti.out_edge(0)['external_port_id'] == ti.output_port_map[data_output_port]['external_port_id']
161
162         # Verify that the TI body have required Reshapes connected to the found ports
163         squeeze = isomorphism['squeeze']
164         unsqueeze = isomorphism['unsqueeze']
165         assert squeeze['internal_layer_id'] == ti.input_port_map[data_input_port]['internal_layer_id']
166         assert squeeze.in_edge(0)['internal_port_id'] == ti.input_port_map[data_input_port]['internal_port_id']
167         assert unsqueeze['internal_layer_id'] == ti.output_port_map[data_output_port]['internal_layer_id']
168         assert unsqueeze.out_edge(0)['internal_port_id'] == ti.output_port_map[data_output_port]['internal_port_id']
169         assert len(squeeze.in_node().shape) == 3
170         assert len(squeeze.out_node().shape) == 2
171         assert len(unsqueeze.in_node().shape) == 2
172         assert len(unsqueeze.out_node().shape) == 3
173
174         # Remove permutes
175         remove_op_node_with_data_node(graph, direct_permute)
176         remove_op_node_with_data_node(graph, inverse_permute)
177         match['output'].shape = match['output'].shape[permute_order]
178
179         # swap 0/1 axis for partitioned ports
180         ti.input_port_map[data_input_port]['axis'] = 1 - ti.input_port_map[data_input_port]['axis']
181         ti.output_port_map[data_output_port]['axis'] = 1 - ti.output_port_map[data_output_port]['axis']
182
183         # smap 0-th and 1-th shape entries for reshapes inside body
184         squeeze.in_node().shape = squeeze.in_node().shape[[1, 0, 2]]
185         unsqueeze.out_node().shape = unsqueeze.out_node().shape[[1, 0, 2]]
186         unsqueeze.dim = unsqueeze.dim[[1, 0, 2]]