"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
limitations under the License.
"""
-import networkx as nx
import numpy as np
from extensions.ops.DetectionOutput import DetectionOutput
from extensions.ops.splitv import SplitV
from mo.front.subgraph_matcher import SubgraphMatch
from mo.front.tf.replacement import FrontReplacementFromConfigFileSubGraph
-from mo.graph.graph import Node
+from mo.graph.graph import Node, Graph
from mo.ops.concat import Concat
from mo.ops.const import Const
from mo.ops.eltwise import Eltwise
replacement_id = 'RetinaNetFilteredDetectionsReplacement'
@staticmethod
- def _create_sub(graph: nx.MultiDiGraph, input_1: Node, port_1: int, input_2: Node, port_2: int):
+ def _create_sub(graph: Graph, input_1: Node, port_1: int, input_2: Node, port_2: int):
negate = Power(graph, dict(scale=-1, name=input_2.name + '/negate_'))
add = Eltwise(graph, dict(operation='sum', name=input_1.name + '/add_'))
out_node = add.create_node([(input_1, port_1), negate.create_node([(input_2, port_2)])])
return out_node
- def output_edges_match(self, graph: nx.DiGraph, match: SubgraphMatch, new_sub_graph: dict):
+ def output_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
return {match.output_node(0)[0].id: new_sub_graph['detection_output_node'].id}
- def nodes_to_remove(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
+ def nodes_to_remove(self, graph: Graph, match: SubgraphMatch):
new_nodes_to_remove = match.matched_nodes_names()
new_nodes_to_remove.remove(match.single_input_node(0)[0].id)
new_nodes_to_remove.remove(match.single_input_node(1)[0].id)
new_nodes_to_remove.remove(match.single_input_node(2)[0].id)
return new_nodes_to_remove
- def generate_sub_graph(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
+ def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
reshape_classes_op = Reshape(graph, {'dim': np.array([0, -1])})
reshape_classes_node = reshape_classes_op.create_node([match.single_input_node(1)[0]],
dict(name='do_reshape_classes'))
[priors_node, priors_scale_const_node])
# calculate prior boxes widths and heights
- split_node = SplitV(graph, {'axis': 2, 'size_splits': [1, 1, 1, 1]}).create_node([priors_scale_node])
+ split_node = SplitV(graph, {'axis': 2, 'size_splits': [1, 1, 1, 1], 'out_ports_count': 4}).create_node([priors_scale_node])
priors_width_node = __class__._create_sub(graph, split_node, 2, split_node, 0)
priors_height_node = __class__._create_sub(graph, split_node, 3, split_node, 1)
# concat weights and heights into a single tensor and multiple with the box coordinates regression values
- concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height', 'axis': -1}).create_node(
+ concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height', 'axis': -1, 'in_ports_count': 4}).create_node(
[priors_width_node, priors_height_node, priors_width_node, priors_height_node])
applied_width_height_regressions_node = Eltwise(graph, {'name': 'final_regressions', 'operation': 'mul'}). \
create_node([concat_width_height_node, match.single_input_node(0)[0]])