"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
limitations under the License.
"""
-import networkx as nx
import numpy as np
from extensions.front.standalone_const_eraser import StandaloneConstEraser
from extensions.ops.DetectionOutput import DetectionOutput
from mo.front.subgraph_matcher import SubgraphMatch
from mo.front.tf.replacement import FrontReplacementFromConfigFileSubGraph
-from mo.graph.graph import Node
+from mo.graph.graph import Node, Graph
from mo.ops.op import PermuteAttrs
from mo.ops.output import Output
from mo.ops.reshape import Reshape
def run_before(self):
return [StandaloneConstEraser]
- def nodes_to_remove(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
+ def nodes_to_remove(self, graph: Graph, match: SubgraphMatch):
return []
- def generate_sub_graph(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
+ def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
# IE DetectionOutput layer consumes flattened confidences and locations tensors.
# That is why we add reshapes before them.
locs_node = match.single_input_node(0)
conf_node = match.single_input_node(1)
prior_boxes_node = match.single_input_node(2)
+ locs_out_nodes = locs_node[0].out_nodes()
+ assert len(locs_out_nodes) == 1
+ locs_out_node = locs_out_nodes[list(locs_out_nodes.keys())[0]]
+ assert locs_out_node.op == "OpOutput", locs_out_node.op
+ graph.remove_node(locs_out_node.id)
+
+ conf_out_nodes = conf_node[0].out_nodes()
+ assert len(conf_out_nodes) == 1
+ conf_out_node = conf_out_nodes[list(conf_out_nodes.keys())[0]]
+ assert conf_out_node.op == "OpOutput", conf_out_node.op
+ graph.remove_node(conf_out_node.id)
+
# reshape operation to flatten confidence tensor
reshape_loc_op = Reshape(graph, {'dim': np.array([0, -1])})
reshape_loc_node = reshape_loc_op.create_node([locs_node], dict(name='DetectionOutput_Reshape_loc_'))