1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
5 #include <cuda_runtime.h>
11 #include "vector_traits.hpp"
12 #include "grid_stride_range.hpp"
13 #include "execution.hpp"
15 #include "../cuda4dnn/csl/stream.hpp"
16 #include "../cuda4dnn/csl/span.hpp"
20 using namespace cv::dnn::cuda4dnn::csl;
21 using namespace cv::dnn::cuda4dnn::csl::device;
23 namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
26 template <class T, bool Normalize>
27 __global__ void prior_box(
29 View<float> boxWidth, View<float> boxHeight, View<float> offsetX, View<float> offsetY, float stepX, float stepY,
30 size_type layerWidth, size_type layerHeight,
31 size_type imageWidth, size_type imageHeight)
33 /* each box consists of two pair of coordinates and hence 4 values in total */
34 /* since the entire output consists (first channel at least) of these boxes,
35 * we are garunteeed that the output is aligned to a boundary of 4 values
37 using vector_type = get_vector_type_t<T, 4>;
38 auto output_vPtr = vector_type::get_pointer(output.data());
40 /* num_points contains the number of points in the feature map of interest
41 * each iteration of the stride loop selects a point and generates prior boxes for it
43 size_type num_points = layerWidth * layerHeight;
44 for (auto idx : grid_stride_range(num_points)) {
45 const index_type x = idx % layerWidth,
48 index_type output_offset_v4 = idx * offsetX.size() * boxWidth.size();
49 for (int i = 0; i < boxWidth.size(); i++) {
50 for (int j = 0; j < offsetX.size(); j++) {
51 float center_x = (x + offsetX[j]) * stepX;
52 float center_y = (y + offsetY[j]) * stepY;
56 vec.data[0] = (center_x - boxWidth[i] * 0.5f) / imageWidth;
57 vec.data[1] = (center_y - boxHeight[i] * 0.5f) / imageHeight;
58 vec.data[2] = (center_x + boxWidth[i] * 0.5f) / imageWidth;
59 vec.data[3] = (center_y + boxHeight[i] * 0.5f) / imageHeight;
61 vec.data[0] = center_x - boxWidth[i] * 0.5f;
62 vec.data[1] = center_y - boxHeight[i] * 0.5f;
63 vec.data[2] = center_x + boxWidth[i] * 0.5f - 1.0f;
64 vec.data[3] = center_y + boxHeight[i] * 0.5f - 1.0f;
67 v_store(output_vPtr[output_offset_v4], vec);
75 __global__ void prior_box_clip(Span<T> output) {
76 for (auto i : grid_stride_range(output.size())) {
78 output[i] = clamp<T>(output[i], 0.0, 1.0);
83 __global__ void prior_box_set_variance1(Span<T> output, float variance) {
84 using vector_type = get_vector_type_t<T, 4>;
85 auto output_vPtr = vector_type::get_pointer(output.data());
86 for (auto i : grid_stride_range(output.size() / 4)) {
88 for (int j = 0; j < 4; j++)
89 vec.data[j] = variance;
90 v_store(output_vPtr[i], vec);
95 __global__ void prior_box_set_variance4(Span<T> output, array<float, 4> variance) {
96 using vector_type = get_vector_type_t<T, 4>;
97 auto output_vPtr = vector_type::get_pointer(output.data());
98 for (auto i : grid_stride_range(output.size() / 4)) {
100 for(int j = 0; j < 4; j++)
101 vec.data[j] = variance[j];
102 v_store(output_vPtr[i], vec);
107 template <class T, bool Normalize> static
108 void launch_prior_box_kernel(
109 const Stream& stream,
110 Span<T> output, View<float> boxWidth, View<float> boxHeight, View<float> offsetX, View<float> offsetY, float stepX, float stepY,
111 std::size_t layerWidth, std::size_t layerHeight, std::size_t imageWidth, std::size_t imageHeight)
113 auto num_points = layerWidth * layerHeight;
114 auto kernel = raw::prior_box<T, Normalize>;
115 auto policy = make_policy(kernel, num_points, 0, stream);
116 launch_kernel(kernel, policy,
117 output, boxWidth, boxHeight, offsetX, offsetY, stepX, stepY,
118 layerWidth, layerHeight, imageWidth, imageHeight);
122 void generate_prior_boxes(
123 const Stream& stream,
125 View<float> boxWidth, View<float> boxHeight, View<float> offsetX, View<float> offsetY, float stepX, float stepY,
126 std::vector<float> variance,
127 std::size_t numPriors,
128 std::size_t layerWidth, std::size_t layerHeight,
129 std::size_t imageWidth, std::size_t imageHeight,
130 bool normalize, bool clip)
133 launch_prior_box_kernel<T, true>(
134 stream, output, boxWidth, boxHeight, offsetX, offsetY, stepX, stepY,
135 layerWidth, layerHeight, imageWidth, imageHeight
138 launch_prior_box_kernel<T, false>(
139 stream, output, boxWidth, boxHeight, offsetX, offsetY, stepX, stepY,
140 layerWidth, layerHeight, imageWidth, imageHeight
144 std::size_t channel_size = layerHeight * layerWidth * numPriors * 4;
145 CV_Assert(channel_size * 2 == output.size());
148 auto output_span_c1 = Span<T>(output.data(), channel_size);
149 auto kernel = raw::prior_box_clip<T>;
150 auto policy = make_policy(kernel, output_span_c1.size(), 0, stream);
151 launch_kernel(kernel, policy, output_span_c1);
154 auto output_span_c2 = Span<T>(output.data() + channel_size, channel_size);
155 if (variance.size() == 1) {
156 auto kernel = raw::prior_box_set_variance1<T>;
157 auto policy = make_policy(kernel, output_span_c2.size() / 4, 0, stream);
158 launch_kernel(kernel, policy, output_span_c2, variance[0]);
160 array<float, 4> variance_k;
161 variance_k.assign(std::begin(variance), std::end(variance));
162 auto kernel = raw::prior_box_set_variance4<T>;
163 auto policy = make_policy(kernel, output_span_c2.size() / 4, 0, stream);
164 launch_kernel(kernel, policy, output_span_c2, variance_k);
168 template void generate_prior_boxes(const Stream&, Span<__half>, View<float>, View<float>, View<float>, View<float>, float, float,
169 std::vector<float>, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t, bool, bool);
171 template void generate_prior_boxes(const Stream&, Span<float>, View<float>, View<float>, View<float>, View<float>, float, float,
172 std::vector<float>, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t, bool, bool);
174 }}}} /* namespace cv::dnn::cuda4dnn::kernels */