add cuDNN 8 support
authorYashasSamaga <yashas_2010@yahoo.com>
Tue, 30 Jun 2020 16:21:23 +0000 (21:51 +0530)
committerYashasSamaga <yashas_2010@yahoo.com>
Tue, 30 Jun 2020 16:21:23 +0000 (21:51 +0530)
modules/dnn/CMakeLists.txt
modules/dnn/src/cuda4dnn/csl/cudnn/convolution.hpp
modules/dnn/src/cuda4dnn/csl/cudnn/transpose_convolution.hpp

index ed15575..ff9d17f 100644 (file)
@@ -21,14 +21,10 @@ if(OPENCV_DNN_OPENCL AND HAVE_OPENCL)
   add_definitions(-DCV_OCL4DNN=1)
 endif()
 
-if(NOT DEFINED OPENCV_DNN_CUDA AND HAVE_CUDNN AND CUDNN_VERSION VERSION_LESS 8.0)
-  message(STATUS "DNN: CUDNN 8.0 is not supported yes. Details: https://github.com/opencv/opencv/issues/17496")
-endif()
 ocv_option(OPENCV_DNN_CUDA "Build with CUDA support"
     HAVE_CUDA
     AND HAVE_CUBLAS
     AND HAVE_CUDNN
-    AND CUDNN_VERSION VERSION_LESS 8.0
 )
 
 if(OPENCV_DNN_CUDA AND HAVE_CUDA AND HAVE_CUBLAS AND HAVE_CUDNN)
index 46463b6..cad4b29 100644 (file)
@@ -225,6 +225,15 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
                     );
                 }
                 CUDA4DNN_CHECK_CUDNN(cudnnSetConvolutionGroupCount(descriptor, group_count));
+
+#if CUDNN_MAJOR >= 8
+                /* cuDNN 7 and below use FMA math by default. cuDNN 8 includes TF32 Tensor Ops
+                 * in the default setting. TF32 convolutions have lower precision than FP32.
+                 * Hence, we set the math type to CUDNN_FMA_MATH to reproduce old behavior.
+                 */
+                CUDA4DNN_CHECK_CUDNN(cudnnSetConvolutionMathType(descriptor, CUDNN_FMA_MATH));
+#endif
+
                 if (std::is_same<T, half>::value)
                     CUDA4DNN_CHECK_CUDNN(cudnnSetConvolutionMathType(descriptor, CUDNN_TENSOR_OP_MATH));
             } catch (...) {
@@ -254,15 +263,49 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
          */
         ConvolutionAlgorithm(
             const Handle& handle,
-            const ConvolutionDescriptor<T>& conv,
-            const FilterDescriptor<T>& filter,
-            const TensorDescriptor<T>& input,
-            const TensorDescriptor<T>& output)
+            const ConvolutionDescriptor<T>& convDesc,
+            const FilterDescriptor<T>& filterDesc,
+            const TensorDescriptor<T>& inputDesc,
+            const TensorDescriptor<T>& outputDesc)
         {
+#if CUDNN_MAJOR >= 8
+            int requestedAlgoCount = 0, returnedAlgoCount = 0;
+            CUDA4DNN_CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithmMaxCount(handle.get(), &requestedAlgoCount));
+            std::vector<cudnnConvolutionFwdAlgoPerf_t> results(requestedAlgoCount);
+            CUDA4DNN_CHECK_CUDNN(
+                cudnnGetConvolutionForwardAlgorithm_v7(
+                    handle.get(),
+                    inputDesc.get(), filterDesc.get(), convDesc.get(), outputDesc.get(),
+                    requestedAlgoCount,
+                    &returnedAlgoCount,
+                    &results[0]
+                )
+            );
+
+            size_t free_memory, total_memory;
+            CUDA4DNN_CHECK_CUDA(cudaMemGetInfo(&free_memory, &total_memory));
+
+            bool found_conv_algorithm = false;
+            for (int i = 0; i < returnedAlgoCount; i++)
+            {
+                if (results[i].status == CUDNN_STATUS_SUCCESS &&
+                    results[i].algo != CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
+                    results[i].memory < free_memory)
+                {
+                    found_conv_algorithm = true;
+                    algo = results[i].algo;
+                    workspace_size = results[i].memory;
+                    break;
+                }
+            }
+
+            if (!found_conv_algorithm)
+                CV_Error (cv::Error::GpuApiCallError, "cuDNN did not return a suitable algorithm for convolution.");
+#else
             CUDA4DNN_CHECK_CUDNN(
                 cudnnGetConvolutionForwardAlgorithm(
                     handle.get(),
-                    input.get(), filter.get(), conv.get(), output.get(),
+                    inputDesc.get(), filterDesc.get(), convDesc.get(), outputDesc.get(),
                     CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
                     0, /* no memory limit */
                     &algo
@@ -272,10 +315,11 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
             CUDA4DNN_CHECK_CUDNN(
                 cudnnGetConvolutionForwardWorkspaceSize(
                     handle.get(),
-                    input.get(), filter.get(), conv.get(), output.get(),
+                    inputDesc.get(), filterDesc.get(), convDesc.get(), outputDesc.get(),
                     algo, &workspace_size
                 )
             );
+#endif
         }
 
         ConvolutionAlgorithm& operator=(const ConvolutionAlgorithm&) = default;
index d1d26aa..e1596b9 100644 (file)
@@ -30,15 +30,49 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
 
         TransposeConvolutionAlgorithm(
             const Handle& handle,
-            const ConvolutionDescriptor<T>& conv,
-            const FilterDescriptor<T>& filter,
-            const TensorDescriptor<T>& input,
-            const TensorDescriptor<T>& output)
+            const ConvolutionDescriptor<T>& convDesc,
+            const FilterDescriptor<T>& filterDesc,
+            const TensorDescriptor<T>& inputDesc,
+            const TensorDescriptor<T>& outputDesc)
         {
+#if CUDNN_MAJOR >= 8
+            int requestedAlgoCount = 0, returnedAlgoCount = 0;
+            CUDA4DNN_CHECK_CUDNN(cudnnGetConvolutionBackwardDataAlgorithmMaxCount(handle.get(), &requestedAlgoCount));
+            std::vector<cudnnConvolutionBwdDataAlgoPerf_t> results(requestedAlgoCount);
+            CUDA4DNN_CHECK_CUDNN(
+                cudnnGetConvolutionBackwardDataAlgorithm_v7(
+                    handle.get(),
+                    filterDesc.get(), inputDesc.get(), convDesc.get(), outputDesc.get(),
+                    requestedAlgoCount,
+                    &returnedAlgoCount,
+                    &results[0]
+                )
+            );
+
+            size_t free_memory, total_memory;
+            CUDA4DNN_CHECK_CUDA(cudaMemGetInfo(&free_memory, &total_memory));
+
+            bool found_conv_algorithm = false;
+            for (int i = 0; i < returnedAlgoCount; i++)
+            {
+                if (results[i].status == CUDNN_STATUS_SUCCESS &&
+                    results[i].algo != CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED &&
+                    results[i].memory < free_memory)
+                {
+                    found_conv_algorithm = true;
+                    dalgo = results[i].algo;
+                    workspace_size = results[i].memory;
+                    break;
+                }
+            }
+
+            if (!found_conv_algorithm)
+                CV_Error (cv::Error::GpuApiCallError, "cuDNN did not return a suitable algorithm for transpose convolution.");
+#else
             CUDA4DNN_CHECK_CUDNN(
                 cudnnGetConvolutionBackwardDataAlgorithm(
                     handle.get(),
-                    filter.get(), input.get(), conv.get(), output.get(),
+                    filterDesc.get(), inputDesc.get(), convDesc.get(), outputDesc.get(),
                     CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
                     0, /* no memory limit */
                     &dalgo
@@ -48,10 +82,11 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
             CUDA4DNN_CHECK_CUDNN(
                 cudnnGetConvolutionBackwardDataWorkspaceSize(
                     handle.get(),
-                    filter.get(), input.get(), conv.get(), output.get(),
+                    filterDesc.get(), inputDesc.get(), convDesc.get(), outputDesc.get(),
                     dalgo, &workspace_size
                 )
             );
+#endif
         }
 
         TransposeConvolutionAlgorithm& operator=(const TransposeConvolutionAlgorithm&) = default;