Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / BlockLSTMtoLSTMSequence.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import numpy as np
17
18 from extensions.ops.LSTM import LSTM
19 from mo.graph.graph import Graph
20 from mo.middle.replacement import MiddleReplacementPattern
21 from mo.utils.error import Error
22
23
24 class BlockLSTMtoLSTMSequence(MiddleReplacementPattern):
25     """
26     MO virtual operation RNNSequence that converts to IE TensorIterator with LSTMCell inside supports 3 outputs:
27     0: concatenated hidden states over the whole time sequence,
28     1: last hidden state,
29     2: last cell state.
30
31     Replacer do several tasks:
32     1. Checks if current BlockLSTM can be translated to IR (IE does not support concatenated cell state output
33     which can be produced by BlockLSTM)
34     2. Searches for sub-graph, that takes last cell state out of unsupported concatenated cell state output.
35     We cut this sub-graph off in case if there are no other consumers of concatenated cell state output and we connect
36     BlockLSTM to consumers of this sub-graph by port producing last cell state output
37     3. Renumber input ports of BlockLSTM to match RNNSequence specification.
38     4. (Optional. Resolves by multiple checks) We cut the same sug-graph (as in 2) for concatenated cell states check
39     for better performance
40     """
41     enabled = True
42
43     def run_before(self):
44         from extensions.middle.FusePermutesSequence import FusePermutesSequence
45         from extensions.middle.LSTMRNNSequenceToTensorIterator import LSTMToTensorIterator
46         return [FusePermutesSequence, LSTMToTensorIterator]
47
48     def run_after(self):
49         from extensions.middle.pass_separator import MiddleStart
50         from extensions.middle.RNNSequenceNormalizeToIE import RNNSequenceNormalize
51         return [MiddleStart, RNNSequenceNormalize]
52
53     def pattern(self):
54         return dict(
55             nodes=[
56                 ('BlockLSTM', dict(op='BlockLSTM')),
57
58                 # 0 port: output h vector over the whole time sequence
59                 ('concatenated_hidden_states', (dict(kind='data'))),
60
61                 ('mul', dict(op='Mul')),
62                 ('mul_data', dict(kind='data')),
63                 ('after_mul_op_to_the_rest_of_model', dict(kind='op')),
64                 ('concat_0', dict(op='ConcatV2')),
65                 ('concat_0_data', dict(kind='data')),
66                 ('reshape_0', dict(op='Reshape')),
67                 ('reshape_0_data', dict(kind='data')),
68                 ('gather_0', dict(op='Gather')),
69                 ('gather_0_data', dict(kind='data')),
70
71                 # 1 port: cell state before the tanh over the whole time sequence
72                 ('concatenated_cell_states_data', (dict(kind='data'))),
73
74                 ('concat_1', dict(op='ConcatV2')),
75                 ('concat_1_data', dict(kind='data')),
76                 ('reshape_1', dict(op='Reshape')),
77                 ('reshape_1_data', dict(kind='data')),
78                 ('gather_1', dict(op='Gather')),
79                 ('gather_1_data', dict(kind='data')),
80             ],
81             edges=[
82                 ('BlockLSTM', 'concatenated_hidden_states', {'out': 0}),
83                 ('concatenated_hidden_states', 'mul'),
84                 ('mul', 'mul_data'),
85                 ('mul_data', 'after_mul_op_to_the_rest_of_model'),
86                 ('mul_data', 'concat_0'),
87                 ('concat_0', 'concat_0_data'),
88                 ('concat_0_data', 'reshape_0'),
89                 ('reshape_0', 'reshape_0_data'),
90                 ('reshape_0_data', 'gather_0'),
91                 ('gather_0', 'gather_0_data'),
92
93                 ('BlockLSTM', 'concatenated_cell_states_data', {'out': 1}),
94                 ('concatenated_cell_states_data', 'concat_1', {'in': 1}),
95                 ('concat_1', 'concat_1_data'),
96                 ('concat_1_data', 'reshape_1'),
97                 ('reshape_1', 'reshape_1_data'),
98                 ('reshape_1_data', 'gather_1'),
99                 ('gather_1', 'gather_1_data')
100             ]
101         )
102
103     @staticmethod
104     def replace_pattern(graph: Graph, match: dict):
105         time_len = match['concatenated_hidden_states'].shape[0]
106         """
107         Working with concatenated_cell_states_data part first, because IE TensorIterator primitive doesn't have
108         concatenated cell states output and if we can not collapse it, then we does not support this type of BlockLSTM
109
110         We simplify the sub-graph below by taking another output of BlockLSTM:
111         concatenated cell states over the whole time sequence -> last cell state
112
113         BlockLSTM
114            || out 1 (concatenated cell states comming out of BlockLSTM)
115            \/  in 1
116         ConcatV2
117            || (concatenation with initial state or another unused data)
118            \/
119         Reshape
120            ||
121            \/
122          Gather (taking the last cell state from previous BlockLSTM, if Gather indexes == time_len)
123         """
124         # check that there are no other consumers of concatenated_cell_states_data data flow
125         valid_output_names = ['concat_1', 'concat_1_data', 'reshape_1', 'reshape_1_data', 'gather_1', 'gather_1_data']
126         valid_output_node_ids = [match[name].id for name in valid_output_names]
127         node_names_to_check_outputs = ['concatenated_cell_states_data', 'concat_1_data', 'reshape_1_data']
128         for name in node_names_to_check_outputs:
129             for node in match[name].out_nodes():
130                 if node.id not in valid_output_node_ids:
131                     raise Error("BlockLSTM node {} has output which contains concatenated cell states over the whole "
132                                 "time sequence. It is not replaceable by another output and is not supported "
133                                 "originally".format(match['BlockLSTM'].id))
134
135         # check that we really take the last cell state data by Gather
136         gather_indexes = match['gather_1'].in_node(1).value
137         if len(gather_indexes) == 1:
138             gather_index = gather_indexes[0]
139         else:
140             raise Error("BlockLSTM node {} has output which contains concatenated cell states over the whole "
141                         "time sequence. It is not replaceable by another output and is not supported "
142                         "originally".format(match['BlockLSTM'].id))
143         if gather_index != time_len:
144             raise Error("BlockLSTM node {} has output which contains concatenated cell states over the whole "
145                         "time sequence. It is not replaceable by another output and is not supported "
146                         "originally".format(match['BlockLSTM'].id))
147
148         """
149         We passed #1 and #2 stages from class description. It means that we can translate the rest of the pattern 
150         to LSTMSequence even without following optimizations
151         """
152
153         node = match['BlockLSTM']
154         weights_node = node.in_node(1)
155         biases_node = node.in_node(2)
156         shift_const = node.forget_bias
157
158         # Assign temporary shape for them for easier manipulation
159         # TF stores weights in IO order
160         input_size = node.in_node(0).shape[-1]
161         hidden_size = node.in_node(3).shape[-1]
162         weights = weights_node.value
163         biases = biases_node.value
164         assert weights.shape[0] == input_size + hidden_size, \
165             "weights.shape={} input_size={} hidden_size={}".format(weights.shape, input_size, hidden_size)
166         assert weights.shape[1] == biases.shape[0] == 4 * hidden_size, \
167             "weights.shape={} biases.shape={} hidden_size={}".format(weights.shape, biases.shape, hidden_size)
168
169         weights = weights.reshape([
170             weights.shape[0],
171             4,  # gates
172             hidden_size
173         ])
174
175         biases = biases.reshape([
176             4,  # gates
177             hidden_size
178         ])
179
180         # Reorder gates icfo --> fico for both weights and biases
181         gate_reorder = [2, 0, 1, 3]
182         weights = np.take(weights, gate_reorder, axis=1)
183         biases = np.take(biases, gate_reorder, axis=0)
184
185         # shift_const.value should be added to the first 1/4th part of the biases (f-gate: 0)
186         # Note: in case of moving this code up before gate reordering, the addition
187         # should be applied at different place
188         biases[0] += shift_const
189
190         # Return to the original shapes
191         weights = weights.reshape([weights.shape[0], -1])
192         biases = biases.flatten()
193
194         # TF stores weights in IO, but IE requires it in OI: transpose
195         weights = weights.transpose()
196
197         weights_node.value = weights
198         weights_node.shape = np.array(weights.shape, dtype=np.int64)
199         biases_node.value = biases
200         biases_node.shape = np.array(biases.shape, dtype=np.int64)
201
202         attrs = dict(graph.get_edge_data(match['gather_1'].id, match['gather_1_data'].id)[0])
203         attrs.update({'out': 2})
204         graph.remove_edge(match['BlockLSTM'].id, match['concatenated_cell_states_data'].id)
205         graph.remove_edge(match['gather_1'].id, match['gather_1_data'].id)
206
207         graph.add_edge(match['BlockLSTM'].id, match['gather_1_data'].id, **attrs)
208
209         """
210         #3 Renumbering h_init_state, c_init_state input ports to match RNNSequence ports order.
211         """
212         h_init_port = 4
213         c_init_port = 5
214         # c_init_state
215         if 4 in node.in_nodes():
216             assert c_init_port not in node.in_nodes()
217             cell_state_edge = graph.get_edge_data(node.in_node(4).id, node.id)
218             cell_state_edge[0]['in'] = c_init_port
219
220
221         #h_init_state
222         if 3 in node.in_nodes():
223             assert h_init_port not in node.in_nodes()
224             hidden_state_edge = graph.get_edge_data(node.in_node(3).id, node.id)
225             hidden_state_edge[0]['in'] = h_init_port
226
227         new_attrs = {'sequence_dim': 0,
228                      'batch_dim': 1,
229                      'direction': 'forward',
230                      'hidden_size': match['concatenated_hidden_states'].shape[-1],
231                      'format': 'tf',
232                      }
233
234         LSTM.update_node_stat(match['BlockLSTM'], new_attrs)
235
236         """
237         Optional #4 optimization from class description following
238         """
239         data_to_mul = [n for n in match['mul'].in_nodes().values() if n.id != match['concatenated_hidden_states'].id]
240         if len(data_to_mul) != 1:
241             return  # unexpected type of mul
242         data_to_mul = data_to_mul[0]
243         if not data_to_mul.has_valid('value'):
244             return  # unexpected type of mul
245         data_to_mul_value = data_to_mul.value
246         if not np.all(data_to_mul_value == 1):
247             return  # unexpected type of mul
248
249         # remove useless mul
250         attrs = dict(graph.get_edge_data(match['BlockLSTM'].id, match['concatenated_hidden_states'].id)[0])
251         graph.remove_edge(match['BlockLSTM'].id, match['concatenated_hidden_states'].id)
252         graph.remove_edge(match['mul'].id, match['mul_data'].id)
253         graph.add_edge(match['BlockLSTM'].id, match['mul_data'].id, **attrs)
254
255         # find true usages of concatenated hidden states data (not last hidden state)
256         valid_output_names = ['mul_data', 'concat_0', 'concat_0_data', 'reshape_0', 'reshape_0_data', 'gather_0',
257                               'gather_0_data']
258         valid_output_node_ids = [match[name].id for name in valid_output_names]
259         node_names_to_check_outputs = ['mul_data', 'concat_0_data', 'reshape_0_data']
260
261         list_of_concatenated_hidden_states_children_node_ids = []
262         for name in node_names_to_check_outputs:
263             for node in match[name].out_nodes():
264                 if node.id not in valid_output_node_ids:
265                     list_of_concatenated_hidden_states_children_node_ids.append(node.id)
266
267         if len(list_of_concatenated_hidden_states_children_node_ids) != 1:
268             return  # not supported placement of patten
269         conacenated_child_node_id = list_of_concatenated_hidden_states_children_node_ids[0]
270         if conacenated_child_node_id != match['after_mul_op_to_the_rest_of_model'].id:
271             return  # not supported placement of patten
272
273         gather_indexes = match['gather_0'].in_node(1).value
274         if len(gather_indexes) == 1:
275             gather_index = gather_indexes[0]
276         else:
277             return  # we have to translate this type of BlockLSTM to LSTMSequence to TensorIterator as is
278         if gather_index != time_len:
279             return  # we have to translate this type of BlockLSTM to LSTMSequence to TensorIterator as is
280
281         attrs = dict(graph.get_edge_data(match['gather_0'].id, match['gather_0_data'].id)[0])
282         attrs.update({'out': 1})
283         graph.remove_edge(match['mul_data'].id, match['concat_0'].id)
284         graph.remove_edge(match['gather_0'].id, match['gather_0_data'].id)
285
286         graph.add_edge(match['BlockLSTM'].id, match['gather_0_data'].id, **attrs)