Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
[platform/upstream/opencv.git] / modules / dnn / src / cuda / slice.cu
1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4
5 #include <cuda_runtime.h>
6 #include <cuda_fp16.h>
7
8 #include "array.hpp"
9 #include "types.hpp"
10 #include "grid_stride_range.hpp"
11 #include "execution.hpp"
12 #include "kernel_dispatcher.hpp"
13
14 #include "../cuda4dnn/csl/stream.hpp"
15 #include "../cuda4dnn/csl/tensor.hpp"
16 #include "../cuda4dnn/csl/span.hpp"
17
18 #include <opencv2/core.hpp>
19
20 #include <cstddef>
21 #include <vector>
22 #include <iostream>
23
24 using namespace cv::dnn::cuda4dnn::csl;
25 using namespace cv::dnn::cuda4dnn::csl::device;
26
27 namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
28
29     namespace raw {
30         template <class T, std::size_t Rank>
31         __global__ void slice(
32             Span<T> output, array<size_type, Rank> out_strides,
33             View<T> input, array<size_type, Rank> in_strides, array<index_type, Rank> in_offset)
34         {
35             for (auto i : grid_stride_range(output.size())) {
36                 index_type out_index = i / out_strides[0];
37                 index_type in_index = in_offset[0] + out_index;
38                 index_type iidx = in_index * in_strides[0];
39                 for (int j = 1; j < Rank; j++) {
40                     out_index = (i % out_strides[j - 1]) / out_strides[j];
41                     in_index = in_offset[j] + out_index;
42                     iidx += in_index * in_strides[j];
43                 }
44
45                 output[i] = input[iidx];
46             }
47         }
48     }
49
50     template <class T, std::size_t Rank> static
51     void launch_slice(
52         const Stream& stream,
53         Span<T> output, const std::vector<std::size_t>& outStride,
54         View<T> input, const std::vector<std::size_t>& inStride, const std::vector<std::size_t>& inOffset)
55     {
56         CV_Assert(outStride.size() == Rank);
57         CV_Assert(inStride.size() == Rank);
58         CV_Assert(inOffset.size() == Rank);
59
60         array<size_type, Rank> outStride_k, inStride_k;
61         outStride_k.assign(std::begin(outStride), std::end(outStride));
62         inStride_k.assign(std::begin(inStride), std::end(inStride));
63
64         array<index_type, Rank> inOffset_k;
65         inOffset_k.assign(std::begin(inOffset), std::end(inOffset));
66
67         auto kernel = raw::slice<T, Rank>;
68         auto policy = make_policy(kernel, output.size(), 0, stream);
69         launch_kernel(kernel, policy, output, outStride_k, input, inStride_k, inOffset_k);
70     }
71
72     GENERATE_KERNEL_DISPATCHER(slice_dispatcher, launch_slice);
73
74     template <class T>
75     void slice(const Stream& stream,
76         TensorSpan<T> output, TensorView<T> input,
77         std::vector<std::size_t> offsets)
78     {
79         CV_Assert(output.rank() == input.rank());
80         CV_Assert(output.rank() == offsets.size());
81
82         /* squeezable axes at the begining of both tensors can be eliminated
83          *
84          * Reasoning:
85          * ----------
86          * Suppose an item's indices in the output tensor is [o1, o2, ...]. The indices in the input
87          * tensor will be [o1 + off1, o2 + off2, ...]. The rest of the elements in the input are igored.
88          *
89          * If the size of the first axis of the input and output tensor is unity, the input and output indices
90          * for all the elements will be of the form be [0, o2 + off2, ...] and [0, o2, ...] respectively. Note that
91          * there cannot be any ignored items since the axes have unit size. The first index does not contribute to the
92          * element's address calculation and hence does nothing apart from eating up few cycles.
93          */
94         while (input.get_axis_size(0) == 1 && output.get_axis_size(0) == 1) {
95             CV_Assert(offsets[0] == 0);
96
97             input.squeeze(0);
98             output.squeeze(0);
99             offsets.erase(std::begin(offsets));
100
101             CV_Assert(output.rank() == input.rank());
102             CV_Assert(output.rank() == offsets.size());
103         }
104
105         auto inShape = input.shape_as_vector();
106         auto outShape = output.shape_as_vector();
107
108         /* contiguous axes which do not undergo slicing can be combined into one axis
109          *
110          * Reasoning:
111          * ----------
112          * Suppose an item's indices in the output tensor is [o1, o2, o3, ...]. Let the first two axes not undergo any
113          * slicing. The indices in the input tensor will be [o1, o2, o3 + off3, ...].
114          *
115          * Each axis in the contiguous unsliced axes sequence will add an offset of iN * strideN. In the above example,
116          * the two axes add a total offset of `o1 * stride1 + o2 * stride2`. We can merge the two axes into one axis with
117          * a size of `size1 * size2`. The new offset added will be o12 * stride2` as the kernel iterates through `o12`.
118          * Note that `o12` is actually `(o1 * size2 + o2)` in the original tensor.
119          */
120         for (int i = 0; i < inShape.size(); i++) {
121             /* check if axis `i` requires any slicing */
122             if (offsets[i] == 0 && inShape[i] == outShape[i]) {
123                 /* loop invariant: `i` is the first axis in the contiguous unsliced axis sequence */
124
125                 int j = i + 1; /* `j` is the axis which we will attempt to merge */
126                 while (j < inShape.size() && offsets[j] == 0 && inShape[j] == outShape[j]) {
127                     /* `j` axis is also unsliced; merge `i` and `j` */
128                     auto new_size = inShape[i] * inShape[j];
129                     inShape[i] = new_size;
130                     outShape[i] = new_size;
131                     offsets[i] = 0; /* redundant */
132
133                     /* delete axis `j` */
134                     inShape.erase(std::begin(inShape) + j);
135                     outShape.erase(std::begin(outShape) + j);
136                     offsets.erase(std::begin(offsets) + j);
137
138                     /* optimizations should not break the invariants */
139                     CV_Assert(inShape.size() == outShape.size());
140                     CV_Assert(inShape.size() == offsets.size());
141                     CV_Assert(inShape[i] == outShape[i]);
142                     CV_Assert(offsets[i] == 0);
143                 }
144             }
145         }
146
147         auto rank = inShape.size();
148
149         std::vector<std::size_t> inStride(rank), outStride(rank);
150         inStride.back() = 1;
151         outStride.back() = 1;
152         /* garbage, ..., garbage, 1 */
153
154         std::copy(std::begin(inShape) + 1, std::end(inShape), std::begin(inStride));
155         std::copy(std::begin(outShape) + 1, std::end(outShape), std::begin(outStride));
156         /* dim[0], dim[1], ..., dim[-1], 1 */
157
158         std::partial_sum(inStride.rbegin(), inStride.rend(), inStride.rbegin(), std::multiplies<std::size_t>());
159         std::partial_sum(outStride.rbegin(), outStride.rend(), outStride.rbegin(), std::multiplies<std::size_t>());
160         /* stride[0], stride[1], ..., stride[-2], 1 */
161
162         CV_Assert(1 <= rank && rank <= CSL_MAX_TENSOR_RANK);
163         slice_dispatcher<T, 1, CSL_MAX_TENSOR_RANK>(rank, stream, output, outStride, input, inStride, offsets);
164     }
165
166     template void slice(const Stream&, TensorSpan<__half>, TensorView<__half>, std::vector<std::size_t>);
167     template void slice(const Stream&, TensorSpan<float>, TensorView<float>, std::vector<std::size_t>);
168
169 }}}} /* namespace cv::dnn::cuda4dnn::kernels */