"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
from collections import deque
from copy import deepcopy
+from numbers import Number
import networkx as nx
import numpy as np
-from mo.graph.graph import Node
+from mo.graph.graph import Node, Graph
from mo.middle.pattern_match import all_edges_in_nodes
from mo.utils.error import Error
update_nodes_attributes: dict = None, nodes_with_edges_only: bool = False,
add_nodes_from_edges: bool = False):
"""
- Build the nx.MultiDiGraph with specific nodes and edges. Also update of edge and node parameters is supported.
+ Build the Graph with specific nodes and edges. Also update of edge and node parameters is supported.
:param nodes_with_attrs: list of tuples ('node_name', {node_attrs})
:param edges_with_attrs: list of tuples like (start node, end node, (optional) {attrs of the edge}).
:param new_nodes_with_attrs: analogically nodes_with_attrs
if not add_nodes_from_edges and not all_edges_in_nodes(nodes=all_nodes_names, edges=all_edges):
raise Error("Some nodes from list of edges is not in nodes. Please, add all necessary nodes.")
- graph = nx.MultiDiGraph()
+ graph = Graph()
# Create dict for nodes with attrs
nodes_attrs = {}
def build_graph(nodes_attrs: dict, edges: list, update_attributes: dict = None, nodes_with_edges_only: bool = False):
"""
- Build the nx.MultiDiGraph with specific nodes and edges.
+ Build the Graph with specific nodes and edges.
:param nodes_attrs: dictionary where key is the node name and the value is the dictionary with node attributes.
:param edges: list of pairs with start and end node names of the edge.
:param update_attributes: optional dictionary which specifies nodes names and their attributes to be updated. The
:param nodes_with_edges_only: add nodes which has at least one incoming or outcoming edge.
:return: generated graph.
"""
- graph = nx.MultiDiGraph()
+ graph = Graph()
for node_name, attrs in nodes_attrs.items():
if 'name' not in attrs:
for attr, value in new_attrs.items():
graph.node[node_name][attr] = value
+ for node in graph.get_op_nodes():
+ # Add in_ports attribute
+ in_edges = node.in_edges()
+ for i in range(len(in_edges)):
+ node.add_input_port(idx=i)
+
+ # Add out_ports attribute
+ out_edges = node.out_edges()
+ for i in range(len(out_edges)):
+ node.add_output_port(idx=i)
+
return graph
def build_graph_with_edge_attrs(nodes_attrs: dict, edges: list, update_attributes: dict = None):
"""
- Build the nx.MultiDiGraph with specific nodes and edges.
+ Build the Graph with specific nodes and edges.
:param nodes_attrs: dictionary where key is the node name and the value is the dictionary with node attributes.
:param edges: list of pairs with start and end node names of the edge.
:param update_attributes: optional dictionary which specifies nodes names and their attributes to be updated. The
key is a node name to update attribute and the value is a dictionary with attribute name and its value.
:return: generated graph.
"""
- graph = nx.MultiDiGraph()
+ graph = Graph()
for node_1, node_2, attr in edges:
if node_1 not in graph.nodes():
graph.add_node(node_1, **deepcopy(nodes_attrs[node_1]))
return graph
-def compare_graphs(graph: nx.MultiDiGraph, graph_ref: nx.MultiDiGraph, last_node: str, last_node_ref=None,
+def compare_graphs(graph: Graph, graph_ref: Graph, last_node: str, last_node_ref=None,
check_op_attrs=False):
if last_node_ref is None:
last_node_ref = last_node
# Check that nodes has same operation
if check_op_attrs:
for attr in graph_ref.node[node_ref.id]:
- if graph_ref.node[node_ref.id][attr] is None or attr in ['name', 'id']:
+ if graph_ref.node[node_ref.id][attr] is None or attr in ['name', 'id', '_in_ports', '_out_ports', 'infer', 'IE']:
continue
if attr not in graph.node[node.id]:
return False, 'Node {} has missing attribute {}'.format(node.id, attr)
return False, '{} and {} has different attr {} : {} and {}'.format(
node.id, node_ref.id, attr, graph.node[node.id][attr],
graph_ref.node[node_ref.id][attr])
- else:
- if graph.node[node.id][attr] != graph_ref.node[node_ref.id][attr]:
+ elif isinstance(graph.node[node.id][attr], Number):
+ if abs(graph.node[node.id][attr] - graph_ref.node[node_ref.id][attr]) > 1e-4:
return False, '{} and {} has different attr {} : {} and {}'.format(
node.id, node_ref.id, attr, graph.node[node.id][attr],
graph_ref.node[node_ref.id][attr])
+ elif graph.node[node.id][attr] != graph_ref.node[node_ref.id][attr]:
+ return False, '{} and {} has different attr {} : {} and {}'.format(
+ node.id, node_ref.id, attr, graph.node[node.id][attr],
+ graph_ref.node[node_ref.id][attr])
+
else:
if node_ref.has_valid('shape') and not node.has_valid('shape'):
return False, '{} has None shape'.format(node.id)