Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / graph / graph.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import collections
18 import logging as log
19 from copy import deepcopy
20
21 import networkx as nx
22 import numpy as np
23
24 from mo.graph.port import Port
25 from mo.utils.error import Error
26 from mo.utils.utils import refer_to_faq_msg, deprecated_api, shrink_str_value
27
28
29 def dict_to_ordered_dict(d: dict):
30     return collections.OrderedDict(sorted(d.items(), key=lambda t: t[0]))
31
32
33 class Node:
34     def __init__(self, graph, node: str):
35         if node not in graph:
36             raise AttributeError("Attempt to access node {} that not in graph".format(node))
37
38         super(Node, self).__setattr__('graph', graph)
39         super(Node, self).__setattr__('node', node)  # obsolete
40         super(Node, self).__setattr__('id', node)
41
42     def __str__(self, max_length: int = 100):
43         node_dict = self.graph.node[self.id]
44         print_dict = {k: v if k != 'value' else shrink_str_value(v, max_symbols=max_length) for k, v in node_dict.items()}
45         return str(print_dict)
46
47     def __setattr__(self, k, v):
48         # you can assign only existing attributes
49         attrs = self.graph.node[self.node]
50         if not k in attrs:
51             raise AttributeError
52         attrs[k] = v
53
54     def __getattr__(self, k):
55         # hope it raises AttributeError if k is not in the dict
56         return self.graph.node[self.node][k]
57
58     def __getitem__(self, k):
59         return self.graph.node[self.node][k]
60
61     def __setitem__(self, k, v):
62         self.graph.node[self.node][k] = v
63
64     def __contains__(self, k):
65         return self.has(k)
66
67     def add_input_port(self, idx):
68         if not self.has_valid('_in_ports'):
69             Node(self.graph, self.id)['_in_ports'] = set()
70         if idx in self.in_ports():
71             raise Error("Input port with {} index already exists for {} node.".format(idx, self.name))
72         self._in_ports.add(idx)
73
74     def add_output_port(self, idx):
75         if not self.has_valid('_out_ports'):
76             Node(self.graph, self.id)['_out_ports'] = set()
77         if idx in self.out_ports():
78             raise Error("Output port with {} index already exists for {} node.".format(idx, self.name))
79         self._out_ports.add(idx)
80
81     def in_port(self, idx=None) -> Port:
82         if not self.has_valid('_in_ports'):
83             raise Error("Operation {} {} has no _in_ports attribute", self.op, self.name)
84         if idx not in self._in_ports:
85             raise Error("Input port with index {} is not in node {}".format(idx, self.name))
86         return Port(node=self, idx=idx, type='in')
87
88     def in_ports(self):
89         if not self.has_valid('_in_ports'):
90             raise Error("Operation {} {} has no _in_ports attribute", self.op, self.name)
91         return dict_to_ordered_dict({idx: self.in_port(idx) for idx in self._in_ports})
92
93     def out_port(self, idx=None) -> Port:
94         if not self.has_valid('_out_ports'):
95             raise Error("Operation {} {} has no _out_ports attribute", self.op, self.name)
96         if idx not in self._out_ports:
97             raise Error("Output port with index {} is not in node {}".format(idx, self.name))
98         return Port(node=self, idx=idx, type='out')
99
100     def out_ports(self):
101         if not self.has_valid('_out_ports'):
102             raise Error("Operation {} {} has no _out_ports attribute", self.op, self.name)
103         return dict_to_ordered_dict({idx: self.out_port(idx) for idx in self._out_ports})
104
105     def has_port(self, port_type, idx):
106         assert port_type in ['in', 'out'], "Invalid usage of has_port method"
107
108         if port_type == 'in':
109             return self.has_valid('_in_ports') and idx in self.in_ports()
110         else:
111             return self.has_valid('_out_ports') and idx in self.out_ports()
112
113     def attrs(self):
114         return self.graph.node[self.node]
115
116     def has(self, k):
117         return k in self.graph.node[self.node]
118
119     def has_valid(self, k):
120         return self.has(k) and not self.graph.node[self.node][k] is None
121
122     def has_and_set(self, k):
123         return self.has_valid(k) and self[k]
124
125     def in_nodes_edges(self, control_flow: bool=False):
126         return dict_to_ordered_dict({x[1]['in']: (Node(self.graph, x[0]), x[1]) for x in
127                                      self.get_inputs(control_flow=control_flow)})
128
129     def in_nodes(self, control_flow: bool=False):
130         assert self.has('kind')  # TODO: remove as it always exists
131         assert self.kind in ['op', 'data'] # TODO: remove as it always exists
132         if self.kind == 'op':
133             return dict_to_ordered_dict({x[1]['in']: Node(self.graph, x[0]) for x in
134                                          self.get_inputs(control_flow=control_flow)})
135         elif self.kind == 'data':
136             return [Node(self.graph, n) for n, d in self.get_inputs(control_flow=control_flow)]
137
138     def in_node(self, key=0, control_flow: bool=False):
139         return self.in_nodes(control_flow=control_flow)[key]
140
141     def in_edges(self, control_flow: bool=False):
142         assert self.has('kind')
143         assert self.kind in ['op', 'data']
144         if self.kind == 'op':
145             return dict_to_ordered_dict({x[1]['in']: x[1] for x in self.get_inputs(control_flow=control_flow)})
146         elif self.kind == 'data':
147             return [d for n, d in self.get_inputs(control_flow=control_flow)]
148
149     def out_nodes_edges(self, control_flow: bool=False):
150         return dict_to_ordered_dict({x[1]['out']: (Node(self.graph, x[0]), x[1]) for x in
151                                      self.get_outputs(control_flow=control_flow)})
152
153     def out_nodes(self, control_flow: bool=False):
154         assert self.has('kind')
155         assert self.kind in ['op', 'data']
156         if self.kind == 'op':
157             return dict_to_ordered_dict({x[1]['out']: Node(self.graph, x[0]) for x in
158                                          self.get_outputs(control_flow=control_flow)})
159         elif self.kind == 'data':
160             return [Node(self.graph, n) for n, d in self.get_outputs(control_flow=control_flow)]
161
162     def out_edges(self, control_flow: bool=False):
163         assert self.has('kind')
164         assert self.kind in ['op', 'data']
165         if self.kind == 'op':
166             return dict_to_ordered_dict({x[1]['out']: x[1] for x in self.get_outputs(control_flow=control_flow)})
167         elif self.kind == 'data':
168             return [d for n, d in self.get_outputs(control_flow=control_flow)]
169
170     def out_node(self, key=0, control_flow: bool=False):
171         return self.out_nodes(control_flow=control_flow)[key]
172
173     def in_edge(self, key=0, control_flow: bool=False):
174         return self.in_edges(control_flow=control_flow)[key]
175
176     def out_edge(self, key=0, control_flow: bool=False):
177         return self.out_edges(control_flow=control_flow)[key]
178
179     def get_attrs(self):
180         return self.graph.node[self.node]
181
182     def get_inputs(self, edge_attr: dict = None, control_flow: bool = False):
183         if edge_attr is None:
184             edge_attr = {}
185         in_edges = self.graph.in_edges(self.id, data=True)
186         if not control_flow:
187             in_edges = [(u, v, d) for u, v, d in in_edges if 'control_flow_edge' not in d or not d['control_flow_edge']]
188         return [(u, d) for u, v, d in in_edges if all([attr in d and d[attr] == edge_attr[attr] for attr in edge_attr])]
189
190     def get_outputs(self, edge_attr: dict = None, control_flow: bool = False):
191         if edge_attr is None:
192             edge_attr = {}
193         out_edges = self.graph.out_edges(self.id, data=True)
194         if not control_flow:
195             out_edges = [(u, v, d) for u, v, d in out_edges if
196                          'control_flow_edge' not in d or not d['control_flow_edge']]
197         return [(v, d) for u, v, d in out_edges if
198                 all([attr in d and d[attr] == edge_attr[attr] for attr in edge_attr])]
199
200     def get_sorted_inputs(self, control_flow: bool = False):
201         return sorted([x for x in self.get_inputs(control_flow=control_flow) if 'in' in x[1]],
202                       key=lambda x: x[1]['in'])
203
204     def get_sorted_outputs(self, control_flow: bool = False):
205         return sorted([x for x in self.get_outputs(control_flow=control_flow) if 'out' in x[1]],
206                       key=lambda x: x[1]['out'])
207
208     def soft_get(self, k):
209         return self[k] if self.has_valid(k) else '<UNKNOWN>'
210
211     def edges(self, attrs: dict=None):
212         """ Get a single edge with specified set of attributes.
213
214             If none or multiple edges satisfies this criteria, exception is raised
215             Edge is represented as tuple (u, v, d), where u is source node,
216             v is destination node and d is edge attributes.
217         """
218         edges = list(self.graph.in_edges([self.id], data=True)) + list(self.graph.out_edges([self.id], data=True))
219         return [(u, v, d) for u, v, d in edges if dict_includes(d, attrs)]
220
221     def edge(self, attrs: dict=None):
222         """ Get a single edge with specified set of attributes.
223
224             If none or multiple edges satisfies this criteria, exception is raised
225             Edge is represented as tuple (u, v, d), where u is source node,
226             v is destination node and d is edge attributes.
227         """
228         edges = self.edges(attrs)
229         assert len(edges) == 1, 'edges: {}, required attributes: {}'.format(edges, attrs)
230         return edges[0]
231
232     def copy_node(self, new_attrs: dict = None, dst_graph=None):
233         ''' Copies node with all attributes (optionally updated) within the same graph or to different graph.'''
234         if new_attrs is None:
235             new_attrs = {}
236         if dst_graph is None:
237             dst_graph = self.graph
238
239         attrs = deepcopy(self.attrs())
240         attrs.update(new_attrs)
241         new_id = dst_graph.unique_id()
242         dst_graph.add_node(new_id, **attrs)
243         return Node(dst_graph, new_id)
244
245     def insert_node_with_data_before(self, inp, new_op_class: callable, op_before_params: dict = None,
246                                      infer_current: bool = False, additional_inputs: list = None):
247         """
248         Inserts operation node with op_before_params and data node before current operation
249
250         :param inp: input data node of current node
251         :param new_op_class: class of operation that will be inserted before current operation node
252         :param op_before_params: parameters to be added to operation that will be inserted before current operation
253
254         Before calling:
255         [...] -> inp -> Cur_Op -> Cur_Data -> [...]
256
257         After calling:
258         [...] -> inp -> New_Op_bef -> New_Data_bef -> Cur_Op -> Cur_Data -> [...]
259                     [op_before_params]
260         """
261         graph = self.graph
262         node = Node(graph, self.node)
263         cls_name = new_op_class.op
264         op_before_params = {} if op_before_params is None else op_before_params
265
266         # operating with input
267         new_op_before = new_op_class(graph, op_before_params)
268         edge_attrs = deepcopy(graph.get_edge_data(inp.id, node.id)[0])
269         graph.remove_edge(inp.id, node.id)
270         # form a list of input nodes for a new op node combining new_out and additional_inputs
271         inputs = [inp] + (additional_inputs if additional_inputs else [])
272         new_inp = new_op_before.create_node_with_data(inputs, {'name': node.name + cls_name + '/Before'})
273         graph.add_edge(new_inp.id, node.id, **edge_attrs)
274         if infer_current:
275             node.infer(node)
276
277     def insert_node_with_data_after(self, out, new_op_class: callable, op_after_params: dict = None,
278                                     additional_inputs: list = None):
279         """
280         Inserts operation node with op_after_params and data node after current operation
281
282         :param out: output data node of current node
283         :param new_op_class: class of operation that will be inserted after current operation node
284         :param op_after_params:  parameters to be added to operation that will be inserted after current operation
285         :param additional_inputs:  other parameters for a new operation node in addition to one that is created
286             at the 'out' placed; new nodes are added after 0-th input
287
288             TODO Allow indexing for input parameters as well as for 'out' data node to explicitly
289                 specify ports that are connected to.
290
291         Before calling:
292         [...] -> Cur_Op -> Cur_Data -> [...]
293
294         After calling:
295         [...] -> Cur_Op -> Cur_Data -> New_Op_aft -> New_Data_aft(==out) -> [...]
296                                    [op_after_params]
297         """
298         # we import it here because Op imports Node and unique_id from this file
299         from mo.ops.op import Op
300
301         graph = self.graph
302         node = Node(graph, self.node)
303         cls_name = new_op_class.op
304         op_after_params = {} if op_after_params is None else op_after_params
305
306         new_op_after = new_op_class(graph, op_after_params)
307         graph.remove_edge(node.id, out.id)
308         new_out = Op.create_data_node(graph, node)
309         node.infer(node)
310         # form a list of input nodes for a new op node combining new_out and additional_inputs
311         inputs = [new_out] + (additional_inputs if additional_inputs else [])
312         new_op_after.create_node_with_data(inputs, {'name': node.name + cls_name + '/After'}, data_nodes=out)
313
314     def bracket_with_different_nodes_with_data(self, inp, out, new_op_class_before: callable,
315                                                new_op_class_after: callable,
316                                                op_before_params: dict = None, op_after_params: dict = None):
317         """
318         Inserts one operation node with op_before_params and data node before current operation node and
319         inserts one operation node with op_after_params and data node after current operation node
320         :param inp: input data node of self.node node
321         :param out: output data node of self.node node
322         :param new_op_class_before: class of operation that will be inserted before current operation node
323         :param new_op_class_after: class of operation that will be inserted after current operation node
324         :param op_before_params: parameters to be added to operation that will be inserted before current operation
325         :param op_after_params: parameters to be added to operation that will be inserted after current operation
326
327         Before calling:
328         [...] -> inp -> Cur_Op -> out -> [...]
329
330         After calling:
331         [...] -> inp -> New_Op_bef -> New_Data_bef -> Cur_Op -> Cur_Data -> New_Op_aft -> New_Data_aft(==out) -> [...]
332                     [op_before_params]                                  [op_after_params]
333         """
334         op_before_params = {} if op_before_params is None else op_before_params
335         op_after_params = {} if op_after_params is None else op_after_params
336         self.insert_node_with_data_before(inp, new_op_class_before, op_before_params)
337         self.insert_node_with_data_after(out, new_op_class_after, op_after_params)
338
339     def bracket_op_with_another_op(self, inp, out, new_op_class: callable,
340                                    op_before_params: dict = None, op_after_params: dict = None):
341         """
342         Covers current operation with two similar another ones of class new_op_class:
343         :param inp: input data node of self.node node
344         :param out: output data node of self.node node
345         :param new_op_class: class of operation with which current operation will be covered
346         :param op_before_params: parameters to be added to operation that will be inserted before current operation
347         :param op_after_params: parameters to be added to operation that will be inserted after current operation
348
349         Before calling:
350         [...] -> inp -> Cur_Op -> out -> [...]
351
352         After calling:
353         [...] -> inp -> New_Op_bef -> New_Data_bef -> Cur_Op -> Cur_Data -> New_Op_aft -> New_Data_aft(==out) -> [...]
354                     [op_before_params]                                  [op_after_params]
355         """
356         self.bracket_with_different_nodes_with_data(inp=inp, out=out,
357                                                     new_op_class_before=new_op_class, new_op_class_after=new_op_class,
358                                                     op_before_params=op_before_params, op_after_params=op_after_params)
359
360     def insert_node_after(self, new_node, node_out_port: int = 0):
361         """
362         Insert node 'new_node' after output with index 'node_out_port' of the node 'node'. All consumers of node 'node'
363         output with index 'node_out_port' will be changed to consume node 'new_node'.
364         The function should be used when graph doesn't contain data nodes yet.
365         :param node: node after which new node should be inserted.
366         :param new_node: node to be inserted.
367         :param node_out_port: the output index for the node 'node' to insert
368         :return: None
369         """
370         assert self.graph is new_node.graph
371         assert (len([name for name in self.graph.nodes() if Node(self.graph, name).soft_get('kind') == 'data']) == 0)
372
373         graph = self.graph
374         old_edges = list(graph.out_edges(self.id, data=True, keys=True))
375         # create new edges first and then remove all old edges. This is needed for case when 'node' has several consumers
376         # getting input from 'node_out_port'.
377         # save tuple ("name of the destination edge", "edge key") to be removed
378         node_name_and_edge_key = []
379         for _, dst_name, edge_key, edge_attrs in old_edges:
380             if edge_attrs['out'] == node_out_port:
381                 log.debug('Create edge from "{}" to "{}"'.format(new_node.name, dst_name))
382                 graph.create_edge(new_node, Node(graph, dst_name), 0, edge_attrs['in'])
383                 node_name_and_edge_key.append((dst_name, edge_key))
384         for dst_name, edge_key in node_name_and_edge_key:
385             log.debug('Remove edge from "{}" to "{}"'.format(self.id, dst_name))
386             graph.remove_edge(self.id, dst_name, edge_key)
387         graph.create_edge(self, new_node, node_out_port, 0, {})
388
389     def replace_node(self, new_node, new_node_out_port: int=None):
390         """
391         Replaces node 'old_node' with a node 'new_node' preserving edge attributes.
392         :param old_node: node to be replaced.
393         :param new_node: node to replace with.
394         :return: None
395         """
396         assert self.graph is new_node.graph
397         assert self.id != new_node.id, "New node and replaceable node are the same"
398         graph = self.graph
399         # save output edges and reconnect them to new node
400         for _, dst_node_name, edge_attrs in graph.out_edges(self.id, data=True):
401             new_edge_attrs = deepcopy(edge_attrs)
402             if new_node_out_port is not None:
403                 assert 'out' not in edge_attrs or edge_attrs['out'] == 0, \
404                     'replace_node function can replace old node with a single output port only if new_node_out_port is ' \
405                     'specified'
406                 new_edge_attrs.update({'out': new_node_out_port})
407             graph.add_edge(new_node.id, dst_node_name, **new_edge_attrs)
408
409         # if the node for replace is output node then we propagate this attribute to a new node
410         if len(self.out_nodes()) == 1 and self.out_node().has('op') and self.out_node().op == 'OpOutput':
411             graph.remove_node(self.out_node().id)
412             add_opoutput(graph, new_node.id, 0, False)
413         graph.remove_node(self.id)
414
415     def input_ports_with(self, node):
416         """
417         Returns a list of integers that specify input ports that connected to a given node.
418         :param node: node in the graph that is expected to appear at input port for self node
419         :return: a list of integers with port indices that are connected to self node
420         """
421         return [i for i in range(len(self.in_nodes())) if self.in_node(i).id == node.id]
422
423 class Graph(nx.MultiDiGraph):
424     def __init__(self, data=None, **attr):
425         self.stage = None
426         super().__init__(data, **attr)
427
428     unique_id_count = 0
429
430     # SAFE API DESCRIPTION
431     # all provided methods below are designed to be more safe and convenient
432     # be careful while using other methods from nx.MultiDiGraph
433
434     def add_node(self, node_for_adding, **attrs):
435         # TODO: check required attrs for node
436         super().add_node(node_for_adding, **attrs)
437         node = Node(self, node_for_adding)
438
439         in_ports_count = node.in_ports_count if node.has_valid('in_ports_count') else None
440         out_ports_count = node.out_ports_count if node.has_valid('out_ports_count') else None
441
442         node['_in_ports'] = set()
443         node['_out_ports'] = set()
444
445         if in_ports_count is not None:
446             for idx in range(in_ports_count):
447                 node.add_input_port(idx=idx)
448
449         if out_ports_count is not None:
450             for idx in range(out_ports_count):
451                 node.add_output_port(idx=idx)
452
453     def add_edge(self, u_for_edge, v_for_edge, key=None, **attr):
454         return super().add_edge(u_for_edge, v_for_edge, key=key, **attr)
455
456     def add_edges_from(self, ebunch_to_add, **attr):
457         for e in ebunch_to_add:
458             ne = len(e)
459             if ne == 4:
460                 u, v, key, dd = e
461             elif ne == 3:
462                 u, v, dd = e
463                 key = None
464             elif ne == 2:
465                 u, v = e
466                 dd = {}
467                 key = None
468             else:
469                 raise Error("Edge tuple %s must be a 2-tuple, 3-tuple or 4-tuple." % (e,))
470             ddd = attr.copy()
471             ddd.update(dd)
472             self.add_edge(u, v, key=key, **ddd)
473
474     def remove_edge(self, u, v, key=None):
475         return super().remove_edge(u, v, key=key)
476
477     def erase_node(self, node: Node):
478         """
479         Erases node from the graph and reconnect edges from input node(s) to output node(s)
480         Produces assertion error if the node being removed has multiple inputs or outputs.
481         The function can be used in the front phase only (when there are no data nodes in the graph).
482         :param node: Node to erase
483         """
484         node_id = node.id
485
486         inputs = list(self.in_edges(node_id, data=True))
487         outputs = list(self.out_edges(node_id, data=True))
488
489         assert node.kind == 'op' and (len(node.out_nodes()) == 0 or list(node.out_nodes().values())[0].kind != 'data'), \
490             "The function must be used before the partial infer when graph doesn't contain data nodes."
491         assert len(node.out_nodes()) <= 1, "The node {} must produce just one output tensor".format(
492             node.soft_get('name'))
493         assert len(inputs) <= 1, "The node {} must have just one input".format(node.soft_get('name'))
494
495         if len(outputs) == 0 and len(inputs) != 0:
496             from mo.front.extractor import add_output_ops
497             input_ids = {input_node_id: {'port': {'out': [attrs['out']]}} for input_node_id, _, attrs in inputs}
498             if node.has('op') and node.op == 'OpOutput':
499                 add_output_ops(self, input_ids)
500
501         if len(outputs) == 0 or len(inputs) == 0:
502             self.remove_node(node_id)
503             return
504
505         input_node_id = inputs[0][0]
506         for src, dst, attrs in outputs:
507             self.remove_edge(src, dst)
508             # update the 'out' attribute of the edge from the node being removed
509             attrs['out'] = inputs[0][2]['out']
510             self.add_edge(input_node_id, dst, **attrs)
511         self.remove_node(node_id)
512
513     def get_edge_data(self, u, v, key=None, default=None):
514         return super().get_edge_data(u, v, key=key, default=default)
515
516     def get_inputs_with_ports(self, match, pattern_edges, input_names_in_pattern):
517         """
518         Front replacements of multi-input nodes should specify output port to add_node-like functions
519         This function is a helper to get such information out of matched nodes
520         :param graph: graph to operate on
521         :param match: dictionary returned by matching function
522         :param pattern_edges: edges that are specified in pattern
523         :param input_names_in_pattern: names of matched nodes as they were specified in pattern that should be in
524         resulting list
525         :return: list of tuples of node and output port
526         """
527         inputs = []
528         for name in input_names_in_pattern:
529             assert name in match, "node named {} not in match {}".format(name, match)
530             src = match[name]
531             dst = []
532             for edge in pattern_edges:
533                 if edge[0] == name:
534                     assert edge[1] in match, "name from pattern_edges {} not in match {}".format(edge[1], match)
535                     dst.append(match[edge[1]])
536             if len(dst) != 1:
537                 raise Error('Multiple output ports detected for node {} as {} in pattern'.format(match[name].id, name))
538             dst = dst[0]
539             out_port = self.get_edge_data(src.id, dst.id)[0]['out']
540             inputs.append((src, out_port))
541         return inputs
542
543     def get_node_id_by_name(self, name: str):
544         for node in self.nodes():
545             if 'name' in self.node[node] and self.node[node]['name'] == name:
546                 return node
547         raise Error('No node with name {}. ' +
548                     refer_to_faq_msg(51), name)
549
550     def get_op_nodes(self, **attrs):
551         nodes = self.get_nodes_with_attributes(**dict(kind='op', **attrs))
552         return [Node(self, node) for node in nodes]
553
554     def get_data_nodes(self, has_value=None):
555         """
556         Returns list of data nodes.
557         If has_value = True, returns data nodes with value
558         If has_value = False, returns data nodes without value
559         """
560         data_nodes = [Node(self, node) for node in self.nodes() if Node(self, node).soft_get('kind') == 'data']
561         return [node for node in data_nodes if has_value is None or node.has_valid('value') == has_value]
562
563     def get_nodes_with_attributes(self, **attrs: dict):
564         node_attrs = self.nodes(data=True)
565         return [n for n, d in node_attrs if all(a in d.items() for a in attrs.items())]
566
567     def unique_id(self, prefix: str = ""):
568         """
569         Generates a unique node id for a new node in a given graph.
570         The optional string prefix can be specified.
571         """
572         # TODO thread safety?
573         self.unique_id_count = max(self.unique_id_count, self.number_of_nodes()) + 1
574         if prefix and not self.has_node(prefix):
575             return str(prefix)
576         while self.has_node(prefix + str(self.unique_id_count)):
577             self.unique_id_count += 1
578         return prefix + str(self.unique_id_count)
579
580     def check_empty_graph(self, description: str):
581         if len(self.nodes()) <= 1:
582             raise Error(
583                 "Graph contains {} node after executing {}. It considered as error because resulting IR will be "
584                 "empty which is not usual".format(len(self.nodes()), description))
585
586     def check_shapes_consistency(self):
587         data_nodes = self.get_data_nodes()
588         data_nodes_with_wrong_shapes = []
589         for data_node in data_nodes:
590             if not data_node.has('shape'):
591                 data_nodes_with_wrong_shapes.append((data_node.name, "no shape attribute"))
592                 continue
593             if data_node.shape is not None and not isinstance(data_node.shape, np.ndarray):
594                 data_nodes_with_wrong_shapes.append((data_node.name, type(data_node.shape)))
595         if len(data_nodes_with_wrong_shapes) > 0:
596             raise Error("Graph contains data nodes ({}) with inconsistent shapes: {}".format(
597                 len(data_nodes_with_wrong_shapes),
598                 data_nodes_with_wrong_shapes
599             ))
600
601     def check_nodes_ports_are_consecutive(self):
602         # Check that all operation nodes has consecutive ports indexes
603         op_nodes = self.get_op_nodes()
604         for node in op_nodes:
605             for idx in range(len(node.in_ports())):
606                 if idx not in node.in_ports():
607                     raise Error("Node {} has not consecutive in ports indexes: {}".format(node.name,
608                                                                                           list(node.in_ports().keys())))
609             for idx in range(len(node.out_ports())):
610                 if idx not in node.out_ports():
611                     raise Error("Node {} has not consecutive out ports indexes: {}".format(node.name,
612                                                                                            list(node.out_ports().keys())))
613
614     def dump_graph_for_graphviz(self, node_attrs: list = ['kind', 'op', 'shape'],
615                                 edge_attrs: list = ['in', 'out'],
616                                 nodes_to_dump: list = None, save_to_svg=False):
617         log.debug("---- GRAPHVIZ OUTPUT STARTS ----")
618         if nodes_to_dump is None:
619             nodes_to_dump = self.nodes()
620         string = '\ndigraph {\n'
621         visited_nodes = set()
622         for src_node_name, dst_node_name, attrs in self.edges(data=True):
623             visited_nodes.add(src_node_name)
624             visited_nodes.add(dst_node_name)
625             if src_node_name not in nodes_to_dump or dst_node_name not in nodes_to_dump:
626                 continue
627             src_node = self.node[src_node_name]
628             dst_node = self.node[dst_node_name]
629             src_node_string = str(src_node_name) + '\\n' + '\\n'.join(
630                 [str(key) + '=' + str(src_node.get(key, 'None')) for key in node_attrs if key in src_node])
631             dst_node_string = str(dst_node_name) + '\\n' + '\\n'.join(
632                 [str(key) + '=' + str(dst_node.get(key, 'None')) for key in node_attrs if key in dst_node])
633             edge_string = ' '.join([str(key) + '=' + str(attrs.get(key, 'None')) for key in edge_attrs if key in attrs])
634             string += '"{}" -> "{}" [label = "{}"];\n'.format(src_node_string, dst_node_string, edge_string)
635         for node in nodes_to_dump:
636             if node not in visited_nodes:
637                 string += '"{}"'.format(node)  # TODO: add attributes like it was done in the loop above
638                 visited_nodes.add(node)
639         string += '}'
640         log.debug(string)
641         log.debug("---- GRAPHVIZ OUTPUT ENDS ----")
642
643         if save_to_svg:
644             try:
645                 import graphviz
646                 import os
647                 file_name = "{}_{}.txt".format(self.name.replace('/', '_'), 0)
648                 id = 1
649                 while os.path.exists(file_name):
650                     file_name = "{}_{}.txt".format(self.name.replace('/', '_'), id)
651                     id += 1
652                 with open(file_name, "w") as f:
653                     f.write(string)
654                 graphviz.render('dot', 'svg', file_name)
655                 print('Graph was saved to {}.{}'.format(file_name, 'svg'))
656             except ImportError:
657                 raise ImportError('Can\'t import graphviz')
658             except Exception as e:
659                 raise Error('Can\'t save graph to svg') from e
660
661         return string
662
663     def print_graph_stat(self):
664         log.debug('Number of nodes in graph: {}'.format(self.number_of_nodes()))
665         log.debug('Number of edges in graph: {}'.format(len(list(self.edges()))))
666         ops = collections.defaultdict(int)
667         for _node in self.nodes():
668             node = Node(self, _node)
669             kind = node.kind if node.has('kind') else '<UNDEFINED>'
670             if node.has('op'):
671                 ops['op/' + node.op] += 1
672             else:
673                 ops[kind] += 1
674             if node.has('shape') and np.any(node.shape == 0):
675                 log.error("Found bad shape: '{}' for node '{}'".format(node.shape, node.node))
676         for k, v in ops.items():
677             log.debug('   {} : {}'.format(k, v))
678
679     def create_sub_graph_copy(self, nodes_to_extract: list):
680         """
681         Create new graph which is a sub-graph of the 'graph' that contains just nodes from 'nodes_to_extract' list. The
682         returned sub-graph is a deep copy of the provided graph nodes.
683         :param graph: graph to create a sub-graph from.
684         :param nodes_to_extract: list of node names to extract.
685         :return: new graph.
686         """
687         return self.subgraph(nodes_to_extract).copy()
688
689     def create_edge(self, src_node: Node, dst_node: Node, out_port: int = 0, in_port: int = 0, edge_attrs: dict = None):
690         """
691         Creates edge from node 'src_node' from output with index 'out_port' to node 'dst_node' with input index 'in_port'.
692         :param src_node: node to create edge from.
693         :param dst_node: node to create edge to.
694         :param out_port: the index of output tensor of the 'src_node'.
695         :param in_port: the input index of the node 'dst_node'.
696         :param edge_attrs: dictionary with edge attrs.
697         :return: None
698         """
699         # edges must belong to the same graph
700         assert src_node.graph is dst_node.graph
701         graph = src_node.graph
702
703         if edge_attrs is None:
704             edge_attrs = dict()
705         else:
706             edge_attrs = edge_attrs.copy()
707         edge_attrs.update(
708             {'in': in_port, 'out': out_port, 'in_attrs': ['in', 'permutation'], 'out_attrs': ['out', 'permutation'],
709              'data_attrs': ['fw_tensor_debug_info']})
710
711         # TODO: in case if in_port do not exists, we should raise an Exception here
712         graph.add_edges_from([(src_node.id, dst_node.id, edge_attrs)])
713
714
715 def create_graph_with_nodes(src_nodes, get_id: callable, get_attrs: callable):
716     """
717     Go over all nodes in src_nodes that should be enumerable and create new NX nodes
718     using get_id and get_attrs functions to create node id and node attributes correspondingly.
719     """
720     graph = Graph()
721     for node in src_nodes:
722         graph.add_node(get_id(node), **get_attrs(node))
723     return graph
724
725
726 def dict_includes_compare_attrs(attr, attr_probe):
727     if callable(attr_probe) and not isinstance(attr_probe, type):
728         return attr_probe(attr)
729     else:
730         return attr == attr_probe
731
732
733 def dict_includes(big: dict, sub_dict: dict, skip_attr_names=[]):
734     """ Searches attributes from sub_dict in big and ensures that all values match.
735
736         Entries in sub_dict can be of two types: callable or not callable. If callable is specified
737         it is treated as probing function for attribute value from big dictionary by callable(attr) expression.
738         If it is not callable, the values are compared with == operator.
739     """
740     return all(
741         dict_includes_compare_attrs(big.get(attr, None), sub_dict[attr])
742         for attr in sub_dict.keys() if attr not in skip_attr_names
743     )
744
745
746 def add_opoutput(graph: Graph, node_name: str, port: int, cut: bool = True):
747     """
748     Creates and connects OpOutput node to node_name port. Cuts existing port if requested.
749     :param graph: graph to operate with
750     :param node_name: name of existing node in the graph that we want to add OpOutput to
751     :param port: output port of node to connect OpOutput to
752     :param cut: determines way of operating with edge specified by node_name and port
753     """
754     # we import it here because Op imports add_attrs_props and update_ie_fields from this file
755     from mo.ops.output import Output
756     node = Node(graph, node_name)
757     if cut and len(node.out_edges()) != 0:
758         opoutput_node = Output(graph).create_node_on_port(node, port, {'name': node_name + '/sink_port_' + str(port)})
759     else:
760         opoutput_node = Output(graph).create_node([(node, port)], {'name': node_name + '/sink_port_' + str(port)})
761         opoutput_node.in_edge()['data_attrs'] = ['fw_tensor_debug_info']
762         opoutput_node.in_edge()['fw_tensor_debug_info'] = [(node_name, port)]
763     log.debug('Sink: {} for node {}'.format(opoutput_node.id, node_name))
764     log.debug(str(graph.node[opoutput_node.id]))
765     log.debug("Add edge from {} to {}".format(node_name, opoutput_node.id))
766     return opoutput_node.id
767
768
769 # TODO implement merging for keys with dictionary values?
770 def merge_edge_props(attrs: dict, additional_attrs: dict):
771     """
772     Update edge attributes without changing 'in' and 'out' keys.
773     It is necessary to copy edge attributes during merging of nodes when
774     result of one subgraph call is passed as input to another subgraph call
775     """
776     result = attrs
777     for (key, value) in additional_attrs.items():
778         if key not in ['in', 'out']:
779             if type(additional_attrs[key]) is list:
780                 if key not in result:
781                     result[key] = []
782                 result[key].extend(additional_attrs[key])
783                 result[key] = list(set(result[key]))  # silly solution to find unique elements
784             else:
785                 result[key] = value
786     return result
787
788
789 # All functions below are deprecated and will be removed in next release
790 # Please, use methods from Graph/Node classes instead
791
792
793 @deprecated_api(Graph)
794 def get_node_id_by_name(graph: Graph, name: str):
795     return graph.get_node_id_by_name(name=name)
796
797
798 @deprecated_api(Graph)
799 def print_graph_stat(graph: Graph):
800     return graph.print_graph_stat()
801
802
803 @deprecated_api(Graph)
804 def get_inputs_with_ports(graph: Graph, match, pattern_edges, input_names_in_pattern):
805     """
806     Front replacements of multi-input nodes should specify output port to add_node-like functions
807     This function is a helper to get such information out of matched nodes
808     :param graph: graph to operate on
809     :param match: dictionary returned by matching function
810     :param pattern_edges: edges that are specified in pattern
811     :param input_names_in_pattern: names of matched nodes as they were specified in pattern that should be in
812     resulting list
813     :return: list of tuples of node and output port
814     """
815     return graph.get_inputs_with_ports(match=match,
816                                        pattern_edges=pattern_edges,
817                                        input_names_in_pattern=input_names_in_pattern)
818
819
820 @deprecated_api(Graph)
821 def dump_graph_for_graphviz(graph: Graph, node_attrs: list = ['kind', 'op', 'shape'],
822                             edge_attrs: list = ['in', 'out'],
823                             nodes_to_dump: list = None, save_to_svg=False):
824     return graph.dump_graph_for_graphviz(node_attrs=node_attrs,
825                                          edge_attrs=edge_attrs,
826                                          nodes_to_dump=nodes_to_dump,
827                                          save_to_svg=save_to_svg)
828
829
830 @deprecated_api(Graph)
831 def create_sub_graph_copy(graph: Graph, nodes_to_extract: list):
832     """
833     Create new graph which is a sub-graph of the 'graph' that contains just nodes from 'nodes_to_extract' list. The
834     returned sub-graph is a deep copy of the provided graph nodes.
835     :param graph: graph to create a sub-graph from.
836     :param nodes_to_extract: list of node names to extract.
837     :return: new graph.
838     """
839     return graph.create_sub_graph_copy(nodes_to_extract=nodes_to_extract)
840
841
842 @deprecated_api(Graph)
843 def get_graph_ops(graph: Graph):
844     return graph.get_op_nodes()
845
846
847 @deprecated_api(Graph)
848 def check_empty_graph(graph: Graph, description: str):
849     return graph.check_empty_graph(description=description)
850
851
852 @deprecated_api(Graph)
853 def create_edge(src_node: Node, dst_node: Node, out_port: int = 0, in_port: int = 0, edge_attrs: dict = None):
854     """
855     Creates edge from node 'src_node' from output with index 'out_port' to node 'dst_node' with input index 'in_port'.
856     :param src_node: node to create edge from.
857     :param dst_node: node to create edge to.
858     :param out_port: the index of output tensor of the 'src_node'.
859     :param in_port: the input index of the node 'dst_node'.
860     :param edge_attrs: dictionary with edge attrs.
861     :return: None
862     """
863     assert src_node.graph is dst_node.graph
864     graph = src_node.graph
865     return graph.create_edge(src_node=src_node, dst_node=dst_node, out_port=out_port, in_port=in_port,
866                              edge_attrs=edge_attrs)
867
868
869 @deprecated_api(Graph)
870 def erase_node(node: Node):
871     """
872     Erases node from the graph and reconnect edges from input node(s) to output node(s)
873     Produces assertion error if the node being removed has multiple inputs or outputs.
874     The function can be used in the front phase only (when there are no data nodes in the graph).
875     :param node: Node to erase
876     """
877     graph = node.graph
878     return graph.erase_node(node)
879
880
881 @deprecated_api(Node)
882 def get_sorted_inputs(node: Node, control_flow: bool = False):
883     return node.get_sorted_inputs(control_flow=control_flow)
884
885
886 @deprecated_api(Node)
887 def get_sorted_outputs(node: Node, control_flow: bool = False):
888     return node.get_sorted_outputs(control_flow=control_flow)
889
890
891 @deprecated_api(Node)
892 def insert_node_after(node: Node, new_node: Node, node_out_port: int = 0):
893     """
894     Insert node 'new_node' after output with index 'node_out_port' of the node 'node'. All consumers of node 'node'
895     output with index 'node_out_port' will be changed to consume node 'new_node'.
896     The function should be used when graph doesn't contain data nodes yet.
897     :param node: node after which new node should be inserted.
898     :param new_node: node to be inserted.
899     :param node_out_port: the output index for the node 'node' to insert
900     :return: None
901     """
902     return node.insert_node_after(new_node=new_node, node_out_port=node_out_port)
903
904
905 @deprecated_api(Node)
906 def replace_node(old_node: Node, new_node: Node, new_node_out_port: int=None):
907     """
908     Replaces node 'old_node' with a node 'new_node' preserving edge attributes.
909     :param old_node: node to be replaced.
910     :param new_node: node to replace with.
911     :return: None
912     """
913     return old_node.replace_node(new_node=new_node, new_node_out_port=new_node_out_port)
914
915
916 @deprecated_api(Node)
917 def copy_node(src_node: Node, new_attrs: dict=None, dst_graph: nx.MultiDiGraph = None):
918     """ Copies node with all attributes (optionally updated) within the same graph or to different graph."""
919     return src_node.copy_node(new_attrs=new_attrs, dst_graph=dst_graph)
920
921
922 @deprecated_api(Node)
923 def get_inputs(graph: Graph, node: str, edge_attr: dict = None, control_flow: bool = False):
924     return Node(graph, node).get_inputs(edge_attr=edge_attr, control_flow=control_flow)
925
926
927 @deprecated_api(Node)
928 def get_outputs(graph: Graph, node: str, edge_attr: dict = None, control_flow: bool = False):
929     return Node(graph, node).get_outputs(edge_attr=edge_attr, control_flow=control_flow)