Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / reverse_tensor_iterator.py
index 7cd529b..62f5133 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
  limitations under the License.
 """
 
-import networkx as nx
-
-from mo.middle.replacement import MiddleReplacementPattern
-from extensions.ops.lstm_sequence import LSTMSequence
 from extensions.middle.FusePermutesSequence import FusePermutesSequence
-from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
-from extensions.middle.lstm_sequence_normalize import LSTMSequenceNormalize
-from extensions.middle.lstm_sequence_tensor_iterator import LSTMSequenceTensorIterator
+from extensions.middle.ONNXRNNSequenceNormalize import ONNXRNNSequenceNormalize
 from extensions.middle.permute_tensor_iterator import PermuteTensorIteratorLSTM
+from mo.graph.graph import Graph
 from mo.middle.passes.eliminate import remove_op_node_with_data_node
 from mo.middle.replacement import MiddleReplacementPattern
 
@@ -38,13 +33,16 @@ class ReverseTensorIteratorLSTM(MiddleReplacementPattern):
 
     def run_after(self):
         return [
-            TensorIteratorMerge,
-            LSTMSequenceNormalize,
-            LSTMSequenceTensorIterator,
+            ONNXRNNSequenceNormalize,
+
             FusePermutesSequence,
             PermuteTensorIteratorLSTM,
         ]
 
+    def run_before(self):
+        from extensions.middle.pass_separator import MiddleFinish
+        return [MiddleFinish]
+
     def pattern(self):
         return dict(
             nodes=[
@@ -52,7 +50,6 @@ class ReverseTensorIteratorLSTM(MiddleReplacementPattern):
                 ('direct_reverse', dict(op='ReverseSequence')),
                 ('input_reversed'),
                 ('init_hidden'),
-                ('init_cell'),
 
                 ('ti', dict(kind='op', op='TensorIterator')),
 
@@ -66,7 +63,6 @@ class ReverseTensorIteratorLSTM(MiddleReplacementPattern):
 
                 ('input_reversed', 'ti', {'in': 0}),
                 ('init_hidden', 'ti', {'in': 1}),
-                ('init_cell', 'ti', {'in': 2}),
                 ('ti', 'output_reversed', {'out': 0}),
 
                 ('output_reversed', 'inverse_reverse', {'in': 0}),
@@ -74,21 +70,21 @@ class ReverseTensorIteratorLSTM(MiddleReplacementPattern):
             ]
         )
 
-    def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
+    def replace_pattern(self, graph: Graph, match: dict):
         ti = match['ti']
         direct_reverse = match['direct_reverse']
         inverse_reverse = match['inverse_reverse']
 
-        assert direct_reverse.seq_dim == inverse_reverse.seq_dim
-        assert direct_reverse.batch_dim is None and inverse_reverse.batch_dim is None or \
-            direct_reverse.batch_dim == inverse_reverse.batch_dim
+        assert direct_reverse.seq_axis == inverse_reverse.seq_axis
+        assert direct_reverse.batch_axis is None and inverse_reverse.batch_axis is None or \
+               direct_reverse.batch_axis == inverse_reverse.batch_axis
 
         # Modify stride in TI
         for port_map in [ti.input_port_map, ti.output_port_map]:
             for port in port_map:
                 if 'axis' in port and port['axis'] is not None and 'external_port_id' in port:
-                    assert port['axis'] == direct_reverse.seq_dim, \
-                        'axis == {} != {} == direct_reverse.seq_dim'.format(port['axis'], direct_reverse.seq_dim)
+                    assert port['axis'] == direct_reverse.seq_axis, \
+                        'axis == {} != {} == direct_reverse.seq_dim'.format(port['axis'], direct_reverse.seq_axis)
                     if 'stride' not in port or port['stride'] is None:
                         port['stride'] = 1
                     assert port['stride'] in [-1, 1]