Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / SSDToolboxDetectionOutput.py
index 278998c..15fa70f 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
  limitations under the License.
 """
 
-import networkx as nx
 import numpy as np
 
 from extensions.front.standalone_const_eraser import StandaloneConstEraser
 from extensions.ops.DetectionOutput import DetectionOutput
 from mo.front.subgraph_matcher import SubgraphMatch
 from mo.front.tf.replacement import FrontReplacementFromConfigFileSubGraph
-from mo.graph.graph import Node
+from mo.graph.graph import Node, Graph
 from mo.ops.op import PermuteAttrs
 from mo.ops.output import Output
 from mo.ops.reshape import Reshape
@@ -33,16 +32,28 @@ class SSDToolboxDetectionOutputReplacement(FrontReplacementFromConfigFileSubGrap
     def run_before(self):
         return [StandaloneConstEraser]
 
-    def nodes_to_remove(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
+    def nodes_to_remove(self, graph: Graph, match: SubgraphMatch):
         return []
 
-    def generate_sub_graph(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
+    def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
         # IE DetectionOutput layer consumes flattened confidences and locations tensors.
         # That is why we add reshapes before them.
         locs_node = match.single_input_node(0)
         conf_node = match.single_input_node(1)
         prior_boxes_node = match.single_input_node(2)
 
+        locs_out_nodes = locs_node[0].out_nodes()
+        assert len(locs_out_nodes) == 1
+        locs_out_node = locs_out_nodes[list(locs_out_nodes.keys())[0]]
+        assert locs_out_node.op == "OpOutput", locs_out_node.op
+        graph.remove_node(locs_out_node.id)
+
+        conf_out_nodes = conf_node[0].out_nodes()
+        assert len(conf_out_nodes) == 1
+        conf_out_node = conf_out_nodes[list(conf_out_nodes.keys())[0]]
+        assert conf_out_node.op == "OpOutput", conf_out_node.op
+        graph.remove_node(conf_out_node.id)
+
         # reshape operation to flatten confidence tensor
         reshape_loc_op = Reshape(graph, {'dim': np.array([0, -1])})
         reshape_loc_node = reshape_loc_op.create_node([locs_node], dict(name='DetectionOutput_Reshape_loc_'))