2 Copyright (c) 2018 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from mo.front.caffe.collect_attributes import merge_attrs
20 from mo.front.common.partial_infer.multi_box_detection import multi_box_detection_infer
21 from mo.front.extractor import FrontExtractorOp
22 from mo.ops.op import Op
25 class DetectionOutputFrontExtractor(FrontExtractorOp):
26 op = 'DetectionOutput'
32 assert pl, 'Protobuf layer can not be empty'
34 param = pl.detection_output_param
36 # TODO rewrite params as complex structures
37 if hasattr(param, 'nms_param'):
38 nms_threshold = param.nms_param.nms_threshold
39 eta = param.nms_param.eta
40 if param.nms_param.top_k == 0:
43 top_k = param.nms_param.top_k
47 "caffe.PriorBoxParameter.CORNER",
48 "caffe.PriorBoxParameter.CENTER_SIZE",
49 "caffe.PriorBoxParameter.CORNER_SIZE"
52 code_type = code_type_values[1]
53 if hasattr(param, 'code_type'):
54 if param.code_type < 1 or param.code_type > 3:
55 log.error("Incorrect value of code_type parameter")
57 code_type = code_type_values[param.code_type]
59 visualize_threshold = param.visualize_threshold if param.visualize_threshold else 0.6
61 resize_mode_values = [
63 "caffe.ResizeParameter.WARP",
64 "caffe.ResizeParameter.FIT_SMALL_SIZE",
65 "caffe.ResizeParameter.FIT_LARGE_SIZE_AND_PAD"
68 if param.save_output_param.resize_param.resize_mode < 1 or param.save_output_param.resize_param.resize_mode > 3:
69 log.error("Incorrect value of resize_mode parameter")
71 resize_mode = resize_mode_values[param.save_output_param.resize_param.resize_mode]
75 "caffe.ResizeParameter.CONSTANT",
76 "caffe.ResizeParameter.MIRRORED",
77 "caffe.ResizeParameter.REPEAT_NEAREST"
80 if param.save_output_param.resize_param.pad_mode < 1 or param.save_output_param.resize_param.pad_mode > 3:
81 log.error("Incorrect value of pad_mode parameter")
83 pad_mode = pad_mode_values[param.save_output_param.resize_param.pad_mode]
85 interp_mode_values = [
87 "caffe.ResizeParameter.LINEAR",
88 "caffe.ResizeParameter.AREA",
89 "caffe.ResizeParameter.NEAREST",
90 "caffe.ResizeParameter.CUBIC",
91 "caffe.ResizeParameter.LANCZOS4"
94 for x in param.save_output_param.resize_param.interp_mode:
96 log.error("Incorrect value of interp_mode parameter")
98 interp_mode += interp_mode_values[x]
101 'num_classes': param.num_classes,
102 'share_location': int(param.share_location),
103 'background_label_id': param.background_label_id,
104 'code_type': code_type,
105 'variance_encoded_in_target': int(param.variance_encoded_in_target),
106 'keep_top_k': param.keep_top_k,
107 'confidence_threshold': param.confidence_threshold,
108 'visualize': param.visualize,
109 'visualize_threshold': visualize_threshold,
110 'save_file': param.save_file,
112 'nms_threshold': nms_threshold,
116 'output_directory': param.save_output_param.output_directory,
117 'output_name_prefix': param.save_output_param.output_name_prefix,
118 'output_format': param.save_output_param.output_format,
119 'label_map_file': param.save_output_param.label_map_file,
120 'name_size_file': param.save_output_param.name_size_file,
121 'num_test_image': param.save_output_param.num_test_image,
122 # save_output_param.resize_param
123 'prob': param.save_output_param.resize_param.prob,
124 'resize_mode': resize_mode,
125 'height': param.save_output_param.resize_param.height,
126 'width': param.save_output_param.resize_param.width,
127 'height_scale': param.save_output_param.resize_param.height_scale,
128 'width_scale': param.save_output_param.resize_param.width_scale,
129 'pad_mode': pad_mode,
130 'pad_value': ','.join(str(x) for x in param.save_output_param.resize_param.pad_value),
131 'interp_mode': interp_mode,
132 'input_width': param.input_width,
133 'input_height': param.input_height,
134 'normalized': int(param.normalized)
137 mapping_rule = merge_attrs(param, attrs)
139 # force setting infer function because it doesn't exist in proto so merge_attrs will not set it
140 mapping_rule.update({'infer': multi_box_detection_infer})
142 # update the attributes of the node
143 Op.get_op_class_by_name(__class__.op).update_node_stat(node, mapping_rule)
144 return __class__.enabled