"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
import networkx as nx
-from mo.graph.graph import Node
+from mo.graph.graph import Node, Graph
from mo.utils.error import Error
from mo.utils.graph import nodes_matching_name_pattern, sub_graph_between_nodes
from mo.utils.utils import refer_to_faq_msg
return None
return [(out['node'], out['port']) for out in self._replacement_desc['outputs']]
- def update_custom_replacement_attributes(self, graph: nx.MultiDiGraph):
+ def update_custom_replacement_attributes(self, graph: Graph):
"""
The function run specific functions to update attributes of the custom replacement description. Currently it
updates information about input/output nodes.
def get_outputs_description(self):
return [('^' + node_name + '$', 0) for node_name in self.instances['end_points']]
- def get_internal_input_nodes(self, graph: nx.MultiDiGraph):
+ def get_internal_input_nodes(self, graph: Graph):
"""
Gets list of node names getting input from outside of the sub-graph. This function checks whether input nodes
specified in the configuration file should be added to the sub-graph or not. If they should not be added to the
else:
return self.instances['start_points']
- def get_internal_output_nodes(self, graph: nx.MultiDiGraph):
+ def get_internal_output_nodes(self, graph: Graph):
"""
Gets list of node names producing output outside of the sub-graph. This function checks whether output nodes
specified in the configuration file should be added to the sub-graph or not. If they should not be added to the
else:
return self.instances['end_points']
- def update_custom_replacement_attributes(self, graph: nx.MultiDiGraph):
+ def update_custom_replacement_attributes(self, graph: Graph):
if not self.has('instances'):
raise Error("No instance(s) is(are) defined for the custom replacement '{}'. ".format(self.replacement_id) +
refer_to_faq_msg(66))
def __init__(self, replacement_id: str, attrs: dict = None):
super().__init__(replacement_id, attrs)
- def update_custom_replacement_attributes(self, graph: nx.MultiDiGraph):
+ def update_custom_replacement_attributes(self, graph: Graph):
if not self.has('instances') or len(self.instances) == 0:
raise Error("No instances are defined for replacement with id '{}'. ".format(self.replacement_id) +
refer_to_faq_msg(68))
return result
-def update_custom_replacement_config_file(graph: nx.MultiDiGraph, file_name: str):
- data = parse_custom_replacement_config_file(file_name)
- if data is None:
- raise Error("Cannot update the file '{}' because it is broken. ".format(file_name) +
- refer_to_faq_msg(73))
-
- for replacement_desc in data:
- replacement_desc.update_custom_replacement_attributes(graph)
-
- return save_custom_replacement_config_file(data, file_name)
-
-
-def save_custom_replacement_config_file(descriptions: list, file_name: str):
- """
- Save custom layer(s) description(s) to the file.
- :param file_name: file to save description information to.
- :param descriptions: list with instances of the CustomLayerDescriptor classes.
- :return: True if operation is successful.
- """
- try:
- json.dump([replacement_desc.get_config_file_representation() for replacement_desc in descriptions],
- open(file_name, "w"), indent=4, sort_keys=True)
- except Exception as ex:
- log.error("failed to update configuration file {}: {}".format(file_name, str(ex)))
- return False
- return True
-
-
-def generate_pattern_for_node(graph: nx.MultiDiGraph, sub_graph_pattern: str, node_name: str):
+def generate_pattern_for_node(graph: Graph, sub_graph_pattern: str, node_name: str):
if sub_graph_pattern == '':
return node_name
node_name_components = node_name.split("/")