Add a section of how to link IE with CMake project (#99)
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / lstm_sequence_tensor_iterator.py
1 """
2  Copyright (c) 2018 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import networkx as nx
18 import numpy as np
19
20 from extensions.middle.FusePermutesSequence import FusePermutesSequence
21 from extensions.middle.lstm_sequence_normalize import LSTMSequenceNormalize
22 from extensions.middle.mxnet_lstm_sequence_normalize import MXNetLSTMSequenceNormalize
23 from extensions.ops.lstm_cell import LSTMCell
24 from extensions.ops.tensor_iterator import TensorIterator
25 from mo.middle.replacement import MiddleReplacementPattern
26 from mo.ops.op import Op
27 from mo.ops.reshape import Reshape
28
29
30 class LSTMSequenceTensorIterator(MiddleReplacementPattern):
31     """ Converts normalized LSTMSequence op to TensorIterator.
32
33         Normalized LSTMSequence means that it should be processed by
34         LSTMSequenceNormalize transform that ensures its stict form.
35
36         This transformation builds an altenative sub-graph for LSTMSequence
37         with TensorIterator connected in the same way as an original LSTMSequence
38         node and with internal body represented as LSTMCell op node with necessary
39         squeezes and unsqueezes around.
40     """
41
42     enabled = True
43
44     def run_after(self):
45         return [LSTMSequenceNormalize, MXNetLSTMSequenceNormalize]
46
47     def run_before(self):
48         return [FusePermutesSequence]
49
50     def pattern(self):
51         return dict(
52             nodes=[
53                 ('lstm', dict(kind='op', op='LSTMSequence')),
54                 ('input', dict(kind='data')),
55                 ('weights', dict(kind='data')),
56                 ('biases', dict(kind='data')),
57                 # don't capture optional input initial states here
58                 ('output', dict(kind='data')),
59                 # don't capture optional output last states here
60             ],
61             edges=[
62                 ('input', 'lstm', {'in': 0}),
63                 ('weights', 'lstm', {'bin': 'weights', 'in': 1}),
64                 ('biases', 'lstm', {'bin': 'biases', 'in': 2}),
65                 ('lstm', 'output', {'out': 0}),
66             ]
67         )
68
69     def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
70         lstm = match['lstm']
71
72         # Build TensorIterator body first
73         body = nx.MultiDiGraph(name=lstm.name + '/sub_graph', layout=graph.graph['layout'])
74         inputs = [Op._create_data_node(body, lstm.name + '/inport/' + str(inp),
75                                        {'shape': lstm.in_node(inp).shape.copy(),
76                                         'value': lstm.in_node(inp).value.copy()
77                                         if lstm.in_node(inp).value is not None and inp in [1, 2] else None})
78                                         for inp in [0, 3, 4, 1, 2]]
79         inputs[0].shape[lstm.sequence_dim] = 1
80         reshape_dim = inputs[0].shape.copy()
81         reshape_dim[lstm.batch_dim] = -1
82         reshape_dim = np.delete(reshape_dim, lstm.sequence_dim)
83         input_squeeze = Reshape(
84             body,
85             dict(name=lstm.name + '/input_squeeze', internal_layer_id=0, dim=reshape_dim)
86         )
87         inputs[0] = input_squeeze.create_node_with_data([inputs[0]], edge_attrs=[{'internal_port_id': 0}])
88         lstm_cell_op = LSTMCell(body, dict(hidden_size=match['lstm'].hidden_size, name=lstm.name + '/LSTMCell',
89                                            internal_layer_id=1))
90         outputs = [Op._create_data_node(body, lstm.name + '/outport/' + str(out),
91                                         {'shape': lstm.out_node(out).shape.copy() if out in lstm.out_nodes()
92                                         else lstm.in_node(3).shape.copy(), 'is_output': True}) for out in [0, 1]]
93         unsqueezed_output_shape = outputs[0].shape.copy()
94         unsqueezed_output_shape[lstm.sequence_dim] = 1
95         squeezed_output_shape = np.delete(unsqueezed_output_shape, lstm.sequence_dim)
96         outputs[0].shape = squeezed_output_shape
97         unsqueezed_output_shape[lstm.batch_dim] = -1
98         output_unsqueeze = Reshape(body, dict(name=lstm.name + 'output_unsqueeze', dim=unsqueezed_output_shape,
99                                               internal_layer_id=2))
100         # TODO edge attributes should be assigned by the op itself
101         lstm_cell_node = lstm_cell_op.create_node_with_data(inputs, data_nodes=outputs,
102                                                             edge_attrs=[{}, {'internal_port_id': 1},
103                                                                         {'internal_port_id': 2}, {'bin': 'weights'},
104                                                                         {'bin': 'biases'}])
105         lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 4
106         lstm_cell_node[0].in_node().out_edge(1)['internal_port_id'] = 5
107         lstm_cell_node[0] = output_unsqueeze.create_node_with_data([lstm_cell_node[0]])
108         lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 3
109         lstm_cell_node[0]['is_output'] = True
110
111         assert lstm.direction in ['forward', 'reverse']
112         if lstm.direction == 'forward':
113             stride = 1
114             start = None
115             end = None
116         else:
117             assert lstm.direction == 'reverse'
118             stride = -1
119             start = -1
120             end = 0
121
122         output_port_map = [{
123             'external_port_id': 3,
124             'internal_layer_id': 2,
125             'internal_port_id': 3,
126             'axis': lstm.sequence_dim,
127             'stride': stride,
128             'start': start,
129             'end': end,
130             'part_size': 1,
131         }]
132
133         if len(lstm.out_nodes()) == 3:
134             output_port_map.extend([{
135                 'external_port_id': 4,
136                 'internal_layer_id': 1,
137                 'internal_port_id': 4,
138             }, {
139                 'external_port_id': 5,
140                 'internal_layer_id': 1,
141                 'internal_port_id': 5,
142             }])
143
144         ti_op = TensorIterator(graph, {
145             'name': lstm.name + '/TensorIterator',
146             'body': body,
147
148             'input_port_map': [
149                 {
150                     'external_port_id': 0,
151                     'internal_layer_id': 0,
152                     'internal_port_id': 0,
153                     'axis': lstm.sequence_dim,
154                     'stride': stride,
155                     'start': start,
156                     'end': end,
157                     'part_size': 1,
158                 },
159                 {
160                     'external_port_id': 1,
161                     'internal_layer_id': 1,
162                     'internal_port_id': 1,
163                 },
164                 {
165                     'external_port_id': 2,
166                     'internal_layer_id': 1,
167                     'internal_port_id': 2,
168                 },
169             ],
170
171             'output_port_map': output_port_map,
172
173             'back_edges': [
174                 {
175                     'from_layer': 1,
176                     'from_port': 4,
177                     'to_layer': 1,
178                     'to_port': 1,
179                 },
180                 {
181                     'from_layer': 1,
182                     'from_port': 5,
183                     'to_layer': 1,
184                     'to_port': 2,
185                 },
186             ]
187         })
188
189         assert sorted(lstm.out_nodes().keys()) == list(range(len(lstm.out_nodes()))), \
190             "There are gaps in output ports of LSTMSequence operation. Node {}".format(lstm.id)
191         outs = ti_op.create_node_with_data([lstm.in_node(i) for i in [0, 3, 4]],
192                                            data_nodes=[lstm.out_node(i) for i in range(len(lstm.out_nodes()))],
193                                            edge_attrs=[{'external_port_id': 0}, {'external_port_id': 1},
194                                                        {'external_port_id': 2}])
195
196         if not isinstance(outs, list):
197             outs = list([outs])
198
199         graph.remove_node(lstm.id)
200         outs[0].in_edge(0)['external_port_id'] = 3
201         for i, out in enumerate(outs[1:]):
202             external_port_id = 4 + i
203             out.in_edge()['external_port_id'] = external_port_id