Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / TensorIteratorCondition.py
index 70b169f..435a686 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
 
 import logging as log
 
-import networkx as nx
-
 from extensions.ops.TensorIterator_ops import TensorIteratorCondition
+from mo.graph.graph import Graph
 from mo.middle.replacement import MiddleReplacementPattern
+import numpy as np
 
 
 class LoopConditionMatcher(MiddleReplacementPattern):
@@ -46,6 +46,14 @@ Shape -> StridedSlice -> Enter -|    LogicalAnd --> LoopCond (data)
                                                                    Const----
     """
     enabled = True
+    graph_condition = [lambda graph: graph.graph['is_cyclic']]
+
+    def run_after(self):
+        return []
+
+    def run_before(self):
+        from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
+        return [TensorIteratorMerge]
 
     @staticmethod
     def pattern():
@@ -69,7 +77,6 @@ Shape -> StridedSlice -> Enter -|    LogicalAnd --> LoopCond (data)
 
                 ('Enter_2_less', dict(kind='op', op='Enter')),
                 ('Enter_2_less_data', dict(kind='data')),
-                ('minimum', dict(kind='op', op='Minimum')),
                 ('minimum_data', dict(kind='data')),
 
                 ('and', dict(kind='op', op='LogicalAnd')),
@@ -78,9 +85,9 @@ Shape -> StridedSlice -> Enter -|    LogicalAnd --> LoopCond (data)
                 ('loop_cond_data', dict(kind='data')),
 
                 ('init_1', dict(kind='op', op='Const')),
-                ('init_1_data',  dict(kind='data')),
+                ('init_1_data', dict(kind='data')),
                 ('Enter_1', dict(kind='op', op='Enter')),
-                ('Enter_1_data',  dict(kind='data')),
+                ('Enter_1_data', dict(kind='data')),
 
                 ('init_2', dict(kind='op', op='Const')),
                 ('init_2_data', dict(kind='data')),
@@ -92,7 +99,7 @@ Shape -> StridedSlice -> Enter -|    LogicalAnd --> LoopCond (data)
                 ('Identity_1', dict(kind='op', op='Identity')),
                 ('Identity_1_data', dict(kind='data')),
                 ('add_1', dict(kind='op', op='Add')),
-                ('add_1_y',  dict(kind='op', op='Const')),
+                ('add_1_y', dict(kind='op', op='Const')),
                 ('add_1_y_data', dict(kind='data')),
                 ('add_1_data', dict(kind='data')),
                 ('NextIteration_1', dict(kind='op', op='NextIteration')),
@@ -111,7 +118,6 @@ Shape -> StridedSlice -> Enter -|    LogicalAnd --> LoopCond (data)
             edges=[
                 ('Strided_slice', 'Strided_slice_data'),
                 ('Strided_slice_data', 'Enter_1_less'),
-                ('Strided_slice_data', 'minimum'),
                 ('Enter_1_less', 'Enter_1_less_data'),
                 ('Enter_1_less_data', 'Less_1'),
                 ('Less_1', 'Less_1_data'),
@@ -150,7 +156,6 @@ Shape -> StridedSlice -> Enter -|    LogicalAnd --> LoopCond (data)
                 ('add_2', 'add_2_data'),
                 ('add_2_data', 'NextIteration_2'),
 
-                ('minimum', 'minimum_data'),
                 ('minimum_data', 'Enter_2_less'),
                 ('Enter_2_less', 'Enter_2_less_data'),
                 ('Enter_2_less_data', 'Less_2'),
@@ -168,26 +173,35 @@ Shape -> StridedSlice -> Enter -|    LogicalAnd --> LoopCond (data)
         )
 
     @staticmethod
-    def replace_pattern(graph: nx.MultiDiGraph, match: dict):
+    def looking_for_iteration_counter(graph: Graph, match: dict):
+        types = ['TensorIteratorInput', 'TensorIteratorOutput']
+        candidates = np.array([match['Identity_1_data'], match['Identity_2_data']])
+        results = np.array([False for i in range(len(candidates))])
+        for i, candidat in enumerate(candidates):
+            for node in candidat.out_nodes():
+                if node['op'] in types:
+                    results[i] = True
+        assert not np.all(results)
+        assert sum(results) == 1
+        return candidates[results == True][0]
+
+    def replace_pattern(self, graph: Graph, match: dict):
         log.debug('================== ConditionFind ===============')
-        max_node = match['minimum'].in_node(1).in_node()
-        assert max_node['kind'] == 'op' and max_node['op'] == 'Maximum'
-
-        #init_1
+        # init_1
         init_1 = match['init_1_data'].value
         assert init_1 is not None
         init_1 = int(init_1)
 
-        #init_2
+        # init_2
         init_2 = match['init_2_data'].value
         assert init_2 is not None
         init_2 = int(init_2)
 
-        #step_1
+        # step_1
         assert match['add_1_y_data'].value is not None
         step_1 = int(match['add_1_y_data'].value)
 
-        #step_2
+        # step_2
         assert match['add_2_y_data'].value is not None
         step_2 = int(match['add_2_y_data'].value)
 
@@ -195,14 +209,17 @@ Shape -> StridedSlice -> Enter -|    LogicalAnd --> LoopCond (data)
         match['Identity_2_data'].value = None
 
         # Create condition node and delete all useless nodes from condition pattern
-        condition_attrs = dict(time=dict(init=init_2, step=step_2), iter=dict(init=init_1, step=step_1), \
+        loop_condiiton = match['loop_cond_data']
+        iterator_data = self.looking_for_iteration_counter(graph, match)
+
+        condition_attrs = dict(time=dict(init=init_2, step=step_2), iter=dict(init=init_1, step=step_1),
                                name=match['loop_cond'].name + '/TensorIteratorCondition_')
         condition = TensorIteratorCondition(graph, attrs=condition_attrs)
         condition.create_node_with_data(inputs=[match['Strided_slice_data'], match['minimum_data']],
-                                        data_nodes=[match['loop_cond_data'], match['Identity_2_data']])
+                                        data_nodes=[loop_condiiton, iterator_data])
 
         # Delete useless nodes
-        safe_nodes = ['loop_cond_data', 'Identity_2_data', 'Strided_slice', 'Strided_slice_data',
+        safe_nodes = ['loop_cond_data', 'Identity_1_data', 'Identity_2_data',  'Strided_slice', 'Strided_slice_data',
                       'minimum', 'minimum_data']
         nodes_for_remove = []
         for node in match.keys():
@@ -211,7 +228,17 @@ Shape -> StridedSlice -> Enter -|    LogicalAnd --> LoopCond (data)
         graph.remove_nodes_from(nodes_for_remove)
 
 
-class SimpleConditionMather(MiddleReplacementPattern):
+class SimpleConditionMatcher(MiddleReplacementPattern):
+    enabled = True
+    graph_condition = [lambda graph: graph.graph['is_cyclic']]
+
+    def run_after(self):
+        return [LoopConditionMatcher]
+
+    def run_before(self):
+        from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
+        return [TensorIteratorMerge]
+
     @staticmethod
     def pattern():
         log.debug('+++++++++++++++ SimpleConditionMatching ++++++++++++++++')
@@ -231,17 +258,16 @@ class SimpleConditionMather(MiddleReplacementPattern):
                 ('loop_cond_data', dict(kind='data')),
 
                 ('init_1', dict(kind='op', op='Const')),
-                ('init_1_data',  dict(kind='data')),
+                ('init_1_data', dict(kind='data')),
                 ('Enter_1', dict(kind='op', op='Enter')),
-                ('Enter_1_data',  dict(kind='data')),
-
+                ('Enter_1_data', dict(kind='data')),
 
                 ('Switch_1', dict(kind='op', op='Switch')),
                 ('Switch_1_data', dict(kind='data')),
                 ('Identity_1', dict(kind='op', op='Identity')),
                 ('Identity_1_data', dict(kind='data')),
                 ('add_1', dict(kind='op', op='Add')),
-                ('add_1_y',  dict(kind='op', op='Const')),
+                ('add_1_y', dict(kind='op', op='Const')),
                 ('add_1_y_data', dict(kind='data')),
                 ('add_1_data', dict(kind='data')),
                 ('NextIteration_1', dict(kind='op', op='NextIteration')),
@@ -278,7 +304,7 @@ class SimpleConditionMather(MiddleReplacementPattern):
         )
 
     @staticmethod
-    def replace_pattern(graph: nx.MultiDiGraph, match: dict):
+    def replace_pattern(graph: Graph, match: dict):
         log.debug('================== SimpleConditionFind ===============')
         # init_1
         init_1 = match['init_1_data'].value
@@ -292,7 +318,7 @@ class SimpleConditionMather(MiddleReplacementPattern):
         match['loop_cond_data'].value = None
 
         # Create condition node and delete all useless nodes from condition pattern
-        condition_attrs = dict(iter=dict(init=init_1, step=step_1), \
+        condition_attrs = dict(iter=dict(init=init_1, step=step_1),
                                name=match['loop_cond'].name + '/TensorIteratorCondition_')
         condition = TensorIteratorCondition(graph, attrs=condition_attrs)
         condition.create_node_with_data(inputs=[match['Strided_slice_data']],
@@ -304,4 +330,4 @@ class SimpleConditionMather(MiddleReplacementPattern):
         for node in match.keys():
             if node not in safe_nodes:
                 nodes_for_remove.append(match[node].id)
-        graph.remove_nodes_from(nodes_for_remove)
\ No newline at end of file
+        graph.remove_nodes_from(nodes_for_remove)