Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / api / CPP / roi_pooling.hpp
1 /*
2 // Copyright (c) 2017-2019 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #pragma once
19 #include "pooling.hpp"
20 #include "../C/roi_pooling.h"
21 #include "primitive.hpp"
22
23
24 namespace cldnn
25 {
26 /// @addtogroup cpp_api C++ API
27 /// @{
28 /// @addtogroup cpp_topology Network Topology
29 /// @{
30 /// @addtogroup cpp_primitives Primitives
31 /// @{
32
33 struct roi_pooling : public primitive_base<roi_pooling, CLDNN_PRIMITIVE_DESC(roi_pooling)>
34 {
35     CLDNN_DECLARE_PRIMITIVE(roi_pooling)
36
37     roi_pooling(
38         const primitive_id& id,
39         const primitive_id& input_data,
40         const primitive_id& input_rois,
41         pooling_mode mode,
42         bool position_sensitive,
43         int pooled_width,
44         int pooled_height,
45         float spatial_scale,
46         int output_dim = 0,
47         int spatial_bins_x = 1,
48         int spatial_bins_y = 1,
49         const padding& output_padding = padding()
50         )
51         : primitive_base(id, {input_data, input_rois}, output_padding)
52         , mode(mode)
53         , position_sensitive(position_sensitive)
54         , pooled_width(pooled_width)
55         , pooled_height(pooled_height)
56         , spatial_scale(spatial_scale)
57         , output_dim(output_dim)
58         , spatial_bins_x(spatial_bins_x)
59         , spatial_bins_y(spatial_bins_y)
60     {}
61
62     roi_pooling(const dto* dto)
63         : primitive_base(dto)
64         , mode(static_cast<pooling_mode>(dto->mode))
65         , position_sensitive(dto->position_sensitive)
66         , pooled_width(dto->pooled_width)
67         , pooled_height(dto->pooled_height)
68         , spatial_scale(dto->spatial_scale)
69         , output_dim(dto->output_dim)
70         , spatial_bins_x(dto->spatial_bins_x)
71         , spatial_bins_y(dto->spatial_bins_y)
72     {}
73
74     pooling_mode mode;
75     bool position_sensitive;
76     int pooled_width;
77     int pooled_height;
78     float spatial_scale;
79     int output_dim;
80     int spatial_bins_x;
81     int spatial_bins_y;
82
83 protected:
84     void update_dto(dto& dto) const override
85     {
86         dto.mode = static_cast<int32_t>(mode);
87         dto.position_sensitive = position_sensitive;
88         dto.pooled_width = pooled_width;
89         dto.pooled_height = pooled_height;
90         dto.spatial_scale = spatial_scale;
91         dto.output_dim = output_dim;
92         dto.spatial_bins_x = spatial_bins_x;
93         dto.spatial_bins_y = spatial_bins_y;
94     }
95 };
96
97 /// @}
98 /// @}
99 /// @}
100 }