Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
[platform/upstream/opencv.git] / modules / dnn / src / cuda4dnn / primitives / batch_norm.hpp
1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4
5 #ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_BATCH_NORM_HPP
6 #define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_BATCH_NORM_HPP
7
8 #include "../../op_cuda.hpp"
9
10 #include "../csl/stream.hpp"
11 #include "../csl/tensor.hpp"
12
13 #include "../kernels/scale_shift.hpp"
14
15 #include <utility>
16
17 namespace cv { namespace dnn { namespace cuda4dnn {
18
19     template <class T>
20     class BatchNormOp final : public CUDABackendNode {
21     public:
22         using wrapper_type = GetCUDABackendWrapperType<T>;
23
24         BatchNormOp(csl::Stream stream_, const cv::Mat& weights, const cv::Mat& bias)
25             : stream(std::move(stream_))
26         {
27             biasTensor = csl::makeTensorHeader<T>(bias);
28             csl::copyMatToTensor<T>(bias, biasTensor, stream);
29
30             weightsTensor = csl::makeTensorHeader<T>(weights);
31             csl::copyMatToTensor<T>(weights, weightsTensor, stream);
32         }
33
34         void forward(
35             const std::vector<cv::Ptr<BackendWrapper>>& inputs,
36             const std::vector<cv::Ptr<BackendWrapper>>& outputs,
37             csl::Workspace& workspace) override
38         {
39             CV_Assert(inputs.size() == 1 && outputs.size() == 1);
40
41             auto input_wrapper = inputs[0].dynamicCast<wrapper_type>();
42             auto input = input_wrapper->getView();
43
44             auto output_wrapper = outputs[0].dynamicCast<wrapper_type>();
45             auto output = output_wrapper->getSpan();
46
47             std::size_t inner_size = input.size_range(2, input.rank());
48             kernels::scaleN_with_biasN<T>(stream, output, input, inner_size, weightsTensor, biasTensor);
49         }
50
51     private:
52         csl::Stream stream;
53         csl::Tensor<T> weightsTensor, biasTensor;
54     };
55
56 }}} /* namespace cv::dnn::cuda4dnn */
57
58 #endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_BATCH_NORM_HPP */