Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
[platform/upstream/opencv.git] / modules / dnn / src / cuda / concat.cu
1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4
5 #include <cuda_runtime.h>
6 #include <cuda_fp16.h>
7
8 #include "array.hpp"
9 #include "types.hpp"
10 #include "vector_traits.hpp"
11 #include "grid_stride_range.hpp"
12 #include "execution.hpp"
13 #include "kernel_dispatcher.hpp"
14
15 #include "../cuda4dnn/csl/stream.hpp"
16 #include "../cuda4dnn/csl/tensor.hpp"
17 #include "../cuda4dnn/csl/span.hpp"
18
19 #include <cstddef>
20 #include <vector>
21
22 using namespace cv::dnn::cuda4dnn::csl;
23 using namespace cv::dnn::cuda4dnn::csl::device;
24
25 namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
26
27     namespace raw {
28         template <class T, std::size_t N>
29         __global__ void concat_vec(
30             Span<T> output, size_type output_axis_size, index_type output_axis_offset,
31             View<T> input, size_type input_axis_size, size_type concat_size)
32         {
33             using vector_type = get_vector_type_t<T, N>;
34
35             auto output_vPtr = vector_type::get_pointer(output.data());
36             auto input_vPtr = vector_type::get_pointer(input.data());
37
38             /* we need to copy all the elements of input to some location in the output
39              * we copy blocks of size `total_concat_size` to some location in the output
40              */
41             const auto total_concat_size = concat_size * input_axis_size;
42
43             for (auto in_idx : grid_stride_range(input.size() / vector_type::size())) {
44                 const index_type idx = in_idx * vector_type::size();
45                 const index_type concat_num = idx / total_concat_size;
46                 const index_type concat_index = idx % total_concat_size;
47                 const index_type top_index = concat_index +
48                     (concat_num * output_axis_size + output_axis_offset) * concat_size;
49
50                 const auto out_idx = top_index / vector_type::size();
51
52                 vector_type vec;
53                 v_load(vec, input_vPtr[in_idx]);
54                 v_store(output_vPtr[out_idx], vec);
55             }
56         }
57
58         template <class T, std::size_t Rank>
59         __global__ void concat_with_offsets(
60             Span<T> output, array<size_type, Rank> out_strides, array<index_type, Rank> out_offset,
61             View<T> input, array<size_type, Rank> in_strides)
62         {
63             for (auto i : grid_stride_range(input.size())) {
64                 index_type in_index = i / in_strides[0];
65                 index_type out_index = out_offset[0] + in_index;
66                 index_type oidx = out_index * out_strides[0];
67                 for (int j = 1; j < Rank; j++) {
68                     in_index = (i % in_strides[j - 1]) / in_strides[j];
69                     out_index = out_offset[j] + in_index;
70                     oidx += out_index * out_strides[j];
71                 }
72
73                 output[oidx] = input[i];
74             }
75         }
76     }
77
78     template <class T, std::size_t N> static
79     void launch_vectorized_concat(const Stream& stream,
80         Span<T> output, size_type output_axis_size, index_type output_axis_offset,
81         View<T> input, size_type input_axis_size, size_type concat_size)
82     {
83         CV_Assert(is_fully_aligned<T>(output, N));
84         CV_Assert(is_fully_aligned<T>(input, N));
85         /* more assertions are required to fully check for vectorization possiblity; check concat() */
86
87         auto kernel = raw::concat_vec<T, N>;
88         auto policy = make_policy(kernel, input.size() / N, 0, stream);
89         launch_kernel(kernel, policy, output, output_axis_size, output_axis_offset, input, input_axis_size, concat_size);
90     }
91
92     template <class T>
93     void concat(
94         const Stream& stream,
95         TensorSpan<T> output, std::size_t output_axis_offset,
96         TensorView<T> input, std::size_t axis)
97     {
98         /* let's call the axis of interest as the channel axis for the purpose of the following discussion
99          * even though it can be any axis
100          *
101          * for each batch item:
102          *    we move all the channels from the input (which together, for a single batch item, is contiguous)
103          *    of a batch item to its corresponding contiguous place in the output
104          *
105          * for a valid vector operation:
106          * - the size of each copy block must be aligned
107          * - input must be aligned
108          * - all the destination locations in the output must be aligned
109          */
110         std::size_t concat_size = output.size_range(axis + 1, output.rank());
111
112         std::size_t input_axis_size = input.get_axis_size(axis);
113         std::size_t output_axis_size = output.get_axis_size(axis);
114
115         std::size_t copy_block_size = concat_size * input_axis_size;
116         std::size_t copy_block_stride = concat_size * output_axis_size;
117         std::size_t starting_offset = output_axis_offset * concat_size;
118
119         /* in a nutshell, all this concat operation does is copy several blocks of size `copy_block_size`
120          * to the output starting from `starting_offset` with blocks in the output strided by `copy_block_stride`
121          */
122
123         bool is_aligned_4 = copy_block_size % 4 == 0 && copy_block_stride % 4 == 0 && starting_offset % 4 == 0;
124         bool is_aligned_2 = copy_block_size % 2 == 0 && copy_block_stride % 2 == 0 && starting_offset % 2 == 0;
125
126         if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4) && is_aligned_4) {
127             launch_vectorized_concat<T, 4>(stream, output, output_axis_size, output_axis_offset, input, input_axis_size, concat_size);
128         } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2) && is_aligned_2) {
129             launch_vectorized_concat<T, 2>(stream, output, output_axis_size, output_axis_offset, input, input_axis_size, concat_size);
130         } else {
131             launch_vectorized_concat<T, 1>(stream, output, output_axis_size, output_axis_offset, input, input_axis_size, concat_size);
132         }
133     }
134
135     template void concat<__half>(const Stream&, TensorSpan<__half>, std::size_t, TensorView<__half>, std::size_t);
136     template void concat<float>(const Stream&, TensorSpan<float>, std::size_t, TensorView<float>,  std::size_t);
137
138     template <class T, std::size_t Rank> static
139     void launch_concat_with_offsets(
140         const Stream& stream,
141         Span<T> output, const std::vector<std::size_t>& outStride, const std::vector<std::size_t>& outOffset,
142         View<T> input, const std::vector<std::size_t>& inStride)
143     {
144         CV_Assert(outStride.size() == Rank);
145         CV_Assert(outOffset.size() == Rank);
146         CV_Assert(inStride.size() == Rank);
147
148         array<size_type, Rank> outStride_k, inStride_k;
149         outStride_k.assign(std::begin(outStride), std::end(outStride));
150         inStride_k.assign(std::begin(inStride), std::end(inStride));
151
152         array<index_type, Rank> outOffset_k;
153         outOffset_k.assign(std::begin(outOffset), std::end(outOffset));
154
155         auto kernel = raw::concat_with_offsets<T, Rank>;
156         auto policy = make_policy(kernel, input.size(), 0, stream);
157         launch_kernel(kernel, policy, output, outStride_k, outOffset_k, input, inStride_k);
158     }
159
160     GENERATE_KERNEL_DISPATCHER(concat_with_offsets_dispatcher, launch_concat_with_offsets);
161
162     template <class T>
163     void concat_with_offsets(
164         const Stream& stream,
165         TensorSpan<T> output, TensorView<T> input,
166         std::vector<std::size_t> offsets)
167     {
168         CV_Assert(output.rank() == input.rank());
169         CV_Assert(output.rank() == offsets.size());
170
171         /* squeezable axes at the begining of both tensors can be eliminated
172          *
173          * Reasoning:
174          * ----------
175          * Suppose an item's indices in the input tensor is [i1, i2, ...]. The indices in the output
176          * tensor will be [i1 + off1, i2 + off2, ...]. The concat operation essentially copies items
177          * from the input tensor to new locations in the output tensor.
178          *
179          * If the size of the first axis of the input and output tensor is unity, the input and output
180          * indices for all the elements will be of the form be [0, i2, ...] and [0, i2 + off2, ...]
181          * respectively. The first index does not contribute to the element's address calculation and
182          * hence does nothing apart from eating up few cycles.
183          */
184         while (input.get_axis_size(0) == 1 && output.get_axis_size(0) == 1) {
185             CV_Assert(offsets[0] == 0);
186
187             input.squeeze(0);
188             output.squeeze(0);
189             offsets.erase(std::begin(offsets));
190
191             CV_Assert(output.rank() == input.rank());
192             CV_Assert(output.rank() == offsets.size());
193         }
194
195         auto inShape = input.shape_as_vector();
196         auto outShape = output.shape_as_vector();
197
198         /* contiguous axes that undergo full copy can be combined into one axis
199          *
200          * Reasoning:
201          * ----------
202          * Suppose an item's indices in the input tensor is [i1, i2, i3, ...]. Let the first two axes not undergo any
203          * concatenation. The indices in the output tensor will be [i1, i2, i3 + off3, ...].
204          *
205          * Each axis in the contiguous axes sequence will add an offset of iN * strideN. In the above example,
206          * the two axes add a total offset of `i1 * stride1 + i2 * stride2`. We can merge the two axes into one axis with
207          * a size of `size1 * size2`. The new offset added will be i12 * stride2` as the kernel iterates through `i12`.
208          * Note that `i12` is actually `(i1 * size2 + i2)` in the original tensor.
209          */
210         for (int i = 0; i < inShape.size(); i++) {
211             /* check if axis `i` requires any slicing */
212             if (offsets[i] == 0 && inShape[i] == outShape[i]) {
213                 /* loop invariant: `i` is the first axis in the contiguous unsliced axis sequence */
214
215                 int j = i + 1; /* `j` is the axis which we will attempt to merge */
216                 while (j < inShape.size() && offsets[j] == 0 && inShape[j] == outShape[j]) {
217                     /* `j` axis is also copied fully; merge `i` and `j` */
218                     auto new_size = inShape[i] * inShape[j];
219                     inShape[i] = new_size;
220                     outShape[i] = new_size;
221                     offsets[i] = 0; /* redundant */
222
223                     /* delete axis `j` */
224                     inShape.erase(std::begin(inShape) + j);
225                     outShape.erase(std::begin(outShape) + j);
226                     offsets.erase(std::begin(offsets) + j);
227
228                     /* optimizations should not break the invariants */
229                     CV_Assert(inShape.size() == outShape.size());
230                     CV_Assert(inShape.size() == offsets.size());
231                     CV_Assert(inShape[i] == outShape[i]);
232                     CV_Assert(offsets[i] == 0);
233                 }
234             }
235         }
236
237         auto rank = inShape.size();
238
239         std::vector<std::size_t> inStride(rank), outStride(rank);
240         inStride.back() = 1;
241         outStride.back() = 1;
242         /* garbage, ..., garbage, 1 */
243
244         std::copy(std::begin(inShape) + 1, std::end(inShape), std::begin(inStride));
245         std::copy(std::begin(outShape) + 1, std::end(outShape), std::begin(outStride));
246         /* dim[0], dim[1], ..., dim[-1], 1 */
247
248         std::partial_sum(inStride.rbegin(), inStride.rend(), inStride.rbegin(), std::multiplies<int>());
249         std::partial_sum(outStride.rbegin(), outStride.rend(), outStride.rbegin(), std::multiplies<int>());
250         /* stride[0], stride[1], ..., stride[-2], 1 */
251
252         CV_Assert(1 <= rank && rank <= CSL_MAX_TENSOR_RANK);
253         concat_with_offsets_dispatcher<T, 1, CSL_MAX_TENSOR_RANK>(rank, stream, output, outStride, offsets, input, inStride);
254     }
255
256     template void concat_with_offsets(const Stream&, TensorSpan<__half>, TensorView<__half>, std::vector<std::size_t>);
257     template void concat_with_offsets(const Stream&, TensorSpan<float>, TensorView<float>, std::vector<std::size_t>);
258
259 }}}} /* namespace cv::dnn::cuda4dnn::kernels */