Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
[platform/upstream/opencv.git] / modules / dnn / src / cuda / normalize.cu
1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4
5 #include <cuda_runtime.h>
6 #include <cuda_fp16.h>
7
8 #include "array.hpp"
9 #include "math.hpp"
10 #include "types.hpp"
11 #include "atomics.hpp"
12 #include "grid_stride_range.hpp"
13 #include "execution.hpp"
14
15 #include "../cuda4dnn/csl/stream.hpp"
16 #include "../cuda4dnn/csl/span.hpp"
17
18 #include "../cuda4dnn/kernels/fill.hpp"
19 #include "../cuda4dnn/kernels/scale_shift.hpp"
20
21 #include <opencv2/core.hpp>
22
23 #include <cstddef>
24
25 using namespace cv::dnn::cuda4dnn::csl;
26 using namespace cv::dnn::cuda4dnn::csl::device;
27
28 namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
29
30     namespace raw {
31         template <class T>
32         __global__ void reduce_sum_abs(Span<T> output, View<T> input, size_type outer_stride, size_type mid_stride) {
33             for (auto idx : grid_stride_range(input.size())) {
34                 const index_type outer_idx = idx / outer_stride;
35                 const index_type inner_idx = idx % mid_stride;
36
37                 const index_type sum_idx = outer_idx * mid_stride + inner_idx;
38                 atomicAdd(&output[sum_idx], device::abs(input[idx]));
39             }
40         }
41
42         template <class T>
43         __global__ void reciprocal(Span<T> output, T epsilon) {
44             for (auto idx : grid_stride_range(output.size()))
45                 output[idx] = T(1) / (output[idx] + epsilon);
46         }
47
48         template <class T>
49         __global__ void reduce_sum_squared(Span<T> output, View<T> input, size_type outer_stride, size_type mid_stride) {
50            for (auto idx : grid_stride_range(input.size())) {
51                 const index_type outer_idx = idx / outer_stride;
52                 const index_type inner_idx = idx % mid_stride;
53
54                 const index_type sum_idx = outer_idx * mid_stride + inner_idx;
55                 atomicAdd(&output[sum_idx], input[idx] * input[idx]);
56            }
57         }
58
59         template <class T>
60         __global__ void rsqrt(Span<T> output, T epsilon) {
61             for (auto idx : grid_stride_range(output.size())) {
62                 using device::sqrt;
63                 output[idx] = T(1) / sqrt(output[idx] + epsilon);
64             }
65         }
66
67         template <class T>
68         __global__ void apply_norm(Span<T> output, View<T> input, size_type outer_stride, size_type mid_stride, View<T> sums) {
69             for (auto idx : grid_stride_range(output.size())) {
70                 const index_type outer_idx = idx / outer_stride;
71                 const index_type inner_idx = idx % mid_stride;
72
73                 const index_type sum_idx = outer_idx * mid_stride + inner_idx;
74                 output[idx] = input[idx] * sums[sum_idx];
75             }
76         }
77     }
78
79     template <class T>
80     void normalize(
81         const Stream& stream,
82         Span<T> output,
83         View<T> input, std::size_t outer_size, std::size_t mid_size, std::size_t inner_size, std::size_t norm, T epsilon,
84         Span<T> workspace)
85     {
86         CV_Assert(output.size() == input.size());
87         CV_Assert(output.size() == outer_size * mid_size * inner_size);
88         CV_Assert(norm == 1 || norm == 2);
89         CV_Assert(workspace.size() >= outer_size * inner_size);
90
91         auto sums = Span<T>(workspace.data(), outer_size * inner_size);
92
93         fill<T>(stream, sums, 0.0);
94
95         if (norm == 1) {
96             auto reduce_kernel = raw::reduce_sum_abs<T>;
97             auto policy = make_policy(reduce_kernel, input.size(), 0, stream);
98             launch_kernel(reduce_kernel, policy, sums, input, mid_size * inner_size, inner_size);
99
100             auto reciprocal_kernel = raw::reciprocal<T>;
101             policy = make_policy(reciprocal_kernel, sums.size(), 0, stream);
102             launch_kernel(reciprocal_kernel, policy, sums, epsilon);
103         } else {
104             auto reduce_kernel = raw::reduce_sum_squared<T>;
105             auto policy = make_policy(reduce_kernel, input.size(), 0, stream);
106             launch_kernel(reduce_kernel, policy, sums, input, mid_size * inner_size, inner_size);
107
108             auto rsqrt_kernel = raw::rsqrt<T>;
109             policy = make_policy(rsqrt_kernel, sums.size(), 0, stream);
110             launch_kernel(rsqrt_kernel, policy, sums, epsilon);
111         }
112
113         auto scale_kernel = raw::apply_norm<T>;
114         auto policy = make_policy(scale_kernel, output.size(), 0, stream);
115         launch_kernel(scale_kernel, policy, output, input, mid_size * inner_size, inner_size, sums);
116     }
117
118     template void normalize(const Stream&, Span<__half>, View<__half>, std::size_t, std::size_t, std::size_t, std::size_t, __half, Span<__half>);
119     template void normalize(const Stream&, Span<float>, View<float>, std::size_t, std::size_t, std::size_t, std::size_t, float, Span<float>);
120
121 }}}} /* namespace cv::dnn::cuda4dnn::kernels */