1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
5 #ifndef OPENCV_DNN_SRC_CUDA_VECTOR_TRAITS_HPP
6 #define OPENCV_DNN_SRC_CUDA_VECTOR_TRAITS_HPP
8 #include <cuda_runtime.h>
12 #include "../cuda4dnn/csl/pointer.hpp"
14 #include <type_traits>
16 namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace device {
18 /** \file vector_traits.hpp
19 * \brief utility classes and functions for vectorized memory loads/stores
22 * using vector_type = get_vector_type_t<float, 4>;
24 * auto input_vPtr = type::get_pointer(iptr); // iptr is of type DevicePtr<const float>
25 * auto output_vPtr = type::get_pointer(optr); // optr is of type DevicePtr<float>
28 * v_load(vec, input_vPtr);
30 * for(int i = 0; i < vector_type::size(); i++)
31 * vec[i] = do_something(vec[i]);
33 * v_store(output_vPtr, vec);
37 template <size_type N> struct raw_type_ { };
38 template <> struct raw_type_<256> { typedef ulonglong4 type; };
39 template <> struct raw_type_<128> { typedef uint4 type; };
40 template <> struct raw_type_<64> { typedef uint2 type; };
41 template <> struct raw_type_<32> { typedef uint1 type; };
42 template <> struct raw_type_<16> { typedef uchar2 type; };
43 template <> struct raw_type_<8> { typedef uchar1 type; };
45 template <size_type N> struct raw_type {
46 using type = typename raw_type_<N>::type;
47 static_assert(sizeof(type) * 8 == N, "");
51 /* \tparam T type of element in the vector
52 * \tparam N "number of elements" of type T in the vector
54 template <class T, size_type N>
57 using raw_type = typename detail::raw_type<N * sizeof(T) * 8>::type;
59 __device__ vector_type() { }
61 __device__ static constexpr size_type size() { return N; }
66 template <class U> static __device__
67 typename std::enable_if<std::is_const<U>::value, const vector_type*>
68 ::type get_pointer(csl::DevicePtr<U> ptr) {
69 return reinterpret_cast<const vector_type*>(ptr.get());
72 template <class U> static __device__
73 typename std::enable_if<!std::is_const<U>::value, vector_type*>
74 ::type get_pointer(csl::DevicePtr<U> ptr) {
75 return reinterpret_cast<vector_type*>(ptr.get());
80 __device__ void v_load(V& dest, const V& src) {
85 __device__ void v_load(V& dest, const V* src) {
90 __device__ void v_store(V* dest, const V& src) {
95 __device__ void v_store(V& dest, const V& src) {
99 template <class T, size_type N>
100 struct get_vector_type {
101 typedef vector_type<T, N> type;
104 template <class T, size_type N>
105 using get_vector_type_t = typename get_vector_type<T, N>::type;
107 }}}}} /* namespace cv::dnn::cuda4dnn::csl::device */
109 #endif /* OPENCV_DNN_SRC_CUDA_VECTOR_TRAITS_HPP */