Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / utils / custom_replacement_config.py
index 8709e19..63dc551 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -21,7 +21,7 @@ from re import compile, match
 
 import networkx as nx
 
-from mo.graph.graph import Node
+from mo.graph.graph import Node, Graph
 from mo.utils.error import Error
 from mo.utils.graph import nodes_matching_name_pattern, sub_graph_between_nodes
 from mo.utils.utils import refer_to_faq_msg
@@ -126,7 +126,7 @@ class CustomReplacementDescriptor(object):
             return None
         return [(out['node'], out['port']) for out in self._replacement_desc['outputs']]
 
-    def update_custom_replacement_attributes(self, graph: nx.MultiDiGraph):
+    def update_custom_replacement_attributes(self, graph: Graph):
         """
         The function run specific functions to update attributes of the custom replacement description. Currently it
         updates information about input/output nodes.
@@ -179,7 +179,7 @@ class CustomReplacementDescriptorPoints(CustomReplacementDescriptor):
     def get_outputs_description(self):
         return [('^' + node_name + '$', 0) for node_name in self.instances['end_points']]
 
-    def get_internal_input_nodes(self, graph: nx.MultiDiGraph):
+    def get_internal_input_nodes(self, graph: Graph):
         """
         Gets list of node names getting input from outside of the sub-graph. This function checks whether input nodes
         specified in the configuration file should be added to the sub-graph or not. If they should not be added to the
@@ -199,7 +199,7 @@ class CustomReplacementDescriptorPoints(CustomReplacementDescriptor):
         else:
             return self.instances['start_points']
 
-    def get_internal_output_nodes(self, graph: nx.MultiDiGraph):
+    def get_internal_output_nodes(self, graph: Graph):
         """
         Gets list of node names producing output outside of the sub-graph. This function checks whether output nodes
         specified in the configuration file should be added to the sub-graph or not. If they should not be added to the
@@ -219,7 +219,7 @@ class CustomReplacementDescriptorPoints(CustomReplacementDescriptor):
         else:
             return self.instances['end_points']
 
-    def update_custom_replacement_attributes(self, graph: nx.MultiDiGraph):
+    def update_custom_replacement_attributes(self, graph: Graph):
         if not self.has('instances'):
             raise Error("No instance(s) is(are) defined for the custom replacement '{}'. ".format(self.replacement_id) +
                         refer_to_faq_msg(66))
@@ -278,7 +278,7 @@ class CustomReplacementDescriptorScope(CustomReplacementDescriptor):
     def __init__(self, replacement_id: str, attrs: dict = None):
         super().__init__(replacement_id, attrs)
 
-    def update_custom_replacement_attributes(self, graph: nx.MultiDiGraph):
+    def update_custom_replacement_attributes(self, graph: Graph):
         if not self.has('instances') or len(self.instances) == 0:
             raise Error("No instances are defined for replacement with id '{}'. ".format(self.replacement_id) +
                         refer_to_faq_msg(68))
@@ -384,35 +384,7 @@ def parse_custom_replacement_config_file(file_name: str):
     return result
 
 
-def update_custom_replacement_config_file(graph: nx.MultiDiGraph, file_name: str):
-    data = parse_custom_replacement_config_file(file_name)
-    if data is None:
-        raise Error("Cannot update the file '{}' because it is broken. ".format(file_name) +
-                    refer_to_faq_msg(73))
-
-    for replacement_desc in data:
-        replacement_desc.update_custom_replacement_attributes(graph)
-
-    return save_custom_replacement_config_file(data, file_name)
-
-
-def save_custom_replacement_config_file(descriptions: list, file_name: str):
-    """
-    Save custom layer(s) description(s) to the file.
-    :param file_name: file to save description information to.
-    :param descriptions: list with instances of the CustomLayerDescriptor classes.
-    :return: True if operation is successful.
-    """
-    try:
-        json.dump([replacement_desc.get_config_file_representation() for replacement_desc in descriptions],
-                  open(file_name, "w"), indent=4, sort_keys=True)
-    except Exception as ex:
-        log.error("failed to update configuration file {}: {}".format(file_name, str(ex)))
-        return False
-    return True
-
-
-def generate_pattern_for_node(graph: nx.MultiDiGraph, sub_graph_pattern: str, node_name: str):
+def generate_pattern_for_node(graph: Graph, sub_graph_pattern: str, node_name: str):
     if sub_graph_pattern == '':
         return node_name
     node_name_components = node_name.split("/")