Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / TensorIteratorLSTMToLSTMSequence.py
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
  limitations under the License.
 """
 
-import networkx as nx
-
-from mo.graph.graph import copy_node
-from mo.utils.error import Error
+from extensions.middle.TF_lstm_cell_to_generic import TensorFlowLSTMtoGeneric
+from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
+from mo.graph.graph import Graph
 from mo.middle.pattern_match import find_isomorphisms
 from mo.middle.replacement import MiddleReplacementPattern
-from extensions.ops.lstm_sequence import LSTMSequence
-from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
-from extensions.middle.lstm_sequence_normalize import LSTMSequenceNormalize, permute_before_and_after
-from extensions.middle.lstm_sequence_tensor_iterator import LSTMSequenceTensorIterator
-from extensions.middle.TF_lstm_cell_to_generic import TensorFlowLSTMtoGeneric
+from mo.utils.error import Error
+from extensions.middle.ONNXRNNSequenceNormalize import ONNXRNNSequenceNormalize, permute_before_and_after
 
 
 class TensorIteratorLSTM(MiddleReplacementPattern):
@@ -40,7 +36,7 @@ class TensorIteratorLSTM(MiddleReplacementPattern):
     enabled = False
 
     def run_after(self):
-        return [TensorIteratorMerge, LSTMSequenceNormalize, LSTMSequenceTensorIterator, TensorFlowLSTMtoGeneric]
+        return [TensorIteratorMerge, ONNXRNNSequenceNormalize, TensorFlowLSTMtoGeneric]
 
     def pattern(self):
         return dict(
@@ -52,8 +48,8 @@ class TensorIteratorLSTM(MiddleReplacementPattern):
         )
 
     @staticmethod
-    def replace_pattern(graph: nx.MultiDiGraph, match: dict):
-        nodes=[
+    def replace_pattern(graph: Graph, match: dict):
+        nodes = [
             ('input_unsqueezed'),
             ('squeeze', dict(op='Reshape')),
             ('input_squeezed'),
@@ -69,7 +65,7 @@ class TensorIteratorLSTM(MiddleReplacementPattern):
             ('unsqueeze', dict(op='Reshape')),
             ('output_unsqueezed'),
         ]
-        edges=[
+        edges = [
             ('input_unsqueezed', 'squeeze'),
             ('squeeze', 'input_squeezed'),
 
@@ -101,37 +97,3 @@ class TensorIteratorLSTM(MiddleReplacementPattern):
                         'Please modify the original network '
                         'to meet the requirements.'.format(ti.soft_get('name')))
         # TODO Additional checks for port indices
-        if body_match['lstm'].has_valid('mark_supported_by_IE'):
-            body_match['lstm'].mark_supported_by_IE(body_match['lstm'])
-
-
-class CheckUnsupportedLSTMCell(MiddleReplacementPattern):
-    """ Finds all unsupported LSTMCell.
-
-        Initiates the second translation round if find any not supported LSTMCell instances.
-    """
-
-    enabled = False
-
-    def run_after(self):
-        return [TensorIteratorLSTM]
-
-    def pattern(self):
-        return dict(
-            nodes=[
-                ('lstm', dict(op='LSTMCell')),
-            ],
-            edges=[
-            ]
-        )
-
-    @staticmethod
-    def replace_pattern(graph: nx.MultiDiGraph, match: dict):
-        lstmcell = match['lstm']
-        if lstmcell.has_valid('finalize_first_round'):
-            lstmcell.finalize_first_round()
-            if not lstmcell.has_and_set('supported_by_IE'):
-                # this is a signal for the main translation pipeline to repeat the entire conversion process
-                graph.graph['repeat_conversion'] = True
-        # in case when there is no lstmcell.finalize_first_round then this cell wasn't created with the pattern
-        # (for example in ONNX) and we don't initiate the second round.