Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
[platform/upstream/opencv.git] / modules / dnn / src / cuda4dnn / csl / cudnn / softmax.hpp
1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4
5 #ifndef OPENCV_DNN_CUDA4DNN_CSL_CUDNN_SOFTMAX_HPP
6 #define OPENCV_DNN_CUDA4DNN_CSL_CUDNN_SOFTMAX_HPP
7
8 #include "cudnn.hpp"
9
10 #include "../pointer.hpp"
11
12 #include <cudnn.h>
13
14 namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cudnn {
15
16     /** @brief computes softmax (or log softmax)
17      *
18      * @tparam          T           element type (must be `half` or `float`)
19      *
20      * @param           handle      valid cuDNN handle
21      * @param           outputDesc  tensor descriptor for A
22      * @param[out]      output      pointer to tensor in device memory
23      * @param           inputDesc   tensor descriptor for C
24      * @param[in]       input       pointer to tensor in device memory
25      * @param           log         apply log on probabilities
26      *
27      * Exception Guarantee: Basic
28      */
29     template <class T>
30     void softmax(const cudnn::Handle& handle,
31         const TensorDescriptor<T>& outputDesc, DevicePtr<T> output,
32         const TensorDescriptor<T>& inputDesc, DevicePtr<const T> input,
33         bool log)
34     {
35         T alpha = 1.0, beta = 0.0;
36         cudnnSoftmaxAlgorithm_t algo = log ? CUDNN_SOFTMAX_LOG : CUDNN_SOFTMAX_ACCURATE;
37         CUDA4DNN_CHECK_CUDNN(
38             cudnnSoftmaxForward(
39                 handle.get(),
40                 algo, CUDNN_SOFTMAX_MODE_CHANNEL,
41                 &alpha, inputDesc.get(), input.get(),
42                 &beta, outputDesc.get(), output.get()
43             )
44         );
45     }
46
47     template <> inline
48     void softmax(const cudnn::Handle& handle,
49         const TensorDescriptor<half>& outputDesc, DevicePtr<half> output,
50         const TensorDescriptor<half>& inputDesc, DevicePtr<const half> input,
51         bool log)
52     {
53         /* we specalize for fp16 as the scaling factors must be provided as `float` */
54         float alpha = 1.0, beta = 0.0;
55         cudnnSoftmaxAlgorithm_t algo = log ? CUDNN_SOFTMAX_LOG : CUDNN_SOFTMAX_ACCURATE;
56         CUDA4DNN_CHECK_CUDNN(
57             cudnnSoftmaxForward(
58                 handle.get(),
59                 algo, CUDNN_SOFTMAX_MODE_CHANNEL,
60                 &alpha, inputDesc.get(), input.get(),
61                 &beta, outputDesc.get(), output.get()
62             )
63         );
64     }
65
66 }}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */
67
68 #endif /* OPENCV_DNN_CUDA4DNN_CSL_CUDNN_SOFTMAX_HPP */