Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / permute_tensor_iterator.py
index fbd3d63..7696660 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
  limitations under the License.
 """
 
-import networkx as nx
 import numpy as np
-from copy import deepcopy
 
-from mo.graph.graph import copy_node, Node, dict_includes
-from mo.utils.error import Error
-from mo.middle.passes.eliminate import remove_op_node_with_data_node
-from mo.middle.pattern_match import find_isomorphisms, find_pattern_matches
-from mo.middle.replacement import MiddleReplacementPattern
-from mo.ops.op import Op
-from extensions.ops.lstm_sequence import LSTMSequence
 from extensions.middle.FusePermutesSequence import FusePermutesSequence
+from extensions.middle.LSTMRNNSequenceToTensorIterator import LSTMToTensorIterator
+from extensions.middle.ONNXRNNSequenceNormalize import ONNXRNNSequenceNormalize
 from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
-from extensions.middle.lstm_sequence_normalize import LSTMSequenceNormalize, permute_before_and_after
-from extensions.middle.lstm_sequence_tensor_iterator import LSTMSequenceTensorIterator
-from extensions.middle.decompose_bi_lstm import DecomposeBiLSTM
+from mo.graph.graph import dict_includes, Graph
+from mo.middle.passes.eliminate import remove_op_node_with_data_node
+from mo.middle.pattern_match import find_isomorphisms
+from mo.middle.replacement import MiddleReplacementPattern
 
 
 class PermuteTensorIteratorLSTM(MiddleReplacementPattern):
-    ''' Fuses Permute(1,0,2) --> TI --> Permute(1,0,2) pattern to a single TI with changed axis.
+    """ Fuses Permute(1,0,2) --> TI --> Permute(1,0,2) pattern to a single TI with changed axis.
 
         WARNING This transformation is limited to support of very special case of TI but
         code doesn't check all the cases.
-    '''
+    """
 
     enabled = True
 
     def run_after(self):
-        return [TensorIteratorMerge, LSTMSequenceNormalize, LSTMSequenceTensorIterator, FusePermutesSequence, DecomposeBiLSTM]
+        return [TensorIteratorMerge, ONNXRNNSequenceNormalize, LSTMToTensorIterator, FusePermutesSequence]
+
+
+    def run_before(self):
+        return []
 
     def pattern(self):
         return dict(
@@ -63,21 +61,21 @@ class PermuteTensorIteratorLSTM(MiddleReplacementPattern):
                 ('input', 'direct_permute'),
                 ('direct_permute', 'input_permuted'),
 
-                ('input_permuted', 'ti', {'in': 0}),   # affected by permute
+                ('input_permuted', 'ti', {'in': 0}),  # affected by permute
                 ('init_hidden', 'ti', {'in': 1}),
                 ('init_cell', 'ti', {'in': 2}),
-                ('ti', 'output_permuted', {'out': 0}), # affected by permute
+                ('ti', 'output_permuted', {'out': 0}),  # affected by permute
 
                 ('output_permuted', 'inverse_permute'),
                 ('inverse_permute', 'output'),
             ]
         )
 
-    def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
+    def replace_pattern(self, graph: Graph, match: dict):
 
         # This transformation works if and only if a body of TI
         # matches the following topology (Reshape -> LSTMCell -> Reshape)
-        nodes=[
+        nodes = [
             ('input_unsqueezed'),
             ('squeeze', dict(op='Reshape')),
             ('input_squeezed'),
@@ -92,8 +90,16 @@ class PermuteTensorIteratorLSTM(MiddleReplacementPattern):
             ('output_cell'),
             ('unsqueeze', dict(op='Reshape')),
             ('output_unsqueezed'),
+
+            ('const_w', dict(op='Const')),
+            ('const_b', dict(op='Const')),
+
+            ('op_output', dict(op='OpOutput')),
+            ('op_output_1', dict(op='OpOutput')),
+            ('op_output_2', dict(op='OpOutput'))
+
         ]
-        edges=[
+        edges = [
             ('input_unsqueezed', 'squeeze'),
             ('squeeze', 'input_squeezed'),
 
@@ -103,11 +109,19 @@ class PermuteTensorIteratorLSTM(MiddleReplacementPattern):
             ('weights', 'lstm', {'in': 3}),
             ('biases', 'lstm', {'in': 4}),
 
+            ('const_w', 'weights'),
+            ('const_b', 'biases'),
+
             ('lstm', 'output_hidden', {'out': 0}),
             ('lstm', 'output_cell', {'out': 1}),
 
             ('output_hidden', 'unsqueeze'),
             ('unsqueeze', 'output_unsqueezed'),
+
+            ('output_unsqueezed', 'op_output'),
+            ('output_hidden', 'op_output_1'),
+            ('output_cell', 'op_output_2'),
+
         ]
         ti = match['ti']
         isomorphisms = find_isomorphisms(ti.body, nodes, edges)
@@ -126,7 +140,6 @@ class PermuteTensorIteratorLSTM(MiddleReplacementPattern):
         if not inverse_permute.has_valid('order') or not np.array_equal(inverse_permute.order, permute_order):
             return
 
-
         def find_ports(port_map: list, attrs: dict):
             """ Find all ports in a given port map with specified attributes """
             result = []