Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / TensorIteratorOutput.py
index 695e776..07b64db 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -16,9 +16,8 @@
 
 import logging as log
 
-import networkx as nx
-
 from extensions.ops.TensorIterator_ops import TensorIteratorOutput
+from mo.graph.graph import Graph
 from mo.middle.replacement import MiddleReplacementPattern
 
 
@@ -40,6 +39,15 @@ class SmartOutputMatcher(MiddleReplacementPattern):
                                     --------> Identity -> TensorArrayWrite -> NextIteration
     """
     enabled = True
+    graph_condition = [lambda graph: graph.graph['is_cyclic']]
+
+    def run_after(self):
+        from extensions.middle.TensorIteratorInput import SmartInputMatcher
+        return [SmartInputMatcher]
+
+    def run_before(self):
+        from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
+        return [TensorIteratorMerge]
 
     @staticmethod
     def pattern():
@@ -121,7 +129,7 @@ class SmartOutputMatcher(MiddleReplacementPattern):
         )
 
     @staticmethod
-    def replace_pattern(graph: nx.MultiDiGraph, match: dict):
+    def replace_pattern(graph: Graph, match: dict):
         log.debug('================== SmartOutputFind ===============')
 
         assert match['WriteEnter_data'].value is not None
@@ -149,3 +157,132 @@ class SmartOutputMatcher(MiddleReplacementPattern):
             if node not in safe_nodes:
                 nodes_for_remove.append(match[node].id)
         graph.remove_nodes_from(nodes_for_remove)
+
+
+class SimpleOutputMatcher(MiddleReplacementPattern):
+    """
+    This pattern match partitioned outputs for TensorIterator in dynamic_rnn loops in TF.
+    The structure of pattern without Data nodes between ops. Every node is named as op attribute of this node
+    (data nodes is marked by (data)):
+        TensorArray
+        |         |
+    Flow(data)  Handle(data)------------------------------
+            |    |                                       |
+            v    v                                       v
+            Enter  ->  Merge -> Switch -> Exit -> TensorArrayRead
+                                    |
+                                    |
+                                    |
+                                    |
+                                    --------> Identity -> TensorArrayWrite -> NextIteration
+    """
+    enabled = True
+    graph_condition = [lambda graph: graph.graph['is_cyclic']]
+
+    def run_after(self):
+        return [SmartOutputMatcher]
+
+    def run_before(self):
+        from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
+        from extensions.middle.TensorIteratorCondition import LoopConditionMatcher
+        return [TensorIteratorMerge, LoopConditionMatcher]
+
+    @staticmethod
+    def pattern():
+        return dict(
+            nodes=[
+                ('TensorArray', dict(kind='op', op='TensorArrayV3')),
+                ('TensorArray_data', dict(kind='data')),
+                ('TensorArray_flow_data', dict(kind='data')),
+
+                ('TensorArrayWrite', dict(kind='op', op='TensorArrayWriteV3')),
+                ('TensorArrayWrite_data', dict(kind='data')),
+
+                ('NextIteration', dict(kind='op', op='NextIteration')),
+                ('NextIteration_data', dict(kind='data')),
+
+                ('Condition_data', dict(kind='data')),
+
+                ('Identity_2', dict(kind='op', op='Identity')),
+                ('Identity_2_data', dict(kind='data')),
+
+                ('Switch_2', dict(kind='op', op='Switch')),
+                ('Switch_2_data', dict(kind='data')),
+                ('Switch_2_data_exit', dict(kind='data')),
+
+                ('Merge_2', dict(kind='op', op='Merge')),
+                ('Merge_2_data', dict(kind='data')),
+
+                ('Enter_2', dict(kind='op', op='Enter')),
+                ('Enter_2_data', dict(kind='data')),
+
+                ('WriteEnter', dict(kind='op', op='Enter')),
+                ('WriteEnter_data', dict(kind='data')),
+
+                ('Exit', dict(kind='op', op='Exit')),
+                ('Exit_data', dict(kind='data')),
+                #
+                ('TensorArrayRead', dict(op='TensorArrayReadV3')),
+                ('TensorArrayRead_data', dict(kind='data')),
+            ],
+            edges=[
+                ('TensorArray', 'TensorArray_data'),
+                ('TensorArray', 'TensorArray_flow_data'),
+                ('TensorArray_flow_data', 'Enter_2'),
+                ('TensorArray_data', 'WriteEnter'),
+
+
+                ('Enter_2', 'Enter_2_data'),
+                ('Enter_2_data', 'Merge_2'),
+                ('Merge_2', 'Merge_2_data'),
+                ('Merge_2_data', 'Switch_2'),
+                ('Switch_2', 'Switch_2_data'),
+                ('Switch_2', 'Switch_2_data_exit'),
+                ('Switch_2_data', 'Identity_2'),
+                ('Identity_2', 'Identity_2_data'),
+
+                ('Switch_2_data_exit', 'Exit'),
+                ('Exit', 'Exit_data'),
+                ('Exit_data', 'TensorArrayRead'),
+
+                ('WriteEnter', 'WriteEnter_data'),
+                ('WriteEnter_data', 'TensorArrayWrite', {'in': 0}),
+
+                ('Identity_2_data', 'TensorArrayWrite', {'in': 3}),
+                #
+                ('TensorArrayWrite', 'TensorArrayWrite_data'),
+                ('TensorArrayWrite_data', 'NextIteration'),
+                ('Condition_data', 'Switch_2'),
+                #
+                ('TensorArray_data', 'TensorArrayRead'),
+                ('TensorArrayRead', 'TensorArrayRead_data'),
+                ('NextIteration', 'NextIteration_data'),
+                ('NextIteration_data', 'Merge_2'),
+            ],
+        )
+
+    @staticmethod
+    def replace_pattern(graph: Graph, match: dict):
+        log.debug('================== SimpleOutputFind ===============')
+        assert match['WriteEnter_data'].value is not None
+
+        index = match['TensorArrayWrite'].in_node(1)
+        value = match['TensorArrayWrite'].in_node(2)
+
+        # axis == 0 because in TensorArray we ALWAYS iterate over 0 axis, other params will be fill later (with
+        # condition)
+        output = TensorIteratorOutput(graph, dict(
+                                                  external_port_id=str(match['WriteEnter_data'].value),
+                                                  internal_layer_id=value.id,
+                                                  name=match['TensorArrayWrite'].name + '/TensorIteratorOutput_'
+                                                  ))
+        output.create_node_with_data(inputs=[value, index],
+                                     data_nodes=[match['TensorArrayRead_data']])
+
+        # Delete useless nodes
+        safe_nodes = ['TensorArrayRead_data', 'Condition_data']
+        nodes_for_remove = []
+        for node in match.keys():
+            if node not in safe_nodes:
+                nodes_for_remove.append(match[node].id)
+        graph.remove_nodes_from(nodes_for_remove)