2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
16 from mo.front.common.replacement import FrontReplacementPattern
17 from mo.graph.graph import Graph, Node
20 class AddInputDataToPriorBoxes(FrontReplacementPattern):
24 from extensions.front.create_tensor_nodes import CreateTensorNodes
25 return [CreateTensorNodes]
28 from extensions.front.pass_separator import FrontFinish
32 def add_input_data_to_prior_boxes(graph: Graph, input_names: str = ''):
34 PriorBox layer has data input unlike mxnet.
35 Need to add data input to _contrib_MultiBoxPrior for
36 for correct conversion to PriorBox layer.
41 Graph with loaded model.
44 input_names = ('data',)
46 input_names = input_names.split(',')
49 for node in graph.nodes():
50 node = Node(graph, node)
51 if node.has_valid('op') and node.name in input_names:
52 input_nodes.update({node.id: node})
54 if len(input_nodes) > 0:
55 for node in graph.nodes():
56 node = Node(graph, node)
57 if node.has_valid('op') and node.op == '_contrib_MultiBoxPrior':
58 node.add_input_port(idx=1)
59 graph.create_edge(list(input_nodes.values())[0], node, out_port=0, in_port=1)
61 def find_and_replace_pattern(self, graph: Graph):
62 self.add_input_data_to_prior_boxes(graph, graph.graph['cmd_params'].input)