Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / api / CPP / proposal.hpp
1 /*
2 // Copyright (c) 2017-2019 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #pragma once
19
20 #include <vector>
21
22 #include "../C/proposal.h"
23 #include "primitive.hpp"
24
25 namespace cldnn
26 {
27 /// @addtogroup cpp_api C++ API
28 /// @{
29 /// @addtogroup cpp_topology Network Topology
30 /// @{
31 /// @addtogroup cpp_primitives Primitives
32 /// @{
33
34 struct proposal : public primitive_base<proposal, CLDNN_PRIMITIVE_DESC(proposal)>
35 {
36     CLDNN_DECLARE_PRIMITIVE(proposal)
37
38     proposal(
39         const primitive_id& id,
40         const primitive_id& cls_scores,
41         const primitive_id& bbox_pred,
42         const primitive_id& image_info,
43         int max_proposals,
44         float iou_threshold,
45         int min_bbox_size,
46         int feature_stride,
47         int pre_nms_topn,
48         int post_nms_topn,
49         const std::vector<float>& ratios_param,
50         const std::vector<float>& scales_param,
51         const padding& output_padding = padding()
52         )
53         : primitive_base(id, {cls_scores, bbox_pred, image_info}, output_padding),
54                  max_proposals(max_proposals),
55                  iou_threshold(iou_threshold),
56                  base_bbox_size(16),
57                  min_bbox_size(min_bbox_size),
58                  feature_stride(feature_stride),
59                  pre_nms_topn(pre_nms_topn),
60                  post_nms_topn(post_nms_topn),
61                  ratios(ratios_param),
62                  scales(scales_param),
63                  coordinates_offset(1.0f),
64                  box_coordinate_scale(1.0f),
65                  box_size_scale(1.0f),
66                  swap_xy(false),
67                  initial_clip(false),
68                  clip_before_nms(true),
69                  clip_after_nms(false),
70                  round_ratios(true),
71                  shift_anchors(false),
72                  normalize(false)
73     {
74     }
75
76     proposal(
77         const primitive_id& id,
78         const primitive_id& cls_scores,
79         const primitive_id& bbox_pred,
80         const primitive_id& image_info,
81         int max_proposals,
82         float iou_threshold,
83         int base_bbox_size,
84         int min_bbox_size,
85         int feature_stride,
86         int pre_nms_topn,
87         int post_nms_topn,
88         const std::vector<float>& ratios_param,
89         const std::vector<float>& scales_param,
90         float coordinates_offset,
91         float box_coordinate_scale,
92         float box_size_scale,
93         bool swap_xy,
94         bool initial_clip,
95         bool clip_before_nms,
96         bool clip_after_nms,
97         bool round_ratios,
98         bool shift_anchors,
99         bool normalize,
100         const padding& output_padding = padding()
101         )
102         : primitive_base(id, {cls_scores, bbox_pred, image_info}, output_padding),
103                  max_proposals(max_proposals),
104                  iou_threshold(iou_threshold),
105                  base_bbox_size(base_bbox_size),
106                  min_bbox_size(min_bbox_size),
107                  feature_stride(feature_stride),
108                  pre_nms_topn(pre_nms_topn),
109                  post_nms_topn(post_nms_topn),
110                  ratios(ratios_param),
111                  scales(scales_param),
112                  coordinates_offset(coordinates_offset),
113                  box_coordinate_scale(box_coordinate_scale),
114                  box_size_scale(box_size_scale),
115                  swap_xy(swap_xy),
116                  initial_clip(initial_clip),
117                  clip_before_nms(clip_before_nms),
118                  clip_after_nms(clip_after_nms),
119                  round_ratios(round_ratios),
120                  shift_anchors(shift_anchors),
121                  normalize(normalize)
122     {
123     }
124
125     proposal(const dto* dto) :
126         primitive_base(dto),
127         max_proposals(dto->max_proposals),
128         iou_threshold(dto->iou_threshold),
129         base_bbox_size(dto->base_bbox_size),
130         min_bbox_size(dto->min_bbox_size),
131         feature_stride(dto->feature_stride),
132         pre_nms_topn(dto->pre_nms_topn),
133         post_nms_topn(dto->post_nms_topn),
134         ratios(float_arr_to_vector(dto->ratios)),
135         scales(float_arr_to_vector(dto->scales)),
136         coordinates_offset(dto->coordinates_offset),
137         box_coordinate_scale(dto->box_coordinate_scale),
138         box_size_scale(dto->box_size_scale),
139         swap_xy(dto->swap_xy != 0),
140         initial_clip(dto->initial_clip != 0),
141         clip_before_nms(dto->clip_before_nms != 0),
142         clip_after_nms(dto->clip_after_nms != 0),
143         round_ratios(dto->round_ratios != 0),
144         shift_anchors(dto->shift_anchors != 0),
145         normalize(dto->normalize != 0)
146     {
147     }
148
149     int max_proposals;
150     float iou_threshold;
151     int base_bbox_size;
152     int min_bbox_size;
153     int feature_stride;
154     int pre_nms_topn;
155     int post_nms_topn;
156     std::vector<float> ratios;
157     std::vector<float> scales;
158     float coordinates_offset;
159     float box_coordinate_scale;
160     float box_size_scale;
161     bool swap_xy;
162     bool initial_clip;
163     bool clip_before_nms;
164     bool clip_after_nms;
165     bool round_ratios;
166     bool shift_anchors;
167     bool normalize;
168
169 protected:
170     void update_dto(dto& dto) const override
171     {
172         dto.max_proposals = max_proposals;
173         dto.iou_threshold = iou_threshold;
174         dto.base_bbox_size = base_bbox_size;
175         dto.min_bbox_size = min_bbox_size;
176         dto.feature_stride = feature_stride;
177         dto.pre_nms_topn = pre_nms_topn;
178         dto.post_nms_topn = post_nms_topn;
179         dto.ratios = float_vector_to_arr(ratios);
180         dto.scales = float_vector_to_arr(scales);
181         dto.coordinates_offset = coordinates_offset;
182         dto.box_coordinate_scale = box_coordinate_scale;
183         dto.box_size_scale = box_size_scale;
184         dto.swap_xy = swap_xy;
185         dto.initial_clip = initial_clip;
186         dto.clip_before_nms = clip_before_nms;
187         dto.clip_after_nms = clip_after_nms;
188         dto.round_ratios = round_ratios;
189         dto.shift_anchors = shift_anchors;
190         dto.normalize = normalize;
191     }
192 };
193
194 /// @}
195 /// @}
196 /// @}
197 }