Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / utils / custom_replacement_config.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import json
18 import logging as log
19 import os
20 from re import compile, match
21
22 import networkx as nx
23
24 from mo.graph.graph import Node, Graph
25 from mo.utils.error import Error
26 from mo.utils.graph import nodes_matching_name_pattern, sub_graph_between_nodes
27 from mo.utils.utils import refer_to_faq_msg
28
29
30 class CustomReplacementDescriptor(object):
31     registered_types = dict()
32
33     def __init__(self, replacement_id: str, attrs: dict = None):
34         """
35         Create class instance based on attrs dictionary which is read from the configuration file.
36         :param attrs:
37         """
38         super(CustomReplacementDescriptor, self).__setattr__('replacement_id', replacement_id)
39         if attrs is not None:
40             super(CustomReplacementDescriptor, self).__setattr__('custom_attributes',
41                                                                  attrs.setdefault('custom_attributes', {}))
42             super(CustomReplacementDescriptor, self).__setattr__('_replacement_desc', attrs.copy())
43
44     def __getattr__(self, k):
45         return self._replacement_desc[k]
46
47     def __setattr__(self, k, v):
48         # you can assign only existing attributes
49         if k not in self._replacement_desc:
50             raise AttributeError
51         self._replacement_desc[k] = v
52
53     def has(self, attr):
54         """
55         Check that attribute 'attr' is defined for the CustomReplacementDescriptor.
56         :param attr: attribute to check.
57         :return: True if the attribute exists and False otherwise.
58         """
59         return attr in self._replacement_desc
60
61     @classmethod
62     def register_type(cls, match_kind: str, class_type: object):
63         if match_kind in cls.registered_types:
64             log.warning('Class for match kind "{}" is already registered'.format(match_kind))
65         else:
66             cls.registered_types[match_kind] = class_type
67
68     @classmethod
69     def create_instance(cls, match_kind: str, replacement_id: str, attrs: dict = None):
70         """
71         Fabric method to create proper object based on match_kind.
72         :param match_kind: match kind.
73         :param replacement_id: id of the replacement.
74         :param attrs: optional attributes to be set.
75         :return: object of the sub-class of the CustomLayerDescriptor class or None if the match kind is not registered.
76         """
77         if attrs is None:
78             attrs = dict()
79         if match_kind in cls.registered_types:
80             return cls.registered_types[match_kind](replacement_id, attrs)
81         else:
82             raise Error('No class registered for match kind "{}". Supported match kinds are "{}". '.format(
83                 match_kind, list(cls.registered_types.keys())) +
84                         refer_to_faq_msg(65))
85
86     def sub_graph_instances(self):
87         raise Exception("The function 'get_sub_graph_instances' must be implemented in the sub-class.")
88
89     def get_config_file_representation(self):
90         result = {
91             'match_kind': self.match_kind, 'instances': self.instances,
92             'inputs': self.inputs, 'outputs': self.outputs,
93             'custom_attributes': self.custom_attributes, 'id': self.id
94         }
95         if self.has('op'):
96             result.update({'op': self.op})
97         return result
98
99     def get_inputs_description(self):
100         """
101         Returns description of inputs of the layer with id 'layer_id'. The format of inputs is the following: list of
102         lists where each list contains information about nodes consuming the same tensor from outside of the graph. Each
103         element of the list is a pair where first element is a regular expression for the name of the node in the
104         sub-graph and the second is the input port of this node.
105         :return: description of inputs or None if layer with such id is not registered or information about inputs is
106         not available.
107         """
108         if 'inputs' not in self._replacement_desc:
109             log.error("Information about inputs of layer with id '{}' is not available".format(self.replacement_id))
110             return None
111         result = list()
112         for index, input_desc in enumerate(self._replacement_desc['inputs']):
113             result.append([(inp['node'], inp['port']) for inp in input_desc])
114         return result
115
116     def get_outputs_description(self):
117         """
118         Returns description of outputs of the layer with id 'layer_id'. The format of outputs is the following: list of
119         pairs where the first element of the pair is a regular expression for the name of the node that produces output
120         of the sub-graph and the second is the output port of this node.
121         :return: description of outputs or None if layer with such id is not registered or information about outputs is
122         not available.
123         """
124         if 'outputs' not in self._replacement_desc:
125             log.error("Information about outputs of layer with id '{}' is not available")
126             return None
127         return [(out['node'], out['port']) for out in self._replacement_desc['outputs']]
128
129     def update_custom_replacement_attributes(self, graph: Graph):
130         """
131         The function run specific functions to update attributes of the custom replacement description. Currently it
132         updates information about input/output nodes.
133         :param graph: graph to operate on.
134         :return: True if the update process completed successfully.
135         """
136         raise Exception("The function 'update_custom_layer_attributes' must be implemented in the sub-class.")
137
138     def validate_data(self):
139         """
140         Validates layer description dictionary.
141         :return: list of errors identified.
142         """
143         errors = list()
144         if not self.has('id'):
145             errors.append("Replacement id is not specified for custom replacement '{}'".format(self.replacement_id))
146         if not self.has('instances') or self.instances == '':
147             errors.append("Attribute 'instances' is not specified for replacement '{}'".format(self.replacement_id))
148         if not self.has('match_kind'):
149             errors.append("Replacement match type is not specified for replacement '{}'".format(self.replacement_id))
150         return errors
151
152
153 class CustomReplacementDescriptorPoints(CustomReplacementDescriptor):
154     """
155     Class that is used to describe custom replacement which is a sub-graph specified by start and end points.
156     """
157
158     def __init__(self, replacement_id: str, attrs: dict = None):
159         super().__init__(replacement_id, attrs)
160         if not self.has('include_inputs_to_sub_graph'):
161             super(CustomReplacementDescriptorPoints, self).__setattr__('include_inputs_to_sub_graph', True)
162         if not self.has('include_outputs_to_sub_graph'):
163             super(CustomReplacementDescriptorPoints, self).__setattr__('include_outputs_to_sub_graph', True)
164
165     def get_config_file_representation(self):
166         result = {
167             'match_kind': self.match_kind, 'instances': self.instances,
168             'custom_attributes': self.custom_attributes, 'id': self.id,
169             'include_inputs_to_sub_graph': bool(self.include_inputs_to_sub_graph),
170             'include_outputs_to_sub_graph': bool(self.include_outputs_to_sub_graph)
171         }
172         if self.has('op'):
173             result.update({'op': self.op})
174         return result
175
176     def get_inputs_description(self):
177         return [[('^' + node_name + '$', 0)] for node_name in self.instances['start_points']]
178
179     def get_outputs_description(self):
180         return [('^' + node_name + '$', 0) for node_name in self.instances['end_points']]
181
182     def get_internal_input_nodes(self, graph: Graph):
183         """
184         Gets list of node names getting input from outside of the sub-graph. This function checks whether input nodes
185         specified in the configuration file should be added to the sub-graph or not. If they should not be added to the
186         sub-graph then input nodes of the sub-graph are children of these nodes.
187         :param graph: graph to operate on.
188         :return: list of input node names.
189         """
190         if not self.include_inputs_to_sub_graph:
191             log.debug('Do not include inputs to sub-graph for replacement with id {}'.format(self.replacement_id))
192             new_start_nodes = set()
193             for start_node in self.instances['start_points']:
194                 for _, out_node_name in graph.out_edges(start_node):
195                     new_start_nodes.add(out_node_name)
196             start_nodes = list(new_start_nodes)
197             log.debug('New inputs are: {}'.format(start_nodes))
198             return start_nodes
199         else:
200             return self.instances['start_points']
201
202     def get_internal_output_nodes(self, graph: Graph):
203         """
204         Gets list of node names producing output outside of the sub-graph. This function checks whether output nodes
205         specified in the configuration file should be added to the sub-graph or not. If they should not be added to the
206         sub-graph then output nodes of the sub-graph are parents of these nodes.
207         :param graph: graph to operate on.
208         :return: list of output node names.
209         """
210         if not self.include_outputs_to_sub_graph:
211             log.debug('Do not include outputs of sub-graph for replacement with id {}'.format(self.replacement_id))
212             new_end_nodes = set()
213             for end_node in self.instances['end_points']:
214                 for in_node_name, _ in graph.in_edges(end_node):
215                     new_end_nodes.add(in_node_name)
216             end_nodes = list(new_end_nodes)
217             log.debug('New outputs are: {}'.format(end_nodes))
218             return end_nodes
219         else:
220             return self.instances['end_points']
221
222     def update_custom_replacement_attributes(self, graph: Graph):
223         if not self.has('instances'):
224             raise Error("No instance(s) is(are) defined for the custom replacement '{}'. ".format(self.replacement_id) +
225                         refer_to_faq_msg(66))
226         if not isinstance(self.instances, dict):
227             raise Error("The instance must be a single dictionary for the custom replacement with id '{}'. ".format(
228                 self.replacement_id) +
229                         refer_to_faq_msg(67))
230
231         start_points = self.get_internal_input_nodes(graph)
232         end_points = self.get_internal_output_nodes(graph)
233
234         matched_nodes = sub_graph_between_nodes(graph, start_points, end_points)
235         output_tensors = set()
236         input_nodes_mapping = dict()  # key is the input tensor name, value is the pair: (input_port, output_node_name)
237         for src_node_name, dst_node_name, edge_attrs in graph.edges(data=True):
238             dst_node = graph.node[dst_node_name]
239
240             # edge outside sub-graph into sub-graph
241             if (src_node_name not in matched_nodes) and (dst_node_name in matched_nodes):
242                 tensor_name = src_node_name + ":" + str(edge_attrs['out'])
243                 if tensor_name not in input_nodes_mapping:
244                     input_nodes_mapping[tensor_name] = list()
245                 input_nodes_mapping[tensor_name].append(('^' + dst_node_name + '$', edge_attrs['in']))
246
247             # edge from inside sub-graph to outside sub-graph
248             if (src_node_name in matched_nodes) and (dst_node_name not in matched_nodes):
249                 output_tensors.add(('^' + dst_node['pb'].input[edge_attrs['in']] + '$', edge_attrs['out']))
250
251         for node_name in graph.nodes():
252             node = Node(graph, node_name)
253             if node_name in matched_nodes and len(node.out_nodes()) == 0 and node['pb'].op != 'Const':
254                 log.debug("Node {} doesn't have output edges. Consider it output".format(node_name))
255                 output_tensors.add(('^' + node_name + '$', 0))
256
257         if not self.has('inputs'):
258             self._replacement_desc['inputs'] = [[{'node': desc[0], 'port': desc[1]} for desc in inp]
259                                                 for inp in sorted(input_nodes_mapping.values())]
260             log.debug('Updated inputs of sub-graph for instance "{}"'.format(self.instances))
261
262         if not self.has('outputs'):
263             self._replacement_desc['outputs'] = [{'node': node, 'port': port} for node, port in sorted(output_tensors)]
264             log.debug('Updated outputs of sub-graph for instance "{}"'.format(self.instances))
265
266     def sub_graph_instances(self):
267         return [self.instances]
268
269
270 CustomReplacementDescriptor.register_type('points', CustomReplacementDescriptorPoints)
271
272
273 class CustomReplacementDescriptorScope(CustomReplacementDescriptor):
274     """
275     Class that is used to describe custom layer which is a sub-graph specified by scope name.
276     """
277
278     def __init__(self, replacement_id: str, attrs: dict = None):
279         super().__init__(replacement_id, attrs)
280
281     def update_custom_replacement_attributes(self, graph: Graph):
282         if not self.has('instances') or len(self.instances) == 0:
283             raise Error("No instances are defined for replacement with id '{}'. ".format(self.replacement_id) +
284                         refer_to_faq_msg(68))
285
286         pattern = self.instances[0]  # use the first instance pattern to find input/output nodes patterns
287         # TODO verify that all instances will produce the same sub-graph
288         matched_nodes = nodes_matching_name_pattern(graph, pattern)
289
290         output_tensors = set()
291         input_nodes_mapping = dict()  # key is the input tensor name, value is the pair: (input_port, output_node_name)
292         for src_node_name, dst_node_name, edge_attrs in graph.edges(data=True):
293             dst_node = graph.node[dst_node_name]
294
295             # edge outside sub-graph into sub-graph
296             if (src_node_name not in matched_nodes) and (dst_node_name in matched_nodes):
297                 tensor_name = src_node_name + ":" + str(edge_attrs['out'])
298                 if tensor_name not in input_nodes_mapping:
299                     input_nodes_mapping[tensor_name] = list()
300                 input_nodes_mapping[tensor_name].append((generate_pattern_for_node(graph, pattern, dst_node_name),
301                                                          edge_attrs['in']))
302
303             # edge from inside sub-graph to outside sub-graph
304             if (src_node_name in matched_nodes) and (dst_node_name not in matched_nodes):
305                 output_tensors.add(
306                     (generate_pattern_for_node(graph, pattern, dst_node['pb'].input[edge_attrs['in']]),
307                      edge_attrs['out']))
308
309         for node_name in graph.nodes():
310             node = Node(graph, node_name)
311             if node_name in matched_nodes and len(node.out_nodes()) == 0 and node['pb'].op != 'Const':
312                 log.debug("Node {} doesn't have output edges. Consider it output".format(node_name))
313                 output_tensors.add((generate_pattern_for_node(graph, pattern, node_name), 0))
314
315         if not self.has('inputs'):
316             self._replacement_desc['inputs'] = [[{'node': desc[0], 'port': desc[1]} for desc in inp]
317                                                 for inp in sorted(input_nodes_mapping.values())]
318             log.debug('Updated inputs of sub-graph for instance "{}"'.format(self.instances))
319
320         if not self.has('outputs'):
321             self._replacement_desc['outputs'] = [{'node': node, 'port': port} for node, port in sorted(output_tensors)]
322             log.debug('Updated outputs of sub-graph for instance "{}"'.format(self.instances))
323
324     def sub_graph_instances(self):
325         return self.instances
326
327
328 CustomReplacementDescriptor.register_type('scope', CustomReplacementDescriptorScope)
329
330
331 class CustomReplacementDescriptorGeneral(CustomReplacementDescriptor):
332     def __init__(self, replacement_id: str, attrs: dict = None):
333         super().__init__(replacement_id, attrs)
334
335     def validate_data(self):
336         """
337         Validates layer description dictionary.
338         :return: list of errors identified.
339         """
340         errors = list()
341         if not self.has('id'):
342             errors.append("Replacement id is not specified for custom replacement '{}'".format(self.replacement_id))
343         if not self.has('match_kind'):
344             errors.append("Replacement match type is not specified for replacement '{}'".format(self.replacement_id))
345         return errors
346
347
348 CustomReplacementDescriptor.register_type('general', CustomReplacementDescriptorGeneral)
349
350
351 def parse_custom_replacement_config_file(file_name: str):
352     """
353     Reads custom replacement configuration file file_name.
354     :param file_name: name of the file to read from.
355     :return: The dictionary where key is the layer id and value is an instance of the CustomLayerDescriptor object.
356     """
357     if not os.path.exists(file_name):
358         raise Error("Custom replacements configuration file '{}' does not exist. ".format(file_name) +
359                     refer_to_faq_msg(69))
360     try:
361         with open(file_name, 'r') as f:
362             data = json.load(f)
363     except Exception as exc:
364         raise Error("Failed to parse custom replacements configuration file '{}': {}. ".format(file_name, exc) +
365                     refer_to_faq_msg(70)) from exc
366
367     result = list()
368     validation_errors = list()
369     for attrs in data:
370         if 'id' not in attrs:
371             raise Error('One of the custom replacements in the configuration file "{}" does not contain attribute '
372                         '"id". '.format(file_name) +
373                         refer_to_faq_msg(71))
374         if 'match_kind' not in attrs:
375             raise Error('One of the custom replacements in the configuration file "{}" does not contain attribute '
376                         '"match_kind". Possible values are "points", "scope" and "general". '.format(file_name) +
377                         refer_to_faq_msg(71))
378         desc = CustomReplacementDescriptor.create_instance(attrs['match_kind'], attrs['id'], attrs)
379         validation_errors.extend(desc.validate_data())
380         result.append(desc)
381     if len(validation_errors) > 0:
382         raise Error("File '{}' validation failed:\n{}. ".format(file_name, "\n".join(validation_errors)) +
383                     refer_to_faq_msg(72))
384     return result
385
386
387 def generate_pattern_for_node(graph: Graph, sub_graph_pattern: str, node_name: str):
388     if sub_graph_pattern == '':
389         return node_name
390     node_name_components = node_name.split("/")
391     cur_name = ''
392     matched_index = None  # index of the node name component to start new pattern from
393     compiled_pattern = compile(sub_graph_pattern)
394     for index in range(0, len(node_name_components)):
395         cur_name += node_name_components[index] + "/"
396         if match(compiled_pattern, cur_name):
397             matched_index = index
398             break
399     if matched_index is None:
400         raise RuntimeError('Node name "{}" does not match pattern "{}"'.format(node_name, sub_graph_pattern))
401
402     if sub_graph_pattern == '' or sub_graph_pattern[-1] != '/':
403         sub_graph_pattern += '/'
404
405     sub_graph_nodes = nodes_matching_name_pattern(graph, sub_graph_pattern)
406     name_suffix = '/'.join(node_name_components[matched_index + 1:]) + '$'
407     if len([node for node in sub_graph_nodes if match(sub_graph_pattern + name_suffix, node)]) == 1:
408         return name_suffix
409
410     raise RuntimeError('The pattern that uniquely identifies node "{}" using sub-graph pattern "{}" has not been found'.
411                        format(node_name, sub_graph_pattern))