2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from copy import deepcopy
24 from mo.graph.port import Port
25 from mo.utils.error import Error
26 from mo.utils.utils import refer_to_faq_msg, deprecated_api, shrink_str_value
29 def dict_to_ordered_dict(d: dict):
30 return collections.OrderedDict(sorted(d.items(), key=lambda t: t[0]))
34 def __init__(self, graph, node: str):
36 raise AttributeError("Attempt to access node {} that not in graph".format(node))
38 super(Node, self).__setattr__('graph', graph)
39 super(Node, self).__setattr__('node', node) # obsolete
40 super(Node, self).__setattr__('id', node)
42 def __str__(self, max_length: int = 100):
43 node_dict = self.graph.node[self.id]
44 print_dict = {k: v if k != 'value' else shrink_str_value(v, max_symbols=max_length) for k, v in node_dict.items()}
45 return str(print_dict)
47 def __setattr__(self, k, v):
48 # you can assign only existing attributes
49 attrs = self.graph.node[self.node]
54 def __getattr__(self, k):
55 # hope it raises AttributeError if k is not in the dict
56 return self.graph.node[self.node][k]
58 def __getitem__(self, k):
59 return self.graph.node[self.node][k]
61 def __setitem__(self, k, v):
62 self.graph.node[self.node][k] = v
64 def __contains__(self, k):
67 def add_input_port(self, idx):
68 if not self.has_valid('_in_ports'):
69 Node(self.graph, self.id)['_in_ports'] = set()
70 if idx in self.in_ports():
71 raise Error("Input port with {} index already exists for {} node.".format(idx, self.name))
72 self._in_ports.add(idx)
74 def add_output_port(self, idx):
75 if not self.has_valid('_out_ports'):
76 Node(self.graph, self.id)['_out_ports'] = set()
77 if idx in self.out_ports():
78 raise Error("Output port with {} index already exists for {} node.".format(idx, self.name))
79 self._out_ports.add(idx)
81 def in_port(self, idx=None) -> Port:
82 if not self.has_valid('_in_ports'):
83 raise Error("Operation {} {} has no _in_ports attribute", self.op, self.name)
84 if idx not in self._in_ports:
85 raise Error("Input port with index {} is not in node {}".format(idx, self.name))
86 return Port(node=self, idx=idx, type='in')
89 if not self.has_valid('_in_ports'):
90 raise Error("Operation {} {} has no _in_ports attribute", self.op, self.name)
91 return dict_to_ordered_dict({idx: self.in_port(idx) for idx in self._in_ports})
93 def out_port(self, idx=None) -> Port:
94 if not self.has_valid('_out_ports'):
95 raise Error("Operation {} {} has no _out_ports attribute", self.op, self.name)
96 if idx not in self._out_ports:
97 raise Error("Output port with index {} is not in node {}".format(idx, self.name))
98 return Port(node=self, idx=idx, type='out')
101 if not self.has_valid('_out_ports'):
102 raise Error("Operation {} {} has no _out_ports attribute", self.op, self.name)
103 return dict_to_ordered_dict({idx: self.out_port(idx) for idx in self._out_ports})
105 def has_port(self, port_type, idx):
106 assert port_type in ['in', 'out'], "Invalid usage of has_port method"
108 if port_type == 'in':
109 return self.has_valid('_in_ports') and idx in self.in_ports()
111 return self.has_valid('_out_ports') and idx in self.out_ports()
114 return self.graph.node[self.node]
117 return k in self.graph.node[self.node]
119 def has_valid(self, k):
120 return self.has(k) and not self.graph.node[self.node][k] is None
122 def has_and_set(self, k):
123 return self.has_valid(k) and self[k]
125 def in_nodes_edges(self, control_flow: bool=False):
126 return dict_to_ordered_dict({x[1]['in']: (Node(self.graph, x[0]), x[1]) for x in
127 self.get_inputs(control_flow=control_flow)})
129 def in_nodes(self, control_flow: bool=False):
130 assert self.has('kind') # TODO: remove as it always exists
131 assert self.kind in ['op', 'data'] # TODO: remove as it always exists
132 if self.kind == 'op':
133 return dict_to_ordered_dict({x[1]['in']: Node(self.graph, x[0]) for x in
134 self.get_inputs(control_flow=control_flow)})
135 elif self.kind == 'data':
136 return [Node(self.graph, n) for n, d in self.get_inputs(control_flow=control_flow)]
138 def in_node(self, key=0, control_flow: bool=False):
139 return self.in_nodes(control_flow=control_flow)[key]
141 def in_edges(self, control_flow: bool=False):
142 assert self.has('kind')
143 assert self.kind in ['op', 'data']
144 if self.kind == 'op':
145 return dict_to_ordered_dict({x[1]['in']: x[1] for x in self.get_inputs(control_flow=control_flow)})
146 elif self.kind == 'data':
147 return [d for n, d in self.get_inputs(control_flow=control_flow)]
149 def out_nodes_edges(self, control_flow: bool=False):
150 return dict_to_ordered_dict({x[1]['out']: (Node(self.graph, x[0]), x[1]) for x in
151 self.get_outputs(control_flow=control_flow)})
153 def out_nodes(self, control_flow: bool=False):
154 assert self.has('kind')
155 assert self.kind in ['op', 'data']
156 if self.kind == 'op':
157 return dict_to_ordered_dict({x[1]['out']: Node(self.graph, x[0]) for x in
158 self.get_outputs(control_flow=control_flow)})
159 elif self.kind == 'data':
160 return [Node(self.graph, n) for n, d in self.get_outputs(control_flow=control_flow)]
162 def out_edges(self, control_flow: bool=False):
163 assert self.has('kind')
164 assert self.kind in ['op', 'data']
165 if self.kind == 'op':
166 return dict_to_ordered_dict({x[1]['out']: x[1] for x in self.get_outputs(control_flow=control_flow)})
167 elif self.kind == 'data':
168 return [d for n, d in self.get_outputs(control_flow=control_flow)]
170 def out_node(self, key=0, control_flow: bool=False):
171 return self.out_nodes(control_flow=control_flow)[key]
173 def in_edge(self, key=0, control_flow: bool=False):
174 return self.in_edges(control_flow=control_flow)[key]
176 def out_edge(self, key=0, control_flow: bool=False):
177 return self.out_edges(control_flow=control_flow)[key]
180 return self.graph.node[self.node]
182 def get_inputs(self, edge_attr: dict = None, control_flow: bool = False):
183 if edge_attr is None:
185 in_edges = self.graph.in_edges(self.id, data=True)
187 in_edges = [(u, v, d) for u, v, d in in_edges if 'control_flow_edge' not in d or not d['control_flow_edge']]
188 return [(u, d) for u, v, d in in_edges if all([attr in d and d[attr] == edge_attr[attr] for attr in edge_attr])]
190 def get_outputs(self, edge_attr: dict = None, control_flow: bool = False):
191 if edge_attr is None:
193 out_edges = self.graph.out_edges(self.id, data=True)
195 out_edges = [(u, v, d) for u, v, d in out_edges if
196 'control_flow_edge' not in d or not d['control_flow_edge']]
197 return [(v, d) for u, v, d in out_edges if
198 all([attr in d and d[attr] == edge_attr[attr] for attr in edge_attr])]
200 def get_sorted_inputs(self, control_flow: bool = False):
201 return sorted([x for x in self.get_inputs(control_flow=control_flow) if 'in' in x[1]],
202 key=lambda x: x[1]['in'])
204 def get_sorted_outputs(self, control_flow: bool = False):
205 return sorted([x for x in self.get_outputs(control_flow=control_flow) if 'out' in x[1]],
206 key=lambda x: x[1]['out'])
208 def soft_get(self, k):
209 return self[k] if self.has_valid(k) else '<UNKNOWN>'
211 def edges(self, attrs: dict=None):
212 """ Get a single edge with specified set of attributes.
214 If none or multiple edges satisfies this criteria, exception is raised
215 Edge is represented as tuple (u, v, d), where u is source node,
216 v is destination node and d is edge attributes.
218 edges = list(self.graph.in_edges([self.id], data=True)) + list(self.graph.out_edges([self.id], data=True))
219 return [(u, v, d) for u, v, d in edges if dict_includes(d, attrs)]
221 def edge(self, attrs: dict=None):
222 """ Get a single edge with specified set of attributes.
224 If none or multiple edges satisfies this criteria, exception is raised
225 Edge is represented as tuple (u, v, d), where u is source node,
226 v is destination node and d is edge attributes.
228 edges = self.edges(attrs)
229 assert len(edges) == 1, 'edges: {}, required attributes: {}'.format(edges, attrs)
232 def copy_node(self, new_attrs: dict = None, dst_graph=None):
233 ''' Copies node with all attributes (optionally updated) within the same graph or to different graph.'''
234 if new_attrs is None:
236 if dst_graph is None:
237 dst_graph = self.graph
239 attrs = deepcopy(self.attrs())
240 attrs.update(new_attrs)
241 new_id = dst_graph.unique_id()
242 dst_graph.add_node(new_id, **attrs)
243 return Node(dst_graph, new_id)
245 def insert_node_with_data_before(self, inp, new_op_class: callable, op_before_params: dict = None,
246 infer_current: bool = False, additional_inputs: list = None):
248 Inserts operation node with op_before_params and data node before current operation
250 :param inp: input data node of current node
251 :param new_op_class: class of operation that will be inserted before current operation node
252 :param op_before_params: parameters to be added to operation that will be inserted before current operation
255 [...] -> inp -> Cur_Op -> Cur_Data -> [...]
258 [...] -> inp -> New_Op_bef -> New_Data_bef -> Cur_Op -> Cur_Data -> [...]
262 node = Node(graph, self.node)
263 cls_name = new_op_class.op
264 op_before_params = {} if op_before_params is None else op_before_params
266 # operating with input
267 new_op_before = new_op_class(graph, op_before_params)
268 edge_attrs = deepcopy(graph.get_edge_data(inp.id, node.id)[0])
269 graph.remove_edge(inp.id, node.id)
270 # form a list of input nodes for a new op node combining new_out and additional_inputs
271 inputs = [inp] + (additional_inputs if additional_inputs else [])
272 new_inp = new_op_before.create_node_with_data(inputs, {'name': node.name + cls_name + '/Before'})
273 graph.add_edge(new_inp.id, node.id, **edge_attrs)
277 def insert_node_with_data_after(self, out, new_op_class: callable, op_after_params: dict = None,
278 additional_inputs: list = None):
280 Inserts operation node with op_after_params and data node after current operation
282 :param out: output data node of current node
283 :param new_op_class: class of operation that will be inserted after current operation node
284 :param op_after_params: parameters to be added to operation that will be inserted after current operation
285 :param additional_inputs: other parameters for a new operation node in addition to one that is created
286 at the 'out' placed; new nodes are added after 0-th input
288 TODO Allow indexing for input parameters as well as for 'out' data node to explicitly
289 specify ports that are connected to.
292 [...] -> Cur_Op -> Cur_Data -> [...]
295 [...] -> Cur_Op -> Cur_Data -> New_Op_aft -> New_Data_aft(==out) -> [...]
298 # we import it here because Op imports Node and unique_id from this file
299 from mo.ops.op import Op
302 node = Node(graph, self.node)
303 cls_name = new_op_class.op
304 op_after_params = {} if op_after_params is None else op_after_params
306 new_op_after = new_op_class(graph, op_after_params)
307 graph.remove_edge(node.id, out.id)
308 new_out = Op.create_data_node(graph, node)
310 # form a list of input nodes for a new op node combining new_out and additional_inputs
311 inputs = [new_out] + (additional_inputs if additional_inputs else [])
312 new_op_after.create_node_with_data(inputs, {'name': node.name + cls_name + '/After'}, data_nodes=out)
314 def bracket_with_different_nodes_with_data(self, inp, out, new_op_class_before: callable,
315 new_op_class_after: callable,
316 op_before_params: dict = None, op_after_params: dict = None):
318 Inserts one operation node with op_before_params and data node before current operation node and
319 inserts one operation node with op_after_params and data node after current operation node
320 :param inp: input data node of self.node node
321 :param out: output data node of self.node node
322 :param new_op_class_before: class of operation that will be inserted before current operation node
323 :param new_op_class_after: class of operation that will be inserted after current operation node
324 :param op_before_params: parameters to be added to operation that will be inserted before current operation
325 :param op_after_params: parameters to be added to operation that will be inserted after current operation
328 [...] -> inp -> Cur_Op -> out -> [...]
331 [...] -> inp -> New_Op_bef -> New_Data_bef -> Cur_Op -> Cur_Data -> New_Op_aft -> New_Data_aft(==out) -> [...]
332 [op_before_params] [op_after_params]
334 op_before_params = {} if op_before_params is None else op_before_params
335 op_after_params = {} if op_after_params is None else op_after_params
336 self.insert_node_with_data_before(inp, new_op_class_before, op_before_params)
337 self.insert_node_with_data_after(out, new_op_class_after, op_after_params)
339 def bracket_op_with_another_op(self, inp, out, new_op_class: callable,
340 op_before_params: dict = None, op_after_params: dict = None):
342 Covers current operation with two similar another ones of class new_op_class:
343 :param inp: input data node of self.node node
344 :param out: output data node of self.node node
345 :param new_op_class: class of operation with which current operation will be covered
346 :param op_before_params: parameters to be added to operation that will be inserted before current operation
347 :param op_after_params: parameters to be added to operation that will be inserted after current operation
350 [...] -> inp -> Cur_Op -> out -> [...]
353 [...] -> inp -> New_Op_bef -> New_Data_bef -> Cur_Op -> Cur_Data -> New_Op_aft -> New_Data_aft(==out) -> [...]
354 [op_before_params] [op_after_params]
356 self.bracket_with_different_nodes_with_data(inp=inp, out=out,
357 new_op_class_before=new_op_class, new_op_class_after=new_op_class,
358 op_before_params=op_before_params, op_after_params=op_after_params)
360 def insert_node_after(self, new_node, node_out_port: int = 0):
362 Insert node 'new_node' after output with index 'node_out_port' of the node 'node'. All consumers of node 'node'
363 output with index 'node_out_port' will be changed to consume node 'new_node'.
364 The function should be used when graph doesn't contain data nodes yet.
365 :param node: node after which new node should be inserted.
366 :param new_node: node to be inserted.
367 :param node_out_port: the output index for the node 'node' to insert
370 assert self.graph is new_node.graph
371 assert (len([name for name in self.graph.nodes() if Node(self.graph, name).soft_get('kind') == 'data']) == 0)
374 old_edges = list(graph.out_edges(self.id, data=True, keys=True))
375 # create new edges first and then remove all old edges. This is needed for case when 'node' has several consumers
376 # getting input from 'node_out_port'.
377 # save tuple ("name of the destination edge", "edge key") to be removed
378 node_name_and_edge_key = []
379 for _, dst_name, edge_key, edge_attrs in old_edges:
380 if edge_attrs['out'] == node_out_port:
381 log.debug('Create edge from "{}" to "{}"'.format(new_node.name, dst_name))
382 graph.create_edge(new_node, Node(graph, dst_name), 0, edge_attrs['in'])
383 node_name_and_edge_key.append((dst_name, edge_key))
384 for dst_name, edge_key in node_name_and_edge_key:
385 log.debug('Remove edge from "{}" to "{}"'.format(self.id, dst_name))
386 graph.remove_edge(self.id, dst_name, edge_key)
387 graph.create_edge(self, new_node, node_out_port, 0, {})
389 def replace_node(self, new_node, new_node_out_port: int=None):
391 Replaces node 'old_node' with a node 'new_node' preserving edge attributes.
392 :param old_node: node to be replaced.
393 :param new_node: node to replace with.
396 assert self.graph is new_node.graph
397 assert self.id != new_node.id, "New node and replaceable node are the same"
399 # save output edges and reconnect them to new node
400 for _, dst_node_name, edge_attrs in graph.out_edges(self.id, data=True):
401 new_edge_attrs = deepcopy(edge_attrs)
402 if new_node_out_port is not None:
403 assert 'out' not in edge_attrs or edge_attrs['out'] == 0, \
404 'replace_node function can replace old node with a single output port only if new_node_out_port is ' \
406 new_edge_attrs.update({'out': new_node_out_port})
407 graph.add_edge(new_node.id, dst_node_name, **new_edge_attrs)
409 # if the node for replace is output node then we propagate this attribute to a new node
410 if len(self.out_nodes()) == 1 and self.out_node().has('op') and self.out_node().op == 'OpOutput':
411 graph.remove_node(self.out_node().id)
412 add_opoutput(graph, new_node.id, 0, False)
413 graph.remove_node(self.id)
415 def input_ports_with(self, node):
417 Returns a list of integers that specify input ports that connected to a given node.
418 :param node: node in the graph that is expected to appear at input port for self node
419 :return: a list of integers with port indices that are connected to self node
421 return [i for i in range(len(self.in_nodes())) if self.in_node(i).id == node.id]
423 class Graph(nx.MultiDiGraph):
424 def __init__(self, data=None, **attr):
426 super().__init__(data, **attr)
430 # SAFE API DESCRIPTION
431 # all provided methods below are designed to be more safe and convenient
432 # be careful while using other methods from nx.MultiDiGraph
434 def add_node(self, node_for_adding, **attrs):
435 # TODO: check required attrs for node
436 super().add_node(node_for_adding, **attrs)
437 node = Node(self, node_for_adding)
439 in_ports_count = node.in_ports_count if node.has_valid('in_ports_count') else None
440 out_ports_count = node.out_ports_count if node.has_valid('out_ports_count') else None
442 node['_in_ports'] = set()
443 node['_out_ports'] = set()
445 if in_ports_count is not None:
446 for idx in range(in_ports_count):
447 node.add_input_port(idx=idx)
449 if out_ports_count is not None:
450 for idx in range(out_ports_count):
451 node.add_output_port(idx=idx)
453 def add_edge(self, u_for_edge, v_for_edge, key=None, **attr):
454 return super().add_edge(u_for_edge, v_for_edge, key=key, **attr)
456 def add_edges_from(self, ebunch_to_add, **attr):
457 for e in ebunch_to_add:
469 raise Error("Edge tuple %s must be a 2-tuple, 3-tuple or 4-tuple." % (e,))
472 self.add_edge(u, v, key=key, **ddd)
474 def remove_edge(self, u, v, key=None):
475 return super().remove_edge(u, v, key=key)
477 def erase_node(self, node: Node):
479 Erases node from the graph and reconnect edges from input node(s) to output node(s)
480 Produces assertion error if the node being removed has multiple inputs or outputs.
481 The function can be used in the front phase only (when there are no data nodes in the graph).
482 :param node: Node to erase
486 inputs = list(self.in_edges(node_id, data=True))
487 outputs = list(self.out_edges(node_id, data=True))
489 assert node.kind == 'op' and (len(node.out_nodes()) == 0 or list(node.out_nodes().values())[0].kind != 'data'), \
490 "The function must be used before the partial infer when graph doesn't contain data nodes."
491 assert len(node.out_nodes()) <= 1, "The node {} must produce just one output tensor".format(
492 node.soft_get('name'))
493 assert len(inputs) <= 1, "The node {} must have just one input".format(node.soft_get('name'))
495 if len(outputs) == 0 and len(inputs) != 0:
496 from mo.front.extractor import add_output_ops
497 input_ids = {input_node_id: {'port': {'out': [attrs['out']]}} for input_node_id, _, attrs in inputs}
498 if node.has('op') and node.op == 'OpOutput':
499 add_output_ops(self, input_ids)
501 if len(outputs) == 0 or len(inputs) == 0:
502 self.remove_node(node_id)
505 input_node_id = inputs[0][0]
506 for src, dst, attrs in outputs:
507 self.remove_edge(src, dst)
508 # update the 'out' attribute of the edge from the node being removed
509 attrs['out'] = inputs[0][2]['out']
510 self.add_edge(input_node_id, dst, **attrs)
511 self.remove_node(node_id)
513 def get_edge_data(self, u, v, key=None, default=None):
514 return super().get_edge_data(u, v, key=key, default=default)
516 def get_inputs_with_ports(self, match, pattern_edges, input_names_in_pattern):
518 Front replacements of multi-input nodes should specify output port to add_node-like functions
519 This function is a helper to get such information out of matched nodes
520 :param graph: graph to operate on
521 :param match: dictionary returned by matching function
522 :param pattern_edges: edges that are specified in pattern
523 :param input_names_in_pattern: names of matched nodes as they were specified in pattern that should be in
525 :return: list of tuples of node and output port
528 for name in input_names_in_pattern:
529 assert name in match, "node named {} not in match {}".format(name, match)
532 for edge in pattern_edges:
534 assert edge[1] in match, "name from pattern_edges {} not in match {}".format(edge[1], match)
535 dst.append(match[edge[1]])
537 raise Error('Multiple output ports detected for node {} as {} in pattern'.format(match[name].id, name))
539 out_port = self.get_edge_data(src.id, dst.id)[0]['out']
540 inputs.append((src, out_port))
543 def get_node_id_by_name(self, name: str):
544 for node in self.nodes():
545 if 'name' in self.node[node] and self.node[node]['name'] == name:
547 raise Error('No node with name {}. ' +
548 refer_to_faq_msg(51), name)
550 def get_op_nodes(self, **attrs):
551 nodes = self.get_nodes_with_attributes(**dict(kind='op', **attrs))
552 return [Node(self, node) for node in nodes]
554 def get_data_nodes(self, has_value=None):
556 Returns list of data nodes.
557 If has_value = True, returns data nodes with value
558 If has_value = False, returns data nodes without value
560 data_nodes = [Node(self, node) for node in self.nodes() if Node(self, node).soft_get('kind') == 'data']
561 return [node for node in data_nodes if has_value is None or node.has_valid('value') == has_value]
563 def get_nodes_with_attributes(self, **attrs: dict):
564 node_attrs = self.nodes(data=True)
565 return [n for n, d in node_attrs if all(a in d.items() for a in attrs.items())]
567 def unique_id(self, prefix: str = ""):
569 Generates a unique node id for a new node in a given graph.
570 The optional string prefix can be specified.
572 # TODO thread safety?
573 self.unique_id_count = max(self.unique_id_count, self.number_of_nodes()) + 1
574 if prefix and not self.has_node(prefix):
576 while self.has_node(prefix + str(self.unique_id_count)):
577 self.unique_id_count += 1
578 return prefix + str(self.unique_id_count)
580 def check_empty_graph(self, description: str):
581 if len(self.nodes()) <= 1:
583 "Graph contains {} node after executing {}. It considered as error because resulting IR will be "
584 "empty which is not usual".format(len(self.nodes()), description))
586 def check_shapes_consistency(self):
587 data_nodes = self.get_data_nodes()
588 data_nodes_with_wrong_shapes = []
589 for data_node in data_nodes:
590 if not data_node.has('shape'):
591 data_nodes_with_wrong_shapes.append((data_node.name, "no shape attribute"))
593 if data_node.shape is not None and not isinstance(data_node.shape, np.ndarray):
594 data_nodes_with_wrong_shapes.append((data_node.name, type(data_node.shape)))
595 if len(data_nodes_with_wrong_shapes) > 0:
596 raise Error("Graph contains data nodes ({}) with inconsistent shapes: {}".format(
597 len(data_nodes_with_wrong_shapes),
598 data_nodes_with_wrong_shapes
601 def check_nodes_ports_are_consecutive(self):
602 # Check that all operation nodes has consecutive ports indexes
603 op_nodes = self.get_op_nodes()
604 for node in op_nodes:
605 for idx in range(len(node.in_ports())):
606 if idx not in node.in_ports():
607 raise Error("Node {} has not consecutive in ports indexes: {}".format(node.name,
608 list(node.in_ports().keys())))
609 for idx in range(len(node.out_ports())):
610 if idx not in node.out_ports():
611 raise Error("Node {} has not consecutive out ports indexes: {}".format(node.name,
612 list(node.out_ports().keys())))
614 def dump_graph_for_graphviz(self, node_attrs: list = ['kind', 'op', 'shape'],
615 edge_attrs: list = ['in', 'out'],
616 nodes_to_dump: list = None, save_to_svg=False):
617 log.debug("---- GRAPHVIZ OUTPUT STARTS ----")
618 if nodes_to_dump is None:
619 nodes_to_dump = self.nodes()
620 string = '\ndigraph {\n'
621 visited_nodes = set()
622 for src_node_name, dst_node_name, attrs in self.edges(data=True):
623 visited_nodes.add(src_node_name)
624 visited_nodes.add(dst_node_name)
625 if src_node_name not in nodes_to_dump or dst_node_name not in nodes_to_dump:
627 src_node = self.node[src_node_name]
628 dst_node = self.node[dst_node_name]
629 src_node_string = str(src_node_name) + '\\n' + '\\n'.join(
630 [str(key) + '=' + str(src_node.get(key, 'None')) for key in node_attrs if key in src_node])
631 dst_node_string = str(dst_node_name) + '\\n' + '\\n'.join(
632 [str(key) + '=' + str(dst_node.get(key, 'None')) for key in node_attrs if key in dst_node])
633 edge_string = ' '.join([str(key) + '=' + str(attrs.get(key, 'None')) for key in edge_attrs if key in attrs])
634 string += '"{}" -> "{}" [label = "{}"];\n'.format(src_node_string, dst_node_string, edge_string)
635 for node in nodes_to_dump:
636 if node not in visited_nodes:
637 string += '"{}"'.format(node) # TODO: add attributes like it was done in the loop above
638 visited_nodes.add(node)
641 log.debug("---- GRAPHVIZ OUTPUT ENDS ----")
647 file_name = "{}_{}.txt".format(self.name.replace('/', '_'), 0)
649 while os.path.exists(file_name):
650 file_name = "{}_{}.txt".format(self.name.replace('/', '_'), id)
652 with open(file_name, "w") as f:
654 graphviz.render('dot', 'svg', file_name)
655 print('Graph was saved to {}.{}'.format(file_name, 'svg'))
657 raise ImportError('Can\'t import graphviz')
658 except Exception as e:
659 raise Error('Can\'t save graph to svg') from e
663 def print_graph_stat(self):
664 log.debug('Number of nodes in graph: {}'.format(self.number_of_nodes()))
665 log.debug('Number of edges in graph: {}'.format(len(list(self.edges()))))
666 ops = collections.defaultdict(int)
667 for _node in self.nodes():
668 node = Node(self, _node)
669 kind = node.kind if node.has('kind') else '<UNDEFINED>'
671 ops['op/' + node.op] += 1
674 if node.has('shape') and np.any(node.shape == 0):
675 log.error("Found bad shape: '{}' for node '{}'".format(node.shape, node.node))
676 for k, v in ops.items():
677 log.debug(' {} : {}'.format(k, v))
679 def create_sub_graph_copy(self, nodes_to_extract: list):
681 Create new graph which is a sub-graph of the 'graph' that contains just nodes from 'nodes_to_extract' list. The
682 returned sub-graph is a deep copy of the provided graph nodes.
683 :param graph: graph to create a sub-graph from.
684 :param nodes_to_extract: list of node names to extract.
687 return self.subgraph(nodes_to_extract).copy()
689 def create_edge(self, src_node: Node, dst_node: Node, out_port: int = 0, in_port: int = 0, edge_attrs: dict = None):
691 Creates edge from node 'src_node' from output with index 'out_port' to node 'dst_node' with input index 'in_port'.
692 :param src_node: node to create edge from.
693 :param dst_node: node to create edge to.
694 :param out_port: the index of output tensor of the 'src_node'.
695 :param in_port: the input index of the node 'dst_node'.
696 :param edge_attrs: dictionary with edge attrs.
699 # edges must belong to the same graph
700 assert src_node.graph is dst_node.graph
701 graph = src_node.graph
703 if edge_attrs is None:
706 edge_attrs = edge_attrs.copy()
708 {'in': in_port, 'out': out_port, 'in_attrs': ['in', 'permutation'], 'out_attrs': ['out', 'permutation'],
709 'data_attrs': ['fw_tensor_debug_info']})
711 # TODO: in case if in_port do not exists, we should raise an Exception here
712 graph.add_edges_from([(src_node.id, dst_node.id, edge_attrs)])
715 def create_graph_with_nodes(src_nodes, get_id: callable, get_attrs: callable):
717 Go over all nodes in src_nodes that should be enumerable and create new NX nodes
718 using get_id and get_attrs functions to create node id and node attributes correspondingly.
721 for node in src_nodes:
722 graph.add_node(get_id(node), **get_attrs(node))
726 def dict_includes_compare_attrs(attr, attr_probe):
727 if callable(attr_probe) and not isinstance(attr_probe, type):
728 return attr_probe(attr)
730 return attr == attr_probe
733 def dict_includes(big: dict, sub_dict: dict, skip_attr_names=[]):
734 """ Searches attributes from sub_dict in big and ensures that all values match.
736 Entries in sub_dict can be of two types: callable or not callable. If callable is specified
737 it is treated as probing function for attribute value from big dictionary by callable(attr) expression.
738 If it is not callable, the values are compared with == operator.
741 dict_includes_compare_attrs(big.get(attr, None), sub_dict[attr])
742 for attr in sub_dict.keys() if attr not in skip_attr_names
746 def add_opoutput(graph: Graph, node_name: str, port: int, cut: bool = True):
748 Creates and connects OpOutput node to node_name port. Cuts existing port if requested.
749 :param graph: graph to operate with
750 :param node_name: name of existing node in the graph that we want to add OpOutput to
751 :param port: output port of node to connect OpOutput to
752 :param cut: determines way of operating with edge specified by node_name and port
754 # we import it here because Op imports add_attrs_props and update_ie_fields from this file
755 from mo.ops.output import Output
756 node = Node(graph, node_name)
757 if cut and len(node.out_edges()) != 0:
758 opoutput_node = Output(graph).create_node_on_port(node, port, {'name': node_name + '/sink_port_' + str(port)})
760 opoutput_node = Output(graph).create_node([(node, port)], {'name': node_name + '/sink_port_' + str(port)})
761 opoutput_node.in_edge()['data_attrs'] = ['fw_tensor_debug_info']
762 opoutput_node.in_edge()['fw_tensor_debug_info'] = [(node_name, port)]
763 log.debug('Sink: {} for node {}'.format(opoutput_node.id, node_name))
764 log.debug(str(graph.node[opoutput_node.id]))
765 log.debug("Add edge from {} to {}".format(node_name, opoutput_node.id))
766 return opoutput_node.id
769 # TODO implement merging for keys with dictionary values?
770 def merge_edge_props(attrs: dict, additional_attrs: dict):
772 Update edge attributes without changing 'in' and 'out' keys.
773 It is necessary to copy edge attributes during merging of nodes when
774 result of one subgraph call is passed as input to another subgraph call
777 for (key, value) in additional_attrs.items():
778 if key not in ['in', 'out']:
779 if type(additional_attrs[key]) is list:
780 if key not in result:
782 result[key].extend(additional_attrs[key])
783 result[key] = list(set(result[key])) # silly solution to find unique elements
789 # All functions below are deprecated and will be removed in next release
790 # Please, use methods from Graph/Node classes instead
793 @deprecated_api(Graph)
794 def get_node_id_by_name(graph: Graph, name: str):
795 return graph.get_node_id_by_name(name=name)
798 @deprecated_api(Graph)
799 def print_graph_stat(graph: Graph):
800 return graph.print_graph_stat()
803 @deprecated_api(Graph)
804 def get_inputs_with_ports(graph: Graph, match, pattern_edges, input_names_in_pattern):
806 Front replacements of multi-input nodes should specify output port to add_node-like functions
807 This function is a helper to get such information out of matched nodes
808 :param graph: graph to operate on
809 :param match: dictionary returned by matching function
810 :param pattern_edges: edges that are specified in pattern
811 :param input_names_in_pattern: names of matched nodes as they were specified in pattern that should be in
813 :return: list of tuples of node and output port
815 return graph.get_inputs_with_ports(match=match,
816 pattern_edges=pattern_edges,
817 input_names_in_pattern=input_names_in_pattern)
820 @deprecated_api(Graph)
821 def dump_graph_for_graphviz(graph: Graph, node_attrs: list = ['kind', 'op', 'shape'],
822 edge_attrs: list = ['in', 'out'],
823 nodes_to_dump: list = None, save_to_svg=False):
824 return graph.dump_graph_for_graphviz(node_attrs=node_attrs,
825 edge_attrs=edge_attrs,
826 nodes_to_dump=nodes_to_dump,
827 save_to_svg=save_to_svg)
830 @deprecated_api(Graph)
831 def create_sub_graph_copy(graph: Graph, nodes_to_extract: list):
833 Create new graph which is a sub-graph of the 'graph' that contains just nodes from 'nodes_to_extract' list. The
834 returned sub-graph is a deep copy of the provided graph nodes.
835 :param graph: graph to create a sub-graph from.
836 :param nodes_to_extract: list of node names to extract.
839 return graph.create_sub_graph_copy(nodes_to_extract=nodes_to_extract)
842 @deprecated_api(Graph)
843 def get_graph_ops(graph: Graph):
844 return graph.get_op_nodes()
847 @deprecated_api(Graph)
848 def check_empty_graph(graph: Graph, description: str):
849 return graph.check_empty_graph(description=description)
852 @deprecated_api(Graph)
853 def create_edge(src_node: Node, dst_node: Node, out_port: int = 0, in_port: int = 0, edge_attrs: dict = None):
855 Creates edge from node 'src_node' from output with index 'out_port' to node 'dst_node' with input index 'in_port'.
856 :param src_node: node to create edge from.
857 :param dst_node: node to create edge to.
858 :param out_port: the index of output tensor of the 'src_node'.
859 :param in_port: the input index of the node 'dst_node'.
860 :param edge_attrs: dictionary with edge attrs.
863 assert src_node.graph is dst_node.graph
864 graph = src_node.graph
865 return graph.create_edge(src_node=src_node, dst_node=dst_node, out_port=out_port, in_port=in_port,
866 edge_attrs=edge_attrs)
869 @deprecated_api(Graph)
870 def erase_node(node: Node):
872 Erases node from the graph and reconnect edges from input node(s) to output node(s)
873 Produces assertion error if the node being removed has multiple inputs or outputs.
874 The function can be used in the front phase only (when there are no data nodes in the graph).
875 :param node: Node to erase
878 return graph.erase_node(node)
881 @deprecated_api(Node)
882 def get_sorted_inputs(node: Node, control_flow: bool = False):
883 return node.get_sorted_inputs(control_flow=control_flow)
886 @deprecated_api(Node)
887 def get_sorted_outputs(node: Node, control_flow: bool = False):
888 return node.get_sorted_outputs(control_flow=control_flow)
891 @deprecated_api(Node)
892 def insert_node_after(node: Node, new_node: Node, node_out_port: int = 0):
894 Insert node 'new_node' after output with index 'node_out_port' of the node 'node'. All consumers of node 'node'
895 output with index 'node_out_port' will be changed to consume node 'new_node'.
896 The function should be used when graph doesn't contain data nodes yet.
897 :param node: node after which new node should be inserted.
898 :param new_node: node to be inserted.
899 :param node_out_port: the output index for the node 'node' to insert
902 return node.insert_node_after(new_node=new_node, node_out_port=node_out_port)
905 @deprecated_api(Node)
906 def replace_node(old_node: Node, new_node: Node, new_node_out_port: int=None):
908 Replaces node 'old_node' with a node 'new_node' preserving edge attributes.
909 :param old_node: node to be replaced.
910 :param new_node: node to replace with.
913 return old_node.replace_node(new_node=new_node, new_node_out_port=new_node_out_port)
916 @deprecated_api(Node)
917 def copy_node(src_node: Node, new_attrs: dict=None, dst_graph: nx.MultiDiGraph = None):
918 """ Copies node with all attributes (optionally updated) within the same graph or to different graph."""
919 return src_node.copy_node(new_attrs=new_attrs, dst_graph=dst_graph)
922 @deprecated_api(Node)
923 def get_inputs(graph: Graph, node: str, edge_attr: dict = None, control_flow: bool = False):
924 return Node(graph, node).get_inputs(edge_attr=edge_attr, control_flow=control_flow)
927 @deprecated_api(Node)
928 def get_outputs(graph: Graph, node: str, edge_attr: dict = None, control_flow: bool = False):
929 return Node(graph, node).get_outputs(edge_attr=edge_attr, control_flow=control_flow)