Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
[platform/upstream/opencv.git] / modules / dnn / src / cuda / activations.cu
1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4
5 #include <cuda_runtime.h>
6 #include <cuda_fp16.h>
7
8 #include "math.hpp"
9 #include "types.hpp"
10 #include "vector_traits.hpp"
11 #include "grid_stride_range.hpp"
12 #include "execution.hpp"
13
14 #include "../cuda4dnn/csl/stream.hpp"
15 #include "../cuda4dnn/csl/span.hpp"
16
17 #include "../cuda4dnn/kernels/scale_shift.hpp"
18
19 #include <opencv2/core.hpp>
20
21 #include <cstddef>
22
23 using namespace cv::dnn::cuda4dnn::csl;
24 using namespace cv::dnn::cuda4dnn::csl::device;
25
26 namespace cv { namespace dnn { namespace cuda4dnn  { namespace kernels {
27
28     namespace raw {
29         template <class T, std::size_t N>
30         __global__ void abs_vec(Span<T> output, View<T> input) {
31             using vector_type = get_vector_type_t<T, N>;
32
33             auto output_vPtr = vector_type::get_pointer(output.data());
34             auto input_vPtr = vector_type::get_pointer(input.data());
35
36             for (auto i : grid_stride_range(output.size() / vector_type::size())) {
37                 vector_type vec;
38                 v_load(vec, input_vPtr[i]);
39                 for (int j = 0; j < vector_type::size(); j++) {
40                     using device::abs;
41                     vec.data[j] = abs(vec.data[j]);
42                 }
43                 v_store(output_vPtr[i], vec);
44             }
45         }
46
47         template <class T, std::size_t N>
48         __global__ void tanh_vec(Span<T> output, View<T> input) {
49             using vector_type = get_vector_type_t<T, N>;
50
51             auto output_vPtr = vector_type::get_pointer(output.data());
52             auto input_vPtr = vector_type::get_pointer(input.data());
53
54             for (auto i : grid_stride_range(output.size() / vector_type::size())) {
55                 vector_type vec;
56                 v_load(vec, input_vPtr[i]);
57                 for (int j = 0; j < vector_type::size(); j++) {
58                     using device::tanh;
59                     vec.data[j] = tanh(vec.data[j]);
60                 }
61                 v_store(output_vPtr[i], vec);
62             }
63         }
64
65         template <class T, std::size_t N>
66         __global__ void sigmoid_vec(Span<T> output, View<T> input) {
67             using vector_type = get_vector_type_t<T, N>;
68
69             auto output_vPtr = vector_type::get_pointer(output.data());
70             auto input_vPtr = vector_type::get_pointer(input.data());
71
72             for (auto i : grid_stride_range(output.size() / vector_type::size())) {
73                 vector_type vec;
74                 v_load(vec, input_vPtr[i]);
75                 for (int j = 0; j < vector_type::size(); j++) {
76                     using device::sigmoid;
77                     vec.data[j] = sigmoid(vec.data[j]);
78                 }
79                 v_store(output_vPtr[i], vec);
80             }
81         }
82
83         template <class T, std::size_t N>
84         __global__ void bnll_vec(Span<T> output, View<T> input) {
85             using vector_type = get_vector_type_t<T, N>;
86
87             auto output_vPtr = vector_type::get_pointer(output.data());
88             auto input_vPtr = vector_type::get_pointer(input.data());
89
90             for (auto i : grid_stride_range(output.size() / vector_type::size())) {
91                 vector_type vec;
92                 v_load(vec, input_vPtr[i]);
93                 for (int j = 0; j < vector_type::size(); j++) {
94                     using device::log1pexp;
95                     vec.data[j] = vec.data[j] > T(0) ? vec.data[j] + log1pexp(-vec.data[j]) : log1pexp(vec.data[j]);
96                 }
97                 v_store(output_vPtr[i], vec);
98             }
99         }
100
101         template <class T, std::size_t N>
102         __global__ void elu_vec(Span<T> output, View<T> input) {
103             using vector_type = get_vector_type_t<T, N>;
104
105             auto output_vPtr = vector_type::get_pointer(output.data());
106             auto input_vPtr = vector_type::get_pointer(input.data());
107
108             for (auto i : grid_stride_range(output.size() / vector_type::size())) {
109                 vector_type vec;
110                 v_load(vec, input_vPtr[i]);
111                 for (int j = 0; j < vector_type::size(); j++) {
112                     using device::expm1;
113                     vec.data[j] = vec.data[j] >= T(0) ? vec.data[j] : expm1(vec.data[j]);
114                 }
115                 v_store(output_vPtr[i], vec);
116             }
117         }
118
119         template <class T, std::size_t N>
120         __global__ void relu_vec(Span<T> output, View<T> input, T slope) {
121             using vector_type = get_vector_type_t<T, N>;
122
123             auto output_vPtr = vector_type::get_pointer(output.data());
124             auto input_vPtr = vector_type::get_pointer(input.data());
125
126             for (auto i : grid_stride_range(output.size() / vector_type::size())) {
127                 vector_type vec;
128                 v_load(vec, input_vPtr[i]);
129                 for(int j = 0; j < vector_type::size(); j++)
130                     vec.data[j] = vec.data[j] >= T(0) ? vec.data[j] : slope * vec.data[j];
131                 v_store(output_vPtr[i], vec);
132             }
133         }
134
135         template <class T, std::size_t N>
136         __global__ void clipped_relu_vec(Span<T> output, View<T> input, T floor, T ceiling) {
137             using vector_type = get_vector_type_t<T, N>;
138
139             auto output_vPtr = vector_type::get_pointer(output.data());
140             auto input_vPtr = vector_type::get_pointer(input.data());
141
142             for (auto i : grid_stride_range(output.size() / vector_type::size())) {
143                 using device::clamp;
144
145                 vector_type vec;
146                 v_load(vec, input_vPtr[i]);
147                 for (int j = 0; j < vector_type::size(); j++)
148                     vec.data[j] = clamp(vec.data[j], floor, ceiling);
149                 v_store(output_vPtr[i], vec);
150             }
151         }
152
153         template <class T, std::size_t N>
154         __global__ void axiswise_relu_vec(Span<T> output, View<T> input, size_type inner_size, View<T> slope) {
155             using vector_type = get_vector_type_t<T, N>;
156
157             auto output_vPtr = vector_type::get_pointer(output.data());
158             auto input_vPtr = vector_type::get_pointer(input.data());
159
160             inner_size /= vector_type::size();
161             for (auto i : grid_stride_range(output.size() / vector_type::size())) {
162                 const index_type c = (i / inner_size) % static_cast<size_type>(slope.size());
163
164                 vector_type vec;
165                 v_load(vec, input_vPtr[i]);
166                 for (int j = 0; j < vector_type::size(); j++)
167                     vec.data[j] = vec.data[j] > T(0) ? vec.data[j] : vec.data[j] * slope[c];
168                 v_store(output_vPtr[i], vec);
169             }
170         }
171
172         template <class T, std::size_t N>
173         __global__ void power_vec(Span<T> output, View<T> input, T exp, T scale, T shift) {
174             using vector_type = get_vector_type_t<T, N>;
175
176             auto output_vPtr = vector_type::get_pointer(output.data());
177             auto input_vPtr = vector_type::get_pointer(input.data());
178
179             for (auto i : grid_stride_range(output.size() / vector_type::size())) {
180                 using device::pow;
181
182                 vector_type vec;
183                 v_load(vec, input_vPtr[i]);
184                 for (int j = 0; j < vector_type::size(); j++)
185                     vec.data[j] = pow(shift + scale * vec.data[j], exp);
186                 v_store(output_vPtr[i], vec);
187             }
188         }
189     }
190
191     template <class T, std::size_t N>
192     void launch_vectorized_abs(const Stream& stream, Span<T> output, View<T> input) {
193         CV_Assert(is_fully_aligned<T>(output, N));
194         CV_Assert(is_fully_aligned<T>(input, N));
195
196         auto kernel = raw::abs_vec<T, N>;
197         auto policy = make_policy(kernel, output.size() / N, 0, stream);
198         launch_kernel(kernel, policy, output, input);
199     }
200
201     template <class T>
202     void abs(const Stream& stream, Span<T> output, View<T> input) {
203         CV_Assert(input.size() == output.size());
204
205         if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
206             launch_vectorized_abs<T, 4>(stream, output, input);
207         } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
208             launch_vectorized_abs<T, 2>(stream, output, input);
209         } else {
210             launch_vectorized_abs<T, 1>(stream, output, input);
211         }
212     }
213
214     template void abs<__half>(const Stream& stream, Span<__half> output, View<__half> input);
215     template void abs<float>(const Stream& stream, Span<float> output, View<float> input);
216
217     template <class T, std::size_t N>
218     void launch_vectorized_tanh(const Stream& stream, Span<T> output, View<T> input) {
219         CV_Assert(is_fully_aligned<T>(output, N));
220         CV_Assert(is_fully_aligned<T>(input, N));
221
222         auto kernel = raw::tanh_vec<T, N>;
223         auto policy = make_policy(kernel, output.size() / N, 0, stream);
224         launch_kernel(kernel, policy, output, input);
225     }
226
227     template <class T>
228     void tanh(const Stream& stream, Span<T> output, View<T> input) {
229         CV_Assert(input.size() == output.size());
230
231         if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
232             launch_vectorized_tanh<T, 4>(stream, output, input);
233         } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
234             launch_vectorized_tanh<T, 2>(stream, output, input);
235         } else {
236             launch_vectorized_tanh<T, 1>(stream, output, input);
237         }
238     }
239
240     template void tanh<__half>(const Stream&, Span<__half>, View<__half>);
241     template void tanh<float>(const Stream&, Span<float>, View<float>);
242
243     template <class T, std::size_t N>
244     void launch_vectorized_sigmoid(const Stream& stream, Span<T> output, View<T> input) {
245         CV_Assert(is_fully_aligned<T>(output, N));
246         CV_Assert(is_fully_aligned<T>(input, N));
247
248         auto kernel = raw::sigmoid_vec<T, N>;
249         auto policy = make_policy(kernel, output.size() / N, 0, stream);
250         launch_kernel(kernel, policy, output, input);
251     }
252
253     template <class T>
254     void sigmoid(const Stream& stream, Span<T> output, View<T> input) {
255         CV_Assert(input.size() == output.size());
256
257         if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
258             launch_vectorized_sigmoid<T, 4>(stream, output, input);
259         } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
260             launch_vectorized_sigmoid<T, 2>(stream, output, input);
261         } else {
262             launch_vectorized_sigmoid<T, 1>(stream, output, input);
263         }
264     }
265
266     template void sigmoid<__half>(const Stream&, Span<__half>, View<__half>);
267     template void sigmoid<float>(const Stream&, Span<float>, View<float>);
268
269     template <class T, std::size_t N>
270     void launch_vectorized_bnll(const Stream& stream, Span<T> output, View<T> input) {
271         CV_Assert(is_fully_aligned<T>(output, N));
272         CV_Assert(is_fully_aligned<T>(input, N));
273
274         auto kernel = raw::bnll_vec<T, N>;
275         auto policy = make_policy(kernel, output.size() / N, 0, stream);
276         launch_kernel(kernel, policy, output, input);
277     }
278
279     template <class T>
280     void bnll(const Stream& stream, Span<T> output, View<T> input) {
281         CV_Assert(input.size() == output.size());
282
283         if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
284             launch_vectorized_bnll<T, 4>(stream, output, input);
285         } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
286             launch_vectorized_bnll<T, 2>(stream, output, input);
287         } else {
288             launch_vectorized_bnll<T, 1>(stream, output, input);
289         }
290     }
291
292     template void bnll<__half>(const Stream&, Span<__half>, View<__half>);
293     template void bnll<float>(const Stream&, Span<float>, View<float>);
294
295     template <class T, std::size_t N>
296     void launch_vectorized_elu(const Stream& stream, Span<T> output, View<T> input) {
297         CV_Assert(is_fully_aligned<T>(output, N));
298         CV_Assert(is_fully_aligned<T>(input, N));
299
300         auto kernel = raw::elu_vec<T, N>;
301         auto policy = make_policy(kernel, output.size() / N, 0, stream);
302         launch_kernel(kernel, policy, output, input);
303     }
304
305     template <class T>
306     void elu(const Stream& stream, Span<T> output, View<T> input) {
307         CV_Assert(input.size() == output.size());
308
309         if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
310             launch_vectorized_elu<T, 4>(stream, output, input);
311         } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
312             launch_vectorized_elu<T, 2>(stream, output, input);
313         } else {
314             launch_vectorized_elu<T, 1>(stream, output, input);
315         }
316     }
317
318     template void elu<__half>(const Stream&, Span<__half>, View<__half>);
319     template void elu<float>(const Stream&, Span<float>, View<float>);
320
321     template <class T, std::size_t N>
322     void launch_vectorized_relu(const Stream& stream, Span<T> output, View<T> input, T slope) {
323         CV_Assert(is_fully_aligned<T>(output, N));
324         CV_Assert(is_fully_aligned<T>(input, N));
325
326         auto kernel = raw::relu_vec<T, N>;
327         auto policy = make_policy(kernel, output.size() / N, 0, stream);
328         launch_kernel(kernel, policy, output, input, slope);
329     }
330
331     template <class T>
332     void relu(const Stream& stream, Span<T> output, View<T> input, T slope) {
333         CV_Assert(input.size() == output.size());
334
335         if(is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
336             launch_vectorized_relu<T, 4>(stream, output, input, slope);
337         } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
338             launch_vectorized_relu<T, 2>(stream, output, input, slope);
339         } else {
340             launch_vectorized_relu<T, 1>(stream, output, input, slope);
341         }
342     }
343
344     template void relu<__half>(const Stream&, Span<__half>, View<__half>, __half);
345     template void relu<float>(const Stream&, Span<float>, View<float>, float);
346
347     template <class T, std::size_t N>
348     void launch_vectorized_clipped_relu(const Stream& stream, Span<T> output, View<T> input, T floor, T ceiling) {
349         CV_Assert(is_fully_aligned<T>(output, N));
350         CV_Assert(is_fully_aligned<T>(input, N));
351
352         auto kernel = raw::clipped_relu_vec<T, N>;
353         auto policy = make_policy(kernel, output.size() / N, 0, stream);
354         launch_kernel(kernel, policy, output, input, floor, ceiling);
355     }
356
357     template <class T>
358     void clipped_relu(const Stream& stream, Span<T> output, View<T> input, T floor, T ceiling) {
359         CV_Assert(input.size() == output.size());
360         CV_Assert(static_cast<double>(floor) <= static_cast<double>(ceiling));
361
362         if(is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
363             launch_vectorized_clipped_relu<T, 4>(stream, output, input, floor, ceiling);
364         } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
365             launch_vectorized_clipped_relu<T, 2>(stream, output, input, floor, ceiling);
366         } else {
367             launch_vectorized_clipped_relu<T, 1>(stream, output, input, floor, ceiling);
368         }
369     }
370
371     template void clipped_relu<__half>(const Stream&, Span<__half>, View<__half>, __half, __half);
372     template void clipped_relu<float>(const Stream&, Span<float>, View<float>, float, float);
373
374     template <class T, std::size_t N>
375     void launch_vectorized_axiswise_relu(const Stream& stream, Span<T> output, View<T> input, std::size_t inner_size, View<T> slope) {
376         CV_Assert(is_fully_aligned<T>(output, N));
377         CV_Assert(is_fully_aligned<T>(input, N));
378         CV_Assert(inner_size % N == 0);
379
380         auto kernel = raw::axiswise_relu_vec<T, N>;
381         auto policy = make_policy(kernel, output.size() / N, 0, stream);
382         launch_kernel(kernel, policy, output, input, inner_size, slope);
383     }
384
385     template <class T>
386     void axiswise_relu(const Stream& stream, Span<T> output, View<T> input, std::size_t inner_size, View<T> slope) {
387         CV_Assert(input.size() == output.size());
388
389         if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4) && inner_size % 4 == 0) {
390             launch_vectorized_axiswise_relu<T, 4>(stream, output, input, inner_size, slope);
391         } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2) && inner_size % 2 == 0) {
392             launch_vectorized_axiswise_relu<T, 2>(stream, output, input, inner_size, slope);
393         } else {
394             launch_vectorized_axiswise_relu<T, 1>(stream, output, input, inner_size, slope);
395         }
396     }
397
398     template void axiswise_relu<__half>(const Stream&, Span<__half>, View<__half>, std::size_t, View<__half>);
399     template void axiswise_relu<float>(const Stream&, Span<float>, View<float>, std::size_t, View<float>);
400
401     template <class T, std::size_t N>
402     void launch_vectorized_power(const Stream& stream, Span<T> output, View<T> input, T exp, T scale, T shift) {
403         CV_Assert(is_fully_aligned<T>(output, N));
404         CV_Assert(is_fully_aligned<T>(input, N));
405
406         auto kernel = raw::power_vec<T, N>;
407         auto policy = make_policy(kernel, output.size() / N, 0, stream);
408         launch_kernel(kernel, policy, output, input, exp, scale, shift);
409     }
410
411     template <class T>
412     void power(const Stream& stream, Span<T> output, View<T> input, T exp, T scale, T shift) {
413         CV_Assert(input.size() == output.size());
414
415         if (static_cast<float>(exp) == 1.0f) {
416             scale1_with_bias1(stream, output, input, scale, shift);
417             return;
418         }
419
420         if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4) && output.size()) {
421             launch_vectorized_power<T, 4>(stream, output, input, exp, scale, shift);
422         } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2) && output.size()) {
423             launch_vectorized_power<T, 2>(stream, output, input, exp, scale, shift);
424         } else {
425             launch_vectorized_power<T, 1>(stream, output, input, exp, scale, shift);
426         }
427     }
428
429     template void power<__half>(const Stream&, Span<__half>, View<__half>, __half, __half, __half);
430     template void power<float>(const Stream&, Span<float>, View<float>, float, float, float);
431
432 }}}} /* namespace cv::dnn::cuda4dnn::kernels */