Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / TensorIteratorMerge.py
index 218b129..29e9749 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
  limitations under the License.
 """
 
-
 from collections import deque
 from copy import deepcopy
 
-import networkx as nx
 import numpy as np
 
-from mo.graph.graph import Node
-from mo.utils.graph import sub_graph_between_nodes
-from mo.middle.replacement import MiddleReplacementPattern
 from extensions.ops.tensor_iterator import TensorIterator
+from mo.graph.graph import Node, Graph, add_opoutput
+from mo.middle.replacement import MiddleReplacementPattern
 from mo.ops.op import Op
 from mo.ops.reshape import Reshape
+from mo.utils.graph import sub_graph_between_nodes
 
 stop_nodes = ['TensorIteratorInput', 'TensorIteratorOutput', 'TensorIteratorBackEdge', 'TensorIteratorCondition']
 
+
 def op_type(graph, node_name: str):
     node = Node(graph, node_name)
     if node.has_valid('kind') and node['kind'] == 'op':
@@ -45,7 +44,7 @@ def update_inputs(graph, inputs: list, node_name: str):
             inputs.append(node_name)
 
 
-def reverse_dfs(graph: nx.MultiDiGraph, node_name: str, stop_nodes: list, inputs: list, visited: set = None):
+def reverse_dfs(graph: Graph, node_name: str, stop_nodes: list, inputs: list, visited: set = None):
     d = deque()
 
     if visited is None:
@@ -62,7 +61,8 @@ def reverse_dfs(graph: nx.MultiDiGraph, node_name: str, stop_nodes: list, inputs
                 else:
                     update_inputs(graph, inputs, in_node_name)
 
-def dfs(graph: nx.MultiDiGraph, node_name: str, stop_nodes: list, visited: set = None):
+
+def dfs(graph: Graph, node_name: str, stop_nodes: list, visited: set = None):
     d = deque()
 
     visited.add(node_name)
@@ -75,18 +75,28 @@ def dfs(graph: nx.MultiDiGraph, node_name: str, stop_nodes: list, visited: set =
                     visited.add(out_node_name)
                     d.append(out_node_name)
 
+
 def get_body(graph, inputs, outputs):
     nodes, extra_inputs = sub_graph_between_nodes(
         graph,
         inputs,
         outputs,
-        lambda node: node.soft_get('op')  == 'TensorIteratorInput'
+        lambda node: node.soft_get('op') == 'TensorIteratorInput'
     )
     nodes = list(set(nodes) - set(inputs) - set(outputs) - set(extra_inputs))
     return nodes, extra_inputs
 
 
 class TensorIteratorMerge(MiddleReplacementPattern):
+    enabled = True
+    graph_condition = [lambda graph: graph.graph['is_cyclic']]
+
+    def run_after(self):
+        return []
+
+    def run_before(self):
+        return []
+
     @staticmethod
     def pattern():
         return dict(
@@ -144,7 +154,7 @@ class TensorIteratorMerge(MiddleReplacementPattern):
         inputs = [Node(graph, node) for node in inputs]
         outputs = [Node(graph, node) for node in outputs]
         back_edges = [Node(graph, node) for node in back_edges]
-        
+
         external_inputs = [
             {
                 'external_data_id': node.in_node(1 if node.has_valid('axis') else 0),
@@ -156,7 +166,6 @@ class TensorIteratorMerge(MiddleReplacementPattern):
                 'part_size': node.part_size
             } for node in inputs]
 
-
         external_outputs = [
             {
                 'external_data_id': node.out_node(0),
@@ -168,7 +177,6 @@ class TensorIteratorMerge(MiddleReplacementPattern):
                 'part_size': node.part_size
             } for node in outputs]
 
-
         back_edges_data = [
             {
                 'from_data_id': node.in_node(1),
@@ -177,12 +185,14 @@ class TensorIteratorMerge(MiddleReplacementPattern):
             } for node in back_edges
         ]
 
-        body = nx.MultiDiGraph(name='body')
-        body.graph['layout'] = graph.graph['layout']
+        body = Graph(name='body')
+        body.graph = graph.graph
         body.add_nodes_from([(node, graph.node[node]) for node in body_nodes])
-        body.add_edges_from([(u,v,k,d)for u,v,k,d in graph.edges(data=True, keys=True) if u in body_nodes and v in body_nodes])
+        body.add_edges_from(
+            [(u, v, k, d) for u, v, k, d in graph.edges(data=True, keys=True) if u in body_nodes and v in body_nodes])
 
-        graph.remove_nodes_from(body_nodes + [match['condition'].id] + [inp.id for inp in inputs] + [out.id for out in outputs])
+        graph.remove_nodes_from(
+            body_nodes + [match['condition'].id] + [inp.id for inp in inputs] + [out.id for out in outputs])
         internal_id_count = 0
         real_back_edges = []
         for edge in back_edges_data:
@@ -192,7 +202,7 @@ class TensorIteratorMerge(MiddleReplacementPattern):
             edge['from_data_id'] = Node(body, edge['from_data_id'].id)
             edge['to_data_id'] = Node(body, edge['to_data_id'].id)
             edge['init_data_id'] = Node(body, edge['init_data_id'].id)
-            edge['from_data_id']['is_output'] = True
+            add_opoutput(body, edge['from_data_id'].id, 0, False)
 
             # Assign/reuse ids for the back-edge start; it comes from from_data_id
             assert len(edge['from_data_id'].in_nodes()) == 1
@@ -214,13 +224,14 @@ class TensorIteratorMerge(MiddleReplacementPattern):
             for _, consumer, key, edge_attrs in body.out_edges(edge['to_data_id'].id, data=True, keys=True):
 
                 real_edge = {}
-                real_edge.update(edge) # all real back_edges have the same back-edge start
+                real_edge.update(edge)  # all real back_edges have the same back-edge start
 
                 consumer = Node(body, consumer)
 
                 if real_edge['to_data_id'].in_node().has_valid('internal_layer_id'):
                     assert False
-                    real_edge['to_data_id'].out_node()['internal_layer_id'] = real_edge['to_data_id'].in_node().internal_layer_id
+                    real_edge['to_data_id'].out_node()['internal_layer_id'] = \
+                        real_edge['to_data_id'].in_node().internal_layer_id
                 elif not consumer.has_valid('internal_layer_id'):
                     consumer['internal_layer_id'] = internal_id_count
                     internal_id_count += 1
@@ -245,7 +256,7 @@ class TensorIteratorMerge(MiddleReplacementPattern):
                     real_edge['consumer'].id,
                     real_edge['consumer_key'],
                     real_edge['attrs'])
-            for real_edge in current_real_back_edges])
+                for real_edge in current_real_back_edges])
 
             body.remove_nodes_from([edge['to_data_id'].id, edge['to_data_id'].in_node().id])
             real_back_edges += current_real_back_edges
@@ -261,7 +272,8 @@ class TensorIteratorMerge(MiddleReplacementPattern):
                 # Insert squeezing resize at input port that has partitioning
                 shape = ext_inp['internal_data_id'].shape.copy()
                 assert not ext_inp['internal_data_id'].has_valid('value')
-                new_input_data = Op._create_data_node(body, ext_inp['internal_data_id'].name + '/UnsqueezedInput', dict(shape=np.insert(shape, ext_inp['axis'], 1)))
+                new_input_data = Op._create_data_node(body, ext_inp['internal_data_id'].name + '/UnsqueezedInput',
+                                                      dict(shape=np.insert(shape, ext_inp['axis'], 1)))
                 dim = shape.copy()
                 # try to do it dynamically reshapable along one of the axis
                 # it is practically useful to reshape along batch dimension, but here we cannot detect where it is
@@ -300,13 +312,14 @@ class TensorIteratorMerge(MiddleReplacementPattern):
                 # trying to make it dynamically reshapable (see related comment above for the first Reshape)
                 dim[0] = -1
                 assert not ext_out['internal_data_id'].has_valid('value')
-                reshape_op = Reshape(body, dict(name=ext_out['internal_data_id'].name + '/OutputUnsqueeze', dim=np.insert(dim, ext_out['axis'], 1)))
+                reshape_op = Reshape(body, dict(name=ext_out['internal_data_id'].name + '/OutputUnsqueeze',
+                                                dim=np.insert(dim, ext_out['axis'], 1)))
                 ext_out['internal_data_id'] = reshape_op.create_node_with_data([ext_out['internal_data_id']])
 
             # TODO: add here working with simple outputs
 
-            ext_out['internal_data_id']['is_output'] = True
-            #assert len(ext_out['internal_data_id'].out_nodes()) == 0
+            add_opoutput(body, ext_out['internal_data_id'].id, 0, False)
+            # assert len(ext_out['internal_data_id'].out_nodes()) == 0
             assert len(ext_out['internal_data_id'].in_nodes()) == 1
             if not 'internal_layer_id' in ext_out['internal_data_id'].in_node():
                 ext_out['internal_data_id'].in_node()['internal_layer_id'] = internal_id_count
@@ -322,16 +335,22 @@ class TensorIteratorMerge(MiddleReplacementPattern):
         ti_op = TensorIterator(graph, {
             'name': name + '/TensorIterator',
             'body': body,
+            'in_ports_count': len(external_inputs),
+            'out_ports_count': len(external_outputs),
 
             'input_port_map': [
-                {field: external_input[field] for field in [ 'external_port_id', 'internal_layer_id', 'internal_port_id', 'axis', 'stride', 'part_size', 'start', 'end']}
+                {field: external_input[field] for field in
+                 ['external_port_id', 'internal_layer_id', 'internal_port_id', 'axis', 'stride', 'part_size', 'start',
+                  'end']}
                 for external_input in real_external_inputs],
 
             'output_port_map': [
-                {field: external_output[field] for field in [ 'external_port_id', 'internal_layer_id', 'internal_port_id', 'axis', 'stride', 'part_size', 'start', 'end']}
+                {field: external_output[field] for field in
+                 ['external_port_id', 'internal_layer_id', 'internal_port_id', 'axis', 'stride', 'part_size', 'start',
+                  'end']}
                 for external_output in external_outputs],
             'back_edges': [
-                {field: edge[field] for field in [ 'from_layer', 'from_port', 'to_layer', 'to_port']}
+                {field: edge[field] for field in ['from_layer', 'from_port', 'to_layer', 'to_port']}
                 for edge in real_back_edges],
         })
 
@@ -346,7 +365,3 @@ class TensorIteratorMerge(MiddleReplacementPattern):
 
         for i, out in enumerate(ti_outs):
             out.in_edge()['external_port_id'] = external_outputs[i]['external_port_id']
-
-
-
-        # Create TI operation