Add a section of how to link IE with CMake project (#99)
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / LSTMRNNSequenceToTensorIterator.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import numpy as np
17
18 from extensions.middle.FusePermutesSequence import FusePermutesSequence
19 from extensions.middle.RNNSequenceNormalizeToIE import RNNSequenceNormalize
20 from extensions.ops.lstm_cell import LSTMCell
21 from extensions.ops.tensor_iterator import TensorIterator
22 from mo.graph.graph import Graph, add_opoutput
23 from mo.middle.replacement import MiddleReplacementPattern
24 from mo.ops.op import Op
25 from mo.ops.reshape import Reshape
26
27
28 class LSTMToTensorIterator(MiddleReplacementPattern):
29     """ Converts normalized RNNSequence with op=LSTM to TensorIterator.
30
31         Normalized RNNSequence means that it should be processed by
32         RNNSequenceNormalize transform that ensures its strict form.
33
34         This transformation builds an alternative sub-graph for LSTMSequence
35         with TensorIterator connected in the same way as an original LSTMSequence
36         node and with internal body represented as LSTMCell op node with necessary
37         squeezes and unsqueezes around.
38     """
39
40     enabled = True
41     force_clean_up = True
42     id = 'lstm_to_tensor_iterator'
43     
44     def run_after(self):
45         return [RNNSequenceNormalize]
46
47     def run_before(self):
48         return [FusePermutesSequence]
49
50     def pattern(self):
51         return dict(
52             nodes=[
53                 ('lstm', dict(kind='op', op='LSTM', type='RNNSequence')),
54                 ('input', dict(kind='data')),
55                 ('weights', dict(kind='data')),
56                 ('biases', dict(kind='data')),
57                 # don't capture optional input initial states here
58                 ('output', dict(kind='data')),
59                 # don't capture optional output last states here
60             ],
61             edges=[
62                 ('input', 'lstm', {'in': 0}),
63                 ('weights', 'lstm', {'bin': 'weights', 'in': 1}),
64                 ('biases', 'lstm', {'bin': 'biases', 'in': 2}),
65                 ('lstm', 'output', {'out': 0}),
66             ]
67         )
68
69     def replace_pattern(self, graph: Graph, match: dict):
70         lstm = match['lstm']
71
72         # Build TensorIterator body first
73         body = Graph(name=lstm.name + '/sub_graph')
74         body.graph = graph.graph
75
76         # 1. Input squeeze Reshape
77         inputs = [Op._create_data_node(body, lstm.name + '/inport/' + str(inp),
78                                        {'shape': lstm.in_node(inp).shape.copy(),
79                                         'value': lstm.in_node(inp).value.copy()
80                                         if lstm.in_node(inp).value is not None and inp in [1, 2] else None})
81                   for inp in [0, 4, 5, 1, 2]]  # X, WR, B, h_init, c_init
82
83         inputs[0].shape[lstm.sequence_dim] = 1
84         reshape_dim = inputs[0].shape.copy()
85         reshape_dim[lstm.batch_dim] = -1
86         reshape_dim = np.delete(reshape_dim, lstm.sequence_dim)
87         input_squeeze = Reshape(
88             body,
89             dict(name=lstm.name + '/input_squeeze', internal_layer_id=0, dim=reshape_dim)
90         )
91         inputs[0] = input_squeeze.create_node_with_data([inputs[0]], edge_attrs=[{'internal_port_id': 0}])
92
93         # 2. Output unsqueeze Reshape
94         outputs = [Op._create_data_node(body, lstm.name + '/outport/' + str(out),
95                                         {'shape': lstm.out_node(out).shape.copy() if out in lstm.out_nodes()
96                                         else lstm.in_node(4).shape.copy()}) for out in [0, 1]]
97         for out in outputs:
98             add_opoutput(body, out.id, 0, False)
99
100         unsqueezed_output_shape = outputs[0].shape.copy()
101         unsqueezed_output_shape[lstm.sequence_dim] = 1
102         squeezed_output_shape = np.delete(unsqueezed_output_shape, lstm.sequence_dim)
103         outputs[0].shape = squeezed_output_shape
104         unsqueezed_output_shape[lstm.batch_dim] = -1
105         output_unsqueeze = Reshape(body, dict(name=lstm.name + 'output_unsqueeze', dim=unsqueezed_output_shape,
106                                               internal_layer_id=2))
107
108         # 3. LSTMCell
109         lstm_cell_op = LSTMCell(body, dict(hidden_size=lstm.hidden_size,
110                                            activations=lstm.activations,
111                                            activation_alpha=lstm.activation_alpha,
112                                            activation_beta=lstm.activation_beta,
113                                            clip=lstm.clip,
114                                            input_forget=lstm.input_forget,
115                                            name=lstm.name + '/LSTMCell',
116                                            internal_layer_id=1))
117         lstm_cell_node = lstm_cell_op.create_node_with_data(inputs, data_nodes=outputs,
118                                                             edge_attrs=[{}, {'internal_port_id': 1},
119                                                                         {'internal_port_id': 2}, {'bin': 'weights'},
120                                                                         {'bin': 'biases'}])
121         lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 4
122         lstm_cell_node[0].in_node().out_edge(1)['internal_port_id'] = 5
123         lstm_cell_node[0] = output_unsqueeze.create_node_with_data([lstm_cell_node[0]])
124         lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 3
125         add_opoutput(body, lstm_cell_node[0].id, 0, False)
126
127         # 4. TensorIterator layer creating
128         assert lstm.direction in ['forward', 'reverse']
129         if lstm.direction == 'forward':
130             stride = 1
131             start = None
132             end = None
133         else:
134             assert lstm.direction == 'reverse'
135             stride = -1
136             start = -1
137             end = 0
138
139         output_port_map = [{
140             'external_port_id': 3,
141             'internal_layer_id': 2,
142             'internal_port_id': 3,
143
144             'axis': lstm.sequence_dim,
145             'stride': stride,
146             'start': start,
147             'end': end,
148             'part_size': 1,
149         }]
150
151         # Adding h_state, c_state to outputs
152         if len(lstm.out_nodes()) == 3:
153             output_port_map.extend([{
154                 'external_port_id': 4,
155                 'internal_layer_id': 1,
156                 'internal_port_id': 4,
157             }, {
158                 'external_port_id': 5,
159                 'internal_layer_id': 1,
160                 'internal_port_id': 5,
161             }])
162
163         ti_op = TensorIterator(graph, {
164             'name': lstm.name + '/TensorIterator',
165             'body': body,
166             'in_ports_count': 3,
167             'out_ports_count': len(lstm.out_nodes()),
168
169             'input_port_map': [
170                 {
171                     'external_port_id': 0,
172                     'internal_layer_id': 0,
173                     'internal_port_id': 0,
174
175                     'axis': lstm.sequence_dim,
176                     'stride': stride,
177                     'start': start,
178                     'end': end,
179                     'part_size': 1,
180                 },
181                 {
182                     'external_port_id': 1,
183                     'internal_layer_id': 1,
184                     'internal_port_id': 1,
185                 },
186                 {
187                     'external_port_id': 2,
188                     'internal_layer_id': 1,
189                     'internal_port_id': 2,
190                 },
191             ],
192
193             'output_port_map': output_port_map,
194
195             'back_edges': [
196                 {
197                     'from_layer': 1,
198                     'from_port': 4,
199                     'to_layer': 1,
200                     'to_port': 1,
201                 },
202                 {
203                     'from_layer': 1,
204                     'from_port': 5,
205                     'to_layer': 1,
206                     'to_port': 2,
207                 },
208             ]
209         })
210
211         assert sorted(lstm.out_nodes().keys()) == list(range(len(lstm.out_nodes()))), \
212             "There are gaps in output ports of LSTMSequence operation. Node {}".format(lstm.id)
213
214         outs = ti_op.create_node_with_data([lstm.in_node(i) for i in [0, 4, 5]],  # X, h_init, c_init
215                                            data_nodes=[lstm.out_node(i) for i in range(len(lstm.out_nodes()))],
216                                            edge_attrs=[{'external_port_id': 0}, {'external_port_id': 1},
217                                                        {'external_port_id': 2}])
218
219         if not isinstance(outs, list):
220             outs = list([outs])
221
222         graph.remove_node(lstm.id)
223         outs[0].in_edge(0)['external_port_id'] = 3
224         for i, out in enumerate(outs[1:]):
225             external_port_id = 4 + i
226             out.in_edge()['external_port_id'] = external_port_id