"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
import networkx as nx
-from mo.graph.graph import NodeWrap
+from mo.graph.graph import Node, Graph
-def find_nodes_by_type(graph: nx.MultiDiGraph, t_name: str):
- nodes = nx.topological_sort(graph)
- inputs = []
- for n in nodes:
- node = NodeWrap(graph, n)
- if node.has('type') and node.type == t_name:
- inputs.append(node.id)
- return inputs
+def find_nodes_by_attribute_value(graph: Graph, attr: str, attr_name: str):
+ return [id for id, v in nx.get_node_attributes(graph, attr).items() if v == attr_name]
-def find_inputs(graph: nx.MultiDiGraph):
- return find_nodes_by_type(graph, 'Input')
+def find_inputs(graph: Graph):
+ return find_nodes_by_attribute_value(graph, 'type', 'Input')
-def find_outputs(graph):
- nodes = nx.topological_sort(graph)
+def find_outputs(graph: Graph):
outputs = []
- for n in nodes:
- node = NodeWrap(graph, n)
- if node.has('is_output') and node['is_output']:
- outputs.append(node.id)
- return outputs
+ for node_id in find_nodes_by_attribute_value(graph, 'op', 'OpOutput'):
+ parents = Node(graph, node_id).in_nodes()
+ assert len(parents) == 1, 'OpOutput node should have exactly one input'
+ parent = parents[0].id
+ outputs.append(parent)
+ return list(set(outputs))