Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / common / find_unsupported_ops.py
index 8706706..8b632c2 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
 
 import logging as log
 
-import networkx as nx
 import numpy as np
 
-from mo.graph.graph import Node
-from mo.utils.dsu import DSU, DSUElem
-from mo.utils.graph import bfs_search
+from mo.graph.graph import Node, Graph
 
 
-def find_unsupported_ops(graph: nx.MultiDiGraph):
+def find_unsupported_ops(graph: Graph):
     """
     The function returns list of node name those are not supported. Currently nodes that product non FP32 data tensors
     or has undefined 'type' attribute are considered unsupported.
@@ -36,57 +33,13 @@ def find_unsupported_ops(graph: nx.MultiDiGraph):
         node = Node(graph, node_name)
         # op node that produce non FP32 data or has no type are considered unsupported
         if node.kind == 'op':
-            if not node.has_valid('type'):
-                log.info('Node "{}" does not have type. Consider it unsupported'.format(node_name))
-                unsupported.append(node.id)
-            else:
+            if node.has_valid('type') or (node.has_valid('op') and node.op == 'OpOutput'):
                 for out_data_node in node.out_nodes().values():
                     if out_data_node.has_valid('data_type') and out_data_node.data_type != np.float32:
                         log.info('Node "{}" produces output as non FP32. Consider it unsupported'.format(node_name))
                         unsupported.append(node.id)
+            else:
+                log.info('Node "{}" does not have type. Consider it unsupported'.format(node_name))
+                unsupported.append(node.id)
     return unsupported
 
-
-def find_unsupported_ops_subgraphs(graph: nx.MultiDiGraph, unsupported_nodes: list,
-                                   find_constant_input_fn: callable = lambda node: node):
-    bfs_nodes = bfs_search(graph, list())
-    visited = set()
-    # mark initial set of nodes as not supported
-    for node_name in unsupported_nodes:
-        graph.node[node_name]['supported'] = False
-
-    for node_name in bfs_nodes:
-        if node_name in visited:
-            continue
-
-        node = Node(graph, node_name)
-        if node.has_valid('supported') and not node['supported']:
-            added_nodes = find_constant_input_fn(node)
-            visited.update(added_nodes)
-            for node in added_nodes:
-                node['supported'] = False
-
-    dsu_elems = list()
-    for node_name in bfs_nodes:
-        node = Node(graph, node_name)
-        if node.has_valid('supported') and not node['supported']:
-            dsu_elems.append(DSUElem(node_name))
-
-    dsu = DSU(dsu_elems)
-
-    # merge adjacent unsupported nodes
-    for dsu_elem in dsu_elems:
-        node = Node(graph, dsu_elem.name)
-        if not node['supported']:
-            for out_node in node.out_nodes().values():
-                if out_node.has_valid('supported') and not out_node['supported']:
-                    dsu.union(dsu_elem, dsu.find_elem(out_node.id))
-
-    subgraph_id = dict()  # key is the name of the node, value is the set of nodes that belong to this subgraph
-    for dsu_elem in dsu.map.values():
-        parent = dsu.find_parent(dsu_elem).name
-        if parent not in subgraph_id.keys():
-            subgraph_id[parent] = set()
-        subgraph_id[parent].add(dsu_elem.name)
-
-    return [list(s) for s in subgraph_id.values()]