Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / subgraph_matcher.py
1 """
2  Copyright (c) 2017-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import logging as log
17 import re
18
19 from mo.graph.graph import Node, Graph
20 from mo.utils.custom_replacement_config import CustomReplacementDescriptor
21 from mo.utils.error import Error
22 from mo.utils.graph import nodes_matching_name_pattern, sub_graph_between_nodes
23 from mo.utils.utils import refer_to_faq_msg
24
25
26 def find_object_by_pattern(names: list, pattern: str):
27     """
28     :param names: list of names to find objects from.
29     :param pattern: regular expression for the name.
30     :return: list of matched objects.
31     """
32     compiled_pattern = re.compile(pattern)
33     return [name for name in names if re.match(compiled_pattern, name)]
34
35
36 class SubgraphMatch(object):
37     """
38     Class providing information about matched sub-graph.
39     """
40
41     def __init__(self, graph: Graph, replacement_desc: CustomReplacementDescriptor, matched_nodes: list,
42                  inputs_order: list, outputs_order: list, prefix: str):
43         """
44         Creates instance of a SubgraphMatch class from the provided configuration.
45         :param graph: networkx graph.
46         :param replacement_desc: CustomReplacementDescriptor object describing sub-graph.
47         :param matched_nodes: list of matched nodes.
48         :param inputs_order: nodes description in the format described in the FrontReplacementFromConfigFileSubGraph.
49         :param outputs_order: nodes description in the format described in the FrontReplacementFromConfigFileSubGraph.
50         :param prefix: optional prefix of the node names. Is not used in the sub-graph match by points.
51         """
52         self._input_nodes_map = dict()
53         self._output_nodes_map = dict()
54         self._matched_nodes_names = matched_nodes
55         self.graph = graph
56         self.custom_replacement_desc = replacement_desc
57         self.scope = prefix
58
59         for sub_graph_input_port, input_desc in enumerate(inputs_order):
60             for node_pattern, node_in_port in input_desc:
61                 node = self.node_by_pattern(node_pattern)
62                 if node is not None:
63                     self._add_input_node(node.id, node_in_port, sub_graph_input_port)
64
65         for sub_graph_output_port, (node_pattern, out_port) in enumerate(outputs_order):
66             node = self.node_by_pattern(node_pattern)
67             if node is not None:
68                 self._add_output_node(node.id, out_port, sub_graph_output_port)
69
70     def matched_nodes_names(self):
71         """
72         Returns list of node names in the matched sub-graph.
73         :return: list of node names in the matched sub-graph.
74         """
75         return self._matched_nodes_names
76
77     def inputs_count(self):
78         """
79         Returns number of inputs for the matched sub-graph. Only unique input tensors are considered, thus if the same
80         tensor is consumed by two or more input nodes of the sub-graph it is counted only once.
81         :return: Number or unique input tensors.
82         """
83         return len(self._input_nodes_map.keys())
84
85     def outputs_count(self):
86         """
87         Returns number of outputs for the matched sub-graph. Only unique output tensors are considered, thus if the same
88         tensor is consumed by two or more nodes outside of the sub-graph it is counted only once.
89         :return: Number or unique input tensors.
90         """
91         return len(self._output_nodes_map.keys())
92
93     def input_nodes(self, port: int):
94         """
95         Returns list of tuples where the first element is a Node of the sub-graph and the second is the input port for
96         that node. Each node of this list gets the same input tensor through the input port with number 'port' of the
97         sub-graph.
98
99         For example, if the returned list requested for port 'portSG' is the following: [(NodeA, portA), (nodeB, portB)]
100         then the same tensor is passed to node 'NodeA' as input with number 'portA' and node 'nodeB' as input with
101         number 'portB' for the sub-graph input with number 'portSG'.
102         :param port: input port of the sub-graph.
103         :return: list describing nodes of the sub-graph getting tensor through the specified port.
104         """
105         return self._input_nodes_map[port]
106
107     def single_input_node(self, port: int):
108         """
109         The function does the same as function 'input_nodes' but it relies on fact that there is just one node that
110         gets input tensor for sub-graph input with number 'port', so it return just tuple (Node, nodePort) or raises
111         exception if the amount of nodes is not equal to 1.
112         :param port: input port of the sub-graph.
113         :return: tuple describing node of the sub-graph getting tensor through the specified port.
114         """
115         input_nodes = self.input_nodes(port)
116         if len(input_nodes) != 1:
117             raise Error('The amount of input nodes for port "{}" is not equal to 1. '.format(port) +
118                         refer_to_faq_msg(33))
119         return input_nodes[0]
120
121     def output_node(self, port: int):
122         """
123         Returns a tuple where the first element is a Node of the sub-graph and the second is the output port of that
124         node. Th node produces output tensor through the output port with number 'port' of the sub-graph.
125         :param port: output port of the sub-graph.
126         :return: tuple describing node of the sub-graph producing sub-graph output tensor through the specified port.
127         """
128         return self._output_nodes_map[port]
129
130     def node_by_pattern(self, pattern: str):
131         """
132         Returns Node from the list of sub-graph nodes matching node name regular expression 'pattern'. If there are more
133         than one nodes matched then the function raises exception.
134         :param pattern: the regular expression for the node name.
135         :return: matched Node.
136         """
137         if self.scope != '':
138             if self.scope[-1] == '/':
139                 pattern = self.scope + pattern
140             else:
141                 pattern = self.scope + '/' + pattern
142         found_names = find_object_by_pattern(self._matched_nodes_names, pattern)
143         if len(found_names) > 1:
144             raise Error('The amount of nodes matched pattern "{}" is more than 1. '.format(pattern) +
145                         refer_to_faq_msg(78))
146         if len(found_names) == 0:
147             return None
148         return Node(self.graph, found_names[0])
149
150     def _add_input_node(self, node_name: str, node_port: int, sub_graph_input_port: int):
151         self._input_nodes_map.setdefault(sub_graph_input_port, []).append((Node(self.graph, node_name), node_port))
152
153     def _add_output_node(self, node_name: str, node_port: int, sub_graph_output_port: int):
154         if sub_graph_output_port in self._output_nodes_map:
155             raise Error('Output node for port "{}" has already been specified. '.format(sub_graph_output_port) +
156                         refer_to_faq_msg(34))
157         self._output_nodes_map[sub_graph_output_port] = (Node(self.graph, node_name), node_port)
158
159
160 # TODO looks like this class is not needed. Can be implemented as pure functions.
161 class SubgraphMatcher(object):
162     def __init__(self, replacement_descriptor: CustomReplacementDescriptor):
163         self.replacement_desc = replacement_descriptor
164
165     def _match_sub_graph_for_scope(self, graph: Graph, scope_pattern: str):
166         """
167         :param graph: networkx graph to find sub-graph in.
168         :param scope_pattern: regular expression specifying sub-graph scope.
169         :return: an object describing matched sub-graph.
170         """
171         inputs_order = self.replacement_desc.get_inputs_description()
172         outputs_order = self.replacement_desc.get_outputs_description()
173
174         for list_nodes in inputs_order:
175             for node_name_pattern, port in list_nodes:
176                 if len(find_object_by_pattern(graph.nodes(), '.*' + node_name_pattern)) == 0:
177                     log.info('Node "{} does not exist in the graph". Failed to match sub-graph by scope "{}".'.format(
178                         node_name_pattern, self.replacement_desc.id))
179                     return None
180
181         matched_nodes = nodes_matching_name_pattern(graph, scope_pattern)
182         if len(matched_nodes) == 0:
183             log.info('There are no instances of the sub-graph by scope "{}"'.format(scope_pattern))
184             return None
185
186         return SubgraphMatch(graph, self.replacement_desc, matched_nodes, inputs_order, outputs_order, scope_pattern)
187
188     def _match_sub_graph_for_points(self, graph: Graph):
189         """
190         :param graph: networkx graph to find sub-graph in.
191         :return: an object describing matched sub-graph.
192         """
193         start_points = self.replacement_desc.get_internal_input_nodes(graph)
194         end_points = self.replacement_desc.get_internal_output_nodes(graph)
195         # check that start and end points exist in the graph
196         for node_name in start_points + end_points:
197             if node_name not in graph.nodes():
198                 log.info('Node "{}" does not exist in the graph. Failed to match sub-graph by points "{}".'.format(
199                     node_name, self.replacement_desc.id))
200                 return None
201
202         matched_nodes = sub_graph_between_nodes(graph, start_points, end_points)
203         return SubgraphMatch(graph, self.replacement_desc, matched_nodes,
204                              self.replacement_desc.get_inputs_description(),
205                              self.replacement_desc.get_outputs_description(), '')
206
207     def matched_sub_graph_instances(self, graph: Graph):
208         """
209         Generator to product all instances of matched sub-graphs.
210         :param graph: graph to find instances in.
211         :return: generator producing SubGraphMatch objects.
212         """
213         if self.replacement_desc.match_kind == 'points':  # instance is specified with lists of start/end nodes
214             match = self._match_sub_graph_for_points(graph)
215             if match is not None:
216                 yield match
217         elif self.replacement_desc.match_kind == 'scope':  # instance is specified with a node name pattern
218             for instance in self.replacement_desc.sub_graph_instances():
219                 match = self._match_sub_graph_for_scope(graph, instance)
220                 if match is not None:
221                     yield match
222         else:
223             raise Error('Unsupported match kind "{}". Match kinds "points" or "scope" are supported only. '.format(
224                 self.replacement_desc.match_kind) +
225                         refer_to_faq_msg(35))