2 Copyright (c) 2017-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
20 from mo.front.subgraph_matcher import SubgraphMatch
21 from mo.graph.graph import Node, merge_edge_props, Graph
22 from mo.middle.pattern_match import apply_pattern
23 from mo.utils import class_registration
24 from mo.utils.replacement_pattern import ReplacementPattern
27 class FrontReplacementPattern(ReplacementPattern):
32 from extensions.front.pass_separator import FrontStart
36 from extensions.front.pass_separator import FrontFinish
40 raise Exception('Function "pattern" must be overridden in the sub-class')
44 return class_registration.ClassType.FRONT_REPLACER
47 ReplacementPattern.excluded_replacers.append(FrontReplacementPattern)
50 class FrontReplacementSubgraph(FrontReplacementPattern):
52 Replace pattern defined set of nodes with a sub-graph.
54 replacement_id = 'None'
57 from extensions.front.pass_separator import FrontStart
61 from extensions.front.pass_separator import FrontFinish
68 def extract_port(node_port):
69 return node_port if isinstance(node_port, tuple) else (node_port, 0)
72 def replace_input_edges(graph: Graph, input_edges_match: dict):
74 Replacing existing input/output edges with a new ones to a new sub-graph.
75 :param graph: networkX graph to operate on.
76 :param input_edges_match: match of input edges between old and new sub-graph.
79 for old_name_port, new_name_port in input_edges_match.items():
80 old_node_name, old_in_port = __class__.extract_port(old_name_port)
81 new_node_name, new_in_port = __class__.extract_port(new_name_port)
82 old_node = Node(graph, old_node_name)
83 src_node_name = old_node.get_sorted_inputs()[old_in_port][0]
84 edge_attrs = graph[src_node_name][old_node_name][0].copy()
85 edge_attrs['in'] = new_in_port
86 graph.add_edge(src_node_name, new_node_name, **edge_attrs)
87 log.debug("Created edge from {} to {} with attrs: {}".format(src_node_name, new_node_name, edge_attrs))
90 def replace_output_edges(graph: Graph, output_edges_match: dict):
92 Replacing existing input/output edges with a new ones to a new sub-graph.
93 :param graph: networkX graph to operate on.
94 :param output_edges_match: match of output edges between old and new sub-graph.
97 for old_name_port, new_name_port in output_edges_match.items():
98 old_node_name, old_out_port = __class__.extract_port(old_name_port)
99 new_node_name, new_out_port = __class__.extract_port(new_name_port)
100 for src, dst, edge_attrs in graph.out_edges(old_node_name, data=True):
101 if edge_attrs['out'] == old_out_port:
102 new_edge_attrs = edge_attrs.copy()
103 new_edge_attrs['out'] = new_out_port
104 graph.add_edge(new_node_name, dst, **new_edge_attrs)
105 log.debug("Created edge from {} to {} with attrs: {}".format(new_node_name, dst, new_edge_attrs))
107 def input_edges_match(self, graph: Graph, match: object, new_sub_graph: dict):
109 Default implementation doesn't add new input edges automatically.
113 def output_edges_match(self, graph: Graph, match: object, new_sub_graph: dict):
115 Default implementation doesn't add new output edges automatically.
119 def generate_sub_graph(self, graph: Graph, match: object):
120 raise Exception("The function 'generate_sub_graph' must be implemented in the sub-class.")
122 def nodes_to_remove(self, graph: Graph, match: dict):
124 Default implementation generates list of all matched nodes. So all matched nodes will be removed.
126 return [node.id for node in match.values()]
128 def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
129 log.debug('replace_sub_graph: "{}" matched nodes: {}'.format(self.replacement_id,
130 '\n'.join(sorted(match.matched_nodes_names()))))
131 new_sub_graph = self.generate_sub_graph(graph, match) # pylint: disable=assignment-from-no-return
132 self.replace_input_edges(graph, self.input_edges_match(graph, match, new_sub_graph))
133 self.replace_output_edges(graph, self.output_edges_match(graph, match, new_sub_graph))
135 remove_nodes = self.nodes_to_remove(graph, match)
137 'replace_sub_graph: "{}" removing nodes: {}'.format(self.replacement_id, '\n'.join(sorted(remove_nodes))))
138 graph.remove_nodes_from(remove_nodes)
140 def find_and_replace_pattern(self, graph: Graph):
141 apply_pattern(graph, action=self.replace_sub_graph, **self.pattern())
148 return class_registration.ClassType.FRONT_REPLACER
151 ReplacementPattern.excluded_replacers.append(FrontReplacementSubgraph)
154 class FrontReplacementOp(FrontReplacementSubgraph):
156 A super class for an operation replacement.
157 Replaces a single operation (identified by 'op' attribute) by a sub-graph of operations.
158 It is a convenient specialization of FrontReplacementPattern.
163 from extensions.front.pass_separator import FrontStart
166 def run_before(self):
167 from extensions.front.pass_separator import FrontFinish
173 ('op', dict(op=self.__class__.op))],
177 def replace_op(self, graph: Graph, node: Node):
178 raise Exception("The function 'replace_op' must be implemented in the sub-class.")
181 def gen_output_edges_match(node: Node, out_node_replace: list):
182 out_edges_match_dict = dict()
183 for old_out_port, new_node_desc in enumerate(out_node_replace):
185 if new_node_desc is tuple:
186 new_node_name = new_node_desc[0]
187 new_out_port = new_node_desc[1]
189 new_node_name = new_node_desc
190 out_edges_match_dict[(node.id, old_out_port)] = (new_node_name, new_out_port)
191 return out_edges_match_dict
194 def update_input_edges_attrs(graph: Graph, node: Node, added_nodes: list):
196 Copy edge attributes from 'old' input edges of node 'node' to new input sub-graph edges.
197 :param graph: graph to operate on
198 :param node: Node object that was replaced.
199 :param added_nodes: list of nodes names added.
202 for old_u, old_v, old_edge_attrs in graph.in_edges(node.id, data=True):
203 for new_u, new_v, new_edge_attrs in graph.in_edges(added_nodes, data=True):
204 if new_u not in added_nodes: # external input to the sub-graph
205 if old_u == new_u and old_edge_attrs['out'] == new_edge_attrs['out']:
206 merge_edge_props(new_edge_attrs, old_edge_attrs) # copy old edge attributes
208 def replace_sub_graph(self, graph: Graph, match: dict):
210 assert len(match) == 1
212 nodes_before_replacement = graph.nodes()
213 self.replace_output_edges(graph, self.gen_output_edges_match(node, self.replace_op(graph, node)))
215 # nodes added by the 'replace_op' function call
216 added_nodes = list(set(graph.nodes()) - set(nodes_before_replacement))
217 self.update_input_edges_attrs(graph, node, added_nodes)
219 # TODO Need to check if there are other users for these nodes
220 remove_nodes = self.nodes_to_remove(graph, match)
221 log.debug("Removing nodes: {}".format(remove_nodes))
222 graph.remove_nodes_from(remove_nodes)
229 return class_registration.ClassType.FRONT_REPLACER
232 ReplacementPattern.excluded_replacers.append(FrontReplacementOp)