Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
[platform/upstream/opencv.git] / modules / dnn / src / cuda4dnn / primitives / scale_shift.hpp
1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4
5 #ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_SCALE_SHIFT_HPP
6 #define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_SCALE_SHIFT_HPP
7
8 #include "../../op_cuda.hpp"
9
10 #include "../csl/stream.hpp"
11 #include "../csl/tensor.hpp"
12
13 #include "../kernels/scale_shift.hpp"
14
15 #include <opencv2/core.hpp>
16
17 #include <cstddef>
18 #include <utility>
19
20 namespace cv { namespace dnn { namespace cuda4dnn {
21
22     template <class T>
23     class ScaleShiftOp final : public CUDABackendNode {
24     public:
25         using wrapper_type = GetCUDABackendWrapperType<T>;
26
27         ScaleShiftOp(csl::Stream stream_, std::size_t axis, const cv::Mat& weights, const cv::Mat& bias)
28             : stream(std::move(stream_)), axis{ axis }
29         {
30             if (!weights.empty())
31             {
32                 weightsTensor = csl::makeTensorHeader<T>(weights);
33                 csl::copyMatToTensor<T>(weights, weightsTensor, stream);
34             }
35
36             if (!bias.empty())
37             {
38                 biasTensor = csl::makeTensorHeader<T>(bias);
39                 csl::copyMatToTensor<T>(bias, biasTensor, stream);
40             }
41         }
42
43         void forward(
44             const std::vector<cv::Ptr<BackendWrapper>>& inputs,
45             const std::vector<cv::Ptr<BackendWrapper>>& outputs,
46             csl::Workspace& workspace) override
47         {
48             CV_Assert(outputs.size() == 1);
49
50             auto input_wrapper = inputs[0].dynamicCast<wrapper_type>();
51             auto input = input_wrapper->getView();
52
53             auto output_wrapper = outputs[0].dynamicCast<wrapper_type>();
54             auto output = output_wrapper->getSpan();
55
56             csl::TensorView<T> weights;
57             if (weightsTensor.empty() && biasTensor.empty())
58             {
59                 CV_Assert(inputs.size() == 2);
60
61                 /* no explicit scale/shift values provided; use the second input as weights */
62                 auto wrapper = inputs[1].dynamicCast<wrapper_type>();
63                 weights = wrapper->getView();
64             }
65             else if (!weightsTensor.empty())
66             {
67                 weights = csl::TensorSpan<T>(weightsTensor);
68             }
69
70             csl::TensorView<T> bias;
71             if (!biasTensor.empty())
72                 bias = csl::TensorSpan<T>(biasTensor);
73
74             const auto numParams = !weights.empty() ? weights.size() : bias.size();
75             CV_Assert(numParams != 0);
76             if (!weightsTensor.empty() && !biasTensor.empty())
77             {
78                 CV_CheckEQ(weights.size(), bias.size(), "weights and bias size are not equal");
79             }
80
81             /* the weights/bias might require broadcasting to scale/shift */
82             const int end_axis = [&] {
83                 for (int endAxis = axis + 1; endAxis <= input.rank(); endAxis++)
84                 {
85                     std::size_t size = input.size_range(axis, endAxis);
86                     if (size == numParams)
87                         return endAxis;
88                 }
89                 CV_Assert(0 /* invalid weights matrix */);
90             }();
91
92             std::size_t inner_size = input.size_range(end_axis, input.rank());
93
94             if (!weights.empty() && !bias.empty())
95                 kernels::scaleN_with_biasN<T>(stream, output, input, inner_size, weights, bias);
96             else if (!weights.empty())
97                 kernels::scaleN<T>(stream, output, input, inner_size, weights);
98             else
99                 kernels::biasN<T>(stream, output, input, inner_size, bias);
100         }
101
102     private:
103         csl::Stream stream;
104         csl::Tensor<T> weightsTensor, biasTensor;
105         std::size_t axis;
106     };
107
108 }}} /* namespace cv::dnn::cuda4dnn */
109
110 #endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_SCALE_SHIFT_HPP */