1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
5 #ifndef OPENCV_DNN_CUDA4DNN_CSL_CUDNN_TRANSPOSE_CONVOLUTION_HPP
6 #define OPENCV_DNN_CUDA4DNN_CSL_CUDNN_TRANSPOSE_CONVOLUTION_HPP
9 #include "convolution.hpp"
11 #include "../pointer.hpp"
12 #include "../workspace.hpp"
18 namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cudnn {
20 /** wrapper around a transpose convolution algorithm
22 * @tparam T type of elements being transpose-convolved
25 class TransposeConvolutionAlgorithm {
27 TransposeConvolutionAlgorithm() noexcept : workspace_size{ 0 } { }
28 TransposeConvolutionAlgorithm(TransposeConvolutionAlgorithm&) = default;
29 TransposeConvolutionAlgorithm(TransposeConvolutionAlgorithm&&) = default;
31 TransposeConvolutionAlgorithm(
33 const ConvolutionDescriptor<T>& conv,
34 const FilterDescriptor<T>& filter,
35 const TensorDescriptor<T>& input,
36 const TensorDescriptor<T>& output)
39 cudnnGetConvolutionBackwardDataAlgorithm(
41 filter.get(), input.get(), conv.get(), output.get(),
42 CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
43 0, /* no memory limit */
49 cudnnGetConvolutionBackwardDataWorkspaceSize(
51 filter.get(), input.get(), conv.get(), output.get(),
52 dalgo, &workspace_size
57 TransposeConvolutionAlgorithm& operator=(const TransposeConvolutionAlgorithm&) = default;
58 TransposeConvolutionAlgorithm& operator=(TransposeConvolutionAlgorithm&& other) = default;
60 cudnnConvolutionBwdDataAlgo_t get() const noexcept { return dalgo; }
62 std::size_t get_workspace_size() const noexcept { return workspace_size; }
65 cudnnConvolutionBwdDataAlgo_t dalgo;
66 std::size_t workspace_size;
69 /** @brief performs transpose convolution
71 * dstValue = alpha * result + beta * priorDstValue
73 * @tparam T transpose convolution element type (must be `half` or `float`)
75 * @param handle valid cuDNN Handle
76 * @param convDesc convolution description
77 * @param transConvAlgo algorithm to use for convolution
78 * @param workspace workspace memory which meets the requirements of \p convAlgo
79 * @param filterDesc filter descriptor
80 * @param[in] filterPtr pointer to device memory containing the filters
81 * @param inputDesc tensor descriptor describing the input
82 * @param[in] inputPtr pointer to input tensor in device memory
83 * @param alpha result scale factor
84 * @param beta previous value scale factor
85 * @param outputDesc tensor descriptor describing the output
86 * @param[out] outputPtr pointer to output tensor in device memory
88 * Exception Guarantee: Basic
91 void transpose_convolve(
93 const ConvolutionDescriptor<T>& convDesc,
94 const TransposeConvolutionAlgorithm<T>& transConvAlgo,
95 WorkspaceInstance workspace,
96 const FilterDescriptor<T>& filterDesc,
97 DevicePtr<const T> filterPtr,
98 const TensorDescriptor<T>& inputDesc,
99 DevicePtr<const T> inputPtr,
101 const TensorDescriptor<T>& outputDesc,
102 DevicePtr<T> outputPtr)
104 CUDA4DNN_CHECK_CUDNN(
105 cudnnConvolutionBackwardData(
108 filterDesc.get(), filterPtr.get(),
109 inputDesc.get(), inputPtr.get(),
110 convDesc.get(), transConvAlgo.get(),
111 static_cast<void*>(workspace.get()), workspace.size_in_bytes(),
112 &beta, outputDesc.get(), outputPtr.get()
118 void transpose_convolve(
119 const Handle& handle,
120 const ConvolutionDescriptor<half>& convDesc,
121 const TransposeConvolutionAlgorithm<half>& convAlgo,
122 WorkspaceInstance workspace,
123 const FilterDescriptor<half>& filterDesc,
124 DevicePtr<const half> filterPtr,
125 const TensorDescriptor<half>& inputDesc,
126 DevicePtr<const half> inputPtr,
127 half alpha, half beta,
128 const TensorDescriptor<half>& outputDesc,
129 DevicePtr<half> outputPtr)
131 /* we specalize for fp16 as the scaling factors must be provided as `float` */
132 float alpha_ = alpha, beta_ = beta;
133 CUDA4DNN_CHECK_CUDNN(
134 cudnnConvolutionBackwardData(
137 filterDesc.get(), filterPtr.get(),
138 inputDesc.get(), inputPtr.get(),
139 convDesc.get(), convAlgo.get(),
140 static_cast<void*>(workspace.get()), workspace.size_in_bytes(),
141 &beta_, outputDesc.get(), outputPtr.get()
146 }}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */
148 #endif /* OPENCV_DNN_CUDA4DNN_CSL_CUDNN_TRANSPOSE_CONVOLUTION_HPP */