Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / RetinaNetFilteredDetectionsReplacement.py
index a46bb50..b0f6eae 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
  limitations under the License.
 """
 
-import networkx as nx
 import numpy as np
 
 from extensions.ops.DetectionOutput import DetectionOutput
 from extensions.ops.splitv import SplitV
 from mo.front.subgraph_matcher import SubgraphMatch
 from mo.front.tf.replacement import FrontReplacementFromConfigFileSubGraph
-from mo.graph.graph import Node
+from mo.graph.graph import Node, Graph
 from mo.ops.concat import Concat
 from mo.ops.const import Const
 from mo.ops.eltwise import Eltwise
@@ -43,23 +42,23 @@ class RetinaNetFilteredDetectionsReplacement(FrontReplacementFromConfigFileSubGr
     replacement_id = 'RetinaNetFilteredDetectionsReplacement'
 
     @staticmethod
-    def _create_sub(graph: nx.MultiDiGraph, input_1: Node, port_1: int, input_2: Node, port_2: int):
+    def _create_sub(graph: Graph, input_1: Node, port_1: int, input_2: Node, port_2: int):
         negate = Power(graph, dict(scale=-1, name=input_2.name + '/negate_'))
         add = Eltwise(graph, dict(operation='sum', name=input_1.name + '/add_'))
         out_node = add.create_node([(input_1, port_1), negate.create_node([(input_2, port_2)])])
         return out_node
 
-    def output_edges_match(self, graph: nx.DiGraph, match: SubgraphMatch, new_sub_graph: dict):
+    def output_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
         return {match.output_node(0)[0].id: new_sub_graph['detection_output_node'].id}
 
-    def nodes_to_remove(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
+    def nodes_to_remove(self, graph: Graph, match: SubgraphMatch):
         new_nodes_to_remove = match.matched_nodes_names()
         new_nodes_to_remove.remove(match.single_input_node(0)[0].id)
         new_nodes_to_remove.remove(match.single_input_node(1)[0].id)
         new_nodes_to_remove.remove(match.single_input_node(2)[0].id)
         return new_nodes_to_remove
 
-    def generate_sub_graph(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
+    def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
         reshape_classes_op = Reshape(graph, {'dim': np.array([0, -1])})
         reshape_classes_node = reshape_classes_op.create_node([match.single_input_node(1)[0]],
                                                               dict(name='do_reshape_classes'))
@@ -79,12 +78,12 @@ class RetinaNetFilteredDetectionsReplacement(FrontReplacementFromConfigFileSubGr
             [priors_node, priors_scale_const_node])
 
         # calculate prior boxes widths and heights
-        split_node = SplitV(graph, {'axis': 2, 'size_splits': [1, 1, 1, 1]}).create_node([priors_scale_node])
+        split_node = SplitV(graph, {'axis': 2, 'size_splits': [1, 1, 1, 1], 'out_ports_count': 4}).create_node([priors_scale_node])
         priors_width_node = __class__._create_sub(graph, split_node, 2, split_node, 0)
         priors_height_node = __class__._create_sub(graph, split_node, 3, split_node, 1)
 
         # concat weights and heights into a single tensor and multiple with the box coordinates regression values
-        concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height', 'axis': -1}).create_node(
+        concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height', 'axis': -1, 'in_ports_count': 4}).create_node(
             [priors_width_node, priors_height_node, priors_width_node, priors_height_node])
         applied_width_height_regressions_node = Eltwise(graph, {'name': 'final_regressions', 'operation': 'mul'}). \
             create_node([concat_width_height_node, match.single_input_node(0)[0]])