2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
20 from re import compile, match
24 from mo.graph.graph import Node, Graph
25 from mo.utils.error import Error
26 from mo.utils.graph import nodes_matching_name_pattern, sub_graph_between_nodes
27 from mo.utils.utils import refer_to_faq_msg
30 class CustomReplacementDescriptor(object):
31 registered_types = dict()
33 def __init__(self, replacement_id: str, attrs: dict = None):
35 Create class instance based on attrs dictionary which is read from the configuration file.
38 super(CustomReplacementDescriptor, self).__setattr__('replacement_id', replacement_id)
40 super(CustomReplacementDescriptor, self).__setattr__('custom_attributes',
41 attrs.setdefault('custom_attributes', {}))
42 super(CustomReplacementDescriptor, self).__setattr__('_replacement_desc', attrs.copy())
44 def __getattr__(self, k):
45 return self._replacement_desc[k]
47 def __setattr__(self, k, v):
48 # you can assign only existing attributes
49 if k not in self._replacement_desc:
51 self._replacement_desc[k] = v
55 Check that attribute 'attr' is defined for the CustomReplacementDescriptor.
56 :param attr: attribute to check.
57 :return: True if the attribute exists and False otherwise.
59 return attr in self._replacement_desc
62 def register_type(cls, match_kind: str, class_type: object):
63 if match_kind in cls.registered_types:
64 log.warning('Class for match kind "{}" is already registered'.format(match_kind))
66 cls.registered_types[match_kind] = class_type
69 def create_instance(cls, match_kind: str, replacement_id: str, attrs: dict = None):
71 Fabric method to create proper object based on match_kind.
72 :param match_kind: match kind.
73 :param replacement_id: id of the replacement.
74 :param attrs: optional attributes to be set.
75 :return: object of the sub-class of the CustomLayerDescriptor class or None if the match kind is not registered.
79 if match_kind in cls.registered_types:
80 return cls.registered_types[match_kind](replacement_id, attrs)
82 raise Error('No class registered for match kind "{}". Supported match kinds are "{}". '.format(
83 match_kind, list(cls.registered_types.keys())) +
86 def sub_graph_instances(self):
87 raise Exception("The function 'get_sub_graph_instances' must be implemented in the sub-class.")
89 def get_config_file_representation(self):
91 'match_kind': self.match_kind, 'instances': self.instances,
92 'inputs': self.inputs, 'outputs': self.outputs,
93 'custom_attributes': self.custom_attributes, 'id': self.id
96 result.update({'op': self.op})
99 def get_inputs_description(self):
101 Returns description of inputs of the layer with id 'layer_id'. The format of inputs is the following: list of
102 lists where each list contains information about nodes consuming the same tensor from outside of the graph. Each
103 element of the list is a pair where first element is a regular expression for the name of the node in the
104 sub-graph and the second is the input port of this node.
105 :return: description of inputs or None if layer with such id is not registered or information about inputs is
108 if 'inputs' not in self._replacement_desc:
109 log.error("Information about inputs of layer with id '{}' is not available".format(self.replacement_id))
112 for index, input_desc in enumerate(self._replacement_desc['inputs']):
113 result.append([(inp['node'], inp['port']) for inp in input_desc])
116 def get_outputs_description(self):
118 Returns description of outputs of the layer with id 'layer_id'. The format of outputs is the following: list of
119 pairs where the first element of the pair is a regular expression for the name of the node that produces output
120 of the sub-graph and the second is the output port of this node.
121 :return: description of outputs or None if layer with such id is not registered or information about outputs is
124 if 'outputs' not in self._replacement_desc:
125 log.error("Information about outputs of layer with id '{}' is not available")
127 return [(out['node'], out['port']) for out in self._replacement_desc['outputs']]
129 def update_custom_replacement_attributes(self, graph: Graph):
131 The function run specific functions to update attributes of the custom replacement description. Currently it
132 updates information about input/output nodes.
133 :param graph: graph to operate on.
134 :return: True if the update process completed successfully.
136 raise Exception("The function 'update_custom_layer_attributes' must be implemented in the sub-class.")
138 def validate_data(self):
140 Validates layer description dictionary.
141 :return: list of errors identified.
144 if not self.has('id'):
145 errors.append("Replacement id is not specified for custom replacement '{}'".format(self.replacement_id))
146 if not self.has('instances') or self.instances == '':
147 errors.append("Attribute 'instances' is not specified for replacement '{}'".format(self.replacement_id))
148 if not self.has('match_kind'):
149 errors.append("Replacement match type is not specified for replacement '{}'".format(self.replacement_id))
153 class CustomReplacementDescriptorPoints(CustomReplacementDescriptor):
155 Class that is used to describe custom replacement which is a sub-graph specified by start and end points.
158 def __init__(self, replacement_id: str, attrs: dict = None):
159 super().__init__(replacement_id, attrs)
160 if not self.has('include_inputs_to_sub_graph'):
161 super(CustomReplacementDescriptorPoints, self).__setattr__('include_inputs_to_sub_graph', True)
162 if not self.has('include_outputs_to_sub_graph'):
163 super(CustomReplacementDescriptorPoints, self).__setattr__('include_outputs_to_sub_graph', True)
165 def get_config_file_representation(self):
167 'match_kind': self.match_kind, 'instances': self.instances,
168 'custom_attributes': self.custom_attributes, 'id': self.id,
169 'include_inputs_to_sub_graph': bool(self.include_inputs_to_sub_graph),
170 'include_outputs_to_sub_graph': bool(self.include_outputs_to_sub_graph)
173 result.update({'op': self.op})
176 def get_inputs_description(self):
177 return [[('^' + node_name + '$', 0)] for node_name in self.instances['start_points']]
179 def get_outputs_description(self):
180 return [('^' + node_name + '$', 0) for node_name in self.instances['end_points']]
182 def get_internal_input_nodes(self, graph: Graph):
184 Gets list of node names getting input from outside of the sub-graph. This function checks whether input nodes
185 specified in the configuration file should be added to the sub-graph or not. If they should not be added to the
186 sub-graph then input nodes of the sub-graph are children of these nodes.
187 :param graph: graph to operate on.
188 :return: list of input node names.
190 if not self.include_inputs_to_sub_graph:
191 log.debug('Do not include inputs to sub-graph for replacement with id {}'.format(self.replacement_id))
192 new_start_nodes = set()
193 for start_node in self.instances['start_points']:
194 for _, out_node_name in graph.out_edges(start_node):
195 new_start_nodes.add(out_node_name)
196 start_nodes = list(new_start_nodes)
197 log.debug('New inputs are: {}'.format(start_nodes))
200 return self.instances['start_points']
202 def get_internal_output_nodes(self, graph: Graph):
204 Gets list of node names producing output outside of the sub-graph. This function checks whether output nodes
205 specified in the configuration file should be added to the sub-graph or not. If they should not be added to the
206 sub-graph then output nodes of the sub-graph are parents of these nodes.
207 :param graph: graph to operate on.
208 :return: list of output node names.
210 if not self.include_outputs_to_sub_graph:
211 log.debug('Do not include outputs of sub-graph for replacement with id {}'.format(self.replacement_id))
212 new_end_nodes = set()
213 for end_node in self.instances['end_points']:
214 for in_node_name, _ in graph.in_edges(end_node):
215 new_end_nodes.add(in_node_name)
216 end_nodes = list(new_end_nodes)
217 log.debug('New outputs are: {}'.format(end_nodes))
220 return self.instances['end_points']
222 def update_custom_replacement_attributes(self, graph: Graph):
223 if not self.has('instances'):
224 raise Error("No instance(s) is(are) defined for the custom replacement '{}'. ".format(self.replacement_id) +
225 refer_to_faq_msg(66))
226 if not isinstance(self.instances, dict):
227 raise Error("The instance must be a single dictionary for the custom replacement with id '{}'. ".format(
228 self.replacement_id) +
229 refer_to_faq_msg(67))
231 start_points = self.get_internal_input_nodes(graph)
232 end_points = self.get_internal_output_nodes(graph)
234 matched_nodes = sub_graph_between_nodes(graph, start_points, end_points)
235 output_tensors = set()
236 input_nodes_mapping = dict() # key is the input tensor name, value is the pair: (input_port, output_node_name)
237 for src_node_name, dst_node_name, edge_attrs in graph.edges(data=True):
238 dst_node = graph.node[dst_node_name]
240 # edge outside sub-graph into sub-graph
241 if (src_node_name not in matched_nodes) and (dst_node_name in matched_nodes):
242 tensor_name = src_node_name + ":" + str(edge_attrs['out'])
243 if tensor_name not in input_nodes_mapping:
244 input_nodes_mapping[tensor_name] = list()
245 input_nodes_mapping[tensor_name].append(('^' + dst_node_name + '$', edge_attrs['in']))
247 # edge from inside sub-graph to outside sub-graph
248 if (src_node_name in matched_nodes) and (dst_node_name not in matched_nodes):
249 output_tensors.add(('^' + dst_node['pb'].input[edge_attrs['in']] + '$', edge_attrs['out']))
251 for node_name in graph.nodes():
252 node = Node(graph, node_name)
253 if node_name in matched_nodes and len(node.out_nodes()) == 0 and node['pb'].op != 'Const':
254 log.debug("Node {} doesn't have output edges. Consider it output".format(node_name))
255 output_tensors.add(('^' + node_name + '$', 0))
257 if not self.has('inputs'):
258 self._replacement_desc['inputs'] = [[{'node': desc[0], 'port': desc[1]} for desc in inp]
259 for inp in sorted(input_nodes_mapping.values())]
260 log.debug('Updated inputs of sub-graph for instance "{}"'.format(self.instances))
262 if not self.has('outputs'):
263 self._replacement_desc['outputs'] = [{'node': node, 'port': port} for node, port in sorted(output_tensors)]
264 log.debug('Updated outputs of sub-graph for instance "{}"'.format(self.instances))
266 def sub_graph_instances(self):
267 return [self.instances]
270 CustomReplacementDescriptor.register_type('points', CustomReplacementDescriptorPoints)
273 class CustomReplacementDescriptorScope(CustomReplacementDescriptor):
275 Class that is used to describe custom layer which is a sub-graph specified by scope name.
278 def __init__(self, replacement_id: str, attrs: dict = None):
279 super().__init__(replacement_id, attrs)
281 def update_custom_replacement_attributes(self, graph: Graph):
282 if not self.has('instances') or len(self.instances) == 0:
283 raise Error("No instances are defined for replacement with id '{}'. ".format(self.replacement_id) +
284 refer_to_faq_msg(68))
286 pattern = self.instances[0] # use the first instance pattern to find input/output nodes patterns
287 # TODO verify that all instances will produce the same sub-graph
288 matched_nodes = nodes_matching_name_pattern(graph, pattern)
290 output_tensors = set()
291 input_nodes_mapping = dict() # key is the input tensor name, value is the pair: (input_port, output_node_name)
292 for src_node_name, dst_node_name, edge_attrs in graph.edges(data=True):
293 dst_node = graph.node[dst_node_name]
295 # edge outside sub-graph into sub-graph
296 if (src_node_name not in matched_nodes) and (dst_node_name in matched_nodes):
297 tensor_name = src_node_name + ":" + str(edge_attrs['out'])
298 if tensor_name not in input_nodes_mapping:
299 input_nodes_mapping[tensor_name] = list()
300 input_nodes_mapping[tensor_name].append((generate_pattern_for_node(graph, pattern, dst_node_name),
303 # edge from inside sub-graph to outside sub-graph
304 if (src_node_name in matched_nodes) and (dst_node_name not in matched_nodes):
306 (generate_pattern_for_node(graph, pattern, dst_node['pb'].input[edge_attrs['in']]),
309 for node_name in graph.nodes():
310 node = Node(graph, node_name)
311 if node_name in matched_nodes and len(node.out_nodes()) == 0 and node['pb'].op != 'Const':
312 log.debug("Node {} doesn't have output edges. Consider it output".format(node_name))
313 output_tensors.add((generate_pattern_for_node(graph, pattern, node_name), 0))
315 if not self.has('inputs'):
316 self._replacement_desc['inputs'] = [[{'node': desc[0], 'port': desc[1]} for desc in inp]
317 for inp in sorted(input_nodes_mapping.values())]
318 log.debug('Updated inputs of sub-graph for instance "{}"'.format(self.instances))
320 if not self.has('outputs'):
321 self._replacement_desc['outputs'] = [{'node': node, 'port': port} for node, port in sorted(output_tensors)]
322 log.debug('Updated outputs of sub-graph for instance "{}"'.format(self.instances))
324 def sub_graph_instances(self):
325 return self.instances
328 CustomReplacementDescriptor.register_type('scope', CustomReplacementDescriptorScope)
331 class CustomReplacementDescriptorGeneral(CustomReplacementDescriptor):
332 def __init__(self, replacement_id: str, attrs: dict = None):
333 super().__init__(replacement_id, attrs)
335 def validate_data(self):
337 Validates layer description dictionary.
338 :return: list of errors identified.
341 if not self.has('id'):
342 errors.append("Replacement id is not specified for custom replacement '{}'".format(self.replacement_id))
343 if not self.has('match_kind'):
344 errors.append("Replacement match type is not specified for replacement '{}'".format(self.replacement_id))
348 CustomReplacementDescriptor.register_type('general', CustomReplacementDescriptorGeneral)
351 def parse_custom_replacement_config_file(file_name: str):
353 Reads custom replacement configuration file file_name.
354 :param file_name: name of the file to read from.
355 :return: The dictionary where key is the layer id and value is an instance of the CustomLayerDescriptor object.
357 if not os.path.exists(file_name):
358 raise Error("Custom replacements configuration file '{}' does not exist. ".format(file_name) +
359 refer_to_faq_msg(69))
361 with open(file_name, 'r') as f:
363 except Exception as exc:
364 raise Error("Failed to parse custom replacements configuration file '{}': {}. ".format(file_name, exc) +
365 refer_to_faq_msg(70)) from exc
368 validation_errors = list()
370 if 'id' not in attrs:
371 raise Error('One of the custom replacements in the configuration file "{}" does not contain attribute '
372 '"id". '.format(file_name) +
373 refer_to_faq_msg(71))
374 if 'match_kind' not in attrs:
375 raise Error('One of the custom replacements in the configuration file "{}" does not contain attribute '
376 '"match_kind". Possible values are "points", "scope" and "general". '.format(file_name) +
377 refer_to_faq_msg(71))
378 desc = CustomReplacementDescriptor.create_instance(attrs['match_kind'], attrs['id'], attrs)
379 validation_errors.extend(desc.validate_data())
381 if len(validation_errors) > 0:
382 raise Error("File '{}' validation failed:\n{}. ".format(file_name, "\n".join(validation_errors)) +
383 refer_to_faq_msg(72))
387 def generate_pattern_for_node(graph: Graph, sub_graph_pattern: str, node_name: str):
388 if sub_graph_pattern == '':
390 node_name_components = node_name.split("/")
392 matched_index = None # index of the node name component to start new pattern from
393 compiled_pattern = compile(sub_graph_pattern)
394 for index in range(0, len(node_name_components)):
395 cur_name += node_name_components[index] + "/"
396 if match(compiled_pattern, cur_name):
397 matched_index = index
399 if matched_index is None:
400 raise RuntimeError('Node name "{}" does not match pattern "{}"'.format(node_name, sub_graph_pattern))
402 if sub_graph_pattern == '' or sub_graph_pattern[-1] != '/':
403 sub_graph_pattern += '/'
405 sub_graph_nodes = nodes_matching_name_pattern(graph, sub_graph_pattern)
406 name_suffix = '/'.join(node_name_components[matched_index + 1:]) + '$'
407 if len([node for node in sub_graph_nodes if match(sub_graph_pattern + name_suffix, node)]) == 1:
410 raise RuntimeError('The pattern that uniquely identifies node "{}" using sub-graph pattern "{}" has not been found'.
411 format(node_name, sub_graph_pattern))