"""
- Copyright (c) 2017-2018 Intel Corporation
+ Copyright (c) 2017-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
import networkx as nx
from mo.front.subgraph_matcher import SubgraphMatch
-from mo.graph.graph import Node, merge_edge_props, get_sorted_inputs
+from mo.graph.graph import Node, merge_edge_props, Graph
from mo.middle.pattern_match import apply_pattern
from mo.utils import class_registration
from mo.utils.replacement_pattern import ReplacementPattern
registered_ops = {}
registered_cls = []
+ def run_after(self):
+ from extensions.front.pass_separator import FrontStart
+ return [FrontStart]
+
+ def run_before(self):
+ from extensions.front.pass_separator import FrontFinish
+ return [FrontFinish]
+
def pattern(self):
raise Exception('Function "pattern" must be overridden in the sub-class')
"""
replacement_id = 'None'
+ def run_after(self):
+ from extensions.front.pass_separator import FrontStart
+ return [FrontStart]
+
+ def run_before(self):
+ from extensions.front.pass_separator import FrontFinish
+ return [FrontFinish]
+
def __init__(self):
pass
return node_port if isinstance(node_port, tuple) else (node_port, 0)
@staticmethod
- def replace_input_edges(graph: nx.DiGraph, input_edges_match: dict):
+ def replace_input_edges(graph: Graph, input_edges_match: dict):
"""
Replacing existing input/output edges with a new ones to a new sub-graph.
:param graph: networkX graph to operate on.
old_node_name, old_in_port = __class__.extract_port(old_name_port)
new_node_name, new_in_port = __class__.extract_port(new_name_port)
old_node = Node(graph, old_node_name)
- src_node_name = get_sorted_inputs(old_node)[old_in_port][0]
+ src_node_name = old_node.get_sorted_inputs()[old_in_port][0]
edge_attrs = graph[src_node_name][old_node_name][0].copy()
edge_attrs['in'] = new_in_port
graph.add_edge(src_node_name, new_node_name, **edge_attrs)
log.debug("Created edge from {} to {} with attrs: {}".format(src_node_name, new_node_name, edge_attrs))
@staticmethod
- def replace_output_edges(graph: nx.DiGraph, output_edges_match: dict):
+ def replace_output_edges(graph: Graph, output_edges_match: dict):
"""
Replacing existing input/output edges with a new ones to a new sub-graph.
:param graph: networkX graph to operate on.
graph.add_edge(new_node_name, dst, **new_edge_attrs)
log.debug("Created edge from {} to {} with attrs: {}".format(new_node_name, dst, new_edge_attrs))
- def input_edges_match(self, graph: nx.MultiDiGraph, match: object, new_sub_graph: dict):
+ def input_edges_match(self, graph: Graph, match: object, new_sub_graph: dict):
"""
Default implementation doesn't add new input edges automatically.
"""
return {}
- def output_edges_match(self, graph: nx.MultiDiGraph, match: object, new_sub_graph: dict):
+ def output_edges_match(self, graph: Graph, match: object, new_sub_graph: dict):
"""
Default implementation doesn't add new output edges automatically.
"""
return {}
- def generate_sub_graph(self, graph: nx.MultiDiGraph, match: object):
+ def generate_sub_graph(self, graph: Graph, match: object):
raise Exception("The function 'generate_sub_graph' must be implemented in the sub-class.")
- def nodes_to_remove(self, graph: nx.MultiDiGraph, match: dict):
+ def nodes_to_remove(self, graph: Graph, match: dict):
"""
Default implementation generates list of all matched nodes. So all matched nodes will be removed.
"""
return [node.id for node in match.values()]
- def replace_sub_graph(self, graph: nx.MultiDiGraph, match: [dict, SubgraphMatch]):
+ def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
log.debug('replace_sub_graph: "{}" matched nodes: {}'.format(self.replacement_id,
'\n'.join(sorted(match.matched_nodes_names()))))
new_sub_graph = self.generate_sub_graph(graph, match) # pylint: disable=assignment-from-no-return
'replace_sub_graph: "{}" removing nodes: {}'.format(self.replacement_id, '\n'.join(sorted(remove_nodes))))
graph.remove_nodes_from(remove_nodes)
- def find_and_replace_pattern(self, graph: nx.MultiDiGraph):
+ def find_and_replace_pattern(self, graph: Graph):
apply_pattern(graph, action=self.replace_sub_graph, **self.pattern())
registered_ops = {}
"""
op = 'UnknownOp'
+ def run_after(self):
+ from extensions.front.pass_separator import FrontStart
+ return [FrontStart]
+
+ def run_before(self):
+ from extensions.front.pass_separator import FrontFinish
+ return [FrontFinish]
+
def pattern(self):
return dict(
nodes=[
edges=[]
)
- def replace_op(self, graph: nx.MultiDiGraph, node: Node):
+ def replace_op(self, graph: Graph, node: Node):
raise Exception("The function 'replace_op' must be implemented in the sub-class.")
@staticmethod
return out_edges_match_dict
@staticmethod
- def update_input_edges_attrs(graph: nx.MultiDiGraph, node: Node, added_nodes: list):
+ def update_input_edges_attrs(graph: Graph, node: Node, added_nodes: list):
"""
Copy edge attributes from 'old' input edges of node 'node' to new input sub-graph edges.
:param graph: graph to operate on
if old_u == new_u and old_edge_attrs['out'] == new_edge_attrs['out']:
merge_edge_props(new_edge_attrs, old_edge_attrs) # copy old edge attributes
- def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
+ def replace_sub_graph(self, graph: Graph, match: dict):
assert 'op' in match
assert len(match) == 1
node = match['op']