Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / extractor.py
index 6ba1ea4..0b68294 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -22,8 +22,8 @@ import networkx as nx
 import numpy as np
 
 from mo.front.onnx.extractors.utils import get_backend_pad
-from mo.graph.graph import Node, unique_id, get_node_id_by_name
-from mo.middle.passes.eliminate import reverse_dfs, get_nodes_with_attributes
+from mo.graph.graph import Node, Graph, add_opoutput
+from mo.middle.passes.eliminate import reverse_dfs
 from mo.utils import class_registration
 from mo.utils.error import Error
 from mo.utils.graph import dfs
@@ -31,15 +31,14 @@ from mo.utils.unsupported_ops import UnsupportedOps
 from mo.utils.utils import refer_to_faq_msg
 
 
-def restore_edges(graph: nx.DiGraph, get_edges: callable):
+def restore_edges(graph: Graph, get_edges: callable):
     """
     Take a graph without edges and extract dependencies between nodes with the help of get_edges function.
     For a given node n the get_edges function returns a list of tuples (n1, n2, attrs), that is used to create
     n1 --> n2 edge with attributes attrs.
-    It is possible that two nodes n1 and n2 have more than one n1 --> n2 edges, so the resulting graph is
-    nx.MultiDiGraph.
+    It is possible that two nodes n1 and n2 have more than one n1 --> n2 edges, so the resulting graph is Graph.
     """
-    graph = nx.MultiDiGraph(graph)
+    graph = Graph(graph)
     for node in list(graph.nodes()):
         edges = get_edges(Node(graph, node))
         for u, v, d in edges:
@@ -56,7 +55,7 @@ def restore_edges(graph: nx.DiGraph, get_edges: callable):
     return graph
 
 
-def remove_control_dependency_inputs(graph: nx.MultiDiGraph):
+def remove_control_dependency_inputs(graph: Graph):
     """
     Delete control dependency inputs from pb all over the graph
     :param graph: graph to operate on 
@@ -473,6 +472,7 @@ def update_ie_fields(attrs: dict, ir_version = None):
     ir_version_mapping = {
         # Default behaviour is IR V3 attributes
         None: ir_v3_attrs,
+        5: ir_v3_attrs,
         4: ir_v3_attrs,
         3: ir_v3_attrs,
         2: ir_v2_attrs
@@ -484,7 +484,7 @@ def update_ie_fields(attrs: dict, ir_version = None):
     attrs.update(ir_version_mapping[ir_version])
 
 
-def create_tensor_nodes(graph: nx.MultiDiGraph):
+def create_tensor_nodes(graph: Graph):
     """
     Creates nodes between ops to represent intermediate data that flows from one op to another.
     For each edge with unique out attribute that goes from a given node,
@@ -528,7 +528,7 @@ def create_tensor_nodes(graph: nx.MultiDiGraph):
         node_name = str(smart_node.name) if smart_node.has_valid('name') else str(smart_node.id)
 
         # assign to each output port a tensor unique id in the graph
-        out_tensor_dict = {port: unique_id(graph, '{}/Output_{}/Data_'.format(node_name, port)) for port in out_ports}
+        out_tensor_dict = {port: graph.unique_id('{}/Output_{}/Data_'.format(node_name, port)) for port in out_ports}
 
         # add a new node with kind='data' per each tensor
         graph.add_nodes_from([(uid,
@@ -561,7 +561,7 @@ def create_tensor_nodes(graph: nx.MultiDiGraph):
         # data node content (numpy array). Shape is initialized by this array.
         if 'embedded_inputs' in node_attr:
             for port_index, value_attr, attrs in node_attr['embedded_inputs']:
-                input_node_id = unique_id(graph, 'embedded_input_')
+                input_node_id = graph.unique_id('embedded_input_')
                 value = node_attr[value_attr]
                 shape = np.array(value.shape, dtype=np.int64)
                 graph.add_node(input_node_id, **add_attrs_props(
@@ -569,6 +569,9 @@ def create_tensor_nodes(graph: nx.MultiDiGraph):
                 edge_attrs = {'in': port_index, 'name': value_attr}
                 edge_attrs.update(attrs)
                 graph.add_edge(input_node_id, node, **edge_attrs)
+                op_node = Node(graph, node)
+                if not op_node.has_port(port_type='in', idx=edge_attrs['in']):
+                    op_node.add_input_port(edge_attrs['in'])
                 del node_attr[value_attr]
     return graph
 
@@ -586,7 +589,7 @@ def get_specific_edge_attrs(attrs: dict, attrs_type: str, additional_attrs=None)
     return new_attrs
 
 
-def extract_node_attrs(graph: nx.MultiDiGraph, extractor: callable):
+def extract_node_attrs(graph: Graph, extractor: callable):
     """
     For each node produce new entries in a node attributes dictionary by existing attributes.
     Old attributes are not removed but merged with new ones.
@@ -652,7 +655,7 @@ def extract_port_from_string(node_name: str):
         return name, in_port, out_port
 
 
-def get_node_id_with_ports(graph: nx.MultiDiGraph, name: str):
+def get_node_id_with_ports(graph: Graph, name: str):
     """
     Extracts port and node ID out of user provided name
     :param graph: graph to operate on
@@ -660,7 +663,7 @@ def get_node_id_with_ports(graph: nx.MultiDiGraph, name: str):
     :return: node ID, direction of port ('in', 'out', 'port') and port number or None
     """
     node_name, in_port, out_port = extract_port_from_string(name)
-    node_id = get_node_id_by_name(graph, node_name)
+    node_id = graph.get_node_id_by_name(node_name)
     if in_port is not None:
         direction = 'in'
         port = in_port
@@ -673,7 +676,7 @@ def get_node_id_with_ports(graph: nx.MultiDiGraph, name: str):
     return node_id, direction, port
 
 
-def input_user_data_repack(graph: nx.MultiDiGraph, input_user_shapes: [None, list, dict, np.ndarray], freeze_placeholder: dict):
+def input_user_data_repack(graph: Graph, input_user_shapes: [None, list, dict, np.ndarray], freeze_placeholder: dict):
     """
     Restructures user input cutting request. Splits ports out of node names. Transforms node names to node ids.
     :param graph: graph to operate on
@@ -712,12 +715,12 @@ def input_user_data_repack(graph: nx.MultiDiGraph, input_user_shapes: [None, lis
     _freeze_placeholder = dict()
     # freeze placeholder restructure
     # Replaces placeholder name with placeholder id. Raises if there is no placeholder with such ID
-    placeholders_ids = get_nodes_with_attributes(graph, op='Placeholder')
+    placeholders_ids = graph.get_nodes_with_attributes(op='Placeholder')
     if freeze_placeholder is None:
         _freeze_placeholder = None
     else:
         for placeholder_name, value in freeze_placeholder.items():
-            placeholder_id = get_node_id_by_name(graph, placeholder_name)
+            placeholder_id = graph.get_node_id_by_name(placeholder_name)
             if placeholder_id not in placeholders_ids:
                 raise Error(
                     'There is no placeholder with name {}. Can not freeze it with value.'.format(placeholder_name))
@@ -761,7 +764,7 @@ def input_user_data_repack(graph: nx.MultiDiGraph, input_user_shapes: [None, lis
     return _input_shapes, _freeze_placeholder
 
 
-def output_user_data_repack(graph: nx.MultiDiGraph, outputs: list):
+def output_user_data_repack(graph: Graph, outputs: list):
     """
 
     :param graph: graph to operate on
@@ -795,7 +798,7 @@ def output_user_data_repack(graph: nx.MultiDiGraph, outputs: list):
     return _outputs
 
 
-def user_data_repack(graph: nx.MultiDiGraph, input_user_shapes: [None, list, dict, np.array], outputs: list,
+def user_data_repack(graph: Graph, input_user_shapes: [None, list, dict, np.array], outputs: list,
                      freeze_placeholder: dict):
     """
     :param graph: graph to operate on
@@ -809,41 +812,17 @@ def user_data_repack(graph: nx.MultiDiGraph, input_user_shapes: [None, list, dic
     return _input_shapes, _outputs, _freeze_placeholder
 
 
-def add_opoutput(graph: nx.MultiDiGraph, node_name: str, port: int, cut: bool = True):
-    """
-    Creates and connects OpOutput node to node_name port. Cuts existing port if requested.
-    :param graph: graph to operate with
-    :param node_name: name of existing node in the graph that we want to add OpOutput to
-    :param port: output port of node to connect OpOutput to
-    :param cut: determines way of operating with edge specified by node_name and port
-    """
-    # we import it here because Op imports add_attrs_props and update_ie_fields from this file
-    from mo.ops.output import Output
-    if cut and len(Node(graph, node_name).out_edges()) != 0:
-        opoutput_node = Output(graph).cut_edge_and_create_node(Node(graph, node_name), port,
-                                                               {'name': '{}/sink_port_{}'.format(node_name, port)})
-    else:
-        opoutput_node = Output(graph).create_node([(Node(graph, node_name), port)],
-                                                  {'name': '{}/sink_port_{}'.format(node_name, port)})
-        opoutput_node.in_edge()['data_attrs'] = ['fw_tensor_debug_info']
-        opoutput_node.in_edge()['fw_tensor_debug_info'] = [(node_name, port)]
-    log.debug('Sink: {} for node {}'.format(opoutput_node.id, node_name))
-    log.debug(str(graph.node[opoutput_node.id]))
-    log.debug("Add edge from {} to {}".format(node_name, opoutput_node.id))
-    return opoutput_node.id
-
-
-def add_output_ops(graph: nx.MultiDiGraph, user_defined_outputs: dict, inputs: dict = None):
+def add_output_ops(graph: Graph, user_defined_outputs: dict, inputs: dict = None):
     sinks = []
     # func sets all layers as outputs in case of empty user_defined_outputs list (it's impossible to reach by cli)
     assert not (isinstance(user_defined_outputs, list) and not len(user_defined_outputs))
 
     # remove previously generated OpOutput if any
     graph.remove_nodes_from([node_name for node_name in graph.nodes() if
-                             'type' in graph.node[node_name] and graph.node[node_name]['type'] == 'OpOutput'])
+                             'op' in graph.node[node_name] and graph.node[node_name]['op'] == 'OpOutput'])
 
     if user_defined_outputs is None:
-        inputs = get_nodes_with_attributes(graph, op='Placeholder') if inputs is None else list(inputs.keys())
+        inputs = graph.get_nodes_with_attributes(op='Placeholder') if inputs is None else list(inputs.keys())
         input_reachable, dead_outputs, undead_outputs = set(), [], []
         for input in inputs:
             dfs(graph=graph, node_name=input, visited=input_reachable)
@@ -885,12 +864,12 @@ def add_output_ops(graph: nx.MultiDiGraph, user_defined_outputs: dict, inputs: d
     return sinks
 
 
-def set_is_input(graph: nx.MultiDiGraph, placeholders: list, is_input: bool):
+def set_is_input(graph: Graph, placeholders: list, is_input: bool):
     for placeholder in placeholders:
         graph.node[placeholder]['is_input'] = is_input
 
 
-def check_input(graph: nx.MultiDiGraph, node_name: str):
+def check_input(graph: Graph, node_name: str):
     node = Node(graph, node_name)
     if node['kind'] == 'op' and node['op'] == 'Placeholder' and not len(graph.in_edges(node_name)) and not node[
         'is_input']:
@@ -914,7 +893,7 @@ def split_node_in_port(node_id: str):
     return node_id, None
 
 
-def add_input_op_input_port_without_data(graph: nx.MultiDiGraph, node_id: str, input_op, edge_attrs: dict):
+def add_input_op_input_port_without_data(graph: Graph, node_id: str, input_op, edge_attrs: dict):
     input_node = input_op.create_node()
     graph.add_edge(input_node.id, node_id, **edge_attrs)
     log.debug('Input: {} for node {}'.format(input_node.id, node_id))
@@ -922,7 +901,7 @@ def add_input_op_input_port_without_data(graph: nx.MultiDiGraph, node_id: str, i
     return input_node.id
 
 
-def add_input_op_input_port_with_data(graph: nx.MultiDiGraph, node_id: str, input_op, edge_attrs: dict):
+def add_input_op_input_port_with_data(graph: Graph, node_id: str, input_op, edge_attrs: dict):
     input_data_node = input_op.create_node_with_data()
     input_node = input_data_node.in_node()
     graph.add_edge(input_data_node.id, node_id, **edge_attrs)
@@ -933,7 +912,7 @@ def add_input_op_input_port_with_data(graph: nx.MultiDiGraph, node_id: str, inpu
     return input_node.id
 
 
-def add_input_op_output_port_without_data(graph: nx.MultiDiGraph, node_id: str, input_op, port: int):
+def add_input_op_output_port_without_data(graph: Graph, node_id: str, input_op, port: int):
     input_node = input_op.create_node()
     # In this case it can be more than one out edge from one port and we should iterate over all output edges
     for _, out_node, attrs in graph.out_edges(node_id, data=True):
@@ -947,7 +926,7 @@ def add_input_op_output_port_without_data(graph: nx.MultiDiGraph, node_id: str,
     return input_node.id
 
 
-def add_input_op_output_port_with_data(graph: nx.MultiDiGraph, node_id: str, input_op, port: int):
+def add_input_op_output_port_with_data(graph: Graph, node_id: str, input_op, port: int):
     # we assume that after op always data node
     data_node = Node(graph, node_id).out_node(port)
     assert data_node.has_valid('kind') and data_node.kind == 'data'
@@ -959,7 +938,7 @@ def add_input_op_output_port_with_data(graph: nx.MultiDiGraph, node_id: str, inp
     return input_node.id
 
 
-def add_input_op(graph: nx.MultiDiGraph, node_id: str, port: int = 0, data: bool = False, shape=None,
+def add_input_op(graph: Graph, node_id: str, port: int = 0, data: bool = False, shape=None,
                  is_out_port: bool = False):
     """
     This function adds Input node to node with id==node_id to specified port (in or out defined with is_out_port).
@@ -996,7 +975,7 @@ def add_input_op(graph: nx.MultiDiGraph, node_id: str, port: int = 0, data: bool
     return new_input_id
 
 
-def add_input_ops_helper_before_infer_input_port(graph: nx.MultiDiGraph, smart_node: Node, port: int, node_id: str,
+def add_input_ops_helper_before_infer_input_port(graph: Graph, smart_node: Node, port: int, node_id: str,
                                                  shape: np.array, inputs: list, edges_to_remove: list):
     n_inputs = len(smart_node.in_nodes())
     if n_inputs > 1 and port is None:
@@ -1010,7 +989,7 @@ def add_input_ops_helper_before_infer_input_port(graph: nx.MultiDiGraph, smart_n
                                shape=shape))
 
 
-def add_input_ops_helper_after_infer_input_port(graph: nx.MultiDiGraph, smart_node: Node, port:int, node_id: str,
+def add_input_ops_helper_after_infer_input_port(graph: Graph, smart_node: Node, port:int, node_id: str,
                                                 inputs: list, edges_to_remove: list):
     n_inputs = len(smart_node.in_nodes())
     if n_inputs > 1 and port is not None and port != 0:
@@ -1029,7 +1008,7 @@ def add_input_ops_helper_after_infer_input_port(graph: nx.MultiDiGraph, smart_no
     edges_to_remove.append((in_node.id, node_id))
 
 
-def add_input_ops_helper_before_infer_output_port(graph: nx.MultiDiGraph, port:int, node_id: str,
+def add_input_ops_helper_before_infer_output_port(graph: Graph, port:int, node_id: str,
                                                  shape: np.array, inputs: list, edges_to_remove: list):
     for u, v, edge_attrs in graph.out_edges(node_id, data=True):
         if edge_attrs['out'] == port:
@@ -1037,7 +1016,7 @@ def add_input_ops_helper_before_infer_output_port(graph: nx.MultiDiGraph, port:i
     inputs.append(add_input_op(graph=graph, node_id=node_id, port=port, data=False,
                                shape=shape, is_out_port=True))
 
-def add_input_ops_helper_after_infer_output_port(graph: nx.MultiDiGraph, smart_node: Node, port:int, node_id: str,
+def add_input_ops_helper_after_infer_output_port(graph: Graph, smart_node: Node, port:int, node_id: str,
                                                  inputs: list, edges_to_remove: list):
     out_node = smart_node.out_node(port)
     shape = out_node['shape'] if 'shape' in out_node else None
@@ -1049,7 +1028,7 @@ def add_input_ops_helper_after_infer_output_port(graph: nx.MultiDiGraph, smart_n
     edges_to_remove.append((node_id, out_node.id))
 
 
-def add_input_ops(graph: nx.MultiDiGraph, user_defined_inputs: dict, before_infer: bool):
+def add_input_ops(graph: Graph, user_defined_inputs: dict, before_infer: bool):
     """
     This function add user defined input operations.
     For cutting without port:
@@ -1067,9 +1046,9 @@ def add_input_ops(graph: nx.MultiDiGraph, user_defined_inputs: dict, before_infe
     For case with before_infer=False data nodes are added to this schemes.
     """
     inputs = []
-    set_is_input(graph, get_nodes_with_attributes(graph, op='Placeholder'), False)
+    set_is_input(graph, graph.get_nodes_with_attributes(op='Placeholder'), False)
     if user_defined_inputs is None:
-        inputs = get_nodes_with_attributes(graph, op='Placeholder')
+        inputs = graph.get_nodes_with_attributes(op='Placeholder')
     else:
         # cutting the net by inputs
         assert isinstance(user_defined_inputs, dict)
@@ -1137,7 +1116,7 @@ def add_input_ops(graph: nx.MultiDiGraph, user_defined_inputs: dict, before_infe
     if len(inputs):
         set_is_input(graph, inputs, True)
         # Check if there are inputs that are not listed in user_defined_inputs and are needed to calculate outputs
-        outputs = get_nodes_with_attributes(graph, is_output=True)
+        outputs = graph.get_nodes_with_attributes(op='OpOutput')
         visited = set()
         for output_name in outputs:
             reverse_dfs(graph, output_name, check_input, visited)
@@ -1145,13 +1124,12 @@ def add_input_ops(graph: nx.MultiDiGraph, user_defined_inputs: dict, before_infe
     return inputs
 
 
-def remove_output_ops(graph: nx.MultiDiGraph):
+def remove_output_ops(graph: Graph):
     for node in list(graph.nodes()):
         node = Node(graph, node)
         if node.has_valid('op') and node.op == 'OpOutput':
             if len(node.in_nodes()) > 0:
                 assert (len(node.in_nodes()) == 1)
-                list(node.in_nodes().values())[0]['is_output'] = node.is_output
             graph.remove_node(node.id)