1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
5 #include <cuda_runtime.h>
9 #include "grid_stride_range.hpp"
10 #include "execution.hpp"
12 #include "vector_traits.hpp"
14 #include "../cuda4dnn/csl/stream.hpp"
15 #include "../cuda4dnn/csl/span.hpp"
17 #include <opencv2/core.hpp>
21 using namespace cv::dnn::cuda4dnn::csl;
22 using namespace cv::dnn::cuda4dnn::csl::device;
24 namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
28 __global__ void sigmoid_strided(Span<T> output, View<T> input, size_type n, size_type stride, size_type offset) {
29 /* - the input is divided into equal blocks strided by `stride`
30 * - we must apply sigmoid to a continuous range of `n` values starting from `offset` in every block
32 for (auto i : grid_stride_range(n * output.size() / stride)) {
33 auto block_idx = i / n;
34 auto index = block_idx * stride + offset + (i % n);
36 using device::sigmoid;
37 output[index] = sigmoid(input[index]);
42 __global__ void softmax_strided(Span<T> output, View<T> input, size_type n, size_type stride, size_type offset_) {
43 for (auto idx : grid_stride_range(output.size() / stride)) {
44 index_type offset = idx * stride + offset_;
46 auto largest = numeric_limits<T>::lowest();
47 for (int i = 0; i < n; i++) {
49 largest = max(largest, output[offset + i]);
53 for (int i = 0; i < n; i++) {
55 auto temp = exp(output[offset + i] - largest);
57 output[offset + i] = temp;
60 for (int i = 0; i < n; i++) {
61 output[offset + i] /= sum;
67 __global__ void region_finalize(Span<T> output, View<T> input, View<T> bias,
68 T object_prob_cutoff, T class_prob_cutoff,
69 size_type height_norm, size_type width_norm,
70 size_type rows, size_type cols,
71 size_type boxes_per_cell,
75 for (auto box_index : grid_stride_range(output.size() / box_size)) {
76 auto box_of_the_cell = box_index % boxes_per_cell; /* box number within a cell */
77 auto box_offset = box_index * box_size;
79 auto batch_inner_size = rows * cols * boxes_per_cell;
80 auto row_inner_size = cols * boxes_per_cell;
81 auto col_inner_size = boxes_per_cell;
83 auto y = (box_index % batch_inner_size) / row_inner_size;
84 auto x = (box_index % row_inner_size) / col_inner_size;
86 using device::sigmoid;
88 output[box_offset + 0] = (T(x) + sigmoid(input[box_offset + 0])) / T(cols);
89 output[box_offset + 1] = (T(y) + sigmoid(input[box_offset + 1])) / T(rows);
90 output[box_offset + 2] = exp(input[box_offset + 2]) * bias[2 * box_of_the_cell + 0] / T(width_norm);
91 output[box_offset + 3] = exp(input[box_offset + 3]) * bias[2 * box_of_the_cell + 1] / T(height_norm);
93 /* squash objectness score into a probability */
94 using device::sigmoid;
95 T objectness_prob = sigmoid(output[box_offset + 4]);
96 output[box_offset + 4] = objectness_prob;
98 /* ignore prediction if the objectness probability is less than the cutoff */
99 if (objectness_prob < object_prob_cutoff)
102 /* the class probabilities we have currently are conditional class probabilities
105 * to obtain the actual class probability, we multiply the conditional probability
106 * with the object probability
108 const index_type class_begin = box_offset + 5; /* 4 box coordinates, 1 obj prob, class probs... */
109 const index_type class_end = class_begin + classes;
110 index_type offset = class_begin;
112 using vector_type = get_vector_type_t<T, 4>;
114 /* process each class independently until the offset is aligned to an n-element boundary */
115 while (offset % vector_type::size() != 0 && offset < class_end) {
116 T actual_class_prob = objectness_prob * output[offset];
117 if (actual_class_prob <= class_prob_cutoff)
118 actual_class_prob = T(0);
119 output[offset] = actual_class_prob;
123 auto output_vPtr = vector_type::get_pointer(output.data() + offset);
124 auto input_vPtr = vector_type::get_pointer(input.data() + offset);
125 for (int i = 0; (offset + vector_type::size()) < class_end; i++) {
127 v_load(vec, output_vPtr[i]);
128 for (int j = 0; j < vector_type::size(); j++) {
129 T actual_class_prob = objectness_prob * vec.data[j];
130 if (actual_class_prob <= class_prob_cutoff)
131 actual_class_prob = T(0);
132 vec.data[j] = actual_class_prob;
134 v_store(output_vPtr[i], vec);
135 offset += vector_type::size();
138 /* process the remaining classes */
139 while (offset < class_end) {
140 T actual_class_prob = objectness_prob * output[offset];
141 if (actual_class_prob <= class_prob_cutoff)
142 actual_class_prob = T(0);
143 output[offset] = actual_class_prob;
151 void sigmoid_strided(const Stream& stream, Span<T> output, View<T> input, std::size_t n, std::size_t stride, std::size_t offset) {
152 CV_Assert(output.size() % stride == 0);
154 auto kernel = raw::sigmoid_strided<T>;
155 auto policy = make_policy(kernel, n * output.size() / stride, 0, stream);
156 launch_kernel(kernel, policy, output, input, n, stride, offset);
159 template void sigmoid_strided(const Stream&, Span<__half>, View<__half>, std::size_t, std::size_t, std::size_t);
160 template void sigmoid_strided(const Stream&, Span<float>, View<float>, std::size_t, std::size_t, std::size_t);
163 void softmax_strided(const Stream& stream, Span<T> output, View<T> input, std::size_t n, std::size_t stride, std::size_t offset) {
164 CV_Assert(output.size() % stride == 0);
166 auto kernel = raw::softmax_strided<T>;
167 auto policy = make_policy(kernel, output.size() / stride, 0, stream);
168 launch_kernel(kernel, policy, output, input, n, stride, offset);
171 template void softmax_strided(const Stream&, Span<__half>, View<__half>, std::size_t, std::size_t, std::size_t);
172 template void softmax_strided(const Stream&, Span<float>, View<float>, std::size_t, std::size_t, std::size_t);
175 void region_finalize(const Stream& stream, Span<T> output, View<T> input, View<T> bias,
176 T object_prob_cutoff, T class_prob_cutoff,
177 std::size_t height_norm, std::size_t width_norm,
178 std::size_t rows, std::size_t cols,
179 std::size_t boxes_per_cell,
180 std::size_t box_size,
183 CV_Assert(output.size() % box_size == 0);
185 auto kernel = raw::region_finalize<T>;
186 auto policy = make_policy(kernel, output.size() / box_size, 0, stream);
187 launch_kernel(kernel, policy, output, input, bias,
188 object_prob_cutoff, class_prob_cutoff,
189 height_norm, width_norm,
190 rows, cols, boxes_per_cell, box_size, classes);
193 template void region_finalize(const Stream&, Span<__half>, View<__half>, View<__half>,
194 __half, __half, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t);
196 template void region_finalize(const Stream&, Span<float>, View<float>, View<float>,
197 float, float, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t);
199 }}}} /* namespace cv::dnn::cuda4dnn::kernels */