Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / LSTMRNNSequenceToTensorIterator.py
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
  See the License for the specific language governing permissions and
  limitations under the License.
 """
-
-import networkx as nx
 import numpy as np
 
 from extensions.middle.FusePermutesSequence import FusePermutesSequence
-from extensions.middle.lstm_sequence_normalize import LSTMSequenceNormalize
-from extensions.middle.mxnet_lstm_sequence_normalize import MXNetLSTMSequenceNormalize
+from extensions.middle.RNNSequenceNormalizeToIE import RNNSequenceNormalize
 from extensions.ops.lstm_cell import LSTMCell
 from extensions.ops.tensor_iterator import TensorIterator
+from mo.graph.graph import Graph, add_opoutput
 from mo.middle.replacement import MiddleReplacementPattern
 from mo.ops.op import Op
 from mo.ops.reshape import Reshape
 
 
-class LSTMSequenceTensorIterator(MiddleReplacementPattern):
-    """ Converts normalized LSTMSequence op to TensorIterator.
+class LSTMToTensorIterator(MiddleReplacementPattern):
+    """ Converts normalized RNNSequence with op=LSTM to TensorIterator.
 
-        Normalized LSTMSequence means that it should be processed by
-        LSTMSequenceNormalize transform that ensures its stict form.
+        Normalized RNNSequence means that it should be processed by
+        RNNSequenceNormalize transform that ensures its strict form.
 
-        This transformation builds an altenative sub-graph for LSTMSequence
+        This transformation builds an alternative sub-graph for LSTMSequence
         with TensorIterator connected in the same way as an original LSTMSequence
         node and with internal body represented as LSTMCell op node with necessary
         squeezes and unsqueezes around.
     """
 
     enabled = True
-
+    force_clean_up = True
+    id = 'lstm_to_tensor_iterator'
+    
     def run_after(self):
-        return [LSTMSequenceNormalize, MXNetLSTMSequenceNormalize]
+        return [RNNSequenceNormalize]
 
     def run_before(self):
         return [FusePermutesSequence]
@@ -50,7 +50,7 @@ class LSTMSequenceTensorIterator(MiddleReplacementPattern):
     def pattern(self):
         return dict(
             nodes=[
-                ('lstm', dict(kind='op', op='LSTMSequence')),
+                ('lstm', dict(kind='op', op='LSTM', type='RNNSequence')),
                 ('input', dict(kind='data')),
                 ('weights', dict(kind='data')),
                 ('biases', dict(kind='data')),
@@ -66,16 +66,20 @@ class LSTMSequenceTensorIterator(MiddleReplacementPattern):
             ]
         )
 
-    def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
+    def replace_pattern(self, graph: Graph, match: dict):
         lstm = match['lstm']
 
         # Build TensorIterator body first
-        body = nx.MultiDiGraph(name=lstm.name + '/sub_graph', layout=graph.graph['layout'])
+        body = Graph(name=lstm.name + '/sub_graph')
+        body.graph = graph.graph
+
+        # 1. Input squeeze Reshape
         inputs = [Op._create_data_node(body, lstm.name + '/inport/' + str(inp),
                                        {'shape': lstm.in_node(inp).shape.copy(),
                                         'value': lstm.in_node(inp).value.copy()
                                         if lstm.in_node(inp).value is not None and inp in [1, 2] else None})
-                                        for inp in [0, 3, 4, 1, 2]]
+                  for inp in [0, 4, 5, 1, 2]]  # X, WR, B, h_init, c_init
+
         inputs[0].shape[lstm.sequence_dim] = 1
         reshape_dim = inputs[0].shape.copy()
         reshape_dim[lstm.batch_dim] = -1
@@ -85,11 +89,14 @@ class LSTMSequenceTensorIterator(MiddleReplacementPattern):
             dict(name=lstm.name + '/input_squeeze', internal_layer_id=0, dim=reshape_dim)
         )
         inputs[0] = input_squeeze.create_node_with_data([inputs[0]], edge_attrs=[{'internal_port_id': 0}])
-        lstm_cell_op = LSTMCell(body, dict(hidden_size=match['lstm'].hidden_size, name=lstm.name + '/LSTMCell',
-                                           internal_layer_id=1))
+
+        # 2. Output unsqueeze Reshape
         outputs = [Op._create_data_node(body, lstm.name + '/outport/' + str(out),
                                         {'shape': lstm.out_node(out).shape.copy() if out in lstm.out_nodes()
-                                        else lstm.in_node(3).shape.copy(), 'is_output': True}) for out in [0, 1]]
+                                        else lstm.in_node(4).shape.copy()}) for out in [0, 1]]
+        for out in outputs:
+            add_opoutput(body, out.id, 0, False)
+
         unsqueezed_output_shape = outputs[0].shape.copy()
         unsqueezed_output_shape[lstm.sequence_dim] = 1
         squeezed_output_shape = np.delete(unsqueezed_output_shape, lstm.sequence_dim)
@@ -97,7 +104,16 @@ class LSTMSequenceTensorIterator(MiddleReplacementPattern):
         unsqueezed_output_shape[lstm.batch_dim] = -1
         output_unsqueeze = Reshape(body, dict(name=lstm.name + 'output_unsqueeze', dim=unsqueezed_output_shape,
                                               internal_layer_id=2))
-        # TODO edge attributes should be assigned by the op itself
+
+        # 3. LSTMCell
+        lstm_cell_op = LSTMCell(body, dict(hidden_size=lstm.hidden_size,
+                                           activations=lstm.activations,
+                                           activation_alpha=lstm.activation_alpha,
+                                           activation_beta=lstm.activation_beta,
+                                           clip=lstm.clip,
+                                           input_forget=lstm.input_forget,
+                                           name=lstm.name + '/LSTMCell',
+                                           internal_layer_id=1))
         lstm_cell_node = lstm_cell_op.create_node_with_data(inputs, data_nodes=outputs,
                                                             edge_attrs=[{}, {'internal_port_id': 1},
                                                                         {'internal_port_id': 2}, {'bin': 'weights'},
@@ -106,8 +122,9 @@ class LSTMSequenceTensorIterator(MiddleReplacementPattern):
         lstm_cell_node[0].in_node().out_edge(1)['internal_port_id'] = 5
         lstm_cell_node[0] = output_unsqueeze.create_node_with_data([lstm_cell_node[0]])
         lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 3
-        lstm_cell_node[0]['is_output'] = True
+        add_opoutput(body, lstm_cell_node[0].id, 0, False)
 
+        # 4. TensorIterator layer creating
         assert lstm.direction in ['forward', 'reverse']
         if lstm.direction == 'forward':
             stride = 1
@@ -123,6 +140,7 @@ class LSTMSequenceTensorIterator(MiddleReplacementPattern):
             'external_port_id': 3,
             'internal_layer_id': 2,
             'internal_port_id': 3,
+
             'axis': lstm.sequence_dim,
             'stride': stride,
             'start': start,
@@ -130,6 +148,7 @@ class LSTMSequenceTensorIterator(MiddleReplacementPattern):
             'part_size': 1,
         }]
 
+        # Adding h_state, c_state to outputs
         if len(lstm.out_nodes()) == 3:
             output_port_map.extend([{
                 'external_port_id': 4,
@@ -144,12 +163,15 @@ class LSTMSequenceTensorIterator(MiddleReplacementPattern):
         ti_op = TensorIterator(graph, {
             'name': lstm.name + '/TensorIterator',
             'body': body,
+            'in_ports_count': 3,
+            'out_ports_count': len(lstm.out_nodes()),
 
             'input_port_map': [
                 {
                     'external_port_id': 0,
                     'internal_layer_id': 0,
                     'internal_port_id': 0,
+
                     'axis': lstm.sequence_dim,
                     'stride': stride,
                     'start': start,
@@ -188,7 +210,8 @@ class LSTMSequenceTensorIterator(MiddleReplacementPattern):
 
         assert sorted(lstm.out_nodes().keys()) == list(range(len(lstm.out_nodes()))), \
             "There are gaps in output ports of LSTMSequence operation. Node {}".format(lstm.id)
-        outs = ti_op.create_node_with_data([lstm.in_node(i) for i in [0, 3, 4]],
+
+        outs = ti_op.create_node_with_data([lstm.in_node(i) for i in [0, 4, 5]],  # X, h_init, c_init
                                            data_nodes=[lstm.out_node(i) for i in range(len(lstm.out_nodes()))],
                                            edge_attrs=[{'external_port_id': 0}, {'external_port_id': 1},
                                                        {'external_port_id': 2}])