Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / common / replacement.py
index 6a2874d..6b86689 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2017-2018 Intel Corporation
+ Copyright (c) 2017-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -18,7 +18,7 @@ import logging as log
 import networkx as nx
 
 from mo.front.subgraph_matcher import SubgraphMatch
-from mo.graph.graph import Node, merge_edge_props, get_sorted_inputs
+from mo.graph.graph import Node, merge_edge_props, Graph
 from mo.middle.pattern_match import apply_pattern
 from mo.utils import class_registration
 from mo.utils.replacement_pattern import ReplacementPattern
@@ -28,6 +28,14 @@ class FrontReplacementPattern(ReplacementPattern):
     registered_ops = {}
     registered_cls = []
 
+    def run_after(self):
+        from extensions.front.pass_separator import FrontStart
+        return [FrontStart]
+
+    def run_before(self):
+        from extensions.front.pass_separator import FrontFinish
+        return [FrontFinish]
+
     def pattern(self):
         raise Exception('Function "pattern" must be overridden in the sub-class')
 
@@ -45,6 +53,14 @@ class FrontReplacementSubgraph(FrontReplacementPattern):
     """
     replacement_id = 'None'
 
+    def run_after(self):
+        from extensions.front.pass_separator import FrontStart
+        return [FrontStart]
+
+    def run_before(self):
+        from extensions.front.pass_separator import FrontFinish
+        return [FrontFinish]
+
     def __init__(self):
         pass
 
@@ -53,7 +69,7 @@ class FrontReplacementSubgraph(FrontReplacementPattern):
         return node_port if isinstance(node_port, tuple) else (node_port, 0)
 
     @staticmethod
-    def replace_input_edges(graph: nx.DiGraph, input_edges_match: dict):
+    def replace_input_edges(graph: Graph, input_edges_match: dict):
         """
         Replacing existing input/output edges with a new ones to a new sub-graph.
         :param graph: networkX graph to operate on.
@@ -64,14 +80,14 @@ class FrontReplacementSubgraph(FrontReplacementPattern):
             old_node_name, old_in_port = __class__.extract_port(old_name_port)
             new_node_name, new_in_port = __class__.extract_port(new_name_port)
             old_node = Node(graph, old_node_name)
-            src_node_name = get_sorted_inputs(old_node)[old_in_port][0]
+            src_node_name = old_node.get_sorted_inputs()[old_in_port][0]
             edge_attrs = graph[src_node_name][old_node_name][0].copy()
             edge_attrs['in'] = new_in_port
             graph.add_edge(src_node_name, new_node_name, **edge_attrs)
             log.debug("Created edge from {} to {} with attrs: {}".format(src_node_name, new_node_name, edge_attrs))
 
     @staticmethod
-    def replace_output_edges(graph: nx.DiGraph, output_edges_match: dict):
+    def replace_output_edges(graph: Graph, output_edges_match: dict):
         """
         Replacing existing input/output edges with a new ones to a new sub-graph.
         :param graph: networkX graph to operate on.
@@ -88,28 +104,28 @@ class FrontReplacementSubgraph(FrontReplacementPattern):
                     graph.add_edge(new_node_name, dst, **new_edge_attrs)
                     log.debug("Created edge from {} to {} with attrs: {}".format(new_node_name, dst, new_edge_attrs))
 
-    def input_edges_match(self, graph: nx.MultiDiGraph, match: object, new_sub_graph: dict):
+    def input_edges_match(self, graph: Graph, match: object, new_sub_graph: dict):
         """
         Default implementation doesn't add new input edges automatically.
         """
         return {}
 
-    def output_edges_match(self, graph: nx.MultiDiGraph, match: object, new_sub_graph: dict):
+    def output_edges_match(self, graph: Graph, match: object, new_sub_graph: dict):
         """
         Default implementation doesn't add new output edges automatically.
         """
         return {}
 
-    def generate_sub_graph(self, graph: nx.MultiDiGraph, match: object):
+    def generate_sub_graph(self, graph: Graph, match: object):
         raise Exception("The function 'generate_sub_graph' must be implemented in the sub-class.")
 
-    def nodes_to_remove(self, graph: nx.MultiDiGraph, match: dict):
+    def nodes_to_remove(self, graph: Graph, match: dict):
         """
         Default implementation generates list of all matched nodes. So all matched nodes will be removed.
         """
         return [node.id for node in match.values()]
 
-    def replace_sub_graph(self, graph: nx.MultiDiGraph, match: [dict, SubgraphMatch]):
+    def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
         log.debug('replace_sub_graph: "{}" matched nodes: {}'.format(self.replacement_id,
                                                                      '\n'.join(sorted(match.matched_nodes_names()))))
         new_sub_graph = self.generate_sub_graph(graph, match)  # pylint: disable=assignment-from-no-return
@@ -121,7 +137,7 @@ class FrontReplacementSubgraph(FrontReplacementPattern):
             'replace_sub_graph: "{}" removing nodes: {}'.format(self.replacement_id, '\n'.join(sorted(remove_nodes))))
         graph.remove_nodes_from(remove_nodes)
 
-    def find_and_replace_pattern(self, graph: nx.MultiDiGraph):
+    def find_and_replace_pattern(self, graph: Graph):
         apply_pattern(graph, action=self.replace_sub_graph, **self.pattern())
 
     registered_ops = {}
@@ -143,6 +159,14 @@ class FrontReplacementOp(FrontReplacementSubgraph):
     """
     op = 'UnknownOp'
 
+    def run_after(self):
+        from extensions.front.pass_separator import FrontStart
+        return [FrontStart]
+
+    def run_before(self):
+        from extensions.front.pass_separator import FrontFinish
+        return [FrontFinish]
+
     def pattern(self):
         return dict(
             nodes=[
@@ -150,7 +174,7 @@ class FrontReplacementOp(FrontReplacementSubgraph):
             edges=[]
         )
 
-    def replace_op(self, graph: nx.MultiDiGraph, node: Node):
+    def replace_op(self, graph: Graph, node: Node):
         raise Exception("The function 'replace_op' must be implemented in the sub-class.")
 
     @staticmethod
@@ -167,7 +191,7 @@ class FrontReplacementOp(FrontReplacementSubgraph):
         return out_edges_match_dict
 
     @staticmethod
-    def update_input_edges_attrs(graph: nx.MultiDiGraph, node: Node, added_nodes: list):
+    def update_input_edges_attrs(graph: Graph, node: Node, added_nodes: list):
         """
         Copy edge attributes from 'old' input edges of node 'node' to new input sub-graph edges.
         :param graph: graph to operate on
@@ -181,7 +205,7 @@ class FrontReplacementOp(FrontReplacementSubgraph):
                     if old_u == new_u and old_edge_attrs['out'] == new_edge_attrs['out']:
                         merge_edge_props(new_edge_attrs, old_edge_attrs)  # copy old edge attributes
 
-    def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
+    def replace_sub_graph(self, graph: Graph, match: dict):
         assert 'op' in match
         assert len(match) == 1
         node = match['op']