"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
See the License for the specific language governing permissions and
limitations under the License.
"""
-
-import networkx as nx
import numpy as np
-from extensions.middle.FusePermutesSequence import FusePermutesSequence
-from extensions.middle.lstm_sequence_normalize import LSTMSequenceNormalize
-from extensions.middle.lstm_sequence_tensor_iterator import LSTMSequenceTensorIterator
+from extensions.ops.LSTM import LSTM
+from mo.graph.graph import Graph
from mo.middle.replacement import MiddleReplacementPattern
from mo.utils.error import Error
class BlockLSTMtoLSTMSequence(MiddleReplacementPattern):
"""
- MO virtual operation LSTMSequence that converts to IE TensorIterator with LSTMCell inside supports 3 outputs:
+ MO virtual operation RNNSequence that converts to IE TensorIterator with LSTMCell inside supports 3 outputs:
0: concatenated hidden states over the whole time sequence,
1: last hidden state,
2: last cell state.
2. Searches for sub-graph, that takes last cell state out of unsupported concatenated cell state output.
We cut this sub-graph off in case if there are no other consumers of concatenated cell state output and we connect
BlockLSTM to consumers of this sub-graph by port producing last cell state output
- 3. (Optional. Resolves by multiple checks) We cut the same sug-graph (as in 2) for concatenated cell states check
+ 3. Renumber input ports of BlockLSTM to match RNNSequence specification.
+ 4. (Optional. Resolves by multiple checks) We cut the same sug-graph (as in 2) for concatenated cell states check
for better performance
"""
enabled = True
def run_before(self):
- return [FusePermutesSequence, LSTMSequenceTensorIterator]
+ from extensions.middle.FusePermutesSequence import FusePermutesSequence
+ from extensions.middle.LSTMRNNSequenceToTensorIterator import LSTMToTensorIterator
+ return [FusePermutesSequence, LSTMToTensorIterator]
+
+ def run_after(self):
+ from extensions.middle.pass_separator import MiddleStart
+ from extensions.middle.RNNSequenceNormalizeToIE import RNNSequenceNormalize
+ return [MiddleStart, RNNSequenceNormalize]
def pattern(self):
return dict(
)
@staticmethod
- def replace_pattern(graph: nx.MultiDiGraph, match: dict):
+ def replace_pattern(graph: Graph, match: dict):
time_len = match['concatenated_hidden_states'].shape[0]
"""
Working with concatenated_cell_states_data part first, because IE TensorIterator primitive doesn't have
- concatenated cell states output and if we can not collepse it, then we does not support this type of BlockLSTM
+ concatenated cell states output and if we can not collapse it, then we does not support this type of BlockLSTM
We simplify the sub-graph below by taking another output of BlockLSTM:
concatenated cell states over the whole time sequence -> last cell state
hidden_size = node.in_node(3).shape[-1]
weights = weights_node.value
biases = biases_node.value
- assert weights.shape[0] == input_size + hidden_size, "weights.shape={} input_size={} hidden_size={}".format(weights.shape, input_size, hidden_size)
- assert weights.shape[1] == biases.shape[0] == 4 * hidden_size, "weights.shape={} biases.shape={} hidden_size={}".format(weights.shape, biases.shape, hidden_size)
+ assert weights.shape[0] == input_size + hidden_size, \
+ "weights.shape={} input_size={} hidden_size={}".format(weights.shape, input_size, hidden_size)
+ assert weights.shape[1] == biases.shape[0] == 4 * hidden_size, \
+ "weights.shape={} biases.shape={} hidden_size={}".format(weights.shape, biases.shape, hidden_size)
weights = weights.reshape([
weights.shape[0],
graph.add_edge(match['BlockLSTM'].id, match['gather_1_data'].id, **attrs)
- match['BlockLSTM'].op = 'LSTMSequence'
- match['BlockLSTM']['sequence_dim'] = 0 # TF reference
- match['BlockLSTM']['batch_dim'] = 1 # TF reference
- match['BlockLSTM']['direction'] = 'forward' # TF reference
- match['BlockLSTM']['hidden_size'] = match['concatenated_hidden_states'].shape[-1]
- match['BlockLSTM']['format'] = 'tf'
+ """
+ #3 Renumbering h_init_state, c_init_state input ports to match RNNSequence ports order.
+ """
+ h_init_port = 4
+ c_init_port = 5
+ # c_init_state
+ if 4 in node.in_nodes():
+ assert c_init_port not in node.in_nodes()
+ cell_state_edge = graph.get_edge_data(node.in_node(4).id, node.id)
+ cell_state_edge[0]['in'] = c_init_port
+
+
+ #h_init_state
+ if 3 in node.in_nodes():
+ assert h_init_port not in node.in_nodes()
+ hidden_state_edge = graph.get_edge_data(node.in_node(3).id, node.id)
+ hidden_state_edge[0]['in'] = h_init_port
+
+ new_attrs = {'sequence_dim': 0,
+ 'batch_dim': 1,
+ 'direction': 'forward',
+ 'hidden_size': match['concatenated_hidden_states'].shape[-1],
+ 'format': 'tf',
+ }
+
+ LSTM.update_node_stat(match['BlockLSTM'], new_attrs)
"""
- Optional #3 optimization from class description following
+ Optional #4 optimization from class description following
"""
data_to_mul = [n for n in match['mul'].in_nodes().values() if n.id != match['concatenated_hidden_states'].id]
if len(data_to_mul) != 1: