Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / mxnet / add_input_data_to_prior_boxes.py
1 """
2  Copyright (c) 2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 from mo.front.common.replacement import FrontReplacementPattern
17 from mo.graph.graph import Graph, Node
18
19
20 class AddInputDataToPriorBoxes(FrontReplacementPattern):
21     enabled = True
22
23     def run_before(self):
24         from extensions.front.create_tensor_nodes import CreateTensorNodes
25         return [CreateTensorNodes]
26
27     def run_after(self):
28         from extensions.front.pass_separator import FrontFinish
29         return [FrontFinish]
30
31     @staticmethod
32     def add_input_data_to_prior_boxes(graph: Graph, input_names: str = ''):
33         """
34         PriorBox layer has data input unlike mxnet.
35         Need to add data input to _contrib_MultiBoxPrior for
36         for correct conversion to PriorBox layer.
37
38         Parameters
39         ----------
40         graph : Graph
41            Graph with loaded model.
42         """
43         if not input_names:
44             input_names = ('data',)
45         else:
46             input_names = input_names.split(',')
47
48         input_nodes = {}
49         for node in graph.nodes():
50             node = Node(graph, node)
51             if node.has_valid('op') and node.name in input_names:
52                 input_nodes.update({node.id: node})
53
54         if len(input_nodes) > 0:
55             for node in graph.nodes():
56                 node = Node(graph, node)
57                 if node.has_valid('op') and node.op == '_contrib_MultiBoxPrior':
58                     node.add_input_port(idx=1)
59                     graph.create_edge(list(input_nodes.values())[0], node, out_port=0, in_port=1)
60
61     def find_and_replace_pattern(self, graph: Graph):
62         self.add_input_data_to_prior_boxes(graph, graph.graph['cmd_params'].input)