1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
5 #ifndef OPENCV_DNN_CUDA4DNN_CSL_CUDNN_SOFTMAX_HPP
6 #define OPENCV_DNN_CUDA4DNN_CSL_CUDNN_SOFTMAX_HPP
10 #include "../pointer.hpp"
14 namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cudnn {
16 /** @brief computes softmax (or log softmax)
18 * @tparam T element type (must be `half` or `float`)
20 * @param handle valid cuDNN handle
21 * @param outputDesc tensor descriptor for A
22 * @param[out] output pointer to tensor in device memory
23 * @param inputDesc tensor descriptor for C
24 * @param[in] input pointer to tensor in device memory
25 * @param log apply log on probabilities
27 * Exception Guarantee: Basic
30 void softmax(const cudnn::Handle& handle,
31 const TensorDescriptor<T>& outputDesc, DevicePtr<T> output,
32 const TensorDescriptor<T>& inputDesc, DevicePtr<const T> input,
35 T alpha = 1.0, beta = 0.0;
36 cudnnSoftmaxAlgorithm_t algo = log ? CUDNN_SOFTMAX_LOG : CUDNN_SOFTMAX_ACCURATE;
40 algo, CUDNN_SOFTMAX_MODE_CHANNEL,
41 &alpha, inputDesc.get(), input.get(),
42 &beta, outputDesc.get(), output.get()
48 void softmax(const cudnn::Handle& handle,
49 const TensorDescriptor<half>& outputDesc, DevicePtr<half> output,
50 const TensorDescriptor<half>& inputDesc, DevicePtr<const half> input,
53 /* we specalize for fp16 as the scaling factors must be provided as `float` */
54 float alpha = 1.0, beta = 0.0;
55 cudnnSoftmaxAlgorithm_t algo = log ? CUDNN_SOFTMAX_LOG : CUDNN_SOFTMAX_ACCURATE;
59 algo, CUDNN_SOFTMAX_MODE_CHANNEL,
60 &alpha, inputDesc.get(), input.get(),
61 &beta, outputDesc.get(), output.get()
66 }}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */
68 #endif /* OPENCV_DNN_CUDA4DNN_CSL_CUDNN_SOFTMAX_HPP */