1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
21 #include <mkldnn_types.h>
24 #include "c_types_map.hpp"
25 #include "type_helpers.hpp"
27 #include "ref_roi_pooling.hpp"
33 template <impl::data_type_t data_type>
34 void ref_roi_pooling_fwd_t<data_type>::execute_forward_generic() const {
38 const memory_desc_wrapper dst_d(pd()->dst_pd());
39 memory_desc_wrapper src_data_d = pd()->src_pd(data_idx);
40 memory_desc_wrapper src_roi_d = pd()->src_pd(roi_idx);
42 if (src_roi_d.dims()[0] < src_data_d.dims()[0]) {
46 src_data_d = pd()->src_pd(data_idx);
47 src_roi_d = pd()->src_pd(roi_idx);
50 auto dst = reinterpret_cast<data_t*>(this->memory(0));
51 const data_t* src_data = reinterpret_cast<const data_t*>(this->input_memory(data_idx));
52 const data_t* src_roi = reinterpret_cast<const data_t*>(this->input_memory(roi_idx));
54 int C = src_data_d.dims()[1];
55 int H = src_data_d.dims()[2];
56 int W = src_data_d.dims()[3];
58 int ROIS = src_roi_d.dims()[0];
60 double spatial_scale = pd()->spatialScale();
61 int pooled_h = pd()->pooledH();
62 int pooled_w = pd()->pooledW();
64 for (size_t i = 0; i < dst_d.size() / sizeof(data_t); i++) {
69 for (; real_rois < ROIS; real_rois++) {
71 if(src_roi_d.ndims() == 4) {
72 roi_off = src_roi_d.off(real_rois, 0, 0, 0);
75 roi_off = src_roi_d.off(real_rois, 0);
78 const data_t* src_roi_ptr = &src_roi[roi_off];
79 int roi_batch_ind = src_roi_ptr[0];
80 if (roi_batch_ind == -1) {
85 for (; n < real_rois; ++n) {
87 if(src_roi_d.ndims() == 4) {
88 roi_off = src_roi_d.off(n, 0, 0, 0);
91 roi_off = src_roi_d.off(n, 0);
94 const data_t* src_roi_ptr = &src_roi[roi_off];
95 int roi_batch_ind = src_roi_ptr[0];
97 if (pd()->desc()->alg_kind == mkldnn_roi_pooling_max) {
98 int roi_start_w = round(src_roi_ptr[1] * spatial_scale);
99 int roi_start_h = round(src_roi_ptr[2] * spatial_scale);
100 int roi_end_w = round(src_roi_ptr[3] * spatial_scale);
101 int roi_end_h = round(src_roi_ptr[4] * spatial_scale);
103 int roi_height = std::max(roi_end_h - roi_start_h + 1, 1);
104 int roi_width = std::max(roi_end_w - roi_start_w + 1, 1);
107 for (int c = 0; c < C; ++c) {
108 for (int ph = 0; ph < pooled_h; ++ph) {
109 for (int pw = 0; pw < pooled_w; ++pw) {
110 int hstart = (ph * roi_height) / pooled_h;
111 if ((hstart * pooled_h) > (ph * roi_height)) {
115 int wstart = (pw * roi_width) / pooled_w;
116 if ((wstart * pooled_w) > (pw * roi_width)) {
120 int hend = ((ph + 1) * roi_height) / pooled_h;
121 if ((hend * pooled_h) < ((ph + 1) * roi_height)) {
125 int wend = ((pw + 1) * roi_width) / pooled_w;
126 if ((wend * pooled_w) < ((pw + 1) * roi_width)) {
130 hstart = std::min(std::max(hstart + roi_start_h, 0), H);
131 hend = std::min(std::max(hend + roi_start_h, 0), H);
132 wstart = std::min(std::max(wstart + roi_start_w, 0), W);
133 wend = std::min(std::max(wend + roi_start_w, 0), W);
135 bool is_empty = (hend <= hstart) || (wend <= wstart);
137 const int pool_index = dst_d.off(n, c, ph, pw);
143 for (int h = hstart; h < hend; ++h) {
144 for (int w = wstart; w < wend; ++w) {
145 data_t batch_data = src_data[src_data_d.off(roi_batch_ind, c, h, w)];
147 if (batch_data > dst[pool_index]) {
148 dst[pool_index] = batch_data;
155 } else if (pd()->desc()->alg_kind == mkldnn_roi_pooling_bilinear) {
156 float roi_start_w_ = src_roi_ptr[1];
157 float roi_start_h_ = src_roi_ptr[2];
158 float roi_end_w_ = src_roi_ptr[3];
159 float roi_end_h_ = src_roi_ptr[4];
161 float height_scale = (roi_end_h_ - roi_start_h_) * (H - 1) / (pooled_h - 1);
162 float width_scale = (roi_end_w_ - roi_start_w_) * (W - 1) / (pooled_w - 1);
164 for (int c = 0; c < C; ++c) {
165 for (int ph = 0; ph < pooled_h; ++ph) {
166 for (int pw = 0; pw < pooled_w; ++pw) {
167 float in_y = (ph * height_scale + roi_start_h_ * (H - 1));
168 float in_x = (pw * width_scale + roi_start_w_ * (W - 1));
170 if (in_y < 0 || in_y > H - 1 || in_x < 0 || in_x > W - 1) {
171 dst[dst_d.off(n, c, ph, pw)] = 0;
173 int top_y_index = static_cast<int>(floorf(in_y));
174 int bottom_y_index = static_cast<int>(ceilf(in_y));
175 int left_x_index = static_cast<int>(floorf(in_x));
176 int right_x_index = static_cast<int>(ceilf(in_x));
178 if (right_x_index > W - 1)
179 right_x_index = W - 1;
181 if (bottom_y_index > H - 1)
182 bottom_y_index = H - 1;
184 const float top_left = src_data[src_data_d.off(roi_batch_ind, c, top_y_index, left_x_index)];
185 const float top_right = src_data[src_data_d.off(roi_batch_ind, c, top_y_index, right_x_index)];
186 const float bottom_left = src_data[src_data_d.off(roi_batch_ind, c, bottom_y_index, left_x_index)];
187 const float bottom_right = src_data[src_data_d.off(roi_batch_ind, c, bottom_y_index, right_x_index)];
189 const float top = top_left + (top_right - top_left) * (in_x - left_x_index);
190 const float bottom = bottom_left + (bottom_right - bottom_left) * (in_x - left_x_index);
192 dst[dst_d.off(n, c, ph, pw)] = top + (bottom - top) * (in_y - top_y_index);
200 for (; n < ROIS; ++n) {
201 for (int c = 0; c < C; ++c) {
202 for (int ph = 0; ph < pooled_h; ++ph) {
203 for (int pw = 0; pw < pooled_w; ++pw) {
204 dst[dst_d.off(n, c, ph, pw)] = 0;
211 template struct ref_roi_pooling_fwd_t<data_type::f32>;
217 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s