1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
5 #include <cuda_runtime.h>
10 #include "vector_traits.hpp"
11 #include "grid_stride_range.hpp"
12 #include "execution.hpp"
14 #include "../cuda4dnn/csl/stream.hpp"
15 #include "../cuda4dnn/csl/span.hpp"
17 #include "../cuda4dnn/kernels/scale_shift.hpp"
19 #include <opencv2/core.hpp>
23 using namespace cv::dnn::cuda4dnn::csl;
24 using namespace cv::dnn::cuda4dnn::csl::device;
26 namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
29 template <class T, std::size_t N>
30 __global__ void abs_vec(Span<T> output, View<T> input) {
31 using vector_type = get_vector_type_t<T, N>;
33 auto output_vPtr = vector_type::get_pointer(output.data());
34 auto input_vPtr = vector_type::get_pointer(input.data());
36 for (auto i : grid_stride_range(output.size() / vector_type::size())) {
38 v_load(vec, input_vPtr[i]);
39 for (int j = 0; j < vector_type::size(); j++) {
41 vec.data[j] = abs(vec.data[j]);
43 v_store(output_vPtr[i], vec);
47 template <class T, std::size_t N>
48 __global__ void tanh_vec(Span<T> output, View<T> input) {
49 using vector_type = get_vector_type_t<T, N>;
51 auto output_vPtr = vector_type::get_pointer(output.data());
52 auto input_vPtr = vector_type::get_pointer(input.data());
54 for (auto i : grid_stride_range(output.size() / vector_type::size())) {
56 v_load(vec, input_vPtr[i]);
57 for (int j = 0; j < vector_type::size(); j++) {
59 vec.data[j] = tanh(vec.data[j]);
61 v_store(output_vPtr[i], vec);
65 template <class T, std::size_t N>
66 __global__ void sigmoid_vec(Span<T> output, View<T> input) {
67 using vector_type = get_vector_type_t<T, N>;
69 auto output_vPtr = vector_type::get_pointer(output.data());
70 auto input_vPtr = vector_type::get_pointer(input.data());
72 for (auto i : grid_stride_range(output.size() / vector_type::size())) {
74 v_load(vec, input_vPtr[i]);
75 for (int j = 0; j < vector_type::size(); j++) {
76 using device::sigmoid;
77 vec.data[j] = sigmoid(vec.data[j]);
79 v_store(output_vPtr[i], vec);
83 template <class T, std::size_t N>
84 __global__ void bnll_vec(Span<T> output, View<T> input) {
85 using vector_type = get_vector_type_t<T, N>;
87 auto output_vPtr = vector_type::get_pointer(output.data());
88 auto input_vPtr = vector_type::get_pointer(input.data());
90 for (auto i : grid_stride_range(output.size() / vector_type::size())) {
92 v_load(vec, input_vPtr[i]);
93 for (int j = 0; j < vector_type::size(); j++) {
94 using device::log1pexp;
95 vec.data[j] = vec.data[j] > T(0) ? vec.data[j] + log1pexp(-vec.data[j]) : log1pexp(vec.data[j]);
97 v_store(output_vPtr[i], vec);
101 template <class T, std::size_t N>
102 __global__ void elu_vec(Span<T> output, View<T> input) {
103 using vector_type = get_vector_type_t<T, N>;
105 auto output_vPtr = vector_type::get_pointer(output.data());
106 auto input_vPtr = vector_type::get_pointer(input.data());
108 for (auto i : grid_stride_range(output.size() / vector_type::size())) {
110 v_load(vec, input_vPtr[i]);
111 for (int j = 0; j < vector_type::size(); j++) {
113 vec.data[j] = vec.data[j] >= T(0) ? vec.data[j] : expm1(vec.data[j]);
115 v_store(output_vPtr[i], vec);
119 template <class T, std::size_t N>
120 __global__ void relu_vec(Span<T> output, View<T> input, T slope) {
121 using vector_type = get_vector_type_t<T, N>;
123 auto output_vPtr = vector_type::get_pointer(output.data());
124 auto input_vPtr = vector_type::get_pointer(input.data());
126 for (auto i : grid_stride_range(output.size() / vector_type::size())) {
128 v_load(vec, input_vPtr[i]);
129 for(int j = 0; j < vector_type::size(); j++)
130 vec.data[j] = vec.data[j] >= T(0) ? vec.data[j] : slope * vec.data[j];
131 v_store(output_vPtr[i], vec);
135 template <class T, std::size_t N>
136 __global__ void clipped_relu_vec(Span<T> output, View<T> input, T floor, T ceiling) {
137 using vector_type = get_vector_type_t<T, N>;
139 auto output_vPtr = vector_type::get_pointer(output.data());
140 auto input_vPtr = vector_type::get_pointer(input.data());
142 for (auto i : grid_stride_range(output.size() / vector_type::size())) {
146 v_load(vec, input_vPtr[i]);
147 for (int j = 0; j < vector_type::size(); j++)
148 vec.data[j] = clamp(vec.data[j], floor, ceiling);
149 v_store(output_vPtr[i], vec);
153 template <class T, std::size_t N>
154 __global__ void axiswise_relu_vec(Span<T> output, View<T> input, size_type inner_size, View<T> slope) {
155 using vector_type = get_vector_type_t<T, N>;
157 auto output_vPtr = vector_type::get_pointer(output.data());
158 auto input_vPtr = vector_type::get_pointer(input.data());
160 inner_size /= vector_type::size();
161 for (auto i : grid_stride_range(output.size() / vector_type::size())) {
162 const index_type c = (i / inner_size) % static_cast<size_type>(slope.size());
165 v_load(vec, input_vPtr[i]);
166 for (int j = 0; j < vector_type::size(); j++)
167 vec.data[j] = vec.data[j] > T(0) ? vec.data[j] : vec.data[j] * slope[c];
168 v_store(output_vPtr[i], vec);
172 template <class T, std::size_t N>
173 __global__ void power_vec(Span<T> output, View<T> input, T exp, T scale, T shift) {
174 using vector_type = get_vector_type_t<T, N>;
176 auto output_vPtr = vector_type::get_pointer(output.data());
177 auto input_vPtr = vector_type::get_pointer(input.data());
179 for (auto i : grid_stride_range(output.size() / vector_type::size())) {
183 v_load(vec, input_vPtr[i]);
184 for (int j = 0; j < vector_type::size(); j++)
185 vec.data[j] = pow(shift + scale * vec.data[j], exp);
186 v_store(output_vPtr[i], vec);
191 template <class T, std::size_t N>
192 void launch_vectorized_abs(const Stream& stream, Span<T> output, View<T> input) {
193 CV_Assert(is_fully_aligned<T>(output, N));
194 CV_Assert(is_fully_aligned<T>(input, N));
196 auto kernel = raw::abs_vec<T, N>;
197 auto policy = make_policy(kernel, output.size() / N, 0, stream);
198 launch_kernel(kernel, policy, output, input);
202 void abs(const Stream& stream, Span<T> output, View<T> input) {
203 CV_Assert(input.size() == output.size());
205 if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
206 launch_vectorized_abs<T, 4>(stream, output, input);
207 } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
208 launch_vectorized_abs<T, 2>(stream, output, input);
210 launch_vectorized_abs<T, 1>(stream, output, input);
214 template void abs<__half>(const Stream& stream, Span<__half> output, View<__half> input);
215 template void abs<float>(const Stream& stream, Span<float> output, View<float> input);
217 template <class T, std::size_t N>
218 void launch_vectorized_tanh(const Stream& stream, Span<T> output, View<T> input) {
219 CV_Assert(is_fully_aligned<T>(output, N));
220 CV_Assert(is_fully_aligned<T>(input, N));
222 auto kernel = raw::tanh_vec<T, N>;
223 auto policy = make_policy(kernel, output.size() / N, 0, stream);
224 launch_kernel(kernel, policy, output, input);
228 void tanh(const Stream& stream, Span<T> output, View<T> input) {
229 CV_Assert(input.size() == output.size());
231 if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
232 launch_vectorized_tanh<T, 4>(stream, output, input);
233 } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
234 launch_vectorized_tanh<T, 2>(stream, output, input);
236 launch_vectorized_tanh<T, 1>(stream, output, input);
240 template void tanh<__half>(const Stream&, Span<__half>, View<__half>);
241 template void tanh<float>(const Stream&, Span<float>, View<float>);
243 template <class T, std::size_t N>
244 void launch_vectorized_sigmoid(const Stream& stream, Span<T> output, View<T> input) {
245 CV_Assert(is_fully_aligned<T>(output, N));
246 CV_Assert(is_fully_aligned<T>(input, N));
248 auto kernel = raw::sigmoid_vec<T, N>;
249 auto policy = make_policy(kernel, output.size() / N, 0, stream);
250 launch_kernel(kernel, policy, output, input);
254 void sigmoid(const Stream& stream, Span<T> output, View<T> input) {
255 CV_Assert(input.size() == output.size());
257 if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
258 launch_vectorized_sigmoid<T, 4>(stream, output, input);
259 } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
260 launch_vectorized_sigmoid<T, 2>(stream, output, input);
262 launch_vectorized_sigmoid<T, 1>(stream, output, input);
266 template void sigmoid<__half>(const Stream&, Span<__half>, View<__half>);
267 template void sigmoid<float>(const Stream&, Span<float>, View<float>);
269 template <class T, std::size_t N>
270 void launch_vectorized_bnll(const Stream& stream, Span<T> output, View<T> input) {
271 CV_Assert(is_fully_aligned<T>(output, N));
272 CV_Assert(is_fully_aligned<T>(input, N));
274 auto kernel = raw::bnll_vec<T, N>;
275 auto policy = make_policy(kernel, output.size() / N, 0, stream);
276 launch_kernel(kernel, policy, output, input);
280 void bnll(const Stream& stream, Span<T> output, View<T> input) {
281 CV_Assert(input.size() == output.size());
283 if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
284 launch_vectorized_bnll<T, 4>(stream, output, input);
285 } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
286 launch_vectorized_bnll<T, 2>(stream, output, input);
288 launch_vectorized_bnll<T, 1>(stream, output, input);
292 template void bnll<__half>(const Stream&, Span<__half>, View<__half>);
293 template void bnll<float>(const Stream&, Span<float>, View<float>);
295 template <class T, std::size_t N>
296 void launch_vectorized_elu(const Stream& stream, Span<T> output, View<T> input) {
297 CV_Assert(is_fully_aligned<T>(output, N));
298 CV_Assert(is_fully_aligned<T>(input, N));
300 auto kernel = raw::elu_vec<T, N>;
301 auto policy = make_policy(kernel, output.size() / N, 0, stream);
302 launch_kernel(kernel, policy, output, input);
306 void elu(const Stream& stream, Span<T> output, View<T> input) {
307 CV_Assert(input.size() == output.size());
309 if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
310 launch_vectorized_elu<T, 4>(stream, output, input);
311 } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
312 launch_vectorized_elu<T, 2>(stream, output, input);
314 launch_vectorized_elu<T, 1>(stream, output, input);
318 template void elu<__half>(const Stream&, Span<__half>, View<__half>);
319 template void elu<float>(const Stream&, Span<float>, View<float>);
321 template <class T, std::size_t N>
322 void launch_vectorized_relu(const Stream& stream, Span<T> output, View<T> input, T slope) {
323 CV_Assert(is_fully_aligned<T>(output, N));
324 CV_Assert(is_fully_aligned<T>(input, N));
326 auto kernel = raw::relu_vec<T, N>;
327 auto policy = make_policy(kernel, output.size() / N, 0, stream);
328 launch_kernel(kernel, policy, output, input, slope);
332 void relu(const Stream& stream, Span<T> output, View<T> input, T slope) {
333 CV_Assert(input.size() == output.size());
335 if(is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
336 launch_vectorized_relu<T, 4>(stream, output, input, slope);
337 } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
338 launch_vectorized_relu<T, 2>(stream, output, input, slope);
340 launch_vectorized_relu<T, 1>(stream, output, input, slope);
344 template void relu<__half>(const Stream&, Span<__half>, View<__half>, __half);
345 template void relu<float>(const Stream&, Span<float>, View<float>, float);
347 template <class T, std::size_t N>
348 void launch_vectorized_clipped_relu(const Stream& stream, Span<T> output, View<T> input, T floor, T ceiling) {
349 CV_Assert(is_fully_aligned<T>(output, N));
350 CV_Assert(is_fully_aligned<T>(input, N));
352 auto kernel = raw::clipped_relu_vec<T, N>;
353 auto policy = make_policy(kernel, output.size() / N, 0, stream);
354 launch_kernel(kernel, policy, output, input, floor, ceiling);
358 void clipped_relu(const Stream& stream, Span<T> output, View<T> input, T floor, T ceiling) {
359 CV_Assert(input.size() == output.size());
360 CV_Assert(static_cast<double>(floor) <= static_cast<double>(ceiling));
362 if(is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
363 launch_vectorized_clipped_relu<T, 4>(stream, output, input, floor, ceiling);
364 } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
365 launch_vectorized_clipped_relu<T, 2>(stream, output, input, floor, ceiling);
367 launch_vectorized_clipped_relu<T, 1>(stream, output, input, floor, ceiling);
371 template void clipped_relu<__half>(const Stream&, Span<__half>, View<__half>, __half, __half);
372 template void clipped_relu<float>(const Stream&, Span<float>, View<float>, float, float);
374 template <class T, std::size_t N>
375 void launch_vectorized_axiswise_relu(const Stream& stream, Span<T> output, View<T> input, std::size_t inner_size, View<T> slope) {
376 CV_Assert(is_fully_aligned<T>(output, N));
377 CV_Assert(is_fully_aligned<T>(input, N));
378 CV_Assert(inner_size % N == 0);
380 auto kernel = raw::axiswise_relu_vec<T, N>;
381 auto policy = make_policy(kernel, output.size() / N, 0, stream);
382 launch_kernel(kernel, policy, output, input, inner_size, slope);
386 void axiswise_relu(const Stream& stream, Span<T> output, View<T> input, std::size_t inner_size, View<T> slope) {
387 CV_Assert(input.size() == output.size());
389 if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4) && inner_size % 4 == 0) {
390 launch_vectorized_axiswise_relu<T, 4>(stream, output, input, inner_size, slope);
391 } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2) && inner_size % 2 == 0) {
392 launch_vectorized_axiswise_relu<T, 2>(stream, output, input, inner_size, slope);
394 launch_vectorized_axiswise_relu<T, 1>(stream, output, input, inner_size, slope);
398 template void axiswise_relu<__half>(const Stream&, Span<__half>, View<__half>, std::size_t, View<__half>);
399 template void axiswise_relu<float>(const Stream&, Span<float>, View<float>, std::size_t, View<float>);
401 template <class T, std::size_t N>
402 void launch_vectorized_power(const Stream& stream, Span<T> output, View<T> input, T exp, T scale, T shift) {
403 CV_Assert(is_fully_aligned<T>(output, N));
404 CV_Assert(is_fully_aligned<T>(input, N));
406 auto kernel = raw::power_vec<T, N>;
407 auto policy = make_policy(kernel, output.size() / N, 0, stream);
408 launch_kernel(kernel, policy, output, input, exp, scale, shift);
412 void power(const Stream& stream, Span<T> output, View<T> input, T exp, T scale, T shift) {
413 CV_Assert(input.size() == output.size());
415 if (static_cast<float>(exp) == 1.0f) {
416 scale1_with_bias1(stream, output, input, scale, shift);
420 if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4) && output.size()) {
421 launch_vectorized_power<T, 4>(stream, output, input, exp, scale, shift);
422 } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2) && output.size()) {
423 launch_vectorized_power<T, 2>(stream, output, input, exp, scale, shift);
425 launch_vectorized_power<T, 1>(stream, output, input, exp, scale, shift);
429 template void power<__half>(const Stream&, Span<__half>, View<__half>, __half, __half, __half);
430 template void power<float>(const Stream&, Span<float>, View<float>, float, float, float);
432 }}}} /* namespace cv::dnn::cuda4dnn::kernels */