Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
[platform/upstream/opencv.git] / modules / dnn / src / cuda / scale_shift.cu
1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4
5 #include <cuda_runtime.h>
6 #include <cuda_fp16.h>
7
8 #include "types.hpp"
9 #include "vector_traits.hpp"
10 #include "grid_stride_range.hpp"
11 #include "execution.hpp"
12
13 #include "../cuda4dnn/csl/stream.hpp"
14 #include "../cuda4dnn/csl/tensor.hpp"
15 #include "../cuda4dnn/csl/span.hpp"
16
17 #include <opencv2/core.hpp>
18
19 #include <cstddef>
20
21 using namespace cv::dnn::cuda4dnn::csl;
22 using namespace cv::dnn::cuda4dnn::csl::device;
23
24 namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
25
26     namespace raw {
27         template <class T, std::size_t N>
28         __global__ void bias1_vec(Span<T> output, View<T> input, T beta) {
29             using vector_type = get_vector_type_t<T, N>;
30
31             auto output_vPtr = vector_type::get_pointer(output.data());
32             auto input_vPtr = vector_type::get_pointer(input.data());
33
34             for (auto i : grid_stride_range(output.size() / vector_type::size())) {
35                 vector_type vec;
36                 v_load(vec, input_vPtr[i]);
37                 for (int j = 0; j < vec.size(); j++)
38                     vec.data[j] = vec.data[j] + beta;
39                 v_store(output_vPtr[i], vec);
40             }
41         }
42
43         template <class T, std::size_t N>
44         __global__ void biasN_vec(Span<T> output, View<T> input, size_type inner_size, View<T> bias) {
45             using vector_type = get_vector_type_t<T, N>;
46
47             auto output_vPtr = vector_type::get_pointer(output.data());
48             auto input_vPtr = vector_type::get_pointer(input.data());
49
50             inner_size /= vector_type::size();
51             for (auto i : grid_stride_range(output.size() / vector_type::size())) {
52                 const index_type bias_idx = (i / inner_size) % static_cast<size_type>(bias.size());
53
54                 vector_type vec;
55                 v_load(vec, input_vPtr[i]);
56                 for(int j = 0; j < vec.size(); j++)
57                     vec.data[j] = vec.data[j] + bias[bias_idx];
58                 v_store(output_vPtr[i], vec);
59             }
60         }
61
62         template <class T, std::size_t N>
63         __global__ void scale1_vec(Span<T> output, View<T> input, T alpha) {
64             using vector_type = get_vector_type_t<T, N>;
65
66             auto output_vPtr = vector_type::get_pointer(output.data());
67             auto input_vPtr = vector_type::get_pointer(input.data());
68
69             for (auto i : grid_stride_range(output.size() / vector_type::size())) {
70                 vector_type vec;
71                 v_load(vec, input_vPtr[i]);
72                 for (int j = 0; j < vec.size(); j++)
73                     vec.data[j] = vec.data[j] * alpha;
74                 v_store(output_vPtr[i], vec);
75             }
76         }
77
78         template <class T, std::size_t N>
79         __global__ void scaleN_vec(Span<T> output, View<T> input, size_type inner_size, View<T> weights)
80         {
81             using vector_type = get_vector_type_t<T, N>;
82
83             auto output_vPtr = vector_type::get_pointer(output.data());
84             auto input_vPtr = vector_type::get_pointer(input.data());
85
86             inner_size /= vector_type::size();
87             for (auto i : grid_stride_range(output.size() / vector_type::size())) {
88                 const index_type scale_idx = (i / inner_size) % static_cast<size_type>(weights.size());
89
90                 vector_type vec;
91                 v_load(vec, input_vPtr[i]);
92                 for (int j = 0; j < vec.size(); j++)
93                     vec.data[j] = vec.data[j] * weights[scale_idx];
94                 v_store(output_vPtr[i], vec);
95             }
96         }
97
98         template <class T, std::size_t N>
99         __global__ void scale1_with_bias1_vec(Span<T> output, View<T> input, T alpha, T beta)
100         {
101             using vector_type = get_vector_type_t<T, N>;
102
103             auto output_vPtr = vector_type::get_pointer(output.data());
104             auto input_vPtr = vector_type::get_pointer(input.data());
105
106             for (auto i : grid_stride_range(output.size() / vector_type::size())) {
107                 vector_type vec;
108                 v_load(vec, input_vPtr[i]);
109                 for (int j = 0; j < vec.size(); j++)
110                     vec.data[j] = alpha * vec.data[j] + beta;
111                 v_store(output_vPtr[i], vec);
112             }
113         }
114
115         template <class T, std::size_t N>
116         __global__ void scaleN_with_biasN_vec(Span<T> output, View<T> input, size_type inner_size, View<T> weights, View<T> bias)
117         {
118             using vector_type = get_vector_type_t<T, N>;
119
120             auto output_vPtr = vector_type::get_pointer(output.data());
121             auto input_vPtr = vector_type::get_pointer(input.data());
122
123             inner_size /= vector_type::size();
124             for (auto i : grid_stride_range(output.size() / vector_type::size())) {
125                 const index_type scale_idx = (i / inner_size) % static_cast<size_type>(weights.size());
126
127                 vector_type vec;
128                 v_load(vec, input_vPtr[i]);
129                 for (int j = 0; j < vec.size(); j++)
130                     vec.data[j] = vec.data[j] * weights[scale_idx] + bias[scale_idx];
131                 v_store(output_vPtr[i], vec);
132             }
133         }
134     }
135
136     template <class T, std::size_t N> static
137     void launch_bias1_vec_kernel(const Stream& stream, Span<T> output, View<T> input, T beta) {
138         CV_Assert(is_fully_aligned<T>(output, N));
139         CV_Assert(is_fully_aligned<T>(input, N));
140
141         auto kernel = raw::bias1_vec<T, N>;
142         auto policy = make_policy(kernel, output.size() / N, 0, stream);
143         launch_kernel(kernel, policy, output, input, beta);
144     }
145
146     template <class T>
147     void bias1(const Stream& stream, TensorSpan<T> output, TensorView<T> input, T beta) {
148         CV_Assert(is_shape_same(input, output));
149
150         if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
151             launch_bias1_vec_kernel<T, 4>(stream, output, input, beta);
152         } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
153             launch_bias1_vec_kernel<T, 2>(stream, output, input, beta);
154         } else {
155             launch_bias1_vec_kernel<T, 1>(stream, output, input, beta);
156         }
157     }
158
159     template void bias1<__half>(const Stream&, TensorSpan<__half>, TensorView<__half>, __half);
160     template void bias1<float>(const Stream&, TensorSpan<float>, TensorView<float>, float);
161
162     template <class T, std::size_t N> static
163     void launch_biasN_vec_kernel(const Stream& stream, Span<T> output, View<T> input, std::size_t inner_size, View<T> bias){
164         CV_Assert(is_fully_aligned<T>(output, N));
165         CV_Assert(is_fully_aligned<T>(input, N));
166         CV_Assert(inner_size % N == 0);
167
168         auto kernel = raw::biasN_vec<T, N>;
169         auto policy = make_policy(kernel, output.size() / N, 0, stream);
170         launch_kernel(kernel, policy, output, input, inner_size, bias);
171     }
172
173     template <class T>
174     void biasN(
175         const Stream& stream,
176         TensorSpan<T> output,
177         TensorView<T> input, std::size_t inner_size,
178         TensorView<T> bias)
179     {
180         CV_Assert(is_shape_same(input, output));
181
182         if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4) && inner_size % 4 == 0) {
183             launch_biasN_vec_kernel<T, 4>(stream, output, input, inner_size, bias);
184         } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2) && inner_size % 2 == 0) {
185             launch_biasN_vec_kernel<T, 2>(stream, output, input, inner_size, bias);
186         } else {
187             launch_biasN_vec_kernel<T, 1>(stream, output, input, inner_size, bias);
188         }
189     }
190
191     template void biasN<__half>(const Stream&, TensorSpan<__half>, TensorView<__half>, std::size_t, TensorView<__half>);
192     template void biasN<float>(const Stream&, TensorSpan<float>, TensorView<float>, std::size_t, TensorView<float>);
193
194     template <class T, std::size_t N> static
195     void launch_scale1_vec_kernel(const Stream& stream, Span<T> output, View<T> input, T alpha) {
196         CV_Assert(is_fully_aligned<T>(output, N));
197         CV_Assert(is_fully_aligned<T>(input, N));
198
199         auto kernel = raw::scale1_vec<T, N>;
200         auto policy = make_policy(kernel, output.size() / N, 0, stream);
201         launch_kernel(kernel, policy, output, input, alpha);
202     }
203
204     template <class T>
205     void scale1(const Stream& stream, TensorSpan<T> output, TensorView<T> input, T alpha) {
206         CV_Assert(is_shape_same(input, output));
207
208         if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
209             launch_scale1_vec_kernel<T, 4>(stream, output, input, alpha);
210         } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
211             launch_scale1_vec_kernel<T, 2>(stream, output, input, alpha);
212         } else {
213             launch_scale1_vec_kernel<T, 1>(stream, output, input, alpha);
214         }
215     }
216
217     template void scale1<__half>(const Stream&, TensorSpan<__half>, TensorView<__half>, __half);
218     template void scale1<float>(const Stream&, TensorSpan<float>, TensorView<float>, float);
219
220     template <class T, std::size_t N> static
221     void launch_scaleN_vec_kernel(const Stream& stream, Span<T> output, View<T> input, std::size_t inner_size, View<T> weights) {
222         CV_Assert(is_fully_aligned<T>(output, N));
223         CV_Assert(is_fully_aligned<T>(input, N));
224         CV_Assert(inner_size % N == 0);
225
226         auto kernel = raw::scaleN_vec<T, N>;
227         auto policy = make_policy(kernel, output.size() / N, 0, stream);
228         launch_kernel(kernel, policy, output, input, inner_size, weights);
229     }
230
231     template <class T>
232     void scaleN(
233         const Stream& stream,
234         TensorSpan<T> output,
235         TensorView<T> input, std::size_t inner_size,
236         TensorView<T> weights)
237     {
238         CV_Assert(is_shape_same(input, output));
239
240         if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4) && inner_size % 4 == 0) {
241             launch_scaleN_vec_kernel<T, 4>(stream, output, input, inner_size, weights);
242         } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2) && inner_size % 2 == 0) {
243             launch_scaleN_vec_kernel<T, 2>(stream, output, input, inner_size, weights);
244         } else {
245             launch_scaleN_vec_kernel<T, 1>(stream, output, input, inner_size, weights);
246         }
247     }
248
249     template void scaleN<__half>(const Stream&, TensorSpan<__half>, TensorView<__half>, std::size_t, TensorView<__half>);
250     template void scaleN<float>(const Stream&, TensorSpan<float>, TensorView<float>, std::size_t, TensorView<float>);
251
252     template <class T, std::size_t N> static
253     void launch_scale1_with_bias1_vec_kernel(const Stream& stream, Span<T> output, View<T> input, T alpha, T beta) {
254         CV_Assert(is_fully_aligned<T>(output, N));
255         CV_Assert(is_fully_aligned<T>(input, N));
256
257         auto kernel = raw::scale1_with_bias1_vec<T, N>;
258         auto policy = make_policy(kernel, output.size() / N, 0, stream);
259         launch_kernel(kernel, policy, output, input, alpha, beta);
260     }
261
262     template <class T>
263     void scale1_with_bias1(const Stream& stream, Span<T> output, View<T> input, T alpha, T beta) {
264         CV_Assert(output.size() == input.size());
265
266         if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
267             launch_scale1_with_bias1_vec_kernel<T, 4>(stream, output, input, alpha, beta);
268         } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
269             launch_scale1_with_bias1_vec_kernel<T, 2>(stream, output, input, alpha, beta);
270         } else {
271             launch_scale1_with_bias1_vec_kernel<T, 1>(stream, output, input, alpha, beta);
272         }
273     }
274
275     template void scale1_with_bias1<__half>(const Stream&, Span<__half>, View<__half>, __half, __half);
276     template void scale1_with_bias1<float>(const Stream&, Span<float>, View<float>, float, float);
277
278     template <class T, std::size_t N> static
279     void launch_scaleN_with_biasN_vec_kernel(const Stream& stream, Span<T> output, View<T> input, std::size_t inner_size, View<T> weights, View<T> bias) {
280         CV_Assert(is_fully_aligned<T>(output, N));
281         CV_Assert(is_fully_aligned<T>(input, N));
282         CV_Assert(inner_size % N == 0);
283
284         auto kernel = raw::scaleN_with_biasN_vec<T, N>;
285         auto policy = make_policy(kernel, output.size() / N, 0, stream);
286         launch_kernel(kernel, policy, output, input, inner_size, weights, bias);
287     }
288
289     template <class T>
290     void scaleN_with_biasN(
291         const Stream& stream,
292         TensorSpan<T> output,
293         TensorView<T> input, std::size_t inner_size,
294         TensorView<T> weights, TensorView<T> bias)
295     {
296         CV_Assert(is_shape_same(input, output));
297         CV_Assert(weights.size() == bias.size());
298
299         if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4) && inner_size % 4 == 0) {
300             launch_scaleN_with_biasN_vec_kernel<T, 4>(stream, output, input, inner_size, weights, bias);
301         } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2) && inner_size % 2 == 0) {
302             launch_scaleN_with_biasN_vec_kernel<T, 2>(stream, output, input, inner_size, weights, bias);
303         } else {
304             launch_scaleN_with_biasN_vec_kernel<T, 1>(stream, output, input, inner_size, weights, bias);
305         }
306     }
307
308     template void scaleN_with_biasN<__half>(const Stream&, TensorSpan<__half>, TensorView<__half>, std::size_t, TensorView<__half>, TensorView<__half>);
309     template void scaleN_with_biasN<float>(const Stream&, TensorSpan<float>, TensorView<float>, std::size_t, TensorView<float>, TensorView<float>);
310
311 }}}} /* namespace cv::dnn::cuda4dnn::kernels */