Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
[platform/upstream/opencv.git] / modules / dnn / src / cuda4dnn / csl / cudnn / transpose_convolution.hpp
1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4
5 #ifndef OPENCV_DNN_CUDA4DNN_CSL_CUDNN_TRANSPOSE_CONVOLUTION_HPP
6 #define OPENCV_DNN_CUDA4DNN_CSL_CUDNN_TRANSPOSE_CONVOLUTION_HPP
7
8 #include "cudnn.hpp"
9 #include "convolution.hpp"
10
11 #include "../pointer.hpp"
12 #include "../workspace.hpp"
13
14 #include <cudnn.h>
15
16 #include <cstddef>
17
18 namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cudnn {
19
20     /** wrapper around a transpose convolution algorithm
21      *
22      * @tparam  T   type of elements being transpose-convolved
23      */
24     template <class T>
25     class TransposeConvolutionAlgorithm {
26     public:
27         TransposeConvolutionAlgorithm() noexcept : workspace_size{ 0 } { }
28         TransposeConvolutionAlgorithm(TransposeConvolutionAlgorithm&) = default;
29         TransposeConvolutionAlgorithm(TransposeConvolutionAlgorithm&&) = default;
30
31         TransposeConvolutionAlgorithm(
32             const Handle& handle,
33             const ConvolutionDescriptor<T>& conv,
34             const FilterDescriptor<T>& filter,
35             const TensorDescriptor<T>& input,
36             const TensorDescriptor<T>& output)
37         {
38             CUDA4DNN_CHECK_CUDNN(
39                 cudnnGetConvolutionBackwardDataAlgorithm(
40                     handle.get(),
41                     filter.get(), input.get(), conv.get(), output.get(),
42                     CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
43                     0, /* no memory limit */
44                     &dalgo
45                 )
46             );
47
48             CUDA4DNN_CHECK_CUDNN(
49                 cudnnGetConvolutionBackwardDataWorkspaceSize(
50                     handle.get(),
51                     filter.get(), input.get(), conv.get(), output.get(),
52                     dalgo, &workspace_size
53                 )
54             );
55         }
56
57         TransposeConvolutionAlgorithm& operator=(const TransposeConvolutionAlgorithm&) = default;
58         TransposeConvolutionAlgorithm& operator=(TransposeConvolutionAlgorithm&& other) = default;
59
60         cudnnConvolutionBwdDataAlgo_t get() const noexcept { return dalgo; }
61
62         std::size_t get_workspace_size() const noexcept { return workspace_size; }
63
64     private:
65         cudnnConvolutionBwdDataAlgo_t dalgo;
66         std::size_t workspace_size;
67     };
68
69     /** @brief performs transpose convolution
70       *
71       * dstValue = alpha * result + beta * priorDstValue
72       *
73       * @tparam          T              transpose convolution element type (must be `half` or `float`)
74       *
75       * @param           handle         valid cuDNN Handle
76       * @param           convDesc       convolution description
77       * @param           transConvAlgo  algorithm to use for convolution
78       * @param           workspace      workspace memory which meets the requirements of \p convAlgo
79       * @param           filterDesc     filter descriptor
80       * @param[in]       filterPtr      pointer to device memory containing the filters
81       * @param           inputDesc      tensor descriptor describing the input
82       * @param[in]       inputPtr       pointer to input tensor in device memory
83       * @param           alpha          result scale factor
84       * @param           beta           previous value scale factor
85       * @param           outputDesc     tensor descriptor describing the output
86       * @param[out]      outputPtr      pointer to output tensor in device memory
87       *
88       * Exception Guarantee: Basic
89       */
90     template <class T>
91     void transpose_convolve(
92         const Handle& handle,
93         const ConvolutionDescriptor<T>& convDesc,
94         const TransposeConvolutionAlgorithm<T>& transConvAlgo,
95         WorkspaceInstance workspace,
96         const FilterDescriptor<T>& filterDesc,
97         DevicePtr<const T> filterPtr,
98         const TensorDescriptor<T>& inputDesc,
99         DevicePtr<const T> inputPtr,
100         T alpha, T beta,
101         const TensorDescriptor<T>& outputDesc,
102         DevicePtr<T> outputPtr)
103     {
104         CUDA4DNN_CHECK_CUDNN(
105             cudnnConvolutionBackwardData(
106                 handle.get(),
107                 &alpha,
108                 filterDesc.get(), filterPtr.get(),
109                 inputDesc.get(), inputPtr.get(),
110                 convDesc.get(), transConvAlgo.get(),
111                 static_cast<void*>(workspace.get()), workspace.size_in_bytes(),
112                 &beta, outputDesc.get(), outputPtr.get()
113             )
114         );
115     }
116
117     template <> inline
118     void transpose_convolve(
119         const Handle& handle,
120         const ConvolutionDescriptor<half>& convDesc,
121         const TransposeConvolutionAlgorithm<half>& convAlgo,
122         WorkspaceInstance workspace,
123         const FilterDescriptor<half>& filterDesc,
124         DevicePtr<const half> filterPtr,
125         const TensorDescriptor<half>& inputDesc,
126         DevicePtr<const half> inputPtr,
127         half alpha, half beta,
128         const TensorDescriptor<half>& outputDesc,
129         DevicePtr<half> outputPtr)
130     {
131         /* we specalize for fp16 as the scaling factors must be provided as `float` */
132         float alpha_ = alpha, beta_ = beta;
133         CUDA4DNN_CHECK_CUDNN(
134             cudnnConvolutionBackwardData(
135                 handle.get(),
136                 &alpha_,
137                 filterDesc.get(), filterPtr.get(),
138                 inputDesc.get(), inputPtr.get(),
139                 convDesc.get(), convAlgo.get(),
140                 static_cast<void*>(workspace.get()), workspace.size_in_bytes(),
141                 &beta_, outputDesc.get(), outputPtr.get()
142             )
143         );
144     }
145
146 }}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */
147
148 #endif /* OPENCV_DNN_CUDA4DNN_CSL_CUDNN_TRANSPOSE_CONVOLUTION_HPP */