Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / caffe / detection_output.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 from mo.front.caffe.collect_attributes import merge_attrs
20 from mo.front.common.partial_infer.multi_box_detection import multi_box_detection_infer
21 from mo.front.extractor import FrontExtractorOp
22 from mo.ops.op import Op
23
24
25 class DetectionOutputFrontExtractor(FrontExtractorOp):
26     op = 'DetectionOutput'
27     enabled = True
28
29     @staticmethod
30     def extract(node):
31         pl = node.pb
32         assert pl, 'Protobuf layer can not be empty'
33
34         param = pl.detection_output_param
35
36         # TODO rewrite params as complex structures
37         if hasattr(param, 'nms_param'):
38             nms_threshold = param.nms_param.nms_threshold
39             eta = param.nms_param.eta
40             if param.nms_param.top_k == 0:
41                 top_k = -1
42             else:
43                 top_k = param.nms_param.top_k
44
45         code_type_values = [
46             "",
47             "caffe.PriorBoxParameter.CORNER",
48             "caffe.PriorBoxParameter.CENTER_SIZE",
49             "caffe.PriorBoxParameter.CORNER_SIZE"
50         ]
51
52         code_type = code_type_values[1]
53         if hasattr(param, 'code_type'):
54             if param.code_type < 1 or param.code_type > 3:
55                 log.error("Incorrect value of code_type parameter")
56                 return
57             code_type = code_type_values[param.code_type]
58
59         visualize_threshold = param.visualize_threshold if param.visualize_threshold else 0.6
60
61         resize_mode_values = [
62             "",
63             "caffe.ResizeParameter.WARP",
64             "caffe.ResizeParameter.FIT_SMALL_SIZE",
65             "caffe.ResizeParameter.FIT_LARGE_SIZE_AND_PAD"
66         ]
67
68         if param.save_output_param.resize_param.resize_mode < 1 or param.save_output_param.resize_param.resize_mode > 3:
69             log.error("Incorrect value of resize_mode parameter")
70             return
71         resize_mode = resize_mode_values[param.save_output_param.resize_param.resize_mode]
72
73         pad_mode_values = [
74             "",
75             "caffe.ResizeParameter.CONSTANT",
76             "caffe.ResizeParameter.MIRRORED",
77             "caffe.ResizeParameter.REPEAT_NEAREST"
78         ]
79
80         if param.save_output_param.resize_param.pad_mode < 1 or param.save_output_param.resize_param.pad_mode > 3:
81             log.error("Incorrect value of pad_mode parameter")
82         else:
83             pad_mode = pad_mode_values[param.save_output_param.resize_param.pad_mode]
84
85         interp_mode_values = [
86             "",
87             "caffe.ResizeParameter.LINEAR",
88             "caffe.ResizeParameter.AREA",
89             "caffe.ResizeParameter.NEAREST",
90             "caffe.ResizeParameter.CUBIC",
91             "caffe.ResizeParameter.LANCZOS4"
92         ]
93         interp_mode = ""
94         for x in param.save_output_param.resize_param.interp_mode:
95             if x < 1 or x > 5:
96                 log.error("Incorrect value of interp_mode parameter")
97                 return
98             interp_mode += interp_mode_values[x]
99
100         attrs = {
101             'num_classes': param.num_classes,
102             'share_location': int(param.share_location),
103             'background_label_id': param.background_label_id,
104             'code_type': code_type,
105             'variance_encoded_in_target': int(param.variance_encoded_in_target),
106             'keep_top_k': param.keep_top_k,
107             'confidence_threshold': param.confidence_threshold,
108             'visualize': param.visualize,
109             'visualize_threshold': visualize_threshold,
110             'save_file': param.save_file,
111             # nms_param
112             'nms_threshold': nms_threshold,
113             'top_k': top_k,
114             'eta': eta,
115             # save_output_param
116             'output_directory': param.save_output_param.output_directory,
117             'output_name_prefix': param.save_output_param.output_name_prefix,
118             'output_format': param.save_output_param.output_format,
119             'label_map_file': param.save_output_param.label_map_file,
120             'name_size_file': param.save_output_param.name_size_file,
121             'num_test_image': param.save_output_param.num_test_image,
122             # save_output_param.resize_param
123             'prob': param.save_output_param.resize_param.prob,
124             'resize_mode': resize_mode,
125             'height': param.save_output_param.resize_param.height,
126             'width': param.save_output_param.resize_param.width,
127             'height_scale': param.save_output_param.resize_param.height_scale,
128             'width_scale': param.save_output_param.resize_param.width_scale,
129             'pad_mode': pad_mode,
130             'pad_value': ','.join(str(x) for x in param.save_output_param.resize_param.pad_value),
131             'interp_mode': interp_mode,
132         }
133
134         # these params can be omitted in caffe.proto and in param as consequence,
135         # so check if it is set or set to default
136         fields = [field[0].name for field in param.ListFields()]
137         if 'input_width' in fields:
138             attrs['input_width'] = param.input_width
139         if 'input_height' in fields:
140             attrs['input_height'] = param.input_height
141         if 'normalized' in fields:
142             attrs['normalized'] = int(param.normalized)
143         if 'objectness_score' in fields:
144             attrs['objectness_score'] = param.objectness_score
145
146         mapping_rule = merge_attrs(param, attrs)
147
148         # force setting infer function because it doesn't exist in proto so merge_attrs will not set it
149         mapping_rule.update({'infer': multi_box_detection_infer})
150
151         # update the attributes of the node
152         Op.get_op_class_by_name(__class__.op).update_node_stat(node, mapping_rule)
153         return __class__.enabled