Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
[platform/upstream/opencv.git] / modules / dnn / src / cuda / padding.cu
1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4
5 #include <cuda_runtime.h>
6 #include <cuda_fp16.h>
7
8 #include "array.hpp"
9 #include "math.hpp"
10 #include "types.hpp"
11 #include "grid_stride_range.hpp"
12 #include "execution.hpp"
13 #include "kernel_dispatcher.hpp"
14
15 #include "../cuda4dnn/csl/stream.hpp"
16 #include "../cuda4dnn/csl/tensor.hpp"
17 #include "../cuda4dnn/csl/span.hpp"
18
19 #include <opencv2/core.hpp>
20
21 #include <cstddef>
22 #include <vector>
23 #include <utility>
24
25 using namespace cv::dnn::cuda4dnn::csl;
26 using namespace cv::dnn::cuda4dnn::csl::device;
27
28 namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
29
30     namespace raw {
31         template <class T, std::size_t Rank>
32         __global__ void copy_with_reflection101(
33             Span<T> output, array<size_type, Rank> out_strides, array<index_type, Rank> start, array<index_type, Rank> end,
34             View<T> input, array<size_type, Rank> in_strides)
35         {
36             for (auto i : grid_stride_range(output.size())) {
37                 /* compute output axis indices corresponding to element 'i' */
38                 array<index_type, Rank> out_index;
39                 out_index[0] = i / out_strides[0];
40                 for (int j = 1; j < Rank; j++)
41                     out_index[j] = (i % out_strides[j - 1]) / out_strides[j];
42
43                 /* compute input axis indices corresponding to output axis indices */
44                 array<index_type, Rank> in_index;
45                 for (int j = 0; j < Rank; j++) {
46                     /* if out_index < start, the point is in the left reflection region
47                      * the reflected value's index is the absolute value of the difference
48                      *
49                      * otherwise, if the value is in the copy region, out_index - start gives the input index
50                      */
51                     using device::abs;
52                     in_index[j] = abs(out_index[j] - start[j]);
53
54                     /* if out_index >= end, it's in the right reflection region */
55                     if (out_index[j] >= end[j])
56                         in_index[j] = (end[j] - start[j]) - (out_index[j] - end[j]) - 2;
57                 }
58
59                 /* compute input element number from input axis indices */
60                 index_type iidx = 0;
61                 for (int j = 0; j < Rank; j++)
62                     iidx += in_index[j] * in_strides[j];
63
64                 output[i] = input[iidx];
65             }
66         }
67     }
68
69     template <class T, std::size_t Rank> static
70     void launch_copy_with_reflection101(
71         const Stream& stream,
72         Span<T> output, const std::vector<std::size_t>& outStride,
73         View<T> input, const std::vector<std::size_t>& inStride,
74         const std::vector<std::pair<std::size_t, std::size_t>>& ranges)
75     {
76         CV_Assert(outStride.size() == Rank);
77         CV_Assert(inStride.size() == Rank);
78         CV_Assert(ranges.size() == Rank);
79
80         array<size_type, Rank> outStride_k, inStride_k;
81         outStride_k.assign(std::begin(outStride), std::end(outStride));
82         inStride_k.assign(std::begin(inStride), std::end(inStride));
83
84         array<index_type, Rank> start_k, end_k;
85         for (int i = 0; i < Rank; i++) {
86             start_k[i] = ranges[i].first;
87             end_k[i] = ranges[i].second;
88         }
89
90         auto kernel = raw::copy_with_reflection101<T, Rank>;
91         auto policy = make_policy(kernel, output.size(), 0, stream);
92         launch_kernel(kernel, policy, output, outStride_k, start_k, end_k, input, inStride_k);
93     }
94
95     GENERATE_KERNEL_DISPATCHER(copy_with_reflection101_dispatcher, launch_copy_with_reflection101);
96
97     template <class T>
98     void copy_with_reflection101(
99         const Stream& stream,
100         TensorSpan<T> output, TensorView<T> input,
101         std::vector<std::pair<std::size_t, std::size_t>> ranges)
102     {
103         CV_Assert(output.rank() == input.rank());
104         CV_Assert(output.rank() == ranges.size());
105
106         /* squeezable axes at the begining of both tensors can be eliminated
107          *
108          * Reasoning:
109          * ----------
110          * Suppose an item's indices in the input tensor is [i1, i2, ...]. The indices in the
111          * output tensor will be [i1 + off1, i2 + off2, ...]. The rest of the elements in the output are padding.
112          * The padding operation essentially copies items from the input tensor to new locations in the output tensor
113          * and pads the remaining.
114          *
115          * If the size of the first axis of the input and output tensor is unity, the input and output indices
116          * for all the elements will be of the form be [0, i2, ...] and [0, i2 + off2, ...] respectively. Note that
117          * there cannot be extra padding since the axes have unit size. The first index does not contribute to the
118          * element's address calculation and hence does nothing apart from eating up few cycles.
119          */
120         while (input.get_axis_size(0) == 1 && output.get_axis_size(0) == 1) {
121             CV_Assert(ranges[0].first == 0 && ranges[0].second == 1);
122
123             input.squeeze(0);
124             output.squeeze(0);
125             ranges.erase(std::begin(ranges));
126
127             CV_Assert(output.rank() == input.rank());
128             CV_Assert(output.rank() == ranges.size());
129         }
130
131         auto inShape = input.shape_as_vector();
132         auto outShape = output.shape_as_vector();
133
134         /* contiguous axes which do not have any padding can be combined into one axis
135          *
136          * Reasoning:
137          * ----------
138          * Suppose an item's indices in the input tensor is [i1, i2, i3, ...]. Let the first two axes not have any
139          * padding. The indices in the output tensor will be [i1, i2, i3 + off3, ...].
140          *
141          * Each axis in the contiguous unpadded axes sequence will add an offset of iN * strideN. In the above example,
142          * the two axes add a total offset of `i1 * stride1 + i2 * stride2`. We can merge the two axes into one axis with
143          * a size of `size1 * size2`. The new offset added will be `i12 * stride2` as the kernel iterates through `i12`.
144          * Note that `i12` is actually `(i1 * size2 + i2)` in the original tensor.
145          */
146         for (int i = 0; i < inShape.size(); i++) {
147             /* check if axis `i` requires any padding */
148             if (ranges[i].first == 0 && ranges[i].second == inShape[i]) {
149                 /* loop invariant: `i` is the first axis in the contiguous unpadded axis sequence */
150                 CV_Assert(inShape[i] == outShape[i]);
151
152                 /* we now iterate through the axes which follow and try to merge */
153                 int j = i + 1; /* `j` is the axis which we will attempt to merge */
154                 while (j < inShape.size() && ranges[j].first == 0 && ranges[j].second == inShape[j]) {
155                     CV_Assert(inShape[j] == outShape[j]);
156
157                     /* `j` is also unpadded; merge `i` and `j` */
158                     auto new_size = inShape[i] * inShape[j];
159                     inShape[i] = new_size;
160                     outShape[i] = new_size;
161                     ranges[i].second = new_size;
162
163                     /* delete axis `j` */
164                     inShape.erase(std::begin(inShape) + j);
165                     outShape.erase(std::begin(outShape) + j);
166                     ranges.erase(std::begin(ranges) + j);
167
168                     /* optimizations should not break the invariants */
169                     CV_Assert(inShape.size() == outShape.size());
170                     CV_Assert(inShape.size() == ranges.size());
171                     CV_Assert(inShape[i] == outShape[i]);
172                     CV_Assert(ranges[i].first == 0 && ranges[i].second == inShape[i]);
173                 }
174             }
175         }
176
177         auto rank = inShape.size();
178
179         std::vector<std::size_t> inStride(rank), outStride(rank);
180         inStride.back() = 1;
181         outStride.back() = 1;
182         /* garbage, ..., garbage, 1 */
183
184         std::copy(std::begin(inShape) + 1, std::end(inShape), std::begin(inStride));
185         std::copy(std::begin(outShape) + 1, std::end(outShape), std::begin(outStride));
186         /* dim[0], dim[1], ..., dim[-1], 1 */
187
188         std::partial_sum(inStride.rbegin(), inStride.rend(), inStride.rbegin(), std::multiplies<int>());
189         std::partial_sum(outStride.rbegin(), outStride.rend(), outStride.rbegin(), std::multiplies<int>());
190         /* stride[0], stride[1], ..., stride[-2], 1 */
191
192         CV_Assert(1 <= rank && rank <= CSL_MAX_TENSOR_RANK);
193         copy_with_reflection101_dispatcher<T, 1, CSL_MAX_TENSOR_RANK>(rank, stream, output, outStride, input, inStride, ranges);
194     }
195
196     template void copy_with_reflection101(const Stream&, TensorSpan<__half>, TensorView<__half>, std::vector<std::pair<std::size_t, std::size_t>> ranges);
197     template void copy_with_reflection101(const Stream&, TensorSpan<float>, TensorView<float>, std::vector<std::pair<std::size_t, std::size_t>> ranges);
198
199 }}}} /* namespace namespace cv::dnn::cuda4dnn::kernels */