do not create redundant handles
authorYashasSamaga <yashas_2010@yahoo.com>
Tue, 19 May 2020 17:45:02 +0000 (23:15 +0530)
committerYashasSamaga <yashas_2010@yahoo.com>
Fri, 22 May 2020 14:22:20 +0000 (19:52 +0530)
modules/dnn/src/cuda4dnn/csl/cublas.hpp
modules/dnn/src/cuda4dnn/csl/cudnn/cudnn.hpp
modules/dnn/src/cuda4dnn/csl/stream.hpp

index 8320767..3cda3d6 100644 (file)
@@ -52,22 +52,30 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
         }
     }
 
-    /** noncopyable cuBLAS smart handle
+    /** non-copyable cuBLAS smart handle
      *
      * UniqueHandle is a smart non-sharable wrapper for cuBLAS handle which ensures that the handle
-     * is destroyed after use. The handle can be associated with a CUDA stream by specifying the
-     * stream during construction. By default, the handle is associated with the default stream.
+     * is destroyed after use. The handle must always be associated with a non-default stream. The stream
+     * must be specified during construction.
+     *
+     * Refer to stream API for more information for the choice of forcing non-default streams.
      */
     class UniqueHandle {
     public:
-        UniqueHandle() { CUDA4DNN_CHECK_CUBLAS(cublasCreate(&handle)); }
+        UniqueHandle() noexcept : handle{ nullptr } { }
         UniqueHandle(UniqueHandle&) = delete;
-        UniqueHandle(UniqueHandle&& other) noexcept
-            : stream(std::move(other.stream)), handle{ other.handle } {
+        UniqueHandle(UniqueHandle&& other) noexcept {
+            stream = std::move(other.stream);
+            handle = other.handle;
             other.handle = nullptr;
         }
 
+        /** creates a cuBLAS handle and associates it with the stream specified
+         *
+         * Exception Guarantee: Basic
+         */
         UniqueHandle(Stream strm) : stream(std::move(strm)) {
+            CV_Assert(stream);
             CUDA4DNN_CHECK_CUBLAS(cublasCreate(&handle));
             try {
                 CUDA4DNN_CHECK_CUBLAS(cublasSetStream(handle, stream.get()));
@@ -79,7 +87,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
         }
 
         ~UniqueHandle() noexcept {
-            if (handle != nullptr) {
+            if (handle) {
                 /* cublasDestroy won't throw if a valid handle is passed */
                 CUDA4DNN_CHECK_CUBLAS(cublasDestroy(handle));
             }
@@ -87,14 +95,24 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
 
         UniqueHandle& operator=(const UniqueHandle&) = delete;
         UniqueHandle& operator=(UniqueHandle&& other) noexcept {
-            stream = std::move(other.stream);
-            handle = other.handle;
-            other.handle = nullptr;
+            CV_Assert(other);
+            if (&other != this) {
+                UniqueHandle(std::move(*this)); /* destroy current handle */
+                stream = std::move(other.stream);
+                handle = other.handle;
+                other.handle = nullptr;
+            }
             return *this;
         }
 
-        /** @brief returns the raw cuBLAS handle */
-        cublasHandle_t get() const noexcept { return handle; }
+        /** returns the raw cuBLAS handle */
+        cublasHandle_t get() const noexcept {
+            CV_Assert(handle);
+            return handle;
+        }
+
+        /** returns true if the handle is valid */
+        explicit operator bool() const noexcept { return static_cast<bool>(handle); }
 
     private:
         Stream stream;
@@ -104,17 +122,21 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
     /** @brief sharable cuBLAS smart handle
      *
      * Handle is a smart sharable wrapper for cuBLAS handle which ensures that the handle
-     * is destroyed after all references to the handle are destroyed. The handle can be
-     * associated with a CUDA stream by specifying the stream during construction. By default,
-     * the handle is associated with the default stream.
+     * is destroyed after all references to the handle are destroyed. The handle must always
+     * be associated with a non-default stream. The stream must be specified during construction.
      *
      * @note Moving a Handle object to another invalidates the former
      */
     class Handle {
     public:
-        Handle() : handle(std::make_shared<UniqueHandle>()) { }
+        Handle() = default;
         Handle(const Handle&) = default;
         Handle(Handle&&) = default;
+
+        /** creates a cuBLAS handle and associates it with the stream specified
+         *
+         * Exception Guarantee: Basic
+         */
         Handle(Stream strm) : handle(std::make_shared<UniqueHandle>(std::move(strm))) { }
 
         Handle& operator=(const Handle&) = default;
@@ -123,6 +145,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
         /** returns true if the handle is valid */
         explicit operator bool() const noexcept { return static_cast<bool>(handle); }
 
+        /** returns the raw cuBLAS handle */
         cublasHandle_t get() const noexcept {
             CV_Assert(handle);
             return handle->get();
index 13ecc1a..abfbc6b 100644 (file)
@@ -58,15 +58,11 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
      */
     class UniqueHandle {
     public:
-        /** creates a cuDNN handle which executes in the default stream
-         *
-         * Exception Guarantee: Basic
-         */
-        UniqueHandle() { CUDA4DNN_CHECK_CUDNN(cudnnCreate(&handle)); }
-
+        UniqueHandle() noexcept : handle{ nullptr } { }
         UniqueHandle(UniqueHandle&) = delete;
-        UniqueHandle(UniqueHandle&& other) noexcept
-            : stream(std::move(other.stream)), handle{ other.handle } {
+        UniqueHandle(UniqueHandle&& other) noexcept {
+            stream = std::move(other.stream);
+            handle = other.handle;
             other.handle = nullptr;
         }
 
@@ -75,6 +71,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
          * Exception Guarantee: Basic
          */
         UniqueHandle(Stream strm) : stream(std::move(strm)) {
+            CV_Assert(stream);
             CUDA4DNN_CHECK_CUDNN(cudnnCreate(&handle));
             try {
                 CUDA4DNN_CHECK_CUDNN(cudnnSetStream(handle, stream.get()));
@@ -94,14 +91,24 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
 
         UniqueHandle& operator=(const UniqueHandle&) = delete;
         UniqueHandle& operator=(UniqueHandle&& other) noexcept {
-            stream = std::move(other.stream);
-            handle = other.handle;
-            other.handle = nullptr;
+            CV_Assert(other);
+            if (&other != this) {
+                UniqueHandle(std::move(*this)); /* destroy current handle */
+                stream = std::move(other.stream);
+                handle = other.handle;
+                other.handle = nullptr;
+            }
             return *this;
         }
 
         /** returns the raw cuDNN handle */
-        cudnnHandle_t get() const noexcept { return handle; }
+        cudnnHandle_t get() const noexcept {
+            CV_Assert(handle);
+            return handle;
+        }
+
+        /** returns true if the handle is valid */
+        explicit operator bool() const noexcept { return static_cast<bool>(handle); }
 
     private:
         Stream stream;
@@ -111,18 +118,14 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
     /** @brief sharable cuDNN smart handle
      *
      * Handle is a smart sharable wrapper for cuDNN handle which ensures that the handle
-     * is destroyed after all references to the handle are destroyed.
+     * is destroyed after all references to the handle are destroyed. The handle must always
+     * be associated with a non-default stream. The stream must be specified during construction.
      *
      * @note Moving a Handle object to another invalidates the former
      */
     class Handle {
     public:
-        /** creates a cuDNN handle which executes in the default stream
-         *
-         * Exception Guarantee: Basic
-         */
-        Handle() : handle(std::make_shared<UniqueHandle>()) { }
-
+        Handle() = default;
         Handle(const Handle&) = default;
         Handle(Handle&&) = default;
 
@@ -138,6 +141,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
         /** returns true if the handle is valid */
         explicit operator bool() const noexcept { return static_cast<bool>(handle); }
 
+        /** returns the raw cuDNN handle */
         cudnnHandle_t get() const noexcept {
             CV_Assert(handle);
             return handle->get();
index 0a1d804..f715983 100644 (file)
 
 namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
 
-    /** @brief noncopyable smart CUDA stream
+    /** \file stream.hpp
+     *
+     * Default streams are not supported as they limit flexiblity. All operations are always
+     * carried out in non-default streams in the CUDA backend. The stream classes sacrifice
+     * the ability to support default streams in exchange for better error detection. That is,
+     * a default constructed stream represents no stream and any attempt to use it will throw an
+     * exception.
+     */
+
+    /** @brief non-copyable smart CUDA stream
      *
      * UniqueStream is a smart non-sharable wrapper for CUDA stream handle which ensures that
      * the handle is destroyed after use. Unless explicitly specified by a constructor argument,
-     * the stream object represents the default stream.
+     * the stream object does not represent any stream by default.
      */
     class UniqueStream {
     public:
@@ -33,14 +42,19 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
             other.stream = 0;
         }
 
+        /** creates a non-default stream if `create` is true; otherwise, no stream is created */
         UniqueStream(bool create) : stream{ 0 } {
             if (create) {
+                /* we create non-blocking streams to avoid inrerruptions from users using the default stream */
                 CUDA4DNN_CHECK_CUDA(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
             }
         }
 
         ~UniqueStream() {
             try {
+                /* cudaStreamDestroy does not throw if a valid stream is passed unless a previous
+                 * asynchronous operation errored.
+                 */
                 if (stream != 0)
                     CUDA4DNN_CHECK_CUDA(cudaStreamDestroy(stream));
             } catch (const CUDAException& ex) {
@@ -54,16 +68,31 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
 
         UniqueStream& operator=(const UniqueStream&) = delete;
         UniqueStream& operator=(UniqueStream&& other) noexcept {
-            stream = other.stream;
-            other.stream = 0;
+            CV_Assert(other);
+            if (&other != this) {
+                UniqueStream(std::move(*this)); /* destroy current stream */
+                stream = other.stream;
+                other.stream = 0;
+            }
             return *this;
         }
 
         /** returns the raw CUDA stream handle */
-        cudaStream_t get() const noexcept { return stream; }
+        cudaStream_t get() const noexcept {
+            CV_Assert(stream);
+            return stream;
+        }
 
-        void synchronize() const { CUDA4DNN_CHECK_CUDA(cudaStreamSynchronize(stream)); }
+        /** blocks the calling thread until all pending operations in the stream finish */
+        void synchronize() const {
+            CV_Assert(stream);
+            CUDA4DNN_CHECK_CUDA(cudaStreamSynchronize(stream));
+        }
+
+        /** returns true if there are pending operations in the stream */
         bool busy() const {
+            CV_Assert(stream);
+
             auto status = cudaStreamQuery(stream);
             if (status == cudaErrorNotReady)
                 return true;
@@ -71,6 +100,9 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
             return false;
         }
 
+        /** returns true if the stream is valid */
+        explicit operator bool() const noexcept { return static_cast<bool>(stream); }
+
     private:
         cudaStream_t stream;
     };
@@ -78,31 +110,42 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
     /** @brief sharable smart CUDA stream
      *
      * Stream is a smart sharable wrapper for CUDA stream handle which ensures that
-     * the handle is destroyed after use. Unless explicitly specified by a constructor argument,
-     * the stream object represents the default stream.
-     *
-     * @note Moving a Stream object to another invalidates the former
+     * the handle is destroyed after use. Unless explicitly specified in the constructor,
+     * the stream object represents no stream.
      */
     class Stream {
     public:
-        Stream() : stream(std::make_shared<UniqueStream>()) { }
+        Stream() { }
         Stream(const Stream&) = default;
         Stream(Stream&&) = default;
 
-        /** if \p create is `true`, a new stream will be created instead of the otherwise default stream */
-        Stream(bool create) : stream(std::make_shared<UniqueStream>(create)) { }
+        /** if \p create is `true`, a new stream will be created; otherwise, no stream is created */
+        Stream(bool create) {
+            if (create)
+                stream = std::make_shared<UniqueStream>(create);
+        }
 
         Stream& operator=(const Stream&) = default;
         Stream& operator=(Stream&&) = default;
 
         /** blocks the caller thread until all operations in the stream are complete */
-        void synchronize() const { stream->synchronize(); }
+        void synchronize() const {
+            CV_Assert(stream);
+            stream->synchronize();
+        }
 
         /** returns true if there are operations pending in the stream */
-        bool busy() const { return stream->busy(); }
+        bool busy() const {
+            CV_Assert(stream);
+            return stream->busy();
+        }
 
-        /** returns true if the stream is valid */
-        explicit operator bool() const noexcept { return static_cast<bool>(stream); }
+        /** returns true if the object points has a valid stream */
+        explicit operator bool() const noexcept {
+            if (!stream)
+                return false;
+            return stream->operator bool();
+        }
 
         cudaStream_t get() const noexcept {
             CV_Assert(stream);