Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / common / replacement.py
1 """
2  Copyright (c) 2017-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import logging as log
17
18 import networkx as nx
19
20 from mo.front.subgraph_matcher import SubgraphMatch
21 from mo.graph.graph import Node, merge_edge_props, Graph
22 from mo.middle.pattern_match import apply_pattern
23 from mo.utils import class_registration
24 from mo.utils.replacement_pattern import ReplacementPattern
25
26
27 class FrontReplacementPattern(ReplacementPattern):
28     registered_ops = {}
29     registered_cls = []
30
31     def run_after(self):
32         from extensions.front.pass_separator import FrontStart
33         return [FrontStart]
34
35     def run_before(self):
36         from extensions.front.pass_separator import FrontFinish
37         return [FrontFinish]
38
39     def pattern(self):
40         raise Exception('Function "pattern" must be overridden in the sub-class')
41
42     @classmethod
43     def class_type(cls):
44         return class_registration.ClassType.FRONT_REPLACER
45
46
47 ReplacementPattern.excluded_replacers.append(FrontReplacementPattern)
48
49
50 class FrontReplacementSubgraph(FrontReplacementPattern):
51     """
52     Replace pattern defined set of nodes with a sub-graph.
53     """
54     replacement_id = 'None'
55
56     def run_after(self):
57         from extensions.front.pass_separator import FrontStart
58         return [FrontStart]
59
60     def run_before(self):
61         from extensions.front.pass_separator import FrontFinish
62         return [FrontFinish]
63
64     def __init__(self):
65         pass
66
67     @staticmethod
68     def extract_port(node_port):
69         return node_port if isinstance(node_port, tuple) else (node_port, 0)
70
71     @staticmethod
72     def replace_input_edges(graph: Graph, input_edges_match: dict):
73         """
74         Replacing existing input/output edges with a new ones to a new sub-graph.
75         :param graph: networkX graph to operate on.
76         :param input_edges_match: match of input edges between old and new sub-graph.
77         :return: None
78         """
79         for old_name_port, new_name_port in input_edges_match.items():
80             old_node_name, old_in_port = __class__.extract_port(old_name_port)
81             new_node_name, new_in_port = __class__.extract_port(new_name_port)
82             old_node = Node(graph, old_node_name)
83             src_node_name = old_node.get_sorted_inputs()[old_in_port][0]
84             edge_attrs = graph[src_node_name][old_node_name][0].copy()
85             edge_attrs['in'] = new_in_port
86             graph.add_edge(src_node_name, new_node_name, **edge_attrs)
87             log.debug("Created edge from {} to {} with attrs: {}".format(src_node_name, new_node_name, edge_attrs))
88
89     @staticmethod
90     def replace_output_edges(graph: Graph, output_edges_match: dict):
91         """
92         Replacing existing input/output edges with a new ones to a new sub-graph.
93         :param graph: networkX graph to operate on.
94         :param output_edges_match: match of output edges between old and new sub-graph.
95         :return: None
96         """
97         for old_name_port, new_name_port in output_edges_match.items():
98             old_node_name, old_out_port = __class__.extract_port(old_name_port)
99             new_node_name, new_out_port = __class__.extract_port(new_name_port)
100             for src, dst, edge_attrs in graph.out_edges(old_node_name, data=True):
101                 if edge_attrs['out'] == old_out_port:
102                     new_edge_attrs = edge_attrs.copy()
103                     new_edge_attrs['out'] = new_out_port
104                     graph.add_edge(new_node_name, dst, **new_edge_attrs)
105                     log.debug("Created edge from {} to {} with attrs: {}".format(new_node_name, dst, new_edge_attrs))
106
107     def input_edges_match(self, graph: Graph, match: object, new_sub_graph: dict):
108         """
109         Default implementation doesn't add new input edges automatically.
110         """
111         return {}
112
113     def output_edges_match(self, graph: Graph, match: object, new_sub_graph: dict):
114         """
115         Default implementation doesn't add new output edges automatically.
116         """
117         return {}
118
119     def generate_sub_graph(self, graph: Graph, match: object):
120         raise Exception("The function 'generate_sub_graph' must be implemented in the sub-class.")
121
122     def nodes_to_remove(self, graph: Graph, match: dict):
123         """
124         Default implementation generates list of all matched nodes. So all matched nodes will be removed.
125         """
126         return [node.id for node in match.values()]
127
128     def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
129         log.debug('replace_sub_graph: "{}" matched nodes: {}'.format(self.replacement_id,
130                                                                      '\n'.join(sorted(match.matched_nodes_names()))))
131         new_sub_graph = self.generate_sub_graph(graph, match)  # pylint: disable=assignment-from-no-return
132         self.replace_input_edges(graph, self.input_edges_match(graph, match, new_sub_graph))
133         self.replace_output_edges(graph, self.output_edges_match(graph, match, new_sub_graph))
134
135         remove_nodes = self.nodes_to_remove(graph, match)
136         log.debug(
137             'replace_sub_graph: "{}" removing nodes: {}'.format(self.replacement_id, '\n'.join(sorted(remove_nodes))))
138         graph.remove_nodes_from(remove_nodes)
139
140     def find_and_replace_pattern(self, graph: Graph):
141         apply_pattern(graph, action=self.replace_sub_graph, **self.pattern())
142
143     registered_ops = {}
144     registered_cls = []
145
146     @classmethod
147     def class_type(cls):
148         return class_registration.ClassType.FRONT_REPLACER
149
150
151 ReplacementPattern.excluded_replacers.append(FrontReplacementSubgraph)
152
153
154 class FrontReplacementOp(FrontReplacementSubgraph):
155     """
156     A super class for an operation replacement.
157     Replaces a single operation (identified by 'op' attribute) by a sub-graph of operations.
158     It is a convenient specialization of FrontReplacementPattern.
159     """
160     op = 'UnknownOp'
161
162     def run_after(self):
163         from extensions.front.pass_separator import FrontStart
164         return [FrontStart]
165
166     def run_before(self):
167         from extensions.front.pass_separator import FrontFinish
168         return [FrontFinish]
169
170     def pattern(self):
171         return dict(
172             nodes=[
173                 ('op', dict(op=self.__class__.op))],
174             edges=[]
175         )
176
177     def replace_op(self, graph: Graph, node: Node):
178         raise Exception("The function 'replace_op' must be implemented in the sub-class.")
179
180     @staticmethod
181     def gen_output_edges_match(node: Node, out_node_replace: list):
182         out_edges_match_dict = dict()
183         for old_out_port, new_node_desc in enumerate(out_node_replace):
184             new_out_port = 0
185             if new_node_desc is tuple:
186                 new_node_name = new_node_desc[0]
187                 new_out_port = new_node_desc[1]
188             else:
189                 new_node_name = new_node_desc
190             out_edges_match_dict[(node.id, old_out_port)] = (new_node_name, new_out_port)
191         return out_edges_match_dict
192
193     @staticmethod
194     def update_input_edges_attrs(graph: Graph, node: Node, added_nodes: list):
195         """
196         Copy edge attributes from 'old' input edges of node 'node' to new input sub-graph edges.
197         :param graph: graph to operate on
198         :param node: Node object that was replaced.
199         :param added_nodes: list of nodes names added.
200         :return: None
201         """
202         for old_u, old_v, old_edge_attrs in graph.in_edges(node.id, data=True):
203             for new_u, new_v, new_edge_attrs in graph.in_edges(added_nodes, data=True):
204                 if new_u not in added_nodes:  # external input to the sub-graph
205                     if old_u == new_u and old_edge_attrs['out'] == new_edge_attrs['out']:
206                         merge_edge_props(new_edge_attrs, old_edge_attrs)  # copy old edge attributes
207
208     def replace_sub_graph(self, graph: Graph, match: dict):
209         assert 'op' in match
210         assert len(match) == 1
211         node = match['op']
212         nodes_before_replacement = graph.nodes()
213         self.replace_output_edges(graph, self.gen_output_edges_match(node, self.replace_op(graph, node)))
214
215         # nodes added by the 'replace_op' function call
216         added_nodes = list(set(graph.nodes()) - set(nodes_before_replacement))
217         self.update_input_edges_attrs(graph, node, added_nodes)
218
219         # TODO Need to check if there are other users for these nodes
220         remove_nodes = self.nodes_to_remove(graph, match)
221         log.debug("Removing nodes: {}".format(remove_nodes))
222         graph.remove_nodes_from(remove_nodes)
223
224     registered_ops = {}
225     registered_cls = []
226
227     @classmethod
228     def class_type(cls):
229         return class_registration.ClassType.FRONT_REPLACER
230
231
232 ReplacementPattern.excluded_replacers.append(FrontReplacementOp)