2 Copyright (c) 2017-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from mo.graph.graph import Node, Graph
20 from mo.utils.custom_replacement_config import CustomReplacementDescriptor
21 from mo.utils.error import Error
22 from mo.utils.graph import nodes_matching_name_pattern, sub_graph_between_nodes
23 from mo.utils.utils import refer_to_faq_msg
26 def find_object_by_pattern(names: list, pattern: str):
28 :param names: list of names to find objects from.
29 :param pattern: regular expression for the name.
30 :return: list of matched objects.
32 compiled_pattern = re.compile(pattern)
33 return [name for name in names if re.match(compiled_pattern, name)]
36 class SubgraphMatch(object):
38 Class providing information about matched sub-graph.
41 def __init__(self, graph: Graph, replacement_desc: CustomReplacementDescriptor, matched_nodes: list,
42 inputs_order: list, outputs_order: list, prefix: str):
44 Creates instance of a SubgraphMatch class from the provided configuration.
45 :param graph: networkx graph.
46 :param replacement_desc: CustomReplacementDescriptor object describing sub-graph.
47 :param matched_nodes: list of matched nodes.
48 :param inputs_order: nodes description in the format described in the FrontReplacementFromConfigFileSubGraph.
49 :param outputs_order: nodes description in the format described in the FrontReplacementFromConfigFileSubGraph.
50 :param prefix: optional prefix of the node names. Is not used in the sub-graph match by points.
52 self._input_nodes_map = dict()
53 self._output_nodes_map = dict()
54 self._matched_nodes_names = matched_nodes
56 self.custom_replacement_desc = replacement_desc
59 for sub_graph_input_port, input_desc in enumerate(inputs_order):
60 for node_pattern, node_in_port in input_desc:
61 node = self.node_by_pattern(node_pattern)
63 self._add_input_node(node.id, node_in_port, sub_graph_input_port)
65 for sub_graph_output_port, (node_pattern, out_port) in enumerate(outputs_order):
66 node = self.node_by_pattern(node_pattern)
68 self._add_output_node(node.id, out_port, sub_graph_output_port)
70 def matched_nodes_names(self):
72 Returns list of node names in the matched sub-graph.
73 :return: list of node names in the matched sub-graph.
75 return self._matched_nodes_names
77 def inputs_count(self):
79 Returns number of inputs for the matched sub-graph. Only unique input tensors are considered, thus if the same
80 tensor is consumed by two or more input nodes of the sub-graph it is counted only once.
81 :return: Number or unique input tensors.
83 return len(self._input_nodes_map.keys())
85 def outputs_count(self):
87 Returns number of outputs for the matched sub-graph. Only unique output tensors are considered, thus if the same
88 tensor is consumed by two or more nodes outside of the sub-graph it is counted only once.
89 :return: Number or unique input tensors.
91 return len(self._output_nodes_map.keys())
93 def input_nodes(self, port: int):
95 Returns list of tuples where the first element is a Node of the sub-graph and the second is the input port for
96 that node. Each node of this list gets the same input tensor through the input port with number 'port' of the
99 For example, if the returned list requested for port 'portSG' is the following: [(NodeA, portA), (nodeB, portB)]
100 then the same tensor is passed to node 'NodeA' as input with number 'portA' and node 'nodeB' as input with
101 number 'portB' for the sub-graph input with number 'portSG'.
102 :param port: input port of the sub-graph.
103 :return: list describing nodes of the sub-graph getting tensor through the specified port.
105 return self._input_nodes_map[port]
107 def single_input_node(self, port: int):
109 The function does the same as function 'input_nodes' but it relies on fact that there is just one node that
110 gets input tensor for sub-graph input with number 'port', so it return just tuple (Node, nodePort) or raises
111 exception if the amount of nodes is not equal to 1.
112 :param port: input port of the sub-graph.
113 :return: tuple describing node of the sub-graph getting tensor through the specified port.
115 input_nodes = self.input_nodes(port)
116 if len(input_nodes) != 1:
117 raise Error('The amount of input nodes for port "{}" is not equal to 1. '.format(port) +
118 refer_to_faq_msg(33))
119 return input_nodes[0]
121 def output_node(self, port: int):
123 Returns a tuple where the first element is a Node of the sub-graph and the second is the output port of that
124 node. Th node produces output tensor through the output port with number 'port' of the sub-graph.
125 :param port: output port of the sub-graph.
126 :return: tuple describing node of the sub-graph producing sub-graph output tensor through the specified port.
128 return self._output_nodes_map[port]
130 def node_by_pattern(self, pattern: str):
132 Returns Node from the list of sub-graph nodes matching node name regular expression 'pattern'. If there are more
133 than one nodes matched then the function raises exception.
134 :param pattern: the regular expression for the node name.
135 :return: matched Node.
138 if self.scope[-1] == '/':
139 pattern = self.scope + pattern
141 pattern = self.scope + '/' + pattern
142 found_names = find_object_by_pattern(self._matched_nodes_names, pattern)
143 if len(found_names) > 1:
144 raise Error('The amount of nodes matched pattern "{}" is more than 1. '.format(pattern) +
145 refer_to_faq_msg(78))
146 if len(found_names) == 0:
148 return Node(self.graph, found_names[0])
150 def _add_input_node(self, node_name: str, node_port: int, sub_graph_input_port: int):
151 self._input_nodes_map.setdefault(sub_graph_input_port, []).append((Node(self.graph, node_name), node_port))
153 def _add_output_node(self, node_name: str, node_port: int, sub_graph_output_port: int):
154 if sub_graph_output_port in self._output_nodes_map:
155 raise Error('Output node for port "{}" has already been specified. '.format(sub_graph_output_port) +
156 refer_to_faq_msg(34))
157 self._output_nodes_map[sub_graph_output_port] = (Node(self.graph, node_name), node_port)
160 # TODO looks like this class is not needed. Can be implemented as pure functions.
161 class SubgraphMatcher(object):
162 def __init__(self, replacement_descriptor: CustomReplacementDescriptor):
163 self.replacement_desc = replacement_descriptor
165 def _match_sub_graph_for_scope(self, graph: Graph, scope_pattern: str):
167 :param graph: networkx graph to find sub-graph in.
168 :param scope_pattern: regular expression specifying sub-graph scope.
169 :return: an object describing matched sub-graph.
171 inputs_order = self.replacement_desc.get_inputs_description()
172 outputs_order = self.replacement_desc.get_outputs_description()
174 for list_nodes in inputs_order:
175 for node_name_pattern, port in list_nodes:
176 if len(find_object_by_pattern(graph.nodes(), '.*' + node_name_pattern)) == 0:
177 log.info('Node "{} does not exist in the graph". Failed to match sub-graph by scope "{}".'.format(
178 node_name_pattern, self.replacement_desc.id))
181 matched_nodes = nodes_matching_name_pattern(graph, scope_pattern)
182 if len(matched_nodes) == 0:
183 log.info('There are no instances of the sub-graph by scope "{}"'.format(scope_pattern))
186 return SubgraphMatch(graph, self.replacement_desc, matched_nodes, inputs_order, outputs_order, scope_pattern)
188 def _match_sub_graph_for_points(self, graph: Graph):
190 :param graph: networkx graph to find sub-graph in.
191 :return: an object describing matched sub-graph.
193 start_points = self.replacement_desc.get_internal_input_nodes(graph)
194 end_points = self.replacement_desc.get_internal_output_nodes(graph)
195 # check that start and end points exist in the graph
196 for node_name in start_points + end_points:
197 if node_name not in graph.nodes():
198 log.info('Node "{}" does not exist in the graph. Failed to match sub-graph by points "{}".'.format(
199 node_name, self.replacement_desc.id))
202 matched_nodes = sub_graph_between_nodes(graph, start_points, end_points)
203 return SubgraphMatch(graph, self.replacement_desc, matched_nodes,
204 self.replacement_desc.get_inputs_description(),
205 self.replacement_desc.get_outputs_description(), '')
207 def matched_sub_graph_instances(self, graph: Graph):
209 Generator to product all instances of matched sub-graphs.
210 :param graph: graph to find instances in.
211 :return: generator producing SubGraphMatch objects.
213 if self.replacement_desc.match_kind == 'points': # instance is specified with lists of start/end nodes
214 match = self._match_sub_graph_for_points(graph)
215 if match is not None:
217 elif self.replacement_desc.match_kind == 'scope': # instance is specified with a node name pattern
218 for instance in self.replacement_desc.sub_graph_instances():
219 match = self._match_sub_graph_for_scope(graph, instance)
220 if match is not None:
223 raise Error('Unsupported match kind "{}". Match kinds "points" or "scope" are supported only. '.format(
224 self.replacement_desc.match_kind) +
225 refer_to_faq_msg(35))