2 // Copyright (c) 2017-2019 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
22 #include "../C/proposal.h"
23 #include "primitive.hpp"
27 /// @addtogroup cpp_api C++ API
29 /// @addtogroup cpp_topology Network Topology
31 /// @addtogroup cpp_primitives Primitives
34 struct proposal : public primitive_base<proposal, CLDNN_PRIMITIVE_DESC(proposal)>
36 CLDNN_DECLARE_PRIMITIVE(proposal)
39 const primitive_id& id,
40 const primitive_id& cls_scores,
41 const primitive_id& bbox_pred,
42 const primitive_id& image_info,
49 const std::vector<float>& ratios_param,
50 const std::vector<float>& scales_param,
51 const padding& output_padding = padding()
53 : primitive_base(id, {cls_scores, bbox_pred, image_info}, output_padding),
54 max_proposals(max_proposals),
55 iou_threshold(iou_threshold),
57 min_bbox_size(min_bbox_size),
58 feature_stride(feature_stride),
59 pre_nms_topn(pre_nms_topn),
60 post_nms_topn(post_nms_topn),
63 coordinates_offset(1.0f),
64 box_coordinate_scale(1.0f),
68 clip_before_nms(true),
69 clip_after_nms(false),
77 const primitive_id& id,
78 const primitive_id& cls_scores,
79 const primitive_id& bbox_pred,
80 const primitive_id& image_info,
88 const std::vector<float>& ratios_param,
89 const std::vector<float>& scales_param,
90 float coordinates_offset,
91 float box_coordinate_scale,
100 const padding& output_padding = padding()
102 : primitive_base(id, {cls_scores, bbox_pred, image_info}, output_padding),
103 max_proposals(max_proposals),
104 iou_threshold(iou_threshold),
105 base_bbox_size(base_bbox_size),
106 min_bbox_size(min_bbox_size),
107 feature_stride(feature_stride),
108 pre_nms_topn(pre_nms_topn),
109 post_nms_topn(post_nms_topn),
110 ratios(ratios_param),
111 scales(scales_param),
112 coordinates_offset(coordinates_offset),
113 box_coordinate_scale(box_coordinate_scale),
114 box_size_scale(box_size_scale),
116 initial_clip(initial_clip),
117 clip_before_nms(clip_before_nms),
118 clip_after_nms(clip_after_nms),
119 round_ratios(round_ratios),
120 shift_anchors(shift_anchors),
125 proposal(const dto* dto) :
127 max_proposals(dto->max_proposals),
128 iou_threshold(dto->iou_threshold),
129 base_bbox_size(dto->base_bbox_size),
130 min_bbox_size(dto->min_bbox_size),
131 feature_stride(dto->feature_stride),
132 pre_nms_topn(dto->pre_nms_topn),
133 post_nms_topn(dto->post_nms_topn),
134 ratios(float_arr_to_vector(dto->ratios)),
135 scales(float_arr_to_vector(dto->scales)),
136 coordinates_offset(dto->coordinates_offset),
137 box_coordinate_scale(dto->box_coordinate_scale),
138 box_size_scale(dto->box_size_scale),
139 swap_xy(dto->swap_xy != 0),
140 initial_clip(dto->initial_clip != 0),
141 clip_before_nms(dto->clip_before_nms != 0),
142 clip_after_nms(dto->clip_after_nms != 0),
143 round_ratios(dto->round_ratios != 0),
144 shift_anchors(dto->shift_anchors != 0),
145 normalize(dto->normalize != 0)
156 std::vector<float> ratios;
157 std::vector<float> scales;
158 float coordinates_offset;
159 float box_coordinate_scale;
160 float box_size_scale;
163 bool clip_before_nms;
170 void update_dto(dto& dto) const override
172 dto.max_proposals = max_proposals;
173 dto.iou_threshold = iou_threshold;
174 dto.base_bbox_size = base_bbox_size;
175 dto.min_bbox_size = min_bbox_size;
176 dto.feature_stride = feature_stride;
177 dto.pre_nms_topn = pre_nms_topn;
178 dto.post_nms_topn = post_nms_topn;
179 dto.ratios = float_vector_to_arr(ratios);
180 dto.scales = float_vector_to_arr(scales);
181 dto.coordinates_offset = coordinates_offset;
182 dto.box_coordinate_scale = box_coordinate_scale;
183 dto.box_size_scale = box_size_scale;
184 dto.swap_xy = swap_xy;
185 dto.initial_clip = initial_clip;
186 dto.clip_before_nms = clip_before_nms;
187 dto.clip_after_nms = clip_after_nms;
188 dto.round_ratios = round_ratios;
189 dto.shift_anchors = shift_anchors;
190 dto.normalize = normalize;