2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.graph.graph import Node
24 def multi_box_detection_infer(node: Node):
25 loc_shape = node.in_node(0).shape
26 conf_shape = node.in_node(1).shape
27 prior_boxes_shape = node.in_node(2).shape
30 if node.has('normalized') and not node.normalized:
33 if prior_boxes_shape[-1] % prior_size != 0:
34 log.warning('Amount of confidences "{}" is not divisible by {}'.format(conf_shape[-1], prior_size))
37 num_priors = prior_boxes_shape[-1] // prior_size
38 if not node.has_valid('keep_top_k') or node.keep_top_k == -1:
39 node['keep_top_k'] = num_priors
40 node.graph.node[node.id]['num_classes'] = conf_shape[-1] // num_priors
41 num_loc_classes = node.num_classes
42 if node.has_and_set('share_location') and node.share_location:
45 if loc_shape is None or conf_shape is None or prior_boxes_shape is None:
46 log.warning('Shapes for the Detection Output are not defined')
49 if num_priors * num_loc_classes * 4 != loc_shape[-1]:
50 log.warning('Locations and prior boxes shapes mismatch: "{}" vs "{}"'.format(loc_shape, prior_boxes_shape))
53 if not node.variance_encoded_in_target and prior_boxes_shape[-2] != 2:
54 log.warning('The "-2" dimension of the prior boxes must be 2 but it is "{}".'.format(prior_boxes_shape[-2]))
57 if conf_shape[-1] % num_priors != 0:
58 log.warning('Amount of confidences "{}" is not divisible by amount of priors "{}".'.format(
59 conf_shape[-1], num_priors))
62 log.debug('Inferred amount of classes "{}"'.format(node.num_classes))
63 node.out_node(0).shape = np.array([1, 1, conf_shape[0] * node.keep_top_k, 7], dtype=np.int64)
65 # the line below is needed for the TF framework so the MO will not change the layout
66 node.graph.node[node.out_node(0).id]['nchw_layout'] = True