1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
5 #include <cuda_runtime.h>
12 #include "grid_stride_range.hpp"
13 #include "execution.hpp"
15 #include "../cuda4dnn/csl/stream.hpp"
16 #include "../cuda4dnn/csl/tensor.hpp"
17 #include "../cuda4dnn/csl/span.hpp"
19 #include "../cuda4dnn/kernels/fill.hpp"
21 #include <opencv2/core.hpp>
25 #include <type_traits>
27 using namespace cv::dnn::cuda4dnn::csl;
28 using namespace cv::dnn::cuda4dnn::csl::device;
30 namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
33 template <class T, std::size_t Order,
34 typename std::enable_if<Order == 2 || Order == 3, bool>::type = true> /* Order has been hardcoded; see code */
35 __global__ void max_pooling_with_indices(
36 Span<T> output, Span<T> indices, View<T> input, size_type channels,
37 array<size_type, Order> out_spatial_dims, array<size_type, Order> in_spatial_dims,
38 array<size_type, Order> window_size, array<size_type, Order> strides, array<size_type, Order> padding_left)
40 /* every element in the output is mapped to a window in the input and each thread processes several windows */
41 for (auto idx : grid_stride_range(output.size())) {
42 size_type out_spatial_size = 1;
43 array<index_type, Order> window_idx;
44 for (int i = Order - 1; i >= 0; i--) {
45 window_idx[i] = (idx / out_spatial_size) % out_spatial_dims[i];
46 out_spatial_size *= out_spatial_dims[i];
49 const index_type n = idx / (out_spatial_size * channels);
50 const index_type c = (idx / out_spatial_size) % channels;
52 array<index_type, Order> start;
53 for(int i = 0; i < Order; i++)
54 start[i] = window_idx[i] * strides[i] - padding_left[i];
56 array<index_type, Order> end;
57 for (int i = 0; i < Order; i++) {
59 end[i] = min<index_type>(start[i] + window_size[i], in_spatial_dims[i]);
62 for (int i = 0; i < Order; i++) {
64 start[i] = max(start[i], 0);
67 T max_value = numeric_limits<T>::lowest();
68 index_type max_idx = -1;
70 size_type in_spatial_size = 1;
71 for (int i = 0; i < Order; i++)
72 in_spatial_size *= in_spatial_dims[i];
74 const auto outer_offset = (n * channels + c) * in_spatial_size;
76 array<index_type, Order> idx;
77 for (idx[0] = start[0]; idx[0] != end[0]; idx[0]++) {
78 for (idx[1] = start[1]; idx[1] != end[1]; idx[1]++) {
79 index_type offset = 0;
80 index_type stride = 1;
81 for (int i = Order - 1; i >= 0; i--) {
82 offset += stride * idx[i];
83 stride *= in_spatial_dims[i];
86 if (input[outer_offset + offset] > max_value) {
88 max_value = input[outer_offset + offset];
92 } else if(Order == 3) {
93 array<index_type, Order> idx;
94 for (idx[0] = start[0]; idx[0] != end[0]; idx[0]++) {
95 for (idx[1] = start[1]; idx[1] != end[1]; idx[1]++) {
96 for (idx[2] = start[2]; idx[2] != end[2]; idx[2]++) {
97 index_type offset = 0;
98 index_type stride = 1;
99 for (int i = Order - 1; i >= 0; i--) {
100 offset += stride * idx[i];
101 stride *= in_spatial_dims[i];
104 if (input[outer_offset + offset] > max_value) {
106 max_value = input[outer_offset + offset];
113 output[idx] = max_value;
114 indices[idx] = max_idx;
118 template <class T, std::size_t Order>
119 __global__ void max_unpooling(
120 Span<T> output, View<T> input, View<T> indices, size_type channels,
121 array<size_type, Order> out_spatial_dims, array<size_type, Order> in_spatial_dims,
122 array<size_type, Order> window_size, array<size_type, Order> strides, array<size_type, Order> padding_left)
124 /* the output has already been zero filled */
125 /* Every input value represents a window in the output. The max unpooling operation
126 * copies the input value to exactly one location in the output window which is given
127 * by the indices tensor.
129 for (auto idx : grid_stride_range(input.size())) {
130 size_type in_spatial_size = 1;
131 array<index_type, Order> window_idx;
132 for (int i = Order - 1; i >= 0; i--) {
133 window_idx[i] = (idx / in_spatial_size) % in_spatial_dims[i];
134 in_spatial_size *= in_spatial_dims[i];
137 const index_type n = idx / (in_spatial_size * channels);
138 const index_type c = (idx / in_spatial_size) % channels;
140 array<index_type, Order> start;
141 for (int i = 0; i < Order; i++) {
144 start[i] = max(0, min(window_idx[i] * strides[i] - padding_left[i], out_spatial_dims[i] - 1));
147 size_type out_spatial_size = 1;
148 for (int i = 0; i < Order; i++)
149 out_spatial_size *= out_spatial_dims[i];
151 index_type outer_offset = (n * channels + c) * out_spatial_size;
152 output[outer_offset + static_cast<index_type>(indices[idx])] = input[idx];
157 template <class T, std::size_t Order> static
158 void launch_max_pooling_kernel(
159 const Stream& stream,
160 Span<T> output, Span<T> indices, View<T> input, std::size_t channels,
161 const std::vector<std::size_t>& out_spatial_dims, const std::vector<std::size_t>& in_spatial_dims,
162 const std::vector<std::size_t>& window_size,
163 const std::vector<std::size_t>& strides, const std::vector<std::size_t>& padding_left)
165 CV_Assert(indices.size() == output.size());
166 CV_Assert(out_spatial_dims.size() == Order);
167 CV_Assert(in_spatial_dims.size() == Order);
168 CV_Assert(window_size.size() == Order);
169 CV_Assert(strides.size() == Order);
170 CV_Assert(padding_left.size() == Order);
172 array<size_type, Order> out_spatial_dims_k, in_spatial_dims_k;
173 out_spatial_dims_k.assign(std::begin(out_spatial_dims), std::end(out_spatial_dims));
174 in_spatial_dims_k.assign(std::begin(in_spatial_dims), std::end(in_spatial_dims));
176 array<size_type, Order> window_size_k, strides_k, padding_left_k;
177 window_size_k.assign(std::begin(window_size), std::end(window_size));
178 strides_k.assign(std::begin(strides), std::end(strides));
179 padding_left_k.assign(std::begin(padding_left), std::end(padding_left));
181 auto kernel = raw::max_pooling_with_indices<T, Order>;
182 auto policy = make_policy(kernel, output.size(), 0, stream);
183 launch_kernel(kernel, policy, output, indices, input, channels,
184 out_spatial_dims_k, in_spatial_dims_k, window_size_k, strides_k, padding_left_k);
188 void max_pooling_with_indices(
189 const Stream& stream,
190 TensorSpan<T> output, TensorSpan<T> indices, TensorView<T> input,
191 const std::vector<std::size_t>& window_size, const std::vector<std::size_t>& strides,
192 const std::vector<std::size_t>& padding_left)
194 CV_Assert(is_shape_same(output, indices));
195 CV_Assert(input.get_axis_size(1) == output.get_axis_size(1));
197 auto order = window_size.size();
198 CV_Assert(strides.size() == order);
199 CV_Assert(padding_left.size() == order);
200 CV_Assert(output.rank() == order + 2);
201 CV_Assert(input.rank() == order + 2);
203 std::vector<std::size_t> out_spatial_dims(order), in_spatial_dims(order);
204 for (int i = 0; i < order; i++) {
205 in_spatial_dims[i] = input.get_axis_size(2 + i);
206 out_spatial_dims[i] = output.get_axis_size(2 + i);
209 /* only max_pooling2d and max_pooling3d are supported */
210 CV_Assert(2 <= order && order <= 3);
211 std::size_t channels = input.get_axis_size(1);
213 launch_max_pooling_kernel<T, 3>(stream, output, indices, input, channels,
214 out_spatial_dims, in_spatial_dims, window_size, strides, padding_left);
215 } else if (order == 2) {
216 launch_max_pooling_kernel<T, 2>(stream, output, indices, input, channels,
217 out_spatial_dims, in_spatial_dims, window_size, strides, padding_left);
221 template void max_pooling_with_indices(const Stream&,
222 TensorSpan<__half>, TensorSpan<__half>, TensorView<__half>,
223 const std::vector<std::size_t>&, const std::vector<std::size_t>&,
224 const std::vector<std::size_t>&);
226 template void max_pooling_with_indices(const Stream&,
227 TensorSpan<float>, TensorSpan<float>, TensorView<float>,
228 const std::vector<std::size_t>&, const std::vector<std::size_t>&,
229 const std::vector<std::size_t>&);
231 template <class T, std::size_t Order> static
232 void launch_max_unpooling_kernel(
233 const Stream& stream,
234 Span<T> output, View<T> input, View<T> indices, std::size_t channels,
235 const std::vector<std::size_t>& out_spatial_dims, const std::vector<std::size_t>& in_spatial_dims,
236 const std::vector<std::size_t>& window_size,
237 const std::vector<std::size_t>& strides, const std::vector<std::size_t>& padding_left)
239 CV_Assert(out_spatial_dims.size() == Order);
240 CV_Assert(in_spatial_dims.size() == Order);
241 CV_Assert(window_size.size() == Order);
242 CV_Assert(strides.size() == Order);
243 CV_Assert(padding_left.size() == Order);
244 CV_Assert(indices.size() == input.size());
246 array<size_type, Order> out_spatial_dims_k, in_spatial_dims_k;
247 out_spatial_dims_k.assign(std::begin(out_spatial_dims), std::end(out_spatial_dims));
248 in_spatial_dims_k.assign(std::begin(in_spatial_dims), std::end(in_spatial_dims));
250 array<size_type, Order> window_size_k, strides_k, padding_left_k;
251 window_size_k.assign(std::begin(window_size), std::end(window_size));
252 strides_k.assign(std::begin(strides), std::end(strides));
253 padding_left_k.assign(std::begin(padding_left), std::end(padding_left));
255 auto kernel = raw::max_unpooling<T, Order>;
256 auto policy = make_policy(kernel, input.size(), 0, stream);
257 launch_kernel(kernel, policy, output, input, indices, channels,
258 out_spatial_dims_k, in_spatial_dims_k, window_size_k, strides_k, padding_left_k);
263 const Stream& stream,
264 TensorSpan<T> output, TensorView<T> input, TensorView<T> indices,
265 const std::vector<std::size_t>& window_size, const std::vector<std::size_t>& strides,
266 const std::vector<std::size_t>& padding_left)
268 CV_Assert(is_shape_same(input, indices));
269 CV_Assert(input.get_axis_size(1) == output.get_axis_size(1));
271 auto order = window_size.size();
272 CV_Assert(strides.size() == order);
273 CV_Assert(padding_left.size() == order);
274 CV_Assert(output.rank() == order + 2);
275 CV_Assert(input.rank() == order + 2);
277 std::vector<std::size_t> out_spatial_dims(order), in_spatial_dims(order);
278 for (int i = 0; i < order; i++) {
279 in_spatial_dims[i] = input.get_axis_size(2 + i);
280 out_spatial_dims[i] = output.get_axis_size(2 + i);
283 kernels::fill<T>(stream, output, 0.0);
285 /* only max_unpooling2d and max_unpooling3d are supported */
286 CV_Assert(2 <= order && order <= 3);
287 std::size_t channels = input.get_axis_size(1);
289 launch_max_unpooling_kernel<T, 3>(stream, output, input, indices, channels,
290 out_spatial_dims, in_spatial_dims, window_size, strides, padding_left);
291 } else if (order == 2) {
292 launch_max_unpooling_kernel<T, 2>(stream, output, input, indices, channels,
293 out_spatial_dims, in_spatial_dims, window_size, strides, padding_left);
297 template void max_unpooling(const Stream&,
298 TensorSpan<__half>, TensorView<__half>, TensorView<__half>,
299 const std::vector<std::size_t>&, const std::vector<std::size_t>&,
300 const std::vector<std::size_t>&);
302 template void max_unpooling(const Stream&,
303 TensorSpan<float>, TensorView<float>, TensorView<float>,
304 const std::vector<std::size_t>&, const std::vector<std::size_t>&,
305 const std::vector<std::size_t>&);
307 }}}} /* namespace cv::dnn::cuda4dnn::kernels */