Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / utils / find_inputs.py
index 87ab7bb..633859b 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
 
 import networkx as nx
 
-from mo.graph.graph import NodeWrap
+from mo.graph.graph import Node, Graph
 
 
-def find_nodes_by_type(graph: nx.MultiDiGraph, t_name: str):
-    nodes = nx.topological_sort(graph)
-    inputs = []
-    for n in nodes:
-        node = NodeWrap(graph, n)
-        if node.has('type') and node.type == t_name:
-            inputs.append(node.id)
-    return inputs
+def find_nodes_by_attribute_value(graph: Graph, attr: str, attr_name: str):
+    return [id for id, v in nx.get_node_attributes(graph, attr).items() if v == attr_name]
 
 
-def find_inputs(graph: nx.MultiDiGraph):
-    return find_nodes_by_type(graph, 'Input')
+def find_inputs(graph: Graph):
+    return find_nodes_by_attribute_value(graph, 'type', 'Input')
 
 
-def find_outputs(graph):
-    nodes = nx.topological_sort(graph)
+def find_outputs(graph: Graph):
     outputs = []
-    for n in nodes:
-        node = NodeWrap(graph, n)
-        if node.has('is_output') and node['is_output']:
-            outputs.append(node.id)
-    return outputs
+    for node_id in find_nodes_by_attribute_value(graph, 'op', 'OpOutput'):
+        parents = Node(graph, node_id).in_nodes()
+        assert len(parents) == 1, 'OpOutput node should have exactly one input'
+        parent = parents[0].id
+        outputs.append(parent)
+    return list(set(outputs))