Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / utils / unittest / graph.py
index 64a0f30..2c36d61 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
 
 from collections import deque
 from copy import deepcopy
+from numbers import Number
 
 import networkx as nx
 import numpy as np
 
-from mo.graph.graph import Node
+from mo.graph.graph import Node, Graph
 from mo.middle.pattern_match import all_edges_in_nodes
 from mo.utils.error import Error
 
@@ -51,7 +52,7 @@ def build_graph_with_attrs(nodes_with_attrs: list, edges_with_attrs: list, new_n
                            update_nodes_attributes: dict = None, nodes_with_edges_only: bool = False,
                            add_nodes_from_edges: bool = False):
     """
-    Build the nx.MultiDiGraph with specific nodes and edges. Also update of edge and node parameters is supported.
+    Build the Graph with specific nodes and edges. Also update of edge and node parameters is supported.
     :param nodes_with_attrs: list of tuples ('node_name', {node_attrs})
     :param edges_with_attrs: list of tuples like (start node, end node, (optional) {attrs of the edge}).
     :param new_nodes_with_attrs: analogically nodes_with_attrs
@@ -78,7 +79,7 @@ def build_graph_with_attrs(nodes_with_attrs: list, edges_with_attrs: list, new_n
     if not add_nodes_from_edges and not all_edges_in_nodes(nodes=all_nodes_names, edges=all_edges):
         raise Error("Some nodes from list of edges is not in nodes. Please, add all necessary nodes.")
 
-    graph = nx.MultiDiGraph()
+    graph = Graph()
 
     # Create dict for nodes with attrs
     nodes_attrs = {}
@@ -129,7 +130,7 @@ def build_graph_with_attrs(nodes_with_attrs: list, edges_with_attrs: list, new_n
 
 def build_graph(nodes_attrs: dict, edges: list, update_attributes: dict = None, nodes_with_edges_only: bool = False):
     """
-    Build the nx.MultiDiGraph with specific nodes and edges.
+    Build the Graph with specific nodes and edges.
     :param nodes_attrs: dictionary where key is the node name and the value is the dictionary with node attributes.
     :param edges: list of pairs with start and end node names of the edge.
     :param update_attributes: optional dictionary which specifies nodes names and their attributes to be updated. The
@@ -137,7 +138,7 @@ def build_graph(nodes_attrs: dict, edges: list, update_attributes: dict = None,
     :param nodes_with_edges_only: add nodes which has at least one incoming or outcoming edge.
     :return: generated graph.
     """
-    graph = nx.MultiDiGraph()
+    graph = Graph()
 
     for node_name, attrs in nodes_attrs.items():
         if 'name' not in attrs:
@@ -180,19 +181,30 @@ def build_graph(nodes_attrs: dict, edges: list, update_attributes: dict = None,
             for attr, value in new_attrs.items():
                 graph.node[node_name][attr] = value
 
+    for node in graph.get_op_nodes():
+        # Add in_ports attribute
+        in_edges = node.in_edges()
+        for i in range(len(in_edges)):
+            node.add_input_port(idx=i)
+
+        # Add out_ports attribute
+        out_edges = node.out_edges()
+        for i in range(len(out_edges)):
+            node.add_output_port(idx=i)
+
     return graph
 
 
 def build_graph_with_edge_attrs(nodes_attrs: dict, edges: list, update_attributes: dict = None):
     """
-    Build the nx.MultiDiGraph with specific nodes and edges.
+    Build the Graph with specific nodes and edges.
     :param nodes_attrs: dictionary where key is the node name and the value is the dictionary with node attributes.
     :param edges: list of pairs with start and end node names of the edge.
     :param update_attributes: optional dictionary which specifies nodes names and their attributes to be updated. The
     key is a node name to update attribute and the value is a dictionary with attribute name and its value.
     :return: generated graph.
     """
-    graph = nx.MultiDiGraph()
+    graph = Graph()
     for node_1, node_2, attr in edges:
         if node_1 not in graph.nodes():
             graph.add_node(node_1, **deepcopy(nodes_attrs[node_1]))
@@ -207,7 +219,7 @@ def build_graph_with_edge_attrs(nodes_attrs: dict, edges: list, update_attribute
     return graph
 
 
-def compare_graphs(graph: nx.MultiDiGraph, graph_ref: nx.MultiDiGraph, last_node: str, last_node_ref=None,
+def compare_graphs(graph: Graph, graph_ref: Graph, last_node: str, last_node_ref=None,
                    check_op_attrs=False):
     if last_node_ref is None:
         last_node_ref = last_node
@@ -249,7 +261,7 @@ def compare_graphs(graph: nx.MultiDiGraph, graph_ref: nx.MultiDiGraph, last_node
             # Check that nodes has same operation
             if check_op_attrs:
                 for attr in graph_ref.node[node_ref.id]:
-                    if graph_ref.node[node_ref.id][attr] is None or attr in ['name', 'id']:
+                    if graph_ref.node[node_ref.id][attr] is None or attr in ['name', 'id', '_in_ports', '_out_ports', 'infer', 'IE']:
                         continue
                     if attr not in graph.node[node.id]:
                         return False, 'Node {} has missing attribute {}'.format(node.id, attr)
@@ -259,11 +271,16 @@ def compare_graphs(graph: nx.MultiDiGraph, graph_ref: nx.MultiDiGraph, last_node
                             return False, '{} and {} has different attr {} : {} and {}'.format(
                                 node.id, node_ref.id, attr, graph.node[node.id][attr],
                                 graph_ref.node[node_ref.id][attr])
-                    else:
-                        if graph.node[node.id][attr] != graph_ref.node[node_ref.id][attr]:
+                    elif isinstance(graph.node[node.id][attr], Number):
+                        if abs(graph.node[node.id][attr] - graph_ref.node[node_ref.id][attr]) > 1e-4:
                             return False, '{} and {} has different attr {} : {} and {}'.format(
                                 node.id, node_ref.id, attr, graph.node[node.id][attr],
                                 graph_ref.node[node_ref.id][attr])
+                    elif graph.node[node.id][attr] != graph_ref.node[node_ref.id][attr]:
+                        return False, '{} and {} has different attr {} : {} and {}'.format(
+                            node.id, node_ref.id, attr, graph.node[node.id][attr],
+                            graph_ref.node[node_ref.id][attr])
+
         else:
             if node_ref.has_valid('shape') and not node.has_valid('shape'):
                 return False, '{} has None shape'.format(node.id)