Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / BlockLSTMtoLSTMSequence.py
index 9835442..aa4bdf6 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
  See the License for the specific language governing permissions and
  limitations under the License.
 """
-
-import networkx as nx
 import numpy as np
 
-from extensions.middle.FusePermutesSequence import FusePermutesSequence
-from extensions.middle.lstm_sequence_normalize import LSTMSequenceNormalize
-from extensions.middle.lstm_sequence_tensor_iterator import LSTMSequenceTensorIterator
+from extensions.ops.LSTM import LSTM
+from mo.graph.graph import Graph
 from mo.middle.replacement import MiddleReplacementPattern
 from mo.utils.error import Error
 
 
 class BlockLSTMtoLSTMSequence(MiddleReplacementPattern):
     """
-    MO virtual operation LSTMSequence that converts to IE TensorIterator with LSTMCell inside supports 3 outputs:
+    MO virtual operation RNNSequence that converts to IE TensorIterator with LSTMCell inside supports 3 outputs:
     0: concatenated hidden states over the whole time sequence,
     1: last hidden state,
     2: last cell state.
@@ -37,13 +34,21 @@ class BlockLSTMtoLSTMSequence(MiddleReplacementPattern):
     2. Searches for sub-graph, that takes last cell state out of unsupported concatenated cell state output.
     We cut this sub-graph off in case if there are no other consumers of concatenated cell state output and we connect
     BlockLSTM to consumers of this sub-graph by port producing last cell state output
-    3. (Optional. Resolves by multiple checks) We cut the same sug-graph (as in 2) for concatenated cell states check
+    3. Renumber input ports of BlockLSTM to match RNNSequence specification.
+    4. (Optional. Resolves by multiple checks) We cut the same sug-graph (as in 2) for concatenated cell states check
     for better performance
     """
     enabled = True
 
     def run_before(self):
-        return [FusePermutesSequence, LSTMSequenceTensorIterator]
+        from extensions.middle.FusePermutesSequence import FusePermutesSequence
+        from extensions.middle.LSTMRNNSequenceToTensorIterator import LSTMToTensorIterator
+        return [FusePermutesSequence, LSTMToTensorIterator]
+
+    def run_after(self):
+        from extensions.middle.pass_separator import MiddleStart
+        from extensions.middle.RNNSequenceNormalizeToIE import RNNSequenceNormalize
+        return [MiddleStart, RNNSequenceNormalize]
 
     def pattern(self):
         return dict(
@@ -96,11 +101,11 @@ class BlockLSTMtoLSTMSequence(MiddleReplacementPattern):
         )
 
     @staticmethod
-    def replace_pattern(graph: nx.MultiDiGraph, match: dict):
+    def replace_pattern(graph: Graph, match: dict):
         time_len = match['concatenated_hidden_states'].shape[0]
         """
         Working with concatenated_cell_states_data part first, because IE TensorIterator primitive doesn't have
-        concatenated cell states output and if we can not collepse it, then we does not support this type of BlockLSTM
+        concatenated cell states output and if we can not collapse it, then we does not support this type of BlockLSTM
 
         We simplify the sub-graph below by taking another output of BlockLSTM:
         concatenated cell states over the whole time sequence -> last cell state
@@ -156,8 +161,10 @@ class BlockLSTMtoLSTMSequence(MiddleReplacementPattern):
         hidden_size = node.in_node(3).shape[-1]
         weights = weights_node.value
         biases = biases_node.value
-        assert weights.shape[0] == input_size + hidden_size, "weights.shape={} input_size={} hidden_size={}".format(weights.shape, input_size, hidden_size)
-        assert weights.shape[1] == biases.shape[0] == 4 * hidden_size, "weights.shape={} biases.shape={} hidden_size={}".format(weights.shape, biases.shape, hidden_size)
+        assert weights.shape[0] == input_size + hidden_size, \
+            "weights.shape={} input_size={} hidden_size={}".format(weights.shape, input_size, hidden_size)
+        assert weights.shape[1] == biases.shape[0] == 4 * hidden_size, \
+            "weights.shape={} biases.shape={} hidden_size={}".format(weights.shape, biases.shape, hidden_size)
 
         weights = weights.reshape([
             weights.shape[0],
@@ -199,15 +206,35 @@ class BlockLSTMtoLSTMSequence(MiddleReplacementPattern):
 
         graph.add_edge(match['BlockLSTM'].id, match['gather_1_data'].id, **attrs)
 
-        match['BlockLSTM'].op = 'LSTMSequence'
-        match['BlockLSTM']['sequence_dim'] = 0  # TF reference
-        match['BlockLSTM']['batch_dim'] = 1  # TF reference
-        match['BlockLSTM']['direction'] = 'forward'  # TF reference
-        match['BlockLSTM']['hidden_size'] = match['concatenated_hidden_states'].shape[-1]
-        match['BlockLSTM']['format'] = 'tf'
+        """
+        #3 Renumbering h_init_state, c_init_state input ports to match RNNSequence ports order.
+        """
+        h_init_port = 4
+        c_init_port = 5
+        # c_init_state
+        if 4 in node.in_nodes():
+            assert c_init_port not in node.in_nodes()
+            cell_state_edge = graph.get_edge_data(node.in_node(4).id, node.id)
+            cell_state_edge[0]['in'] = c_init_port
+
+
+        #h_init_state
+        if 3 in node.in_nodes():
+            assert h_init_port not in node.in_nodes()
+            hidden_state_edge = graph.get_edge_data(node.in_node(3).id, node.id)
+            hidden_state_edge[0]['in'] = h_init_port
+
+        new_attrs = {'sequence_dim': 0,
+                     'batch_dim': 1,
+                     'direction': 'forward',
+                     'hidden_size': match['concatenated_hidden_states'].shape[-1],
+                     'format': 'tf',
+                     }
+
+        LSTM.update_node_stat(match['BlockLSTM'], new_attrs)
 
         """
-        Optional #3 optimization from class description following
+        Optional #4 optimization from class description following
         """
         data_to_mul = [n for n in match['mul'].in_nodes().values() if n.id != match['concatenated_hidden_states'].id]
         if len(data_to_mul) != 1: