Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / middle / passes / eliminate.py
index 2878add..d131875 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -17,18 +17,21 @@ import logging as log
 from collections import deque
 
 import networkx as nx
+import numpy as np
 
-from mo.graph.graph import Node, create_edge
+from mo.graph.graph import Node, Graph
 from mo.middle.pattern_match import apply_pattern
+from mo.utils.error import Error
 from mo.utils.graph import bfs_search, pseudo_topological_sort
 
 
-def get_nodes_with_attributes(graph: nx.MultiDiGraph, **attrs: dict):
+# TODO: dep warning
+def get_nodes_with_attributes(graph: Graph, **attrs: dict):
     node_attrs = graph.nodes(data=True)
     return [n for n, d in node_attrs if all(a in d.items() for a in attrs.items())]
 
 
-def reverse_dfs(graph: nx.MultiDiGraph, node_name: str, update_func: callable, visited: set = None):
+def reverse_dfs(graph: Graph, node_name: str, update_func: callable, visited: set = None):
     d = deque()
 
     if visited is None:
@@ -44,23 +47,23 @@ def reverse_dfs(graph: nx.MultiDiGraph, node_name: str, update_func: callable, v
                 d.append(in_node_name)
 
 
-def mark_input_nodes(graph: nx.MultiDiGraph, node_name: str, key: str, value):
+def mark_input_nodes(graph: Graph, node_name: str, key: str, value):
     for input, _ in graph.in_edges(node_name):
         graph.node[input][key] = value
 
 
-def mark_output_nodes(graph: nx.MultiDiGraph, node_name: str, key: str, value):
+def mark_output_nodes(graph: Graph, node_name: str, key: str, value):
     for output, _ in graph.out_edges(node_name):
         graph.node[output][key] = value
 
 
-def mark_output_reachable_nodes(graph: nx.MultiDiGraph):
+def mark_output_reachable_nodes(graph: Graph):
     """
     Mark nodes whether they are outputs reachable or not. The node is considered output reachable if it is connected to
-    one of the nodes that has attribute is_output=True.
+    one of the nodes that has attribute op=OpOutput.
     """
     nx.set_node_attributes(G=graph, name='is_output_reachable', values=False)
-    outputs = get_nodes_with_attributes(graph, is_output=True)
+    outputs = graph.get_nodes_with_attributes(op='OpOutput')
     log.debug('The following nodes are seeded as output reachable:\n{}'.format('\n'.join(sorted(map(str, outputs)))))
     nx.set_node_attributes(G=graph, name='is_output_reachable', values={n: True for n in outputs})
     visited = set()
@@ -69,7 +72,7 @@ def mark_output_reachable_nodes(graph: nx.MultiDiGraph):
                     lambda graph, node_name: mark_input_nodes(graph, node_name, 'is_output_reachable', True), visited)
 
 
-def mark_undead_nodes(graph: nx.MultiDiGraph, undead_types: list):
+def mark_undead_nodes(graph: Graph, undead_types: list):
     """
     Mark output nodes and nodes of the specific type as undead, meaning that they should survive the dead nodes
     elimination phase. Then mark all children nodes of the undead nodes (except children of inputs) as undead.
@@ -80,29 +83,30 @@ def mark_undead_nodes(graph: nx.MultiDiGraph, undead_types: list):
     nx.set_node_attributes(G=graph, name='is_undead', values=False)
 
     # mark output nodes as undead
-    outputs = get_nodes_with_attributes(graph, is_output=True)
+    outputs = graph.get_nodes_with_attributes(op='OpOutput')
     nx.set_node_attributes(G=graph, name='is_undead', values={n: True for n in outputs})
 
     # mark specifically defined with node type set of nodes
     for type in undead_types:
-        node_of_specific_type = get_nodes_with_attributes(graph, type=type)
+        node_of_specific_type = graph.get_nodes_with_attributes(type=type)
         nx.set_node_attributes(G=graph, name='is_undead', values={n: True for n in node_of_specific_type})
 
-    undead_nodes = get_nodes_with_attributes(graph, is_undead=True)
+    undead_nodes = graph.get_nodes_with_attributes(is_undead=True)
     # propagate 'undead' attribute to children nodes of undead nodes if the node produces constant value
     for node_name in bfs_search(graph, undead_nodes):
         if graph.node[node_name]['is_undead']:
             for _, dst_node_name in graph.out_edges(node_name):
                 node_attrs = graph.node[dst_node_name]
-                if 'kind' in node_attrs and node_attrs['kind'] == 'data' and node_attrs['value'] is not None:
+                if 'kind' in node_attrs and (
+                        node_attrs['kind'] == 'data' and node_attrs['value'] is not None or node_attrs['kind'] == 'op'):
                     graph.node[dst_node_name]['is_undead'] = True
 
     # mark input nodes as undead
-    inputs = get_nodes_with_attributes(graph, is_input=True)
+    inputs = graph.get_nodes_with_attributes(is_input=True)
     nx.set_node_attributes(G=graph, name='is_undead', values={n: True for n in inputs})
 
 
-def mark_const_producer_nodes(graph: nx.MultiDiGraph):
+def mark_const_producer_nodes(graph: Graph):
     """
     Mark nodes that produce constant values.
     :param graph: graph to operate on.
@@ -122,7 +126,7 @@ def mark_const_producer_nodes(graph: nx.MultiDiGraph):
                 graph.node[input]['is_const_producer'] = False
 
 
-def eliminate_dead_nodes(graph: nx.MultiDiGraph):
+def eliminate_dead_nodes(graph: Graph):
     nodes_to_remove = set()
     for node_name, node_attrs in graph.nodes(data=True):
         if not node_attrs['is_output_reachable'] or (node_attrs['is_const_producer'] and not node_attrs['is_undead']):
@@ -131,25 +135,69 @@ def eliminate_dead_nodes(graph: nx.MultiDiGraph):
     graph.remove_nodes_from(nodes_to_remove)
 
 
-def graph_clean_up(graph: nx.MultiDiGraph, undead_node_types: list = []):
+def add_constant_operations(graph: Graph):
+    data_nodes = graph.get_data_nodes(has_value=True)
+    for node in data_nodes:
+        # If data node has no producers we create Const operation
+        if len(node.in_nodes()) == 0 and len(node.out_nodes()) != 0:
+            # It's necessary to import here due to cycle dependencies
+            from mo.ops.const import Const
+            Const(graph, dict(value=node.value, shape=np.array(node.value.shape))).create_node_with_data(data_nodes=node)
+
+
+def remove_const_ops(graph: Graph):
+    ops = [node for node in graph.get_op_nodes() if node.soft_get('type') == 'Const']
+    for node in ops:
+        graph.remove_edge(node.id, node.out_node().id)
+        graph.remove_node(node.id)
+
+
+def shape_inference(graph: Graph):
+    nodes = pseudo_topological_sort(graph)
+    for node in nodes:
+        node = Node(graph, node)
+        if node.has_and_set('need_shape_inference'):
+            old_out_shapes = [port.data.get_shape() for port in node.out_ports().values()]
+            node.infer(node)
+            new_out_shapes = [port.data.get_shape() for port in node.out_ports().values()]
+            for shape1, shape2 in zip(old_out_shapes, new_out_shapes):
+                if shape1 is not None and not np.array_equal(shape1, shape2):
+                    raise Error("After partial shape inference were found shape collision for node {} (old shape: {}, new shape: {})".format(node.name, shape1, shape2))
+            node.need_shape_inference = False
+
+
+def graph_clean_up(graph: Graph, undead_node_types: list = None):
+    if undead_node_types is None:
+        undead_node_types = []
+
+    if 'Shape' in undead_node_types and not graph.graph['cmd_params'].keep_shape_ops:
+        undead_node_types.remove('Shape')
+
     mark_output_reachable_nodes(graph)
     mark_undead_nodes(graph, undead_node_types)
     mark_const_producer_nodes(graph)
     eliminate_dead_nodes(graph)
+    # Add Const op for constant data nodes
+    add_constant_operations(graph)
+    shape_inference(graph)
+
 
+def graph_clean_up_tf(graph: Graph):
+    graph_clean_up(graph, ['TFCustomSubgraphCall', 'Shape'])
 
-def graph_clean_up_tf(graph: nx.MultiDiGraph):
-    graph_clean_up(graph, ['TFCustomSubgraphCall'])
 
+def graph_clean_up_onnx(graph: Graph):
+    graph_clean_up(graph, ['Shape'])
 
-def remove_identity_action(graph: nx.MultiDiGraph, matches: dict):
+
+def remove_identity_action(graph: Graph, matches: dict):
     remove_op_node_with_data_node(graph, matches['identity'])
 
 
 # TODO: unit tests
-def merge_data_nodes(graph: nx.MultiDiGraph, survived: Node, removed: Node):
-    if survived.has_and_set('is_output'):
-        graph.node[removed.id].update({'is_output': True})
+def merge_data_nodes(graph: Graph, survived: Node, removed: Node):
+    if survived.has_and_set('op') and survived.op == 'OpOutput':
+        graph.node[removed.id].update({'op': 'OpOutput'})
 
     for u, v, d in list(graph.in_edges(removed.id, data=True)):
         graph.add_edges_from([(u, survived.id, d)])
@@ -172,7 +220,7 @@ def merge_data_nodes(graph: nx.MultiDiGraph, survived: Node, removed: Node):
 
 
 # TODO: unit tests
-def remove_op_node_with_data_node(graph: nx.MultiDiGraph, node_to_remove: Node):
+def remove_op_node_with_data_node(graph: Graph, node_to_remove: Node):
     assert node_to_remove.kind == 'op'
     input_data_node = node_to_remove.in_node()
     output_node = [v for _, v in graph.out_edges(node_to_remove.id)]
@@ -190,7 +238,7 @@ def remove_op_node_with_data_node(graph: nx.MultiDiGraph, node_to_remove: Node):
     graph.remove_nodes_from([node_to_remove.id, input_data_node.id])
 
 
-def remove_op_nodes(graph: nx.MultiDiGraph, attrs: dict):
+def remove_op_nodes(graph: Graph, attrs: dict):
     op_attrs = {'kind': 'op'}
     op_attrs.update(attrs)
     apply_pattern(
@@ -201,7 +249,7 @@ def remove_op_nodes(graph: nx.MultiDiGraph, attrs: dict):
     )
 
 
-def remove_edges_for_nodes(graph: nx.MultiDiGraph, node_attrs: dict, edge_attrs: dict):
+def remove_edges_for_nodes(graph: Graph, node_attrs: dict, edge_attrs: dict):
     for node in graph.nodes():
         node = Node(graph, node)
         if all([node.has(attr) and node[attr] == node_attrs[attr] for attr in node_attrs]):
@@ -212,21 +260,3 @@ def remove_edges_for_nodes(graph: nx.MultiDiGraph, node_attrs: dict, edge_attrs:
                     graph.remove_edge(src_node.id, node.id)
 
 
-def remove_useless_split_action(graph: nx.MultiDiGraph, matches: dict):
-    split_node = matches['split']
-    input = split_node.in_node(1)
-    output = split_node.out_node()
-    graph.remove_edge(input.id, split_node.id)
-
-    for u, v, d in list(graph.out_edges(output.id, data=True)):
-        graph.add_edges_from([(input.id, v, d)])
-        graph.remove_edge(u, v)
-
-
-def remove_useless_split(graph: nx.MultiDiGraph):
-    apply_pattern(
-        graph,
-        nodes=[('split', {'kind': 'op', 'op': 'Split', 'num_split': 1})],
-        edges=[],
-        action=remove_useless_split_action
-    )