Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / TensorIteratorInput.py
index 65cdb40..93d63fa 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
 """
 
 import logging as log
-import networkx as nx
+
 import numpy as np
 
 from extensions.ops.TensorIterator_ops import TensorIteratorInput
+from mo.graph.graph import Graph
 from mo.middle.replacement import MiddleReplacementPattern
 
 
@@ -38,7 +39,16 @@ class SmartInputMatcher(MiddleReplacementPattern):
         |__________________________________________________|
     """
 
-    enabled = False  # called from mo.pipeline.tf directly
+    enabled = True
+    graph_condition = [lambda graph: graph.graph['is_cyclic']]
+
+    def run_after(self):
+        from extensions.middle.TensorIterator_utils import DeleteSelect
+        return [DeleteSelect]
+
+    def run_before(self):
+        from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
+        return [TensorIteratorMerge]
 
     @staticmethod
     def pattern():
@@ -115,7 +125,7 @@ class SmartInputMatcher(MiddleReplacementPattern):
         )
 
     @staticmethod
-    def replace_pattern(graph: nx.MultiDiGraph, match: dict):
+    def replace_pattern(graph: Graph, match: dict):
         log.debug('================== SmartInputFind ===============')
 
         assert match['Enter_data'].value is not None
@@ -141,12 +151,12 @@ class SmartInputMatcher(MiddleReplacementPattern):
         # axis == 0 because in TensorArray we ALWAYS iterate over 0 axis, other params will be fill later (with
         # condition)
         input_node = TensorIteratorInput(graph, dict(axis=0, start=start, stride=None, part_size=None,
-                                                external_port_id=str(match['Enter_data'].value),
-                                                internal_layer_id=match['TensorArrayRead_data'].id,
-                                                name=match['TensorArrayRead'].name + '/TensorIteratorInput_'
-                                                ))
+                                                     external_port_id=str(match['Enter_data'].value),
+                                                     internal_layer_id=match['TensorArrayRead_data'].id,
+                                                     name=match['TensorArrayRead'].name + '/TensorIteratorInput_'
+                                                     ))
         input_node.create_node_with_data(inputs=[ta_size_data, value, match['Condition_data']],
-                                    data_nodes=[match['TensorArrayRead_data']])
+                                         data_nodes=[match['TensorArrayRead_data']])
         # Delete useless nodes
         safe_nodes = ['TensorArrayRead_data', 'Condition', 'Condition_data']
 
@@ -158,12 +168,21 @@ class SmartInputMatcher(MiddleReplacementPattern):
 
 
 class SimpleInputMatcher(MiddleReplacementPattern):
+    enabled = True
+    graph_condition = [lambda graph: graph.graph['is_cyclic']]
 
-    enabled = False  # called from mo.pipeline.tf directly
+    def run_after(self):
+        from extensions.middle.DeleteNotExecutable import DeleteNotExecutable
+        return [DeleteNotExecutable]
+
+    def run_before(self):
+        from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
+        return [TensorIteratorMerge]
 
     """
     This pattern match simple inputs (without partitions) in while loops in TF (this inputs are set by Enter nodes).
     """
+
     @staticmethod
     def pattern():
         return dict(
@@ -175,13 +194,13 @@ class SimpleInputMatcher(MiddleReplacementPattern):
         )
 
     @staticmethod
-    def replace_pattern(graph: nx.MultiDiGraph, match: dict):
+    def replace_pattern(graph: Graph, match: dict):
         log.debug('================== SimpletInputFind ===============')
 
         input_node = TensorIteratorInput(graph, dict(external_port_id=None,
-                                         internal_layer_id=None,
-                                         name=match['Enter'].name + '/TensorIteratorInput_'
-                                         ))
+                                                     internal_layer_id=None,
+                                                     name=match['Enter'].name + '/TensorIteratorInput_'
+                                                     ))
         input_node.create_node_with_data(inputs=[match['Enter'].in_node()], data_nodes=[match['Enter'].out_node()])
 
         # Delete useless nodes
@@ -189,8 +208,15 @@ class SimpleInputMatcher(MiddleReplacementPattern):
 
 
 class BackEdgeSimpleInputMatcher(MiddleReplacementPattern):
+    enabled = True
+    graph_condition = [lambda graph: graph.graph['is_cyclic']]
 
-    enabled = False  # called from mo.pipeline.tf directly
+    def run_after(self):
+        return [SimpleInputMatcher]
+
+    def run_before(self):
+        from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
+        return [TensorIteratorMerge]
 
     @staticmethod
     def pattern():
@@ -203,7 +229,7 @@ class BackEdgeSimpleInputMatcher(MiddleReplacementPattern):
         )
 
     @staticmethod
-    def replace_pattern(graph: nx.MultiDiGraph, match: dict):
+    def replace_pattern(graph: Graph, match: dict):
         log.debug('================== SimpleBackEdgeInputFind ===============')
 
         assert len(match['BackEdge'].in_nodes()) == 3
@@ -212,11 +238,18 @@ class BackEdgeSimpleInputMatcher(MiddleReplacementPattern):
         cycle_input = match['BackEdge'].in_node(1)
 
         # We need to create new TensorItertorInput node only if this node doesn't exist already.
-        if len(init_input.in_nodes()) == 0:
+        if len(init_input.in_nodes()) == 0 or\
+           (len(init_input.in_nodes()) == 1 and init_input.has_valid('value')):
+
             input_node = TensorIteratorInput(graph, dict(external_port_id=None,
-                                             internal_layer_id=None,
-                                             name=match['BackEdge'].name + '/TensorIteratorInput_'
-                                            ))
+                                                         internal_layer_id=None,
+                                                         name=match['BackEdge'].name + '/TensorIteratorInput_'
+                                                         ))
+
+            # In case if data node has Constant producer
+            if len(init_input.in_nodes()) == 1:
+                graph.remove_edge(init_input.in_node(0).id, init_input.id)
+
             input_data_node = input_node.create_node_with_data(inputs=[init_input])
             input_data_node.shape = np.array(init_input.shape, dtype=np.int64)
             graph.remove_edges_from([(init_input.id, match['BackEdge'].id)])