Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
[platform/upstream/opencv.git] / modules / dnn / src / cuda / permute.cu
1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4
5 #include <cuda_runtime.h>
6 #include <cuda_fp16.h>
7
8 #include "array.hpp"
9 #include "types.hpp"
10 #include "grid_stride_range.hpp"
11 #include "execution.hpp"
12 #include "kernel_dispatcher.hpp"
13
14 #include "../cuda4dnn/csl/stream.hpp"
15 #include "../cuda4dnn/csl/tensor.hpp"
16 #include "../cuda4dnn/csl/span.hpp"
17
18 #include <opencv2/core.hpp>
19
20 #include <cstddef>
21 #include <vector>
22
23 using namespace cv::dnn::cuda4dnn::csl;
24 using namespace cv::dnn::cuda4dnn::csl::device;
25
26 namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
27
28     namespace raw {
29         template <class T, std::size_t Rank>
30         __global__ void permute(
31             array<index_type, Rank> axis_order,
32             Span<T> output, array<size_type, Rank> outStrides,
33             View<T> input, array<size_type, Rank> inStrides)
34         {
35             for (auto i : grid_stride_range(input.size())) {
36                 index_type oldPosition = 0;
37                 index_type newPosition = i;
38
39                 for (int j = 0; j < Rank; j++)
40                 {
41                     auto order = axis_order[j];
42                     oldPosition += (newPosition / outStrides[j]) * inStrides[order];
43                     newPosition %= outStrides[j];
44                 }
45
46                 output[i] = input[oldPosition];
47             }
48         }
49     }
50
51     template <class T, std::size_t Rank> static
52     void launch_permute_kernel(
53         const Stream& stream,
54         const std::vector<std::size_t>& order,
55         Span<T> output, const std::vector<std::size_t>& outStride,
56         View<T> input, const std::vector<std::size_t>& inStride)
57     {
58         CV_Assert(order.size() == Rank);
59         CV_Assert(outStride.size() == Rank);
60         CV_Assert(inStride.size() == Rank);
61
62         array<index_type, Rank> order_k;
63         order_k.assign(std::begin(order), std::end(order));
64
65         array<size_type, Rank> outStride_k, inStride_k;
66         outStride_k.assign(std::begin(outStride), std::end(outStride));
67         inStride_k.assign(std::begin(inStride), std::end(inStride));
68
69         auto kernel = raw::permute<T, Rank>;
70         auto policy = make_policy(kernel, input.size(), 0, stream);
71         launch_kernel(kernel, policy, order_k, output, outStride_k, input, inStride_k);
72     }
73
74     GENERATE_KERNEL_DISPATCHER(permute_dispatcher, launch_permute_kernel);
75
76     template <class T>
77     void permute(
78         const Stream& stream,
79         TensorSpan<T> output, TensorView<T> input,
80         std::vector<std::size_t> order)
81     {
82         CV_Assert(output.rank() == input.rank());
83         CV_Assert(input.rank() == order.size());
84         CV_Assert(input.size() == output.size());
85
86         /* squeezable axes at the begining of both tensors which aren't permuted can be eliminated
87          *
88          * Reasoning:
89          * ----------
90          * Suppose an item's indices in the input tensor is [i1, i2, ...]. The indices in the
91          * output tensor will be some permutation of the input tensor indices. Let the output
92          * tensor indices be [o1, o2, ...]. The permutation operation essentially copies items
93          * from the input tensor to new locations in the output tensor as dictated by the indices.
94          *
95          * If the size of the first axis of the input and output tensor is one and these axes are
96          * not involved in any permutation, i.e. order[0] = 0, the input and output indicies for
97          * all the elements will be of the form be [0, i2, ...] and [0, o2, ...] respectively.
98          * The first index does not contribute to the element's address calculation and hence does
99          * nothing apart from eating up few cycles.
100          */
101         while (order[0] == 0 && input.get_axis_size(0) == 1 && output.get_axis_size(0) == 1) {
102             /* remove the axes */
103             input.squeeze(0);
104             output.squeeze(0);
105
106             /* when we remove axis zero, the axis index will be one less than the previous index
107              * for the remaining axes
108              */
109             order.erase(order.begin());
110             for (auto& axis : order)
111                 axis--;
112
113             /* optimizations should not break the invariants */
114             CV_Assert(output.rank() == input.rank());
115             CV_Assert(input.rank() == order.size());
116             CV_Assert(input.size() == output.size());
117         }
118
119         auto rank = output.rank();
120         auto inShape = input.shape_as_vector();
121         auto outShape = output.shape_as_vector();
122
123         std::vector<std::size_t> inStride(rank), outStride(rank);
124         inStride.back() = 1;
125         outStride.back() = 1;
126         /* garbage, ..., garbage, 1 */
127
128         std::copy(std::begin(inShape) + 1, std::end(inShape), std::begin(inStride));
129         std::copy(std::begin(outShape) + 1, std::end(outShape), std::begin(outStride));
130         /* dim[0], dim[1], ..., dim[-1], 1 */
131
132         std::partial_sum(inStride.rbegin(), inStride.rend(), inStride.rbegin(), std::multiplies<std::size_t>());
133         std::partial_sum(outStride.rbegin(), outStride.rend(), outStride.rbegin(), std::multiplies<std::size_t>());
134         /* stride[0], stride[1], ..., stride[-2], 1 */
135
136         CV_Assert(2 <= rank && rank <= CSL_MAX_TENSOR_RANK);
137         permute_dispatcher<T, 2, CSL_MAX_TENSOR_RANK>(rank, stream, order, output, outStride, input, inStride);
138     }
139
140     template void permute(const Stream&, TensorSpan<__half>, TensorView<__half>, std::vector<std::size_t>);
141     template void permute(const Stream&, TensorSpan<float>, TensorView<float>, std::vector<std::size_t>);
142
143 }}}} /* namespace cv::dnn::cuda4dnn::kernels */