Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / ops / op.py
index 83d80fb..2028acc 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -14,6 +14,7 @@
  limitations under the License.
 """
 
+import copy
 import logging as log
 from collections import namedtuple
 
@@ -22,7 +23,8 @@ import numpy as np
 
 from mo.front.extractor import add_attrs_props
 from mo.front.extractor import update_ie_fields
-from mo.graph.graph import Node, unique_id
+from mo.graph.graph import Node, Graph
+from mo.graph.port import Port
 from mo.utils import class_registration
 from mo.utils.error import Error
 
@@ -33,7 +35,7 @@ class Op(object):
     # Add the derived class to excluded_classes if one should not be registered in registered_ops
     excluded_classes = []
 
-    def __init__(self, graph: nx.MultiDiGraph, attrs1: dict = None, attrs2: dict = None):
+    def __init__(self, graph: Graph, attrs1: dict = None, attrs2: dict = None):
         self.graph = graph
         try:
             self.ir_version = graph.graph['ir_version']
@@ -56,13 +58,15 @@ class Op(object):
         if attrs is not None:
             new_attrs.update(attrs)
         id_prefix = new_attrs['name'] if 'name' in new_attrs else ''
-        id = unique_id(self.graph, id_prefix)
+        id = self.graph.unique_id(id_prefix)
         new_attrs['name'] = id
         new_attrs = add_attrs_props(new_attrs)
         update_ie_fields(new_attrs, self.ir_version)
         self.substitute_ie_attrs(new_attrs)
         self.graph.add_node(id, **new_attrs)
-        return Node(self.graph, id)
+
+        node = Node(self.graph, id)
+        return node
 
     def substitute_ie_attrs(self, new_attrs: dict):
         """
@@ -71,6 +75,7 @@ class Op(object):
         """
         backend_attrs_mapping = {
             None: self.backend_attrs,
+            5: self.backend_attrs,
             4: self.backend_attrs,
             3: self.backend_attrs,
             2: self.backend_attrs_v2
@@ -103,23 +108,25 @@ class Op(object):
             raise Error('Node {} has more than one outputs. Provide output port explicitly. '.format(node.name))
         return node, port
 
-    def cut_edge_and_create_node(self, node: Node, out_port: int, attrs: dict = None):
+    def create_node_on_port(self, node: Node, out_port: int, attrs: dict = None, edge_attrs: dict = None):
         """
         Removes an edge, that is connected to nodes out_port. Creates new_node with attrs attributes and
         connects it to node by edge that stores the same information as cutted edge.
         :param node: Input node, to cut the edge from
         :param out_port: output port of edge to cut
         :param attrs: attributes of new node
+        :param edge_attrs: attributes to be changed/added to new edge
         :return: Node instance of created new_node
         """
-        edges = [(u, v, keys, params) for u, v, keys, params in node.graph.out_edges(node.id, data=True, keys=True)
-                 if 'out' in params and params['out'] == out_port]
-        edge_attrs = edges[0][3]
-        [self.graph.remove_edge(u, v, key=key) for u, v, key, params in edges]
+        if edge_attrs is None:
+            edge_attrs = {'in': 0}
+        prev_edge_attrs = copy.deepcopy(node.out_edge(out_port))
+        prev_edge_attrs.update(edge_attrs)
+        new_edge_attrs = prev_edge_attrs
         if attrs is None:
             attrs = dict()
         new_node = self.add_node(attrs)
-        self.graph.add_edge(node.id, new_node.id, **edge_attrs)
+        self.graph.add_edge(node.id, new_node.id, **new_edge_attrs)
         return new_node
 
     def create_node(self, inputs: list = None, attrs: dict = None, edge_attrs: dict = None):
@@ -176,7 +183,7 @@ class Op(object):
         old_data_value = [None]
         old_data_shape = [None]
         if data_nodes is None:
-            data_node = unique_id(self.graph)
+            data_node = self.graph.unique_id()
             self.graph.add_node(data_node, **add_attrs_props(
                 dict(kind='data', precision="FP32", name=data_node, value=None, shape=None, data_type=None,
                      infer=None)))
@@ -190,9 +197,11 @@ class Op(object):
                               data_nodes]
         for id, data_node in enumerate(data_nodes):
             self.graph.add_edges_from([(new_op_node.id, data_node.id, {'out': id})])
+
         if new_op_node.has_valid('infer'):
-            log.debug('Start running infer function for individual op node with attributes: {}'.format(
-                new_op_node.graph.node[new_op_node.id]))
+            if log.getLogger().isEnabledFor(log.DEBUG):
+                log.debug('Start running infer function for individual op node with attributes: {}'
+                          ''.format(str(new_op_node)))
             new_op_node.infer(new_op_node)
             assert all(old_value is None for old_value in old_data_value) or all(
                 [np.array_equal(old_data_value[id], data_node.value) for id, data_node in enumerate(data_nodes)])
@@ -203,36 +212,36 @@ class Op(object):
                     [old_data_shape[id] for id in range(len(data_nodes))],
                     [data_node.shape for data_node in data_nodes])
             for data_node in data_nodes:
-                log.debug(
-                    'Finished running infer function, data nodes attributes: {}'.format(
-                        data_node.graph.node[data_node.id]))
+                if log.getLogger().isEnabledFor(log.DEBUG):
+                    log.debug(
+                        'Finished running infer function, data nodes attributes: {}'.format(data_node))
         return data_nodes[0] if len(data_nodes) == 1 else data_nodes
 
     @staticmethod
-    def create_data_node(graph: nx.MultiDiGraph, op_node: Node, attrs: dict = None, edge_attrs: dict = None):
+    def create_data_node(graph: Graph, op_node: Node, attrs: dict = None, edge_attrs: dict = None, out_port=0):
         assert op_node is not None and op_node.kind == 'op'
         assert len(op_node.out_nodes()) == 0
         if attrs is None:
             attrs = {}
 
-        data_node = unique_id(graph, op_node.id)
+        data_node = graph.unique_id(op_node.id)
         defaul_attrs = dict(kind='data', precision="FP32", name=data_node, value=None, shape=None, data_type=None,
                             infer=None)
         defaul_attrs.update(attrs)
         graph.add_node(data_node, **add_attrs_props(defaul_attrs))
         data_node = Node(graph, data_node)
         if edge_attrs is not None:
-            graph.add_edges_from([(op_node.id, data_node.id, {'out': 0, **edge_attrs})])
+            graph.add_edges_from([(op_node.id, data_node.id, {'out': out_port, **edge_attrs})])
         else:
-            graph.add_edges_from([(op_node.id, data_node.id, {'out': 0})])
+            graph.add_edges_from([(op_node.id, data_node.id, {'out': out_port})])
         return data_node
 
     @staticmethod
-    def _create_data_node(graph: nx.MultiDiGraph, name: str, attrs: dict = None):
+    def _create_data_node(graph: Graph, name: str, attrs: dict = None):
         if attrs is None:
             attrs = {}
 
-        data_node = unique_id(graph, name)
+        data_node = graph.unique_id(name)
         defaul_attrs = dict(kind='data', precision="FP32", name=data_node, value=None, shape=None, data_type=None,
                             infer=None)
         defaul_attrs.update(attrs)
@@ -241,23 +250,24 @@ class Op(object):
         return data_node
 
     @staticmethod
-    def create_input_data_node(graph: nx.MultiDiGraph, name: str, value: np.array, attrs: dict = {}):
-        data_node = unique_id(graph, name)
-        defaul_attrs = dict(kind='data', precision="FP32", name=data_node, value=np.array(value), shape=value.shape,
+    def create_input_data_node(graph: Graph, name: str, value: np.array, attrs: dict = {}):
+        data_node = graph.unique_id(name)
+        defaul_attrs = dict(kind='data', precision="FP32", name=data_node, value=np.array(value),
+                            shape=np.array(value.shape),
                             data_type=None, infer=None)
         defaul_attrs.update(attrs)
         graph.add_node(data_node, **add_attrs_props(defaul_attrs))
         return Node(graph, data_node)
 
     @staticmethod
-    def create_and_connect_input_data_node(graph: nx.MultiDiGraph, op_node: Node, attrs: dict = None, edge_attrs: dict = None):
+    def create_and_connect_input_data_node(graph: Graph, op_node: Node, attrs: dict = None, edge_attrs: dict = None):
         assert op_node is not None and op_node.kind == 'op'
         if attrs is None:
             attrs = {}
         if edge_attrs is None:
             edge_attrs = {}
 
-        data_node = unique_id(graph, op_node.id)
+        data_node = graph.unique_id(op_node.id)
         defaul_attrs = dict(kind='data', precision="FP32", name=data_node, value=None, shape=None, data_type=None,
                             infer=None)
         defaul_attrs.update(attrs)