2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
18 from extensions.ops.LSTM import LSTM
19 from mo.graph.graph import Graph
20 from mo.middle.replacement import MiddleReplacementPattern
21 from mo.utils.error import Error
24 class BlockLSTMtoLSTMSequence(MiddleReplacementPattern):
26 MO virtual operation RNNSequence that converts to IE TensorIterator with LSTMCell inside supports 3 outputs:
27 0: concatenated hidden states over the whole time sequence,
31 Replacer do several tasks:
32 1. Checks if current BlockLSTM can be translated to IR (IE does not support concatenated cell state output
33 which can be produced by BlockLSTM)
34 2. Searches for sub-graph, that takes last cell state out of unsupported concatenated cell state output.
35 We cut this sub-graph off in case if there are no other consumers of concatenated cell state output and we connect
36 BlockLSTM to consumers of this sub-graph by port producing last cell state output
37 3. Renumber input ports of BlockLSTM to match RNNSequence specification.
38 4. (Optional. Resolves by multiple checks) We cut the same sug-graph (as in 2) for concatenated cell states check
39 for better performance
44 from extensions.middle.FusePermutesSequence import FusePermutesSequence
45 from extensions.middle.LSTMRNNSequenceToTensorIterator import LSTMToTensorIterator
46 return [FusePermutesSequence, LSTMToTensorIterator]
49 from extensions.middle.pass_separator import MiddleStart
50 from extensions.middle.RNNSequenceNormalizeToIE import RNNSequenceNormalize
51 return [MiddleStart, RNNSequenceNormalize]
56 ('BlockLSTM', dict(op='BlockLSTM')),
58 # 0 port: output h vector over the whole time sequence
59 ('concatenated_hidden_states', (dict(kind='data'))),
61 ('mul', dict(op='Mul')),
62 ('mul_data', dict(kind='data')),
63 ('after_mul_op_to_the_rest_of_model', dict(kind='op')),
64 ('concat_0', dict(op='ConcatV2')),
65 ('concat_0_data', dict(kind='data')),
66 ('reshape_0', dict(op='Reshape')),
67 ('reshape_0_data', dict(kind='data')),
68 ('gather_0', dict(op='Gather')),
69 ('gather_0_data', dict(kind='data')),
71 # 1 port: cell state before the tanh over the whole time sequence
72 ('concatenated_cell_states_data', (dict(kind='data'))),
74 ('concat_1', dict(op='ConcatV2')),
75 ('concat_1_data', dict(kind='data')),
76 ('reshape_1', dict(op='Reshape')),
77 ('reshape_1_data', dict(kind='data')),
78 ('gather_1', dict(op='Gather')),
79 ('gather_1_data', dict(kind='data')),
82 ('BlockLSTM', 'concatenated_hidden_states', {'out': 0}),
83 ('concatenated_hidden_states', 'mul'),
85 ('mul_data', 'after_mul_op_to_the_rest_of_model'),
86 ('mul_data', 'concat_0'),
87 ('concat_0', 'concat_0_data'),
88 ('concat_0_data', 'reshape_0'),
89 ('reshape_0', 'reshape_0_data'),
90 ('reshape_0_data', 'gather_0'),
91 ('gather_0', 'gather_0_data'),
93 ('BlockLSTM', 'concatenated_cell_states_data', {'out': 1}),
94 ('concatenated_cell_states_data', 'concat_1', {'in': 1}),
95 ('concat_1', 'concat_1_data'),
96 ('concat_1_data', 'reshape_1'),
97 ('reshape_1', 'reshape_1_data'),
98 ('reshape_1_data', 'gather_1'),
99 ('gather_1', 'gather_1_data')
104 def replace_pattern(graph: Graph, match: dict):
105 time_len = match['concatenated_hidden_states'].shape[0]
107 Working with concatenated_cell_states_data part first, because IE TensorIterator primitive doesn't have
108 concatenated cell states output and if we can not collapse it, then we does not support this type of BlockLSTM
110 We simplify the sub-graph below by taking another output of BlockLSTM:
111 concatenated cell states over the whole time sequence -> last cell state
114 || out 1 (concatenated cell states comming out of BlockLSTM)
117 || (concatenation with initial state or another unused data)
122 Gather (taking the last cell state from previous BlockLSTM, if Gather indexes == time_len)
124 # check that there are no other consumers of concatenated_cell_states_data data flow
125 valid_output_names = ['concat_1', 'concat_1_data', 'reshape_1', 'reshape_1_data', 'gather_1', 'gather_1_data']
126 valid_output_node_ids = [match[name].id for name in valid_output_names]
127 node_names_to_check_outputs = ['concatenated_cell_states_data', 'concat_1_data', 'reshape_1_data']
128 for name in node_names_to_check_outputs:
129 for node in match[name].out_nodes():
130 if node.id not in valid_output_node_ids:
131 raise Error("BlockLSTM node {} has output which contains concatenated cell states over the whole "
132 "time sequence. It is not replaceable by another output and is not supported "
133 "originally".format(match['BlockLSTM'].id))
135 # check that we really take the last cell state data by Gather
136 gather_indexes = match['gather_1'].in_node(1).value
137 if len(gather_indexes) == 1:
138 gather_index = gather_indexes[0]
140 raise Error("BlockLSTM node {} has output which contains concatenated cell states over the whole "
141 "time sequence. It is not replaceable by another output and is not supported "
142 "originally".format(match['BlockLSTM'].id))
143 if gather_index != time_len:
144 raise Error("BlockLSTM node {} has output which contains concatenated cell states over the whole "
145 "time sequence. It is not replaceable by another output and is not supported "
146 "originally".format(match['BlockLSTM'].id))
149 We passed #1 and #2 stages from class description. It means that we can translate the rest of the pattern
150 to LSTMSequence even without following optimizations
153 node = match['BlockLSTM']
154 weights_node = node.in_node(1)
155 biases_node = node.in_node(2)
156 shift_const = node.forget_bias
158 # Assign temporary shape for them for easier manipulation
159 # TF stores weights in IO order
160 input_size = node.in_node(0).shape[-1]
161 hidden_size = node.in_node(3).shape[-1]
162 weights = weights_node.value
163 biases = biases_node.value
164 assert weights.shape[0] == input_size + hidden_size, \
165 "weights.shape={} input_size={} hidden_size={}".format(weights.shape, input_size, hidden_size)
166 assert weights.shape[1] == biases.shape[0] == 4 * hidden_size, \
167 "weights.shape={} biases.shape={} hidden_size={}".format(weights.shape, biases.shape, hidden_size)
169 weights = weights.reshape([
175 biases = biases.reshape([
180 # Reorder gates icfo --> fico for both weights and biases
181 gate_reorder = [2, 0, 1, 3]
182 weights = np.take(weights, gate_reorder, axis=1)
183 biases = np.take(biases, gate_reorder, axis=0)
185 # shift_const.value should be added to the first 1/4th part of the biases (f-gate: 0)
186 # Note: in case of moving this code up before gate reordering, the addition
187 # should be applied at different place
188 biases[0] += shift_const
190 # Return to the original shapes
191 weights = weights.reshape([weights.shape[0], -1])
192 biases = biases.flatten()
194 # TF stores weights in IO, but IE requires it in OI: transpose
195 weights = weights.transpose()
197 weights_node.value = weights
198 weights_node.shape = np.array(weights.shape, dtype=np.int64)
199 biases_node.value = biases
200 biases_node.shape = np.array(biases.shape, dtype=np.int64)
202 attrs = dict(graph.get_edge_data(match['gather_1'].id, match['gather_1_data'].id)[0])
203 attrs.update({'out': 2})
204 graph.remove_edge(match['BlockLSTM'].id, match['concatenated_cell_states_data'].id)
205 graph.remove_edge(match['gather_1'].id, match['gather_1_data'].id)
207 graph.add_edge(match['BlockLSTM'].id, match['gather_1_data'].id, **attrs)
210 #3 Renumbering h_init_state, c_init_state input ports to match RNNSequence ports order.
215 if 4 in node.in_nodes():
216 assert c_init_port not in node.in_nodes()
217 cell_state_edge = graph.get_edge_data(node.in_node(4).id, node.id)
218 cell_state_edge[0]['in'] = c_init_port
222 if 3 in node.in_nodes():
223 assert h_init_port not in node.in_nodes()
224 hidden_state_edge = graph.get_edge_data(node.in_node(3).id, node.id)
225 hidden_state_edge[0]['in'] = h_init_port
227 new_attrs = {'sequence_dim': 0,
229 'direction': 'forward',
230 'hidden_size': match['concatenated_hidden_states'].shape[-1],
234 LSTM.update_node_stat(match['BlockLSTM'], new_attrs)
237 Optional #4 optimization from class description following
239 data_to_mul = [n for n in match['mul'].in_nodes().values() if n.id != match['concatenated_hidden_states'].id]
240 if len(data_to_mul) != 1:
241 return # unexpected type of mul
242 data_to_mul = data_to_mul[0]
243 if not data_to_mul.has_valid('value'):
244 return # unexpected type of mul
245 data_to_mul_value = data_to_mul.value
246 if not np.all(data_to_mul_value == 1):
247 return # unexpected type of mul
250 attrs = dict(graph.get_edge_data(match['BlockLSTM'].id, match['concatenated_hidden_states'].id)[0])
251 graph.remove_edge(match['BlockLSTM'].id, match['concatenated_hidden_states'].id)
252 graph.remove_edge(match['mul'].id, match['mul_data'].id)
253 graph.add_edge(match['BlockLSTM'].id, match['mul_data'].id, **attrs)
255 # find true usages of concatenated hidden states data (not last hidden state)
256 valid_output_names = ['mul_data', 'concat_0', 'concat_0_data', 'reshape_0', 'reshape_0_data', 'gather_0',
258 valid_output_node_ids = [match[name].id for name in valid_output_names]
259 node_names_to_check_outputs = ['mul_data', 'concat_0_data', 'reshape_0_data']
261 list_of_concatenated_hidden_states_children_node_ids = []
262 for name in node_names_to_check_outputs:
263 for node in match[name].out_nodes():
264 if node.id not in valid_output_node_ids:
265 list_of_concatenated_hidden_states_children_node_ids.append(node.id)
267 if len(list_of_concatenated_hidden_states_children_node_ids) != 1:
268 return # not supported placement of patten
269 conacenated_child_node_id = list_of_concatenated_hidden_states_children_node_ids[0]
270 if conacenated_child_node_id != match['after_mul_op_to_the_rest_of_model'].id:
271 return # not supported placement of patten
273 gather_indexes = match['gather_0'].in_node(1).value
274 if len(gather_indexes) == 1:
275 gather_index = gather_indexes[0]
277 return # we have to translate this type of BlockLSTM to LSTMSequence to TensorIterator as is
278 if gather_index != time_len:
279 return # we have to translate this type of BlockLSTM to LSTMSequence to TensorIterator as is
281 attrs = dict(graph.get_edge_data(match['gather_0'].id, match['gather_0_data'].id)[0])
282 attrs.update({'out': 1})
283 graph.remove_edge(match['mul_data'].id, match['concat_0'].id)
284 graph.remove_edge(match['gather_0'].id, match['gather_0_data'].id)
286 graph.add_edge(match['BlockLSTM'].id, match['gather_0_data'].id, **attrs)