dd8fb96bc28acafcbca4eb97a89eb58f2e859d3e
[platform/upstream/dldt.git] / model-optimizer / extensions / front / caffe / detection_output.py
1 """
2  Copyright (c) 2018 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 from mo.front.caffe.collect_attributes import merge_attrs
20 from mo.front.common.partial_infer.multi_box_detection import multi_box_detection_infer
21 from mo.front.extractor import FrontExtractorOp
22 from mo.ops.op import Op
23
24
25 class DetectionOutputFrontExtractor(FrontExtractorOp):
26     op = 'DetectionOutput'
27     enabled = True
28
29     @staticmethod
30     def extract(node):
31         pl = node.pb
32         assert pl, 'Protobuf layer can not be empty'
33
34         param = pl.detection_output_param
35
36         # TODO rewrite params as complex structures
37         if hasattr(param, 'nms_param'):
38             nms_threshold = param.nms_param.nms_threshold
39             eta = param.nms_param.eta
40             if param.nms_param.top_k == 0:
41                 top_k = -1
42             else:
43                 top_k = param.nms_param.top_k
44
45         code_type_values = [
46             "",
47             "caffe.PriorBoxParameter.CORNER",
48             "caffe.PriorBoxParameter.CENTER_SIZE",
49             "caffe.PriorBoxParameter.CORNER_SIZE"
50         ]
51
52         code_type = code_type_values[1]
53         if hasattr(param, 'code_type'):
54             if param.code_type < 1 or param.code_type > 3:
55                 log.error("Incorrect value of code_type parameter")
56                 return
57             code_type = code_type_values[param.code_type]
58
59         visualize_threshold = param.visualize_threshold if param.visualize_threshold else 0.6
60
61         resize_mode_values = [
62             "",
63             "caffe.ResizeParameter.WARP",
64             "caffe.ResizeParameter.FIT_SMALL_SIZE",
65             "caffe.ResizeParameter.FIT_LARGE_SIZE_AND_PAD"
66         ]
67
68         if param.save_output_param.resize_param.resize_mode < 1 or param.save_output_param.resize_param.resize_mode > 3:
69             log.error("Incorrect value of resize_mode parameter")
70             return
71         resize_mode = resize_mode_values[param.save_output_param.resize_param.resize_mode]
72
73         pad_mode_values = [
74             "",
75             "caffe.ResizeParameter.CONSTANT",
76             "caffe.ResizeParameter.MIRRORED",
77             "caffe.ResizeParameter.REPEAT_NEAREST"
78         ]
79
80         if param.save_output_param.resize_param.pad_mode < 1 or param.save_output_param.resize_param.pad_mode > 3:
81             log.error("Incorrect value of pad_mode parameter")
82         else:
83             pad_mode = pad_mode_values[param.save_output_param.resize_param.pad_mode]
84
85         interp_mode_values = [
86             "",
87             "caffe.ResizeParameter.LINEAR",
88             "caffe.ResizeParameter.AREA",
89             "caffe.ResizeParameter.NEAREST",
90             "caffe.ResizeParameter.CUBIC",
91             "caffe.ResizeParameter.LANCZOS4"
92         ]
93         interp_mode = ""
94         for x in param.save_output_param.resize_param.interp_mode:
95             if x < 1 or x > 5:
96                 log.error("Incorrect value of interp_mode parameter")
97                 return
98             interp_mode += interp_mode_values[x]
99
100         attrs = {
101             'num_classes': param.num_classes,
102             'share_location': int(param.share_location),
103             'background_label_id': param.background_label_id,
104             'code_type': code_type,
105             'variance_encoded_in_target': int(param.variance_encoded_in_target),
106             'keep_top_k': param.keep_top_k,
107             'confidence_threshold': param.confidence_threshold,
108             'visualize': param.visualize,
109             'visualize_threshold': visualize_threshold,
110             'save_file': param.save_file,
111             # nms_param
112             'nms_threshold': nms_threshold,
113             'top_k': top_k,
114             'eta': eta,
115             # save_output_param
116             'output_directory': param.save_output_param.output_directory,
117             'output_name_prefix': param.save_output_param.output_name_prefix,
118             'output_format': param.save_output_param.output_format,
119             'label_map_file': param.save_output_param.label_map_file,
120             'name_size_file': param.save_output_param.name_size_file,
121             'num_test_image': param.save_output_param.num_test_image,
122             # save_output_param.resize_param
123             'prob': param.save_output_param.resize_param.prob,
124             'resize_mode': resize_mode,
125             'height': param.save_output_param.resize_param.height,
126             'width': param.save_output_param.resize_param.width,
127             'height_scale': param.save_output_param.resize_param.height_scale,
128             'width_scale': param.save_output_param.resize_param.width_scale,
129             'pad_mode': pad_mode,
130             'pad_value': ','.join(str(x) for x in param.save_output_param.resize_param.pad_value),
131             'interp_mode': interp_mode,
132             'input_width': param.input_width,
133             'input_height': param.input_height,
134             'normalized': int(param.normalized)
135         }
136
137         mapping_rule = merge_attrs(param, attrs)
138
139         # force setting infer function because it doesn't exist in proto so merge_attrs will not set it
140         mapping_rule.update({'infer': multi_box_detection_infer})
141
142         # update the attributes of the node
143         Op.get_op_class_by_name(__class__.op).update_node_stat(node, mapping_rule)
144         return __class__.enabled