2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from extensions.middle.FusePermutesSequence import FusePermutesSequence
20 from extensions.middle.LSTMRNNSequenceToTensorIterator import LSTMToTensorIterator
21 from extensions.middle.ONNXRNNSequenceNormalize import ONNXRNNSequenceNormalize
22 from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
23 from mo.graph.graph import dict_includes, Graph
24 from mo.middle.passes.eliminate import remove_op_node_with_data_node
25 from mo.middle.pattern_match import find_isomorphisms
26 from mo.middle.replacement import MiddleReplacementPattern
29 class PermuteTensorIteratorLSTM(MiddleReplacementPattern):
30 """ Fuses Permute(1,0,2) --> TI --> Permute(1,0,2) pattern to a single TI with changed axis.
32 WARNING This transformation is limited to support of very special case of TI but
33 code doesn't check all the cases.
39 return [TensorIteratorMerge, ONNXRNNSequenceNormalize, LSTMToTensorIterator, FusePermutesSequence]
49 ('direct_permute', dict(type='Permute')),
54 ('ti', dict(kind='op', op='TensorIterator')),
57 ('inverse_permute', dict(type='Permute')),
61 ('input', 'direct_permute'),
62 ('direct_permute', 'input_permuted'),
64 ('input_permuted', 'ti', {'in': 0}), # affected by permute
65 ('init_hidden', 'ti', {'in': 1}),
66 ('init_cell', 'ti', {'in': 2}),
67 ('ti', 'output_permuted', {'out': 0}), # affected by permute
69 ('output_permuted', 'inverse_permute'),
70 ('inverse_permute', 'output'),
74 def replace_pattern(self, graph: Graph, match: dict):
76 # This transformation works if and only if a body of TI
77 # matches the following topology (Reshape -> LSTMCell -> Reshape)
80 ('squeeze', dict(op='Reshape')),
87 ('lstm', dict(op='LSTMCell')),
91 ('unsqueeze', dict(op='Reshape')),
92 ('output_unsqueezed'),
94 ('const_w', dict(op='Const')),
95 ('const_b', dict(op='Const')),
97 ('op_output', dict(op='OpOutput')),
98 ('op_output_1', dict(op='OpOutput')),
99 ('op_output_2', dict(op='OpOutput'))
103 ('input_unsqueezed', 'squeeze'),
104 ('squeeze', 'input_squeezed'),
106 ('input_squeezed', 'lstm', {'in': 0}),
107 ('input_hidden', 'lstm', {'in': 1}),
108 ('input_cell', 'lstm', {'in': 2}),
109 ('weights', 'lstm', {'in': 3}),
110 ('biases', 'lstm', {'in': 4}),
112 ('const_w', 'weights'),
113 ('const_b', 'biases'),
115 ('lstm', 'output_hidden', {'out': 0}),
116 ('lstm', 'output_cell', {'out': 1}),
118 ('output_hidden', 'unsqueeze'),
119 ('unsqueeze', 'output_unsqueezed'),
121 ('output_unsqueezed', 'op_output'),
122 ('output_hidden', 'op_output_1'),
123 ('output_cell', 'op_output_2'),
127 isomorphisms = find_isomorphisms(ti.body, nodes, edges)
128 if len(list(isomorphisms)) != 1:
130 isomorphism = isomorphisms[0]
132 direct_permute = match['direct_permute']
133 inverse_permute = match['inverse_permute']
135 permute_order = [1, 0, 2]
137 # Check both perumute orders exactly match expected one - [1, 0, 2]
138 if not direct_permute.has_valid('order') or not np.array_equal(direct_permute.order, permute_order):
140 if not inverse_permute.has_valid('order') or not np.array_equal(inverse_permute.order, permute_order):
143 def find_ports(port_map: list, attrs: dict):
144 """ Find all ports in a given port map with specified attributes """
146 for i, port in enumerate(port_map):
147 if dict_includes(port, attrs):
151 # Check TI has only single partitioned input/output port; all partitioned ports have defined axis
152 data_input_port = find_ports(ti.input_port_map, {'axis': lambda attr: attr in [0, 1]})
153 data_output_port = find_ports(ti.output_port_map, {'axis': lambda attr: attr in [0, 1]})
154 assert len(data_input_port) == 1
155 assert len(data_output_port) == 1
156 data_input_port = data_input_port[0]
157 data_output_port = data_output_port[0]
158 # Verify that they are really connected to Permute layers (guarantied by port numbers of TI, see the pattern)
159 assert ti.in_edge(0)['external_port_id'] == ti.input_port_map[data_input_port]['external_port_id']
160 assert ti.out_edge(0)['external_port_id'] == ti.output_port_map[data_output_port]['external_port_id']
162 # Verify that the TI body have required Reshapes connected to the found ports
163 squeeze = isomorphism['squeeze']
164 unsqueeze = isomorphism['unsqueeze']
165 assert squeeze['internal_layer_id'] == ti.input_port_map[data_input_port]['internal_layer_id']
166 assert squeeze.in_edge(0)['internal_port_id'] == ti.input_port_map[data_input_port]['internal_port_id']
167 assert unsqueeze['internal_layer_id'] == ti.output_port_map[data_output_port]['internal_layer_id']
168 assert unsqueeze.out_edge(0)['internal_port_id'] == ti.output_port_map[data_output_port]['internal_port_id']
169 assert len(squeeze.in_node().shape) == 3
170 assert len(squeeze.out_node().shape) == 2
171 assert len(unsqueeze.in_node().shape) == 2
172 assert len(unsqueeze.out_node().shape) == 3
175 remove_op_node_with_data_node(graph, direct_permute)
176 remove_op_node_with_data_node(graph, inverse_permute)
177 match['output'].shape = match['output'].shape[permute_order]
179 # swap 0/1 axis for partitioned ports
180 ti.input_port_map[data_input_port]['axis'] = 1 - ti.input_port_map[data_input_port]['axis']
181 ti.output_port_map[data_output_port]['axis'] = 1 - ti.output_port_map[data_output_port]['axis']
183 # smap 0-th and 1-th shape entries for reshapes inside body
184 squeeze.in_node().shape = squeeze.in_node().shape[[1, 0, 2]]
185 unsqueeze.out_node().shape = unsqueeze.out_node().shape[[1, 0, 2]]
186 unsqueeze.dim = unsqueeze.dim[[1, 0, 2]]