From: Yashas Samaga B L Date: Mon, 21 Oct 2019 11:28:00 +0000 (+0530) Subject: Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low X-Git-Tag: accepted/tizen/6.0/unified/20201030.111113~1^2~78 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=613c12e59015f4bd7909916ceee195edd7ef88d0;p=platform%2Fupstream%2Fopencv.git Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low CUDA backend for the DNN module * stub cuda4dnn design * minor fixes for tests and doxygen * add csl public api directory to module headers * add low-level CSL components * add high-level CSL components * integrate csl::Tensor into backbone code * switch to CPU iff unsupported; otherwise, fail on error * add fully connected layer * add softmax layer * add activation layers * support arbitary rank TensorDescriptor * pass input wrappers to `initCUDA()` * add 1d/2d/3d-convolution * add pooling layer * reorganize and refactor code * fixes for gcc, clang and doxygen; remove cxx14/17 code * add blank_layer * add LRN layer * add rounding modes for pooling layer * split tensor.hpp into tensor.hpp and tensor_ops.hpp * add concat layer * add scale layer * add batch normalization layer * split math.cu into activations.cu and math.hpp * add eltwise layer * add flatten layer * add tensor transform api * add asymmetric padding support for convolution layer * add reshape layer * fix rebase issues * add permute layer * add padding support for concat layer * refactor and reorganize code * add normalize layer * optimize bias addition in scale layer * add prior box layer * fix and optimize normalize layer * add asymmetric padding support for pooling layer * add event API * improve pooling performance for some padding scenarios * avoid over-allocation of compute resources to kernels * improve prior box performance * enable layer fusion * add const layer * add resize layer * add slice layer * add padding layer * add deconvolution layer * fix channelwise ReLU initialization * add vector traits * add vectorized versions of relu, clipped_relu, power * add vectorized concat kernels * improve concat_with_offsets performance * vectorize scale and bias kernels * add support for multi-billion element tensors * vectorize prior box kernels * fix address alignment check * improve bias addition performance of conv/deconv/fc layers * restructure code for supporting multiple targets * add DNN_TARGET_CUDA_FP64 * add DNN_TARGET_FP16 * improve vectorization * add region layer * improve tensor API, add dynamic ranks 1. use ManagedPtr instead of a Tensor in backend wrapper 2. add new methods to tensor classes - size_range: computes the combined size of for a given axis range - tensor span/view can be constructed from a raw pointer and shape 3. the tensor classes can change their rank at runtime (previously rank was fixed at compile-time) 4. remove device code from tensor classes (as they are unused) 5. enforce strict conditions on tensor class APIs to improve debugging ability * fix parametric relu activation * add squeeze/unsqueeze tensor API * add reorg layer * optimize permute and enable 2d permute * enable 1d and 2d slice * add split layer * add shuffle channel layer * allow tensors of different ranks in reshape primitive * patch SliceOp to allow Crop Layer * allow extra shape inputs in reshape layer * use `std::move_backward` instead of `std::move` for insert in resizable_static_array * improve workspace management * add spatial LRN * add nms (cpu) to region layer * add max pooling with argmax ( and a fix to limits.hpp) * add max unpooling layer * rename DNN_TARGET_CUDA_FP32 to DNN_TARGET_CUDA * update supportBackend to be more rigorous * remove stray include from preventing non-cuda build * include op_cuda.hpp outside condition #if * refactoring, fixes and many optimizations * drop DNN_TARGET_CUDA_FP64 * fix gcc errors * increase max. tensor rank limit to six * add Interp layer * drop custom layers; use BackendNode * vectorize activation kernels * fixes for gcc * remove wrong assertion * fix broken assertion in unpooling primitive * fix build errors in non-CUDA build * completely remove workspace from public API * fix permute layer * enable accuracy and perf. tests for DNN_TARGET_CUDA * add asynchronous forward * vectorize eltwise ops * vectorize fill kernel * fixes for gcc * remove CSL headers from public API * remove csl header source group from cmake * update min. cudnn version in cmake * add numerically stable FP32 log1pexp * refactor code * add FP16 specialization to cudnn based tensor addition * vectorize scale1 and bias1 + minor refactoring * fix doxygen build * fix invalid alignment assertion * clear backend wrappers before allocateLayers * ignore memory lock failures * do not allocate internal blobs * integrate NVTX * add numerically stable half precision log1pexp * fix indentation, following coding style, improve docs * remove accidental modification of IE code * Revert "add asynchronous forward" This reverts commit 1154b9da9da07e9b52f8a81bdcea48cf31c56f70. * [cmake] throw error for unsupported CC versions * fix rebase issues * add more docs, refactor code, fix bugs * minor refactoring and fixes * resolve warnings/errors from clang * remove haveCUDA() checks from supportBackend() * remove NVTX integration * changes based on review comments * avoid exception when no CUDA device is present * add color code for CUDA in Net::dump --- diff --git a/cmake/OpenCVMinDepVersions.cmake b/cmake/OpenCVMinDepVersions.cmake index a57f2a4..ce0c0ba 100644 --- a/cmake/OpenCVMinDepVersions.cmake +++ b/cmake/OpenCVMinDepVersions.cmake @@ -2,7 +2,7 @@ if(NOT DEFINED MIN_VER_CMAKE) set(MIN_VER_CMAKE 3.5.1) endif() set(MIN_VER_CUDA 6.5) -set(MIN_VER_CUDNN 6) +set(MIN_VER_CUDNN 7.5) set(MIN_VER_PYTHON2 2.7) set(MIN_VER_PYTHON3 3.2) set(MIN_VER_ZLIB 1.2.3) diff --git a/modules/dnn/CMakeLists.txt b/modules/dnn/CMakeLists.txt index fa6eadf..9381453 100644 --- a/modules/dnn/CMakeLists.txt +++ b/modules/dnn/CMakeLists.txt @@ -90,6 +90,14 @@ endif() if(OPENCV_DNN_CUDA AND HAVE_CUDA AND HAVE_CUBLAS AND HAVE_CUDNN) list(APPEND include_dirs ${CUDA_TOOLKIT_INCLUDE} ${CUDNN_INCLUDE_DIRS}) + set(CC_LIST ${CUDA_ARCH_BIN}) + separate_arguments(CC_LIST) + foreach(cc ${CC_LIST}) + if(cc VERSION_LESS 5.3) + message(FATAL_ERROR "CUDA backend for DNN module requires CC 5.3 or higher. Please remove unsupported architectures from CUDA_ARCH_BIN option.") + endif() + endforeach() + unset(CC_LIST) else() set(sources_options ${sources_options} EXCLUDE_CUDA) endif() diff --git a/modules/dnn/include/opencv2/dnn/dnn.hpp b/modules/dnn/include/opencv2/dnn/dnn.hpp index 8b69b76..2c000b2 100644 --- a/modules/dnn/include/opencv2/dnn/dnn.hpp +++ b/modules/dnn/include/opencv2/dnn/dnn.hpp @@ -71,7 +71,8 @@ CV__DNN_INLINE_NS_BEGIN DNN_BACKEND_HALIDE, DNN_BACKEND_INFERENCE_ENGINE, //!< Intel's Inference Engine computational backend. DNN_BACKEND_OPENCV, - DNN_BACKEND_VKCOM + DNN_BACKEND_VKCOM, + DNN_BACKEND_CUDA }; /** @@ -85,7 +86,9 @@ CV__DNN_INLINE_NS_BEGIN DNN_TARGET_OPENCL_FP16, DNN_TARGET_MYRIAD, DNN_TARGET_VULKAN, - DNN_TARGET_FPGA //!< FPGA device with CPU fallbacks using Inference Engine's Heterogeneous plugin. + DNN_TARGET_FPGA, //!< FPGA device with CPU fallbacks using Inference Engine's Heterogeneous plugin. + DNN_TARGET_CUDA, + DNN_TARGET_CUDA_FP16 }; CV_EXPORTS std::vector< std::pair > getAvailableBackends(); @@ -274,6 +277,20 @@ CV__DNN_INLINE_NS_BEGIN virtual Ptr initInfEngine(const std::vector > &inputs); virtual Ptr initVkCom(const std::vector > &inputs); + + /** + * @brief Returns a CUDA backend node + * + * @param context void pointer to CSLContext object + * @param inputs layer inputs + * @param outputs layer outputs + */ + virtual Ptr initCUDA( + void *context, + const std::vector>& inputs, + const std::vector>& outputs + ); + /** * @brief Automatic Halide scheduling based on layer hyper-parameters. * @param[in] node Backend node with Halide functions. @@ -515,13 +532,15 @@ CV__DNN_INLINE_NS_BEGIN * @see Target * * List of supported combinations backend / target: - * | | DNN_BACKEND_OPENCV | DNN_BACKEND_INFERENCE_ENGINE | DNN_BACKEND_HALIDE | - * |------------------------|--------------------|------------------------------|--------------------| - * | DNN_TARGET_CPU | + | + | + | - * | DNN_TARGET_OPENCL | + | + | + | - * | DNN_TARGET_OPENCL_FP16 | + | + | | - * | DNN_TARGET_MYRIAD | | + | | - * | DNN_TARGET_FPGA | | + | | + * | | DNN_BACKEND_OPENCV | DNN_BACKEND_INFERENCE_ENGINE | DNN_BACKEND_HALIDE | DNN_BACKEND_CUDA | + * |------------------------|--------------------|------------------------------|--------------------|-------------------| + * | DNN_TARGET_CPU | + | + | + | | + * | DNN_TARGET_OPENCL | + | + | + | | + * | DNN_TARGET_OPENCL_FP16 | + | + | | | + * | DNN_TARGET_MYRIAD | | + | | | + * | DNN_TARGET_FPGA | | + | | | + * | DNN_TARGET_CUDA | | | | + | + * | DNN_TARGET_CUDA_FP16 | | | | + | */ CV_WRAP void setPreferableTarget(int targetId); diff --git a/modules/dnn/perf/perf_convolution3d.cpp b/modules/dnn/perf/perf_convolution3d.cpp index 1f512b2..e81a4bf 100644 --- a/modules/dnn/perf/perf_convolution3d.cpp +++ b/modules/dnn/perf/perf_convolution3d.cpp @@ -111,8 +111,8 @@ PERF_TEST_P_(Conv3D, conv3d) Backend backendId = get<0>(get<1>(GetParam())); Target targetId = get<1>(get<1>(GetParam())); - if (targetId != DNN_TARGET_CPU) - throw SkipTestException("Only CPU is supported"); + if (targetId != DNN_TARGET_CPU && backendId != DNN_BACKEND_CUDA) + throw SkipTestException("Only CPU and CUDA is supported"); int inChannels = inputShape[1]; diff --git a/modules/dnn/src/cuda/activations.cu b/modules/dnn/src/cuda/activations.cu new file mode 100644 index 0000000..344ef79 --- /dev/null +++ b/modules/dnn/src/cuda/activations.cu @@ -0,0 +1,432 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include +#include + +#include "math.hpp" +#include "types.hpp" +#include "vector_traits.hpp" +#include "grid_stride_range.hpp" +#include "execution.hpp" + +#include "../cuda4dnn/csl/stream.hpp" +#include "../cuda4dnn/csl/span.hpp" + +#include "../cuda4dnn/kernels/scale_shift.hpp" + +#include + +#include + +using namespace cv::dnn::cuda4dnn::csl; +using namespace cv::dnn::cuda4dnn::csl::device; + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + namespace raw { + template + __global__ void abs_vec(Span output, View input) { + using vector_type = get_vector_type_t; + + auto output_vPtr = vector_type::get_pointer(output.data()); + auto input_vPtr = vector_type::get_pointer(input.data()); + + for (auto i : grid_stride_range(output.size() / vector_type::size())) { + vector_type vec; + v_load(vec, input_vPtr[i]); + for (int j = 0; j < vector_type::size(); j++) { + using device::abs; + vec.data[j] = abs(vec.data[j]); + } + v_store(output_vPtr[i], vec); + } + } + + template + __global__ void tanh_vec(Span output, View input) { + using vector_type = get_vector_type_t; + + auto output_vPtr = vector_type::get_pointer(output.data()); + auto input_vPtr = vector_type::get_pointer(input.data()); + + for (auto i : grid_stride_range(output.size() / vector_type::size())) { + vector_type vec; + v_load(vec, input_vPtr[i]); + for (int j = 0; j < vector_type::size(); j++) { + using device::tanh; + vec.data[j] = tanh(vec.data[j]); + } + v_store(output_vPtr[i], vec); + } + } + + template + __global__ void sigmoid_vec(Span output, View input) { + using vector_type = get_vector_type_t; + + auto output_vPtr = vector_type::get_pointer(output.data()); + auto input_vPtr = vector_type::get_pointer(input.data()); + + for (auto i : grid_stride_range(output.size() / vector_type::size())) { + vector_type vec; + v_load(vec, input_vPtr[i]); + for (int j = 0; j < vector_type::size(); j++) { + using device::sigmoid; + vec.data[j] = sigmoid(vec.data[j]); + } + v_store(output_vPtr[i], vec); + } + } + + template + __global__ void bnll_vec(Span output, View input) { + using vector_type = get_vector_type_t; + + auto output_vPtr = vector_type::get_pointer(output.data()); + auto input_vPtr = vector_type::get_pointer(input.data()); + + for (auto i : grid_stride_range(output.size() / vector_type::size())) { + vector_type vec; + v_load(vec, input_vPtr[i]); + for (int j = 0; j < vector_type::size(); j++) { + using device::log1pexp; + vec.data[j] = vec.data[j] > T(0) ? vec.data[j] + log1pexp(-vec.data[j]) : log1pexp(vec.data[j]); + } + v_store(output_vPtr[i], vec); + } + } + + template + __global__ void elu_vec(Span output, View input) { + using vector_type = get_vector_type_t; + + auto output_vPtr = vector_type::get_pointer(output.data()); + auto input_vPtr = vector_type::get_pointer(input.data()); + + for (auto i : grid_stride_range(output.size() / vector_type::size())) { + vector_type vec; + v_load(vec, input_vPtr[i]); + for (int j = 0; j < vector_type::size(); j++) { + using device::expm1; + vec.data[j] = vec.data[j] >= T(0) ? vec.data[j] : expm1(vec.data[j]); + } + v_store(output_vPtr[i], vec); + } + } + + template + __global__ void relu_vec(Span output, View input, T slope) { + using vector_type = get_vector_type_t; + + auto output_vPtr = vector_type::get_pointer(output.data()); + auto input_vPtr = vector_type::get_pointer(input.data()); + + for (auto i : grid_stride_range(output.size() / vector_type::size())) { + vector_type vec; + v_load(vec, input_vPtr[i]); + for(int j = 0; j < vector_type::size(); j++) + vec.data[j] = vec.data[j] >= T(0) ? vec.data[j] : slope * vec.data[j]; + v_store(output_vPtr[i], vec); + } + } + + template + __global__ void clipped_relu_vec(Span output, View input, T floor, T ceiling) { + using vector_type = get_vector_type_t; + + auto output_vPtr = vector_type::get_pointer(output.data()); + auto input_vPtr = vector_type::get_pointer(input.data()); + + for (auto i : grid_stride_range(output.size() / vector_type::size())) { + using device::clamp; + + vector_type vec; + v_load(vec, input_vPtr[i]); + for (int j = 0; j < vector_type::size(); j++) + vec.data[j] = clamp(vec.data[j], floor, ceiling); + v_store(output_vPtr[i], vec); + } + } + + template + __global__ void axiswise_relu_vec(Span output, View input, size_type inner_size, View slope) { + using vector_type = get_vector_type_t; + + auto output_vPtr = vector_type::get_pointer(output.data()); + auto input_vPtr = vector_type::get_pointer(input.data()); + + inner_size /= vector_type::size(); + for (auto i : grid_stride_range(output.size() / vector_type::size())) { + const index_type c = (i / inner_size) % static_cast(slope.size()); + + vector_type vec; + v_load(vec, input_vPtr[i]); + for (int j = 0; j < vector_type::size(); j++) + vec.data[j] = vec.data[j] > T(0) ? vec.data[j] : vec.data[j] * slope[c]; + v_store(output_vPtr[i], vec); + } + } + + template + __global__ void power_vec(Span output, View input, T exp, T scale, T shift) { + using vector_type = get_vector_type_t; + + auto output_vPtr = vector_type::get_pointer(output.data()); + auto input_vPtr = vector_type::get_pointer(input.data()); + + for (auto i : grid_stride_range(output.size() / vector_type::size())) { + using device::pow; + + vector_type vec; + v_load(vec, input_vPtr[i]); + for (int j = 0; j < vector_type::size(); j++) + vec.data[j] = pow(shift + scale * vec.data[j], exp); + v_store(output_vPtr[i], vec); + } + } + } + + template + void launch_vectorized_abs(const Stream& stream, Span output, View input) { + CV_Assert(is_fully_aligned(output, N)); + CV_Assert(is_fully_aligned(input, N)); + + auto kernel = raw::abs_vec; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, input); + } + + template + void abs(const Stream& stream, Span output, View input) { + CV_Assert(input.size() == output.size()); + + if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4)) { + launch_vectorized_abs(stream, output, input); + } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2)) { + launch_vectorized_abs(stream, output, input); + } else { + launch_vectorized_abs(stream, output, input); + } + } + + template void abs<__half>(const Stream& stream, Span<__half> output, View<__half> input); + template void abs(const Stream& stream, Span output, View input); + + template + void launch_vectorized_tanh(const Stream& stream, Span output, View input) { + CV_Assert(is_fully_aligned(output, N)); + CV_Assert(is_fully_aligned(input, N)); + + auto kernel = raw::tanh_vec; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, input); + } + + template + void tanh(const Stream& stream, Span output, View input) { + CV_Assert(input.size() == output.size()); + + if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4)) { + launch_vectorized_tanh(stream, output, input); + } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2)) { + launch_vectorized_tanh(stream, output, input); + } else { + launch_vectorized_tanh(stream, output, input); + } + } + + template void tanh<__half>(const Stream&, Span<__half>, View<__half>); + template void tanh(const Stream&, Span, View); + + template + void launch_vectorized_sigmoid(const Stream& stream, Span output, View input) { + CV_Assert(is_fully_aligned(output, N)); + CV_Assert(is_fully_aligned(input, N)); + + auto kernel = raw::sigmoid_vec; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, input); + } + + template + void sigmoid(const Stream& stream, Span output, View input) { + CV_Assert(input.size() == output.size()); + + if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4)) { + launch_vectorized_sigmoid(stream, output, input); + } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2)) { + launch_vectorized_sigmoid(stream, output, input); + } else { + launch_vectorized_sigmoid(stream, output, input); + } + } + + template void sigmoid<__half>(const Stream&, Span<__half>, View<__half>); + template void sigmoid(const Stream&, Span, View); + + template + void launch_vectorized_bnll(const Stream& stream, Span output, View input) { + CV_Assert(is_fully_aligned(output, N)); + CV_Assert(is_fully_aligned(input, N)); + + auto kernel = raw::bnll_vec; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, input); + } + + template + void bnll(const Stream& stream, Span output, View input) { + CV_Assert(input.size() == output.size()); + + if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4)) { + launch_vectorized_bnll(stream, output, input); + } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2)) { + launch_vectorized_bnll(stream, output, input); + } else { + launch_vectorized_bnll(stream, output, input); + } + } + + template void bnll<__half>(const Stream&, Span<__half>, View<__half>); + template void bnll(const Stream&, Span, View); + + template + void launch_vectorized_elu(const Stream& stream, Span output, View input) { + CV_Assert(is_fully_aligned(output, N)); + CV_Assert(is_fully_aligned(input, N)); + + auto kernel = raw::elu_vec; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, input); + } + + template + void elu(const Stream& stream, Span output, View input) { + CV_Assert(input.size() == output.size()); + + if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4)) { + launch_vectorized_elu(stream, output, input); + } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2)) { + launch_vectorized_elu(stream, output, input); + } else { + launch_vectorized_elu(stream, output, input); + } + } + + template void elu<__half>(const Stream&, Span<__half>, View<__half>); + template void elu(const Stream&, Span, View); + + template + void launch_vectorized_relu(const Stream& stream, Span output, View input, T slope) { + CV_Assert(is_fully_aligned(output, N)); + CV_Assert(is_fully_aligned(input, N)); + + auto kernel = raw::relu_vec; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, input, slope); + } + + template + void relu(const Stream& stream, Span output, View input, T slope) { + CV_Assert(input.size() == output.size()); + + if(is_fully_aligned(output, 4) && is_fully_aligned(input, 4)) { + launch_vectorized_relu(stream, output, input, slope); + } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2)) { + launch_vectorized_relu(stream, output, input, slope); + } else { + launch_vectorized_relu(stream, output, input, slope); + } + } + + template void relu<__half>(const Stream&, Span<__half>, View<__half>, __half); + template void relu(const Stream&, Span, View, float); + + template + void launch_vectorized_clipped_relu(const Stream& stream, Span output, View input, T floor, T ceiling) { + CV_Assert(is_fully_aligned(output, N)); + CV_Assert(is_fully_aligned(input, N)); + + auto kernel = raw::clipped_relu_vec; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, input, floor, ceiling); + } + + template + void clipped_relu(const Stream& stream, Span output, View input, T floor, T ceiling) { + CV_Assert(input.size() == output.size()); + CV_Assert(static_cast(floor) <= static_cast(ceiling)); + + if(is_fully_aligned(output, 4) && is_fully_aligned(input, 4)) { + launch_vectorized_clipped_relu(stream, output, input, floor, ceiling); + } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2)) { + launch_vectorized_clipped_relu(stream, output, input, floor, ceiling); + } else { + launch_vectorized_clipped_relu(stream, output, input, floor, ceiling); + } + } + + template void clipped_relu<__half>(const Stream&, Span<__half>, View<__half>, __half, __half); + template void clipped_relu(const Stream&, Span, View, float, float); + + template + void launch_vectorized_axiswise_relu(const Stream& stream, Span output, View input, std::size_t inner_size, View slope) { + CV_Assert(is_fully_aligned(output, N)); + CV_Assert(is_fully_aligned(input, N)); + CV_Assert(inner_size % N == 0); + + auto kernel = raw::axiswise_relu_vec; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, input, inner_size, slope); + } + + template + void axiswise_relu(const Stream& stream, Span output, View input, std::size_t inner_size, View slope) { + CV_Assert(input.size() == output.size()); + + if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4) && inner_size % 4 == 0) { + launch_vectorized_axiswise_relu(stream, output, input, inner_size, slope); + } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2) && inner_size % 2 == 0) { + launch_vectorized_axiswise_relu(stream, output, input, inner_size, slope); + } else { + launch_vectorized_axiswise_relu(stream, output, input, inner_size, slope); + } + } + + template void axiswise_relu<__half>(const Stream&, Span<__half>, View<__half>, std::size_t, View<__half>); + template void axiswise_relu(const Stream&, Span, View, std::size_t, View); + + template + void launch_vectorized_power(const Stream& stream, Span output, View input, T exp, T scale, T shift) { + CV_Assert(is_fully_aligned(output, N)); + CV_Assert(is_fully_aligned(input, N)); + + auto kernel = raw::power_vec; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, input, exp, scale, shift); + } + + template + void power(const Stream& stream, Span output, View input, T exp, T scale, T shift) { + CV_Assert(input.size() == output.size()); + + if (static_cast(exp) == 1.0f) { + scale1_with_bias1(stream, output, input, scale, shift); + return; + } + + if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4) && output.size()) { + launch_vectorized_power(stream, output, input, exp, scale, shift); + } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2) && output.size()) { + launch_vectorized_power(stream, output, input, exp, scale, shift); + } else { + launch_vectorized_power(stream, output, input, exp, scale, shift); + } + } + + template void power<__half>(const Stream&, Span<__half>, View<__half>, __half, __half, __half); + template void power(const Stream&, Span, View, float, float, float); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ diff --git a/modules/dnn/src/cuda/array.hpp b/modules/dnn/src/cuda/array.hpp new file mode 100644 index 0000000..97f3946 --- /dev/null +++ b/modules/dnn/src/cuda/array.hpp @@ -0,0 +1,73 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA_ARRAY_HPP +#define OPENCV_DNN_SRC_CUDA_ARRAY_HPP + +#include + +#include "types.hpp" + +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace device { + + template + struct array { + using value_type = T; + using size_type = device::size_type; + using difference_type = std::ptrdiff_t; + using reference = typename std::add_lvalue_reference::type; + using const_reference = typename std::add_lvalue_reference::type>::type; + using pointer = typename std::add_pointer::type; + using const_pointer = typename std::add_pointer::type>::type; + using iterator = pointer; + using const_iterator = const_pointer; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; + + __host__ __device__ bool empty() const noexcept { return N == 0; } + __host__ __device__ size_type size() const noexcept { return N; } + + __host__ __device__ iterator begin() noexcept { return ptr; } + __host__ __device__ iterator end() noexcept { return ptr + N; } + __host__ __device__ const_iterator begin() const noexcept { return ptr; } + __host__ __device__ const_iterator end() const noexcept { return ptr + N; } + + __host__ __device__ const_iterator cbegin() const noexcept { return ptr; } + __host__ __device__ const_iterator cend() const noexcept { return ptr + N; } + + __host__ __device__ reverse_iterator rbegin() noexcept { return ptr + N; } + __host__ __device__ reverse_iterator rend() noexcept { return ptr; } + __host__ __device__ const_reverse_iterator rbegin() const noexcept { return ptr + N; } + __host__ __device__ const_reverse_iterator rend() const noexcept { return ptr; } + + __host__ __device__ const_reverse_iterator crbegin() const noexcept { return ptr + N; } + __host__ __device__ const_reverse_iterator crend() const noexcept { return ptr; } + + template + __host__ void assign(InputItr first, InputItr last) { + std::copy(first, last, std::begin(ptr)); + } + + __host__ __device__ reference operator[](int idx) { return ptr[idx]; } + __host__ __device__ const_reference operator[](int idx) const { return ptr[idx]; } + + __host__ __device__ reference front() { return ptr[0]; } + __host__ __device__ const_reference front() const { return ptr[0]; } + + __host__ __device__ reference back() { return ptr[N - 1]; } + __host__ __device__ const_reference back() const { return ptr[N - 1]; } + + __host__ __device__ pointer data() noexcept { return ptr; } + __host__ __device__ const_pointer data() const noexcept { return ptr; } + + T ptr[N]; + }; + +}}}}} /* namespace cv::dnn::cuda4dnn::csl::device */ + +#endif /* OPENCV_DNN_SRC_CUDA_ARRAY_HPP */ diff --git a/modules/dnn/src/cuda/atomics.hpp b/modules/dnn/src/cuda/atomics.hpp new file mode 100644 index 0000000..034522d --- /dev/null +++ b/modules/dnn/src/cuda/atomics.hpp @@ -0,0 +1,32 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA_ATOMICS_HPP +#define OPENCV_DNN_SRC_CUDA_ATOMICS_HPP + +#include +#include + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700 +#else +inline __device__ void atomicAdd(__half* address, __half val) { + unsigned int* address_as_ui = (unsigned int *)((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do { + assumed = old; + + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + __half tmpres = hsum + val; + hsum = __half_raw(tmpres); + + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); +} +#endif + +#endif /* OPENCV_DNN_SRC_CUDA_ATOMICS_HPP */ diff --git a/modules/dnn/src/cuda/concat.cu b/modules/dnn/src/cuda/concat.cu new file mode 100644 index 0000000..21e542f --- /dev/null +++ b/modules/dnn/src/cuda/concat.cu @@ -0,0 +1,259 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include +#include + +#include "array.hpp" +#include "types.hpp" +#include "vector_traits.hpp" +#include "grid_stride_range.hpp" +#include "execution.hpp" +#include "kernel_dispatcher.hpp" + +#include "../cuda4dnn/csl/stream.hpp" +#include "../cuda4dnn/csl/tensor.hpp" +#include "../cuda4dnn/csl/span.hpp" + +#include +#include + +using namespace cv::dnn::cuda4dnn::csl; +using namespace cv::dnn::cuda4dnn::csl::device; + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + namespace raw { + template + __global__ void concat_vec( + Span output, size_type output_axis_size, index_type output_axis_offset, + View input, size_type input_axis_size, size_type concat_size) + { + using vector_type = get_vector_type_t; + + auto output_vPtr = vector_type::get_pointer(output.data()); + auto input_vPtr = vector_type::get_pointer(input.data()); + + /* we need to copy all the elements of input to some location in the output + * we copy blocks of size `total_concat_size` to some location in the output + */ + const auto total_concat_size = concat_size * input_axis_size; + + for (auto in_idx : grid_stride_range(input.size() / vector_type::size())) { + const index_type idx = in_idx * vector_type::size(); + const index_type concat_num = idx / total_concat_size; + const index_type concat_index = idx % total_concat_size; + const index_type top_index = concat_index + + (concat_num * output_axis_size + output_axis_offset) * concat_size; + + const auto out_idx = top_index / vector_type::size(); + + vector_type vec; + v_load(vec, input_vPtr[in_idx]); + v_store(output_vPtr[out_idx], vec); + } + } + + template + __global__ void concat_with_offsets( + Span output, array out_strides, array out_offset, + View input, array in_strides) + { + for (auto i : grid_stride_range(input.size())) { + index_type in_index = i / in_strides[0]; + index_type out_index = out_offset[0] + in_index; + index_type oidx = out_index * out_strides[0]; + for (int j = 1; j < Rank; j++) { + in_index = (i % in_strides[j - 1]) / in_strides[j]; + out_index = out_offset[j] + in_index; + oidx += out_index * out_strides[j]; + } + + output[oidx] = input[i]; + } + } + } + + template static + void launch_vectorized_concat(const Stream& stream, + Span output, size_type output_axis_size, index_type output_axis_offset, + View input, size_type input_axis_size, size_type concat_size) + { + CV_Assert(is_fully_aligned(output, N)); + CV_Assert(is_fully_aligned(input, N)); + /* more assertions are required to fully check for vectorization possiblity; check concat() */ + + auto kernel = raw::concat_vec; + auto policy = make_policy(kernel, input.size() / N, 0, stream); + launch_kernel(kernel, policy, output, output_axis_size, output_axis_offset, input, input_axis_size, concat_size); + } + + template + void concat( + const Stream& stream, + TensorSpan output, std::size_t output_axis_offset, + TensorView input, std::size_t axis) + { + /* let's call the axis of interest as the channel axis for the purpose of the following discussion + * even though it can be any axis + * + * for each batch item: + * we move all the channels from the input (which together, for a single batch item, is contiguous) + * of a batch item to its corresponding contiguous place in the output + * + * for a valid vector operation: + * - the size of each copy block must be aligned + * - input must be aligned + * - all the destination locations in the output must be aligned + */ + std::size_t concat_size = output.size_range(axis + 1, output.rank()); + + std::size_t input_axis_size = input.get_axis_size(axis); + std::size_t output_axis_size = output.get_axis_size(axis); + + std::size_t copy_block_size = concat_size * input_axis_size; + std::size_t copy_block_stride = concat_size * output_axis_size; + std::size_t starting_offset = output_axis_offset * concat_size; + + /* in a nutshell, all this concat operation does is copy several blocks of size `copy_block_size` + * to the output starting from `starting_offset` with blocks in the output strided by `copy_block_stride` + */ + + bool is_aligned_4 = copy_block_size % 4 == 0 && copy_block_stride % 4 == 0 && starting_offset % 4 == 0; + bool is_aligned_2 = copy_block_size % 2 == 0 && copy_block_stride % 2 == 0 && starting_offset % 2 == 0; + + if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4) && is_aligned_4) { + launch_vectorized_concat(stream, output, output_axis_size, output_axis_offset, input, input_axis_size, concat_size); + } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2) && is_aligned_2) { + launch_vectorized_concat(stream, output, output_axis_size, output_axis_offset, input, input_axis_size, concat_size); + } else { + launch_vectorized_concat(stream, output, output_axis_size, output_axis_offset, input, input_axis_size, concat_size); + } + } + + template void concat<__half>(const Stream&, TensorSpan<__half>, std::size_t, TensorView<__half>, std::size_t); + template void concat(const Stream&, TensorSpan, std::size_t, TensorView, std::size_t); + + template static + void launch_concat_with_offsets( + const Stream& stream, + Span output, const std::vector& outStride, const std::vector& outOffset, + View input, const std::vector& inStride) + { + CV_Assert(outStride.size() == Rank); + CV_Assert(outOffset.size() == Rank); + CV_Assert(inStride.size() == Rank); + + array outStride_k, inStride_k; + outStride_k.assign(std::begin(outStride), std::end(outStride)); + inStride_k.assign(std::begin(inStride), std::end(inStride)); + + array outOffset_k; + outOffset_k.assign(std::begin(outOffset), std::end(outOffset)); + + auto kernel = raw::concat_with_offsets; + auto policy = make_policy(kernel, input.size(), 0, stream); + launch_kernel(kernel, policy, output, outStride_k, outOffset_k, input, inStride_k); + } + + GENERATE_KERNEL_DISPATCHER(concat_with_offsets_dispatcher, launch_concat_with_offsets); + + template + void concat_with_offsets( + const Stream& stream, + TensorSpan output, TensorView input, + std::vector offsets) + { + CV_Assert(output.rank() == input.rank()); + CV_Assert(output.rank() == offsets.size()); + + /* squeezable axes at the begining of both tensors can be eliminated + * + * Reasoning: + * ---------- + * Suppose an item's indices in the input tensor is [i1, i2, ...]. The indices in the output + * tensor will be [i1 + off1, i2 + off2, ...]. The concat operation essentially copies items + * from the input tensor to new locations in the output tensor. + * + * If the size of the first axis of the input and output tensor is unity, the input and output + * indices for all the elements will be of the form be [0, i2, ...] and [0, i2 + off2, ...] + * respectively. The first index does not contribute to the element's address calculation and + * hence does nothing apart from eating up few cycles. + */ + while (input.get_axis_size(0) == 1 && output.get_axis_size(0) == 1) { + CV_Assert(offsets[0] == 0); + + input.squeeze(0); + output.squeeze(0); + offsets.erase(std::begin(offsets)); + + CV_Assert(output.rank() == input.rank()); + CV_Assert(output.rank() == offsets.size()); + } + + auto inShape = input.shape_as_vector(); + auto outShape = output.shape_as_vector(); + + /* contiguous axes that undergo full copy can be combined into one axis + * + * Reasoning: + * ---------- + * Suppose an item's indices in the input tensor is [i1, i2, i3, ...]. Let the first two axes not undergo any + * concatenation. The indices in the output tensor will be [i1, i2, i3 + off3, ...]. + * + * Each axis in the contiguous axes sequence will add an offset of iN * strideN. In the above example, + * the two axes add a total offset of `i1 * stride1 + i2 * stride2`. We can merge the two axes into one axis with + * a size of `size1 * size2`. The new offset added will be i12 * stride2` as the kernel iterates through `i12`. + * Note that `i12` is actually `(i1 * size2 + i2)` in the original tensor. + */ + for (int i = 0; i < inShape.size(); i++) { + /* check if axis `i` requires any slicing */ + if (offsets[i] == 0 && inShape[i] == outShape[i]) { + /* loop invariant: `i` is the first axis in the contiguous unsliced axis sequence */ + + int j = i + 1; /* `j` is the axis which we will attempt to merge */ + while (j < inShape.size() && offsets[j] == 0 && inShape[j] == outShape[j]) { + /* `j` axis is also copied fully; merge `i` and `j` */ + auto new_size = inShape[i] * inShape[j]; + inShape[i] = new_size; + outShape[i] = new_size; + offsets[i] = 0; /* redundant */ + + /* delete axis `j` */ + inShape.erase(std::begin(inShape) + j); + outShape.erase(std::begin(outShape) + j); + offsets.erase(std::begin(offsets) + j); + + /* optimizations should not break the invariants */ + CV_Assert(inShape.size() == outShape.size()); + CV_Assert(inShape.size() == offsets.size()); + CV_Assert(inShape[i] == outShape[i]); + CV_Assert(offsets[i] == 0); + } + } + } + + auto rank = inShape.size(); + + std::vector inStride(rank), outStride(rank); + inStride.back() = 1; + outStride.back() = 1; + /* garbage, ..., garbage, 1 */ + + std::copy(std::begin(inShape) + 1, std::end(inShape), std::begin(inStride)); + std::copy(std::begin(outShape) + 1, std::end(outShape), std::begin(outStride)); + /* dim[0], dim[1], ..., dim[-1], 1 */ + + std::partial_sum(inStride.rbegin(), inStride.rend(), inStride.rbegin(), std::multiplies()); + std::partial_sum(outStride.rbegin(), outStride.rend(), outStride.rbegin(), std::multiplies()); + /* stride[0], stride[1], ..., stride[-2], 1 */ + + CV_Assert(1 <= rank && rank <= CSL_MAX_TENSOR_RANK); + concat_with_offsets_dispatcher(rank, stream, output, outStride, offsets, input, inStride); + } + + template void concat_with_offsets(const Stream&, TensorSpan<__half>, TensorView<__half>, std::vector); + template void concat_with_offsets(const Stream&, TensorSpan, TensorView, std::vector); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ diff --git a/modules/dnn/src/cuda/eltwise_ops.cu b/modules/dnn/src/cuda/eltwise_ops.cu new file mode 100644 index 0000000..260783c --- /dev/null +++ b/modules/dnn/src/cuda/eltwise_ops.cu @@ -0,0 +1,224 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include +#include + +#include "math.hpp" +#include "grid_stride_range.hpp" +#include "execution.hpp" +#include "vector_traits.hpp" + +#include "../cuda4dnn/csl/stream.hpp" +#include "../cuda4dnn/csl/span.hpp" + +#include + +using namespace cv::dnn::cuda4dnn::csl; +using namespace cv::dnn::cuda4dnn::csl::device; + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + namespace raw { + template + __global__ void eltwise_max_2_vec(Span output, View x, View y) { + using vector_type = get_vector_type_t; + + auto output_vPtr = vector_type::get_pointer(output.data()); + auto x_vPtr = vector_type::get_pointer(x.data()); + auto y_vPtr = vector_type::get_pointer(y.data()); + + for (auto i : grid_stride_range(output.size() / vector_type::size())) { + vector_type vec_x, vec_y; + v_load(vec_x, x_vPtr[i]); + v_load(vec_y, y_vPtr[i]); + + for (int j = 0; j < vector_type::size(); j++) { + using device::max; + vec_x.data[j] = max(vec_x.data[j], vec_y.data[j]); + } + + v_store(output_vPtr[i], vec_x); + } + } + + template + __global__ void eltwise_sum_2_vec(Span output, View x, View y) { + using vector_type = get_vector_type_t; + + auto output_vPtr = vector_type::get_pointer(output.data()); + auto x_vPtr = vector_type::get_pointer(x.data()); + auto y_vPtr = vector_type::get_pointer(y.data()); + + for (auto i : grid_stride_range(output.size() / vector_type::size())) { + vector_type vec_x, vec_y; + v_load(vec_x, x_vPtr[i]); + v_load(vec_y, y_vPtr[i]); + + for (int j = 0; j < vector_type::size(); j++) + vec_x.data[j] = vec_x.data[j] + vec_y.data[j]; + + v_store(output_vPtr[i], vec_x); + } + } + + template + __global__ void eltwise_sum_coeff_2_vec(Span output, T coeff_x, View x, T coeff_y, View y) { + using vector_type = get_vector_type_t; + + auto output_vPtr = vector_type::get_pointer(output.data()); + auto x_vPtr = vector_type::get_pointer(x.data()); + auto y_vPtr = vector_type::get_pointer(y.data()); + + for (auto i : grid_stride_range(output.size() / vector_type::size())) { + vector_type vec_x, vec_y; + v_load(vec_x, x_vPtr[i]); + v_load(vec_y, y_vPtr[i]); + + for (int j = 0; j < vector_type::size(); j++) + vec_x.data[j] = coeff_x * vec_x.data[j] + coeff_y * vec_y.data[j]; + + v_store(output_vPtr[i], vec_x); + } + } + + template + __global__ void eltwise_prod_2_vec(Span output, View x, View y) { + using vector_type = get_vector_type_t; + + auto output_vPtr = vector_type::get_pointer(output.data()); + auto x_vPtr = vector_type::get_pointer(x.data()); + auto y_vPtr = vector_type::get_pointer(y.data()); + + for (auto i : grid_stride_range(output.size() / vector_type::size())) { + vector_type vec_x, vec_y; + v_load(vec_x, x_vPtr[i]); + v_load(vec_y, y_vPtr[i]); + + for (int j = 0; j < vector_type::size(); j++) + vec_x.data[j] = vec_x.data[j] * vec_y.data[j]; + + v_store(output_vPtr[i], vec_x); + } + } + } + + template + void launch_vectorized_eltwise_max_2(const Stream& stream, Span output, View x, View y) { + CV_Assert(is_fully_aligned(output, N)); + CV_Assert(is_fully_aligned(x, N)); + CV_Assert(is_fully_aligned(y, N)); + + auto kernel = raw::eltwise_max_2_vec; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, x, y); + } + + template + void eltwise_max_2(const Stream& stream, Span output, View x, View y) { + CV_Assert(x.size() == y.size()); + CV_Assert(x.size() == output.size()); + + if (is_fully_aligned(output, 4) && is_fully_aligned(x, 4) && is_fully_aligned(y, 4)) { + launch_vectorized_eltwise_max_2(stream, output, x, y); + } else if (is_fully_aligned(output, 2) && is_fully_aligned(x, 2) && is_fully_aligned(y, 2)) { + launch_vectorized_eltwise_max_2(stream, output, x, y); + } else { + launch_vectorized_eltwise_max_2(stream, output, x, y); + } + } + + template void eltwise_max_2(const Stream& stream, Span<__half> output, View<__half> x, View<__half> y); + template void eltwise_max_2(const Stream& stream, Span output, View x, View y); + + template + void launch_vectorized_eltwise_sum_2(const Stream& stream, Span output, View x, View y) { + CV_Assert(is_fully_aligned(output, N)); + CV_Assert(is_fully_aligned(x, N)); + CV_Assert(is_fully_aligned(y, N)); + + auto kernel = raw::eltwise_sum_2_vec; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, x, y); + } + + template + void eltwise_sum_2(const Stream& stream, Span output, View x, View y) { + CV_Assert(x.size() == y.size()); + CV_Assert(x.size() == output.size()); + + if (is_fully_aligned(output, 4) && is_fully_aligned(x, 4) && is_fully_aligned(y, 4)) { + launch_vectorized_eltwise_sum_2(stream, output, x, y); + } else if (is_fully_aligned(output, 2) && is_fully_aligned(x, 2) && is_fully_aligned(y, 2)) { + launch_vectorized_eltwise_sum_2(stream, output, x, y); + } else { + launch_vectorized_eltwise_sum_2(stream, output, x, y); + } + } + + template void eltwise_sum_2(const Stream& stream, Span<__half> output, View<__half> x, View<__half> y); + template void eltwise_sum_2(const Stream& stream, Span output, View x, View y); + + template + void launch_vectorized_eltwise_sum_coeff_2(const Stream& stream, Span output, T coeff_x, View x, T coeff_y, View y) { + CV_Assert(is_fully_aligned(output, N)); + CV_Assert(is_fully_aligned(x, N)); + CV_Assert(is_fully_aligned(y, N)); + + auto kernel = raw::eltwise_sum_coeff_2_vec; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, coeff_x, x, coeff_y, y); + } + + template + void eltwise_sum_coeff_2(const Stream& stream, Span output, T coeff_x, View x, T coeff_y, View y) { + CV_Assert(x.size() == y.size()); + CV_Assert(x.size() == output.size()); + + if (static_cast(coeff_x) == 1.0f && static_cast(coeff_y) == 1.0f) { + eltwise_sum_2(stream, output, x, y); + return; + } + + if (is_fully_aligned(output, 4) && is_fully_aligned(x, 4) && is_fully_aligned(y, 4)) { + launch_vectorized_eltwise_sum_coeff_2(stream, output, coeff_x, x, coeff_y, y); + } else if (is_fully_aligned(output, 2) && is_fully_aligned(x, 2) && is_fully_aligned(y, 2)) { + launch_vectorized_eltwise_sum_coeff_2(stream, output, coeff_x, x, coeff_y, y); + } else { + launch_vectorized_eltwise_sum_coeff_2(stream, output, coeff_x, x, coeff_y, y); + } + } + + template void eltwise_sum_coeff_2(const Stream&, Span<__half>, __half, View<__half>, __half, View<__half>); + template void eltwise_sum_coeff_2(const Stream&, Span, float, View, float, View); + + template + void launch_vectorized_eltwise_prod_2(const Stream& stream, Span output, View x, View y) { + CV_Assert(is_fully_aligned(output, N)); + CV_Assert(is_fully_aligned(x, N)); + CV_Assert(is_fully_aligned(y, N)); + + auto kernel = raw::eltwise_prod_2_vec; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, x, y); + } + + template + void eltwise_prod_2(const Stream& stream, Span output, View x, View y) { + CV_Assert(x.size() == y.size()); + CV_Assert(x.size() == output.size()); + + if (is_fully_aligned(output, 4) && is_fully_aligned(x, 4) && is_fully_aligned(y, 4)) { + launch_vectorized_eltwise_prod_2(stream, output, x, y); + } else if (is_fully_aligned(output, 2) && is_fully_aligned(x, 2) && is_fully_aligned(y, 2)) { + launch_vectorized_eltwise_prod_2(stream, output, x, y); + } else { + launch_vectorized_eltwise_prod_2(stream, output, x, y); + } + } + + template void eltwise_prod_2(const Stream& stream, Span<__half> output, View<__half> x, View<__half> y); + template void eltwise_prod_2(const Stream& stream, Span output, View x, View y); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ diff --git a/modules/dnn/src/cuda/execution.hpp b/modules/dnn/src/cuda/execution.hpp new file mode 100644 index 0000000..57d1e30 --- /dev/null +++ b/modules/dnn/src/cuda/execution.hpp @@ -0,0 +1,81 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA_EXECUTION_HPP +#define OPENCV_DNN_SRC_CUDA_EXECUTION_HPP + +#include "../cuda4dnn/csl/error.hpp" +#include "../cuda4dnn/csl/stream.hpp" + +#include + +#include + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { + + struct execution_policy { + execution_policy(dim3 grid_size, dim3 block_size) + : grid{ grid_size }, block{ block_size }, sharedMem{ 0 }, stream{ 0 } { } + + execution_policy(dim3 grid_size, dim3 block_size, std::size_t shared_mem) + : grid{ grid_size }, block{ block_size }, sharedMem{ shared_mem }, stream{ nullptr } { } + + execution_policy(dim3 grid_size, dim3 block_size, const Stream& strm) + : grid{ grid_size }, block{ block_size }, sharedMem{ 0 }, stream{ strm.get() } { } + + execution_policy(dim3 grid_size, dim3 block_size, std::size_t shared_mem, const Stream& strm) + : grid{ grid_size }, block{ block_size }, sharedMem{ shared_mem }, stream{ strm.get() } { } + + dim3 grid; + dim3 block; + std::size_t sharedMem; + cudaStream_t stream; + }; + + /* this overload shouldn't be necessary; we should always provide a bound on the number of threads */ + /* + template inline + execution_policy make_policy(Kernel kernel, std::size_t sharedMem = 0, const Stream& stream = 0) { + int grid_size, block_size; + CUDA4DNN_CHECK_CUDA(cudaOccupancyMaxPotentialBlockSize(&grid_size, &block_size, kernel, sharedMem)); + return execution_policy(grid_size, block_size, sharedMem, stream); + }*/ + + template inline + execution_policy make_policy(Kernel kernel, std::size_t max_threads, std::size_t sharedMem = 0, const Stream& stream = 0) { + CV_Assert(max_threads > 0); + + int grid_size = 0, block_size = 0; + CUDA4DNN_CHECK_CUDA(cudaOccupancyMaxPotentialBlockSize(&grid_size, &block_size, kernel, sharedMem)); + if (grid_size * block_size > max_threads) { + grid_size = (max_threads + block_size - 1) / block_size; + if (block_size > max_threads) + block_size = max_threads; + } + + CV_Assert(grid_size >= 1 && block_size >= 1); + return execution_policy(grid_size, block_size, sharedMem, stream); + } + + template inline + void launch_kernel(Kernel kernel, Args ...args) { + auto policy = make_policy(kernel); + kernel <<>> (std::forward(args)...); + } + + template inline + void launch_kernel(Kernel kernel, dim3 grid, dim3 block, Args ...args) { + kernel <<>> (std::forward(args)...); + } + + template inline + void launch_kernel(Kernel kernel, execution_policy policy, Args ...args) { + kernel <<>> (std::forward(args)...); + } + +}}}} /* namespace cv::dnn::cuda4dnn::csl */ + +#endif /* OPENCV_DNN_SRC_CUDA_EXECUTION_HPP */ diff --git a/modules/dnn/src/cuda/fill.cu b/modules/dnn/src/cuda/fill.cu new file mode 100644 index 0000000..e4fea27 --- /dev/null +++ b/modules/dnn/src/cuda/fill.cu @@ -0,0 +1,58 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include +#include + +#include "grid_stride_range.hpp" +#include "execution.hpp" +#include "vector_traits.hpp" + +#include "../cuda4dnn/csl/stream.hpp" +#include "../cuda4dnn/csl/span.hpp" + +using namespace cv::dnn::cuda4dnn::csl; +using namespace cv::dnn::cuda4dnn::csl::device; + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + namespace raw { + template + __global__ void fill_vec(Span output, T value) { + using vector_type = get_vector_type_t; + + auto output_vPtr = vector_type::get_pointer(output.data()); + for (auto i : grid_stride_range(output.size() / vector_type::size())) { + vector_type vec; + for (int j = 0; j < vector_type::size(); j++) + vec.data[j] = value; + v_store(output_vPtr[i], vec); + } + } + } + + template + void launch_vectorized_fill(const Stream& stream, Span output, T value) { + CV_Assert(is_fully_aligned(output, N)); + + auto kernel = raw::fill_vec; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, value); + } + + template + void fill(const Stream& stream, Span output, T value) { + if (is_fully_aligned(output, 4)) { + launch_vectorized_fill(stream, output, value); + } else if (is_fully_aligned(output, 2)) { + launch_vectorized_fill(stream, output, value); + } else { + launch_vectorized_fill(stream, output, value); + } + } + + template void fill(const Stream&, Span<__half>, __half); + template void fill(const Stream&, Span, float); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ diff --git a/modules/dnn/src/cuda/grid_stride_range.hpp b/modules/dnn/src/cuda/grid_stride_range.hpp new file mode 100644 index 0000000..4b61a0f --- /dev/null +++ b/modules/dnn/src/cuda/grid_stride_range.hpp @@ -0,0 +1,92 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA_GRID_STRIDE_RANGE_HPP +#define OPENCV_DNN_SRC_CUDA_GRID_STRIDE_RANGE_HPP + +#include "types.hpp" + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace device { + + namespace detail { + template __device__ auto getGridDim()->decltype(dim3::x); + template <> inline __device__ auto getGridDim<0>()->decltype(dim3::x) { return gridDim.x; } + template <> inline __device__ auto getGridDim<1>()->decltype(dim3::x) { return gridDim.y; } + template <> inline __device__ auto getGridDim<2>()->decltype(dim3::x) { return gridDim.z; } + + template __device__ auto getBlockDim()->decltype(dim3::x); + template <> inline __device__ auto getBlockDim<0>()->decltype(dim3::x) { return blockDim.x; } + template <> inline __device__ auto getBlockDim<1>()->decltype(dim3::x) { return blockDim.y; } + template <> inline __device__ auto getBlockDim<2>()->decltype(dim3::x) { return blockDim.z; } + + template __device__ auto getBlockIdx()->decltype(uint3::x); + template <> inline __device__ auto getBlockIdx<0>()->decltype(uint3::x) { return blockIdx.x; } + template <> inline __device__ auto getBlockIdx<1>()->decltype(uint3::x) { return blockIdx.y; } + template <> inline __device__ auto getBlockIdx<2>()->decltype(uint3::x) { return blockIdx.z; } + + template __device__ auto getThreadIdx()->decltype(uint3::x); + template <> inline __device__ auto getThreadIdx<0>()->decltype(uint3::x) { return threadIdx.x; } + template <> inline __device__ auto getThreadIdx<1>()->decltype(uint3::x) { return threadIdx.y; } + template <> inline __device__ auto getThreadIdx<2>()->decltype(uint3::x) { return threadIdx.z; } + } + + template + class grid_stride_range_generic { + public: + __device__ grid_stride_range_generic(index_type to_) : from(0), to(to_) { } + __device__ grid_stride_range_generic(index_type from_, index_type to_) : from(from_), to(to_) { } + + class iterator + { + public: + __device__ iterator(index_type pos_) : pos(pos_) {} + + /* these iterators return the index when dereferenced; this allows us to loop + * through the indices using a range based for loop + */ + __device__ index_type operator*() const { return pos; } + + __device__ iterator& operator++() { + pos += detail::getGridDim() * static_cast(detail::getBlockDim()); + return *this; + } + + __device__ bool operator!=(const iterator& other) const { + /* NOTE HACK + ** 'pos' can move in large steps (see operator++) + ** expansion of range for loop uses != as the loop conditioion + ** => operator!= must return false if 'pos' crosses the end + */ + return pos < other.pos; + } + + private: + index_type pos; + }; + + __device__ iterator begin() const { + using detail::getBlockDim; + using detail::getBlockIdx; + using detail::getThreadIdx; + return iterator(from + getBlockDim() * getBlockIdx() + getThreadIdx()); + } + + __device__ iterator end() const { + return iterator(to); + } + + private: + index_type from, to; + }; + + using grid_stride_range_x = grid_stride_range_generic<0>; + using grid_stride_range_y = grid_stride_range_generic<1>; + using grid_stride_range_z = grid_stride_range_generic<2>; + using grid_stride_range = grid_stride_range_x; + +}}}}} /* namespace cv::dnn::cuda4dnn::csl::device */ + +#endif /* OPENCV_DNN_SRC_CUDA_GRID_STRIDE_RANGE_HPP */ diff --git a/modules/dnn/src/cuda/kernel_dispatcher.hpp b/modules/dnn/src/cuda/kernel_dispatcher.hpp new file mode 100644 index 0000000..6eff834 --- /dev/null +++ b/modules/dnn/src/cuda/kernel_dispatcher.hpp @@ -0,0 +1,76 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA_KERNEL_DISPATCHER_HPP +#define OPENCV_DNN_SRC_CUDA_KERNEL_DISPATCHER_HPP + +#include +#include + +/* The performance of many kernels are highly dependent on the tensor rank. Instead of having + * one kernel which can work with the maximally ranked tensors, we make one kernel for each supported + * tensor rank. This is to ensure that the requirements of the maximally ranked tensors do not take a + * toll on the performance of the operation for low ranked tensors. Hence, many kernels take the tensor + * rank as a template parameter. + * + * The kernel is a template and we have different instantiations for each rank. This causes the following pattern + * to arise frequently: + * + * if(rank == 3) + * kernel(); + * else if(rank == 2) + * kernel(); + * else + * kernel(); + * + * The rank is a runtime variable. To facilitate creation of such structures, we use GENERATE_KERNEL_DISPATCHER. + * This macro creates a function which selects the correct kernel instantiation at runtime. + * + * Example: + * + * // function which setups the kernel and launches it + * template + * void launch_some_kernel(...); + * + * // creates the dispatcher named "some_dispatcher" which invokves the correct instantiation of "launch_some_kernel" + * GENERATE_KERNEL_DISPATCHER(some_dispatcher, launch_some_kernel); + * + * // internal API function + * template + * void some(...) { + * // ... + * auto rank = input.rank(); + * some_dispatcher(rank, ...); + * } + */ + +/* + * name name of the dispatcher function that is generated + * func template function that requires runtime selection + * + * T first template parameter to `func` + * start starting rank + * end ending rank (inclusive) + * + * Executes func based on runtime `selector` argument given `selector` lies + * within the range [start, end]. If outside the range, no instantiation of `func` is executed. + */ +#define GENERATE_KERNEL_DISPATCHER(name,func); \ + template static \ + typename std::enable_if \ + ::type name(int selector, Args&& ...args) { \ + if(selector == start) \ + func(std::forward(args)...); \ + } \ + \ + template static \ + typename std::enable_if \ + ::type name(int selector, Args&& ...args) { \ + if(selector == start) \ + func(std::forward(args)...); \ + else \ + name(selector, std::forward(args)...); \ + } + +#endif /* OPENCV_DNN_SRC_CUDA_KERNEL_DISPATCHER_HPP */ diff --git a/modules/dnn/src/cuda/limits.hpp b/modules/dnn/src/cuda/limits.hpp new file mode 100644 index 0000000..fec65e6 --- /dev/null +++ b/modules/dnn/src/cuda/limits.hpp @@ -0,0 +1,34 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA_LIMITS_HPP +#define OPENCV_DNN_SRC_CUDA_LIMITS_HPP + +#include +#include + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace device { + + template + struct numeric_limits; + + template <> + struct numeric_limits<__half> { + __device__ static __half min() { return 0.0000610; } + __device__ static __half max() { return 65504.0; } + __device__ static __half lowest() { return -65504.0; } + }; + + template <> + struct numeric_limits { + __device__ static float min() { return FLT_MIN; } + __device__ static float max() { return FLT_MAX; } + __device__ static float lowest() { return -FLT_MAX; } + }; + +}}}}} /* namespace cv::dnn::cuda4dnn::csl::device */ + +#endif /* OPENCV_DNN_SRC_CUDA_LIMITS_HPP */ diff --git a/modules/dnn/src/cuda/math.hpp b/modules/dnn/src/cuda/math.hpp new file mode 100644 index 0000000..d95191b --- /dev/null +++ b/modules/dnn/src/cuda/math.hpp @@ -0,0 +1,125 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA_MATH_HPP +#define OPENCV_DNN_SRC_CUDA_MATH_HPP + +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace device { + + template __device__ T abs(T val) { return (val < T(0) ? -val : val); } + template <> inline __device__ __half2 abs(__half2 val) { + val.x = abs(val.x); + val.y = abs(val.y); + return val; + } + template <> inline __device__ float abs(float val) { return fabsf(val); } + template <> inline __device__ double abs(double val) { return fabs(val); } + + template __device__ T exp(T val); + template <> inline __device__ __half exp(__half val) { return hexp(val); } + template <> inline __device__ __half2 exp(__half2 val) { return h2exp(val); } + template <> inline __device__ float exp(float val) { return expf(val); } + template <> inline __device__ double exp(double val) { return ::exp(val); } + + template __device__ T expm1(T val); + template <> inline __device__ __half expm1(__half val) { return hexp(val) + __half(1); } + template <> inline __device__ __half2 expm1(__half2 val) { return h2exp(val) + __half2(1, 1); } + template <> inline __device__ float expm1(float val) { return expm1f(val); } + template <> inline __device__ double expm1(double val) { return ::expm1(val); } + + template __device__ T max(T x, T y) { return (x > y ? x : y); } + template <> inline __device__ __half2 max(__half2 a, __half2 b) { + a.x = max(a.x, a.x); + a.y = max(a.y, b.y); + return a; + } + template <> inline __device__ float max(float x, float y) { return fmaxf(x, y); } + template <> inline __device__ double max(double x, double y) { return fmax(x, y); } + + template __device__ T min(T x, T y) { return (x > y ? y : x); } + template <> inline __device__ __half2 min(__half2 a, __half2 b) { + a.x = min(a.x, a.x); + a.y = min(a.y, b.y); + return a; + } + template <> inline __device__ float min(float x, float y) { return fminf(x, y); } + template <> inline __device__ double min(double x, double y) { return fmin(x, y); } + + template __device__ T log1p(T val); + template <> inline __device__ __half log1p(__half val) { return hlog(val) + __half(1); } + template <> inline __device__ __half2 log1p(__half2 val) { return h2log(val) + __half2(1, 1); } + template <> inline __device__ float log1p(float val) { return log1pf(val); } + + template __device__ T log1pexp(T val); + template <> inline __device__ __half log1pexp(__half val) { + if (val <= __half(-4.0)) + return exp(val); + else if (val <= __half(8.0)) + return log1p(exp(val)); + else if (val <= __half(8.7)) + return val + exp(-val); + else + return val; + } + template <> inline __device__ __half2 log1pexp(__half2 val) { + val.x = log1pexp(val.x); + val.y = log1pexp(val.y); + return val; + } + template <> inline __device__ float log1pexp(float val) { + if (val <= -20) + return expf(val); + else if (val <= 9.0) + return log1pf(expf(val)); + else if (val <= 14.6) + return val + exp(-val); + else + return val; + } + template <> inline __device__ double log1pexp(double val) { + if (val <= -37) + return exp(val); + else if (val <= 18) + return log1p(exp(val)); + else if (val <= 33.3) + return val + exp(-val); + else + return val; + } + + template __device__ T tanh(T val); + template <> inline __device__ __half tanh(__half val) { return tanhf(val); } + template <> inline __device__ __half2 tanh(__half2 val) { return __half2(tanh(val.x), tanh(val.y)); } + template <> inline __device__ float tanh(float val) { return tanhf(val); } + template <> inline __device__ double tanh(double val) { return ::tanh(val); } + + template __device__ T pow(T val, T exp); + template <> inline __device__ __half pow(__half val, __half exp) { return powf(val, exp); } + template <> inline __device__ __half2 pow(__half2 val, __half2 exp) { return __half2(pow(val.x, exp.x), pow(val.y, exp.y)); } + template <> inline __device__ float pow(float val, float exp) { return powf(val, exp); } + template <> inline __device__ double pow(double val, double exp) { return ::pow(val, exp); } + + template __device__ T sqrt(T val); + template <> inline __device__ __half sqrt(__half val) { return hsqrt(val); } + template <> inline __device__ __half2 sqrt(__half2 val) { return h2sqrt(val); } + template <> inline __device__ float sqrt(float val) { return sqrtf(val); } + template <> inline __device__ double sqrt(double val) { return ::sqrt(val); } + + template __device__ T rsqrt(T val); + template <> inline __device__ __half rsqrt(__half val) { return hrsqrt(val); } + template <> inline __device__ __half2 rsqrt(__half2 val) { return h2rsqrt(val); } + template <> inline __device__ float rsqrt(float val) { return rsqrtf(val); } + template <> inline __device__ double rsqrt(double val) { return ::rsqrt(val); } + + template __device__ T sigmoid(T val) { return T(1) / (T(1) + exp(-val)); } + template <> inline __device__ __half2 sigmoid(__half2 val) { return __half2(1, 1) / (__half2(1, 1) + exp(__hneg2(val))); } + + template __device__ T clamp(T value, T lower, T upper) { return min(max(value, lower), upper); } + +}}}}} /* namespace cv::dnn::cuda4dnn::csl::device */ + +#endif /* OPENCV_DNN_SRC_CUDA_MATH_HPP */ diff --git a/modules/dnn/src/cuda/max_unpooling.cu b/modules/dnn/src/cuda/max_unpooling.cu new file mode 100644 index 0000000..e388c95 --- /dev/null +++ b/modules/dnn/src/cuda/max_unpooling.cu @@ -0,0 +1,307 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include +#include + +#include "math.hpp" +#include "array.hpp" +#include "limits.hpp" +#include "types.hpp" +#include "grid_stride_range.hpp" +#include "execution.hpp" + +#include "../cuda4dnn/csl/stream.hpp" +#include "../cuda4dnn/csl/tensor.hpp" +#include "../cuda4dnn/csl/span.hpp" + +#include "../cuda4dnn/kernels/fill.hpp" + +#include + +#include +#include +#include + +using namespace cv::dnn::cuda4dnn::csl; +using namespace cv::dnn::cuda4dnn::csl::device; + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + namespace raw { + template ::type = true> /* Order has been hardcoded; see code */ + __global__ void max_pooling_with_indices( + Span output, Span indices, View input, size_type channels, + array out_spatial_dims, array in_spatial_dims, + array window_size, array strides, array padding_left) + { + /* every element in the output is mapped to a window in the input and each thread processes several windows */ + for (auto idx : grid_stride_range(output.size())) { + size_type out_spatial_size = 1; + array window_idx; + for (int i = Order - 1; i >= 0; i--) { + window_idx[i] = (idx / out_spatial_size) % out_spatial_dims[i]; + out_spatial_size *= out_spatial_dims[i]; + } + + const index_type n = idx / (out_spatial_size * channels); + const index_type c = (idx / out_spatial_size) % channels; + + array start; + for(int i = 0; i < Order; i++) + start[i] = window_idx[i] * strides[i] - padding_left[i]; + + array end; + for (int i = 0; i < Order; i++) { + using device::min; + end[i] = min(start[i] + window_size[i], in_spatial_dims[i]); + } + + for (int i = 0; i < Order; i++) { + using device::max; + start[i] = max(start[i], 0); + } + + T max_value = numeric_limits::lowest(); + index_type max_idx = -1; + + size_type in_spatial_size = 1; + for (int i = 0; i < Order; i++) + in_spatial_size *= in_spatial_dims[i]; + + const auto outer_offset = (n * channels + c) * in_spatial_size; + if (Order == 2) { + array idx; + for (idx[0] = start[0]; idx[0] != end[0]; idx[0]++) { + for (idx[1] = start[1]; idx[1] != end[1]; idx[1]++) { + index_type offset = 0; + index_type stride = 1; + for (int i = Order - 1; i >= 0; i--) { + offset += stride * idx[i]; + stride *= in_spatial_dims[i]; + } + + if (input[outer_offset + offset] > max_value) { + max_idx = offset; + max_value = input[outer_offset + offset]; + } + } + } + } else if(Order == 3) { + array idx; + for (idx[0] = start[0]; idx[0] != end[0]; idx[0]++) { + for (idx[1] = start[1]; idx[1] != end[1]; idx[1]++) { + for (idx[2] = start[2]; idx[2] != end[2]; idx[2]++) { + index_type offset = 0; + index_type stride = 1; + for (int i = Order - 1; i >= 0; i--) { + offset += stride * idx[i]; + stride *= in_spatial_dims[i]; + } + + if (input[outer_offset + offset] > max_value) { + max_idx = offset; + max_value = input[outer_offset + offset]; + } + } + } + } + } + + output[idx] = max_value; + indices[idx] = max_idx; + } + } + + template + __global__ void max_unpooling( + Span output, View input, View indices, size_type channels, + array out_spatial_dims, array in_spatial_dims, + array window_size, array strides, array padding_left) + { + /* the output has already been zero filled */ + /* Every input value represents a window in the output. The max unpooling operation + * copies the input value to exactly one location in the output window which is given + * by the indices tensor. + */ + for (auto idx : grid_stride_range(input.size())) { + size_type in_spatial_size = 1; + array window_idx; + for (int i = Order - 1; i >= 0; i--) { + window_idx[i] = (idx / in_spatial_size) % in_spatial_dims[i]; + in_spatial_size *= in_spatial_dims[i]; + } + + const index_type n = idx / (in_spatial_size * channels); + const index_type c = (idx / in_spatial_size) % channels; + + array start; + for (int i = 0; i < Order; i++) { + using device::min; + using device::max; + start[i] = max(0, min(window_idx[i] * strides[i] - padding_left[i], out_spatial_dims[i] - 1)); + } + + size_type out_spatial_size = 1; + for (int i = 0; i < Order; i++) + out_spatial_size *= out_spatial_dims[i]; + + index_type outer_offset = (n * channels + c) * out_spatial_size; + output[outer_offset + static_cast(indices[idx])] = input[idx]; + } + } + } + + template static + void launch_max_pooling_kernel( + const Stream& stream, + Span output, Span indices, View input, std::size_t channels, + const std::vector& out_spatial_dims, const std::vector& in_spatial_dims, + const std::vector& window_size, + const std::vector& strides, const std::vector& padding_left) + { + CV_Assert(indices.size() == output.size()); + CV_Assert(out_spatial_dims.size() == Order); + CV_Assert(in_spatial_dims.size() == Order); + CV_Assert(window_size.size() == Order); + CV_Assert(strides.size() == Order); + CV_Assert(padding_left.size() == Order); + + array out_spatial_dims_k, in_spatial_dims_k; + out_spatial_dims_k.assign(std::begin(out_spatial_dims), std::end(out_spatial_dims)); + in_spatial_dims_k.assign(std::begin(in_spatial_dims), std::end(in_spatial_dims)); + + array window_size_k, strides_k, padding_left_k; + window_size_k.assign(std::begin(window_size), std::end(window_size)); + strides_k.assign(std::begin(strides), std::end(strides)); + padding_left_k.assign(std::begin(padding_left), std::end(padding_left)); + + auto kernel = raw::max_pooling_with_indices; + auto policy = make_policy(kernel, output.size(), 0, stream); + launch_kernel(kernel, policy, output, indices, input, channels, + out_spatial_dims_k, in_spatial_dims_k, window_size_k, strides_k, padding_left_k); + } + + template + void max_pooling_with_indices( + const Stream& stream, + TensorSpan output, TensorSpan indices, TensorView input, + const std::vector& window_size, const std::vector& strides, + const std::vector& padding_left) + { + CV_Assert(is_shape_same(output, indices)); + CV_Assert(input.get_axis_size(1) == output.get_axis_size(1)); + + auto order = window_size.size(); + CV_Assert(strides.size() == order); + CV_Assert(padding_left.size() == order); + CV_Assert(output.rank() == order + 2); + CV_Assert(input.rank() == order + 2); + + std::vector out_spatial_dims(order), in_spatial_dims(order); + for (int i = 0; i < order; i++) { + in_spatial_dims[i] = input.get_axis_size(2 + i); + out_spatial_dims[i] = output.get_axis_size(2 + i); + } + + /* only max_pooling2d and max_pooling3d are supported */ + CV_Assert(2 <= order && order <= 3); + std::size_t channels = input.get_axis_size(1); + if (order == 3) { + launch_max_pooling_kernel(stream, output, indices, input, channels, + out_spatial_dims, in_spatial_dims, window_size, strides, padding_left); + } else if (order == 2) { + launch_max_pooling_kernel(stream, output, indices, input, channels, + out_spatial_dims, in_spatial_dims, window_size, strides, padding_left); + } + } + + template void max_pooling_with_indices(const Stream&, + TensorSpan<__half>, TensorSpan<__half>, TensorView<__half>, + const std::vector&, const std::vector&, + const std::vector&); + + template void max_pooling_with_indices(const Stream&, + TensorSpan, TensorSpan, TensorView, + const std::vector&, const std::vector&, + const std::vector&); + + template static + void launch_max_unpooling_kernel( + const Stream& stream, + Span output, View input, View indices, std::size_t channels, + const std::vector& out_spatial_dims, const std::vector& in_spatial_dims, + const std::vector& window_size, + const std::vector& strides, const std::vector& padding_left) + { + CV_Assert(out_spatial_dims.size() == Order); + CV_Assert(in_spatial_dims.size() == Order); + CV_Assert(window_size.size() == Order); + CV_Assert(strides.size() == Order); + CV_Assert(padding_left.size() == Order); + CV_Assert(indices.size() == input.size()); + + array out_spatial_dims_k, in_spatial_dims_k; + out_spatial_dims_k.assign(std::begin(out_spatial_dims), std::end(out_spatial_dims)); + in_spatial_dims_k.assign(std::begin(in_spatial_dims), std::end(in_spatial_dims)); + + array window_size_k, strides_k, padding_left_k; + window_size_k.assign(std::begin(window_size), std::end(window_size)); + strides_k.assign(std::begin(strides), std::end(strides)); + padding_left_k.assign(std::begin(padding_left), std::end(padding_left)); + + auto kernel = raw::max_unpooling; + auto policy = make_policy(kernel, input.size(), 0, stream); + launch_kernel(kernel, policy, output, input, indices, channels, + out_spatial_dims_k, in_spatial_dims_k, window_size_k, strides_k, padding_left_k); + } + + template + void max_unpooling( + const Stream& stream, + TensorSpan output, TensorView input, TensorView indices, + const std::vector& window_size, const std::vector& strides, + const std::vector& padding_left) + { + CV_Assert(is_shape_same(input, indices)); + CV_Assert(input.get_axis_size(1) == output.get_axis_size(1)); + + auto order = window_size.size(); + CV_Assert(strides.size() == order); + CV_Assert(padding_left.size() == order); + CV_Assert(output.rank() == order + 2); + CV_Assert(input.rank() == order + 2); + + std::vector out_spatial_dims(order), in_spatial_dims(order); + for (int i = 0; i < order; i++) { + in_spatial_dims[i] = input.get_axis_size(2 + i); + out_spatial_dims[i] = output.get_axis_size(2 + i); + } + + kernels::fill(stream, output, 0.0); + + /* only max_unpooling2d and max_unpooling3d are supported */ + CV_Assert(2 <= order && order <= 3); + std::size_t channels = input.get_axis_size(1); + if (order == 3) { + launch_max_unpooling_kernel(stream, output, input, indices, channels, + out_spatial_dims, in_spatial_dims, window_size, strides, padding_left); + } else if (order == 2) { + launch_max_unpooling_kernel(stream, output, input, indices, channels, + out_spatial_dims, in_spatial_dims, window_size, strides, padding_left); + } + } + + template void max_unpooling(const Stream&, + TensorSpan<__half>, TensorView<__half>, TensorView<__half>, + const std::vector&, const std::vector&, + const std::vector&); + + template void max_unpooling(const Stream&, + TensorSpan, TensorView, TensorView, + const std::vector&, const std::vector&, + const std::vector&); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ diff --git a/modules/dnn/src/cuda/normalize.cu b/modules/dnn/src/cuda/normalize.cu new file mode 100644 index 0000000..49dff9b --- /dev/null +++ b/modules/dnn/src/cuda/normalize.cu @@ -0,0 +1,121 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include +#include + +#include "array.hpp" +#include "math.hpp" +#include "types.hpp" +#include "atomics.hpp" +#include "grid_stride_range.hpp" +#include "execution.hpp" + +#include "../cuda4dnn/csl/stream.hpp" +#include "../cuda4dnn/csl/span.hpp" + +#include "../cuda4dnn/kernels/fill.hpp" +#include "../cuda4dnn/kernels/scale_shift.hpp" + +#include + +#include + +using namespace cv::dnn::cuda4dnn::csl; +using namespace cv::dnn::cuda4dnn::csl::device; + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + namespace raw { + template + __global__ void reduce_sum_abs(Span output, View input, size_type outer_stride, size_type mid_stride) { + for (auto idx : grid_stride_range(input.size())) { + const index_type outer_idx = idx / outer_stride; + const index_type inner_idx = idx % mid_stride; + + const index_type sum_idx = outer_idx * mid_stride + inner_idx; + atomicAdd(&output[sum_idx], device::abs(input[idx])); + } + } + + template + __global__ void reciprocal(Span output, T epsilon) { + for (auto idx : grid_stride_range(output.size())) + output[idx] = T(1) / (output[idx] + epsilon); + } + + template + __global__ void reduce_sum_squared(Span output, View input, size_type outer_stride, size_type mid_stride) { + for (auto idx : grid_stride_range(input.size())) { + const index_type outer_idx = idx / outer_stride; + const index_type inner_idx = idx % mid_stride; + + const index_type sum_idx = outer_idx * mid_stride + inner_idx; + atomicAdd(&output[sum_idx], input[idx] * input[idx]); + } + } + + template + __global__ void rsqrt(Span output, T epsilon) { + for (auto idx : grid_stride_range(output.size())) { + using device::sqrt; + output[idx] = T(1) / sqrt(output[idx] + epsilon); + } + } + + template + __global__ void apply_norm(Span output, View input, size_type outer_stride, size_type mid_stride, View sums) { + for (auto idx : grid_stride_range(output.size())) { + const index_type outer_idx = idx / outer_stride; + const index_type inner_idx = idx % mid_stride; + + const index_type sum_idx = outer_idx * mid_stride + inner_idx; + output[idx] = input[idx] * sums[sum_idx]; + } + } + } + + template + void normalize( + const Stream& stream, + Span output, + View input, std::size_t outer_size, std::size_t mid_size, std::size_t inner_size, std::size_t norm, T epsilon, + Span workspace) + { + CV_Assert(output.size() == input.size()); + CV_Assert(output.size() == outer_size * mid_size * inner_size); + CV_Assert(norm == 1 || norm == 2); + CV_Assert(workspace.size() >= outer_size * inner_size); + + auto sums = Span(workspace.data(), outer_size * inner_size); + + fill(stream, sums, 0.0); + + if (norm == 1) { + auto reduce_kernel = raw::reduce_sum_abs; + auto policy = make_policy(reduce_kernel, input.size(), 0, stream); + launch_kernel(reduce_kernel, policy, sums, input, mid_size * inner_size, inner_size); + + auto reciprocal_kernel = raw::reciprocal; + policy = make_policy(reciprocal_kernel, sums.size(), 0, stream); + launch_kernel(reciprocal_kernel, policy, sums, epsilon); + } else { + auto reduce_kernel = raw::reduce_sum_squared; + auto policy = make_policy(reduce_kernel, input.size(), 0, stream); + launch_kernel(reduce_kernel, policy, sums, input, mid_size * inner_size, inner_size); + + auto rsqrt_kernel = raw::rsqrt; + policy = make_policy(rsqrt_kernel, sums.size(), 0, stream); + launch_kernel(rsqrt_kernel, policy, sums, epsilon); + } + + auto scale_kernel = raw::apply_norm; + auto policy = make_policy(scale_kernel, output.size(), 0, stream); + launch_kernel(scale_kernel, policy, output, input, mid_size * inner_size, inner_size, sums); + } + + template void normalize(const Stream&, Span<__half>, View<__half>, std::size_t, std::size_t, std::size_t, std::size_t, __half, Span<__half>); + template void normalize(const Stream&, Span, View, std::size_t, std::size_t, std::size_t, std::size_t, float, Span); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ diff --git a/modules/dnn/src/cuda/padding.cu b/modules/dnn/src/cuda/padding.cu new file mode 100644 index 0000000..d8f4812 --- /dev/null +++ b/modules/dnn/src/cuda/padding.cu @@ -0,0 +1,199 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include +#include + +#include "array.hpp" +#include "math.hpp" +#include "types.hpp" +#include "grid_stride_range.hpp" +#include "execution.hpp" +#include "kernel_dispatcher.hpp" + +#include "../cuda4dnn/csl/stream.hpp" +#include "../cuda4dnn/csl/tensor.hpp" +#include "../cuda4dnn/csl/span.hpp" + +#include + +#include +#include +#include + +using namespace cv::dnn::cuda4dnn::csl; +using namespace cv::dnn::cuda4dnn::csl::device; + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + namespace raw { + template + __global__ void copy_with_reflection101( + Span output, array out_strides, array start, array end, + View input, array in_strides) + { + for (auto i : grid_stride_range(output.size())) { + /* compute output axis indices corresponding to element 'i' */ + array out_index; + out_index[0] = i / out_strides[0]; + for (int j = 1; j < Rank; j++) + out_index[j] = (i % out_strides[j - 1]) / out_strides[j]; + + /* compute input axis indices corresponding to output axis indices */ + array in_index; + for (int j = 0; j < Rank; j++) { + /* if out_index < start, the point is in the left reflection region + * the reflected value's index is the absolute value of the difference + * + * otherwise, if the value is in the copy region, out_index - start gives the input index + */ + using device::abs; + in_index[j] = abs(out_index[j] - start[j]); + + /* if out_index >= end, it's in the right reflection region */ + if (out_index[j] >= end[j]) + in_index[j] = (end[j] - start[j]) - (out_index[j] - end[j]) - 2; + } + + /* compute input element number from input axis indices */ + index_type iidx = 0; + for (int j = 0; j < Rank; j++) + iidx += in_index[j] * in_strides[j]; + + output[i] = input[iidx]; + } + } + } + + template static + void launch_copy_with_reflection101( + const Stream& stream, + Span output, const std::vector& outStride, + View input, const std::vector& inStride, + const std::vector>& ranges) + { + CV_Assert(outStride.size() == Rank); + CV_Assert(inStride.size() == Rank); + CV_Assert(ranges.size() == Rank); + + array outStride_k, inStride_k; + outStride_k.assign(std::begin(outStride), std::end(outStride)); + inStride_k.assign(std::begin(inStride), std::end(inStride)); + + array start_k, end_k; + for (int i = 0; i < Rank; i++) { + start_k[i] = ranges[i].first; + end_k[i] = ranges[i].second; + } + + auto kernel = raw::copy_with_reflection101; + auto policy = make_policy(kernel, output.size(), 0, stream); + launch_kernel(kernel, policy, output, outStride_k, start_k, end_k, input, inStride_k); + } + + GENERATE_KERNEL_DISPATCHER(copy_with_reflection101_dispatcher, launch_copy_with_reflection101); + + template + void copy_with_reflection101( + const Stream& stream, + TensorSpan output, TensorView input, + std::vector> ranges) + { + CV_Assert(output.rank() == input.rank()); + CV_Assert(output.rank() == ranges.size()); + + /* squeezable axes at the begining of both tensors can be eliminated + * + * Reasoning: + * ---------- + * Suppose an item's indices in the input tensor is [i1, i2, ...]. The indices in the + * output tensor will be [i1 + off1, i2 + off2, ...]. The rest of the elements in the output are padding. + * The padding operation essentially copies items from the input tensor to new locations in the output tensor + * and pads the remaining. + * + * If the size of the first axis of the input and output tensor is unity, the input and output indices + * for all the elements will be of the form be [0, i2, ...] and [0, i2 + off2, ...] respectively. Note that + * there cannot be extra padding since the axes have unit size. The first index does not contribute to the + * element's address calculation and hence does nothing apart from eating up few cycles. + */ + while (input.get_axis_size(0) == 1 && output.get_axis_size(0) == 1) { + CV_Assert(ranges[0].first == 0 && ranges[0].second == 1); + + input.squeeze(0); + output.squeeze(0); + ranges.erase(std::begin(ranges)); + + CV_Assert(output.rank() == input.rank()); + CV_Assert(output.rank() == ranges.size()); + } + + auto inShape = input.shape_as_vector(); + auto outShape = output.shape_as_vector(); + + /* contiguous axes which do not have any padding can be combined into one axis + * + * Reasoning: + * ---------- + * Suppose an item's indices in the input tensor is [i1, i2, i3, ...]. Let the first two axes not have any + * padding. The indices in the output tensor will be [i1, i2, i3 + off3, ...]. + * + * Each axis in the contiguous unpadded axes sequence will add an offset of iN * strideN. In the above example, + * the two axes add a total offset of `i1 * stride1 + i2 * stride2`. We can merge the two axes into one axis with + * a size of `size1 * size2`. The new offset added will be `i12 * stride2` as the kernel iterates through `i12`. + * Note that `i12` is actually `(i1 * size2 + i2)` in the original tensor. + */ + for (int i = 0; i < inShape.size(); i++) { + /* check if axis `i` requires any padding */ + if (ranges[i].first == 0 && ranges[i].second == inShape[i]) { + /* loop invariant: `i` is the first axis in the contiguous unpadded axis sequence */ + CV_Assert(inShape[i] == outShape[i]); + + /* we now iterate through the axes which follow and try to merge */ + int j = i + 1; /* `j` is the axis which we will attempt to merge */ + while (j < inShape.size() && ranges[j].first == 0 && ranges[j].second == inShape[j]) { + CV_Assert(inShape[j] == outShape[j]); + + /* `j` is also unpadded; merge `i` and `j` */ + auto new_size = inShape[i] * inShape[j]; + inShape[i] = new_size; + outShape[i] = new_size; + ranges[i].second = new_size; + + /* delete axis `j` */ + inShape.erase(std::begin(inShape) + j); + outShape.erase(std::begin(outShape) + j); + ranges.erase(std::begin(ranges) + j); + + /* optimizations should not break the invariants */ + CV_Assert(inShape.size() == outShape.size()); + CV_Assert(inShape.size() == ranges.size()); + CV_Assert(inShape[i] == outShape[i]); + CV_Assert(ranges[i].first == 0 && ranges[i].second == inShape[i]); + } + } + } + + auto rank = inShape.size(); + + std::vector inStride(rank), outStride(rank); + inStride.back() = 1; + outStride.back() = 1; + /* garbage, ..., garbage, 1 */ + + std::copy(std::begin(inShape) + 1, std::end(inShape), std::begin(inStride)); + std::copy(std::begin(outShape) + 1, std::end(outShape), std::begin(outStride)); + /* dim[0], dim[1], ..., dim[-1], 1 */ + + std::partial_sum(inStride.rbegin(), inStride.rend(), inStride.rbegin(), std::multiplies()); + std::partial_sum(outStride.rbegin(), outStride.rend(), outStride.rbegin(), std::multiplies()); + /* stride[0], stride[1], ..., stride[-2], 1 */ + + CV_Assert(1 <= rank && rank <= CSL_MAX_TENSOR_RANK); + copy_with_reflection101_dispatcher(rank, stream, output, outStride, input, inStride, ranges); + } + + template void copy_with_reflection101(const Stream&, TensorSpan<__half>, TensorView<__half>, std::vector> ranges); + template void copy_with_reflection101(const Stream&, TensorSpan, TensorView, std::vector> ranges); + +}}}} /* namespace namespace cv::dnn::cuda4dnn::kernels */ diff --git a/modules/dnn/src/cuda/permute.cu b/modules/dnn/src/cuda/permute.cu new file mode 100644 index 0000000..7d0ffe8 --- /dev/null +++ b/modules/dnn/src/cuda/permute.cu @@ -0,0 +1,143 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include +#include + +#include "array.hpp" +#include "types.hpp" +#include "grid_stride_range.hpp" +#include "execution.hpp" +#include "kernel_dispatcher.hpp" + +#include "../cuda4dnn/csl/stream.hpp" +#include "../cuda4dnn/csl/tensor.hpp" +#include "../cuda4dnn/csl/span.hpp" + +#include + +#include +#include + +using namespace cv::dnn::cuda4dnn::csl; +using namespace cv::dnn::cuda4dnn::csl::device; + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + namespace raw { + template + __global__ void permute( + array axis_order, + Span output, array outStrides, + View input, array inStrides) + { + for (auto i : grid_stride_range(input.size())) { + index_type oldPosition = 0; + index_type newPosition = i; + + for (int j = 0; j < Rank; j++) + { + auto order = axis_order[j]; + oldPosition += (newPosition / outStrides[j]) * inStrides[order]; + newPosition %= outStrides[j]; + } + + output[i] = input[oldPosition]; + } + } + } + + template static + void launch_permute_kernel( + const Stream& stream, + const std::vector& order, + Span output, const std::vector& outStride, + View input, const std::vector& inStride) + { + CV_Assert(order.size() == Rank); + CV_Assert(outStride.size() == Rank); + CV_Assert(inStride.size() == Rank); + + array order_k; + order_k.assign(std::begin(order), std::end(order)); + + array outStride_k, inStride_k; + outStride_k.assign(std::begin(outStride), std::end(outStride)); + inStride_k.assign(std::begin(inStride), std::end(inStride)); + + auto kernel = raw::permute; + auto policy = make_policy(kernel, input.size(), 0, stream); + launch_kernel(kernel, policy, order_k, output, outStride_k, input, inStride_k); + } + + GENERATE_KERNEL_DISPATCHER(permute_dispatcher, launch_permute_kernel); + + template + void permute( + const Stream& stream, + TensorSpan output, TensorView input, + std::vector order) + { + CV_Assert(output.rank() == input.rank()); + CV_Assert(input.rank() == order.size()); + CV_Assert(input.size() == output.size()); + + /* squeezable axes at the begining of both tensors which aren't permuted can be eliminated + * + * Reasoning: + * ---------- + * Suppose an item's indices in the input tensor is [i1, i2, ...]. The indices in the + * output tensor will be some permutation of the input tensor indices. Let the output + * tensor indices be [o1, o2, ...]. The permutation operation essentially copies items + * from the input tensor to new locations in the output tensor as dictated by the indices. + * + * If the size of the first axis of the input and output tensor is one and these axes are + * not involved in any permutation, i.e. order[0] = 0, the input and output indicies for + * all the elements will be of the form be [0, i2, ...] and [0, o2, ...] respectively. + * The first index does not contribute to the element's address calculation and hence does + * nothing apart from eating up few cycles. + */ + while (order[0] == 0 && input.get_axis_size(0) == 1 && output.get_axis_size(0) == 1) { + /* remove the axes */ + input.squeeze(0); + output.squeeze(0); + + /* when we remove axis zero, the axis index will be one less than the previous index + * for the remaining axes + */ + order.erase(order.begin()); + for (auto& axis : order) + axis--; + + /* optimizations should not break the invariants */ + CV_Assert(output.rank() == input.rank()); + CV_Assert(input.rank() == order.size()); + CV_Assert(input.size() == output.size()); + } + + auto rank = output.rank(); + auto inShape = input.shape_as_vector(); + auto outShape = output.shape_as_vector(); + + std::vector inStride(rank), outStride(rank); + inStride.back() = 1; + outStride.back() = 1; + /* garbage, ..., garbage, 1 */ + + std::copy(std::begin(inShape) + 1, std::end(inShape), std::begin(inStride)); + std::copy(std::begin(outShape) + 1, std::end(outShape), std::begin(outStride)); + /* dim[0], dim[1], ..., dim[-1], 1 */ + + std::partial_sum(inStride.rbegin(), inStride.rend(), inStride.rbegin(), std::multiplies()); + std::partial_sum(outStride.rbegin(), outStride.rend(), outStride.rbegin(), std::multiplies()); + /* stride[0], stride[1], ..., stride[-2], 1 */ + + CV_Assert(2 <= rank && rank <= CSL_MAX_TENSOR_RANK); + permute_dispatcher(rank, stream, order, output, outStride, input, inStride); + } + + template void permute(const Stream&, TensorSpan<__half>, TensorView<__half>, std::vector); + template void permute(const Stream&, TensorSpan, TensorView, std::vector); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ diff --git a/modules/dnn/src/cuda/prior_box.cu b/modules/dnn/src/cuda/prior_box.cu new file mode 100644 index 0000000..313fefc --- /dev/null +++ b/modules/dnn/src/cuda/prior_box.cu @@ -0,0 +1,174 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include +#include + +#include "array.hpp" +#include "math.hpp" +#include "types.hpp" +#include "vector_traits.hpp" +#include "grid_stride_range.hpp" +#include "execution.hpp" + +#include "../cuda4dnn/csl/stream.hpp" +#include "../cuda4dnn/csl/span.hpp" + +#include + +using namespace cv::dnn::cuda4dnn::csl; +using namespace cv::dnn::cuda4dnn::csl::device; + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + namespace raw { + template + __global__ void prior_box( + Span output, + View boxWidth, View boxHeight, View offsetX, View offsetY, float stepX, float stepY, + size_type layerWidth, size_type layerHeight, + size_type imageWidth, size_type imageHeight) + { + /* each box consists of two pair of coordinates and hence 4 values in total */ + /* since the entire output consists (first channel at least) of these boxes, + * we are garunteeed that the output is aligned to a boundary of 4 values + */ + using vector_type = get_vector_type_t; + auto output_vPtr = vector_type::get_pointer(output.data()); + + /* num_points contains the number of points in the feature map of interest + * each iteration of the stride loop selects a point and generates prior boxes for it + */ + size_type num_points = layerWidth * layerHeight; + for (auto idx : grid_stride_range(num_points)) { + const index_type x = idx % layerWidth, + y = idx / layerWidth; + + index_type output_offset_v4 = idx * offsetX.size() * boxWidth.size(); + for (int i = 0; i < boxWidth.size(); i++) { + for (int j = 0; j < offsetX.size(); j++) { + float center_x = (x + offsetX[j]) * stepX; + float center_y = (y + offsetY[j]) * stepY; + + vector_type vec; + if(Normalize) { + vec.data[0] = (center_x - boxWidth[i] * 0.5f) / imageWidth; + vec.data[1] = (center_y - boxHeight[i] * 0.5f) / imageHeight; + vec.data[2] = (center_x + boxWidth[i] * 0.5f) / imageWidth; + vec.data[3] = (center_y + boxHeight[i] * 0.5f) / imageHeight; + } else { + vec.data[0] = center_x - boxWidth[i] * 0.5f; + vec.data[1] = center_y - boxHeight[i] * 0.5f; + vec.data[2] = center_x + boxWidth[i] * 0.5f - 1.0f; + vec.data[3] = center_y + boxHeight[i] * 0.5f - 1.0f; + } + + v_store(output_vPtr[output_offset_v4], vec); + output_offset_v4++; + } + } + } + } + + template + __global__ void prior_box_clip(Span output) { + for (auto i : grid_stride_range(output.size())) { + using device::clamp; + output[i] = clamp(output[i], 0.0, 1.0); + } + } + + template + __global__ void prior_box_set_variance1(Span output, float variance) { + using vector_type = get_vector_type_t; + auto output_vPtr = vector_type::get_pointer(output.data()); + for (auto i : grid_stride_range(output.size() / 4)) { + vector_type vec; + for (int j = 0; j < 4; j++) + vec.data[j] = variance; + v_store(output_vPtr[i], vec); + } + } + + template + __global__ void prior_box_set_variance4(Span output, array variance) { + using vector_type = get_vector_type_t; + auto output_vPtr = vector_type::get_pointer(output.data()); + for (auto i : grid_stride_range(output.size() / 4)) { + vector_type vec; + for(int j = 0; j < 4; j++) + vec.data[j] = variance[j]; + v_store(output_vPtr[i], vec); + } + } + } + + template static + void launch_prior_box_kernel( + const Stream& stream, + Span output, View boxWidth, View boxHeight, View offsetX, View offsetY, float stepX, float stepY, + std::size_t layerWidth, std::size_t layerHeight, std::size_t imageWidth, std::size_t imageHeight) + { + auto num_points = layerWidth * layerHeight; + auto kernel = raw::prior_box; + auto policy = make_policy(kernel, num_points, 0, stream); + launch_kernel(kernel, policy, + output, boxWidth, boxHeight, offsetX, offsetY, stepX, stepY, + layerWidth, layerHeight, imageWidth, imageHeight); + } + + template + void generate_prior_boxes( + const Stream& stream, + Span output, + View boxWidth, View boxHeight, View offsetX, View offsetY, float stepX, float stepY, + std::vector variance, + std::size_t numPriors, + std::size_t layerWidth, std::size_t layerHeight, + std::size_t imageWidth, std::size_t imageHeight, + bool normalize, bool clip) + { + if (normalize) { + launch_prior_box_kernel( + stream, output, boxWidth, boxHeight, offsetX, offsetY, stepX, stepY, + layerWidth, layerHeight, imageWidth, imageHeight + ); + } else { + launch_prior_box_kernel( + stream, output, boxWidth, boxHeight, offsetX, offsetY, stepX, stepY, + layerWidth, layerHeight, imageWidth, imageHeight + ); + } + + std::size_t channel_size = layerHeight * layerWidth * numPriors * 4; + CV_Assert(channel_size * 2 == output.size()); + + if (clip) { + auto output_span_c1 = Span(output.data(), channel_size); + auto kernel = raw::prior_box_clip; + auto policy = make_policy(kernel, output_span_c1.size(), 0, stream); + launch_kernel(kernel, policy, output_span_c1); + } + + auto output_span_c2 = Span(output.data() + channel_size, channel_size); + if (variance.size() == 1) { + auto kernel = raw::prior_box_set_variance1; + auto policy = make_policy(kernel, output_span_c2.size() / 4, 0, stream); + launch_kernel(kernel, policy, output_span_c2, variance[0]); + } else { + array variance_k; + variance_k.assign(std::begin(variance), std::end(variance)); + auto kernel = raw::prior_box_set_variance4; + auto policy = make_policy(kernel, output_span_c2.size() / 4, 0, stream); + launch_kernel(kernel, policy, output_span_c2, variance_k); + } + } + + template void generate_prior_boxes(const Stream&, Span<__half>, View, View, View, View, float, float, + std::vector, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t, bool, bool); + + template void generate_prior_boxes(const Stream&, Span, View, View, View, View, float, float, + std::vector, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t, bool, bool); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ diff --git a/modules/dnn/src/cuda/region.cu b/modules/dnn/src/cuda/region.cu new file mode 100644 index 0000000..158deb9 --- /dev/null +++ b/modules/dnn/src/cuda/region.cu @@ -0,0 +1,199 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include +#include + +#include "math.hpp" +#include "grid_stride_range.hpp" +#include "execution.hpp" +#include "limits.hpp" +#include "vector_traits.hpp" + +#include "../cuda4dnn/csl/stream.hpp" +#include "../cuda4dnn/csl/span.hpp" + +#include + +#include + +using namespace cv::dnn::cuda4dnn::csl; +using namespace cv::dnn::cuda4dnn::csl::device; + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + namespace raw { + template + __global__ void sigmoid_strided(Span output, View input, size_type n, size_type stride, size_type offset) { + /* - the input is divided into equal blocks strided by `stride` + * - we must apply sigmoid to a continuous range of `n` values starting from `offset` in every block + */ + for (auto i : grid_stride_range(n * output.size() / stride)) { + auto block_idx = i / n; + auto index = block_idx * stride + offset + (i % n); + + using device::sigmoid; + output[index] = sigmoid(input[index]); + } + } + + template + __global__ void softmax_strided(Span output, View input, size_type n, size_type stride, size_type offset_) { + for (auto idx : grid_stride_range(output.size() / stride)) { + index_type offset = idx * stride + offset_; + + auto largest = numeric_limits::lowest(); + for (int i = 0; i < n; i++) { + using device::max; + largest = max(largest, output[offset + i]); + } + + auto sum = T(0); + for (int i = 0; i < n; i++) { + using device::exp; + auto temp = exp(output[offset + i] - largest); + sum += temp; + output[offset + i] = temp; + } + + for (int i = 0; i < n; i++) { + output[offset + i] /= sum; + } + } + } + + template + __global__ void region_finalize(Span output, View input, View bias, + T object_prob_cutoff, T class_prob_cutoff, + size_type height_norm, size_type width_norm, + size_type rows, size_type cols, + size_type boxes_per_cell, + size_type box_size, + size_type classes) + { + for (auto box_index : grid_stride_range(output.size() / box_size)) { + auto box_of_the_cell = box_index % boxes_per_cell; /* box number within a cell */ + auto box_offset = box_index * box_size; + + auto batch_inner_size = rows * cols * boxes_per_cell; + auto row_inner_size = cols * boxes_per_cell; + auto col_inner_size = boxes_per_cell; + + auto y = (box_index % batch_inner_size) / row_inner_size; + auto x = (box_index % row_inner_size) / col_inner_size; + + using device::sigmoid; + using device::exp; + output[box_offset + 0] = (T(x) + sigmoid(input[box_offset + 0])) / T(cols); + output[box_offset + 1] = (T(y) + sigmoid(input[box_offset + 1])) / T(rows); + output[box_offset + 2] = exp(input[box_offset + 2]) * bias[2 * box_of_the_cell + 0] / T(width_norm); + output[box_offset + 3] = exp(input[box_offset + 3]) * bias[2 * box_of_the_cell + 1] / T(height_norm); + + /* squash objectness score into a probability */ + using device::sigmoid; + T objectness_prob = sigmoid(output[box_offset + 4]); + output[box_offset + 4] = objectness_prob; + + /* ignore prediction if the objectness probability is less than the cutoff */ + if (objectness_prob < object_prob_cutoff) + objectness_prob = 0; + + /* the class probabilities we have currently are conditional class probabilities + * given the object + * + * to obtain the actual class probability, we multiply the conditional probability + * with the object probability + */ + const index_type class_begin = box_offset + 5; /* 4 box coordinates, 1 obj prob, class probs... */ + const index_type class_end = class_begin + classes; + index_type offset = class_begin; + + using vector_type = get_vector_type_t; + + /* process each class independently until the offset is aligned to an n-element boundary */ + while (offset % vector_type::size() != 0 && offset < class_end) { + T actual_class_prob = objectness_prob * output[offset]; + if (actual_class_prob <= class_prob_cutoff) + actual_class_prob = T(0); + output[offset] = actual_class_prob; + offset++; + } + + auto output_vPtr = vector_type::get_pointer(output.data() + offset); + auto input_vPtr = vector_type::get_pointer(input.data() + offset); + for (int i = 0; (offset + vector_type::size()) < class_end; i++) { + vector_type vec; + v_load(vec, output_vPtr[i]); + for (int j = 0; j < vector_type::size(); j++) { + T actual_class_prob = objectness_prob * vec.data[j]; + if (actual_class_prob <= class_prob_cutoff) + actual_class_prob = T(0); + vec.data[j] = actual_class_prob; + } + v_store(output_vPtr[i], vec); + offset += vector_type::size(); + } + + /* process the remaining classes */ + while (offset < class_end) { + T actual_class_prob = objectness_prob * output[offset]; + if (actual_class_prob <= class_prob_cutoff) + actual_class_prob = T(0); + output[offset] = actual_class_prob; + offset++; + } + } + } + } + + template + void sigmoid_strided(const Stream& stream, Span output, View input, std::size_t n, std::size_t stride, std::size_t offset) { + CV_Assert(output.size() % stride == 0); + + auto kernel = raw::sigmoid_strided; + auto policy = make_policy(kernel, n * output.size() / stride, 0, stream); + launch_kernel(kernel, policy, output, input, n, stride, offset); + } + + template void sigmoid_strided(const Stream&, Span<__half>, View<__half>, std::size_t, std::size_t, std::size_t); + template void sigmoid_strided(const Stream&, Span, View, std::size_t, std::size_t, std::size_t); + + template + void softmax_strided(const Stream& stream, Span output, View input, std::size_t n, std::size_t stride, std::size_t offset) { + CV_Assert(output.size() % stride == 0); + + auto kernel = raw::softmax_strided; + auto policy = make_policy(kernel, output.size() / stride, 0, stream); + launch_kernel(kernel, policy, output, input, n, stride, offset); + } + + template void softmax_strided(const Stream&, Span<__half>, View<__half>, std::size_t, std::size_t, std::size_t); + template void softmax_strided(const Stream&, Span, View, std::size_t, std::size_t, std::size_t); + + template + void region_finalize(const Stream& stream, Span output, View input, View bias, + T object_prob_cutoff, T class_prob_cutoff, + std::size_t height_norm, std::size_t width_norm, + std::size_t rows, std::size_t cols, + std::size_t boxes_per_cell, + std::size_t box_size, + std::size_t classes) + { + CV_Assert(output.size() % box_size == 0); + + auto kernel = raw::region_finalize; + auto policy = make_policy(kernel, output.size() / box_size, 0, stream); + launch_kernel(kernel, policy, output, input, bias, + object_prob_cutoff, class_prob_cutoff, + height_norm, width_norm, + rows, cols, boxes_per_cell, box_size, classes); + } + + template void region_finalize(const Stream&, Span<__half>, View<__half>, View<__half>, + __half, __half, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t); + + template void region_finalize(const Stream&, Span, View, View, + float, float, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ diff --git a/modules/dnn/src/cuda/resize.cu b/modules/dnn/src/cuda/resize.cu new file mode 100644 index 0000000..6eed48a --- /dev/null +++ b/modules/dnn/src/cuda/resize.cu @@ -0,0 +1,133 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include +#include + +#include "math.hpp" +#include "types.hpp" +#include "grid_stride_range.hpp" +#include "execution.hpp" + +#include "../cuda4dnn/csl/stream.hpp" +#include "../cuda4dnn/csl/tensor.hpp" +#include "../cuda4dnn/csl/span.hpp" + +#include + +using namespace cv::dnn::cuda4dnn::csl; +using namespace cv::dnn::cuda4dnn::csl::device; + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + namespace raw { + template + __global__ void resize_nn( + Span output, size_type out_height, size_type out_width, + View input, size_type in_height, size_type in_width) + { + auto in_image_size = in_height * in_width; + auto out_image_size = out_height * out_width; + + /* o2i = output to input */ + auto o2i_fx = static_cast(in_width) / out_width; + auto o2i_fy = static_cast(in_height) / out_height; + + /* think of the output and input as a collection of 2d images with the last axis + * representing the width and the last but one axis representing the height + * + * the remaining axis together form a collection of these images + */ + for (auto idx : grid_stride_range(output.size())) { + const index_type n = idx / out_image_size; + const index_type x = (idx % out_image_size) % out_width; + const index_type y = (idx % out_image_size) / out_width; + + auto in_x = static_cast(x * o2i_fx); + auto in_y = static_cast(y * o2i_fy); + + index_type in_idx = n * in_image_size + in_y * in_width + in_x; + output[idx] = input[in_idx]; + } + } + + template + __global__ void resize_bilinear( + Span output, size_type out_height, size_type out_width, + View input, size_type in_height, size_type in_width, + float o2i_fy, float o2i_fx) + { + auto in_image_size = in_height * in_width; + auto out_image_size = out_height * out_width; + + /* think of the output and input as a collection of 2d images with the last axis + * representing the width and the last but one axis representing the height + * + * the remaining axis together form a collection of these images + */ + for (auto idx : grid_stride_range(output.size())) { + const index_type n = idx / out_image_size; + const index_type x = (idx % out_image_size) % out_width; + const index_type y = (idx % out_image_size) / out_width; + + auto in_x = x * o2i_fx; + auto in_y = y * o2i_fy; + + auto in_x0 = static_cast(in_x); + auto in_y0 = static_cast(in_y); + + using device::min; + auto in_x1 = min(in_x0 + 1, in_width - 1); + auto in_y1 = min(in_y0 + 1, in_height - 1); + + const index_type in_offset_r0 = n * in_image_size + in_y0 * in_width; + const index_type in_offset_r1 = n * in_image_size + in_y1 * in_width; + + auto v_00 = input[in_offset_r0 + in_x0], + v_01 = input[in_offset_r0 + in_x1], + v_10 = input[in_offset_r1 + in_x0], + v_11 = input[in_offset_r1 + in_x1]; + + output[idx] = + v_00 + + T(in_y - in_y0) * T(v_10 - v_00) + + T(in_x - in_x0) * T(v_01 - v_00) + + T(in_y - in_y0) * T(in_x - in_x0) * T(v_11 - v_01 - v_10 + v_00); + } + } + } + + template + void resize_nn(const Stream& stream, TensorSpan output, TensorView input) { + auto in_height = input.get_axis_size(-2); + auto in_width = input.get_axis_size(-1); + + auto out_height = output.get_axis_size(-2); + auto out_width = output.get_axis_size(-1); + + auto kernel = raw::resize_nn; + auto policy = make_policy(kernel, output.size(), 0, stream); + launch_kernel(kernel, policy, output, out_height, out_width, input, in_height, in_width); + } + + template void resize_nn<__half>(const Stream&, TensorSpan<__half>, TensorView<__half>); + template void resize_nn(const Stream&, TensorSpan, TensorView); + + template + void resize_bilinear(const Stream& stream, TensorSpan output, TensorView input, float scale_y, float scale_x) { + auto in_height = input.get_axis_size(-2); + auto in_width = input.get_axis_size(-1); + + auto out_height = output.get_axis_size(-2); + auto out_width = output.get_axis_size(-1); + + auto kernel = raw::resize_bilinear; + auto policy = make_policy(kernel, output.size(), 0, stream); + launch_kernel(kernel, policy, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x); + } + + template void resize_bilinear<__half>(const Stream&, TensorSpan<__half>, TensorView<__half>, float, float); + template void resize_bilinear(const Stream&, TensorSpan, TensorView, float, float); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ diff --git a/modules/dnn/src/cuda/scale_shift.cu b/modules/dnn/src/cuda/scale_shift.cu new file mode 100644 index 0000000..05f4374 --- /dev/null +++ b/modules/dnn/src/cuda/scale_shift.cu @@ -0,0 +1,311 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include +#include + +#include "types.hpp" +#include "vector_traits.hpp" +#include "grid_stride_range.hpp" +#include "execution.hpp" + +#include "../cuda4dnn/csl/stream.hpp" +#include "../cuda4dnn/csl/tensor.hpp" +#include "../cuda4dnn/csl/span.hpp" + +#include + +#include + +using namespace cv::dnn::cuda4dnn::csl; +using namespace cv::dnn::cuda4dnn::csl::device; + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + namespace raw { + template + __global__ void bias1_vec(Span output, View input, T beta) { + using vector_type = get_vector_type_t; + + auto output_vPtr = vector_type::get_pointer(output.data()); + auto input_vPtr = vector_type::get_pointer(input.data()); + + for (auto i : grid_stride_range(output.size() / vector_type::size())) { + vector_type vec; + v_load(vec, input_vPtr[i]); + for (int j = 0; j < vec.size(); j++) + vec.data[j] = vec.data[j] + beta; + v_store(output_vPtr[i], vec); + } + } + + template + __global__ void biasN_vec(Span output, View input, size_type inner_size, View bias) { + using vector_type = get_vector_type_t; + + auto output_vPtr = vector_type::get_pointer(output.data()); + auto input_vPtr = vector_type::get_pointer(input.data()); + + inner_size /= vector_type::size(); + for (auto i : grid_stride_range(output.size() / vector_type::size())) { + const index_type bias_idx = (i / inner_size) % static_cast(bias.size()); + + vector_type vec; + v_load(vec, input_vPtr[i]); + for(int j = 0; j < vec.size(); j++) + vec.data[j] = vec.data[j] + bias[bias_idx]; + v_store(output_vPtr[i], vec); + } + } + + template + __global__ void scale1_vec(Span output, View input, T alpha) { + using vector_type = get_vector_type_t; + + auto output_vPtr = vector_type::get_pointer(output.data()); + auto input_vPtr = vector_type::get_pointer(input.data()); + + for (auto i : grid_stride_range(output.size() / vector_type::size())) { + vector_type vec; + v_load(vec, input_vPtr[i]); + for (int j = 0; j < vec.size(); j++) + vec.data[j] = vec.data[j] * alpha; + v_store(output_vPtr[i], vec); + } + } + + template + __global__ void scaleN_vec(Span output, View input, size_type inner_size, View weights) + { + using vector_type = get_vector_type_t; + + auto output_vPtr = vector_type::get_pointer(output.data()); + auto input_vPtr = vector_type::get_pointer(input.data()); + + inner_size /= vector_type::size(); + for (auto i : grid_stride_range(output.size() / vector_type::size())) { + const index_type scale_idx = (i / inner_size) % static_cast(weights.size()); + + vector_type vec; + v_load(vec, input_vPtr[i]); + for (int j = 0; j < vec.size(); j++) + vec.data[j] = vec.data[j] * weights[scale_idx]; + v_store(output_vPtr[i], vec); + } + } + + template + __global__ void scale1_with_bias1_vec(Span output, View input, T alpha, T beta) + { + using vector_type = get_vector_type_t; + + auto output_vPtr = vector_type::get_pointer(output.data()); + auto input_vPtr = vector_type::get_pointer(input.data()); + + for (auto i : grid_stride_range(output.size() / vector_type::size())) { + vector_type vec; + v_load(vec, input_vPtr[i]); + for (int j = 0; j < vec.size(); j++) + vec.data[j] = alpha * vec.data[j] + beta; + v_store(output_vPtr[i], vec); + } + } + + template + __global__ void scaleN_with_biasN_vec(Span output, View input, size_type inner_size, View weights, View bias) + { + using vector_type = get_vector_type_t; + + auto output_vPtr = vector_type::get_pointer(output.data()); + auto input_vPtr = vector_type::get_pointer(input.data()); + + inner_size /= vector_type::size(); + for (auto i : grid_stride_range(output.size() / vector_type::size())) { + const index_type scale_idx = (i / inner_size) % static_cast(weights.size()); + + vector_type vec; + v_load(vec, input_vPtr[i]); + for (int j = 0; j < vec.size(); j++) + vec.data[j] = vec.data[j] * weights[scale_idx] + bias[scale_idx]; + v_store(output_vPtr[i], vec); + } + } + } + + template static + void launch_bias1_vec_kernel(const Stream& stream, Span output, View input, T beta) { + CV_Assert(is_fully_aligned(output, N)); + CV_Assert(is_fully_aligned(input, N)); + + auto kernel = raw::bias1_vec; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, input, beta); + } + + template + void bias1(const Stream& stream, TensorSpan output, TensorView input, T beta) { + CV_Assert(is_shape_same(input, output)); + + if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4)) { + launch_bias1_vec_kernel(stream, output, input, beta); + } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2)) { + launch_bias1_vec_kernel(stream, output, input, beta); + } else { + launch_bias1_vec_kernel(stream, output, input, beta); + } + } + + template void bias1<__half>(const Stream&, TensorSpan<__half>, TensorView<__half>, __half); + template void bias1(const Stream&, TensorSpan, TensorView, float); + + template static + void launch_biasN_vec_kernel(const Stream& stream, Span output, View input, std::size_t inner_size, View bias){ + CV_Assert(is_fully_aligned(output, N)); + CV_Assert(is_fully_aligned(input, N)); + CV_Assert(inner_size % N == 0); + + auto kernel = raw::biasN_vec; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, input, inner_size, bias); + } + + template + void biasN( + const Stream& stream, + TensorSpan output, + TensorView input, std::size_t inner_size, + TensorView bias) + { + CV_Assert(is_shape_same(input, output)); + + if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4) && inner_size % 4 == 0) { + launch_biasN_vec_kernel(stream, output, input, inner_size, bias); + } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2) && inner_size % 2 == 0) { + launch_biasN_vec_kernel(stream, output, input, inner_size, bias); + } else { + launch_biasN_vec_kernel(stream, output, input, inner_size, bias); + } + } + + template void biasN<__half>(const Stream&, TensorSpan<__half>, TensorView<__half>, std::size_t, TensorView<__half>); + template void biasN(const Stream&, TensorSpan, TensorView, std::size_t, TensorView); + + template static + void launch_scale1_vec_kernel(const Stream& stream, Span output, View input, T alpha) { + CV_Assert(is_fully_aligned(output, N)); + CV_Assert(is_fully_aligned(input, N)); + + auto kernel = raw::scale1_vec; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, input, alpha); + } + + template + void scale1(const Stream& stream, TensorSpan output, TensorView input, T alpha) { + CV_Assert(is_shape_same(input, output)); + + if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4)) { + launch_scale1_vec_kernel(stream, output, input, alpha); + } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2)) { + launch_scale1_vec_kernel(stream, output, input, alpha); + } else { + launch_scale1_vec_kernel(stream, output, input, alpha); + } + } + + template void scale1<__half>(const Stream&, TensorSpan<__half>, TensorView<__half>, __half); + template void scale1(const Stream&, TensorSpan, TensorView, float); + + template static + void launch_scaleN_vec_kernel(const Stream& stream, Span output, View input, std::size_t inner_size, View weights) { + CV_Assert(is_fully_aligned(output, N)); + CV_Assert(is_fully_aligned(input, N)); + CV_Assert(inner_size % N == 0); + + auto kernel = raw::scaleN_vec; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, input, inner_size, weights); + } + + template + void scaleN( + const Stream& stream, + TensorSpan output, + TensorView input, std::size_t inner_size, + TensorView weights) + { + CV_Assert(is_shape_same(input, output)); + + if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4) && inner_size % 4 == 0) { + launch_scaleN_vec_kernel(stream, output, input, inner_size, weights); + } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2) && inner_size % 2 == 0) { + launch_scaleN_vec_kernel(stream, output, input, inner_size, weights); + } else { + launch_scaleN_vec_kernel(stream, output, input, inner_size, weights); + } + } + + template void scaleN<__half>(const Stream&, TensorSpan<__half>, TensorView<__half>, std::size_t, TensorView<__half>); + template void scaleN(const Stream&, TensorSpan, TensorView, std::size_t, TensorView); + + template static + void launch_scale1_with_bias1_vec_kernel(const Stream& stream, Span output, View input, T alpha, T beta) { + CV_Assert(is_fully_aligned(output, N)); + CV_Assert(is_fully_aligned(input, N)); + + auto kernel = raw::scale1_with_bias1_vec; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, input, alpha, beta); + } + + template + void scale1_with_bias1(const Stream& stream, Span output, View input, T alpha, T beta) { + CV_Assert(output.size() == input.size()); + + if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4)) { + launch_scale1_with_bias1_vec_kernel(stream, output, input, alpha, beta); + } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2)) { + launch_scale1_with_bias1_vec_kernel(stream, output, input, alpha, beta); + } else { + launch_scale1_with_bias1_vec_kernel(stream, output, input, alpha, beta); + } + } + + template void scale1_with_bias1<__half>(const Stream&, Span<__half>, View<__half>, __half, __half); + template void scale1_with_bias1(const Stream&, Span, View, float, float); + + template static + void launch_scaleN_with_biasN_vec_kernel(const Stream& stream, Span output, View input, std::size_t inner_size, View weights, View bias) { + CV_Assert(is_fully_aligned(output, N)); + CV_Assert(is_fully_aligned(input, N)); + CV_Assert(inner_size % N == 0); + + auto kernel = raw::scaleN_with_biasN_vec; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, input, inner_size, weights, bias); + } + + template + void scaleN_with_biasN( + const Stream& stream, + TensorSpan output, + TensorView input, std::size_t inner_size, + TensorView weights, TensorView bias) + { + CV_Assert(is_shape_same(input, output)); + CV_Assert(weights.size() == bias.size()); + + if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4) && inner_size % 4 == 0) { + launch_scaleN_with_biasN_vec_kernel(stream, output, input, inner_size, weights, bias); + } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2) && inner_size % 2 == 0) { + launch_scaleN_with_biasN_vec_kernel(stream, output, input, inner_size, weights, bias); + } else { + launch_scaleN_with_biasN_vec_kernel(stream, output, input, inner_size, weights, bias); + } + } + + template void scaleN_with_biasN<__half>(const Stream&, TensorSpan<__half>, TensorView<__half>, std::size_t, TensorView<__half>, TensorView<__half>); + template void scaleN_with_biasN(const Stream&, TensorSpan, TensorView, std::size_t, TensorView, TensorView); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ diff --git a/modules/dnn/src/cuda/slice.cu b/modules/dnn/src/cuda/slice.cu new file mode 100644 index 0000000..a6e3a94 --- /dev/null +++ b/modules/dnn/src/cuda/slice.cu @@ -0,0 +1,169 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include +#include + +#include "array.hpp" +#include "types.hpp" +#include "grid_stride_range.hpp" +#include "execution.hpp" +#include "kernel_dispatcher.hpp" + +#include "../cuda4dnn/csl/stream.hpp" +#include "../cuda4dnn/csl/tensor.hpp" +#include "../cuda4dnn/csl/span.hpp" + +#include + +#include +#include +#include + +using namespace cv::dnn::cuda4dnn::csl; +using namespace cv::dnn::cuda4dnn::csl::device; + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + namespace raw { + template + __global__ void slice( + Span output, array out_strides, + View input, array in_strides, array in_offset) + { + for (auto i : grid_stride_range(output.size())) { + index_type out_index = i / out_strides[0]; + index_type in_index = in_offset[0] + out_index; + index_type iidx = in_index * in_strides[0]; + for (int j = 1; j < Rank; j++) { + out_index = (i % out_strides[j - 1]) / out_strides[j]; + in_index = in_offset[j] + out_index; + iidx += in_index * in_strides[j]; + } + + output[i] = input[iidx]; + } + } + } + + template static + void launch_slice( + const Stream& stream, + Span output, const std::vector& outStride, + View input, const std::vector& inStride, const std::vector& inOffset) + { + CV_Assert(outStride.size() == Rank); + CV_Assert(inStride.size() == Rank); + CV_Assert(inOffset.size() == Rank); + + array outStride_k, inStride_k; + outStride_k.assign(std::begin(outStride), std::end(outStride)); + inStride_k.assign(std::begin(inStride), std::end(inStride)); + + array inOffset_k; + inOffset_k.assign(std::begin(inOffset), std::end(inOffset)); + + auto kernel = raw::slice; + auto policy = make_policy(kernel, output.size(), 0, stream); + launch_kernel(kernel, policy, output, outStride_k, input, inStride_k, inOffset_k); + } + + GENERATE_KERNEL_DISPATCHER(slice_dispatcher, launch_slice); + + template + void slice(const Stream& stream, + TensorSpan output, TensorView input, + std::vector offsets) + { + CV_Assert(output.rank() == input.rank()); + CV_Assert(output.rank() == offsets.size()); + + /* squeezable axes at the begining of both tensors can be eliminated + * + * Reasoning: + * ---------- + * Suppose an item's indices in the output tensor is [o1, o2, ...]. The indices in the input + * tensor will be [o1 + off1, o2 + off2, ...]. The rest of the elements in the input are igored. + * + * If the size of the first axis of the input and output tensor is unity, the input and output indices + * for all the elements will be of the form be [0, o2 + off2, ...] and [0, o2, ...] respectively. Note that + * there cannot be any ignored items since the axes have unit size. The first index does not contribute to the + * element's address calculation and hence does nothing apart from eating up few cycles. + */ + while (input.get_axis_size(0) == 1 && output.get_axis_size(0) == 1) { + CV_Assert(offsets[0] == 0); + + input.squeeze(0); + output.squeeze(0); + offsets.erase(std::begin(offsets)); + + CV_Assert(output.rank() == input.rank()); + CV_Assert(output.rank() == offsets.size()); + } + + auto inShape = input.shape_as_vector(); + auto outShape = output.shape_as_vector(); + + /* contiguous axes which do not undergo slicing can be combined into one axis + * + * Reasoning: + * ---------- + * Suppose an item's indices in the output tensor is [o1, o2, o3, ...]. Let the first two axes not undergo any + * slicing. The indices in the input tensor will be [o1, o2, o3 + off3, ...]. + * + * Each axis in the contiguous unsliced axes sequence will add an offset of iN * strideN. In the above example, + * the two axes add a total offset of `o1 * stride1 + o2 * stride2`. We can merge the two axes into one axis with + * a size of `size1 * size2`. The new offset added will be o12 * stride2` as the kernel iterates through `o12`. + * Note that `o12` is actually `(o1 * size2 + o2)` in the original tensor. + */ + for (int i = 0; i < inShape.size(); i++) { + /* check if axis `i` requires any slicing */ + if (offsets[i] == 0 && inShape[i] == outShape[i]) { + /* loop invariant: `i` is the first axis in the contiguous unsliced axis sequence */ + + int j = i + 1; /* `j` is the axis which we will attempt to merge */ + while (j < inShape.size() && offsets[j] == 0 && inShape[j] == outShape[j]) { + /* `j` axis is also unsliced; merge `i` and `j` */ + auto new_size = inShape[i] * inShape[j]; + inShape[i] = new_size; + outShape[i] = new_size; + offsets[i] = 0; /* redundant */ + + /* delete axis `j` */ + inShape.erase(std::begin(inShape) + j); + outShape.erase(std::begin(outShape) + j); + offsets.erase(std::begin(offsets) + j); + + /* optimizations should not break the invariants */ + CV_Assert(inShape.size() == outShape.size()); + CV_Assert(inShape.size() == offsets.size()); + CV_Assert(inShape[i] == outShape[i]); + CV_Assert(offsets[i] == 0); + } + } + } + + auto rank = inShape.size(); + + std::vector inStride(rank), outStride(rank); + inStride.back() = 1; + outStride.back() = 1; + /* garbage, ..., garbage, 1 */ + + std::copy(std::begin(inShape) + 1, std::end(inShape), std::begin(inStride)); + std::copy(std::begin(outShape) + 1, std::end(outShape), std::begin(outStride)); + /* dim[0], dim[1], ..., dim[-1], 1 */ + + std::partial_sum(inStride.rbegin(), inStride.rend(), inStride.rbegin(), std::multiplies()); + std::partial_sum(outStride.rbegin(), outStride.rend(), outStride.rbegin(), std::multiplies()); + /* stride[0], stride[1], ..., stride[-2], 1 */ + + CV_Assert(1 <= rank && rank <= CSL_MAX_TENSOR_RANK); + slice_dispatcher(rank, stream, output, outStride, input, inStride, offsets); + } + + template void slice(const Stream&, TensorSpan<__half>, TensorView<__half>, std::vector); + template void slice(const Stream&, TensorSpan, TensorView, std::vector); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ diff --git a/modules/dnn/src/cuda/test.cu b/modules/dnn/src/cuda/test.cu deleted file mode 100644 index 1a50e97..0000000 --- a/modules/dnn/src/cuda/test.cu +++ /dev/null @@ -1,18 +0,0 @@ -// This file is part of OpenCV project. -// It is subject to the license terms in the LICENSE file found in the top-level directory -// of this distribution and at http://opencv.org/license.html. - -// this file is a stub and will be removed once actual code is added - -#include "../precomp.hpp" - -#include - -#ifndef HAVE_CUDA -# error "CUDA files should not be compiled if CUDA was not enabled" -#endif - -__global__ void cuda4dnn_build_test_kernel(float* addr) { - int idx = threadIdx.x; - addr[idx] = 0.0; -} diff --git a/modules/dnn/src/cuda/types.hpp b/modules/dnn/src/cuda/types.hpp new file mode 100644 index 0000000..258aacf --- /dev/null +++ b/modules/dnn/src/cuda/types.hpp @@ -0,0 +1,27 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA_TYPES_HPP +#define OPENCV_DNN_SRC_CUDA_TYPES_HPP + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace device { + + /* For indices, we can use 32bit variables or 64bit variables. The GPU registers are 32 bits in size. + * Hence, a 64bit variable requires two registers and is significantly slower than the 32bit versions. + * + * If we do not need to handle huge tensors, we can use 32-bit indices and get better performance. + */ +#ifdef __CUDACC__ + using size_type = int; + using index_type = int; +#else + using size_type = std::int32_t; + using index_type = std::int32_t; +#endif + +}}}}} /* namespace cv::dnn::cuda4dnn::csl::device */ + +#endif /* OPENCV_DNN_SRC_CUDA_TYPES_HPP */ diff --git a/modules/dnn/src/cuda/vector_traits.hpp b/modules/dnn/src/cuda/vector_traits.hpp new file mode 100644 index 0000000..b10bcd3 --- /dev/null +++ b/modules/dnn/src/cuda/vector_traits.hpp @@ -0,0 +1,109 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA_VECTOR_TRAITS_HPP +#define OPENCV_DNN_SRC_CUDA_VECTOR_TRAITS_HPP + +#include + +#include "types.hpp" + +#include "../cuda4dnn/csl/pointer.hpp" + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace device { + + /** \file vector_traits.hpp + * \brief utility classes and functions for vectorized memory loads/stores + * + * Example: + * using vector_type = get_vector_type_t; + * + * auto input_vPtr = type::get_pointer(iptr); // iptr is of type DevicePtr + * auto output_vPtr = type::get_pointer(optr); // optr is of type DevicePtr + * + * vector_type vec; + * v_load(vec, input_vPtr); + * + * for(int i = 0; i < vector_type::size(); i++) + * vec[i] = do_something(vec[i]); + * + * v_store(output_vPtr, vec); + */ + + namespace detail { + template struct raw_type_ { }; + template <> struct raw_type_<256> { typedef ulonglong4 type; }; + template <> struct raw_type_<128> { typedef uint4 type; }; + template <> struct raw_type_<64> { typedef uint2 type; }; + template <> struct raw_type_<32> { typedef uint1 type; }; + template <> struct raw_type_<16> { typedef uchar2 type; }; + template <> struct raw_type_<8> { typedef uchar1 type; }; + + template struct raw_type { + using type = typename raw_type_::type; + static_assert(sizeof(type) * 8 == N, ""); + }; + } + + /* \tparam T type of element in the vector + * \tparam N "number of elements" of type T in the vector + */ + template + union vector_type { + using value_type = T; + using raw_type = typename detail::raw_type::type; + + __device__ vector_type() { } + + __device__ static constexpr size_type size() { return N; } + + raw_type raw; + T data[N]; + + template static __device__ + typename std::enable_if::value, const vector_type*> + ::type get_pointer(csl::DevicePtr ptr) { + return reinterpret_cast(ptr.get()); + } + + template static __device__ + typename std::enable_if::value, vector_type*> + ::type get_pointer(csl::DevicePtr ptr) { + return reinterpret_cast(ptr.get()); + } + }; + + template + __device__ void v_load(V& dest, const V& src) { + dest.raw = src.raw; + } + + template + __device__ void v_load(V& dest, const V* src) { + dest.raw = src->raw; + } + + template + __device__ void v_store(V* dest, const V& src) { + dest->raw = src.raw; + } + + template + __device__ void v_store(V& dest, const V& src) { + dest.raw = src.raw; + } + + template + struct get_vector_type { + typedef vector_type type; + }; + + template + using get_vector_type_t = typename get_vector_type::type; + +}}}}} /* namespace cv::dnn::cuda4dnn::csl::device */ + +#endif /* OPENCV_DNN_SRC_CUDA_VECTOR_TRAITS_HPP */ diff --git a/modules/dnn/src/cuda4dnn/csl/cublas.hpp b/modules/dnn/src/cuda4dnn/csl/cublas.hpp new file mode 100644 index 0000000..8320767 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/csl/cublas.hpp @@ -0,0 +1,230 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_CSL_CUBLAS_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_CSL_CUBLAS_HPP + +#include "error.hpp" +#include "stream.hpp" +#include "pointer.hpp" +#include "fp16.hpp" + +#include + +#include + +#include +#include +#include + +#define CUDA4DNN_CHECK_CUBLAS(call) \ + ::cv::dnn::cuda4dnn::csl::cublas::detail::check((call), CV_Func, __FILE__, __LINE__) + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cublas { + + /** @brief exception class for errors thrown by the cuBLAS API */ + class cuBLASException : public CUDAException { + public: + using CUDAException::CUDAException; + }; + + namespace detail { + static void check(cublasStatus_t status, const char* func, const char* file, int line) { + auto cublasGetErrorString = [](cublasStatus_t err) { + switch (err) { + case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR"; + } + return "UNKNOWN_CUBLAS_ERROR"; + }; + + if (status != CUBLAS_STATUS_SUCCESS) + throw cuBLASException(Error::GpuApiCallError, cublasGetErrorString(status), func, file, line); + } + } + + /** noncopyable cuBLAS smart handle + * + * UniqueHandle is a smart non-sharable wrapper for cuBLAS handle which ensures that the handle + * is destroyed after use. The handle can be associated with a CUDA stream by specifying the + * stream during construction. By default, the handle is associated with the default stream. + */ + class UniqueHandle { + public: + UniqueHandle() { CUDA4DNN_CHECK_CUBLAS(cublasCreate(&handle)); } + UniqueHandle(UniqueHandle&) = delete; + UniqueHandle(UniqueHandle&& other) noexcept + : stream(std::move(other.stream)), handle{ other.handle } { + other.handle = nullptr; + } + + UniqueHandle(Stream strm) : stream(std::move(strm)) { + CUDA4DNN_CHECK_CUBLAS(cublasCreate(&handle)); + try { + CUDA4DNN_CHECK_CUBLAS(cublasSetStream(handle, stream.get())); + } catch (...) { + /* cublasDestroy won't throw if a valid handle is passed */ + CUDA4DNN_CHECK_CUBLAS(cublasDestroy(handle)); + throw; + } + } + + ~UniqueHandle() noexcept { + if (handle != nullptr) { + /* cublasDestroy won't throw if a valid handle is passed */ + CUDA4DNN_CHECK_CUBLAS(cublasDestroy(handle)); + } + } + + UniqueHandle& operator=(const UniqueHandle&) = delete; + UniqueHandle& operator=(UniqueHandle&& other) noexcept { + stream = std::move(other.stream); + handle = other.handle; + other.handle = nullptr; + return *this; + } + + /** @brief returns the raw cuBLAS handle */ + cublasHandle_t get() const noexcept { return handle; } + + private: + Stream stream; + cublasHandle_t handle; + }; + + /** @brief sharable cuBLAS smart handle + * + * Handle is a smart sharable wrapper for cuBLAS handle which ensures that the handle + * is destroyed after all references to the handle are destroyed. The handle can be + * associated with a CUDA stream by specifying the stream during construction. By default, + * the handle is associated with the default stream. + * + * @note Moving a Handle object to another invalidates the former + */ + class Handle { + public: + Handle() : handle(std::make_shared()) { } + Handle(const Handle&) = default; + Handle(Handle&&) = default; + Handle(Stream strm) : handle(std::make_shared(std::move(strm))) { } + + Handle& operator=(const Handle&) = default; + Handle& operator=(Handle&&) = default; + + /** returns true if the handle is valid */ + explicit operator bool() const noexcept { return static_cast(handle); } + + cublasHandle_t get() const noexcept { + CV_Assert(handle); + return handle->get(); + } + + private: + std::shared_ptr handle; + }; + + /** @brief GEMM for colummn-major matrices + * + * \f$ C = \alpha AB + \beta C \f$ + * + * @tparam T matrix element type (must be `half` or `float`) + * + * @param handle valid cuBLAS Handle + * @param transa use transposed matrix of A for computation + * @param transb use transposed matrix of B for computation + * @param rows_c number of rows in C + * @param cols_c number of columns in C + * @param common_dim common dimension of A (or trans A) and B (or trans B) + * @param alpha scale factor for AB + * @param[in] A pointer to column-major matrix A in device memory + * @param lda leading dimension of matrix A + * @param[in] B pointer to column-major matrix B in device memory + * @param ldb leading dimension of matrix B + * @param beta scale factor for C + * @param[in,out] C pointer to column-major matrix C in device memory + * @param ldc leading dimension of matrix C + * + * Exception Guarantee: Basic + */ + template + void gemm(const Handle& handle, + bool transa, bool transb, + std::size_t rows_c, std::size_t cols_c, std::size_t common_dim, + T alpha, const DevicePtr A, std::size_t lda, + const DevicePtr B, std::size_t ldb, + T beta, const DevicePtr C, std::size_t ldc); + + template <> inline + void gemm(const Handle& handle, + bool transa, bool transb, + std::size_t rows_c, std::size_t cols_c, std::size_t common_dim, + half alpha, const DevicePtr A, std::size_t lda, + const DevicePtr B, std::size_t ldb, + half beta, const DevicePtr C, std::size_t ldc) + { + CV_Assert(handle); + + auto opa = transa ? CUBLAS_OP_T : CUBLAS_OP_N, + opb = transb ? CUBLAS_OP_T : CUBLAS_OP_N; + int irows_c = static_cast(rows_c), + icols_c = static_cast(cols_c), + icommon_dim = static_cast(common_dim), + ilda = static_cast(lda), + ildb = static_cast(ldb), + ildc = static_cast(ldc); + + CUDA4DNN_CHECK_CUBLAS( + cublasHgemm( + handle.get(), + opa, opb, + irows_c, icols_c, icommon_dim, + &alpha, A.get(), ilda, + B.get(), ildb, + &beta, C.get(), ildc + ) + ); + } + + template <> inline + void gemm(const Handle& handle, + bool transa, bool transb, + std::size_t rows_c, std::size_t cols_c, std::size_t common_dim, + float alpha, const DevicePtr A, std::size_t lda, + const DevicePtr B, std::size_t ldb, + float beta, const DevicePtr C, std::size_t ldc) + { + CV_Assert(handle); + + auto opa = transa ? CUBLAS_OP_T : CUBLAS_OP_N, + opb = transb ? CUBLAS_OP_T : CUBLAS_OP_N; + int irows_c = static_cast(rows_c), + icols_c = static_cast(cols_c), + icommon_dim = static_cast(common_dim), + ilda = static_cast(lda), + ildb = static_cast(ldb), + ildc = static_cast(ldc); + + CUDA4DNN_CHECK_CUBLAS( + cublasSgemm( + handle.get(), + opa, opb, + irows_c, icols_c, icommon_dim, + &alpha, A.get(), ilda, + B.get(), ildb, + &beta, C.get(), ildc + ) + ); + } + +}}}}} /* namespace cv::dnn::cuda4dnn::csl::cublas */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_CUBLAS_HPP */ diff --git a/modules/dnn/src/cuda4dnn/csl/cudnn.hpp b/modules/dnn/src/cuda4dnn/csl/cudnn.hpp new file mode 100644 index 0000000..049f146 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/csl/cudnn.hpp @@ -0,0 +1,10 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_CSL_CUDNN_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_CSL_CUDNN_HPP + +#include "cudnn/cudnn.hpp" + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_CUDNN_HPP */ diff --git a/modules/dnn/src/cuda4dnn/csl/cudnn/convolution.hpp b/modules/dnn/src/cuda4dnn/csl/cudnn/convolution.hpp new file mode 100644 index 0000000..792776e --- /dev/null +++ b/modules/dnn/src/cuda4dnn/csl/cudnn/convolution.hpp @@ -0,0 +1,408 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_CUDA4DNN_CSL_CUDNN_CONVOLUTION_HPP +#define OPENCV_DNN_CUDA4DNN_CSL_CUDNN_CONVOLUTION_HPP + +#include "cudnn.hpp" + +#include "../pointer.hpp" +#include "../workspace.hpp" + +#include + +#include +#include +#include +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cudnn { + + /** describe convolution filters + * + * @tparam T type of elements in the kernels + */ + template + class FilterDescriptor { + public: + FilterDescriptor() noexcept : descriptor{ nullptr } { } + FilterDescriptor(const FilterDescriptor&) = delete; + FilterDescriptor(FilterDescriptor&& other) noexcept + : descriptor{ other.descriptor } { + other.descriptor = nullptr; + } + + /** constructs a filter descriptor from the filter dimensions provided in \p shape + * + * Shape dimensions: + * 0: number of filters + * 1: number of input feature maps + * 2..n: kernel dimensions + * + * Exception Guarantee: Strong + */ + template ()))> + FilterDescriptor(const SequenceContainer& shape) { + constructor(shape.begin(), shape.end()); + } + + /** constructs a filter descriptor from the filter dimensions provided in [begin, end) + * + * Shape dimensions: + * 0: number of filters + * 1: number of input feature maps + * 2..n: kernel dimensions + * + * Exception Guarantee: Strong + */ + template ::value, void>::type> // TODO is_iterator + FilterDescriptor(ForwardItr begin, ForwardItr end) { + constructor(begin, end); + } + + /** constructs a filter descriptor from the filter dimensions provided as arguments + * + * Shape dimensions: + * 0: number of filters + * 1: number of input feature maps + * 2..n: kernel dimensions + * + * Exception Guarantee: Strong + */ + template + FilterDescriptor(Sizes ...sizes) { + static_assert(sizeof...(Sizes) >= 3, "filter descriptors must have at least three dimensions"); + static_assert(sizeof...(Sizes) <= CUDNN_DIM_MAX, "required rank exceeds maximum supported rank"); + std::array dims = { static_cast(sizes)... }; + constructor(std::begin(dims), std::end(dims)); + } + + ~FilterDescriptor() noexcept { + if (descriptor != nullptr) { + /* cudnnDestroyFilterDescriptor will not fail for a valid descriptor object */ + CUDA4DNN_CHECK_CUDNN(cudnnDestroyFilterDescriptor(descriptor)); + } + } + + FilterDescriptor& operator=(const FilterDescriptor&) = delete; + FilterDescriptor& operator=(FilterDescriptor&& other) noexcept { + descriptor = other.descriptor; + other.descriptor = nullptr; + return *this; + }; + + cudnnFilterDescriptor_t get() const noexcept { return descriptor; } + + private: + template + void constructor(ForwardItr start, ForwardItr end) { + CV_Assert(start != end); + CV_Assert(std::distance(start, end) >= 3); + CV_Assert(std::distance(start, end) <= CUDNN_DIM_MAX); + + CUDA4DNN_CHECK_CUDNN(cudnnCreateFilterDescriptor(&descriptor)); + try { + const auto rank = std::distance(start, end); + if (rank == 4) { + std::array dims; + std::copy(start, end, std::begin(dims)); + CUDA4DNN_CHECK_CUDNN( + cudnnSetFilter4dDescriptor( + descriptor, + detail::get_data_type(), CUDNN_TENSOR_NCHW, + dims[0], dims[1], dims[2], dims[3] + ) + ); + } else { + std::vector dims(start, end); + CUDA4DNN_CHECK_CUDNN( + cudnnSetFilterNdDescriptor( + descriptor, + detail::get_data_type(), CUDNN_TENSOR_NCHW, + dims.size(), dims.data() + ) + ); + } + } catch (...) { + /* cudnnDestroyFilterDescriptor will not fail for a valid descriptor object */ + CUDA4DNN_CHECK_CUDNN(cudnnDestroyFilterDescriptor(descriptor)); + throw; + } + } + + cudnnFilterDescriptor_t descriptor; + }; + + /** describes a convolution operation + * + * @tparam T type of element participating in convolution + */ + template + class ConvolutionDescriptor { + public: + ConvolutionDescriptor() noexcept : descriptor{ nullptr } { } + ConvolutionDescriptor(const ConvolutionDescriptor&) = delete; + ConvolutionDescriptor(ConvolutionDescriptor&& other) noexcept + : descriptor{ other.descriptor } { + other.descriptor = nullptr; + } + + /** constructs a convolution descriptor + * + * Pre-conditions: + * - \p zero_padding, \p stride and \p dilation must have the same size + * + * The length of the containers is interpreted as the order of the convolution. + * + * Exception Guarantee: Strong + */ + template ()))> + ConvolutionDescriptor( + const SequenceContainer& zero_padding, + const SequenceContainer& stride, + const SequenceContainer& dilation, + std::size_t group_count) + { + constructor(zero_padding, stride, dilation, group_count); + } + + ~ConvolutionDescriptor() noexcept { + if (descriptor != nullptr) { + /* cudnnDestroyConvolutionDescriptor will not fail for a valid descriptor object */ + CUDA4DNN_CHECK_CUDNN(cudnnDestroyConvolutionDescriptor(descriptor)); + } + } + + ConvolutionDescriptor& operator=(const ConvolutionDescriptor&) = delete; + ConvolutionDescriptor& operator=(ConvolutionDescriptor&& other) noexcept { + descriptor = other.descriptor; + other.descriptor = nullptr; + return *this; + }; + + cudnnConvolutionDescriptor_t get() const noexcept { return descriptor; } + + private: + template + void constructor( + const SequenceContainer& zero_padding, + const SequenceContainer& stride, + const SequenceContainer& dilation, + std::size_t group_count) + { + CV_Assert(zero_padding.size() == stride.size()); + CV_Assert(zero_padding.size() == dilation.size()); + + CUDA4DNN_CHECK_CUDNN(cudnnCreateConvolutionDescriptor(&descriptor)); + try { + const auto rank = zero_padding.size(); + if (rank == 2) { + CUDA4DNN_CHECK_CUDNN( + cudnnSetConvolution2dDescriptor( + descriptor, + zero_padding[0], zero_padding[1], + stride[0], stride[1], + dilation[0], dilation[1], + CUDNN_CROSS_CORRELATION, + detail::get_data_type() + ) + ); + } else { + std::vector ipadding(std::begin(zero_padding), std::end(zero_padding)); + std::vector istride(std::begin(stride), std::end(stride)); + std::vector idilation(std::begin(dilation), std::end(dilation)); + CUDA4DNN_CHECK_CUDNN( + cudnnSetConvolutionNdDescriptor( + descriptor, + rank, ipadding.data(), istride.data(), idilation.data(), + CUDNN_CROSS_CORRELATION, + detail::get_data_type() + ) + ); + } + CUDA4DNN_CHECK_CUDNN(cudnnSetConvolutionGroupCount(descriptor, group_count)); + } catch (...) { + /* cudnnDestroyConvolutionDescriptor will not fail for a valid desriptor object */ + CUDA4DNN_CHECK_CUDNN(cudnnDestroyConvolutionDescriptor(descriptor)); + throw; + } + } + + cudnnConvolutionDescriptor_t descriptor; + }; + + /** wrapper around a convolution algorithm + * + * @tparam T type of elements being convolved + */ + template + class ConvolutionAlgorithm { + public: + ConvolutionAlgorithm() noexcept : workspace_size{ 0 } { } + ConvolutionAlgorithm(ConvolutionAlgorithm&) = default; + ConvolutionAlgorithm(ConvolutionAlgorithm&&) = default; + + /** selects a good algorithm for convolution for given configuration + * + * Exception Guarantee: Strong + */ + ConvolutionAlgorithm( + const Handle& handle, + const ConvolutionDescriptor& conv, + const FilterDescriptor& filter, + const TensorDescriptor& input, + const TensorDescriptor& output) + { + CUDA4DNN_CHECK_CUDNN( + cudnnGetConvolutionForwardAlgorithm( + handle.get(), + input.get(), filter.get(), conv.get(), output.get(), + CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, + 0, /* no memory limit */ + &algo + ) + ); + + CUDA4DNN_CHECK_CUDNN( + cudnnGetConvolutionForwardWorkspaceSize( + handle.get(), + input.get(), filter.get(), conv.get(), output.get(), + algo, &workspace_size + ) + ); + } + + ConvolutionAlgorithm& operator=(const ConvolutionAlgorithm&) = default; + ConvolutionAlgorithm& operator=(ConvolutionAlgorithm&& other) = default; + + cudnnConvolutionFwdAlgo_t get() const noexcept { return algo; } + + /** number of bytes of workspace memory required by the algorithm */ + std::size_t get_workspace_size() const noexcept { return workspace_size; } + + private: + cudnnConvolutionFwdAlgo_t algo; + std::size_t workspace_size; + }; + + /** gives the shape of the output tensor of convolution + * + * Exception Guarantee: Basic + */ + template + void getConvolutionForwardOutputDim( + const ConvolutionDescriptor& convDesc, + const FilterDescriptor& filterDesc, + const TensorDescriptor& inputDesc, + std::vector& output) + { + output.clear(); + output.resize(CUDNN_DIM_MAX); /* we use `output` to hold temporaries */ + + std::vector temp(CUDNN_DIM_MAX); + cudnnDataType_t tempDataType; + CUDA4DNN_CHECK_CUDNN( + cudnnGetTensorNdDescriptor( + inputDesc.get(), + CUDNN_DIM_MAX + 1, /* according to docs, this is what we do to get the rank */ + &tempDataType, + output.data(), + temp.data(), + temp.data() + ) + ); + + const auto rank = output[0]; + output.resize(rank); + CUDA4DNN_CHECK_CUDNN( + cudnnGetConvolutionNdForwardOutputDim( + convDesc.get(), inputDesc.get(), filterDesc.get(), rank, output.data() + ) + ); + } + + /** @brief performs convolution + * + * dstValue = alpha * result + beta * priorDstValue + * + * @tparam T convolution element type (must be `half` or `float`) + * + * @param handle valid cuDNN Handle + * @param convDesc convolution description + * @param convAlgo algorithm to use for convolution + * @param workspace workspace memory which meets the requirements of \p convAlgo + * @param filterDesc filter descriptor + * @param[in] filterPtr pointer to device memory containing the filters + * @param inputDesc tensor descriptor describing the input + * @param[in] inputPtr pointer to input tensor in device memory + * @param alpha result scale factor + * @param beta previous value scale factor + * @param outputDesc tensor descriptor describing the output + * @param[out] outputPtr pointer to output tensor in device memory + * + * Exception Guarantee: Basic + */ + template + void convolve( + const Handle& handle, + const ConvolutionDescriptor& convDesc, + const ConvolutionAlgorithm& convAlgo, + WorkspaceInstance workspace, + const FilterDescriptor& filterDesc, + DevicePtr filterPtr, + const TensorDescriptor& inputDesc, + DevicePtr inputPtr, + T alpha, T beta, + const TensorDescriptor& outputDesc, + DevicePtr outputPtr) + { + CV_Assert(handle); + + CUDA4DNN_CHECK_CUDNN( + cudnnConvolutionForward( + handle.get(), + &alpha, inputDesc.get(), inputPtr.get(), + filterDesc.get(), filterPtr.get(), + convDesc.get(), convAlgo.get(), + static_cast(workspace.get()), workspace.size_in_bytes(), + &beta, outputDesc.get(), outputPtr.get() + ) + ); + } + + template <> inline + void convolve( + const Handle& handle, + const ConvolutionDescriptor& convDesc, + const ConvolutionAlgorithm& convAlgo, + WorkspaceInstance workspace, + const FilterDescriptor& filterDesc, + DevicePtr filterPtr, + const TensorDescriptor& inputDesc, + DevicePtr inputPtr, + half alpha, half beta, + const TensorDescriptor& outputDesc, + DevicePtr outputPtr) + { + CV_Assert(handle); + + /* we specalize for fp16 as the scaling factors must be provided as `float` */ + float alpha_ = alpha, beta_ = beta; + CUDA4DNN_CHECK_CUDNN( + cudnnConvolutionForward( + handle.get(), + &alpha_, inputDesc.get(), inputPtr.get(), + filterDesc.get(), filterPtr.get(), + convDesc.get(), convAlgo.get(), + static_cast(workspace.get()), workspace.size_in_bytes(), + &beta_, outputDesc.get(), outputPtr.get() + ) + ); + } + +}}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */ + +#endif /* OPENCV_DNN_CUDA4DNN_CSL_CUDNN_CONVOLUTION_HPP */ diff --git a/modules/dnn/src/cuda4dnn/csl/cudnn/cudnn.hpp b/modules/dnn/src/cuda4dnn/csl/cudnn/cudnn.hpp new file mode 100644 index 0000000..59d1989 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/csl/cudnn/cudnn.hpp @@ -0,0 +1,280 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_CUDA4DNN_CSL_CUDNN_CUDNN_HPP +#define OPENCV_DNN_CUDA4DNN_CSL_CUDNN_CUDNN_HPP + +#include "../fp16.hpp" +#include "../pointer.hpp" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#define CUDA4DNN_CHECK_CUDNN(call) \ + ::cv::dnn::cuda4dnn::csl::cudnn::detail::check((call), CV_Func, __FILE__, __LINE__) + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cudnn { + + /** @brief exception class for errors thrown by the cuDNN API */ + class cuDNNException : public CUDAException { + public: + using CUDAException::CUDAException; + }; + + namespace detail { + inline void check(cudnnStatus_t status, const char* func, const char* file, int line) { + if (status != CUDNN_STATUS_SUCCESS) + throw cuDNNException(Error::GpuApiCallError, cudnnGetErrorString(status), func, file, line); + } + + /** get_data_type returns the equivalent cudnn enumeration constant for type T */ + template auto get_data_type()->decltype(CUDNN_DATA_FLOAT); + template <> inline auto get_data_type()->decltype(CUDNN_DATA_HALF) { return CUDNN_DATA_HALF; } + template <> inline auto get_data_type()->decltype(CUDNN_DATA_FLOAT) { return CUDNN_DATA_FLOAT; } + } + + /** @brief noncopyable cuDNN smart handle + * + * UniqueHandle is a smart non-sharable wrapper for cuDNN handle which ensures that the handle + * is destroyed after use. + */ + class UniqueHandle { + public: + /** creates a cuDNN handle which executes in the default stream + * + * Exception Guarantee: Basic + */ + UniqueHandle() { CUDA4DNN_CHECK_CUDNN(cudnnCreate(&handle)); } + + UniqueHandle(UniqueHandle&) = delete; + UniqueHandle(UniqueHandle&& other) noexcept + : stream(std::move(other.stream)), handle{ other.handle } { + other.handle = nullptr; + } + + /** creates a cuDNN handle and associates it with the stream specified + * + * Exception Guarantee: Basic + */ + UniqueHandle(Stream strm) : stream(std::move(strm)) { + CUDA4DNN_CHECK_CUDNN(cudnnCreate(&handle)); + try { + CUDA4DNN_CHECK_CUDNN(cudnnSetStream(handle, stream.get())); + } catch (...) { + /* cudnnDestroy won't throw if a valid handle is passed */ + CUDA4DNN_CHECK_CUDNN(cudnnDestroy(handle)); + throw; + } + } + + ~UniqueHandle() noexcept { + if (handle != nullptr) { + /* cudnnDestroy won't throw if a valid handle is passed */ + CUDA4DNN_CHECK_CUDNN(cudnnDestroy(handle)); + } + } + + UniqueHandle& operator=(const UniqueHandle&) = delete; + UniqueHandle& operator=(UniqueHandle&& other) noexcept { + stream = std::move(other.stream); + handle = other.handle; + other.handle = nullptr; + return *this; + } + + /** returns the raw cuDNN handle */ + cudnnHandle_t get() const noexcept { return handle; } + + private: + Stream stream; + cudnnHandle_t handle; + }; + + /** @brief sharable cuDNN smart handle + * + * Handle is a smart sharable wrapper for cuDNN handle which ensures that the handle + * is destroyed after all references to the handle are destroyed. + * + * @note Moving a Handle object to another invalidates the former + */ + class Handle { + public: + /** creates a cuDNN handle which executes in the default stream + * + * Exception Guarantee: Basic + */ + Handle() : handle(std::make_shared()) { } + + Handle(const Handle&) = default; + Handle(Handle&&) = default; + + /** creates a cuDNN handle and associates it with the stream specified + * + * Exception Guarantee: Basic + */ + Handle(Stream strm) : handle(std::make_shared(std::move(strm))) { } + + Handle& operator=(const Handle&) = default; + Handle& operator=(Handle&&) = default; + + /** returns true if the handle is valid */ + explicit operator bool() const noexcept { return static_cast(handle); } + + cudnnHandle_t get() const noexcept { + CV_Assert(handle); + return handle->get(); + } + + private: + std::shared_ptr handle; + }; + + /** describe a tensor + * + * @tparam T type of elements in the tensor + */ + template + class TensorDescriptor { + public: + TensorDescriptor() noexcept : descriptor{ nullptr } { } + TensorDescriptor(const TensorDescriptor&) = delete; + TensorDescriptor(TensorDescriptor&& other) noexcept + : descriptor{ other.descriptor } { + other.descriptor = nullptr; + } + + /** constructs a tensor descriptor from the axis lengths provided in \p shape + * + * Exception Guarantee: Basic + */ + template ()))> + TensorDescriptor(const SequenceContainer& shape) { + constructor(shape.begin(), shape.end()); + } + + /** constructs a tensor descriptor from the axis lengths provided in [begin, end) + * + * Exception Guarantee: Basic + */ + template ::value, void>::type> // TODO is_iterator + TensorDescriptor(ForwardItr begin, ForwardItr end) { + constructor(begin, end); + } + + /** constructs a tensor descriptor from the axis lengths provided as arguments + * + * Exception Guarantee: Basic + */ + template + TensorDescriptor(Sizes ...sizes) { + static_assert(sizeof...(Sizes) <= CUDNN_DIM_MAX, "required rank exceeds maximum supported rank"); + std::array dims = { static_cast(sizes)... }; + constructor(std::begin(dims), std::end(dims)); + } + + ~TensorDescriptor() noexcept { + if (descriptor != nullptr) { + /* cudnnDestroyTensorDescriptor will not fail */ + CUDA4DNN_CHECK_CUDNN(cudnnDestroyTensorDescriptor(descriptor)); + } + } + + TensorDescriptor& operator=(const TensorDescriptor&) = delete; + TensorDescriptor& operator=(TensorDescriptor&& other) noexcept { + descriptor = other.descriptor; + other.descriptor = nullptr; + return *this; + }; + + cudnnTensorDescriptor_t get() const noexcept { return descriptor; } + + private: + template + void constructor(ForwardItr start, ForwardItr end) { + CV_Assert(start != end); + CV_Assert(std::distance(start, end) <= CUDNN_DIM_MAX); + + CUDA4DNN_CHECK_CUDNN(cudnnCreateTensorDescriptor(&descriptor)); + try { + /* cuDNN documentation recommends using the 4d tensor API whenever possible + * hence, we create a 4d tensor descriptors for 3d tensor + */ + const auto rank = std::distance(start, end); + if (rank <= 4) { + std::array dims; + std::fill(std::begin(dims), std::end(dims), 1); + + /* suppose we have a 3d tensor, the first axis is the batch axis and + * the second axis is the channel axis (generally) + * + * cuDNN frequently assumes that the first axis is the batch axis and the + * second axis is the channel axis; hence, we copy the shape of a lower rank + * tensor to the begining of `dims` + */ + std::copy(start, end, std::begin(dims)); + + CUDA4DNN_CHECK_CUDNN( + cudnnSetTensor4dDescriptor(descriptor, + CUDNN_TENSOR_NCHW, detail::get_data_type(), + dims[0], dims[1], dims[2], dims[3] + ) + ); + } else { + std::vector stride(rank); + stride.back() = 1; + /* WHAT WE HAVE NOW: + * stride[-1] = 1 + * stride[-2] = garbage + * stride[-3] = garbage + * stride[-4] = garbage + * ... + */ + + std::copy(start + 1, end, stride.begin()); + /* WHAT WE HAVE NOW: + * stride[-1] = 1 + * stride[-2] = dim[-1] + * stride[-3] = dim[-2] + * stride[-4] = dim[-3] + * ... + */ + + std::partial_sum(stride.rbegin(), stride.rend(), stride.rbegin(), std::multiplies()); + /* WHAT WE HAVE NOW: + * stride[-1] = 1 + * stride[-2] = stride[-1] * dim[-1] + * stride[-3] = stride[-2] * dim[-2] + * stride[-4] = stride[-3] * dim[-3] + * ... + */ + + std::vector dims(start, end); + CUDA4DNN_CHECK_CUDNN( + cudnnSetTensorNdDescriptor(descriptor, + detail::get_data_type(), rank, + dims.data(), stride.data() + ) + ); + } + } catch (...) { + /* cudnnDestroyTensorDescriptor will not fail */ + CUDA4DNN_CHECK_CUDNN(cudnnDestroyTensorDescriptor(descriptor)); + throw; + } + } + + cudnnTensorDescriptor_t descriptor; + }; + +}}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */ + +#endif /* OPENCV_DNN_CUDA4DNN_CSL_CUDNN_HPP */ diff --git a/modules/dnn/src/cuda4dnn/csl/cudnn/lrn.hpp b/modules/dnn/src/cuda4dnn/csl/cudnn/lrn.hpp new file mode 100644 index 0000000..f06b7b0 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/csl/cudnn/lrn.hpp @@ -0,0 +1,205 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_CUDA4DNN_CSL_CUDNN_LRN_HPP +#define OPENCV_DNN_CUDA4DNN_CSL_CUDNN_LRN_HPP + +#include "cudnn.hpp" + +#include "../pointer.hpp" +#include "../workspace.hpp" + +#include + +#include + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cudnn { + + class LRNDescriptor { + public: + enum class LRNType { + ACROSS_CHANNELS, + WITHIN_CHANNEL + }; + + LRNDescriptor() noexcept : descriptor{ nullptr } { } + LRNDescriptor(const LRNDescriptor&) = delete; + LRNDescriptor(LRNDescriptor&& other) noexcept + : descriptor{ other.descriptor }, type{ other.type } { + other.descriptor = nullptr; + } + + /** sets up a LRN descriptor + * + * @param local_size size of the normalization window + * @param alpha variance scaling parameter + * @param beta power parameter + * @param k bias parameter + * + * @note \p alpha is divided by the window width in across channels mode + * @note \p alpha is divided by the (window width)^spatialDimensions in within channel mode + * + * @note the \p alpha, \p beta and \p k will be type casted to the tensor datatype during operation + * + * Exception Guarantee: Basic + */ + LRNDescriptor(std::size_t local_size, double alpha, double beta, double k, LRNType type_) { + constructor(local_size, alpha, beta, k, type_); + } + + ~LRNDescriptor() noexcept { + if (descriptor != nullptr) { + /* cudnnDestroyLRNDescriptor will not fail for a valid descriptor */ + CUDA4DNN_CHECK_CUDNN(cudnnDestroyLRNDescriptor(descriptor)); + } + } + + LRNDescriptor& operator=(const LRNDescriptor&) = delete; + LRNDescriptor& operator=(LRNDescriptor&& other) noexcept { + descriptor = other.descriptor; + type = other.type; + other.descriptor = nullptr; + return *this; + }; + + cudnnLRNDescriptor_t get() const noexcept { return descriptor; } + LRNType getType() const noexcept { return type; } + + private: + void constructor(std::size_t local_size, double alpha, double beta, double k, LRNType type_) { + CV_Assert(CUDNN_LRN_MIN_N <= local_size && local_size <= CUDNN_LRN_MAX_N); + + type = type_; + + CUDA4DNN_CHECK_CUDNN(cudnnCreateLRNDescriptor(&descriptor)); + try { + CUDA4DNN_CHECK_CUDNN( + cudnnSetLRNDescriptor( + descriptor, + local_size, + alpha, + beta, + k + ) + ); + } catch (...) { + /* cudnnDestroyLRNDescriptor will not fail for a valid descriptor */ + CUDA4DNN_CHECK_CUDNN(cudnnDestroyLRNDescriptor(descriptor)); + throw; + } + } + + cudnnLRNDescriptor_t descriptor; + LRNType type; + }; + + /** @brief performs local response normalization + * + * dstValue = alpha * result + beta * priorDstValue + * + * @tparam T element type (must be `half` or `float`) + * + * @param handle valid cuDNN Handle + * @param lrnDesc LRN description + * @param inputDesc tensor descriptor describing the input + * @param[in] inputPtr pointer to input tensor in device memory + * @param alpha result scale factor + * @param beta previous value scale factor + * @param outputDesc tensor descriptor describing the output + * @param[out] outputPtr pointer to output tensor in device memory + * @param workspace workspace memory which meets the requirements of \p convAlgo + * + * Exception Guarantee: Basic + */ + template + void LRNForward( + const Handle& handle, + const LRNDescriptor& lrnDesc, + const TensorDescriptor& inputDesc, + DevicePtr inputPtr, + T alpha, T beta, + const TensorDescriptor& outputDesc, + DevicePtr outputPtr, + WorkspaceInstance workspace) + { + CV_Assert(handle); + + if (lrnDesc.getType() == LRNDescriptor::LRNType::ACROSS_CHANNELS) { + CUDA4DNN_CHECK_CUDNN( + cudnnLRNCrossChannelForward( + handle.get(), + lrnDesc.get(), CUDNN_LRN_CROSS_CHANNEL_DIM1, + &alpha, inputDesc.get(), inputPtr.get(), + &beta, outputDesc.get(), outputPtr.get() + ) + ); + } else if (lrnDesc.getType() == LRNDescriptor::LRNType::WITHIN_CHANNEL) { + std::size_t size; + CUDA4DNN_CHECK_CUDNN(cudnnGetTensorSizeInBytes(inputDesc.get(), &size)); + + DevicePtr temp1 = workspace.get_span(size).data(); + DevicePtr temp2 = workspace.get_span(size).data(); + + CUDA4DNN_CHECK_CUDNN( + cudnnDivisiveNormalizationForward( + handle.get(), + lrnDesc.get(), CUDNN_DIVNORM_PRECOMPUTED_MEANS, + &alpha, inputDesc.get(), inputPtr.get(), + NULL, + static_cast(temp1), static_cast(temp2), + &beta, outputDesc.get(), outputPtr.get() + ) + ); + } + } + + template <> inline + void LRNForward( + const Handle& handle, + const LRNDescriptor& lrnDesc, + const TensorDescriptor& inputDesc, + DevicePtr inputPtr, + half alpha, half beta, + const TensorDescriptor& outputDesc, + DevicePtr outputPtr, + WorkspaceInstance workspace) + { + CV_Assert(handle); + + /* we specalize for fp16 as the scaling factors must be provided as `float` */ + float alpha_ = alpha, beta_ = beta; + if (lrnDesc.getType() == LRNDescriptor::LRNType::ACROSS_CHANNELS) { + CUDA4DNN_CHECK_CUDNN( + cudnnLRNCrossChannelForward( + handle.get(), + lrnDesc.get(), CUDNN_LRN_CROSS_CHANNEL_DIM1, + &alpha_, inputDesc.get(), inputPtr.get(), + &beta_, outputDesc.get(), outputPtr.get() + ) + ); + } else if (lrnDesc.getType() == LRNDescriptor::LRNType::WITHIN_CHANNEL) { + std::size_t size; + CUDA4DNN_CHECK_CUDNN(cudnnGetTensorSizeInBytes(inputDesc.get(), &size)); + + DevicePtr temp1 = workspace.get_span(size).data(); + DevicePtr temp2 = workspace.get_span(size).data(); + + CUDA4DNN_CHECK_CUDNN( + cudnnDivisiveNormalizationForward( + handle.get(), + lrnDesc.get(), CUDNN_DIVNORM_PRECOMPUTED_MEANS, + &alpha_, inputDesc.get(), inputPtr.get(), + NULL, + static_cast(temp1), static_cast(temp2), + &beta_, outputDesc.get(), outputPtr.get() + ) + ); + } + } + +}}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */ + +#endif /* OPENCV_DNN_CUDA4DNN_CSL_CUDNN_LRN_HPP */ diff --git a/modules/dnn/src/cuda4dnn/csl/cudnn/pooling.hpp b/modules/dnn/src/cuda4dnn/csl/cudnn/pooling.hpp new file mode 100644 index 0000000..cee58f7 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/csl/cudnn/pooling.hpp @@ -0,0 +1,236 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_CUDA4DNN_CSL_CUDNN_POOLING_HPP +#define OPENCV_DNN_CUDA4DNN_CSL_CUDNN_POOLING_HPP + +#include "cudnn.hpp" + +#include "../pointer.hpp" + +#include + +#include + +#include +#include +#include +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cudnn { + + class PoolingDescriptor { + public: + enum class PoolingType { + MAX, + MAX_DETERMINISTIC, + AVERAGE_EXCLUDE_PADDING, + AVERAGE_INCLUDE_PADDING + }; + + PoolingDescriptor() noexcept : descriptor{ nullptr } { } + PoolingDescriptor(const PoolingDescriptor&) = delete; + PoolingDescriptor(PoolingDescriptor&& other) noexcept + : descriptor{ other.descriptor } { + other.descriptor = nullptr; + } + + /** constructs a pooling descriptor + * + * Pre-conditions: + * - \p window_size, \p padding and \p stride must have the same size + * + * The length of the containers is interpreted as the order of the pooling operation. + * + * Exception Guarantee: Basic + */ + template ()))> + PoolingDescriptor( + const SequenceContainer& window_size, + const SequenceContainer& padding, + const SequenceContainer& stride, + PoolingType type) + { + constructor(window_size, padding, stride, type); + } + + ~PoolingDescriptor() noexcept { + if (descriptor != nullptr) { + /* cudnnDestroyPoolingDescriptor will not fail for a valid descriptor */ + CUDA4DNN_CHECK_CUDNN(cudnnDestroyPoolingDescriptor(descriptor)); + } + } + + PoolingDescriptor& operator=(const PoolingDescriptor&) = delete; + PoolingDescriptor& operator=(PoolingDescriptor&& other) noexcept { + descriptor = other.descriptor; + other.descriptor = nullptr; + return *this; + }; + + cudnnPoolingDescriptor_t get() const noexcept { return descriptor; } + + private: + template + void constructor( + const SequenceContainer& window_size, + const SequenceContainer& padding, + const SequenceContainer& stride, + PoolingType type) + { + CV_Assert(window_size.size() == padding.size()); + CV_Assert(window_size.size() == stride.size()); + + auto get_pooling_type = [] (PoolingType type) { + switch (type) { + case PoolingType::MAX: + return CUDNN_POOLING_MAX; + case PoolingType::MAX_DETERMINISTIC: + return CUDNN_POOLING_MAX_DETERMINISTIC; + case PoolingType::AVERAGE_EXCLUDE_PADDING: + return CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; + case PoolingType::AVERAGE_INCLUDE_PADDING: + return CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + } + CV_Error(Error::StsBadArg, "unknown pooling type"); + }; + + CUDA4DNN_CHECK_CUDNN(cudnnCreatePoolingDescriptor(&descriptor)); + try { + const auto rank = window_size.size(); + if (rank == 2) { + CUDA4DNN_CHECK_CUDNN( + cudnnSetPooling2dDescriptor( + descriptor, + get_pooling_type(type), CUDNN_PROPAGATE_NAN, + window_size[0], window_size[1], + padding[0], padding[1], + stride[0], stride[1] + ) + ); + } else { + std::vector iwindow_size(std::begin(window_size), std::end(window_size)); + std::vector ipadding(std::begin(padding), std::end(padding)); + std::vector istride(std::begin(stride), std::end(stride)); + CUDA4DNN_CHECK_CUDNN( + cudnnSetPoolingNdDescriptor( + descriptor, + get_pooling_type(type), CUDNN_PROPAGATE_NAN, + rank, iwindow_size.data(), ipadding.data(), istride.data() + ) + ); + } + } catch (...) { + /* cudnnDestroyPoolingDescriptor will not fail for a valid descriptor */ + CUDA4DNN_CHECK_CUDNN(cudnnDestroyPoolingDescriptor(descriptor)); + throw; + } + } + + cudnnPoolingDescriptor_t descriptor; + }; + + /** gives the shape of the output tensor after pooling + * + * @note it's not required to enforce the this shape in the output tensor; slightly different shapes will work + * + * Exception Guarantee: Basic + */ + template inline + void getPoolingForwardOutputDim( + const PoolingDescriptor& poolingDesc, + const TensorDescriptor& inputDesc, + std::vector& output_dim) + { + output_dim.clear(); + output_dim.resize(CUDNN_DIM_MAX); /* we use `output_dim` to hold temporaries */ + + std::vector temp(CUDNN_DIM_MAX); + cudnnDataType_t tempDataType; + CUDA4DNN_CHECK_CUDNN( + cudnnGetTensorNdDescriptor( + inputDesc.get(), + CUDNN_DIM_MAX + 1, /* according to docs, this is what we do to get the rank */ + &tempDataType, + output_dim.data(), + temp.data(), + temp.data() + ) + ); + + const auto rank = output_dim[0]; + output_dim.resize(rank); + CUDA4DNN_CHECK_CUDNN( + cudnnGetPoolingNdForwardOutputDim(poolingDesc.get(), inputDesc.get(), rank, output_dim.data()) + ); + } + + /** @brief performs pooling operation + * + * dstValue = alpha * result + beta * priorDstValue + * + * @tparam T pooling element type (must be `half` or `float`) + * + * @param handle valid cuDNN Handle + * @param poolingDesc pooling description + * @param inputDesc tensor descriptor describing the input + * @param[in] inputPtr pointer to input tensor in device memory + * @param alpha result scale factor + * @param beta previous value scale factor + * @param outputDesc tensor descriptor describing the output + * @param[out] outputPtr pointer to output tensor in device memory + * + * Exception Guarantee: Basic + */ + template + void pool( + const Handle& handle, + const PoolingDescriptor& poolingDesc, + const TensorDescriptor& inputDesc, + const DevicePtr inputPtr, + T alpha, T beta, + const TensorDescriptor& outputDesc, + DevicePtr outputPtr) + { + CV_Assert(handle); + + CUDA4DNN_CHECK_CUDNN( + cudnnPoolingForward( + handle.get(), + poolingDesc.get(), + &alpha, inputDesc.get(), inputPtr.get(), + &beta, outputDesc.get(), outputPtr.get() + ) + ); + } + + template <> inline + void pool( + const Handle& handle, + const PoolingDescriptor& poolingDesc, + const TensorDescriptor& inputDesc, + const DevicePtr inputPtr, + half alpha, half beta, + const TensorDescriptor& outputDesc, + DevicePtr outputPtr) + { + CV_Assert(handle); + + /* we specalize for fp16 as the scaling factors must be provided as `float` */ + float alpha_ = alpha, beta_ = beta; + CUDA4DNN_CHECK_CUDNN( + cudnnPoolingForward( + handle.get(), + poolingDesc.get(), + &alpha_, inputDesc.get(), inputPtr.get(), + &beta_, outputDesc.get(), outputPtr.get() + ) + ); + } + +}}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */ + +#endif /* OPENCV_DNN_CUDA4DNN_CSL_CUDNN_POOLING_HPP */ diff --git a/modules/dnn/src/cuda4dnn/csl/cudnn/softmax.hpp b/modules/dnn/src/cuda4dnn/csl/cudnn/softmax.hpp new file mode 100644 index 0000000..251a321 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/csl/cudnn/softmax.hpp @@ -0,0 +1,68 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_CUDA4DNN_CSL_CUDNN_SOFTMAX_HPP +#define OPENCV_DNN_CUDA4DNN_CSL_CUDNN_SOFTMAX_HPP + +#include "cudnn.hpp" + +#include "../pointer.hpp" + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cudnn { + + /** @brief computes softmax (or log softmax) + * + * @tparam T element type (must be `half` or `float`) + * + * @param handle valid cuDNN handle + * @param outputDesc tensor descriptor for A + * @param[out] output pointer to tensor in device memory + * @param inputDesc tensor descriptor for C + * @param[in] input pointer to tensor in device memory + * @param log apply log on probabilities + * + * Exception Guarantee: Basic + */ + template + void softmax(const cudnn::Handle& handle, + const TensorDescriptor& outputDesc, DevicePtr output, + const TensorDescriptor& inputDesc, DevicePtr input, + bool log) + { + T alpha = 1.0, beta = 0.0; + cudnnSoftmaxAlgorithm_t algo = log ? CUDNN_SOFTMAX_LOG : CUDNN_SOFTMAX_ACCURATE; + CUDA4DNN_CHECK_CUDNN( + cudnnSoftmaxForward( + handle.get(), + algo, CUDNN_SOFTMAX_MODE_CHANNEL, + &alpha, inputDesc.get(), input.get(), + &beta, outputDesc.get(), output.get() + ) + ); + } + + template <> inline + void softmax(const cudnn::Handle& handle, + const TensorDescriptor& outputDesc, DevicePtr output, + const TensorDescriptor& inputDesc, DevicePtr input, + bool log) + { + /* we specalize for fp16 as the scaling factors must be provided as `float` */ + float alpha = 1.0, beta = 0.0; + cudnnSoftmaxAlgorithm_t algo = log ? CUDNN_SOFTMAX_LOG : CUDNN_SOFTMAX_ACCURATE; + CUDA4DNN_CHECK_CUDNN( + cudnnSoftmaxForward( + handle.get(), + algo, CUDNN_SOFTMAX_MODE_CHANNEL, + &alpha, inputDesc.get(), input.get(), + &beta, outputDesc.get(), output.get() + ) + ); + } + +}}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */ + +#endif /* OPENCV_DNN_CUDA4DNN_CSL_CUDNN_SOFTMAX_HPP */ diff --git a/modules/dnn/src/cuda4dnn/csl/cudnn/transform.hpp b/modules/dnn/src/cuda4dnn/csl/cudnn/transform.hpp new file mode 100644 index 0000000..029c5d8 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/csl/cudnn/transform.hpp @@ -0,0 +1,142 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_CUDA4DNN_CSL_CUDNN_TRANSFORM_HPP +#define OPENCV_DNN_CUDA4DNN_CSL_CUDNN_TRANSFORM_HPP + +#include "../pointer.hpp" + +#include "cudnn.hpp" + +#include +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cudnn { + + /** describes a tensor transform operation + * + * Supported transformations: + * - add or remove asymmetric padding + */ + class TensorTransformDescriptor { + public: + TensorTransformDescriptor() noexcept : descriptor{ nullptr } { } + TensorTransformDescriptor(const TensorTransformDescriptor&) = delete; + TensorTransformDescriptor(TensorTransformDescriptor&& other) noexcept + : descriptor{ other.descriptor } { + other.descriptor = nullptr; + } + + /** constructs a convolution descriptor + * + * Pre-conditions: + * - \p padding_left and \p padding_right must have the same size + * + * The length of the containers is interpreted as the rank of the tensors which will be given. + * + * @note \p padding_left and \p padding_right may have negative values to remove padding + * + * Exception Guarantee: Basic + */ + template ()))> + TensorTransformDescriptor( + const SequenceContainer& padding_left, + const SequenceContainer& padding_right) + { + constructor(padding_left, padding_right); + } + + ~TensorTransformDescriptor() noexcept { + if (descriptor != nullptr) { + /* cudnnDestroyTensorTransformDescriptor will not fail for a valid descriptor */ + CUDA4DNN_CHECK_CUDNN(cudnnDestroyTensorTransformDescriptor(descriptor)); + } + } + + TensorTransformDescriptor& operator=(const TensorTransformDescriptor&) = delete; + TensorTransformDescriptor& operator=(TensorTransformDescriptor&& other) noexcept { + descriptor = other.descriptor; + other.descriptor = nullptr; + return *this; + }; + + cudnnTensorTransformDescriptor_t get() const noexcept { return descriptor; } + + private: + template + void constructor( + const SequenceContainer& padding_left, + const SequenceContainer& padding_right + ) + { + CV_Assert(padding_left.size() == padding_right.size()); + + auto ipadding_left = std::vector(std::begin(padding_left), std::end(padding_left)); + auto ipadding_right = std::vector(std::begin(padding_right), std::end(padding_right)); + CUDA4DNN_CHECK_CUDNN(cudnnCreateTensorTransformDescriptor(&descriptor)); + try { + CUDA4DNN_CHECK_CUDNN( + cudnnSetTensorTransformDescriptor( + descriptor, + ipadding_left.size(), CUDNN_TENSOR_NCHW, + ipadding_left.data(), ipadding_right.data(), + NULL, CUDNN_TRANSFORM_FOLD + ) + ); + } catch (...) { + /* cudnnDestroyTensorTransformDescriptor will not fail for a valid descriptor */ + CUDA4DNN_CHECK_CUDNN(cudnnDestroyTensorTransformDescriptor(descriptor)); + throw; + } + } + + cudnnTensorTransformDescriptor_t descriptor; + }; + + template + void transform( + const Handle& handle, + const TensorTransformDescriptor& transDesc, + const TensorDescriptor& inputDesc, + DevicePtr inputPtr, + const TensorDescriptor& outputDesc, + DevicePtr outputPtr) + { + T alpha = 1.0, beta = 0.0; + CUDA4DNN_CHECK_CUDNN( + cudnnTransformTensorEx( + handle.get(), + transDesc.get(), + &alpha, inputDesc.get(), inputPtr.get(), + &beta, outputDesc.get(), outputPtr.get() + ) + ); + } + + template <> inline + void transform( + const Handle& handle, + const TensorTransformDescriptor& transDesc, + const TensorDescriptor& inputDesc, + DevicePtr inputPtr, + const TensorDescriptor& outputDesc, + DevicePtr outputPtr) + { + /* we specalize for fp16 as the scaling factors must be provided as `float` */ + float alpha = 1.0, beta = 0.0; + CUDA4DNN_CHECK_CUDNN( + cudnnTransformTensorEx( + handle.get(), + transDesc.get(), + &alpha, inputDesc.get(), inputPtr.get(), + &beta, outputDesc.get(), outputPtr.get() + ) + ); + } + +}}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */ + +#endif /* OPENCV_DNN_CUDA4DNN_CSL_CUDNN_TRANSFORM_HPP */ diff --git a/modules/dnn/src/cuda4dnn/csl/cudnn/transpose_convolution.hpp b/modules/dnn/src/cuda4dnn/csl/cudnn/transpose_convolution.hpp new file mode 100644 index 0000000..d1d26aa --- /dev/null +++ b/modules/dnn/src/cuda4dnn/csl/cudnn/transpose_convolution.hpp @@ -0,0 +1,148 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_CUDA4DNN_CSL_CUDNN_TRANSPOSE_CONVOLUTION_HPP +#define OPENCV_DNN_CUDA4DNN_CSL_CUDNN_TRANSPOSE_CONVOLUTION_HPP + +#include "cudnn.hpp" +#include "convolution.hpp" + +#include "../pointer.hpp" +#include "../workspace.hpp" + +#include + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cudnn { + + /** wrapper around a transpose convolution algorithm + * + * @tparam T type of elements being transpose-convolved + */ + template + class TransposeConvolutionAlgorithm { + public: + TransposeConvolutionAlgorithm() noexcept : workspace_size{ 0 } { } + TransposeConvolutionAlgorithm(TransposeConvolutionAlgorithm&) = default; + TransposeConvolutionAlgorithm(TransposeConvolutionAlgorithm&&) = default; + + TransposeConvolutionAlgorithm( + const Handle& handle, + const ConvolutionDescriptor& conv, + const FilterDescriptor& filter, + const TensorDescriptor& input, + const TensorDescriptor& output) + { + CUDA4DNN_CHECK_CUDNN( + cudnnGetConvolutionBackwardDataAlgorithm( + handle.get(), + filter.get(), input.get(), conv.get(), output.get(), + CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, + 0, /* no memory limit */ + &dalgo + ) + ); + + CUDA4DNN_CHECK_CUDNN( + cudnnGetConvolutionBackwardDataWorkspaceSize( + handle.get(), + filter.get(), input.get(), conv.get(), output.get(), + dalgo, &workspace_size + ) + ); + } + + TransposeConvolutionAlgorithm& operator=(const TransposeConvolutionAlgorithm&) = default; + TransposeConvolutionAlgorithm& operator=(TransposeConvolutionAlgorithm&& other) = default; + + cudnnConvolutionBwdDataAlgo_t get() const noexcept { return dalgo; } + + std::size_t get_workspace_size() const noexcept { return workspace_size; } + + private: + cudnnConvolutionBwdDataAlgo_t dalgo; + std::size_t workspace_size; + }; + + /** @brief performs transpose convolution + * + * dstValue = alpha * result + beta * priorDstValue + * + * @tparam T transpose convolution element type (must be `half` or `float`) + * + * @param handle valid cuDNN Handle + * @param convDesc convolution description + * @param transConvAlgo algorithm to use for convolution + * @param workspace workspace memory which meets the requirements of \p convAlgo + * @param filterDesc filter descriptor + * @param[in] filterPtr pointer to device memory containing the filters + * @param inputDesc tensor descriptor describing the input + * @param[in] inputPtr pointer to input tensor in device memory + * @param alpha result scale factor + * @param beta previous value scale factor + * @param outputDesc tensor descriptor describing the output + * @param[out] outputPtr pointer to output tensor in device memory + * + * Exception Guarantee: Basic + */ + template + void transpose_convolve( + const Handle& handle, + const ConvolutionDescriptor& convDesc, + const TransposeConvolutionAlgorithm& transConvAlgo, + WorkspaceInstance workspace, + const FilterDescriptor& filterDesc, + DevicePtr filterPtr, + const TensorDescriptor& inputDesc, + DevicePtr inputPtr, + T alpha, T beta, + const TensorDescriptor& outputDesc, + DevicePtr outputPtr) + { + CUDA4DNN_CHECK_CUDNN( + cudnnConvolutionBackwardData( + handle.get(), + &alpha, + filterDesc.get(), filterPtr.get(), + inputDesc.get(), inputPtr.get(), + convDesc.get(), transConvAlgo.get(), + static_cast(workspace.get()), workspace.size_in_bytes(), + &beta, outputDesc.get(), outputPtr.get() + ) + ); + } + + template <> inline + void transpose_convolve( + const Handle& handle, + const ConvolutionDescriptor& convDesc, + const TransposeConvolutionAlgorithm& convAlgo, + WorkspaceInstance workspace, + const FilterDescriptor& filterDesc, + DevicePtr filterPtr, + const TensorDescriptor& inputDesc, + DevicePtr inputPtr, + half alpha, half beta, + const TensorDescriptor& outputDesc, + DevicePtr outputPtr) + { + /* we specalize for fp16 as the scaling factors must be provided as `float` */ + float alpha_ = alpha, beta_ = beta; + CUDA4DNN_CHECK_CUDNN( + cudnnConvolutionBackwardData( + handle.get(), + &alpha_, + filterDesc.get(), filterPtr.get(), + inputDesc.get(), inputPtr.get(), + convDesc.get(), convAlgo.get(), + static_cast(workspace.get()), workspace.size_in_bytes(), + &beta_, outputDesc.get(), outputPtr.get() + ) + ); + } + +}}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */ + +#endif /* OPENCV_DNN_CUDA4DNN_CSL_CUDNN_TRANSPOSE_CONVOLUTION_HPP */ diff --git a/modules/dnn/src/cuda4dnn/csl/error.hpp b/modules/dnn/src/cuda4dnn/csl/error.hpp new file mode 100644 index 0000000..d49d912 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/csl/error.hpp @@ -0,0 +1,30 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_CSL_ERROR_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_CSL_ERROR_HPP + +#include + +#include + +#define CUDA4DNN_CHECK_CUDA(call) \ + ::cv::dnn::cuda4dnn::csl::detail::check((call), CV_Func, __FILE__, __LINE__) + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { + /** @brief exception class for errors thrown by the CUDA APIs */ + class CUDAException : public cv::Exception { + public: + using cv::Exception::Exception; + }; + + namespace detail { + inline void check(cudaError_t err, const char* func, const char* file, int line) { + if (err != cudaSuccess) + throw CUDAException(Error::GpuApiCallError, cudaGetErrorString(err), func, file, line); + } + } +}}}} /* namespace cv::dnn::cuda4dnn::csl */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_ERROR_HPP */ diff --git a/modules/dnn/src/cuda4dnn/csl/event.hpp b/modules/dnn/src/cuda4dnn/csl/event.hpp new file mode 100644 index 0000000..63da75a --- /dev/null +++ b/modules/dnn/src/cuda4dnn/csl/event.hpp @@ -0,0 +1,101 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_CSL_EVENT_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_CSL_EVENT_HPP + +#include "error.hpp" +#include "stream.hpp" + +#include + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { + + /** @brief sharable CUDA event + * + * Event is a smart sharable wrapper for CUDA event handle which ensures that + * the handle is destroyed after use. + * + * @note Moving an Event object to another invalidates the former + */ + class Event { + public: + Event() noexcept : event{ nullptr } { } + Event(const Event&) = delete; + Event(Event&& other) noexcept + : event{ other.event } { + other.event = nullptr; + } + + /** if \p create is `true`, a new event will be created; otherwise, an empty event object is created */ + Event(bool create, bool timing_event = false) : event{nullptr} { + if (create) { + unsigned int flags = cudaEventBlockingSync | (timing_event ? 0 : cudaEventDisableTiming); + CUDA4DNN_CHECK_CUDA(cudaEventCreateWithFlags(&event, flags)); + } + } + + ~Event() { + try { + if (event != nullptr) + CUDA4DNN_CHECK_CUDA(cudaEventDestroy(event)); + } catch (const CUDAException& ex) { + std::ostringstream os; + os << "Asynchronous exception caught during CUDA event destruction.\n"; + os << ex.what(); + os << "Exception will be ignored.\n"; + CV_LOG_WARNING(0, os.str().c_str()); + } + } + + Event& operator=(const Event&) noexcept = delete; + Event& operator=(Event&& other) noexcept { + event = other.event; + other.event = nullptr; + return *this; + } + + /** mark a point in \p stream */ + void record(const Stream& stream) { + CUDA4DNN_CHECK_CUDA(cudaEventRecord(event, stream.get())); + } + + /** blocks the caller thread until all operations before the event finish */ + void synchronize() const { CUDA4DNN_CHECK_CUDA(cudaEventSynchronize(event)); } + + /** returns true if there are operations pending before the event completes */ + bool busy() const { + auto status = cudaEventQuery(event); + if (status == cudaErrorNotReady) + return true; + CUDA4DNN_CHECK_CUDA(status); + return false; + } + + cudaEvent_t get() const noexcept { return event; } + + /** returns true if the event is valid */ + explicit operator bool() const noexcept { return event; } + + private: + cudaEvent_t event; + }; + + /** makes a stream wait on an event */ + void StreamWaitOnEvent(const Stream& stream, const Event& event) { + CUDA4DNN_CHECK_CUDA(cudaStreamWaitEvent(stream.get(), event.get(), 0)); + } + + /** returns the time elapsed between two events in milliseconds */ + float TimeElapsedBetweenEvents(const Event& start, const Event& end) { + float temp; + CUDA4DNN_CHECK_CUDA(cudaEventElapsedTime(&temp, start.get(), end.get())); + return temp; + } + +}}}} /* namespace cv::dnn::cuda4dnn::csl */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_EVENT_HPP */ diff --git a/modules/dnn/src/cuda4dnn/csl/fp16.hpp b/modules/dnn/src/cuda4dnn/csl/fp16.hpp new file mode 100644 index 0000000..c76de45 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/csl/fp16.hpp @@ -0,0 +1,84 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_CSL_FP16_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_CSL_FP16_HPP + +#include "nvcc_defs.hpp" + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { + + namespace detail { + template + struct is_half_convertible : std::false_type { }; + + template + struct is_half_convertible::value, void>::type> : std::true_type { }; + + template + struct is_half_convertible::value, void>::type> : std::true_type { }; + } + + /* Note: nvcc has a broken overload resolution; it considers host overloads inside device code + CUDA4DNN_HOST bool operator==(half lhs, half rhs) noexcept { return static_cast(lhs) == static_cast(rhs); } + CUDA4DNN_HOST bool operator!=(half lhs, half rhs) noexcept { return static_cast(lhs) != static_cast(rhs); } + CUDA4DNN_HOST bool operator<(half lhs, half rhs) noexcept { return static_cast(lhs) < static_cast(rhs); } + CUDA4DNN_HOST bool operator>(half lhs, half rhs) noexcept { return static_cast(lhs) > static_cast(rhs); } + CUDA4DNN_HOST bool operator<=(half lhs, half rhs) noexcept { return static_cast(lhs) <= static_cast(rhs); } + CUDA4DNN_HOST bool operator>=(half lhs, half rhs) noexcept { return static_cast(lhs) >= static_cast(rhs); } + */ + + template CUDA4DNN_HOST + typename std::enable_if::value, bool> + ::type operator==(half lhs, T rhs) noexcept { return static_cast(lhs) == static_cast(rhs); } + + template CUDA4DNN_HOST + typename std::enable_if::value, bool> + ::type operator!=(half lhs, T rhs) noexcept { return static_cast(lhs) != static_cast(rhs); } + + template CUDA4DNN_HOST + typename std::enable_if::value, bool> + ::type operator<(half lhs, T rhs) noexcept { return static_cast(lhs) < static_cast(rhs); } + + template CUDA4DNN_HOST + typename std::enable_if::value, bool> + ::type operator>(half lhs, T rhs) noexcept { return static_cast(lhs) > static_cast(rhs); } + + template CUDA4DNN_HOST + typename std::enable_if::value, bool> + ::type operator<=(half lhs, T rhs) noexcept { return static_cast(lhs) <= static_cast(rhs); } + + template CUDA4DNN_HOST + typename std::enable_if::value, bool> + ::type operator>=(half lhs, T rhs) noexcept { return static_cast(lhs) >= static_cast(rhs); } + + template CUDA4DNN_HOST + typename std::enable_if::value, bool> + ::type operator==(T lhs, half rhs) noexcept { return static_cast(lhs) == static_cast(rhs); } + + template CUDA4DNN_HOST + typename std::enable_if::value, bool> + ::type operator!=(T lhs, half rhs) noexcept { return static_cast(lhs) != static_cast(rhs); } + + template CUDA4DNN_HOST + typename std::enable_if::value, bool> + ::type operator<(T lhs, half rhs) noexcept { return static_cast(lhs) < static_cast(rhs); } + + template CUDA4DNN_HOST + typename std::enable_if::value, bool> + ::type operator>(T lhs, half rhs) noexcept { return static_cast(lhs) > static_cast(rhs); } + + template CUDA4DNN_HOST + typename std::enable_if::value, bool> + ::type operator<=(T lhs, half rhs) noexcept { return static_cast(lhs) <= static_cast(rhs); } + + template CUDA4DNN_HOST + typename std::enable_if::value, bool> + ::type operator>=(T lhs, half rhs) noexcept { return static_cast(lhs) >= static_cast(rhs); } + +}}}} /* namespace cv::dnn::cuda4dnn::csl */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_FP16_HPP */ diff --git a/modules/dnn/src/cuda4dnn/csl/memory.hpp b/modules/dnn/src/cuda4dnn/csl/memory.hpp new file mode 100644 index 0000000..2ffa32f --- /dev/null +++ b/modules/dnn/src/cuda4dnn/csl/memory.hpp @@ -0,0 +1,295 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_CSL_MEMORY_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_CSL_MEMORY_HPP + +#include "error.hpp" +#include "pointer.hpp" + +#include + +#include + +#include +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { + + /* @brief smart device pointer with allocation/deallocation methods + * + * ManagedPtr is a smart shared device pointer which also handles memory allocation. + */ + template + class ManagedPtr { + static_assert(!std::is_const::value && !std::is_volatile::value, "T cannot be cv-qualified"); + static_assert(std::is_standard_layout::value, "T must satisfy StandardLayoutType"); + + public: + using element_type = T; + + using pointer = DevicePtr; + using const_pointer = DevicePtr::type>; + + using size_type = std::size_t; + + ManagedPtr() noexcept : wrapped{ nullptr }, n{ 0 }, capacity{ 0 } { } + ManagedPtr(const ManagedPtr&) noexcept = default; + ManagedPtr(ManagedPtr&& other) noexcept + : wrapped{ std::move(other.wrapped) }, n{ other.n }, capacity { other.capacity } + { + other.reset(); + } + + /** allocates device memory for \p count number of element */ + ManagedPtr(size_type count) { + if (count <= 0) { + CV_Error(Error::StsBadArg, "number of elements is zero or negative"); + } + + void* temp = nullptr; + CUDA4DNN_CHECK_CUDA(cudaMalloc(&temp, count * sizeof(element_type))); + + auto ptr = typename pointer::pointer(static_cast(temp)); + wrapped.reset(ptr, [](element_type* ptr) { + if (ptr != nullptr) { + /* contract violation for std::shared_ptr if cudaFree throws */ + try { + CUDA4DNN_CHECK_CUDA(cudaFree(ptr)); + } catch (const CUDAException& ex) { + std::ostringstream os; + os << "Device memory deallocation failed in deleter.\n"; + os << ex.what(); + os << "Exception will be ignored.\n"; + CV_LOG_WARNING(0, os.str().c_str()); + } + } + }); + /* std::shared_ptr::reset invokves the deleter if an exception occurs; hence, we don't + * need to have a try-catch block to free the allocated device memory + */ + + n = capacity = count; + } + + ManagedPtr& operator=(ManagedPtr&& other) noexcept { + wrapped = std::move(other.wrapped); + n = other.n; + capacity = other.capacity; + + other.reset(); + return *this; + } + + size_type size() const noexcept { return n; } + + void reset() noexcept { wrapped.reset(); n = capacity = 0; } + + /** + * deallocates any previously allocated memory and allocates device memory + * for \p count number of elements + * + * @note no reallocation if the previously allocated memory has no owners and the requested memory size fits in it + * @note use move constructor to guarantee a deallocation of the previously allocated memory + * + * Exception Guarantee: Strong + */ + void reset(size_type count) { + /* we need to fully own the memory to perform optimizations */ + if (wrapped.use_count() == 1) { + /* avoid reallocation if the existing capacity is sufficient */ + if (count <= capacity) { + n = count; + return; + } + } + + /* no optimization performed; allocate memory */ + ManagedPtr tmp(count); + swap(tmp, *this); + } + + pointer get() const noexcept { return pointer(wrapped.get()); } + + explicit operator bool() const noexcept { return wrapped; } + + friend bool operator==(const ManagedPtr& lhs, const ManagedPtr& rhs) noexcept { return lhs.wrapped == rhs.wrapped; } + friend bool operator!=(const ManagedPtr& lhs, const ManagedPtr& rhs) noexcept { return lhs.wrapped != rhs.wrapped; } + + friend void swap(ManagedPtr& lhs, ManagedPtr& rhs) noexcept { + using std::swap; + swap(lhs.wrapped, rhs.wrapped); + swap(lhs.n, rhs.n); + swap(lhs.capacity, rhs.capacity); + } + + private: + std::shared_ptr wrapped; + size_type n, capacity; + }; + + /** copies entire memory block pointed by \p src to \p dest + * + * \param[in] src device pointer + * \param[out] dest host pointer + * + * Pre-conditions: + * - memory pointed by \p dest must be large enough to hold the entire block of memory held by \p src + * + * Exception Guarantee: Basic + */ + template + void memcpy(T *dest, const ManagedPtr& src) { + memcpy(dest, src.get(), src.size()); + } + + /** copies data from memory pointed by \p src to fully fill \p dest + * + * \param[in] src host pointer + * \param[out] dest device pointer + * + * Pre-conditions: + * - memory pointed by \p src must be at least as big as the memory block held by \p dest + * + * Exception Guarantee: Basic + */ + template + void memcpy(const ManagedPtr& dest, const T* src) { + memcpy(dest.get(), src, dest.size()); + } + + /** copies data from memory pointed by \p src to \p dest + * + * if the two \p src and \p dest have different sizes, the number of elements copied is + * equal to the size of the smaller memory block + * + * \param[in] src device pointer + * \param[out] dest device pointer + * + * Exception Guarantee: Basic + */ + template + void memcpy(const ManagedPtr& dest, const ManagedPtr& src) { + memcpy(dest.get(), src.get(), std::min(dest.size(), src.size())); + } + + /** sets device memory block to a specific 8-bit value + * + * \param[in] src device pointer + * \param[out] ch 8-bit value to fill the device memory with + * + * Exception Guarantee: Basic + */ + template + void memset(const ManagedPtr& dest, std::int8_t ch) { + memset(dest.get(), ch, dest.size()); + } + + /** copies entire memory block pointed by \p src to \p dest asynchronously + * + * \param[in] src device pointer + * \param[out] dest host pointer + * \param stream CUDA stream that has to be used for the memory transfer + * + * Pre-conditions: + * - memory pointed by \p dest must be large enough to hold the entire block of memory held by \p src + * - \p dest points to page-locked memory + * + * Exception Guarantee: Basic + */ + template + void memcpy(T *dest, const ManagedPtr& src, const Stream& stream) { + CV_Assert(stream); + memcpy(dest, src.get(), src.size(), stream); + } + + /** copies data from memory pointed by \p src to \p dest asynchronously + * + * \param[in] src host pointer + * \param[out] dest device pointer + * \param stream CUDA stream that has to be used for the memory transfer + * + * Pre-conditions: + * - memory pointed by \p dest must be large enough to hold the entire block of memory held by \p src + * - \p src points to page-locked memory + * + * Exception Guarantee: Basic + */ + template + void memcpy(const ManagedPtr& dest, const T* src, const Stream& stream) { + CV_Assert(stream); + memcpy(dest.get(), src, dest.size(), stream); + } + + /** copies data from memory pointed by \p src to \p dest asynchronously + * + * \param[in] src device pointer + * \param[out] dest device pointer + * \param stream CUDA stream that has to be used for the memory transfer + * + * if the two \p src and \p dest have different sizes, the number of elements copied is + * equal to the size of the smaller memory block + * + * Exception Guarantee: Basic + */ + template + void memcpy(ManagedPtr& dest, const ManagedPtr& src, const Stream& stream) { + CV_Assert(stream); + memcpy(dest.get(), src.get(), std::min(dest.size(), src.size()), stream); + } + + /** sets device memory block to a specific 8-bit value asynchronously + * + * \param[in] src device pointer + * \param[out] ch 8-bit value to fill the device memory with + * \param stream CUDA stream that has to be used for the memory operation + * + * Exception Guarantee: Basic + */ + template + void memset(const ManagedPtr& dest, int ch, const Stream& stream) { + CV_Assert(stream); + memset(dest.get(), ch, dest.size(), stream); + } + + /** @brief registers host memory as page-locked and unregisters on destruction */ + class MemoryLockGuard { + public: + MemoryLockGuard() noexcept : ptr { nullptr } { } + MemoryLockGuard(const MemoryLockGuard&) = delete; + MemoryLockGuard(MemoryLockGuard&& other) noexcept : ptr{ other.ptr } { + other.ptr = nullptr; + } + + /** page-locks \p size_in_bytes bytes of memory starting from \p ptr_ + * + * Pre-conditons: + * - host memory should be unregistered + */ + MemoryLockGuard(void* ptr_, std::size_t size_in_bytes) { + CUDA4DNN_CHECK_CUDA(cudaHostRegister(ptr_, size_in_bytes, cudaHostRegisterPortable)); + ptr = ptr_; + } + + MemoryLockGuard& operator=(const MemoryLockGuard&) = delete; + MemoryLockGuard& operator=(MemoryLockGuard&& other) noexcept { + ptr = other.ptr; + other.ptr = nullptr; + return *this; + } + + ~MemoryLockGuard() { + if(ptr != nullptr) + CUDA4DNN_CHECK_CUDA(cudaHostUnregister(ptr)); + } + + private: + void *ptr; + }; + +}}}} /* namespace cv::dnn::cuda4dnn::csl */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_MEMORY_HPP */ diff --git a/modules/dnn/src/cuda4dnn/csl/nvcc_defs.hpp b/modules/dnn/src/cuda4dnn/csl/nvcc_defs.hpp new file mode 100644 index 0000000..b9418c9 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/csl/nvcc_defs.hpp @@ -0,0 +1,20 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_CSL_NVCC_DEFS_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_CSL_NVCC_DEFS_HPP + +#include + +#ifdef __CUDACC__ +# define CUDA4DNN_HOST __host__ +# define CUDA4DNN_DEVICE __device__ +# define CUDA4DNN_HOST_DEVICE CUDA4DNN_HOST CUDA4DNN_DEVICE +#else +# define CUDA4DNN_HOST +# define CUDA4DNN_DEVICE +# define CUDA4DNN_HOST_DEVICE +#endif + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_NVCC_DEFS_HPP */ diff --git a/modules/dnn/src/cuda4dnn/csl/pointer.hpp b/modules/dnn/src/cuda4dnn/csl/pointer.hpp new file mode 100644 index 0000000..0d89112 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/csl/pointer.hpp @@ -0,0 +1,411 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_CSL_POINTER_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_CSL_POINTER_HPP + +#include "nvcc_defs.hpp" +#include "error.hpp" +#include "stream.hpp" + +#include + +#include + +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { + + /** @brief provides a type-safe device pointer + * + * DevicePtr wraps a raw pointer and mimics its behaviour. It does not implicitly convert + * to a raw pointer. This ensures that accidental mixing of host and device pointers do not happen. + * + * It is meant to point to locations in device memory. Hence, it provides dereferencing or + * array subscript capability for device code only. + * + * A `const DevicePtr` represents an immutable pointer to a mutable memory. + * A `DevicePtr` represents a mutable pointer to an immutable memory. + * A `const DevicePtr` represents an immutable pointer to an immutable memory. + * + * A `DevicePtr` can implicitly convert to `DevicePtr`. + * + * Specalizations: + * - DevicePtr/DevicePtr do not support pointer arithmetic (but relational operators are provided) + * - any device pointer pointing to mutable memory is implicitly convertible to DevicePtr + * - any device pointer is implicitly convertible to DevicePtr + * - DevicePtr can be explicitly converted to any device pointer + * - DevicePtr can be explicitly converted to any device pointer pointing to immutable memory + */ + template + class DevicePtr { + static_assert(std::is_standard_layout::value, "T must satisfy StandardLayoutType"); + + public: + using element_type = T; + using difference_type = std::ptrdiff_t; + using pointer = typename std::add_pointer::type; + using reference = typename std::add_lvalue_reference::type; + + DevicePtr() = default; + CUDA4DNN_HOST_DEVICE explicit DevicePtr(pointer ptr_) noexcept : ptr{ ptr_ } { } + + CUDA4DNN_HOST_DEVICE DevicePtr operator=(pointer ptr_) noexcept { ptr = ptr_; return *this; } + + CUDA4DNN_HOST_DEVICE pointer get() const noexcept { return ptr; }; + + CUDA4DNN_DEVICE reference operator[](difference_type idx) const noexcept { return get()[idx]; } + CUDA4DNN_DEVICE reference operator*() const noexcept { return *get(); } + CUDA4DNN_DEVICE pointer operator->() const noexcept { return get(); } + + template::value, bool>::type = true> + CUDA4DNN_HOST_DEVICE operator DevicePtr::type>() const noexcept { + return DevicePtr::type>{ptr}; + } + + CUDA4DNN_HOST_DEVICE explicit operator bool() const noexcept { return ptr; } + + CUDA4DNN_HOST_DEVICE DevicePtr operator++() noexcept { + ++ptr; + return *this; + } + + CUDA4DNN_HOST_DEVICE DevicePtr operator++(int) noexcept { + auto tmp = DevicePtr(*this); + ptr++; + return tmp; + } + + CUDA4DNN_HOST_DEVICE DevicePtr operator--() noexcept { + --ptr; + return *this; + } + + CUDA4DNN_HOST_DEVICE DevicePtr operator--(int) noexcept { + auto tmp = DevicePtr(*this); + ptr--; + return tmp; + } + + CUDA4DNN_HOST_DEVICE DevicePtr operator+=(std::ptrdiff_t offset) noexcept { + ptr += offset; + return *this; + } + + CUDA4DNN_HOST_DEVICE DevicePtr operator-=(std::ptrdiff_t offset) noexcept { + ptr -= offset; + return *this; + } + + CUDA4DNN_HOST_DEVICE friend DevicePtr operator+(DevicePtr lhs, std::ptrdiff_t offset) noexcept { + return lhs += offset; + } + + CUDA4DNN_HOST_DEVICE friend DevicePtr operator-(DevicePtr lhs, std::ptrdiff_t offset) noexcept { + return lhs -= offset; + } + + CUDA4DNN_HOST_DEVICE friend difference_type operator-(DevicePtr lhs, DevicePtr rhs) noexcept { + return lhs.ptr - rhs.ptr; + } + + CUDA4DNN_HOST_DEVICE friend bool operator==(DevicePtr lhs, DevicePtr rhs) noexcept { return lhs.ptr == rhs.ptr; } + CUDA4DNN_HOST_DEVICE friend bool operator!=(DevicePtr lhs, DevicePtr rhs) noexcept { return !(lhs == rhs); } + CUDA4DNN_HOST_DEVICE friend bool operator<(DevicePtr lhs, DevicePtr rhs) noexcept { return lhs.ptr < rhs.ptr; } + CUDA4DNN_HOST_DEVICE friend bool operator>(DevicePtr lhs, DevicePtr rhs) noexcept { return rhs < lhs; } + CUDA4DNN_HOST_DEVICE friend bool operator<=(DevicePtr lhs, DevicePtr rhs) noexcept { return !(rhs < lhs); } + CUDA4DNN_HOST_DEVICE friend bool operator>=(DevicePtr lhs, DevicePtr rhs) noexcept { return !(lhs < rhs); } + + CUDA4DNN_HOST_DEVICE explicit operator pointer() const noexcept { return ptr; } + + CUDA4DNN_HOST friend void swap(DevicePtr& lhs, DevicePtr& rhs) noexcept { + using std::swap; + swap(lhs.ptr, rhs.ptr); + } + + template + CUDA4DNN_HOST friend std::basic_ostream& operator<<(std::basic_ostream& os, DevicePtr other) { + os << other.get() << " (device)"; + return os; + } + + private: + pointer ptr; + }; + + template <> + class DevicePtr { + public: + using element_type = const void; + using pointer = typename std::add_pointer::type; + + DevicePtr() = default; + + /* host const void pointer to const void device pointer */ + CUDA4DNN_HOST_DEVICE explicit DevicePtr(pointer ptr_) noexcept : ptr{ ptr_ } { } + + /* allow any device pointer to be implicitly convereted to void device pointer */ + template + CUDA4DNN_HOST_DEVICE DevicePtr(DevicePtr ptr_) noexcept : ptr{ ptr_.get() } { } + + CUDA4DNN_HOST_DEVICE DevicePtr operator=(pointer ptr_) noexcept { ptr = ptr_; return *this; } + + CUDA4DNN_HOST_DEVICE pointer get() const noexcept { return ptr; }; + + CUDA4DNN_HOST_DEVICE explicit operator bool() const noexcept { return ptr; } + + CUDA4DNN_HOST_DEVICE friend bool operator==(DevicePtr lhs, DevicePtr rhs) noexcept { return lhs.ptr == rhs.ptr; } + CUDA4DNN_HOST_DEVICE friend bool operator!=(DevicePtr lhs, DevicePtr rhs) noexcept { return !(lhs == rhs); } + CUDA4DNN_HOST_DEVICE friend bool operator<(DevicePtr lhs, DevicePtr rhs) noexcept { return lhs.ptr < rhs.ptr; } + CUDA4DNN_HOST_DEVICE friend bool operator>(DevicePtr lhs, DevicePtr rhs) noexcept { return rhs < lhs; } + CUDA4DNN_HOST_DEVICE friend bool operator<=(DevicePtr lhs, DevicePtr rhs) noexcept { return !(rhs < lhs); } + CUDA4DNN_HOST_DEVICE friend bool operator>=(DevicePtr lhs, DevicePtr rhs) noexcept { return !(lhs < rhs); } + + /* explicit conversion into host void pointer */ + CUDA4DNN_HOST_DEVICE explicit operator pointer() const noexcept { return ptr; } + + /* const void device pointer can be explicitly casted into any const device pointer type */ + template ::value, bool>::type = true> + CUDA4DNN_HOST_DEVICE explicit operator DevicePtr() const noexcept { + return static_cast(ptr); + } + + CUDA4DNN_HOST friend void swap(DevicePtr& lhs, DevicePtr& rhs) noexcept { + using std::swap; + swap(lhs.ptr, rhs.ptr); + } + + template + CUDA4DNN_HOST friend std::basic_ostream& operator<<(std::basic_ostream& os, DevicePtr other) { + os << other.get() << " (device)"; + return os; + } + + private: + pointer ptr; + }; + + template <> + class DevicePtr { + public: + using element_type = void; + using pointer = typename std::add_pointer::type; + + DevicePtr() = default; + + /* host pointer to device pointer */ + CUDA4DNN_HOST_DEVICE explicit DevicePtr(pointer ptr_) noexcept : ptr{ ptr_ } { } + + /* allow any device pointer to mutable memory to be implicitly convereted to void device pointer */ + template ::value, bool>::type = false> + CUDA4DNN_HOST_DEVICE DevicePtr(DevicePtr ptr_) noexcept : ptr { ptr_.get() } { } + + CUDA4DNN_HOST_DEVICE DevicePtr operator=(pointer ptr_) noexcept { ptr = ptr_; return *this; } + + CUDA4DNN_HOST_DEVICE pointer get() const noexcept { return ptr; }; + + CUDA4DNN_HOST_DEVICE operator DevicePtr() const noexcept { return DevicePtr{ptr}; } + + CUDA4DNN_HOST_DEVICE explicit operator bool() const noexcept { return ptr; } + + CUDA4DNN_HOST_DEVICE friend bool operator==(DevicePtr lhs, DevicePtr rhs) noexcept { return lhs.ptr == rhs.ptr; } + CUDA4DNN_HOST_DEVICE friend bool operator!=(DevicePtr lhs, DevicePtr rhs) noexcept { return !(lhs == rhs); } + CUDA4DNN_HOST_DEVICE friend bool operator<(DevicePtr lhs, DevicePtr rhs) noexcept { return lhs.ptr < rhs.ptr; } + CUDA4DNN_HOST_DEVICE friend bool operator>(DevicePtr lhs, DevicePtr rhs) noexcept { return rhs < lhs; } + CUDA4DNN_HOST_DEVICE friend bool operator<=(DevicePtr lhs, DevicePtr rhs) noexcept { return !(rhs < lhs); } + CUDA4DNN_HOST_DEVICE friend bool operator>=(DevicePtr lhs, DevicePtr rhs) noexcept { return !(lhs < rhs); } + + /* explicit conversion into host void pointer */ + CUDA4DNN_HOST_DEVICE explicit operator pointer() const noexcept { return ptr; } + + /* void device pointer can be explicitly casted into any device pointer type */ + template + CUDA4DNN_HOST_DEVICE explicit operator DevicePtr() const noexcept { + return DevicePtr(static_cast(ptr)); + } + + CUDA4DNN_HOST friend void swap(DevicePtr& lhs, DevicePtr& rhs) noexcept { + using std::swap; + swap(lhs.ptr, rhs.ptr); + } + + template + CUDA4DNN_HOST friend std::basic_ostream& operator<<(std::basic_ostream& os, DevicePtr other) { + os << other.get() << " (device)"; + return os; + } + + private: + pointer ptr; + }; + + template + bool is_aligned(DevicePtr ptr, std::size_t alignment) { + auto addr = reinterpret_cast(ptr.get()); + return addr % alignment == 0; + } + + /** copies \p n elements from \p src to \p dest4 + * + * \param[in] src device pointer + * \param[out] dest host pointer + * + * Pre-conditions: + * - memory pointed by \p dest and \p src must be large enough to hold \p n elements + * + * Exception Guarantee: Basic + */ + template + void memcpy(T *dest, DevicePtr src, std::size_t n) { + if (n <= 0) { + CV_Error(Error::StsBadArg, "number of elements to copy is zero or negtaive"); + } + + CUDA4DNN_CHECK_CUDA(cudaMemcpy(dest, src.get(), n * sizeof(T), cudaMemcpyDefault)); + } + + /** copies \p n elements from \p src to \p dest + * + * \param[in] src host pointer + * \param[out] dest device pointer + * + * Pre-conditions: + * - memory pointed by \p dest and \p src must be large enough to hold \p n elements + * + * Exception Guarantee: Basic + */ + template + void memcpy(DevicePtr dest, const T* src, std::size_t n) { + if (n <= 0) { + CV_Error(Error::StsBadArg, "number of elements to copy is zero or negtaive"); + } + + CUDA4DNN_CHECK_CUDA(cudaMemcpy(dest.get(), src, n * sizeof(T), cudaMemcpyDefault)); + } + + /** copies \p n elements from \p src to \p dest + * + * \param[in] src device pointer + * \param[out] dest device pointer + * + * Pre-conditions: + * - memory pointed by \p dest and \p src must be large enough to hold \p n elements + * + * Exception Guarantee: Basic + */ + template + void memcpy(DevicePtr dest, DevicePtr src, std::size_t n) { + if (n <= 0) { + CV_Error(Error::StsBadArg, "number of elements to copy is zero or negtaive"); + } + + CUDA4DNN_CHECK_CUDA(cudaMemcpy(dest.get(), src.get(), n * sizeof(T), cudaMemcpyDefault)); + } + + /** sets \p n elements to \p ch in \p dest + * + * \param[in] src device pointer + * \param[out] ch 8-bit value to fill the device memory with + * + * Pre-conditions: + * - memory pointed by \p dest must be large enough to hold \p n elements + * + * Exception Guarantee: Basic + */ + template + void memset(DevicePtr dest, std::int8_t ch, std::size_t n) { + if (n <= 0) { + CV_Error(Error::StsBadArg, "number of elements to copy is zero or negtaive"); + } + + CUDA4DNN_CHECK_CUDA(cudaMemset(dest.get(), ch, n * sizeof(T))); + } + + /** copies \p n elements from \p src to \p dest asynchronously + * + * \param[in] src device pointer + * \param[out] dest host pointer + * \param stream CUDA stream that has to be used for the memory transfer + * + * Pre-conditions: + * - memory pointed by \p dest and \p src must be large enough to hold \p n elements + * - \p dest points to page-locked memory + * + * Exception Guarantee: Basic + */ + template + void memcpy(T *dest, DevicePtr src, std::size_t n, const Stream& stream) { + if (n <= 0) { + CV_Error(Error::StsBadArg, "number of elements to copy is zero or negtaive"); + } + + CUDA4DNN_CHECK_CUDA(cudaMemcpyAsync(dest, src.get(), n * sizeof(T), cudaMemcpyDefault, stream.get())); + } + + /** copies data from memory pointed by \p src to \p dest asynchronously + * + * \param[in] src host pointer + * \param[out] dest device pointer + * \param stream CUDA stream that has to be used for the memory transfer + * + * Pre-conditions: + * - memory pointed by \p dest and \p src must be large enough to hold \p n elements + * - \p src points to page-locked memory + * + * Exception Guarantee: Basic + */ + template + void memcpy(DevicePtr dest, const T *src, std::size_t n, const Stream& stream) { + if (n <= 0) { + CV_Error(Error::StsBadArg, "number of elements to copy is zero or negtaive"); + } + + CUDA4DNN_CHECK_CUDA(cudaMemcpyAsync(dest.get(), src, n * sizeof(T), cudaMemcpyDefault, stream.get())); + } + + /** copies \p n elements from \p src to \p dest asynchronously + * + * \param[in] src device pointer + * \param[out] dest device pointer + * \param stream CUDA stream that has to be used for the memory transfer + * + * Pre-conditions: + * - memory pointed by \p dest and \p src must be large enough to hold \p n elements + * + * Exception Guarantee: Basic + */ + template + void memcpy(DevicePtr dest, DevicePtr src, std::size_t n, const Stream& stream) { + if (n <= 0) { + CV_Error(Error::StsBadArg, "number of elements to copy is zero or negtaive"); + } + + CUDA4DNN_CHECK_CUDA(cudaMemcpyAsync(dest.get(), src.get(), n * sizeof(T), cudaMemcpyDefault, stream.get())); + } + + /** sets \p n elements to \p ch in \p dest asynchronously + * + * \param[in] src device pointer + * \param[out] ch 8-bit value to fill the device memory with + * \param stream CUDA stream that has to be used for the memory operation + * + * Pre-conditions: + * - memory pointed by \p dest must be large enough to hold \p n elements + * + * Exception Guarantee: Basic + */ + template + void memset(DevicePtr dest, std::int8_t ch, std::size_t n, const Stream& stream) { + if (n <= 0) { + CV_Error(Error::StsBadArg, "number of elements to copy is zero or negtaive"); + } + + CUDA4DNN_CHECK_CUDA(cudaMemsetAsync(dest.get(), ch, n * sizeof(T), stream.get())); + } + +}}}} /* namespace cv::dnn::cuda4dnn::csl */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_POINTER_HPP */ diff --git a/modules/dnn/src/cuda4dnn/csl/span.hpp b/modules/dnn/src/cuda4dnn/csl/span.hpp new file mode 100644 index 0000000..55e8e5f --- /dev/null +++ b/modules/dnn/src/cuda4dnn/csl/span.hpp @@ -0,0 +1,83 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_CSL_SPAN_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_CSL_SPAN_HPP + +#include "pointer.hpp" +#include "nvcc_defs.hpp" + +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { + + /** @brief provides non-owning mutable access for device arrays + * + * const Span/Span provides mutable access to the elements unless T is const qualified + * const Span makes the span immutable but not the elements + */ + template + class Span { + static_assert(std::is_standard_layout::value, "T must satisfy StandardLayoutType"); + + public: + using value_type = T; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + + using pointer = DevicePtr; + using const_pointer = DevicePtr::type>; + using reference = typename std::add_lvalue_reference::type; + using const_reference = typename std::add_lvalue_reference::type>; + + using iterator = pointer; + using const_iterator = const_pointer; + + Span() noexcept : ptr{ nullptr }, sz{ 0 } { } + CUDA4DNN_HOST_DEVICE Span(pointer first, pointer last) noexcept : ptr{ first }, sz{ last - first } { } + CUDA4DNN_HOST_DEVICE Span(pointer first, size_type count) noexcept : ptr{ first }, sz{ count } { } + + CUDA4DNN_HOST_DEVICE size_type size() const noexcept { return sz; } + CUDA4DNN_HOST_DEVICE bool empty() const noexcept { return size() == 0; } + + CUDA4DNN_DEVICE reference operator[](difference_type index) const { return ptr[index]; } + CUDA4DNN_HOST_DEVICE pointer data() const noexcept { return ptr; } + + template::type, + typename std::enable_if::value, bool>::type = true> + CUDA4DNN_HOST_DEVICE operator Span() const noexcept { return Span{ptr, sz}; } + + private: + pointer ptr; + size_type sz; + }; + + /** @brief provides non-owning immutable view for device arrays */ + template + using View = Span; + + /** returns true if the address of a span/view is aligned to \p alignment number of elements (not bytes) */ + template + bool is_address_aligned(View v, std::size_t alignment) { + return is_aligned(v.data(), alignment * sizeof(T)); + } + + /** returns true if the size of a span/view is a multiple of \p alignment */ + template + bool is_size_aligned(View v, std::size_t alignment) { + return v.size() % alignment == 0; + } + + /** @brief returns true if the address and the size of the span/view is aligned + * \p alignment refers to the number of elements (not bytes) + */ + template + bool is_fully_aligned(View v, std::size_t alignment) { + return is_address_aligned(v, alignment) && is_size_aligned(v, alignment); + } + +}}}} /* namespace cv::dnn::cuda4dnn::csl */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_SPAN_HPP */ diff --git a/modules/dnn/src/cuda4dnn/csl/stream.hpp b/modules/dnn/src/cuda4dnn/csl/stream.hpp new file mode 100644 index 0000000..0a1d804 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/csl/stream.hpp @@ -0,0 +1,118 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_CSL_STREAM_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_CSL_STREAM_HPP + +#include "error.hpp" + +#include +#include + +#include + +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { + + /** @brief noncopyable smart CUDA stream + * + * UniqueStream is a smart non-sharable wrapper for CUDA stream handle which ensures that + * the handle is destroyed after use. Unless explicitly specified by a constructor argument, + * the stream object represents the default stream. + */ + class UniqueStream { + public: + UniqueStream() noexcept : stream{ 0 } { } + UniqueStream(UniqueStream&) = delete; + UniqueStream(UniqueStream&& other) noexcept { + stream = other.stream; + other.stream = 0; + } + + UniqueStream(bool create) : stream{ 0 } { + if (create) { + CUDA4DNN_CHECK_CUDA(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + } + } + + ~UniqueStream() { + try { + if (stream != 0) + CUDA4DNN_CHECK_CUDA(cudaStreamDestroy(stream)); + } catch (const CUDAException& ex) { + std::ostringstream os; + os << "Asynchronous exception caught during CUDA stream destruction.\n"; + os << ex.what(); + os << "Exception will be ignored.\n"; + CV_LOG_WARNING(0, os.str().c_str()); + } + } + + UniqueStream& operator=(const UniqueStream&) = delete; + UniqueStream& operator=(UniqueStream&& other) noexcept { + stream = other.stream; + other.stream = 0; + return *this; + } + + /** returns the raw CUDA stream handle */ + cudaStream_t get() const noexcept { return stream; } + + void synchronize() const { CUDA4DNN_CHECK_CUDA(cudaStreamSynchronize(stream)); } + bool busy() const { + auto status = cudaStreamQuery(stream); + if (status == cudaErrorNotReady) + return true; + CUDA4DNN_CHECK_CUDA(status); + return false; + } + + private: + cudaStream_t stream; + }; + + /** @brief sharable smart CUDA stream + * + * Stream is a smart sharable wrapper for CUDA stream handle which ensures that + * the handle is destroyed after use. Unless explicitly specified by a constructor argument, + * the stream object represents the default stream. + * + * @note Moving a Stream object to another invalidates the former + */ + class Stream { + public: + Stream() : stream(std::make_shared()) { } + Stream(const Stream&) = default; + Stream(Stream&&) = default; + + /** if \p create is `true`, a new stream will be created instead of the otherwise default stream */ + Stream(bool create) : stream(std::make_shared(create)) { } + + Stream& operator=(const Stream&) = default; + Stream& operator=(Stream&&) = default; + + /** blocks the caller thread until all operations in the stream are complete */ + void synchronize() const { stream->synchronize(); } + + /** returns true if there are operations pending in the stream */ + bool busy() const { return stream->busy(); } + + /** returns true if the stream is valid */ + explicit operator bool() const noexcept { return static_cast(stream); } + + cudaStream_t get() const noexcept { + CV_Assert(stream); + return stream->get(); + } + + private: + std::shared_ptr stream; + }; + +}}}} /* namespace cv::dnn::cuda4dnn::csl */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_STREAM_HPP */ diff --git a/modules/dnn/src/cuda4dnn/csl/tensor.hpp b/modules/dnn/src/cuda4dnn/csl/tensor.hpp new file mode 100644 index 0000000..eef69df --- /dev/null +++ b/modules/dnn/src/cuda4dnn/csl/tensor.hpp @@ -0,0 +1,1143 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_CSL_TENSOR_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_CSL_TENSOR_HPP + +#include "nvcc_defs.hpp" +#include "memory.hpp" +#include "cublas.hpp" +#include "cudnn.hpp" +#include "span.hpp" + +#include "../cxx_utils/resizable_static_array.hpp" +#include "../cxx_utils/is_iterator.hpp" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef CSL_MAX_TENSOR_RANK + #define CSL_MAX_TENSOR_RANK 6 +#endif + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { + + /** \file tensor.hpp + * + * TYPE | OWNERSHIP | MUTABLE + * ------------ + --------- + -------- + * Tensor | Yes | Yes + * TensorSpan | No | Yes + * TensorView | No | No + * + * Tensor is implicitly convertible to TensorSpan and TensorView + * TensorSpan is implicitly convertible to TensorView + * + * Concepts and template parameter naming convention: + * - "MutableTensorType" can refer to a Tensor or TensorSpan + * - "ImmutableTensorType" can refer to a Tensor, TensorSpan or TensorView + * - "TensorType" can refer to a Tensor, TensorSpan or TensorView + * + * "ImmutableTensorType" is used when the tensor data might be used. + * "TensorType" is used when only meta-information such as the size or shape is required, i.e. the data won't be touched + */ + + /** if the \p axis is a negative index, the equivalent postive index is returned; otherwise, returns \p axis */ + CUDA4DNN_HOST_DEVICE constexpr std::size_t clamp_axis(int axis, std::size_t rank) { + return axis < 0 ? axis + rank : axis; + } + + /** @brief multi-dimensional contiguous non-copyable GPU tensor + * + * \tparam T type of data stored + * + * @note scalars or zero rank tensors are not supported + * @note the maximum rank supported is controlled by the `CSL_MAX_TENSOR_RANK` preprocessor symbol + */ + template + class Tensor { + static_assert(std::is_standard_layout::value, "T must staisfy StandardLayoutType"); + + public: + using value_type = typename ManagedPtr::element_type; + using pointer = typename ManagedPtr::pointer; + using const_pointer = typename ManagedPtr::const_pointer; + using size_type = typename ManagedPtr::size_type; + + Tensor() noexcept { } + Tensor(const Tensor&) = delete; + Tensor(Tensor&& other) noexcept { + data = std::move(other.data); + shape = other.shape; + other.shape.clear(); + } + + /** @brief constructs a tensor of a specific shape + * + * Whatever arguments are accepted by the resize methods are accepted here. + */ + template + Tensor(Args&&... sizes) { resize(std::forward(sizes)...); } + + Tensor& operator=(const Tensor&) = delete; + Tensor& operator=(Tensor&& other) noexcept { + data = std::move(other.data); + shape = other.shape; + other.shape.clear(); + return *this; + } + + /** returns true if the tensor is empty (or uninitialized) */ + bool empty() const noexcept { return shape.size() == 0; } + + /** returns the total number of elements in the tensor + * + * Pre-conditions: + * - tensor must be non-empty + */ + size_type size() const noexcept { + CV_Assert(!empty()); + return std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies()); + } + + /** returns the rank of the tensor + * + * Pre-conditions: + * - tensor must be non-empty + */ + size_type rank() const noexcept { + CV_Assert(!empty()); + return shape.size(); + } + + /** @brief returns the length of the axis + * + * Every axis is assigned a zero-based index which can be used to select an axis. + * Negative index can be used to select an axis from the end. + * + * Examples: + * > -1 represents the last axis + * > 0 represents the first axis + * > 1 represents the second axis + * + * Pre-conditions: + * - tensor must be non-empty + * - the axis must be in the range [-rank(), rank()) + */ + size_type get_axis_size(int axis) const noexcept { + axis = clamp_axis(axis, rank()); + CV_Assert(axis >= 0 && axis < rank()); + return shape[axis]; + } + + /** @brief returns the combined size of the axes in an axis range + * + * if the shape is [3 x 5 x 7 x 11] + * - `size_range(0, 2)` will return 3 x 5 = 15 + * - `size_range(1, 3)` will return 5 x 7 = 35 + * - `size_range(0, 4)` will return 3 x 5 x 7 x 11 = 1155 + * + * Pre-conditions: + * - tensor must be non-empty + * - `axis_start` must be less than or equal to `axis_end` + * - `axis_end` must be less than or equal to the rank + * + * returns one if the two `axis_start` and `axis_end` are equal + */ + size_type size_range(size_type axis_start, size_type axis_end) const noexcept { + CV_Assert(!empty()); + CV_Assert(axis_start <= axis_end); + CV_Assert(axis_end <= rank()); + auto start = std::begin(shape) + axis_start; + auto end = std::begin(shape) + axis_end; + return std::accumulate(start, end, 1, std::multiplies()); + } + + /** returns an std::vector containing axis lengths starting from axis zero + * + * Pre-conditions: + * - tensor must be non-empty + * + * Exception Guarantee: Strong + */ + std::vector shape_as_vector() const { + CV_Assert(!empty()); + return std::vector(std::begin(shape), std::end(shape)); + } + + /** returns a pointer to mutable device memory owned by the tensor */ + pointer get() noexcept { return data.get(); } + + /** returns a pointer to immutable device memory owned by the tensor */ + const_pointer get() const noexcept { return data.get(); } + + /** @brief releases the memory owned by the tensor + * + * Pre-conditions: + * - tensor must be non-empty + * + * Exception Guarantee: Strong + */ + void clear() { + CV_Assert(!empty()); + data.reset(); + shape.clear(); + } + + /** @brief resizes the tensor + * + * Pre-conditions: + * - [start, end) represents a forward range containing the length of the axes in order starting from axis zero + * - number of lengths provided must not exceed the maximum tensor rank (CSL_MAX_TENSOR_RANK) + * - the sizes must be positive integers + * + * Exception Guarantee: Strong + */ + template + typename std::enable_if::value, void> + ::type resize(ForwardItr start, ForwardItr end) { + CV_Assert(start != end); + CV_Assert(std::distance(start, end) <= CSL_MAX_TENSOR_RANK); + + using ItrValueType = typename std::iterator_traits::value_type; + auto total = std::accumulate(start, end, 1, std::multiplies()); + data.reset(total); + + shape.assign(start, end); + } + + /** @brief resizes the tensor + * constructs a range out of the arguments and invokes the range-based resize method + */ + template + void resize(Sizes... new_sizes_) { + static_assert(sizeof...(Sizes) <= CSL_MAX_TENSOR_RANK, "required rank exceeds maximum supported rank"); + static_assert(sizeof...(Sizes) > 0, "no sizes provided"); + std::array new_sizes = { static_cast(new_sizes_)... }; + resize(std::begin(new_sizes), std::end(new_sizes)); + } + + /** @brief resizes the tensor + * + * Pre-conditions: + * - the reference tensor must be non-empty + * + * Exception Guarantee: Strong + */ + template + void resize_as(const TensorType& tensor) { + CV_Assert(!tensor.empty()); + cxx_utils::resizable_static_array new_sizes(tensor.rank()); + for (int i = 0; i < new_sizes.size(); i++) + new_sizes[i] = tensor.get_axis_size(i); + resize(std::begin(new_sizes), std::end(new_sizes)); + } + + /** @brief reshapes the tensor + * + * Length deduction: + * The length of at most one axis can be deduced using the total size constraint. The axis can + * be marked for deduction by specifying the size as -1. + * + * The axes for which no size was provided (excluding -1) will be assumed to be one. + * + * Pre-conditions: + * - the tensor must be non-empty + * - [start, end) represents a forward range containing the length of the axes starting from axis zero + * - the number of lengths provided must be less than or equal to the tensor rank + * - at most one axis length is allowed for length deduction + * - the lengths provided must ensure that the total number of elements remains unchanged + * + * Exception Guarantee: Strong + */ + template + typename std::enable_if::value, void> + ::type reshape(ForwardItr start, ForwardItr end) { + CV_Assert(start != end); + CV_Assert(std::distance(start, end) <= rank()); + + using ItrValueType = typename std::iterator_traits::value_type; + + /* the user may leave at most one axis size for deduction by specifying -1 */ + auto sizes_to_deduce = std::count(start, end, -1); + if (sizes_to_deduce > 1) { CV_Error(Error::StsBadArg, "only one axis size can be deduced"); } + + /* sizes must be positive numbers with the exception of -1 */ + auto invalid_sizes = std::count_if(start, end, [](ItrValueType x) { + return !(x > 0 || x == -1); + }); + if (invalid_sizes) { CV_Error(Error::StsBadArg, "invalid axis size"); } + + /* compute the total number of elements in the new tensor */ + size_type unknown_size = 0; + auto total = std::accumulate(start, end, 1, std::multiplies()); + if (total < 0) { + /* there is an unknown size */ + if (std::abs(total) <= size()) { + unknown_size = size() / std::abs(total); + total = size(); + } + /* Edge case: if `total` is already more than size(), skip the deduction as it's impossible + ** Since `total` is negative, the size check which follows will fail and throw an error + */ + } + + /* the number of elements before and after reshape must be exactly same */ + if (total != size()) { + CV_Error(Error::StsBadArg, "new axes do not preserve the tensor element count"); + } + + /* we assume the size of the unspecified axes to be one */ + std::fill(std::begin(shape), std::end(shape), 1); + std::copy_backward(start, end, std::end(shape)); + + /* replace the unknown axis with the correct value */ + std::replace(std::begin(shape), std::end(shape), size_type(-1), unknown_size); + } + + /** @brief reshapes the tensor + * constructs a range out of the arguments and invokes range-based reshape method + */ + template + void reshape(Sizes... new_sizes_) { + static_assert(sizeof...(Sizes) <= CSL_MAX_TENSOR_RANK, "required rank exceeds maximum supported rank"); + static_assert(sizeof...(Sizes) > 0, "no sizes provided"); + std::array new_sizes = { static_cast(new_sizes_)... }; + reshape(std::begin(new_sizes), std::end(new_sizes)); + } + + /** @brief reshapes the tensor + * + * Pre-conditions: + * - the reference tensor must be a non-empty tensor + * - the reference tensor's rank must be lesser than or equal to the rank of target tensor + * + * Exception Guarantee: Strong + */ + template + void reshape_as(const TensorType& tensor) { + CV_Assert(!tensor.empty()); + cxx_utils::resizable_static_array new_sizes(tensor.rank()); + for (int i = 0; i < new_sizes.size(); i++) + new_sizes[i] = tensor.get_axis_size(i); + reshape(std::begin(new_sizes), std::end(new_sizes)); + } + + /** @brief squeezes the tensor + * + * removes all axes of unit size + * + * Pre-conditions: + * - the tensor must be non-empty + * - the tensor's rank must be at least two + * + * Exception Guarantee: Strong + */ + void squeeze() { + CV_Assert(!empty()); + CV_Assert(rank() >= 2); + auto itr = std::remove(std::begin(shape), std::end(shape), 1); + shape.resize(itr - std::begin(shape)); + } + + /** @brief squeezes the tensor + * + * removes the specified axis if the axis length is one; otherwise, ignores the request + * + * Pre-conditions: + * - the tensor must be non-empty + * - the tensor's rank must be at least two + * + * Exception Guarantee: Strong + */ + void squeeze(int axis) { + CV_Assert(!empty()); + CV_Assert(rank() >= 2); + axis = clamp_axis(axis, rank()); + CV_Assert(axis >= 0 && axis < rank()); + shape.erase(std::begin(shape) + axis); + } + + /** @brief unsqueezes the tensor + * + * adds a axis of unit size at the requested before the specified axis + * + * Pre-conditions: + * - the tensor must be non-empty + * - the tensor's rank must be less than the maximum supported rank (CSL_MAX_TENSOR_RANK) + * + * Exception Guarantee: Strong + */ + void unsqueeze(int axis = 0) { + CV_Assert(!empty()); + CV_Assert(rank() < CSL_MAX_TENSOR_RANK); + axis = clamp_axis(axis, rank()); + CV_Assert(axis >= 0 && axis < rank()); + shape.insert(std::begin(shape) + axis, 1); + } + + operator Span() noexcept { return Span(data.get(), size()); } + operator View() const noexcept { return View(data.get(), size()); } + + friend void swap(Tensor& lhs, Tensor& rhs) noexcept { + using std::swap; + swap(lhs.data, rhs.data); + swap(lhs.shape, rhs.shape); + } + + private: + cxx_utils::resizable_static_array shape; + ManagedPtr data; + }; + + /** @brief provides a non-owning mutable span of a Tensor + * + * \tparam T type of data stored by the tensor + * + * A span is valid if and only if the following hold true: + * - span is non-empty + * - spanned memory is still allocated + * + * A span may be used if and only if it is valid. + */ + template + class TensorSpan { + public: + using value_type = typename Tensor::value_type; + using pointer = typename Tensor::pointer; + using const_pointer = typename Tensor::const_pointer; + using size_type = typename Tensor::size_type; + + TensorSpan() noexcept : ptr{ nullptr } { } + TensorSpan(const TensorSpan&) noexcept = default; + TensorSpan(Tensor& tensor) noexcept : ptr{ tensor.get() } { + const auto rank = tensor.rank(); + shape.resize(rank); + for (int i = 0; i < rank; i++) + shape[i] = tensor.get_axis_size(i); + } + + template + TensorSpan(pointer ptr_, ForwardItr start, ForwardItr end) : ptr{ ptr_ } { + CV_Assert(start != end); + CV_Assert(std::distance(start, end) <= CSL_MAX_TENSOR_RANK); + + using ItrValueType = typename std::iterator_traits::value_type; + if (std::any_of(start, end, [](ItrValueType x) { return x <= 0; })) { + CV_Error(Error::StsBadArg, "the given shape contains negative or zero size"); + } + + shape.assign(start, end); + } + + /** creates a subspan of a tensor (or span); refer to subspan method for more details */ + template + TensorSpan(TensorSpan other, size_type offset, Args&&... args) + : TensorSpan(other.subspan(offset, std::forward(args)...)) { } + + /** returns true if the span is empty */ + bool empty() const noexcept { return shape.size() == 0; } + + /** returns the total number of elements in the span + * + * Pre-conditions: + * - span must be non-empty + */ + size_type size() const noexcept { + CV_Assert(!empty()); + return std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies()); + } + + /** returns the rank of the span + * + * Pre-conditions: + * - span must be non-empty + */ + size_type rank() const noexcept { + CV_Assert(!empty()); + return shape.size(); + } + + /** @brief returns the length of the axis + * + * Every axis is assigned a zero-based index which can be used to select an axis. + * Negative index can be used to select an axis from the end. + * + * Examples: + * > -1 represents the last axis + * > 0 represents the first axis + * > 1 represents the second axis + * + * Pre-conditions: + * - span must be non-empty + * - the axis must be in the range [-rank(), rank()) + */ + size_type get_axis_size(int axis) const noexcept { + axis = clamp_axis(axis, rank()); + CV_Assert(axis >= 0 && axis < rank()); + return shape[axis]; + } + + /** @brief returns the combined size of the axes in an axis range + * + * if the shape is [3 x 5 x 7 x 11] + * - `size_range(0, 2)` will return 3 x 5 = 15 + * - `size_range(1, 3)` will return 5 x 7 = 35 + * - `size_range(0, 4)` will return 3 x 5 x 7 x 11 = 1155 + * + * Pre-conditions: + * - span must be non-empty + * - `axis_start` must be less than or equal to `axis_end` + * - `axis_end` must be less than or equal to the rank + * + * returns one if the two `axis_start` and `axis_end` are equal + */ + size_type size_range(size_type axis_start, size_type axis_end) const noexcept { + CV_Assert(!empty()); + CV_Assert(axis_start <= axis_end); + CV_Assert(axis_end <= rank()); + auto start = std::begin(shape) + axis_start; + auto end = std::begin(shape) + axis_end; + return std::accumulate(start, end, 1, std::multiplies()); + } + + /** returns an std::vector containing axis lengths starting from axis zero + * + * Pre-conditions: + * - span must be non-empty + * + * Exception Guarantee: Strong + */ + std::vector shape_as_vector() const { + CV_Assert(!empty()); + return std::vector(std::begin(shape), std::end(shape)); + } + + /** returns a pointer to mutable device memory */ + pointer get() const noexcept { return ptr; } + + /** @brief clears the span + * + * Pre-conditions: + * - span must be non-empty + * + * Exception Guarantee: Strong + */ + void clear() noexcept { + CV_Assert(!empty()); + ptr = nullptr; + shape.clear(); + } + + /** @brief reshapes the span + * + * Length deduction: + * The length of at most one axis can be deduced using the total size constraint. The axis can + * be marked for deduction by specifying the corresponding size as -1. + * + * The axes for which no size was provided (excluding -1) will be assumed to be one. + * + * Pre-conditions: + * - the span must be non-empty + * - [start, end) represents a forward range containing the length of the axes in order + * - the number of axis lengths must be less than or equal to the rank + * - at most one axis length is allowed for length deduction + * - the lengths provided must ensure that the total number of elements remains unchnged + * + * Exception Guarantee: Strong + */ + template + typename std::enable_if::value, void> + ::type reshape(ForwardItr start, ForwardItr end) { + CV_Assert(start != end); + CV_Assert(std::distance(start, end) <= rank()); + + using ItrValueType = typename std::iterator_traits::value_type; + + /* the user may leave at most one axis size for deduction by specifying -1 */ + auto sizes_to_deduce = std::count(start, end, -1); + if (sizes_to_deduce > 1) { CV_Error(Error::StsBadArg, "only one axis size can be deduced"); } + + /* sizes must be positive numbers with the exception of -1 */ + auto invalid_sizes = std::count_if(start, end, [](ItrValueType x) { + return !(x > 0 || x == -1); + }); + if (invalid_sizes) { CV_Error(Error::StsBadArg, "invalid axis size"); } + + /* compute the total number of elements in the new tensor */ + size_type unknown_size = 0; + auto total = std::accumulate(start, end, 1, std::multiplies()); + if (total < 0) { + /* there is an unknown size */ + if (std::abs(total) <= size()) { + unknown_size = size() / std::abs(total); + total = size(); + } + /* Edge case: if `total` is already more than size(), skip the deduction as it's impossible + ** Since `total` is negative, the size check which follows will fail and throw an error + */ + } + + /* the number of elements before and after reshape must be exactly same */ + if (total != size()) { + CV_Error(Error::StsBadArg, "new axes do not preserve the tensor element count"); + } + + /* we assume the size of the unspecified axes to be one */ + std::fill(std::begin(shape), std::end(shape), 1); + std::copy_backward(start, end, std::end(shape)); + + /* replace the unknown axis with the correct value */ + std::replace(std::begin(shape), std::end(shape), size_type(-1), unknown_size); + } + + /** @brief reshapes the tensor + * constructs a range out of the arguments and invokes the range-based reshape method + */ + template + void reshape(Sizes... new_sizes_) { + static_assert(sizeof...(Sizes) <= CSL_MAX_TENSOR_RANK, "unsupported tensor rank"); + static_assert(sizeof...(Sizes) > 0, "no sizes provided"); + std::array new_sizes = { static_cast(new_sizes_)... }; + reshape(std::begin(new_sizes), std::end(new_sizes)); + } + + /** @brief reshapes the span + * + * Pre-conditions: + * - the reference tensor/span/view must be non-empty + * - the reference tensor/span/view's rank must be less than or equal to the rank of the span + * + * Exception Guarantee: Strong + */ + template + void reshape_as(const TensorType& tensor) { + CV_Assert(!tensor.empty()); + cxx_utils::resizable_static_array new_sizes(tensor.rank()); + for (int i = 0; i < new_sizes.size(); i++) + new_sizes[i] = tensor.get_axis_size(i); + reshape(std::begin(new_sizes), std::end(new_sizes)); + } + + /** @brief squeezes the tensor + * + * removes all axes of unit size + * + * Pre-conditions: + * - the span must be non-empty + * - the span's rank must be at least two + * + * Exception Guarantee: Strong + */ + void squeeze() { + CV_Assert(!empty()); + CV_Assert(rank() >= 2); + auto itr = std::remove(std::begin(shape), std::end(shape), 1); + shape.resize(itr - std::begin(shape)); + } + + /** @brief squeezes the tensor + * + * removes the specified axis if the axis length is one; otherwise, ignores the request + * + * Pre-conditions: + * - the span must be non-empty + * - the span's rank must be at least two + * + * Exception Guarantee: Strong + */ + void squeeze(int axis) { + CV_Assert(!empty()); + CV_Assert(rank() >= 2); + axis = clamp_axis(axis, rank()); + CV_Assert(axis >= 0 && axis < rank()); + shape.erase(std::begin(shape) + axis); + } + + /** @brief unsqueezes the tensor + * + * adds a axis of unit size at the requested before the specified axis + * + * Pre-conditions: + * - the span must be non-empty + * - the span's rank must be less than the maximum supported rank (CSL_MAX_TENSOR_RANK) + * + * Exception Guarantee: Strong + */ + void unsqueeze(int axis = 0) { + CV_Assert(!empty()); + CV_Assert(rank() < CSL_MAX_TENSOR_RANK); + axis = clamp_axis(axis, rank()); + CV_Assert(axis >= 0 && axis < rank()); + shape.insert(std::begin(shape) + axis, 1); + } + + /** @brief obtains a subspan of the span + * + * Pre-conditions: + * - the span must be non-empty + * - the `offset` must be less than the size of the span + * - [start, end) represents a forward range containing length of the subspan axes + * - the lengths provided must ensure that the number of elements does not exceed (old size - offset) + * + * Exception Guarantee: Strong + */ + template + typename std::enable_if::value, TensorSpan> + ::type subspan(size_type offset, ForwardItr start, ForwardItr end) const { + CV_Assert(start != end); + CV_Assert(std::distance(start, end) <= rank()); + + auto cur_size = size(); + CV_Assert(offset < cur_size); + + using ItrValueType = typename std::iterator_traits::value_type; + + /* sizes must be positive numbers */ + auto invalid_sizes = std::count_if(start, end, [](ItrValueType x) { + return !(x > 0); + }); + if (invalid_sizes) { CV_Error(Error::StsBadArg, "invalid axis size"); } + + /* the number of elements must be equal to the new size */ + auto max_size = (cur_size - offset); + auto total = std::accumulate(start, end, 1, std::multiplies()); + if (total > max_size) { + CV_Error(Error::StsBadArg, "axis lengths lead to OOB accesses"); + } + + TensorSpan temp; + temp.shape.assign(start, end); + temp.ptr = ptr + offset; + return temp; + } + + /** @brief obtains a subspan of the span + * constructs a range out of the size arguments and invokes the range-based subspan method + */ + template + TensorSpan subspan(size_type offset, Sizes... new_sizes_) const { + static_assert(sizeof...(Sizes) <= CSL_MAX_TENSOR_RANK, "required rank exceeds maximum supported rank"); + static_assert(sizeof...(Sizes) > 0, "no sizes provided"); + std::array new_sizes = { static_cast(new_sizes_)... }; + return subspan(offset, std::begin(new_sizes), std::end(new_sizes)); + } + + operator Span() noexcept { return Span(ptr, size()); } + operator View() const noexcept { return View(ptr, size()); } + + friend void swap(TensorSpan& lhs, TensorSpan& rhs) noexcept { + using std::swap; + swap(lhs.ptr, rhs.ptr); + swap(lhs.shape, rhs.shape); + } + + private: + cxx_utils::resizable_static_array shape; + pointer ptr; + }; + + /** @brief view of a tensor + * + * \tparam T type of data stored by the tensor + * + * A view is valid if and only if the following hold true: + * - view is non-empty + * - viewed memory is still allocated + */ + template + class TensorView { + public: + using value_type = typename Tensor::value_type; + using pointer = typename Tensor::pointer; + using const_pointer = typename Tensor::const_pointer; + using size_type = typename Tensor::size_type; + + TensorView() noexcept : ptr{ nullptr } { } + TensorView(const TensorView&) noexcept = default; + TensorView(TensorSpan other) noexcept : ptr{ other.get() } { + const auto rank = other.rank(); + shape.resize(rank); + for (int i = 0; i < rank; i++) + shape[i] = other.get_axis_size(i); + } + TensorView(const Tensor& tensor) noexcept : ptr{ tensor.get() } { + const auto rank = tensor.rank(); + shape.resize(rank); + for (int i = 0; i < rank; i++) + shape[i] = tensor.get_axis_size(i); + } + + template + TensorView(pointer ptr_, ForwardItr start, ForwardItr end) : ptr{ ptr_ } { + CV_Assert(start != end); + CV_Assert(std::distance(start, end) <= CSL_MAX_TENSOR_RANK); + + using ItrValueType = typename std::iterator_traits::value_type; + if (std::any_of(start, end, [](ItrValueType x) { return x <= 0; })) { + CV_Error(Error::StsBadArg, "the given shape contains negative or zero size"); + } + + shape.assign(start, end); + } + + /** creates a subview of a tensor (or span or view); refer to subview method for more details */ + template + TensorView(TensorView other, size_type offset, Args&&... args) noexcept + : TensorView(other.subview(offset, std::forward(args)...)) { } + + TensorView& operator=(const TensorView&) = default; + TensorView& operator=(TensorSpan other) noexcept { + TensorView tmp(other); + swap(*this, tmp); + return *this; + } + + /** returns true if the view is empty */ + bool empty() const noexcept { return shape.size() == 0; } + + /** returns the total number of elements in the view + * + * Pre-conditions: + * - view must be non-empty + */ + size_type size() const noexcept { + CV_Assert(!empty()); + return std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies()); + } + + /** returns the rank of the view + * + * Pre-conditions: + * - view must be non-empty + */ + size_type rank() const noexcept { + CV_Assert(!empty()); + return shape.size(); + } + + /** @brief returns the length of the axis + * + * Every axis is assigned a zero-based index which can be used to select an axis. + * Negative index can be used to select an axis from the end. + * + * Examples: + * > -1 represents the last axis + * > 0 represents the first axis + * > 1 represents the second axis + * + * Pre-conditions: + * - view must be non-empty + * - the axis must be in the range [-rank(), rank()) + */ + size_type get_axis_size(int axis) const noexcept { + axis = clamp_axis(axis, rank()); + CV_Assert(axis >= 0 && axis < rank()); + return shape[axis]; + } + + /** @brief returns the combined size of the axes in an axis range + * + * if the shape is [3 x 5 x 7 x 11] + * - `size_range(0, 2)` will return 3 x 5 = 15 + * - `size_range(1, 3)` will return 5 x 7 = 35 + * - `size_range(0, 4)` will return 3 x 5 x 7 x 11 = 1155 + * + * Pre-conditions: + * - view must be non-empty + * - `axis_start` must be less than or equal to `axis_end` + * - `axis_end` must be less than or equal to the rank + * + * returns one if the two `axis_start` and `axis_end` are equal + */ + size_type size_range(size_type axis_start, size_type axis_end) const noexcept { + CV_Assert(!empty()); + CV_Assert(axis_start <= axis_end); + CV_Assert(axis_end <= rank()); + auto start = std::begin(shape) + axis_start; + auto end = std::begin(shape) + axis_end; + return std::accumulate(start, end, 1, std::multiplies()); + } + + /** returns an std::vector containing axis lengths starting from axis zero + * + * Pre-conditions: + * - view must be non-empty + * + * Exception Guarantee: Strong + */ + std::vector shape_as_vector() const { + CV_Assert(!empty()); + return std::vector(std::begin(shape), std::end(shape)); + } + + /** returns a device pointer to immutable device memory */ + const_pointer get() const noexcept { return ptr; } + + /** @brief reshapes the view + * + * Length deduction: + * The length of at most one axis can be deduced using the total size constraint. The axis can + * be marked for deduction by specifying the size as -1. + * + * The axes for which no size was provided (excluding -1) will be assumed to be one. + * + * Pre-conditions: + * - view must be non-empty + * - [start, end) represents a forward range containing length of the axes in order starting from axis zero + * - the number of axis lengths must be less than or equal to the tensor rank + * - at most one axis length is allowed for length deduction + * - the lengths provided must ensure that the total number of elements remains unchnged + * + * Exception Guarantee: Strong + */ + template + typename std::enable_if::value, void> + ::type reshape(ForwardItr start, ForwardItr end) { + CV_Assert(start != end); + CV_Assert(std::distance(start, end) <= rank()); + + using ItrValueType = typename std::iterator_traits::value_type; + + /* the user may leave at most one axis size for deduction by specifying -1 */ + auto sizes_to_deduce = std::count(start, end, -1); + if (sizes_to_deduce > 1) { CV_Error(Error::StsBadArg, "only one axis size can be deduced"); } + + /* sizes must be positive numbers with the exception of -1 */ + auto invalid_sizes = std::count_if(start, end, [](ItrValueType x) { + return !(x > 0 || x == -1); + }); + if (invalid_sizes) { CV_Error(Error::StsBadArg, "invalid axis size"); } + + /* compute the total number of elements in the new tensor */ + size_type unknown_size = 0; + auto total = std::accumulate(start, end, 1, std::multiplies()); + if (total < 0) { + /* there is an unknown size */ + if (std::abs(total) <= size()) { + unknown_size = size() / std::abs(total); + total = size(); + } + /* Edge case: if `total` is already more than size(), skip the deduction as it's impossible + ** Since `total` is negative, the size check which follows will fail and throw an error + */ + } + + /* the number of elements before and after reshape must be exactly same */ + if (total != size()) { + CV_Error(Error::StsBadArg, "new axes do not preserve the tensor element count"); + } + + /* we assume the size of the unspecified axes to be one */ + std::fill(std::begin(shape), std::end(shape), 1); + std::copy_backward(start, end, std::end(shape)); + + /* replace the unknown axis with the correct value */ + std::replace(std::begin(shape), std::end(shape), size_type(-1), unknown_size); + } + + /** @brief reshapes the view + * constructs a range out of the arguments and invokes the range-based reshape method + */ + template + void reshape(Sizes... new_sizes_) { + static_assert(sizeof...(Sizes) <= CSL_MAX_TENSOR_RANK, "required rank exceeds maximum supported rank"); + static_assert(sizeof...(Sizes) > 0, "no sizes provided"); + std::array new_sizes = { static_cast(new_sizes_)... }; + reshape(std::begin(new_sizes), std::end(new_sizes)); + } + + /** @brief reshapes the view + * + * Pre-conditions: + * - the reference tensor/span/view must be non-empty + * - the reference tensor/span/view's rank must be less than or equal to the rank of the view + * + * Exception Guarantee: Strong + */ + template + void reshape_as(const TensorType& tensor) { + CV_Assert(!tensor.empty()); + cxx_utils::resizable_static_array new_sizes(tensor.rank()); + for (int i = 0; i < new_sizes.size(); i++) + new_sizes[i] = tensor.get_axis_size(i); + reshape(std::begin(new_sizes), std::end(new_sizes)); + } + + /** @brief squeezes the tensor + * + * removes all axes of unit size + * + * Pre-conditions: + * - the view must be non-empty + * - the view's rank must be at least two + * + * Exception Guarantee: Strong + */ + void squeeze() { + CV_Assert(!empty()); + CV_Assert(rank() >= 2); + auto itr = std::remove(std::begin(shape), std::end(shape), 1); + shape.resize(itr - std::begin(shape)); + } + + /** @brief squeezes the tensor + * + * removes the specified axis if the axis length is one; otherwise, ignores the request + * + * Pre-conditions: + * - the view must be non-empty + * - the view's rank must be at least two + * + * Exception Guarantee: Strong + */ + void squeeze(int axis) { + CV_Assert(!empty()); + CV_Assert(rank() >= 2); + axis = clamp_axis(axis, rank()); + CV_Assert(axis >= 0 && axis < rank()); + shape.erase(std::begin(shape) + axis); + } + + /** @brief unsqueezes the tensor + * + * adds a axis of unit size at the requested before the specified axis + * + * Pre-conditions: + * - the view must be non-empty + * - the view's rank must be less than the maximum supported rank (CSL_MAX_TENSOR_RANK) + * + * Exception Guarantee: Strong + */ + void unsqueeze(int axis = 0) { + CV_Assert(!empty()); + CV_Assert(rank() < CSL_MAX_TENSOR_RANK); + axis = clamp_axis(axis, rank()); + CV_Assert(axis >= 0 && axis < rank()); + shape.insert(std::begin(shape) + axis, 1); + } + + /** @brief obtains a subview of the view + * + * The axes for which no size was provided will be assumed to be one. + * + * Pre-conditions: + * - the view must be non-empty + * - the `offset` must be less than the size of the view + * - [start, end) represents a forward range containing length of the subview axes in order + * - the number of axis lengths provided must be less than or equal to the tensor rank + * - the lengths provided must ensure that the number of elements does not exceed (old size - offset) + * + * Exception Guarantee: Strong + */ + template + typename std::enable_if::value, TensorView> + ::type subview(size_type offset, ForwardItr start, ForwardItr end) const { + CV_Assert(start != end); + CV_Assert(std::distance(start, end) <= rank()); + + auto cur_size = size(); + CV_Assert(offset < cur_size); + + using ItrValueType = typename std::iterator_traits::value_type; + + /* sizes must be positive numbers */ + auto invalid_sizes = std::count_if(start, end, [](ItrValueType x) { + return !(x > 0); + }); + if (invalid_sizes) { CV_Error(Error::StsBadArg, "invalid axis size"); } + + /* the number of elements must be equal to the new size */ + auto max_size = (cur_size - offset); + auto total = std::accumulate(start, end, 1, std::multiplies()); + if (total > max_size) { + CV_Error(Error::StsBadArg, "axes lengths lead to OOB accesses"); + } + + TensorView temp; + temp.shape.assign(start, end); + temp.ptr = ptr + offset; + return temp; + } + + /** @brief obtains a subview of the view + * constructs a range out of the size arguments and invokes the range-based subview method + */ + template + TensorView subview(size_type offset, Sizes... new_sizes_) const { + static_assert(sizeof...(Sizes) <= CSL_MAX_TENSOR_RANK, "required rank exceeds maximum supported rank"); + static_assert(sizeof...(Sizes) > 0, "no sizes provided"); + std::array new_sizes = { static_cast(new_sizes_)... }; + return subview(offset, std::begin(new_sizes), std::end(new_sizes)); + } + + operator View() const noexcept { return View(ptr, size()); } + + friend void swap(TensorView& lhs, TensorView& rhs) noexcept { + using std::swap; + swap(lhs.ptr, rhs.ptr); + swap(lhs.shape, rhs.shape); + } + + private: + cxx_utils::resizable_static_array shape; + const_pointer ptr; + }; + + /** returns true if the two TensorType objects have the same shape */ + template + bool is_shape_same(const TensorType1& x, const TensorType2& y) noexcept { + auto rank1 = x.rank(); + auto rank2 = y.rank(); + + if (rank1 != rank2) + return false; + + for (int i = 0; i < rank1; i++) + if (x.get_axis_size(i) != y.get_axis_size(i)) + return false; + return true; + } + + /** returns true if the two TensorType objects are compatible */ + template + bool is_shape_compatible(const TensorType1& x, const TensorType2& y) noexcept { + const auto rank1 = x.rank(); + const auto rank2 = y.rank(); + + /* mathematically not required but is a technically required */ + if (rank1 != rank2) + return false; + + for (int i = 0; i < rank1; i++) + if (x.get_axis_size(i) != y.get_axis_size(i) && + x.get_axis_size(i) != 1 && y.get_axis_size(i) != 1) + return false; + return true; + } + + /** returns the rank to which the given tensor can be squeezed to */ + template + std::size_t get_effective_rank(const TensorType& x) noexcept { + const auto rank = x.rank(); + auto effective_rank = rank; + for (int i = 0; i < rank; i++, effective_rank--) + if (x.get_axis_size(i) != 1) + break; + return effective_rank; + } + +}}}} /* namespace cv::dnn::cuda4dnn::csl */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_TENSOR_HPP */ diff --git a/modules/dnn/src/cuda4dnn/csl/tensor_ops.hpp b/modules/dnn/src/cuda4dnn/csl/tensor_ops.hpp new file mode 100644 index 0000000..1d39674 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/csl/tensor_ops.hpp @@ -0,0 +1,384 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_CSL_TENSOR_OPS_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_CSL_TENSOR_OPS_HPP + +#include "stream.hpp" +#include "tensor.hpp" +#include "pointer.hpp" +#include "cublas.hpp" +#include "cudnn.hpp" +#include "workspace.hpp" + +#include "cudnn/convolution.hpp" +#include "cudnn/pooling.hpp" +#include "cudnn/lrn.hpp" +#include "cudnn/softmax.hpp" +#include "cudnn/transform.hpp" +#include "cudnn/transpose_convolution.hpp" + +#include + +#include +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { + + namespace tensor_ops { + + /** @brief copies data between tensors + * + * Pre-conditions: + * - \p dest and \p src must have the same shape + * + * Exception Gaurantee: Basic + */ + template inline + void copy(const Stream& stream, TensorSpan dest, TensorView src) { + CV_Assert(is_shape_same(dest, src)); + if (dest.get() != src.get()) + memcpy(dest.get(), src.get(), dest.size(), stream); + } + + /** @brief performs generalized matrix-multiplication + * + * Pre-conditions: + * - \p A and \p B must meet the mathematical requirements for matrix multiplication + * - \p result must be large enough to hold the result + * + * Exception Gaurantee: Basic + */ + template inline + void gemm(const cublas::Handle& handle, T beta, TensorSpan result, T alpha, bool transa, TensorView A, bool transb, TensorView B) { + /* matrix operations can be performed only on rank two or less tensors */ + CV_Assert(get_effective_rank(A) <= 2 && + get_effective_rank(B) <= 2 && + get_effective_rank(result) <= 2); + + /* check dimension requirements for matrix multiplication */ + if (!transa && !transb) { + CV_Assert(A.get_axis_size(-2) == result.get_axis_size(-2)); + CV_Assert(A.get_axis_size(-1) == B.get_axis_size(-2)); + CV_Assert(B.get_axis_size(-1) == result.get_axis_size(-1)); + } else if (!transa && transb) { + CV_Assert(A.get_axis_size(-2) == result.get_axis_size(-2)); + CV_Assert(A.get_axis_size(-1) == B.get_axis_size(-1)); + CV_Assert(B.get_axis_size(-2) == result.get_axis_size(-1)); + } else if (transa && !transb) { + CV_Assert(A.get_axis_size(-1) == result.get_axis_size(-2)); + CV_Assert(A.get_axis_size(-2) == B.get_axis_size(-2)); + CV_Assert(B.get_axis_size(-1) == result.get_axis_size(-1)); + } else { + CV_Assert(A.get_axis_size(-1) == result.get_axis_size(-2)); + CV_Assert(A.get_axis_size(-2) == B.get_axis_size(-1)); + CV_Assert(B.get_axis_size(-2) == result.get_axis_size(-1)); + } + + const auto result_nr = result.get_axis_size(-2); + const auto result_nc = result.get_axis_size(-1); + const auto common_dim = A.get_axis_size(transa ? -2 : -1); + const auto A_nc = A.get_axis_size(-1); + const auto B_nc = B.get_axis_size(-1); + + /* tensors are stored in row-major but cublas::gemm operates on column-major matrices + * a row-major matrix when read as column-major matrix gives the transpose of the intended matrix + * + * Required: C = AB + * what cuBLAS sees: C^T = A^TB^T = (BA)^T + * + * By reversing operands, we effectively perform: + * C^T = B^TA^T = (AB)^T + * + * which gives C = AB + */ + cublas::gemm(handle, + transb, transa, + result_nc, result_nr, common_dim, + alpha, B.get(), B_nc, + A.get(), A_nc, + beta, result.get(), result_nc); + } + + /** @brief performs element-wise addition with broadcasting + * + * Pre-conditions: + * - \p A and \p result must be compatible tensors + * + * Exception Gaurantee: Basic + */ + template inline + void softmax(const cudnn::Handle& handle, TensorSpan output, TensorView input, int channel_axis, bool log) { + CV_Assert(is_shape_same(output, input)); + + channel_axis = clamp_axis(channel_axis, input.rank()); + + std::size_t outer_size = input.size_range(0, channel_axis); + auto channel_size = input.get_axis_size(channel_axis); + std::size_t inner_size = input.size_range(channel_axis + 1, input.rank()); + + std::array shape = { outer_size, channel_size, 1, inner_size }; + + using cudnn::TensorDescriptor; + auto inputDesc = TensorDescriptor(shape); + auto outputDesc = TensorDescriptor(shape); + cudnn::softmax(handle, outputDesc, output.get(), inputDesc, input.get(), log); + } + } + + template + class Convolution { + using TensorDescriptor = cudnn::TensorDescriptor; + using FilterDescriptor = cudnn::FilterDescriptor; + using ConvolutionDescriptor = cudnn::ConvolutionDescriptor; + using ConvolutionAlgorithm = cudnn::ConvolutionAlgorithm; + + public: + struct params_type { + std::vector input_shape; + std::vector filter_shape; + + std::vector padding; + std::vector stride; + std::vector dilation; + + std::size_t groups; + }; + + Convolution() = default; + Convolution(const Convolution&) = delete; + Convolution(Convolution&&) = default; + Convolution(cudnn::Handle handle, const params_type& params) { + cudnnHandle = std::move(handle); + + inputTensorDesc = TensorDescriptor(params.input_shape); + filterDesc = FilterDescriptor(params.filter_shape); + convDesc = ConvolutionDescriptor(params.padding, params.stride, params.dilation, params.groups); + + std::vector output_dims; + getConvolutionForwardOutputDim(convDesc, filterDesc, inputTensorDesc, output_dims); + outputTensorDesc = TensorDescriptor(output_dims); + + algo = ConvolutionAlgorithm(cudnnHandle, convDesc, filterDesc, inputTensorDesc, outputTensorDesc); + } + + Convolution& operator=(const Convolution&) = delete; + Convolution& operator=(Convolution&&) = default; + + std::size_t get_workspace_size() const noexcept { + return algo.get_workspace_size(); + } + + void convolve(TensorSpan output, TensorView input, TensorView filters, WorkspaceInstance scratchpad) { + cudnn::convolve( + cudnnHandle, + convDesc, algo, scratchpad, + filterDesc, filters.get(), + inputTensorDesc, input.get(), + 1.0, 0.0, outputTensorDesc, output.get() + ); + } + + private: + cudnn::Handle cudnnHandle; + TensorDescriptor inputTensorDesc, outputTensorDesc; + FilterDescriptor filterDesc; + ConvolutionDescriptor convDesc; + ConvolutionAlgorithm algo; + }; + + template + class TransposeConvolution { + using TensorDescriptor = cudnn::TensorDescriptor; + using FilterDescriptor = cudnn::FilterDescriptor; + using ConvolutionDescriptor = cudnn::ConvolutionDescriptor; + using TransposeConvolutionAlgorithm = cudnn::TransposeConvolutionAlgorithm; + + public: + struct params_type { + std::vector input_shape; + std::vector output_shape; + + std::vector filter_shape; + + std::vector padding; + std::vector stride; + std::vector dilation; + + std::size_t groups; + }; + + TransposeConvolution() = default; + TransposeConvolution(const TransposeConvolution&) = delete; + TransposeConvolution(TransposeConvolution&&) = default; + TransposeConvolution(cudnn::Handle handle, const params_type& params) { + cudnnHandle = std::move(handle); + + filterDesc = FilterDescriptor(params.filter_shape); + convDesc = ConvolutionDescriptor(params.padding, params.stride, params.dilation, params.groups); + + /* input_shape is the output shape for convolution + * output_shape is the input shape for convolution + */ + convInputTensorDesc = TensorDescriptor(params.output_shape); + + std::vector conv_output_dims; + getConvolutionForwardOutputDim(convDesc, filterDesc, convInputTensorDesc, conv_output_dims); + + /* the convolution output must be identical to what cuDNN expects */ + CV_Assert(std::equal(std::begin(conv_output_dims), std::end(conv_output_dims), std::begin(params.input_shape))); + + convOutputTensorDesc = TensorDescriptor(params.input_shape); + + algo = TransposeConvolutionAlgorithm(cudnnHandle, convDesc, filterDesc, convOutputTensorDesc, convInputTensorDesc); + } + + TransposeConvolution& operator=(const TransposeConvolution&) = delete; + TransposeConvolution& operator=(TransposeConvolution&&) = default; + + std::size_t get_workspace_size() const noexcept { + return algo.get_workspace_size(); + } + + void transpose_convolve(TensorSpan output, TensorView input, TensorView filters, WorkspaceInstance scratchpad) { + cudnn::transpose_convolve( + cudnnHandle, + convDesc, algo, scratchpad, + filterDesc, filters.get(), + convOutputTensorDesc, input.get(), + 1.0, 0.0, convInputTensorDesc, output.get() + ); + } + + private: + cudnn::Handle cudnnHandle; + TensorDescriptor convInputTensorDesc, convOutputTensorDesc; + FilterDescriptor filterDesc; + ConvolutionDescriptor convDesc; + TransposeConvolutionAlgorithm algo; + }; + + template + class Pooling { + using TensorDescriptor = cudnn::TensorDescriptor; + using PoolingDescriptor = cudnn::PoolingDescriptor; + + public: + using PoolingType = PoolingDescriptor::PoolingType; + + struct params_type { + std::vector input_shape; + std::vector output_shape; + + std::vector window_size; + std::vector padding; + std::vector stride; + + PoolingType type; + }; + + Pooling() = default; + Pooling(const Pooling&) = delete; + Pooling(Pooling&&) = default; + Pooling(cudnn::Handle handle, const params_type& params) { + cudnnHandle = std::move(handle); + + inputTensorDesc = TensorDescriptor(params.input_shape); + poolingDesc = PoolingDescriptor(params.window_size, params.padding, params.stride, params.type); + + //std::vector output_dim; + //getPoolingForwardOutputDim(poolingDesc, inputTensorDesc, output_dim); + outputTensorDesc = TensorDescriptor(params.output_shape); + } + + Pooling& operator=(const Pooling&) = delete; + Pooling& operator=(Pooling&&) = default; + + void pool(TensorView input, TensorSpan output) { + cudnn::pool( + cudnnHandle, + poolingDesc, + inputTensorDesc, input.get(), + 1.0, 0.0, outputTensorDesc, output.get() + ); + } + + private: + cudnn::Handle cudnnHandle; + TensorDescriptor inputTensorDesc, outputTensorDesc; + PoolingDescriptor poolingDesc; + }; + + template + class LRN { + using LRNDescriptor = cudnn::LRNDescriptor; + using TensorDescriptor = cudnn::TensorDescriptor; + + public: + using LRNType = LRNDescriptor::LRNType; + + LRN() = default; + LRN(const LRN&) = delete; + LRN(LRN&&) = default; + LRN(cudnn::Handle handle, std::size_t local_size, T alpha, T beta, T k, LRNType type) { + cudnnHandle = std::move(handle); + lrnDesc = LRNDescriptor(local_size, alpha, beta, k, type); + } + + LRN& operator=(const LRN&) = delete; + LRN& operator=(LRN&&) = default; + + void normalize(TensorView input, TensorSpan output, WorkspaceInstance workspace) { + cudnn::LRNForward( + cudnnHandle, + lrnDesc, + TensorDescriptor(input.shape_as_vector()), input.get(), + 1.0, 0.0, TensorDescriptor(output.shape_as_vector()), output.get(), + workspace + ); + } + + private: + cudnn::Handle cudnnHandle; + LRNDescriptor lrnDesc; + }; + + template + class TensorTransform { + using TensorTransformDescriptor = cudnn::TensorTransformDescriptor; + using TensorDescriptor = cudnn::TensorDescriptor; + + public: + TensorTransform() = default; + TensorTransform(const TensorTransform&) = delete; + TensorTransform(TensorTransform&&) = default; + + template + TensorTransform(cudnn::Handle handle, const SequenceContainer& paddingLeft, const SequenceContainer& paddingRight) { + cudnnHandle = std::move(handle); + transDesc = TensorTransformDescriptor(paddingLeft, paddingRight); + } + + TensorTransform& operator=(const TensorTransform&) = delete; + TensorTransform& operator=(TensorTransform&&) = default; + + void transform(TensorView input, TensorSpan output) { + cudnn::transform( + cudnnHandle, + transDesc, + TensorDescriptor(input.shape_as_vector()), input.get(), + TensorDescriptor(output.shape_as_vector()), output.get() + ); + } + + private: + cudnn::Handle cudnnHandle; + TensorTransformDescriptor transDesc; + }; + +}}}} /* namespace cv::dnn::cuda4dnn::csl */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_TENSOR_OPS_HPP */ diff --git a/modules/dnn/src/cuda4dnn/csl/workspace.hpp b/modules/dnn/src/cuda4dnn/csl/workspace.hpp new file mode 100644 index 0000000..0d852c6 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/csl/workspace.hpp @@ -0,0 +1,166 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_CSL_WORKSPACE_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_CSL_WORKSPACE_HPP + +#include "pointer.hpp" +#include "span.hpp" +#include "tensor.hpp" + +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { + + /** @brief maintains a single block of reusable device memory + * + * Each Workspace object is intended to be used by a single entity at a time but by + * different entities at different times. It maintains a single reusable block of memory which + * is sufficient for the largest consumer. + */ + class Workspace { + public: + + /** @brief reserve \p bytes of memory */ + void require(std::size_t bytes) { + if (bytes > ptr.size()) + ptr.reset(bytes); + } + + /** @brief number of bytes reserved by the largest consumer */ + std::size_t size() const noexcept { + return ptr.size(); + } + + /** @brief returns the pointer to the workspace memory */ + DevicePtr get() { + return ptr.get(); + } + + private: + ManagedPtr ptr; + }; + + /** used to compute total workspace size from several workspace requests */ + class WorkspaceBuilder { + public: + WorkspaceBuilder() noexcept : max_size_in_bytes{ 0 } { } + + /** request memory for \p count number of elements of the type \tparam T */ + template + void require(std::size_t count) noexcept { + auto blocks256 = (count * sizeof(T) + 255) / 256; + max_size_in_bytes += blocks256 * 256; + } + + /** returns the total workspace memory that is required */ + std::size_t required_workspace_size() const noexcept { return max_size_in_bytes; } + + private: + std::size_t max_size_in_bytes; + }; + + /** general memory block from a workspace which can be passed on to the requester */ + class WorkspaceInstance { + public: + + /** returns a device pointer to the workspace memory */ + template + DevicePtr get() const noexcept { + return static_cast>(ptr); + } + + /** returnss the size of the workspace memory in bytes */ + std::size_t size_in_bytes() const noexcept { + return size_in_bytes_; + } + + /** creates a Span of \p count elements from the workspace memory */ + template + Span get_span(std::size_t count = 0) const { + if (count == 0) + count = size_in_bytes_ / sizeof(T); + + if (count * sizeof(T) > size_in_bytes_) + CV_Error(Error::StsNoMem, "memory not sufficient"); + + return Span(static_cast>(ptr), count); + } + + /** creates a TensorSpan of the given shape from the workspace memory */ + template + TensorSpan get_tensor_span(ForwardItr shape_begin, ForwardItr shape_end) const { + using ItrValueType = typename std::iterator_traits::value_type; + auto required_size = std::accumulate(shape_begin, shape_end, 1, std::multiplies()); + if (required_size * sizeof(T) > size_in_bytes_) + CV_Error(Error::StsNoMem, "memory not sufficient"); + return TensorSpan(static_cast>(ptr), shape_begin, shape_end); + } + + private: + DevicePtr ptr; + std::size_t size_in_bytes_; + + friend class WorkspaceAllocator; + WorkspaceInstance(DevicePtr ptr_, std::size_t size_in_bytes__) + : ptr{ ptr_ }, size_in_bytes_{ size_in_bytes__ } { } + }; + + /** used to split a single workspace into constituents */ + class WorkspaceAllocator { + public: + WorkspaceAllocator() = default; + WorkspaceAllocator(Workspace& workspace) noexcept + : current{ workspace.get() }, bytes_remaining { workspace.size() } + { + CV_Assert(is_aligned(current, 256)); + CV_Assert(bytes_remaining % 256 == 0); + } + + /** allocates a Span of \p count elements from the workspace memory */ + template + Span get_span(std::size_t count = 0) { + return accquire(count); + } + + /** allocates a TensorSpan of the given shape from the workspace memory */ + template + TensorSpan get_tensor_span(ForwardItr start, ForwardItr end) { + using ItrValueType = typename std::iterator_traits::value_type; + auto required_size = std::accumulate(start, end, 1, std::multiplies()); + return TensorSpan(accquire(required_size).data(), start, end); + } + + /** allocates a WorkspaceInstance of size \p bytes from the workspace memory */ + WorkspaceInstance get_instance(std::size_t bytes = 0) { + auto span = accquire(bytes); + return WorkspaceInstance(DevicePtr(span.data()), span.size()); + } + + private: + template + Span accquire(std::size_t count = 0) { + auto ptr = current; + + if (count == 0) + count = bytes_remaining / sizeof(T); + + auto blocks256 = (count * sizeof(T) + 255) / 256; + if (bytes_remaining < blocks256 * 256) + CV_Error(Error::StsNoMem, "out of workspace memory"); + + bytes_remaining -= blocks256 * 256; + current = static_cast>(current) + blocks256 * 256; + return Span(static_cast>(ptr), count); + } + + DevicePtr current; + std::size_t bytes_remaining; + }; + +}}}} /* namespace cv::dnn::cuda4dnn::csl */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_WORKSPACE_HPP */ diff --git a/modules/dnn/src/cuda4dnn/cxx_utils/is_iterator.hpp b/modules/dnn/src/cuda4dnn/cxx_utils/is_iterator.hpp new file mode 100644 index 0000000..3710d08 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/cxx_utils/is_iterator.hpp @@ -0,0 +1,31 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_CXX_UTILS_IS_ITERATOR_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_CXX_UTILS_IS_ITERATOR_HPP + +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace cxx_utils { + + namespace detail { + template + struct is_iterator_helper : std::false_type {}; + + template + struct is_iterator_helper::iterator_category>::value, void>::type + > : std::true_type {}; + } + + template + using is_iterator = typename detail::is_iterator_helper; + + template + using is_forward_iterator = typename detail::is_iterator_helper; + +}}}} /* namespace cv::dnn::cuda4dnn::csl::cxx_utils */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_CXX_UTILS_IS_ITERATOR_HPP */ diff --git a/modules/dnn/src/cuda4dnn/cxx_utils/resizable_static_array.hpp b/modules/dnn/src/cuda4dnn/cxx_utils/resizable_static_array.hpp new file mode 100644 index 0000000..ae53ac6 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/cxx_utils/resizable_static_array.hpp @@ -0,0 +1,110 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_CXX_UTILS_RESIZABLE_STATIC_ARRAY_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_CXX_UTILS_RESIZABLE_STATIC_ARRAY_HPP + +#include +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace cxx_utils { + + template + class resizable_static_array { + using container_type = std::array; + + public: + using value_type = typename container_type::value_type; + using size_type = typename container_type::size_type; + using difference_type = typename container_type::difference_type; + using reference = typename container_type::reference; + using const_reference = typename container_type::const_reference; + using pointer = typename container_type::pointer; + using const_pointer = typename container_type::const_pointer; + using iterator = typename container_type::iterator; + using const_iterator = typename container_type::const_iterator; + using reverse_iterator = typename container_type::reverse_iterator; + using const_reverse_iterator = typename container_type::const_reverse_iterator; + + resizable_static_array() noexcept : size_{ 0 } { } + explicit resizable_static_array(size_type sz) noexcept : size_{ sz } { } + + bool empty() const noexcept { return static_cast(size_); } + size_type size() const noexcept { return size_; } + size_type capacity() const noexcept { return maxN; } + + void resize(size_type sz) noexcept { + assert(sz <= capacity()); + size_ = sz; + } + + void clear() noexcept { size_ = 0; } + + template + void assign(ForwardItr first, ForwardItr last) { + resize(std::distance(first, last)); + std::copy(first, last, begin()); + } + + iterator begin() noexcept { return std::begin(arr); } + iterator end() noexcept { return std::begin(arr) + size(); } + + const_iterator begin() const noexcept { return arr.cbegin(); } + const_iterator end() const noexcept { return arr.cbegin() + size(); } + + const_iterator cbegin() const noexcept { return arr.cbegin(); } + const_iterator cend() const noexcept { return arr.cbegin() + size(); } + + reverse_iterator rbegin() noexcept { return std::begin(arr) + size(); } + reverse_iterator rend() noexcept { return std::begin(arr); } + + const_reverse_iterator rbegin() const noexcept { return arr.cbegin()+ size(); } + const_reverse_iterator rend() const noexcept { return arr.cbegin(); } + + const_reverse_iterator crbegin() const noexcept { return arr.cbegin() + size(); } + const_reverse_iterator crend() const noexcept { return arr.cbegin(); } + + reference operator[](size_type pos) { + assert(pos < size()); + return arr[pos]; + } + + const_reference operator[](size_type pos) const { + assert(pos < size()); + return arr[pos]; + } + + iterator insert(iterator pos, const T& value) { + resize(size() + 1); + std::move_backward(pos, end() - 1, end()); + *pos = value; + return pos; + } + + iterator insert(iterator pos, T&& value) { + resize(size() + 1); + std::move_backward(pos, end() - 1, end()); + *pos = std::move(value); + return pos; + } + + iterator erase(iterator pos) { + std::move(pos + 1, end(), pos); + resize(size() - 1); + return pos; + } + + pointer data() noexcept { return arr.data(); } + const_pointer data() const noexcept { return arr.data(); } + + private: + std::size_t size_; + container_type arr; + }; + +}}}} /* namespace cv::dnn::cuda4dnn::csl::cxx_utils */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_CXX_UTILS_RESIZABLE_STATIC_ARRAY_HPP */ diff --git a/modules/dnn/src/cuda4dnn/kernels/activations.hpp b/modules/dnn/src/cuda4dnn/kernels/activations.hpp new file mode 100644 index 0000000..05f8f48 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/kernels/activations.hpp @@ -0,0 +1,44 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_ACTIVATIONS_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_ACTIVATIONS_HPP + +#include "../csl/stream.hpp" +#include "../csl/span.hpp" + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + template + void abs(const csl::Stream& stream, csl::Span output, csl::View input); + + template + void tanh(const csl::Stream& stream, csl::Span output, csl::View input); + + template + void sigmoid(const csl::Stream& stream, csl::Span output, csl::View input); + + template + void bnll(const csl::Stream& stream, csl::Span output, csl::View input); + + template + void elu(const csl::Stream& stream, csl::Span output, csl::View input); + + template + void relu(const csl::Stream& stream, csl::Span output, csl::View input, T slope); + + template + void clipped_relu(const csl::Stream& stream, csl::Span output, csl::View input, T floor, T ceiling); + + template + void axiswise_relu(const csl::Stream& stream, csl::Span output, csl::View input, std::size_t inner_size, csl::View slope); + + template + void power(const csl::Stream& stream, csl::Span output, csl::View input, T exp, T scale, T shift); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_ACTIVATIONS_HPP */ diff --git a/modules/dnn/src/cuda4dnn/kernels/concat.hpp b/modules/dnn/src/cuda4dnn/kernels/concat.hpp new file mode 100644 index 0000000..3916a24 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/kernels/concat.hpp @@ -0,0 +1,27 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_CONCAT_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_CONCAT_HPP + +#include "../csl/stream.hpp" +#include "../csl/tensor.hpp" + +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + template + void concat( + const csl::Stream& stream, + csl::TensorSpan output, std::size_t output_axis_offset, + csl::TensorView input, std::size_t axis); + + template + void concat_with_offsets(const csl::Stream& stream, csl::TensorSpan output, csl::TensorView input, std::vector axis_offsets); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_CONCAT_HPP */ diff --git a/modules/dnn/src/cuda4dnn/kernels/eltwise_ops.hpp b/modules/dnn/src/cuda4dnn/kernels/eltwise_ops.hpp new file mode 100644 index 0000000..7d84d07 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/kernels/eltwise_ops.hpp @@ -0,0 +1,29 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_ELTWISE_OPS_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_ELTWISE_OPS_HPP + +#include "../csl/stream.hpp" +#include "../csl/span.hpp" + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + template + void eltwise_max_2(const csl::Stream& stream, csl::Span output, csl::View x, csl::View y); + + template + void eltwise_sum_2(const csl::Stream& stream, csl::Span output, csl::View x, csl::View y); + + template + void eltwise_sum_coeff_2(const csl::Stream& stream, csl::Span output, T coeff_x, csl::View x, T coeff_y, csl::View y); + + template + void eltwise_prod_2(const csl::Stream& stream, csl::Span output, csl::View x, csl::View y); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_ELTWISE_OPS_HPP */ diff --git a/modules/dnn/src/cuda4dnn/kernels/fill.hpp b/modules/dnn/src/cuda4dnn/kernels/fill.hpp new file mode 100644 index 0000000..4101481 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/kernels/fill.hpp @@ -0,0 +1,18 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_FILL_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_FILL_HPP + +#include "../csl/stream.hpp" +#include "../csl/span.hpp" + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + template + void fill(const csl::Stream& stream, csl::Span output, T value); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_FILL_HPP */ diff --git a/modules/dnn/src/cuda4dnn/kernels/max_unpooling.hpp b/modules/dnn/src/cuda4dnn/kernels/max_unpooling.hpp new file mode 100644 index 0000000..6fe4d61 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/kernels/max_unpooling.hpp @@ -0,0 +1,32 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_MAX_UNPOOLING_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_MAX_UNPOOLING_HPP + +#include "../csl/stream.hpp" +#include "../csl/tensor.hpp" + +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + template + void max_pooling_with_indices( + const csl::Stream& stream, + csl::TensorSpan output, csl::TensorSpan indices, csl::TensorView input, + const std::vector& kernel_size, const std::vector& strides, + const std::vector& padding_left); + + template + void max_unpooling( + const csl::Stream& stream, + csl::TensorSpan output, csl::TensorView input, csl::TensorView indices, + const std::vector& window_size, const std::vector& strides, + const std::vector& padding_left); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_MAX_UNPOOLING_HPP */ diff --git a/modules/dnn/src/cuda4dnn/kernels/normalize.hpp b/modules/dnn/src/cuda4dnn/kernels/normalize.hpp new file mode 100644 index 0000000..6a3f7eb --- /dev/null +++ b/modules/dnn/src/cuda4dnn/kernels/normalize.hpp @@ -0,0 +1,24 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_NORMALIZE_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_NORMALIZE_HPP + +#include "../csl/stream.hpp" +#include "../csl/span.hpp" + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + template + void normalize( + const csl::Stream& stream, + csl::Span output, csl::View input, + std::size_t outer_size, std::size_t mid_size, std::size_t inner_size, std::size_t norm, T epsilon, + csl::Span workspace); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_NORMALIZE_HPP */ diff --git a/modules/dnn/src/cuda4dnn/kernels/padding.hpp b/modules/dnn/src/cuda4dnn/kernels/padding.hpp new file mode 100644 index 0000000..cb55e99 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/kernels/padding.hpp @@ -0,0 +1,25 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_PADDING_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_PADDING_HPP + +#include "../csl/stream.hpp" +#include "../csl/tensor.hpp" + +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + template + void copy_with_reflection101( + const csl::Stream& stream, + csl::TensorSpan output, csl::TensorView input, + std::vector> ranges); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_PADDING_HPP */ diff --git a/modules/dnn/src/cuda4dnn/kernels/permute.hpp b/modules/dnn/src/cuda4dnn/kernels/permute.hpp new file mode 100644 index 0000000..65fe46b --- /dev/null +++ b/modules/dnn/src/cuda4dnn/kernels/permute.hpp @@ -0,0 +1,21 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_PERMUTE_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_PERMUTE_HPP + +#include "../csl/stream.hpp" +#include "../csl/tensor.hpp" + +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + template + void permute(const csl::Stream& stream, csl::TensorSpan output, csl::TensorView input, std::vector order); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_PERMUTE_HPP */ diff --git a/modules/dnn/src/cuda4dnn/kernels/prior_box.hpp b/modules/dnn/src/cuda4dnn/kernels/prior_box.hpp new file mode 100644 index 0000000..643cecf --- /dev/null +++ b/modules/dnn/src/cuda4dnn/kernels/prior_box.hpp @@ -0,0 +1,28 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_PRIOR_BOX_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_PRIOR_BOX_HPP + +#include "../csl/stream.hpp" +#include "../csl/span.hpp" + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + template + void generate_prior_boxes( + const csl::Stream& stream, + csl::Span output, + csl::View boxWidth, csl::View boxHeight, csl::View offsetX, csl::View offsetY, float stepX, float stepY, + std::vector variance, + std::size_t numPriors, + std::size_t layerWidth, std::size_t layerHeight, + std::size_t imageWidth, std::size_t imageHeight, + bool normalize, bool clip); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_PRIOR_BOX_HPP */ diff --git a/modules/dnn/src/cuda4dnn/kernels/region.hpp b/modules/dnn/src/cuda4dnn/kernels/region.hpp new file mode 100644 index 0000000..0e12ad3 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/kernels/region.hpp @@ -0,0 +1,32 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_REGION_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_REGION_HPP + +#include "../csl/stream.hpp" +#include "../csl/span.hpp" + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + template + void sigmoid_strided(const csl::Stream& stream, csl::Span output, csl::View input, std::size_t n, std::size_t stride, std::size_t offset); + + template + void softmax_strided(const csl::Stream& stream, csl::Span output, csl::View input, std::size_t n, std::size_t stride, std::size_t offset); + + template + void region_finalize(const csl::Stream& stream, csl::Span output, csl::View input, csl::View bias, + T object_prob_cutoff, T class_prob_cutoff, + std::size_t height_norm, std::size_t width_norm, + std::size_t rows, std::size_t cols, + std::size_t boxes_per_cell, + std::size_t box_size, + std::size_t classes); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_REGION_HPP */ diff --git a/modules/dnn/src/cuda4dnn/kernels/resize.hpp b/modules/dnn/src/cuda4dnn/kernels/resize.hpp new file mode 100644 index 0000000..5c5cc3d --- /dev/null +++ b/modules/dnn/src/cuda4dnn/kernels/resize.hpp @@ -0,0 +1,23 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_RESIZE_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_RESIZE_HPP + +#include "../csl/stream.hpp" +#include "../csl/tensor.hpp" + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + template + void resize_nn(const csl::Stream& stream, csl::TensorSpan output, csl::TensorView input); + + template + void resize_bilinear(const csl::Stream& stream, csl::TensorSpan output, csl::TensorView input, float scale_y, float scale_x); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_RESIZE_HPP */ diff --git a/modules/dnn/src/cuda4dnn/kernels/scale_shift.hpp b/modules/dnn/src/cuda4dnn/kernels/scale_shift.hpp new file mode 100644 index 0000000..32fa1d8 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/kernels/scale_shift.hpp @@ -0,0 +1,45 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_SCALE_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_SCALE_HPP + +#include "../csl/stream.hpp" +#include "../csl/tensor.hpp" + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + template + void bias1(const csl::Stream& stream, csl::TensorSpan output, csl::TensorView input, T alpha); + + template + void biasN(const csl::Stream& stream, + csl::TensorSpan output, + csl::TensorView input, std::size_t inner_size, + csl::TensorView bias); + + template + void scale1(const csl::Stream& stream, csl::TensorSpan output, csl::TensorView input, T alpha); + + template + void scaleN(const csl::Stream& stream, + csl::TensorSpan output, + csl::TensorView input, std::size_t inner_size, + csl::TensorView weights); + + template + void scale1_with_bias1(const csl::Stream& stream, csl::Span output, csl::View input, T alpha, T beta); + + template + void scaleN_with_biasN( + const csl::Stream& stream, + csl::TensorSpan output, + csl::TensorView input, std::size_t inner_size, + csl::TensorView weights, csl::TensorView bias); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_SCALE_HPP */ diff --git a/modules/dnn/src/cuda4dnn/kernels/slice.hpp b/modules/dnn/src/cuda4dnn/kernels/slice.hpp new file mode 100644 index 0000000..d78efc5 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/kernels/slice.hpp @@ -0,0 +1,22 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_SLICE_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_SLICE_HPP + +#include "../csl/stream.hpp" +#include "../csl/tensor.hpp" + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + template + void slice(const csl::Stream& stream, + csl::TensorSpan output, csl::TensorView input, + std::vector offsets); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_SLICE_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/activation.hpp b/modules/dnn/src/cuda4dnn/primitives/activation.hpp new file mode 100644 index 0000000..d90aef3 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/activation.hpp @@ -0,0 +1,290 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_ACTIVATION_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_ACTIVATION_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/stream.hpp" +#include "../csl/tensor.hpp" + +#include "../kernels/activations.hpp" + +#include + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + template + class ReLUOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + ReLUOp(csl::Stream stream_, T slope_) + : stream(std::move(stream_)), slope{ slope_ } { } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + for (int i = 0; i < inputs.size(); i++) + { + auto input_wrapper = inputs[i].dynamicCast(); + auto input = input_wrapper->getView(); + + auto output_wrapper = outputs[i].dynamicCast(); + auto output = output_wrapper->getSpan(); + + kernels::relu(stream, output, input, slope); + } + } + + private: + csl::Stream stream; + const T slope; + }; + + template + class ClippedReLUOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + ClippedReLUOp(csl::Stream stream_, T min_, T max_) + : stream(std::move(stream_)), min{ min_ }, max{ max_ } { } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + for (int i = 0; i < inputs.size(); i++) + { + auto input_wrapper = inputs[i].dynamicCast(); + auto input = input_wrapper->getView(); + + auto output_wrapper = outputs[i].dynamicCast(); + auto output = output_wrapper->getSpan(); + + kernels::clipped_relu(stream, output, input, min, max); + } + } + + private: + csl::Stream stream; + const T min, max; + }; + + template + class ChannelwiseReLUOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + ChannelwiseReLUOp(csl::Stream stream_, const Mat& slope) + : stream(std::move(stream_)) + { + CV_Assert(!slope.empty()); + slopeTensor = csl::makeTensorHeader(slope); + csl::copyMatToTensor(slope, slopeTensor, stream); + } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + for (int i = 0; i < inputs.size(); i++) + { + auto input_wrapper = inputs[i].dynamicCast(); + auto input = input_wrapper->getView(); + + auto output_wrapper = outputs[i].dynamicCast(); + auto output = output_wrapper->getSpan(); + + CV_Assert(input.get_axis_size(1) == slopeTensor.size()); + std::size_t inner_size = input.size_range(2, input.rank()); + kernels::axiswise_relu(stream, output, input, inner_size, slopeTensor); + } + } + + private: + csl::Stream stream; + csl::Tensor slopeTensor; + }; + + template + class TanHOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + TanHOp(csl::Stream stream_) : stream(std::move(stream_)) { } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + for (int i = 0; i < inputs.size(); i++) + { + auto input_wrapper = inputs[i].dynamicCast(); + auto input = input_wrapper->getView(); + + auto output_wrapper = outputs[i].dynamicCast(); + auto output = output_wrapper->getSpan(); + + kernels::tanh(stream, output, input); + } + } + + private: + csl::Stream stream; + }; + + template + class SigmoidOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + SigmoidOp(csl::Stream stream_) : stream(std::move(stream_)) { } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + for (int i = 0; i < inputs.size(); i++) + { + auto input_wrapper = inputs[i].dynamicCast(); + auto input = input_wrapper->getView(); + + auto output_wrapper = outputs[i].dynamicCast(); + auto output = output_wrapper->getSpan(); + + kernels::sigmoid(stream, output, input); + } + } + + private: + csl::Stream stream; + }; + + template + class ELUOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + ELUOp(csl::Stream stream_) : stream(std::move(stream_)) { } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + for (int i = 0; i < inputs.size(); i++) + { + auto input_wrapper = inputs[i].dynamicCast(); + auto input = input_wrapper->getView(); + + auto output_wrapper = outputs[i].dynamicCast(); + auto output = output_wrapper->getSpan(); + + kernels::elu(stream, output, input); + } + } + + private: + csl::Stream stream; + }; + + template + class AbsValOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + AbsValOp(csl::Stream stream_) : stream(std::move(stream_)) { } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + for (int i = 0; i < inputs.size(); i++) + { + auto input_wrapper = inputs[i].dynamicCast(); + auto input = input_wrapper->getView(); + + auto output_wrapper = outputs[i].dynamicCast(); + auto output = output_wrapper->getSpan(); + + kernels::abs(stream, output, input); + } + } + + private: + csl::Stream stream; + }; + + template + class BNLLOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + BNLLOp(csl::Stream stream_) : stream(std::move(stream_)) { } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + for (int i = 0; i < inputs.size(); i++) + { + auto input_wrapper = inputs[i].dynamicCast(); + auto input = input_wrapper->getView(); + + auto output_wrapper = outputs[i].dynamicCast(); + auto output = output_wrapper->getSpan(); + + kernels::bnll(stream, output, input); + } + } + + private: + csl::Stream stream; + }; + + template + class PowerOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + PowerOp(csl::Stream stream_, T exp_, T scale_, T shift_) + : stream(std::move(stream_)), exp{ exp_ }, scale{ scale_ }, shift{ shift_ } { } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + for (int i = 0; i < inputs.size(); i++) + { + auto input_wrapper = inputs[i].dynamicCast(); + auto input = input_wrapper->getView(); + + auto output_wrapper = outputs[i].dynamicCast(); + auto output = output_wrapper->getSpan(); + + kernels::power(stream, output, input, exp, scale, shift); + } + } + + private: + csl::Stream stream; + const T exp, scale, shift; + }; + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_ACTIVATION_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/batch_norm.hpp b/modules/dnn/src/cuda4dnn/primitives/batch_norm.hpp new file mode 100644 index 0000000..293811f --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/batch_norm.hpp @@ -0,0 +1,58 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_BATCH_NORM_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_BATCH_NORM_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/stream.hpp" +#include "../csl/tensor.hpp" + +#include "../kernels/scale_shift.hpp" + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + template + class BatchNormOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + BatchNormOp(csl::Stream stream_, const cv::Mat& weights, const cv::Mat& bias) + : stream(std::move(stream_)) + { + biasTensor = csl::makeTensorHeader(bias); + csl::copyMatToTensor(bias, biasTensor, stream); + + weightsTensor = csl::makeTensorHeader(weights); + csl::copyMatToTensor(weights, weightsTensor, stream); + } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + CV_Assert(inputs.size() == 1 && outputs.size() == 1); + + auto input_wrapper = inputs[0].dynamicCast(); + auto input = input_wrapper->getView(); + + auto output_wrapper = outputs[0].dynamicCast(); + auto output = output_wrapper->getSpan(); + + std::size_t inner_size = input.size_range(2, input.rank()); + kernels::scaleN_with_biasN(stream, output, input, inner_size, weightsTensor, biasTensor); + } + + private: + csl::Stream stream; + csl::Tensor weightsTensor, biasTensor; + }; + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_BATCH_NORM_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/concat.hpp b/modules/dnn/src/cuda4dnn/primitives/concat.hpp new file mode 100644 index 0000000..9655627 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/concat.hpp @@ -0,0 +1,90 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_CONCAT_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_CONCAT_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/stream.hpp" +#include "../csl/pointer.hpp" + +#include "../kernels/fill.hpp" +#include "../kernels/concat.hpp" + +#include + +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + template + class ConcatOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + ConcatOp(csl::Stream stream_, std::size_t concat_axis, bool zero_padding) + : stream(std::move(stream_)), concat_axis{ concat_axis }, zero_padding{ zero_padding } + { + } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + CV_Assert(outputs.size() == 1); + + auto output_wrapper = outputs[0].dynamicCast(); + auto output = output_wrapper->getSpan(); + + if(zero_padding) + { + auto output_shape = output_wrapper->getShape(); + + kernels::fill(stream, output, 0.0); + + std::size_t output_concat_axis_offset = 0; + for (int i = 0; i < inputs.size(); i++) + { + auto input_wrapper = inputs[i].dynamicCast(); + auto input = input_wrapper->getView(); + auto input_shape = input_wrapper->getShape(); + + std::vector offsets(input_shape.size()); + for (int j = 0; j < offsets.size(); j++) + offsets[j] = (output_shape[j] - input_shape[j]) / 2; + offsets[concat_axis] = output_concat_axis_offset; + + kernels::concat_with_offsets(stream, output, input, offsets); + + output_concat_axis_offset += input.get_axis_size(concat_axis); + } + } + else + { + std::size_t output_axis_offset = 0; + for (int i = 0; i < inputs.size(); i++) + { + auto input_wrapper = inputs[i].dynamicCast(); + auto input = input_wrapper->getView(); + + kernels::concat(stream, output, output_axis_offset, input, concat_axis); + + output_axis_offset += input.get_axis_size(concat_axis); + } + } + } + + private: + csl::Stream stream; + std::size_t concat_axis; + bool zero_padding; + }; + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_CONCAT_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/const.hpp b/modules/dnn/src/cuda4dnn/primitives/const.hpp new file mode 100644 index 0000000..e883031 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/const.hpp @@ -0,0 +1,51 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_CONST_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_CONST_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/stream.hpp" +#include "../csl/tensor.hpp" +#include "../csl/tensor_ops.hpp" + +#include + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + template + class ConstOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + ConstOp(csl::Stream stream_, const cv::Mat& data) + : stream(std::move(stream_)) + { + constTensor = csl::makeTensorHeader(data); + csl::copyMatToTensor(data, constTensor, stream); + } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + CV_Assert(outputs.size() == 1 && inputs.size() == 0); + + auto output_wrapper = outputs[0].dynamicCast(); + auto output = output_wrapper->getSpan(); + csl::tensor_ops::copy(stream, output, constTensor); + } + + private: + csl::Stream stream; + csl::Tensor constTensor; + }; + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_CONST_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/convolution.hpp b/modules/dnn/src/cuda4dnn/primitives/convolution.hpp new file mode 100644 index 0000000..6713357 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/convolution.hpp @@ -0,0 +1,250 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_CONVOLUTION_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_CONVOLUTION_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/cudnn.hpp" +#include "../csl/stream.hpp" +#include "../csl/tensor.hpp" +#include "../csl/tensor_ops.hpp" +#include "../kernels/scale_shift.hpp" + +#include + +#include +#include +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + struct ConvolutionConfiguration { + /* the size of the following vectors must be equal to the kernel size */ + std::vector kernel_size; + std::vector dilations, strides; + + enum class PaddingMode { + MANUAL, /* uses explicit padding values provided in `pads_begin` and `pads_end` */ + VALID, /* no padding is added */ + SAME /* TensorFlow logic is used for same padding */ + }; + + /* explicit paddings are used if and only if padMode is set to manual */ + PaddingMode padMode; + std::vector pads_begin, pads_end; + + /* full shape inclusive of channel and batch axis */ + std::vector input_shape; + std::vector output_shape; + + /* group count for grouped convolution */ + std::size_t groups; + }; + + template + class ConvolutionOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + ConvolutionOp(csl::Stream stream_, csl::cudnn::Handle handle, const ConvolutionConfiguration& config, const Mat& filters, const Mat& bias) + : stream(std::move(stream_)), cudnnHandle(std::move(handle)) + { + const auto& kernel_size = config.kernel_size; + const auto& dilations = config.dilations; + const auto& strides = config.strides; + + const auto convolution_order = kernel_size.size(); + CV_Assert(convolution_order >= 1); + + CV_Assert(convolution_order == dilations.size()); + CV_Assert(convolution_order == strides.size()); + + const auto& input_shape = config.input_shape; + const auto& output_shape = config.output_shape; + CV_Assert(input_shape.size() == output_shape.size()); + CV_Assert(input_shape.size() == convolution_order + 2); + + const auto groups = config.groups; + + if (convolution_order > 3) + CV_Error(Error::StsNotImplemented, "Only 1D/2D/3D convolution is supported."); + + const auto rank = input_shape.size(); + const auto output_feature_maps = output_shape[1]; + const auto input_feature_maps = input_shape[1]; + const auto input_feature_maps_per_group = input_feature_maps / groups; + CV_Assert(input_feature_maps % groups == 0); + + filtersTensor = csl::makeTensorHeader(filters); + csl::copyMatToTensor(filters, filtersTensor, stream); + + if (!bias.empty()) + { + biasTensor = csl::makeTensorHeader(bias); + csl::copyMatToTensor(bias, biasTensor, stream); + } + + /* left and right are misleading as the padding is applicable for any number of dimensions + * but we use those identifiers to avoid confusion with `pads_begin` and `pads_end` + * + * `common_padding` contains the amount of padding that has to be added to both sides + * `padding_left` and `padding_right` contains the amount of padding that needs to be added + * to a particular side in addition to the common padding + */ + std::vector common_padding(rank, 0); + std::vector padding_left(rank, 0), padding_right(rank, 0); + if (config.padMode == ConvolutionConfiguration::PaddingMode::MANUAL) + { + const auto& pads_begin = config.pads_begin; + const auto& pads_end = config.pads_end; + + CV_Assert(convolution_order == pads_begin.size()); + CV_Assert(convolution_order == pads_end.size()); + + for (int i = 2; i < common_padding.size(); i++) + { + common_padding[i] = std::min(pads_begin[i - 2], pads_end[i - 2]); + padding_left[i] = pads_begin[i - 2] - common_padding[i]; + padding_right[i] = pads_end[i - 2] - common_padding[i]; + } + } + else if (config.padMode == ConvolutionConfiguration::PaddingMode::VALID) + { + /* nothing to do as the paddings are already preset to zero */ + } + else if (config.padMode == ConvolutionConfiguration::PaddingMode::SAME) + { + /* TensorFlow Logic: + * total_padding[i] = (o[i] - 1) * s[i] + effective_k[i] - i[i] + * + * if total padding is odd, the extra is added towards the end + */ + for (int i = 2; i < rank; i++) + { + const auto j = i - 2; /* filter index */ + const auto effective_kernel_size = dilations[j] * (kernel_size[j] - 1) + 1; + const auto required_total_padding = + std::max(0, (output_shape[i] - 1) * strides[j] + effective_kernel_size - input_shape[i]); + + common_padding[i] = required_total_padding / 2; + padding_left[i] = 0; + padding_right[i] = required_total_padding % 2; + } + } + + /* in some scenarios, the extra padding at the end may not change the output at all */ + for (int i = 2; i < rank; i++) { + const auto j = i - 2; /* filter idx */ + const auto total_padding = common_padding[i] * 2 + padding_left[i] + padding_right[i]; + const auto effective_kernel_size = dilations[j] * (kernel_size[j] - 1) + 1; + std::int64_t rem = (input_shape[i] + total_padding - effective_kernel_size) % strides[j]; + + /* the output shape doesn't change if we decrease the total padding by at most `rem` + * provided that we decrease from the right + */ + if (rem && padding_right[i] > 0) + padding_right[i] = std::max(0, padding_right[i] - rem); + } + + auto is_not_zero = [](std::size_t i) { return i != 0; }; + if(std::any_of(std::begin(padding_left), std::end(padding_left), is_not_zero) || + std::any_of(std::begin(padding_right), std::end(padding_right), is_not_zero)) + { + /* csl::Convolution supports symmetric padding only; hence, we deal with asymmetric padding by + * copying the input to a bigger tensor and padding the ends manually + */ + transformed_shape = input_shape; + for (int i = 0; i < rank; i++) + transformed_shape[i] += padding_left[i] + padding_right[i]; + + inputTransformer = csl::TensorTransform(cudnnHandle, padding_left, padding_right); + } + + typename csl::Convolution::params_type params; + if (transformed_shape.empty()) + { + params.input_shape.assign(std::begin(input_shape), std::end(input_shape)); + } + else + { + /* the convolution operation will be seeing the transformed input */ + params.input_shape.assign(std::begin(transformed_shape), std::end(transformed_shape)); + } + + auto& fshape = params.filter_shape; + fshape.resize(rank); + fshape[0] = output_feature_maps; + fshape[1] = input_feature_maps_per_group; + std::copy(std::begin(kernel_size), std::end(kernel_size), std::begin(fshape) + 2); + CV_Assert(fshape.size() == kernel_size.size() + 2); + + params.padding.assign(std::begin(common_padding) + 2, std::end(common_padding)); + params.stride = strides; + params.dilation = dilations; + params.groups = config.groups; + + convoluter = csl::Convolution(cudnnHandle, params); + + csl::WorkspaceBuilder builder; + if (!transformed_shape.empty()) { + auto& shape = transformed_shape; + auto sz = std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies()); + builder.require(sz); + } + builder.require(convoluter.get_workspace_size()); + scratch_mem_in_bytes = builder.required_workspace_size(); + } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + CV_Assert(inputs.size() == 1 && outputs.size() == 1); + + csl::WorkspaceAllocator allocator(workspace); + + auto input_wrapper = inputs[0].dynamicCast(); + auto input = input_wrapper->getView(); + + if (!transformed_shape.empty()) + { + auto& shape = transformed_shape; + auto transformed_input = allocator.get_tensor_span(std::begin(shape), std::end(shape)); + inputTransformer.transform(input, transformed_input); + input = transformed_input; + } + + auto output_wrapper = outputs[0].dynamicCast(); + auto output = output_wrapper->getSpan(); + + convoluter.convolve(output, input, filtersTensor, allocator.get_instance()); + if (!biasTensor.empty()) + { + std::size_t inner_size = output.size_range(2, output.rank()); + kernels::biasN(stream, output, output, inner_size, biasTensor); + } + } + + std::size_t get_workspace_memory_in_bytes() const noexcept override { return scratch_mem_in_bytes; } + + private: + csl::Stream stream; + csl::cudnn::Handle cudnnHandle; + csl::Tensor filtersTensor, biasTensor; + csl::Convolution convoluter; + + std::vector transformed_shape; + csl::TensorTransform inputTransformer; + + std::size_t scratch_mem_in_bytes; + }; + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_CONVOLUTION_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/eltwise.hpp b/modules/dnn/src/cuda4dnn/primitives/eltwise.hpp new file mode 100644 index 0000000..c044730 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/eltwise.hpp @@ -0,0 +1,115 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_ELTWISE_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_ELTWISE_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/stream.hpp" +#include "../csl/tensor.hpp" +#include "../csl/tensor_ops.hpp" + +#include "../kernels/eltwise_ops.hpp" + +#include + +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + enum class EltwiseOpType { + MAX, + SUM, + PRODUCT + }; + + template + class EltwiseOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + template + EltwiseOp(csl::Stream stream_, EltwiseOpType op_, std::vector coeffs_) + : stream(std::move(stream_)), op{ op_ }, coeffs(std::begin(coeffs_), std::end(coeffs_)) + { + } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + CV_Assert(inputs.size() >= 2); + CV_Assert(outputs.size() == 1); + + CV_Assert(coeffs.size() == 0 || op == EltwiseOpType::SUM); + CV_Assert(coeffs.size() == 0 || inputs.size() == coeffs.size()); + + auto output_wrapper = outputs[0].dynamicCast(); + auto output = output_wrapper->getSpan(); + + if (inputs.size() == 2) + { + auto input_wrapper_x = inputs[0].dynamicCast(); + auto input_x = input_wrapper_x->getView(); + + auto input_wrapper_y = inputs[1].dynamicCast(); + auto input_y = input_wrapper_y->getView(); + + switch (op) + { + case EltwiseOpType::MAX: kernels::eltwise_max_2(stream, output, input_x, input_y); break; + case EltwiseOpType::PRODUCT: kernels::eltwise_prod_2(stream, output, input_x, input_y); break; + case EltwiseOpType::SUM: + if (coeffs.empty() || (coeffs[0] == 1 && coeffs[1] == 1)) + kernels::eltwise_sum_2(stream, output, input_x, input_y); + else + kernels::eltwise_sum_coeff_2(stream, output, coeffs[0], input_x, coeffs[1], input_y); + break; + } + } + else + { + auto input_wrapper_0 = inputs[0].dynamicCast(); + auto input_0 = input_wrapper_0->getView(); + + /* we first make a copy and then apply EltwiseOp cumulatively */ + csl::tensor_ops::copy(stream, output, input_0); + + for (int i = 1; i < inputs.size(); i++) + { + auto input_wrapper = inputs[i].dynamicCast(); + auto input = input_wrapper->getView(); + + switch (op) + { + case EltwiseOpType::MAX: kernels::eltwise_max_2(stream, output, output, input); break; + case EltwiseOpType::PRODUCT: kernels::eltwise_prod_2(stream, output, output, input); break; + case EltwiseOpType::SUM: + if (coeffs.empty() || coeffs[i] == 1) + kernels::eltwise_sum_2(stream, output, output, input); + else + { + /* if this is the first op, we must scale output too */ + auto coeff_x = (i == 1) ? coeffs[0] : static_cast(1.0); + kernels::eltwise_sum_coeff_2(stream, output, coeff_x, output, coeffs[i], input); + } + break; + } + } + } + } + + private: + csl::Stream stream; + EltwiseOpType op; + std::vector coeffs; + }; + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_ELTWISE_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/inner_product.hpp b/modules/dnn/src/cuda4dnn/primitives/inner_product.hpp new file mode 100644 index 0000000..d5baa50 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/inner_product.hpp @@ -0,0 +1,92 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_INNER_PRODUCT_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_INNER_PRODUCT_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/stream.hpp" +#include "../csl/cublas.hpp" +#include "../csl/tensor.hpp" +#include "../csl/tensor_ops.hpp" + +#include "../kernels/scale_shift.hpp" + +#include + +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + template + class InnerProductOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + InnerProductOp(csl::Stream stream_, csl::cublas::Handle handle, std::size_t axis, const Mat& weights, const Mat& bias) + : stream(std::move(stream_)), cublasHandle(std::move(handle)), axis{ axis } + { + weightsTensor = csl::makeTensorHeader(weights); + CV_Assert(get_effective_rank(weightsTensor) == 2); + csl::copyMatToTensor(weights, weightsTensor, stream); + + if (!bias.empty()) + { + biasTensor = csl::makeTensorHeader(bias); + csl::copyMatToTensor(bias, biasTensor, stream); + CV_Assert(weightsTensor.get_axis_size(-2) == biasTensor.size()); + } + } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + for (int i = 0; i < inputs.size(); i++) + { + auto input_wrapper = inputs[i].dynamicCast(); + auto input = input_wrapper->getView(); + + auto output_wrapper = outputs[i].dynamicCast(); + auto output = output_wrapper->getSpan(); + + std::size_t batch_size = input.size_range(0, axis); + + auto input_size = input.size() / batch_size; + CV_Assert(input_size == weightsTensor.get_axis_size(-1)); + + auto output_size = output.size() / batch_size; + CV_Assert(output_size == weightsTensor.get_axis_size(-2)); + + /* we treat the input and output as a matrix with dimensions (batch_size, input_size) + * and (batch_size, output_size) respectively + * + * weight matrix dimensions: (output_size, input_size) + * + * I(W^T) = O + * (batch_size, input_size) * (input_size, output_size) = (batch_size, output_size) + */ + input.reshape(batch_size, input_size); + output.reshape(batch_size, output_size); + csl::tensor_ops::gemm(cublasHandle, 0.0, output, 1.0, false, input, true, weightsTensor); + + if (!biasTensor.empty()) + kernels::biasN(stream, output, output, 1, biasTensor); + } + } + + private: + csl::Stream stream; + csl::cublas::Handle cublasHandle; + csl::Tensor weightsTensor, biasTensor; + std::size_t axis; + }; + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_INNER_PRODUCT_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/lrn.hpp b/modules/dnn/src/cuda4dnn/primitives/lrn.hpp new file mode 100644 index 0000000..13d86fc --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/lrn.hpp @@ -0,0 +1,75 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_LRN_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_LRN_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/cudnn.hpp" +#include "../csl/tensor_ops.hpp" + +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + enum class LRNType { + ACROSS_CHANNELS, + WITHIN_CHANNEL + }; + + template + class LRNOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + LRNOp(csl::cudnn::Handle handle, LRNType type_, std::size_t local_size, T alpha, T beta, T bias, std::size_t largestInputSize) + : scratch_mem_in_bytes { 0 } + { + typename csl::LRN::LRNType type{}; + switch (type_) { + case LRNType::ACROSS_CHANNELS: type = csl::LRN::LRNType::ACROSS_CHANNELS; break; + case LRNType::WITHIN_CHANNEL: type = csl::LRN::LRNType::WITHIN_CHANNEL; break; + } + lrn = csl::LRN(std::move(handle), local_size, alpha, beta, bias, type); + + csl::WorkspaceBuilder builder; + if (type_ == LRNType::WITHIN_CHANNEL) { + /* this is not a bug; we require two of these */ + builder.require(largestInputSize); + builder.require(largestInputSize); + } + + scratch_mem_in_bytes = builder.required_workspace_size(); + } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + for (int i = 0; i < inputs.size(); i++) + { + auto input_wrapper = inputs[i].dynamicCast(); + auto input = input_wrapper->getView(); + + auto output_wrapper = outputs[i].dynamicCast(); + auto output = output_wrapper->getSpan(); + + csl::WorkspaceAllocator allocator(workspace); + lrn.normalize(input, output, allocator.get_instance()); + } + } + + std::size_t get_workspace_memory_in_bytes() const noexcept override { return scratch_mem_in_bytes; } + + private: + csl::LRN lrn; + std::size_t scratch_mem_in_bytes; + }; + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_LRN_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/max_unpooling.hpp b/modules/dnn/src/cuda4dnn/primitives/max_unpooling.hpp new file mode 100644 index 0000000..1102dc5 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/max_unpooling.hpp @@ -0,0 +1,182 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_MAX_UNPOOLING_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_MAX_UNPOOLING_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/stream.hpp" + +#include "../kernels/max_unpooling.hpp" + +#include + +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + struct MaxPoolingConfiguration { + /* the size of the following vectors must be equal to the pooling order */ + std::vector window_size; + std::vector strides; + + enum class PaddingMode { + MANUAL, /* uses explicit padding values provided in `pads_begin` and `pads_end` */ + VALID, /* no padding is added */ + SAME /* TensorFlow logic is used for same padding */ + }; + + PaddingMode padMode; + + /* explicit paddings are used if and only if padMode is set to manual */ + std::vector pads_begin; + + /* full shape inclusive of channel and batch axis */ + std::vector input_shape; + }; + + template + class MaxPoolingOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + MaxPoolingOp(csl::Stream stream_, const MaxPoolingConfiguration& config) + : stream(std::move(stream_)) + { + window_size = config.window_size; + + const auto pooling_order = window_size.size(); + CV_Assert(pooling_order >= 1); + + strides = config.strides; + CV_Assert(pooling_order == strides.size()); + + if (pooling_order != 2 && pooling_order != 3) + CV_Error(Error::StsNotImplemented, "Only 2D/3D max-pooling are supported."); + + padding_left.resize(pooling_order); + if (config.padMode == MaxPoolingConfiguration::PaddingMode::MANUAL) + { + const auto& pads_begin = config.pads_begin; + CV_Assert(pooling_order == pads_begin.size()); + + padding_left.assign(std::begin(pads_begin), std::end(pads_begin)); + } + else if (config.padMode == MaxPoolingConfiguration::PaddingMode::VALID) + { + /* nothing to do as the paddings are already preset to zero */ + } + else if (config.padMode == MaxPoolingConfiguration::PaddingMode::SAME) + { + /* TensorFlow Logic: + * total_padding[i] = (o[i] - 1) * s[i] + effective_k[i] - i[i] + * + * if total padding is odd, the extra is added towards the end + */ + const auto& input_shape = config.input_shape; + CV_Assert(input_shape.size() == pooling_order + 2); + + for (int i = 0; i < pooling_order; i++) + { + const auto output_dim = (input_shape[i + 2] - 1 + strides[i]) / strides[i]; + const auto required_total_padding = + std::max(0, (output_dim - 1) * strides[i] + window_size[i] - input_shape[i + 2]); + + padding_left[i] = required_total_padding / 2; + } + } + } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + CV_Assert(inputs.size() == 1 && outputs.size() == 2); + + auto input_wrapper = inputs[0].dynamicCast(); + auto input_data = input_wrapper->getView(); + + auto output_wrapper = outputs[0].dynamicCast(); + auto output_data = output_wrapper->getSpan(); + + auto indices_wrapper = outputs[1].dynamicCast(); + auto output_indices = indices_wrapper->getSpan(); + + kernels::max_pooling_with_indices( + stream, output_data, output_indices, input_data, window_size, strides, padding_left + ); + } + + private: + csl::Stream stream; + + std::vector window_size, strides, padding_left; + }; + + struct MaxUnpoolingConfiguration { + /* the size of the following vectors must be equal to the unpooling order */ + std::vector window_size; + std::vector strides; + std::vector pads_begin; + }; + + template + class MaxUnpoolingOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + MaxUnpoolingOp(csl::Stream stream_, const MaxUnpoolingConfiguration& config) + : stream(std::move(stream_)) + { + window_size = config.window_size; + + const auto pooling_order = window_size.size(); + CV_Assert(pooling_order >= 1); + + strides = config.strides; + padding_left = config.pads_begin; + CV_Assert(strides.size() == pooling_order); + CV_Assert(padding_left.size() == pooling_order); + + if (pooling_order != 2 && pooling_order != 3) + CV_Error(Error::StsNotImplemented, "Only 2D/3D max-unpooling are supported."); + } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + /* sometimes a third input is passed to provide the output shape; we won't need it */ + CV_Assert(inputs.size() == 2 || inputs.size() == 3); + CV_Assert(outputs.size() >= 1); + + for(int i = 0; i < outputs.size(); i++) + { + auto input_wrapper = inputs[0].dynamicCast(); + auto input_data = input_wrapper->getView(); + + auto indices_wrapper = inputs[1].dynamicCast(); + auto input_indices = indices_wrapper->getView(); + + auto output_wrapper = outputs[i].dynamicCast(); + auto output_data = output_wrapper->getSpan(); + + kernels::max_unpooling(stream, output_data, input_data, input_indices, window_size, strides, padding_left); + } + } + + private: + csl::Stream stream; + + std::vector window_size, strides, padding_left; + }; + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_MAX_UNPOOLING_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/normalize_bbox.hpp b/modules/dnn/src/cuda4dnn/primitives/normalize_bbox.hpp new file mode 100644 index 0000000..a61ab99 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/normalize_bbox.hpp @@ -0,0 +1,142 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_NORMALIZE_BBOX_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_NORMALIZE_BBOX_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/stream.hpp" +#include "../csl/span.hpp" +#include "../csl/tensor.hpp" +#include "../csl/workspace.hpp" + +#include "../kernels/scale_shift.hpp" +#include "../kernels/normalize.hpp" + +#include + +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + template + struct NormalizeConfiguration { + std::vector input_shape; + + /* axis range across which values are normalized + * + * [0, axis_start) = outer range + * [axis_start, axis_end) = mid range + * [axis_end + 1, -1) = inner range + * + * for each location in the outer and inner range, all the values in the mid range are + * normalized together + */ + std::size_t axis_start, axis_end; + + /* 1 for L1 norm, 2 for L2 norm */ + std::size_t norm; + + /* epsilon to use to avoid divison by zero */ + T eps; + }; + + template + class NormalizeOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + template + NormalizeOp(csl::Stream stream_, const Mat& weights, const NormalizeConfiguration& config) + : stream(std::move(stream_)), weight{ 1.0 } + { + norm_order = config.norm; + epsilon = config.eps; + axis_start = config.axis_start; + axis_end = config.axis_end; + + if (!weights.empty()) + { + if (weights.total() == 1) + { + CV_Assert(weights.type() == CV_32F); + weight = weights.at(0, 0); + } + else + { + weightsTensor = csl::makeTensorHeader(weights); + csl::copyMatToTensor(weights, weightsTensor, stream); + } + } + + std::size_t outer_size = 1; + for (int i = 0; i < axis_start; i++) + outer_size *= config.input_shape[i]; + + std::size_t inner_size = 1; + for (int i = axis_end; i < config.input_shape.size(); i++) + inner_size *= config.input_shape[i]; + + csl::WorkspaceBuilder builder; + builder.require(outer_size * inner_size); + scratch_mem_in_bytes = builder.required_workspace_size(); + } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + CV_Assert(inputs.size() == 1 && outputs.size() == 1); + + auto input_wrapper = inputs[0].dynamicCast(); + auto input = input_wrapper->getView(); + + auto output_wrapper = outputs[0].dynamicCast(); + auto output = output_wrapper->getSpan(); + + std::size_t outer_size = input.size_range(0, axis_start); + std::size_t mid_size = input.size_range(axis_start, axis_end); + std::size_t inner_size = input.size_range(axis_end, input.rank()); + + auto ws_allocator = csl::WorkspaceAllocator(workspace); + auto scratch = ws_allocator.get_span(); + kernels::normalize(stream, output, input, outer_size, mid_size, inner_size, norm_order, epsilon, scratch); + + /* there might be a single weight in which case `weight` will be not equal to 1.0 + * or there might be several weights + * or we don't have to scale + */ + if (weight != 1.0) + { + kernels::scale1(stream, output, input, weight); + } + else if (!weightsTensor.empty()) + { + CV_Assert(weightsTensor.size() != 1); /* constructor should have set up to use `weight` */ + CV_Assert(weightsTensor.size() == mid_size); + kernels::scaleN(stream, output, input, inner_size, weightsTensor); + } + } + + std::size_t get_workspace_memory_in_bytes() const noexcept override { return scratch_mem_in_bytes; } + + private: + csl::Stream stream; + csl::Tensor weightsTensor; + T weight; /* if there is only one weight, we use this */ + + T epsilon; + std::size_t norm_order; + std::size_t axis_start, axis_end; + + std::size_t scratch_mem_in_bytes; + }; + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_NORMALIZE_BBOX_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/padding.hpp b/modules/dnn/src/cuda4dnn/primitives/padding.hpp new file mode 100644 index 0000000..9795378 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/padding.hpp @@ -0,0 +1,113 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_PADDING_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_PADDING_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/stream.hpp" +#include "../csl/tensor.hpp" + +#include "../kernels/fill.hpp" +#include "../kernels/concat.hpp" +#include "../kernels/padding.hpp" + +#include + +#include +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + enum class PaddingType { + CONSTANT, + REFLECTION101 + }; + + template + class PaddingOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + /* `ranges` is indexed by axis and contains the range in the output where the input is copied to */ + PaddingOp(csl::Stream stream_, PaddingType type_, T value_, std::vector ranges) + : stream(std::move(stream_)), type{ type_ }, value{ value_ }, dstRanges(std::move(ranges)) + { + } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + CV_Assert(inputs.size() == 1 && outputs.size() == 1); + + auto input_wrapper = inputs[0].dynamicCast(); + auto input = input_wrapper->getView(); + + auto output_wrapper = outputs[0].dynamicCast(); + auto output = output_wrapper->getSpan(); + + auto effective_rank = get_effective_rank(input); + CV_Assert(get_effective_rank(input) == get_effective_rank(output)); + + /* suppose we require padding for the first spatial axis (H in NCHW or D in NCDHW) + * + * there could be a case where the batch axis, channel axis, and the first spatial axis are all one + * this would result in effective rank being less than the number of axes requiring padding + */ + effective_rank = std::max(effective_rank, dstRanges.size()); + + for (int i = effective_rank - dstRanges.size(); i < effective_rank; i++) + { + if (dstRanges[i] == Range::all()) + CV_Assert(input.get_axis_size(i) == output.get_axis_size(i)); + else + CV_Assert(input.get_axis_size(i) == dstRanges[i].size()); + } + + if (type == PaddingType::CONSTANT) + { + kernels::fill(stream, output, value); + + std::vector offsets(effective_rank, 0); + for (int i = 0; i < dstRanges.size(); i++) + { + const auto delta = effective_rank - dstRanges.size(); + if (dstRanges[i] != Range::all()) + offsets[delta + i] = dstRanges[i].start; + } + + kernels::concat_with_offsets(stream, output, input, offsets); + } + else if (type == PaddingType::REFLECTION101) + { + std::vector> ranges(effective_rank); + for (int i = 0; i < effective_rank; i++) + { + const auto delta = effective_rank - dstRanges.size(); + if (i < delta || dstRanges[i - delta] == Range::all()) + ranges[i] = { 0, input.get_axis_size(i) }; + else + ranges[i] = { dstRanges[i].start, dstRanges[i].end }; + } + + kernels::copy_with_reflection101(stream, output, input, ranges); + } + } + + private: + csl::Stream stream; + PaddingType type; + T value; + + std::vector dstRanges; + }; + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_PADDING_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/permute.hpp b/modules/dnn/src/cuda4dnn/primitives/permute.hpp new file mode 100644 index 0000000..9d49241 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/permute.hpp @@ -0,0 +1,70 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_PERMUTE_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_PERMUTE_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/stream.hpp" +#include "../csl/tensor_ops.hpp" + +#include "../kernels/permute.hpp" + +#include + +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + template + class PermuteOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + PermuteOp(csl::Stream stream_, std::vector order_) + : stream(std::move(stream_)), order(std::move(order_)) { } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + for (int i = 0; i < inputs.size(); i++) + { + auto input_wrapper = inputs[i].dynamicCast(); + auto input = input_wrapper->getView(); + + auto output_wrapper = outputs[i].dynamicCast(); + auto output = output_wrapper->getSpan(); + + auto needsPermute = [&] { + for (int i = 0; i < order.size(); i++) + if (order[i] != i) + return true; + return false; + }(); + + if (needsPermute) + { + kernels::permute(stream, output, input, order); + } + else + { + if (input.get() != output.get()) + csl::tensor_ops::copy(stream, output, input); + } + } + } + + private: + csl::Stream stream; + std::vector order; + }; + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_PERMUTE_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/pooling.hpp b/modules/dnn/src/cuda4dnn/primitives/pooling.hpp new file mode 100644 index 0000000..8b8cf37 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/pooling.hpp @@ -0,0 +1,258 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_POOLING_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_POOLING_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/cudnn.hpp" +#include "../csl/tensor.hpp" +#include "../csl/tensor_ops.hpp" + +#include + +#include +#include +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + struct PoolingConfiguration { + enum class PoolingMode { + MAX, + AVERAGE_INCLUDE_PADDING, /* include padding while calculating average */ + AVERAGE_EXCLUDE_PADDING /* exclude padding while calculating average */ + }; + + PoolingMode poolMode; + + /* the size of the following vectors must be equal to the window size */ + std::vector window_size; + std::vector strides; + + enum class PaddingMode { + MANUAL, /* uses explicit padding values provided in `pads_begin` and `pads_end` */ + VALID, /* no padding is added */ + SAME /* TensorFlow logic is used for same padding */ + }; + + PaddingMode padMode; + + /* explicit paddings are used if and only if padMode is set to manual */ + std::vector pads_begin, pads_end; + + /* the output shape is calculated using the following formula: + * output_dim = func[(input_dim + padding_left + padding_right - kernel_dim)/stride] + 1 + * + * rounding mode decides what is used as `func` + */ + enum class RoundingMode { + CEIL, /* uses ceil */ + FLOOR + }; + + RoundingMode roundMode; + + /* full shape inclusive of channel and batch axis */ + std::vector input_shape; + }; + + template + class PoolingOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + PoolingOp(csl::cudnn::Handle handle, const PoolingConfiguration& config) + : cudnnHandle(std::move(handle)) + { + const auto& window_size = config.window_size; + + const auto pooling_order = window_size.size(); + CV_Assert(pooling_order >= 1); + + const auto& strides = config.strides; + CV_Assert(pooling_order == strides.size()); + + const auto& input_shape = config.input_shape; + CV_Assert(input_shape.size() == pooling_order + 2); + + if (pooling_order > 3) + CV_Error(Error::StsNotImplemented, "Only 1D/2D/3D pooling are supported."); + + const auto rank = input_shape.size(); + + /* left and right are misleading as the padding is applicable for any number of dimensions + * but we use those identifiers to avoid confusion with `pads_begin` and `pads_end` + * + * `common_padding` contains the amount of padding that has to be added to both sides + * `padding_left` and `padding_right` contains the amount of padding that needs to be added + * to a particular side in addition to the common padding + */ + std::vector common_padding(rank, 0); + std::vector padding_left(rank, 0), padding_right(rank, 0); + if (config.padMode == PoolingConfiguration::PaddingMode::MANUAL) + { + const auto& pads_begin = config.pads_begin; + const auto& pads_end = config.pads_end; + + CV_Assert(pooling_order == pads_begin.size()); + CV_Assert(pooling_order == pads_end.size()); + + /* cuDNN rounds down by default; hence, if ceilMode is false, we do nothing + * otherwise, we add extra padding towards the end so that the convolution arithmetic yeilds + * the correct output size without having to deal with fancy fractional sizes + */ + auto pads_end_modified = pads_end; + if (config.roundMode == PoolingConfiguration::RoundingMode::CEIL) + { + for (int i = 0; i < window_size.size(); i++) { + auto rem = (input_shape[i + 2] + pads_begin[i] + pads_end[i] - window_size[i]) % strides[i]; + if (rem) + pads_end_modified[i] += strides[i] - rem; + } + } + + for (int i = 2; i < common_padding.size(); i++) + { + common_padding[i] = std::min(pads_begin[i - 2], pads_end_modified[i - 2]); + padding_left[i] = pads_begin[i - 2] - common_padding[i]; + padding_right[i] = pads_end_modified[i - 2] - common_padding[i]; + } + } + else if (config.padMode == PoolingConfiguration::PaddingMode::VALID) + { + /* nothing to do as the paddings are already preset to zero */ + } + else if (config.padMode == PoolingConfiguration::PaddingMode::SAME) + { + /* TensorFlow Logic: + * total_padding[i] = (o[i] - 1) * s[i] + effective_k[i] - i[i] + * + * if total padding is odd, the extra is added towards the end + */ + for (int i = 2; i < rank; i++) + { + const auto j = i - 2; /* filter index */ + const auto output_dim = (input_shape[i] - 1 + strides[j]) / strides[j]; + const auto required_total_padding = + std::max(0, (output_dim - 1) * strides[j] + window_size[j] - input_shape[i]); + + common_padding[i] = required_total_padding / 2; + padding_left[i] = 0; + padding_right[i] = required_total_padding % 2; + } + } + + /* in some scenarios, the extra padding at the end may not change the output at all */ + for (int i = 2; i < rank; i++) { + const auto j = i - 2; /* filter idx */ + const auto total_padding = common_padding[i] * 2 + padding_left[i] + padding_right[i]; + std::int64_t rem = (input_shape[i] + total_padding - window_size[j]) % strides[j]; + + /* the output shape doesn't change if we decrease the total padding by at most `rem` + * provided that we decrease from the right + */ + if (rem && padding_right[i] > 0) + padding_right[i] = std::max(0, padding_right[i] - rem); + } + + auto is_not_zero = [](std::size_t i) { return i != 0; }; + if (std::any_of(std::begin(padding_left), std::end(padding_left), is_not_zero) || + std::any_of(std::begin(padding_right), std::end(padding_right), is_not_zero)) + { + /* csl::Pooling does not fully support asymmetric padding; hence, we deal with asymmetric padding by + * copying the input to a bigger tensor and padding the ends manually + * + * But we first try to avoid the transformation using cuDNN's flexibility. cuDNN can accept a smaller or + * a bigger output shape. This effectively allows us to have arbitary padding at the right. + */ + if (std::any_of(std::begin(padding_left), std::end(padding_left), is_not_zero)) + { + /* there is padding on the left and we are forced to transform */ + auto transformed_input_shape = input_shape; + for (int i = 0; i < rank; i++) + transformed_input_shape[i] += padding_left[i] + padding_right[i]; + + transformedInput.resize(std::begin(transformed_input_shape), std::end(transformed_input_shape)); + inputTransformer = csl::TensorTransform(cudnnHandle, padding_left, padding_right); + } + } + + typename csl::Pooling::params_type params; + if (transformedInput.empty()) + { + /* no transform => use original input shape */ + params.input_shape.assign(std::begin(input_shape), std::end(input_shape)); + } + else + { + /* the pooling operation will be seeing the transformed input */ + auto transformed_input_shape = transformedInput.shape_as_vector(); + params.input_shape.assign(std::begin(transformed_input_shape), std::end(transformed_input_shape)); + } + + auto output_shape = input_shape; + for (int i = 2; i < rank; i++) + { + auto total_padding = common_padding[i] * 2 + padding_left[i] + padding_right[i]; + output_shape[i] = (params.input_shape[i] + total_padding - window_size[i - 2]) / strides[i - 2] + 1; + } + + params.output_shape.assign(std::begin(output_shape), std::end(output_shape)); + params.window_size = window_size; + params.padding.assign(std::begin(common_padding) + 2, std::end(common_padding)); + params.stride = strides; + + if (config.poolMode == PoolingConfiguration::PoolingMode::MAX) + { + params.type = csl::Pooling::PoolingType::MAX; + } + else if (config.poolMode == PoolingConfiguration::PoolingMode::AVERAGE_INCLUDE_PADDING) + { + params.type = csl::Pooling::PoolingType::AVERAGE_INCLUDE_PADDING; + } + else if (config.poolMode == PoolingConfiguration::PoolingMode::AVERAGE_EXCLUDE_PADDING) + { + params.type = csl::Pooling::PoolingType::AVERAGE_EXCLUDE_PADDING; + } + + pooler = csl::Pooling(cudnnHandle, params); + } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + CV_Assert(inputs.size() == 1 && outputs.size() == 1); + + auto input_wrapper = inputs[0].dynamicCast(); + auto input = input_wrapper->getView(); + + if (!transformedInput.empty()) + { + inputTransformer.transform(input, transformedInput); + input = csl::TensorView(transformedInput); + } + + auto output_wrapper = outputs[0].dynamicCast(); + auto output = output_wrapper->getSpan(); + + pooler.pool(input, output); + } + + private: + csl::cudnn::Handle cudnnHandle; + csl::Pooling pooler; + + csl::Tensor transformedInput; + csl::TensorTransform inputTransformer; + }; + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_POOLING_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/prior_box.hpp b/modules/dnn/src/cuda4dnn/primitives/prior_box.hpp new file mode 100644 index 0000000..5f48a8c --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/prior_box.hpp @@ -0,0 +1,136 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_PRIOR_BOX_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_PRIOR_BOX_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/stream.hpp" +#include "../csl/span.hpp" +#include "../csl/tensor.hpp" + +#include "../kernels/prior_box.hpp" + +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + struct PriorBoxConfiguration { + std::size_t feature_map_width, feature_map_height; + std::size_t image_width, image_height; + + /* parameters for prior boxes for each feature point */ + std::vector box_widths, box_heights; + std::vector offsets_x, offsets_y; + float stepX, stepY; + + std::vector variance; + + /* number of priors per feature point */ + std::size_t num_priors; + + /* clamps the box coordinates to [0, 1] range */ + bool clip; + + /* normalizes the box coordinates using the image dimensions */ + bool normalize; + }; + + template + class PriorBoxOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + PriorBoxOp(csl::Stream stream_, const PriorBoxConfiguration& config) + : stream(std::move(stream_)) + { + feature_map_width = config.feature_map_width; + feature_map_height = config.feature_map_height; + + image_width = config.image_width; + image_height = config.image_height; + + const auto& box_widths = config.box_widths; + const auto& box_heights = config.box_heights; + CV_Assert(box_widths.size() == box_heights.size()); + + box_size = box_widths.size(); + + const auto& offsets_x = config.offsets_x; + const auto& offsets_y = config.offsets_y; + CV_Assert(offsets_x.size() == offsets_y.size()); + + offset_size = offsets_x.size(); + + /* for better memory utilization and preassumably better cache performance, we merge + * the four vectors and put them in a single tensor + */ + auto total = box_widths.size() * 2 + offsets_x.size() * 2; + std::vector merged_params; + merged_params.insert(std::end(merged_params), std::begin(box_widths), std::end(box_widths)); + merged_params.insert(std::end(merged_params), std::begin(box_heights), std::end(box_heights)); + merged_params.insert(std::end(merged_params), std::begin(offsets_x), std::end(offsets_x)); + merged_params.insert(std::end(merged_params), std::begin(offsets_y), std::end(offsets_y)); + CV_Assert(merged_params.size() == total); + + paramsTensor.resize(total); + csl::memcpy(paramsTensor.get(), merged_params.data(), total, stream); /* synchronous copy */ + + const auto& variance_ = config.variance; + variance.assign(std::begin(variance_), std::end(variance_)); + + num_priors = config.num_priors; + stepX = config.stepX; + stepY = config.stepY; + clip = config.clip; + normalize = config.normalize; + } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + CV_Assert(inputs.size() == 2); /* we don't need the inputs but we are given */ + CV_Assert(outputs.size() == 1); + + auto output_wrapper = outputs[0].dynamicCast(); + auto output = output_wrapper->getSpan(); + + /* we had stored all the parameters in a single tensor; now we create appropriate views + * for each of the parameter arrays from the single tensor + */ + auto boxWidths = csl::View(paramsTensor.get(), box_size); + auto boxHeights = csl::View(paramsTensor.get() + box_size, box_size); + auto offsetsX = csl::View(paramsTensor.get() + 2 * box_size, offset_size); + auto offsetsY = csl::View(paramsTensor.get() + 2 * box_size + offset_size, offset_size); + + kernels::generate_prior_boxes(stream, output, + boxWidths, boxHeights, offsetsX, offsetsY, stepX, stepY, + variance, num_priors, feature_map_width, feature_map_height, image_width, image_height, normalize, clip); + } + + private: + csl::Stream stream; + csl::Tensor paramsTensor; /* widths, heights, offsetsX, offsetsY */ + + std::size_t feature_map_width, feature_map_height; + std::size_t image_width, image_height; + + std::size_t box_size, offset_size; + float stepX, stepY; + + std::vector variance; + + std::size_t num_priors; + bool clip, normalize; + }; + + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_PRIOR_BOX_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/region.hpp b/modules/dnn/src/cuda4dnn/primitives/region.hpp new file mode 100644 index 0000000..775dd0f --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/region.hpp @@ -0,0 +1,181 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_REGION_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_REGION_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/stream.hpp" +#include "../csl/cudnn.hpp" +#include "../csl/tensor_ops.hpp" + +#include "../kernels/region.hpp" + +#include "../../nms.inl.hpp" + +#include + +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + enum class SquashMethod { + SOFTMAX, + SIGMOID + }; + + template + struct RegionConfiguration { + /* The image is divided into (H, W) cells. + * + * Each cell is interested in exactly one object and predicts `boxes_per_cell` bounding boxes + * for that object. + * + * Each bounding box contains: + * - 4 box coordinates + * - objectness confidence score + * - `classes` number of class scores + * + * The object score is reduced to a probability using sigmoid and the class scores are reduced to + * probabilities by either applying sigmoid or softmax (which is a configuration option). + * + * object_prob = sigmoid(object_score) + * conditional_class_prob = sigmoid, softmax across all classes + * + * actual class probability = conditional_class_prob * object_prob + */ + + /* method for reducing class scores to probabilities */ + SquashMethod squash_method; + + std::size_t classes, boxes_per_cell; + + std::size_t width_norm, height_norm; + + /* prob cutoffs below which the prediction is nulled */ + T object_prob_cutoff; + T class_prob_cutoff; + + T nms_iou_threshold; + }; + + template + class RegionOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + template + RegionOp(csl::Stream stream_, const cv::Mat& bias, const RegionConfiguration& config) + : stream(std::move(stream_)) + { + biasTensor = csl::makeTensorHeader(bias); + csl::copyMatToTensor(bias, biasTensor, stream); + + classes = config.classes; + boxes_per_cell = config.boxes_per_cell; + + width_norm = config.width_norm; + height_norm = config.height_norm; + + squash_type = config.squash_method; + + object_prob_cutoff = config.object_prob_cutoff; + class_prob_cutoff = config.class_prob_cutoff; + + nms_iou_threshold = config.nms_iou_threshold; + } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + CV_Assert(outputs.size() == 1); + + auto input_wrapper = inputs[0].dynamicCast(); + auto input = input_wrapper->getView(); + + auto output_wrapper = outputs[0].dynamicCast(); + auto output = output_wrapper->getSpan(); + + csl::memcpy(output.get(), input.get(), output.size(), stream); + + auto rows = input.get_axis_size(1); + auto cols = input.get_axis_size(2); + + auto cell_box_size = classes + 4 + 1; + + /* we squash class scores into probabilities using softmax or sigmoid */ + if (squash_type == SquashMethod::SOFTMAX) + kernels::softmax_strided(stream, output, input, classes, cell_box_size, 5); + else if (squash_type == SquashMethod::SIGMOID) + kernels::sigmoid_strided(stream, output, input, classes, cell_box_size, 5); + + kernels::region_finalize(stream, output, input, biasTensor, object_prob_cutoff, class_prob_cutoff, + height_norm, width_norm, rows, cols, boxes_per_cell, cell_box_size, classes); + + if (nms_iou_threshold > 0) { + auto output_mat = output_wrapper->getMutableHostMat(); + CV_Assert(output_mat.type() == CV_32F); + for (int i = 0; i < input.get_axis_size(0); i++) { + auto sample_size = rows * cols * boxes_per_cell * cell_box_size; + do_nms_sort(reinterpret_cast(output_mat.data) + i * sample_size, rows * cols * boxes_per_cell, class_prob_cutoff, nms_iou_threshold); + } + } + } + + private: + void do_nms_sort(float *detections, int total, float score_thresh, float nms_thresh) + { + std::vector boxes(total); + std::vector scores(total); + + for (int i = 0; i < total; ++i) + { + Rect2d &b = boxes[i]; + int box_index = i * (classes + 4 + 1); + b.width = detections[box_index + 2]; + b.height = detections[box_index + 3]; + b.x = detections[box_index + 0] - b.width / 2; + b.y = detections[box_index + 1] - b.height / 2; + } + + std::vector indices; + for (int k = 0; k < classes; ++k) + { + for (int i = 0; i < total; ++i) + { + int box_index = i * (classes + 4 + 1); + int class_index = box_index + 5; + scores[i] = detections[class_index + k]; + detections[class_index + k] = 0; + } + NMSBoxes(boxes, scores, score_thresh, nms_thresh, indices); + for (int i = 0, n = indices.size(); i < n; ++i) + { + int box_index = indices[i] * (classes + 4 + 1); + int class_index = box_index + 5; + detections[class_index + k] = scores[indices[i]]; + } + } + } + + private: + csl::Stream stream; + + csl::Tensor biasTensor; + std::size_t classes, boxes_per_cell; + std::size_t width_norm, height_norm; + SquashMethod squash_type; + + T object_prob_cutoff, class_prob_cutoff; + T nms_iou_threshold; + }; + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_REGION_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/reorg.hpp b/modules/dnn/src/cuda4dnn/primitives/reorg.hpp new file mode 100644 index 0000000..45185af --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/reorg.hpp @@ -0,0 +1,75 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_REORG_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_REORG_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/stream.hpp" +#include "../kernels/permute.hpp" + +#include + +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + template + class ReorgOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + ReorgOp(csl::Stream stream_, std::size_t stride_) + : stream(std::move(stream_)), stride{ stride_ } { } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + CV_Assert(inputs.size() == 1 && outputs.size() == 1); + + auto input_wrapper = inputs[0].dynamicCast(); + auto input = input_wrapper->getView(); + + auto output_wrapper = outputs[0].dynamicCast(); + auto output = output_wrapper->getSpan(); + + const std::size_t permute_input_shape[] = { + input.get_axis_size(0), + input.get_axis_size(1) * input.get_axis_size(2) / (stride * stride), + stride, + input.get_axis_size(3), + stride + }; + + constexpr std::size_t order[] = { 0, 2, 4, 1, 3 }; + + const std::size_t permute_output_shape[] = { + permute_input_shape[order[0]], + permute_input_shape[order[1]], + permute_input_shape[order[2]], + permute_input_shape[order[3]], + permute_input_shape[order[4]] + }; + + input.unsqueeze(); + input.reshape(std::begin(permute_input_shape), std::end(permute_input_shape)); + + output.unsqueeze(); + output.reshape(std::begin(permute_output_shape), std::end(permute_output_shape)); + + kernels::permute(stream, output, input, { std::begin(order), std::end(order) }); + } + + private: + csl::Stream stream; + std::size_t stride; + }; + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_REORG_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/reshape.hpp b/modules/dnn/src/cuda4dnn/primitives/reshape.hpp new file mode 100644 index 0000000..2cf1d67 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/reshape.hpp @@ -0,0 +1,61 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_RESHAPE_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_RESHAPE_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/stream.hpp" +#include "../csl/tensor.hpp" +#include "../csl/tensor_ops.hpp" + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + template + class ReshapeOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + ReshapeOp(csl::Stream stream_) : stream(std::move(stream_)) { } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + /* sometimes the output shape is passed as extra inputs; hence, >= instead of == */ + CV_Assert(inputs.size() >= outputs.size()); + + for (int i = 0; i < outputs.size(); i++) + { + auto input_wrapper = inputs[i].dynamicCast(); + auto input = input_wrapper->getView(); + + auto output_wrapper = outputs[i].dynamicCast(); + auto output = output_wrapper->getSpan(); + + if (input.get() != output.get()) + { + while (input.rank() < output.rank()) + input.unsqueeze(); + + while (output.rank() < input.rank()) + output.unsqueeze(); + + input.reshape_as(output); + csl::tensor_ops::copy(stream, output, input); + } + } + } + + private: + csl::Stream stream; + }; + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_RESHAPE_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/resize.hpp b/modules/dnn/src/cuda4dnn/primitives/resize.hpp new file mode 100644 index 0000000..3caf58d --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/resize.hpp @@ -0,0 +1,60 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_RESIZE_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_RESIZE_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/stream.hpp" + +#include "../kernels/resize.hpp" + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + enum class InterpolationType { + NEAREST_NEIGHBOUR, + BILINEAR + }; + + template + class ResizeOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + ResizeOp(csl::Stream stream_, InterpolationType type_, float scaleHeight_, float scaleWidth_) + : stream(std::move(stream_)), type{ type_ }, scaleHeight{ scaleHeight_ }, scaleWidth{ scaleWidth_ } + { + } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + CV_Assert(inputs.size() == 1 && outputs.size() == 1); + + auto input_wrapper = inputs[0].dynamicCast(); + auto input = input_wrapper->getView(); + + auto output_wrapper = outputs[0].dynamicCast(); + auto output = output_wrapper->getSpan(); + + if (type == InterpolationType::NEAREST_NEIGHBOUR) + kernels::resize_nn(stream, output, input); + else if (type == InterpolationType::BILINEAR) + kernels::resize_bilinear(stream, output, input, scaleHeight, scaleWidth); + } + + private: + csl::Stream stream; + InterpolationType type; + float scaleHeight, scaleWidth; /* for bilinear interpolation */ + }; + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_RESIZE_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/scale_shift.hpp b/modules/dnn/src/cuda4dnn/primitives/scale_shift.hpp new file mode 100644 index 0000000..399cce0 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/scale_shift.hpp @@ -0,0 +1,110 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_SCALE_SHIFT_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_SCALE_SHIFT_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/stream.hpp" +#include "../csl/tensor.hpp" + +#include "../kernels/scale_shift.hpp" + +#include + +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + template + class ScaleShiftOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + ScaleShiftOp(csl::Stream stream_, std::size_t axis, const cv::Mat& weights, const cv::Mat& bias) + : stream(std::move(stream_)), axis{ axis } + { + if (!weights.empty()) + { + weightsTensor = csl::makeTensorHeader(weights); + csl::copyMatToTensor(weights, weightsTensor, stream); + } + + if (!bias.empty()) + { + biasTensor = csl::makeTensorHeader(bias); + csl::copyMatToTensor(bias, biasTensor, stream); + } + } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + CV_Assert(outputs.size() == 1); + + auto input_wrapper = inputs[0].dynamicCast(); + auto input = input_wrapper->getView(); + + auto output_wrapper = outputs[0].dynamicCast(); + auto output = output_wrapper->getSpan(); + + csl::TensorView weights; + if (weightsTensor.empty() && biasTensor.empty()) + { + CV_Assert(inputs.size() == 2); + + /* no explicit scale/shift values provided; use the second input as weights */ + auto wrapper = inputs[1].dynamicCast(); + weights = wrapper->getView(); + } + else if (!weightsTensor.empty()) + { + weights = csl::TensorSpan(weightsTensor); + } + + csl::TensorView bias; + if (!biasTensor.empty()) + bias = csl::TensorSpan(biasTensor); + + const auto numParams = !weights.empty() ? weights.size() : bias.size(); + CV_Assert(numParams != 0); + if (!weightsTensor.empty() && !biasTensor.empty()) + { + CV_CheckEQ(weights.size(), bias.size(), "weights and bias size are not equal"); + } + + /* the weights/bias might require broadcasting to scale/shift */ + const int end_axis = [&] { + for (int endAxis = axis + 1; endAxis <= input.rank(); endAxis++) + { + std::size_t size = input.size_range(axis, endAxis); + if (size == numParams) + return endAxis; + } + CV_Assert(0 /* invalid weights matrix */); + }(); + + std::size_t inner_size = input.size_range(end_axis, input.rank()); + + if (!weights.empty() && !bias.empty()) + kernels::scaleN_with_biasN(stream, output, input, inner_size, weights, bias); + else if (!weights.empty()) + kernels::scaleN(stream, output, input, inner_size, weights); + else + kernels::biasN(stream, output, input, inner_size, bias); + } + + private: + csl::Stream stream; + csl::Tensor weightsTensor, biasTensor; + std::size_t axis; + }; + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_SCALE_SHIFT_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/shuffle_channel.hpp b/modules/dnn/src/cuda4dnn/primitives/shuffle_channel.hpp new file mode 100644 index 0000000..6118886 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/shuffle_channel.hpp @@ -0,0 +1,79 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_SHUFFLE_CHANNEL_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_SHUFFLE_CHANNEL_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/stream.hpp" +#include "../csl/tensor_ops.hpp" + +#include "../kernels/permute.hpp" + +#include + +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + template + class ShuffleChannelOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + ShuffleChannelOp(csl::Stream stream_, std::size_t group_) + : stream(std::move(stream_)), group{ group_ } { } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + CV_Assert(inputs.size() == 1 && outputs.size() == 1); + + auto input_wrapper = inputs[0].dynamicCast(); + auto input = input_wrapper->getView(); + + auto output_wrapper = outputs[0].dynamicCast(); + auto output = output_wrapper->getSpan(); + + if (group == 1) { + /* permute is redundant; check else branch to know why */ + if (input.get() != output.get()) { + input.reshape_as(output); + csl::tensor_ops::copy(stream, output, input); + } + } else { + const std::size_t permute_input_shape[] = { + input.get_axis_size(0), + group, + input.get_axis_size(1) / group, + input.get_axis_size(2) * input.get_axis_size(3) + }; + + constexpr std::size_t order[] = { 0, 2, 1, 3 }; + + const std::size_t permute_output_shape[] = { + permute_input_shape[order[0]], + permute_input_shape[order[1]], + permute_input_shape[order[2]], + permute_input_shape[order[3]], + }; + + input.reshape(std::begin(permute_input_shape), std::end(permute_input_shape)); + output.reshape(std::begin(permute_output_shape), std::end(permute_output_shape)); + kernels::permute(stream, output, input, { std::begin(order), std::end(order) }); + } + } + + private: + csl::Stream stream; + std::size_t group; + }; + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_SHUFFLE_CHANNEL_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/slice.hpp b/modules/dnn/src/cuda4dnn/primitives/slice.hpp new file mode 100644 index 0000000..e5847d7 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/slice.hpp @@ -0,0 +1,62 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_SLICE_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_SLICE_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/stream.hpp" + +#include "../kernels/slice.hpp" + +#include + +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + template + class SliceOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + /* offsets is indexed by output number and each subvector is indexed by axis number */ + SliceOp(csl::Stream stream_, std::vector> offsets) + : stream(std::move(stream_)), offsets(std::move(offsets)) + { + } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + /* sometimes the output shape is passed in the form of a second input tensor + * it's only required for initialization and not here + */ + CV_Assert(inputs.size() == 1 || inputs.size() == 2); + + auto input_wrapper = inputs[0].dynamicCast(); + auto input = input_wrapper->getView(); + + for (int i = 0; i < outputs.size(); ++i) + { + auto output_wrapper = outputs[i].dynamicCast(); + auto output = output_wrapper->getSpan(); + + kernels::slice(stream, output, input, offsets[i]); + } + } + + private: + csl::Stream stream; + std::vector> offsets; + }; + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_SLICE_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/softmax.hpp b/modules/dnn/src/cuda4dnn/primitives/softmax.hpp new file mode 100644 index 0000000..fd19c5b --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/softmax.hpp @@ -0,0 +1,53 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_SOFTMAX_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_SOFTMAX_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/cudnn.hpp" +#include "../csl/tensor_ops.hpp" + +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + template + class SoftmaxOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + SoftmaxOp(csl::cudnn::Handle handle, std::size_t axis_, bool log_) + : cudnnHandle(std::move(handle)), channel_axis{ axis_ }, log{ log_ } + { + } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + for (int i = 0; i < inputs.size(); i++) + { + auto input_wrapper = inputs[i].dynamicCast(); + auto input = input_wrapper->getView(); + + auto output_wrapper = outputs[i].dynamicCast(); + auto output = output_wrapper->getSpan(); + + csl::tensor_ops::softmax(cudnnHandle, output, input, channel_axis, log); + } + } + + private: + csl::cudnn::Handle cudnnHandle; + std::size_t channel_axis; + bool log; + }; + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_SOFTMAX_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/split.hpp b/modules/dnn/src/cuda4dnn/primitives/split.hpp new file mode 100644 index 0000000..4e9535b --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/split.hpp @@ -0,0 +1,54 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_SPLIT_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_SPLIT_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/stream.hpp" +#include "../csl/tensor_ops.hpp" + +#include + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + template + class SplitOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + SplitOp(csl::Stream stream_) + : stream(std::move(stream_)) + { + } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + CV_Assert(inputs.size() == 1); + + auto input_wrapper = inputs[0].dynamicCast(); + auto input = input_wrapper->getView(); + + for (int i = 0; i < outputs.size(); i++) + { + auto output_wrapper = outputs[i].dynamicCast(); + auto output = output_wrapper->getSpan(); + + csl::tensor_ops::copy(stream, output, input); + } + } + + private: + csl::Stream stream; + }; + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_SPLIT_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/transpose_convolution.hpp b/modules/dnn/src/cuda4dnn/primitives/transpose_convolution.hpp new file mode 100644 index 0000000..8e5cda3 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/transpose_convolution.hpp @@ -0,0 +1,230 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_TRANSPOSE_CONVOLUTION_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_TRANSPOSE_CONVOLUTION_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/cudnn.hpp" +#include "../csl/stream.hpp" +#include "../csl/tensor.hpp" +#include "../csl/tensor_ops.hpp" + +#include "../kernels/scale_shift.hpp" + +#include + +#include +#include +#include +#include +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + struct TransposeConvolutionConfiguration { + /* other than `input_shape` and `output_shape`, all the configuration values must be provided + * for the corresponding convolution operation (not transpose convolution) + */ + + /* the size of the following vectors must be equal to the kernel size */ + std::vector kernel_size; + std::vector dilations, strides; + + enum class PaddingMode { + MANUAL, /* uses explicit padding values provided in `pads_begin` and `pads_end` */ + VALID, /* no padding is added */ + SAME /* TensorFlow logic is used for same padding */ + }; + + /* explicit paddings are used if and only if padMode is set to manual */ + PaddingMode padMode; + std::vector pads_begin, pads_end; + + /* full shape inclusive of channel and batch axis */ + std::vector input_shape; + std::vector output_shape; + + /* group count for grouped convolution */ + std::size_t groups; + }; + + template + class TransposeConvolutionOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + TransposeConvolutionOp(csl::Stream stream_, csl::cudnn::Handle handle, const TransposeConvolutionConfiguration& config, const Mat& filters, const Mat& bias) + : stream(std::move(stream_)), cudnnHandle(std::move(handle)) + { + /* we make use of backward pass of convolution to perform forward pass of transpose convolution + * hence, we must setup configuration for the convolution operation and perform backward pass + */ + const auto& kernel_size = config.kernel_size; + const auto& dilations = config.dilations; + const auto& strides = config.strides; + + const auto convolution_order = kernel_size.size(); + CV_Assert(convolution_order >= 1); + + CV_Assert(convolution_order == dilations.size()); + CV_Assert(convolution_order == strides.size()); + + const auto& input_shape = config.input_shape; + const auto& output_shape = config.output_shape; + CV_Assert(input_shape.size() == output_shape.size()); + CV_Assert(input_shape.size() == convolution_order + 2); + + const auto groups = config.groups; + + if (convolution_order > 3) + CV_Error(Error::StsNotImplemented, "Only 1D/2D/3D transpose convolution is supported."); + + const auto rank = input_shape.size(); + const auto input_feature_maps = input_shape[1]; + const auto output_feature_maps = output_shape[1]; + const auto output_feature_maps_per_group = output_feature_maps / groups; + CV_Assert(output_feature_maps % groups == 0); + + filtersTensor = csl::makeTensorHeader(filters); + csl::copyMatToTensor(filters, filtersTensor, stream); + + if (!bias.empty()) + { + CV_Assert(bias.total() == output_feature_maps); + biasTensor = csl::makeTensorHeader(bias); + csl::copyMatToTensor(bias, biasTensor, stream); + } + + /* left and right are misleading as the padding is applicable for any number of dimensions + * but we use those identifiers to avoid confusion with `pads_begin` and `pads_end` + * + * `common_padding` contains the amount of padding that has to be added to both sides + * `padding_left` and `padding_right` contains the amount of padding that needs to be added + * to a particular side in addition to the common padding + * + * note that we compute the padding for the convolution operation + */ + std::vector common_padding(rank, 0); + std::vector padding_left(rank, 0), padding_right(rank, 0); + if (config.padMode == TransposeConvolutionConfiguration::PaddingMode::MANUAL) + { + const auto& pads_begin = config.pads_begin; + const auto& pads_end = config.pads_end; + + CV_Assert(convolution_order == pads_begin.size()); + CV_Assert(convolution_order == pads_end.size()); + + for (int i = 2; i < common_padding.size(); i++) + { + common_padding[i] = std::min(pads_begin[i - 2], pads_end[i - 2]); + padding_left[i] = pads_begin[i - 2] - common_padding[i]; + padding_right[i] = pads_end[i - 2] - common_padding[i]; + } + } + else if (config.padMode == TransposeConvolutionConfiguration::PaddingMode::VALID) + { + /* nothing to do as the paddings are already preset to zero */ + } + else if (config.padMode == TransposeConvolutionConfiguration::PaddingMode::SAME) + { + /* TensorFlow Logic: + * total_padding[i] = (o[i] - 1) * s[i] + effective_k[i] - i[i] + * + * if total padding is odd, the extra is added towards the end + */ + for (int i = 2; i < rank; i++) + { + const auto j = i - 2; /* filter index */ + const auto effective_kernel_size = dilations[j] * (kernel_size[j] - 1) + 1; + const auto required_total_padding = + std::max(0, (input_shape[i] - 1) * strides[j] + effective_kernel_size - output_shape[i]); + + common_padding[i] = required_total_padding / 2; + padding_left[i] = 0; + padding_right[i] = required_total_padding % 2; + } + } + + /* in some scenarios, the extra padding at the end may not change the output at all */ + for (int i = 2; i < rank; i++) { + const auto j = i - 2; /* filter idx */ + const auto total_padding = common_padding[i] * 2 + padding_left[i] + padding_right[i]; + const auto effective_kernel_size = dilations[j] * (kernel_size[j] - 1) + 1; + std::int64_t rem = (input_shape[i] + total_padding - effective_kernel_size) % strides[j]; + + /* the output shape doesn't change if we decrease the total padding by at most `rem` + * provided that we decrease from the right + */ + if (rem && padding_right[i] > 0) + padding_right[i] = std::max(0, padding_right[i] - rem); + } + + auto is_not_zero = [](std::size_t i) { return i != 0; }; + if(std::any_of(std::begin(padding_left), std::end(padding_left), is_not_zero) || + std::any_of(std::begin(padding_right), std::end(padding_right), is_not_zero)) + { + CV_Error(Error::StsNotImplemented, "Padding configuration requires asymmetric padding and hence is not supported."); + } + + typename csl::TransposeConvolution::params_type params; + params.input_shape.assign(std::begin(input_shape), std::end(input_shape)); + params.output_shape.assign(std::begin(output_shape), std::end(output_shape)); + + auto& fshape = params.filter_shape; + fshape.resize(rank); + fshape[0] = input_feature_maps; + fshape[1] = output_feature_maps_per_group; + std::copy(std::begin(kernel_size), std::end(kernel_size), std::begin(fshape) + 2); + CV_Assert(fshape.size() == kernel_size.size() + 2); + + params.padding.assign(std::begin(common_padding) + 2, std::end(common_padding)); + params.stride = strides; + params.dilation = dilations; + params.groups = config.groups; + + convoluter = csl::TransposeConvolution(cudnnHandle, params); + + csl::WorkspaceBuilder builder; + builder.require(convoluter.get_workspace_size()); + scratch_mem_in_bytes = builder.required_workspace_size(); + } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + CV_Assert(inputs.size() == 1 && outputs.size() == 1); + + auto input_wrapper = inputs[0].dynamicCast(); + auto input = input_wrapper->getView(); + + auto output_wrapper = outputs[0].dynamicCast(); + auto output = output_wrapper->getSpan(); + + csl::WorkspaceAllocator allocator(workspace); + convoluter.transpose_convolve(output, input, filtersTensor, allocator.get_instance()); + if (!biasTensor.empty()) + { + std::size_t inner_size = total(output_wrapper->getShape(), 2, -1); + kernels::biasN(stream, output, output, inner_size, biasTensor); + } + } + + std::size_t get_workspace_memory_in_bytes() const noexcept override { return scratch_mem_in_bytes; } + + private: + csl::Stream stream; + csl::cudnn::Handle cudnnHandle; + csl::Tensor filtersTensor, biasTensor; + csl::TransposeConvolution convoluter; + + std::size_t scratch_mem_in_bytes; + }; + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_TRANSPOSE_CONVOLUTION_HPP */ diff --git a/modules/dnn/src/cuda4dnn/test.cpp b/modules/dnn/src/cuda4dnn/test.cpp deleted file mode 100644 index 066d919..0000000 --- a/modules/dnn/src/cuda4dnn/test.cpp +++ /dev/null @@ -1,18 +0,0 @@ -// This file is part of OpenCV project. -// It is subject to the license terms in the LICENSE file found in the top-level directory -// of this distribution and at http://opencv.org/license.html. - -// this file is a stub and will be removed once actual code is added - -#include "../precomp.hpp" - -#ifndef HAVE_CUDA -# error "CUDA4DNN should be enabled iff CUDA and cuDNN were found" -#endif - -#include - -void cuda4dnn_build_test_func() { - auto ver = cudnnGetVersion(); - CV_UNUSED(ver); -} diff --git a/modules/dnn/src/dnn.cpp b/modules/dnn/src/dnn.cpp index 480cf96..47acc07 100644 --- a/modules/dnn/src/dnn.cpp +++ b/modules/dnn/src/dnn.cpp @@ -43,7 +43,9 @@ #include "op_halide.hpp" #include "op_inf_engine.hpp" #include "op_vkcom.hpp" +#include "op_cuda.hpp" #include "halide_scheduler.hpp" + #include #include #include @@ -51,12 +53,15 @@ #include #include #include +#include #include #include #include #include +#include + namespace cv { namespace dnn { CV__DNN_INLINE_NS_BEGIN @@ -141,6 +146,13 @@ private: if (haveVulkan()) backends.push_back(std::make_pair(DNN_BACKEND_VKCOM, DNN_TARGET_VULKAN)); #endif + +#ifdef HAVE_CUDA + if (haveCUDA()) { + backends.push_back(std::make_pair(DNN_BACKEND_CUDA, DNN_TARGET_CUDA)); + backends.push_back(std::make_pair(DNN_BACKEND_CUDA, DNN_TARGET_CUDA_FP16)); + } +#endif } static inline bool checkIETarget(int target) { @@ -1010,6 +1022,22 @@ static Ptr wrapMat(int backendId, int targetId, cv::Mat& m) return Ptr(new VkComBackendWrapper(m)); #endif // HAVE_VULKAN } + else if (backendId == DNN_BACKEND_CUDA) + { + CV_Assert(haveCUDA()); + +#ifdef HAVE_CUDA + switch (targetId) + { + case DNN_TARGET_CUDA: + return CUDABackendWrapperFP32::create(m); + case DNN_TARGET_CUDA_FP16: + return CUDABackendWrapperFP16::create(m); + default: + CV_Assert(IS_DNN_CUDA_TARGET(targetId)); + } +#endif + } else CV_Error(Error::StsNotImplemented, "Unknown backend identifier"); return Ptr(); @@ -1038,6 +1066,18 @@ struct Net::Impl preferableBackend = DNN_BACKEND_DEFAULT; preferableTarget = DNN_TARGET_CPU; skipInfEngineInit = false; + +#ifdef HAVE_CUDA + if (cv::cuda::getCudaEnabledDeviceCount() > 0) + { + cuda4dnn::csl::CSLContext context; + context.stream = cuda4dnn::csl::Stream(true); + context.cublas_handle = cuda4dnn::csl::cublas::Handle(context.stream); + context.cudnn_handle = cuda4dnn::csl::cudnn::Handle(context.stream); + + cudaInfo = std::unique_ptr(new CudaInfo_t(std::move(context))); + } +#endif } Ptr netInputLayer; @@ -1060,6 +1100,17 @@ struct Net::Impl std::vector layersTimings; Mat output_blob; +#ifdef HAVE_CUDA + struct CudaInfo_t + { + CudaInfo_t(cuda4dnn::csl::CSLContext ctxt) : context(std::move(ctxt)) { } + cuda4dnn::csl::CSLContext context; + cuda4dnn::csl::Workspace workspace; + }; + + std::unique_ptr cudaInfo; +#endif + Ptr wrap(Mat& host) { if (preferableBackend == DNN_BACKEND_OPENCV && preferableTarget == DNN_TARGET_CPU) @@ -1095,6 +1146,21 @@ struct Net::Impl return Ptr(new VkComBackendWrapper(baseBuffer, host)); #endif } + else if (preferableBackend == DNN_BACKEND_CUDA) + { + CV_Assert(haveCUDA()); +#ifdef HAVE_CUDA + switch (preferableTarget) + { + case DNN_TARGET_CUDA: + return CUDABackendWrapperFP32::create(baseBuffer, shape); + case DNN_TARGET_CUDA_FP16: + return CUDABackendWrapperFP16::create(baseBuffer, shape); + default: + CV_Assert(IS_DNN_CUDA_TARGET(preferableTarget)); + } +#endif + } else CV_Error(Error::StsNotImplemented, "Unknown backend identifier"); } @@ -1200,6 +1266,9 @@ struct Net::Impl preferableTarget == DNN_TARGET_FPGA); CV_Assert(preferableBackend != DNN_BACKEND_VKCOM || preferableTarget == DNN_TARGET_VULKAN); + CV_Assert(preferableBackend != DNN_BACKEND_CUDA || + IS_DNN_CUDA_TARGET(preferableTarget)); + if (!netWasAllocated || this->blobsToKeep != blobsToKeep_) { if (preferableBackend == DNN_BACKEND_OPENCV && IS_DNN_OPENCL_TARGET(preferableTarget)) @@ -1235,6 +1304,17 @@ struct Net::Impl preferableTarget = DNN_TARGET_CPU; } + if (preferableBackend == DNN_BACKEND_CUDA && !haveCUDA()) + { +#ifdef HAVE_CUDA + CV_LOG_WARNING(NULL, "unable to use CUDA backend; switching to CPU"); +#else + CV_LOG_WARNING(NULL, "DNN module was not built with CUDA backend; switching to CPU"); +#endif + preferableBackend = DNN_BACKEND_OPENCV; + preferableTarget = DNN_TARGET_CPU; + } + clear(); allocateLayers(blobsToKeep_); @@ -1245,7 +1325,7 @@ struct Net::Impl initBackend(); - if (!netWasAllocated ) + if (!netWasAllocated) { #ifdef HAVE_HALIDE if (preferableBackend == DNN_BACKEND_HALIDE) @@ -1389,6 +1469,8 @@ struct Net::Impl initInfEngineBackend(); else if (preferableBackend == DNN_BACKEND_VKCOM) initVkComBackend(); + else if (preferableBackend == DNN_BACKEND_CUDA) + initCUDABackend(); else CV_Error(Error::StsNotImplemented, "Unknown backend identifier"); } @@ -1777,6 +1859,35 @@ struct Net::Impl #endif // HAVE_INF_ENGINE } + void initCUDABackend() { + CV_Assert(haveCUDA()); + +#ifdef HAVE_CUDA + for (auto& layer : layers) + { + auto& ld = layer.second; + auto& layerInstance = ld.layerInstance; + + if (!layerInstance->supportBackend(DNN_BACKEND_CUDA)) + { + std::ostringstream os; + os << "CUDA backend will fallback to the CPU implementation for the layer \"" << ld.name + << "\" of type " << ld.type << '\n'; + CV_LOG_INFO(NULL, os.str().c_str()); + continue; + } + + /* we make a copy so that `initCUDA` doesn't modify `cudaInfo->context` */ + auto context = cudaInfo->context; + auto node = layerInstance->initCUDA(&context, ld.inputBlobsWrappers, ld.outputBlobsWrappers); + ld.backendNodes[DNN_BACKEND_CUDA] = node; + + auto cudaNode = node.dynamicCast(); + cudaInfo->workspace.require(cudaNode->get_workspace_memory_in_bytes()); + } +#endif + } + void allocateLayer(int lid, const LayersShapesMap& layersShapes) { CV_TRACE_FUNCTION(); @@ -1822,6 +1933,13 @@ struct Net::Impl for (size_t i = 0; i < ninputs; i++) { ld.inputBlobsWrappers[i] = wrap(netInputLayer->inputsData[i]); +#ifdef HAVE_CUDA + if (IS_DNN_CUDA_TARGET(preferableTarget)) + { + auto wrapper = ld.inputBlobsWrappers[i].dynamicCast(); + wrapper->setStream(cudaInfo->context.stream); + } +#endif } } else @@ -1850,9 +1968,18 @@ struct Net::Impl for (int i = 0; i < ld.outputBlobs.size(); ++i) { ld.outputBlobsWrappers[i] = wrap(ld.outputBlobs[i]); +#ifdef HAVE_CUDA + if (IS_DNN_CUDA_TARGET(preferableTarget)) + { + auto wrapper = ld.outputBlobsWrappers[i].dynamicCast(); + wrapper->setStream(cudaInfo->context.stream); + } +#endif } - ld.internalBlobsWrappers.resize(ld.internals.size()); - for (int i = 0; i < ld.internals.size(); ++i) + + /* CUDA backend has its own system for internal blobs; we don't need these */ + ld.internalBlobsWrappers.resize((preferableBackend == DNN_BACKEND_CUDA) ? 0 : ld.internals.size()); + for (int i = 0; i < ld.internalBlobsWrappers.size(); ++i) { ld.internalBlobsWrappers[i] = wrap(ld.internals[i]); } @@ -1893,6 +2020,7 @@ struct Net::Impl void fuseLayers(const std::vector& blobsToKeep_) { if( !fusion || (preferableBackend != DNN_BACKEND_OPENCV && + preferableBackend != DNN_BACKEND_CUDA && preferableBackend != DNN_BACKEND_INFERENCE_ENGINE)) return; @@ -2243,6 +2371,15 @@ struct Net::Impl blobManager.reset(); backendWrappers.clear(); + + for(auto& layer : layers) + { + auto& ld = layer.second; + ld.inputBlobsWrappers.clear(); + ld.outputBlobsWrappers.clear(); + ld.internalBlobsWrappers.clear(); + } + // Fake references to input blobs. for (int i = 0; i < layers[0].outputBlobs.size(); ++i) blobManager.addReference(LayerPin(0, i)); @@ -2437,7 +2574,18 @@ struct Net::Impl { Ptr node = it->second; CV_Assert(!node.empty()); - if (preferableBackend == DNN_BACKEND_HALIDE) + if (preferableBackend == DNN_BACKEND_CUDA) + { + CV_Assert(haveCUDA()); + +#ifdef HAVE_CUDA + Ptr cudaNode = node.dynamicCast(); + CV_Assert(!cudaNode.empty()); + + cudaNode->forward(ld.inputBlobsWrappers, ld.outputBlobsWrappers, cudaInfo->workspace); +#endif + } + else if (preferableBackend == DNN_BACKEND_HALIDE) { forwardHalide(ld.outputBlobsWrappers, node); } @@ -2500,6 +2648,11 @@ struct Net::Impl //forward itself forwardLayer(ld); + +#ifdef HAVE_CUDA + if (preferableBackend == DNN_BACKEND_CUDA) + cudaInfo->context.stream.synchronize(); +#endif } void getLayerShapesRecursively(int id, LayersShapesMap& inOutShapes) @@ -3124,13 +3277,14 @@ String Net::dump() prevNode = itBackend->second; } } - String colors[] = {"#ffffb3", "#fccde5", "#8dd3c7", "#bebada", "#80b1d3", "#fdb462"}; + String colors[] = {"#ffffb3", "#fccde5", "#8dd3c7", "#bebada", "#80b1d3", "#fdb462", "#ff4848"}; String backend; switch (prefBackend) { case DNN_BACKEND_DEFAULT: backend = "DEFAULT/"; break; case DNN_BACKEND_HALIDE: backend = "HALIDE/"; break; case DNN_BACKEND_INFERENCE_ENGINE: backend = "DLIE/"; break; case DNN_BACKEND_OPENCV: backend = "OCV/"; break; + case DNN_BACKEND_CUDA: backend = "CUDA/"; break; } out << "digraph G {" << '\n'; // Add nodes @@ -3227,6 +3381,8 @@ String Net::dump() case DNN_TARGET_OPENCL_FP16: out << "OCL_FP16\\n"; colorId = 2; break; case DNN_TARGET_MYRIAD: out << "MYRIAD\\n"; colorId = 3; break; case DNN_TARGET_FPGA: out << "FPGA\\n"; colorId = 4; break; + case DNN_TARGET_CUDA: out << "CUDA\\n"; colorId = 5; break; + case DNN_TARGET_CUDA_FP16: out << "CUDA_FP16\\n"; colorId = 6; break; } out << ((skipId.size() == 1)? "\" " : " }\" "); out << "fillcolor=\"" << colors[colorId] << "\" "; @@ -3632,6 +3788,16 @@ bool Layer::supportBackend(int backendId) return backendId == DNN_BACKEND_OPENCV; } +Ptr Layer::initCUDA( + void*, + const std::vector>&, + const std::vector>&) +{ + CV_Error(Error::StsNotImplemented, "CUDA pipeline of " + type + + " layers is not defined."); + return Ptr(); +} + Ptr Layer::initVkCom(const std::vector > &) { CV_Error(Error::StsNotImplemented, "VkCom pipeline of " + type + diff --git a/modules/dnn/src/layers/batch_norm_layer.cpp b/modules/dnn/src/layers/batch_norm_layer.cpp index 791d8f1..e6eb2c4 100644 --- a/modules/dnn/src/layers/batch_norm_layer.cpp +++ b/modules/dnn/src/layers/batch_norm_layer.cpp @@ -11,6 +11,7 @@ Implementation of Batch Normalization layer. #include "../precomp.hpp" #include "layers_common.hpp" +#include "../op_cuda.hpp" #include "../op_halide.hpp" #include "../op_inf_engine.hpp" #include @@ -19,6 +20,11 @@ Implementation of Batch Normalization layer. #include "opencl_kernels_dnn.hpp" #endif +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/batch_norm.hpp" +using namespace cv::dnn::cuda4dnn; +#endif + namespace cv { namespace dnn @@ -155,6 +161,7 @@ public: virtual bool supportBackend(int backendId) CV_OVERRIDE { return (backendId == DNN_BACKEND_OPENCV) || + backendId == DNN_BACKEND_CUDA || (backendId == DNN_BACKEND_HALIDE && haveHalide()) || (backendId == DNN_BACKEND_INFERENCE_ENGINE && haveInfEngine() && (preferableTarget == DNN_TARGET_CPU || dims == 4)); } @@ -306,6 +313,18 @@ public: } } +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + return make_cuda_node(preferableTarget, std::move(context->stream), weights_, bias_); + } +#endif + virtual Ptr tryAttach(const Ptr& node) CV_OVERRIDE { switch (node->backendId) diff --git a/modules/dnn/src/layers/blank_layer.cpp b/modules/dnn/src/layers/blank_layer.cpp index ef44ed7..ed6a743 100644 --- a/modules/dnn/src/layers/blank_layer.cpp +++ b/modules/dnn/src/layers/blank_layer.cpp @@ -40,8 +40,14 @@ // //M*/ #include "../precomp.hpp" +#include "../op_cuda.hpp" #include "../op_inf_engine.hpp" +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/reshape.hpp" +using namespace cv::dnn::cuda4dnn; +#endif + namespace cv { namespace dnn @@ -57,6 +63,7 @@ public: virtual bool supportBackend(int backendId) CV_OVERRIDE { return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA || (backendId == DNN_BACKEND_INFERENCE_ENGINE && haveInfEngine()); } @@ -107,6 +114,18 @@ public: inputs[i].copyTo(outputs[i]); } +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + return make_cuda_node(preferableTarget, std::move(context->stream)); + } +#endif + #ifdef HAVE_INF_ENGINE virtual Ptr initInfEngine(const std::vector >& inputs) CV_OVERRIDE { diff --git a/modules/dnn/src/layers/concat_layer.cpp b/modules/dnn/src/layers/concat_layer.cpp index aae9bde..ed2cf4f 100644 --- a/modules/dnn/src/layers/concat_layer.cpp +++ b/modules/dnn/src/layers/concat_layer.cpp @@ -42,6 +42,7 @@ #include "../precomp.hpp" #include "layers_common.hpp" +#include "../op_cuda.hpp" #include "../op_halide.hpp" #include "../op_inf_engine.hpp" #include "../op_vkcom.hpp" @@ -50,6 +51,11 @@ #include "opencl_kernels_dnn.hpp" #endif +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/concat.hpp" +using namespace cv::dnn::cuda4dnn; +#endif + namespace cv { namespace dnn @@ -105,6 +111,7 @@ public: virtual bool supportBackend(int backendId) CV_OVERRIDE { return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA || (backendId == DNN_BACKEND_HALIDE && haveHalide() && axis == 1 && !padding) || // By channels (backendId == DNN_BACKEND_INFERENCE_ENGINE && haveInfEngine() && !padding) || (backendId == DNN_BACKEND_VKCOM && haveVulkan() && !padding); @@ -276,6 +283,22 @@ public: } } } + +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + + auto input_wrapper = inputs[0].dynamicCast(); + auto concat_axis = clamp(axis, input_wrapper->getRank()); + return make_cuda_node(preferableTarget, std::move(context->stream), concat_axis, padding); + } +#endif + virtual Ptr initVkCom(const std::vector > &input) CV_OVERRIDE { #ifdef HAVE_VULKAN diff --git a/modules/dnn/src/layers/const_layer.cpp b/modules/dnn/src/layers/const_layer.cpp index 7a33d6e..fd712cd 100644 --- a/modules/dnn/src/layers/const_layer.cpp +++ b/modules/dnn/src/layers/const_layer.cpp @@ -7,12 +7,18 @@ #include "../precomp.hpp" #include "../op_inf_engine.hpp" +#include "../op_cuda.hpp" #include "layers_common.hpp" #ifdef HAVE_OPENCL #include "opencl_kernels_dnn.hpp" #endif +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/const.hpp" +using namespace cv::dnn::cuda4dnn; +#endif + namespace cv { namespace dnn { class ConstLayerImpl CV_FINAL : public ConstLayer @@ -26,7 +32,9 @@ public: virtual bool supportBackend(int backendId) CV_OVERRIDE { - return backendId == DNN_BACKEND_OPENCV || backendId == DNN_BACKEND_INFERENCE_ENGINE; + return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_INFERENCE_ENGINE || + backendId == DNN_BACKEND_CUDA; } virtual bool getMemoryShapes(const std::vector &inputs, @@ -73,6 +81,21 @@ public: return Ptr(new InfEngineBackendNode(ieLayer)); } #endif // HAVE_INF_ENGINE + +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + + CV_Assert(blobs.size() == 1); + return make_cuda_node(preferableTarget, std::move(context->stream), blobs[0]); + } +#endif + }; Ptr ConstLayer::create(const LayerParams& params) diff --git a/modules/dnn/src/layers/convolution_layer.cpp b/modules/dnn/src/layers/convolution_layer.cpp index c8744fa..09bdd93 100644 --- a/modules/dnn/src/layers/convolution_layer.cpp +++ b/modules/dnn/src/layers/convolution_layer.cpp @@ -42,6 +42,7 @@ #include "../precomp.hpp" #include "layers_common.hpp" +#include "../op_cuda.hpp" #include "../op_halide.hpp" #include "../op_inf_engine.hpp" #include "../op_vkcom.hpp" @@ -55,6 +56,12 @@ using namespace cv::dnn::ocl4dnn; #endif +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/convolution.hpp" +#include "../cuda4dnn/primitives/transpose_convolution.hpp" +using namespace cv::dnn::cuda4dnn; +#endif + namespace cv { namespace dnn @@ -253,6 +260,15 @@ public: virtual bool supportBackend(int backendId) CV_OVERRIDE { + if (backendId == DNN_BACKEND_CUDA) + { + /* only convolution 2d and 3d supported */ + if(kernel_size.size() == 2 || kernel_size.size() == 3) + return true; + + return false; + } + #ifdef HAVE_INF_ENGINE if (backendId == DNN_BACKEND_INFERENCE_ENGINE) { @@ -491,8 +507,6 @@ public: return Ptr(); } - - virtual Ptr initHalide(const std::vector > &inputs) CV_OVERRIDE { #ifdef HAVE_HALIDE @@ -1281,6 +1295,66 @@ public: kernel_size, strides, pads_begin, pads_end, dilations, activ.get(), ngroups, nstripes); } +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + + CV_Assert(inputs.size() == 1); + auto input_wrapper = inputs[0].dynamicCast(); + auto input_shape = input_wrapper->getShape(); + + CV_Assert(outputs.size() == 1); + auto output_wrapper = outputs[0].dynamicCast(); + auto output_shape = output_wrapper->getShape(); + + const auto output_feature_maps = blobs[0].size[0]; + const auto input_feature_maps = input_shape[1]; + const auto input_feature_maps_per_group = blobs[0].size[1]; + const auto groups = input_feature_maps / input_feature_maps_per_group; + + ConvolutionConfiguration config; + config.kernel_size.assign(std::begin(kernel_size), std::end(kernel_size)); + config.dilations.assign(std::begin(dilations), std::end(dilations)); + config.strides.assign(std::begin(strides), std::end(strides)); + + if (padMode.empty()) + { + config.padMode = ConvolutionConfiguration::PaddingMode::MANUAL; + config.pads_begin.assign(std::begin(pads_begin), std::end(pads_begin)); + config.pads_end.assign(std::begin(pads_end), std::end(pads_end)); + } + else if (padMode == "VALID") + { + config.padMode = ConvolutionConfiguration::PaddingMode::VALID; + } + else if (padMode == "SAME") + { + config.padMode = ConvolutionConfiguration::PaddingMode::SAME; + } + else + { + CV_Error(Error::StsNotImplemented, padMode + " padding mode not supported by ConvolutionLayer"); + } + + config.input_shape.assign(std::begin(input_shape), std::end(input_shape)); + config.output_shape.assign(std::begin(output_shape), std::end(output_shape)); + config.groups = groups; + + Mat filtersMat = fusedWeights ? weightsMat : blobs[0]; + Mat biasMat = (hasBias() || fusedBias) ? Mat(output_feature_maps, 1, CV_32F, biasvec.data()) : Mat(); + if (countNonZero(biasMat) == 0) + biasMat = Mat(); + + return make_cuda_node( + preferableTarget, std::move(context->stream), std::move(context->cudnn_handle), config, filtersMat, biasMat); + } +#endif + virtual int64 getFLOPS(const std::vector &inputs, const std::vector &outputs) const CV_OVERRIDE { @@ -1323,6 +1397,15 @@ public: virtual bool supportBackend(int backendId) CV_OVERRIDE { + if (backendId == DNN_BACKEND_CUDA) + { + /* only deconvolution 2d and 3d supported */ + if (kernel_size.size() == 2 || kernel_size.size() == 3) + return true; + + return false; + } + #ifdef HAVE_INF_ENGINE const int outGroupCn = blobs[0].size[1]; // Weights are in IOHW or IODHW layout const int group = numOutput / outGroupCn; @@ -1372,7 +1455,8 @@ public: } else #endif // HAVE_INF_ENGINE - return kernel_size.size() == 2 && (backendId == DNN_BACKEND_OPENCV || backendId == DNN_BACKEND_HALIDE); + return backendId == DNN_BACKEND_CUDA || + (kernel_size.size() == 2 && (backendId == DNN_BACKEND_OPENCV || backendId == DNN_BACKEND_HALIDE)); } bool getMemoryShapes(const std::vector &inputs, @@ -1898,6 +1982,67 @@ public: } } +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + + CV_Assert(inputs.size() == 1); + auto input_wrapper = inputs[0].dynamicCast(); + auto input_shape = input_wrapper->getShape(); + + CV_Assert(outputs.size() == 1); + auto output_wrapper = outputs[0].dynamicCast(); + auto output_shape = output_wrapper->getShape(); + + const auto output_feature_maps = numOutput; + const auto output_feature_maps_per_group = blobs[0].size[1]; + const auto groups = output_feature_maps / output_feature_maps_per_group; + + TransposeConvolutionConfiguration config; + config.kernel_size.assign(std::begin(kernel_size), std::end(kernel_size)); + config.dilations.assign(std::begin(dilations), std::end(dilations)); + config.strides.assign(std::begin(strides), std::end(strides)); + + if (padMode.empty()) + { + config.padMode = TransposeConvolutionConfiguration::PaddingMode::MANUAL; + config.pads_begin.assign(std::begin(pads_begin), std::end(pads_begin)); + config.pads_end.assign(std::begin(pads_end), std::end(pads_end)); + } + else if (padMode == "VALID") + { + config.padMode = TransposeConvolutionConfiguration::PaddingMode::VALID; + } + else if (padMode == "SAME") + { + config.padMode = TransposeConvolutionConfiguration::PaddingMode::SAME; + } + else + { + CV_Error(Error::StsNotImplemented, padMode + " padding mode not supported by DeconvolutionLayer"); + } + + config.input_shape.assign(std::begin(input_shape), std::end(input_shape)); + config.output_shape.assign(std::begin(output_shape), std::end(output_shape)); + config.groups = groups; + + CV_Assert(blobs.size() >= 1); + Mat filtersMat = fusedWeights ? weightsMat.t() : blobs[0]; + + Mat biasMat = (hasBias() || fusedBias) ? biasesMat : Mat(); + if (countNonZero(biasMat) == 0) + biasMat = Mat(); + + return make_cuda_node( + preferableTarget, std::move(context->stream), std::move(context->cudnn_handle), config, filtersMat, biasMat); + } +#endif + virtual Ptr initHalide(const std::vector > &inputs) CV_OVERRIDE { #ifdef HAVE_HALIDE diff --git a/modules/dnn/src/layers/elementwise_layers.cpp b/modules/dnn/src/layers/elementwise_layers.cpp index 96dffce..632cac8 100644 --- a/modules/dnn/src/layers/elementwise_layers.cpp +++ b/modules/dnn/src/layers/elementwise_layers.cpp @@ -42,6 +42,7 @@ #include "../precomp.hpp" #include "layers_common.hpp" +#include "../op_cuda.hpp" #include "../op_halide.hpp" #include "../op_inf_engine.hpp" #include "../op_vkcom.hpp" @@ -52,6 +53,11 @@ #include "opencl_kernels_dnn.hpp" #endif +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/activation.hpp" +using namespace cv::dnn::cuda4dnn; +#endif + namespace cv { namespace dnn @@ -221,6 +227,18 @@ public: func.apply(src, dst, len, planeSize, cn0, cn1); } +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + return func.initCUDA(Layer::preferableTarget, context->stream); + } +#endif + virtual int64 getFLOPS(const std::vector &inputs, const std::vector &outputs) const CV_OVERRIDE { @@ -261,7 +279,9 @@ struct ReLUFunctor if (backendId == DNN_BACKEND_INFERENCE_ENGINE) return slope >= 0 || !INF_ENGINE_VER_MAJOR_EQ(INF_ENGINE_RELEASE_2019R1); #endif - return backendId == DNN_BACKEND_OPENCV || backendId == DNN_BACKEND_HALIDE || + return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA || + backendId == DNN_BACKEND_HALIDE || backendId == DNN_BACKEND_VKCOM; } @@ -297,6 +317,13 @@ struct ReLUFunctor } } +#ifdef HAVE_CUDA + Ptr initCUDA(int target, csl::Stream stream) + { + return make_cuda_node(target, stream, slope); + } +#endif + #ifdef HAVE_OPENCL bool initKernel(ocl::Kernel &ker, const UMat &src) const { @@ -370,8 +397,6 @@ struct ReLUFunctor } #endif // HAVE_VULKAN - - bool tryFuse(Ptr&) { return false; } void getScaleShift(Mat&, Mat&) const {} @@ -392,7 +417,9 @@ struct ReLU6Functor bool supportBackend(int backendId, int) { - return backendId == DNN_BACKEND_OPENCV || backendId == DNN_BACKEND_HALIDE || + return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA || + backendId == DNN_BACKEND_HALIDE || backendId == DNN_BACKEND_INFERENCE_ENGINE; } @@ -460,6 +487,13 @@ struct ReLU6Functor } #endif +#ifdef HAVE_CUDA + Ptr initCUDA(int target, csl::Stream stream) + { + return make_cuda_node(target, stream, minValue, maxValue); + } +#endif + #ifdef HAVE_HALIDE void attachHalide(const Halide::Expr& input, Halide::Func& top) { @@ -496,7 +530,9 @@ struct TanHFunctor bool supportBackend(int backendId, int) { - return backendId == DNN_BACKEND_OPENCV || backendId == DNN_BACKEND_HALIDE || + return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA || + backendId == DNN_BACKEND_HALIDE || backendId == DNN_BACKEND_INFERENCE_ENGINE; } @@ -540,6 +576,13 @@ struct TanHFunctor } #endif +#ifdef HAVE_CUDA + Ptr initCUDA(int target, csl::Stream stream) + { + return make_cuda_node(target, stream); + } +#endif + #ifdef HAVE_HALIDE void attachHalide(const Halide::Expr& input, Halide::Func& top) { @@ -576,7 +619,9 @@ struct SigmoidFunctor bool supportBackend(int backendId, int) { - return backendId == DNN_BACKEND_OPENCV || backendId == DNN_BACKEND_HALIDE || + return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA || + backendId == DNN_BACKEND_HALIDE || backendId == DNN_BACKEND_INFERENCE_ENGINE; } @@ -620,6 +665,13 @@ struct SigmoidFunctor } #endif +#ifdef HAVE_CUDA + Ptr initCUDA(int target, csl::Stream stream) + { + return make_cuda_node(target, stream); + } +#endif + #ifdef HAVE_HALIDE void attachHalide(const Halide::Expr& input, Halide::Func& top) { @@ -658,7 +710,9 @@ struct ELUFunctor bool supportBackend(int backendId, int) { - return backendId == DNN_BACKEND_OPENCV || backendId == DNN_BACKEND_HALIDE || + return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA || + backendId == DNN_BACKEND_HALIDE || backendId == DNN_BACKEND_INFERENCE_ENGINE; } @@ -702,6 +756,13 @@ struct ELUFunctor } #endif +#ifdef HAVE_CUDA + Ptr initCUDA(int target, csl::Stream stream) + { + return make_cuda_node(target, stream); + } +#endif + #ifdef HAVE_HALIDE void attachHalide(const Halide::Expr& input, Halide::Func& top) { @@ -742,7 +803,9 @@ struct AbsValFunctor if (backendId == DNN_BACKEND_INFERENCE_ENGINE) return !INF_ENGINE_VER_MAJOR_EQ(INF_ENGINE_RELEASE_2019R1); #endif - return backendId == DNN_BACKEND_OPENCV || backendId == DNN_BACKEND_HALIDE; + return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA || + backendId == DNN_BACKEND_HALIDE; } void apply(const float* srcptr, float* dstptr, int len, size_t planeSize, int cn0, int cn1) const @@ -785,6 +848,13 @@ struct AbsValFunctor } #endif +#ifdef HAVE_CUDA + Ptr initCUDA(int target, csl::Stream stream) + { + return make_cuda_node(target, stream); + } +#endif + #ifdef HAVE_HALIDE void attachHalide(const Halide::Expr& input, Halide::Func& top) { @@ -821,7 +891,9 @@ struct BNLLFunctor bool supportBackend(int backendId, int) { - return backendId == DNN_BACKEND_OPENCV || backendId == DNN_BACKEND_HALIDE; + return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA || + backendId == DNN_BACKEND_HALIDE; } void apply(const float* srcptr, float* dstptr, int len, size_t planeSize, int cn0, int cn1) const @@ -865,6 +937,13 @@ struct BNLLFunctor } #endif +#ifdef HAVE_CUDA + Ptr initCUDA(int target, csl::Stream stream) + { + return make_cuda_node(target, stream); + } +#endif + #ifdef HAVE_HALIDE void attachHalide(const Halide::Expr& input, Halide::Func& top) { @@ -912,7 +991,9 @@ struct PowerFunctor if (backendId == DNN_BACKEND_INFERENCE_ENGINE) return (targetId != DNN_TARGET_OPENCL && targetId != DNN_TARGET_OPENCL_FP16) || power == 1.0 || power == 0.5; else - return backendId == DNN_BACKEND_OPENCV || backendId == DNN_BACKEND_HALIDE; + return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA || + backendId == DNN_BACKEND_HALIDE; } void apply(const float* srcptr, float* dstptr, int len, size_t planeSize, int cn0, int cn1) const @@ -973,6 +1054,13 @@ struct PowerFunctor } #endif +#ifdef HAVE_CUDA + Ptr initCUDA(int target, csl::Stream stream) + { + return make_cuda_node(target, stream, power, scale, shift); + } +#endif + #ifdef HAVE_HALIDE void attachHalide(const Halide::Expr& input, Halide::Func& top) { @@ -1051,7 +1139,9 @@ struct ChannelsPReLUFunctor bool supportBackend(int backendId, int) { - return backendId == DNN_BACKEND_OPENCV || backendId == DNN_BACKEND_HALIDE || + return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA || + backendId == DNN_BACKEND_HALIDE || backendId == DNN_BACKEND_INFERENCE_ENGINE; } @@ -1126,6 +1216,13 @@ struct ChannelsPReLUFunctor } #endif +#ifdef HAVE_CUDA + Ptr initCUDA(int target, csl::Stream stream) + { + return make_cuda_node(target, stream, scale); + } +#endif + #ifdef HAVE_HALIDE void attachHalide(const Halide::Expr& input, Halide::Func& top) { diff --git a/modules/dnn/src/layers/eltwise_layer.cpp b/modules/dnn/src/layers/eltwise_layer.cpp index 4743068..cccef29 100644 --- a/modules/dnn/src/layers/eltwise_layer.cpp +++ b/modules/dnn/src/layers/eltwise_layer.cpp @@ -42,6 +42,7 @@ #include "../precomp.hpp" #include "layers_common.hpp" +#include "../op_cuda.hpp" #include "../op_halide.hpp" #include "../op_inf_engine.hpp" @@ -49,6 +50,11 @@ #include "opencl_kernels_dnn.hpp" #endif +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/eltwise.hpp" +using namespace cv::dnn::cuda4dnn; +#endif + namespace cv { namespace dnn @@ -97,6 +103,7 @@ public: virtual bool supportBackend(int backendId) CV_OVERRIDE { return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA || backendId == DNN_BACKEND_HALIDE || (backendId == DNN_BACKEND_INFERENCE_ENGINE && (preferableTarget != DNN_TARGET_OPENCL || coeffs.empty())); @@ -374,6 +381,28 @@ public: coeffs, op, activ.get(), nstripes); } +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + + auto op_ = [this] { + switch (op) { + case MAX: return cuda4dnn::EltwiseOpType::MAX; + case SUM: return cuda4dnn::EltwiseOpType::SUM; + case PROD: return cuda4dnn::EltwiseOpType::PRODUCT; + } + return cuda4dnn::EltwiseOpType::SUM; + }(); + + return make_cuda_node(preferableTarget, std::move(context->stream), op_, coeffs); + } +#endif + virtual Ptr initHalide(const std::vector > &input) CV_OVERRIDE { #ifdef HAVE_HALIDE diff --git a/modules/dnn/src/layers/flatten_layer.cpp b/modules/dnn/src/layers/flatten_layer.cpp index f1250e7..4553d9e 100644 --- a/modules/dnn/src/layers/flatten_layer.cpp +++ b/modules/dnn/src/layers/flatten_layer.cpp @@ -42,11 +42,17 @@ #include "../precomp.hpp" #include "layers_common.hpp" +#include "../op_cuda.hpp" #include "../op_inf_engine.hpp" #include #include #include +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/reshape.hpp" +using namespace cv::dnn::cuda4dnn; +#endif + namespace cv { namespace dnn @@ -65,6 +71,7 @@ public: virtual bool supportBackend(int backendId) CV_OVERRIDE { return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA || (backendId == DNN_BACKEND_INFERENCE_ENGINE && haveInfEngine()); } @@ -162,6 +169,18 @@ public: } } +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + return make_cuda_node(preferableTarget, std::move(context->stream)); + } +#endif + #ifdef HAVE_INF_ENGINE virtual Ptr initInfEngine(const std::vector >& inputs) CV_OVERRIDE { diff --git a/modules/dnn/src/layers/fully_connected_layer.cpp b/modules/dnn/src/layers/fully_connected_layer.cpp index c9baf79..f3705d4 100644 --- a/modules/dnn/src/layers/fully_connected_layer.cpp +++ b/modules/dnn/src/layers/fully_connected_layer.cpp @@ -42,6 +42,7 @@ #include "../precomp.hpp" #include "layers_common.hpp" +#include "../op_cuda.hpp" #include "../op_halide.hpp" #include "../op_inf_engine.hpp" #include @@ -51,6 +52,11 @@ using namespace cv::dnn::ocl4dnn; #endif +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/inner_product.hpp" +using namespace cv::dnn::cuda4dnn; +#endif + namespace cv { namespace dnn @@ -123,6 +129,7 @@ public: virtual bool supportBackend(int backendId) CV_OVERRIDE { return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA || (backendId == DNN_BACKEND_HALIDE && haveHalide() && axis == 1) || (backendId == DNN_BACKEND_INFERENCE_ENGINE && haveInfEngine() && axis == 1); } @@ -415,6 +422,24 @@ public: } } +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + + auto input_wrapper = inputs[0].dynamicCast(); + + auto flatten_start_axis = clamp(axis, input_wrapper->getRank()); + + auto biasMat_ = bias ? biasMat : Mat(); + return make_cuda_node(preferableTarget, std::move(context->stream), std::move(context->cublas_handle), flatten_start_axis, weightsMat, biasMat_); + } +#endif + virtual Ptr initHalide(const std::vector > &inputs) CV_OVERRIDE { #ifdef HAVE_HALIDE diff --git a/modules/dnn/src/layers/lrn_layer.cpp b/modules/dnn/src/layers/lrn_layer.cpp index b9e3876..be2c165 100644 --- a/modules/dnn/src/layers/lrn_layer.cpp +++ b/modules/dnn/src/layers/lrn_layer.cpp @@ -42,6 +42,7 @@ #include "../precomp.hpp" #include "layers_common.hpp" +#include "../op_cuda.hpp" #include "../op_halide.hpp" #include "../op_inf_engine.hpp" #include "../op_vkcom.hpp" @@ -55,6 +56,11 @@ using namespace cv::dnn::ocl4dnn; #endif +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/lrn.hpp" +using namespace cv::dnn::cuda4dnn; +#endif + namespace cv { namespace dnn @@ -94,6 +100,7 @@ public: if (backendId == DNN_BACKEND_INFERENCE_ENGINE) return bias == (int)bias; return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA || backendId == DNN_BACKEND_HALIDE || (backendId == DNN_BACKEND_VKCOM && haveVulkan() && (size % 2 == 1) && (type == CHANNEL_NRM)); } @@ -309,6 +316,46 @@ public: } } +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + + cuda4dnn::LRNType type_; + if (type == CHANNEL_NRM) + type_ = cuda4dnn::LRNType::ACROSS_CHANNELS; + else if (type == SPATIAL_NRM) + type_ = cuda4dnn::LRNType::WITHIN_CHANNEL; + else + CV_Error(Error::StsNotImplemented, "Unknown normalization region"); + + float alphaSize = alpha; + if (!normBySize) { + switch (type) { + case CHANNEL_NRM: alphaSize = alpha * size; break; + case SPATIAL_NRM: alphaSize = alpha * size * size; break; + } + } + + std::size_t largestInputSize = 0; + for(auto& wrapper : inputs) { + auto input_wrapper = wrapper.dynamicCast(); + auto shape = input_wrapper->getShape(); + largestInputSize = std::max( + largestInputSize, + std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies()) + ); + } + + return make_cuda_node(preferableTarget, + std::move(context->cudnn_handle), type_, size, alphaSize, beta, bias, largestInputSize); + } +#endif + virtual Ptr initVkCom(const std::vector > &inputs) CV_OVERRIDE { #ifdef HAVE_VULKAN diff --git a/modules/dnn/src/layers/max_unpooling_layer.cpp b/modules/dnn/src/layers/max_unpooling_layer.cpp index 2978509..a44d25c 100644 --- a/modules/dnn/src/layers/max_unpooling_layer.cpp +++ b/modules/dnn/src/layers/max_unpooling_layer.cpp @@ -11,10 +11,14 @@ Implementation of Batch Normalization layer. #include "../precomp.hpp" #include "layers_common.hpp" +#include "../op_cuda.hpp" #include "../op_halide.hpp" #include -#include +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/max_unpooling.hpp" +using namespace cv::dnn::cuda4dnn; +#endif namespace cv { @@ -35,6 +39,7 @@ public: virtual bool supportBackend(int backendId) CV_OVERRIDE { return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA || (backendId == DNN_BACKEND_HALIDE && haveHalide() && !poolPad.width && !poolPad.height); } @@ -124,6 +129,35 @@ public: } } +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + + cuda4dnn::MaxUnpoolingConfiguration config; + auto& window_size = config.window_size; + window_size.resize(2); + window_size[0] = poolKernel.height; + window_size[1] = poolKernel.width; + + auto& strides = config.strides; + strides.resize(2); + strides[0] = poolStride.height; + strides[1] = poolStride.width; + + auto& pads_begin = config.pads_begin; + pads_begin.resize(2); + pads_begin[0] = poolPad.height; + pads_begin[1] = poolPad.width; + + return make_cuda_node(preferableTarget, std::move(context->stream), config); + } +#endif + virtual Ptr initHalide(const std::vector > &input) CV_OVERRIDE { #ifdef HAVE_HALIDE diff --git a/modules/dnn/src/layers/normalize_bbox_layer.cpp b/modules/dnn/src/layers/normalize_bbox_layer.cpp index b6b973d..cab25ab 100644 --- a/modules/dnn/src/layers/normalize_bbox_layer.cpp +++ b/modules/dnn/src/layers/normalize_bbox_layer.cpp @@ -42,8 +42,14 @@ #include "../precomp.hpp" #include "layers_common.hpp" +#include "../op_cuda.hpp" #include "../op_inf_engine.hpp" +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/normalize_bbox.hpp" +using namespace cv::dnn::cuda4dnn; +#endif + namespace cv { namespace dnn { class NormalizeBBoxLayerImpl CV_FINAL : public NormalizeBBoxLayer @@ -70,7 +76,8 @@ public: return preferableTarget == DNN_TARGET_MYRIAD ? !acrossSpatial : startAxis == 1; } - return backendId == DNN_BACKEND_OPENCV; + return backendId == DNN_BACKEND_OPENCV || + (backendId == DNN_BACKEND_CUDA && (pnorm == 1 || pnorm == 2)); } bool getMemoryShapes(const std::vector &inputs, @@ -257,6 +264,33 @@ public: } } +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + + if(pnorm != 1 && pnorm != 2) + CV_Error(Error::StsNotImplemented, "Unsupported normalization mode"); + + auto input_wrapper = inputs[0].dynamicCast(); + auto input_shape = input_wrapper->getShape(); + + NormalizeConfiguration config; + config.input_shape.assign(std::begin(input_shape), std::end(input_shape)); + config.axis_start = clamp(startAxis, input_shape.size()); + config.axis_end = clamp(endAxis, input_shape.size()) + 1; /* +1 because NormalizeOp follows [start, end) convention */ + config.norm = pnorm; + config.eps = epsilon; + + const auto& weightsMat = blobs.empty() ? Mat() : blobs[0]; + return make_cuda_node(preferableTarget, std::move(context->stream), weightsMat, config); + } +#endif + #ifdef HAVE_INF_ENGINE virtual Ptr initInfEngine(const std::vector >& inputs) CV_OVERRIDE { diff --git a/modules/dnn/src/layers/padding_layer.cpp b/modules/dnn/src/layers/padding_layer.cpp index cffb84d..f0726aa 100644 --- a/modules/dnn/src/layers/padding_layer.cpp +++ b/modules/dnn/src/layers/padding_layer.cpp @@ -11,10 +11,16 @@ Implementation of padding layer, which adds paddings to input blob. #include "../precomp.hpp" #include "layers_common.hpp" +#include "../op_cuda.hpp" #include "../op_halide.hpp" #include "../op_inf_engine.hpp" #include +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/padding.hpp" +using namespace cv::dnn::cuda4dnn; +#endif + namespace cv { namespace dnn @@ -100,6 +106,7 @@ public: (dstRanges.size() == 4 && paddings[0].first == 0 && paddings[0].second == 0)); #endif return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA || (backendId == DNN_BACKEND_HALIDE && haveHalide() && dstRanges.size() == 4); } @@ -161,6 +168,27 @@ public: CV_Error(Error::StsNotImplemented, "Unknown padding type: " + paddingType); } +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + + cuda4dnn::PaddingType ptype; + if (paddingType == "constant") + ptype = PaddingType::CONSTANT; + else if (paddingType == "reflect") + ptype = PaddingType::REFLECTION101; + else + CV_Error(Error::StsNotImplemented, "Unsupported padding mode"); + + return make_cuda_node(preferableTarget, std::move(context->stream), ptype, paddingValue, dstRanges); + } +#endif + virtual Ptr initHalide(const std::vector > &inputs) CV_OVERRIDE { #ifdef HAVE_HALIDE diff --git a/modules/dnn/src/layers/permute_layer.cpp b/modules/dnn/src/layers/permute_layer.cpp index 6c0b53f..fbfe2c7 100644 --- a/modules/dnn/src/layers/permute_layer.cpp +++ b/modules/dnn/src/layers/permute_layer.cpp @@ -42,6 +42,7 @@ #include "../precomp.hpp" #include "layers_common.hpp" +#include "../op_cuda.hpp" #include "../op_inf_engine.hpp" #include "../op_vkcom.hpp" #include @@ -51,6 +52,11 @@ #include "opencl_kernels_dnn.hpp" #endif +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/permute.hpp" +using namespace cv::dnn::cuda4dnn; +#endif + namespace cv { namespace dnn @@ -106,6 +112,7 @@ public: virtual bool supportBackend(int backendId) CV_OVERRIDE { return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA || (backendId == DNN_BACKEND_INFERENCE_ENGINE && haveInfEngine()) || (backendId == DNN_BACKEND_VKCOM && haveVulkan()); } @@ -372,6 +379,18 @@ public: } } +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + return make_cuda_node(preferableTarget, std::move(context->stream), _order); + } +#endif + virtual Ptr initVkCom(const std::vector > &input) CV_OVERRIDE { #ifdef HAVE_VULKAN diff --git a/modules/dnn/src/layers/pooling_layer.cpp b/modules/dnn/src/layers/pooling_layer.cpp index 8143b93..8cda45e 100644 --- a/modules/dnn/src/layers/pooling_layer.cpp +++ b/modules/dnn/src/layers/pooling_layer.cpp @@ -43,6 +43,7 @@ #include "../precomp.hpp" #include "layers_common.hpp" #include "opencv2/core/hal/intrin.hpp" +#include "../op_cuda.hpp" #include "../op_halide.hpp" #include "../op_inf_engine.hpp" #include "../op_vkcom.hpp" @@ -57,6 +58,12 @@ using std::min; using namespace cv::dnn::ocl4dnn; #endif +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/pooling.hpp" +#include "../cuda4dnn/primitives/max_unpooling.hpp" +using namespace cv::dnn::cuda4dnn; +#endif + namespace cv { namespace dnn @@ -161,7 +168,11 @@ public: virtual bool supportBackend(int backendId) CV_OVERRIDE { - if (backendId == DNN_BACKEND_INFERENCE_ENGINE) + if (backendId == DNN_BACKEND_CUDA) + { + return type == MAX || type == AVE; + } + else if (backendId == DNN_BACKEND_INFERENCE_ENGINE) { if (computeMaxIdx) return false; @@ -283,6 +294,100 @@ public: } } +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + + auto input_wrapper = inputs[0].dynamicCast(); + auto input_shape = input_wrapper->getShape(); + + /* storing max indices is a special case and we deal with it separately */ + if (computeMaxIdx) { + CV_Assert(type == MAX); + + cuda4dnn::MaxPoolingConfiguration config; + config.window_size.assign(std::begin(kernel_size), std::end(kernel_size)); + config.strides.assign(std::begin(strides), std::end(strides)); + + if (padMode.empty()) + { + config.padMode = MaxPoolingConfiguration::PaddingMode::MANUAL; + config.pads_begin.assign(std::begin(pads_begin), std::end(pads_begin)); + } + else if (padMode == "VALID") + { + config.padMode = MaxPoolingConfiguration::PaddingMode::VALID; + } + else if (padMode == "SAME") + { + config.padMode = MaxPoolingConfiguration::PaddingMode::SAME; + } + else + { + CV_Error(Error::StsNotImplemented, padMode + " padding mode not supported by PoolingLayer"); + } + + config.input_shape.assign(std::begin(input_shape), std::end(input_shape)); + + return make_cuda_node(preferableTarget, std::move(context->stream), config); + } + + PoolingConfiguration config; + if (type == MAX) + { + config.poolMode = PoolingConfiguration::PoolingMode::MAX; + } + else if (type == AVE && !avePoolPaddedArea) + { + config.poolMode = PoolingConfiguration::PoolingMode::AVERAGE_EXCLUDE_PADDING; + } + else if (type == AVE && avePoolPaddedArea) + { + config.poolMode = PoolingConfiguration::PoolingMode::AVERAGE_INCLUDE_PADDING; + } + else + { + CV_Error(Error::StsNotImplemented, "Unsupported pooling mode"); + } + + config.window_size.assign(std::begin(kernel_size), std::end(kernel_size)); + config.strides.assign(std::begin(strides), std::end(strides)); + + if (padMode.empty()) + { + config.padMode = PoolingConfiguration::PaddingMode::MANUAL; + config.pads_begin.assign(std::begin(pads_begin), std::end(pads_begin)); + config.pads_end.assign(std::begin(pads_end), std::end(pads_end)); + } + else if (padMode == "VALID") + { + config.padMode = PoolingConfiguration::PaddingMode::VALID; + } + else if (padMode == "SAME") + { + config.padMode = PoolingConfiguration::PaddingMode::SAME; + } + else + { + CV_Error(Error::StsNotImplemented, padMode + " padding mode not supported by PoolingLayer"); + } + + if (ceilMode) + config.roundMode = PoolingConfiguration::RoundingMode::CEIL; + else + config.roundMode = PoolingConfiguration::RoundingMode::FLOOR; + + config.input_shape.assign(std::begin(input_shape), std::end(input_shape)); + + return make_cuda_node(preferableTarget, std::move(context->cudnn_handle), config); + } +#endif + virtual Ptr initVkCom(const std::vector > &inputs) CV_OVERRIDE { #ifdef HAVE_VULKAN diff --git a/modules/dnn/src/layers/prior_box_layer.cpp b/modules/dnn/src/layers/prior_box_layer.cpp index f5d9bf7..713ee00 100644 --- a/modules/dnn/src/layers/prior_box_layer.cpp +++ b/modules/dnn/src/layers/prior_box_layer.cpp @@ -42,6 +42,7 @@ #include "../precomp.hpp" #include "layers_common.hpp" +#include "../op_cuda.hpp" #include "../op_inf_engine.hpp" #include "../op_vkcom.hpp" #include @@ -52,6 +53,11 @@ #include "opencl_kernels_dnn.hpp" #endif +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/prior_box.hpp" +using namespace cv::dnn::cuda4dnn; +#endif + namespace cv { namespace dnn @@ -274,6 +280,7 @@ public: virtual bool supportBackend(int backendId) CV_OVERRIDE { return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA || (backendId == DNN_BACKEND_INFERENCE_ENGINE && haveInfEngine() && ( _explicitSizes || (_minSize.size() == 1 && _maxSize.size() <= 1))) || (backendId == DNN_BACKEND_VKCOM && haveVulkan()); @@ -485,6 +492,44 @@ public: } } +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + + auto feature_map_wrapper = inputs[0].dynamicCast(); + auto feature_map_shape = feature_map_wrapper->getShape(); + + auto image_wrapper = inputs[1].dynamicCast(); + auto image_shape = image_wrapper->getShape(); + + PriorBoxConfiguration config; + config.feature_map_width = feature_map_shape.rbegin()[0]; + config.feature_map_height = feature_map_shape.rbegin()[1]; + config.image_width = image_shape.rbegin()[0]; + config.image_height = image_shape.rbegin()[1]; + + config.num_priors = _numPriors; + config.box_widths = _boxWidths; + config.box_heights = _boxHeights; + config.offsets_x = _offsetsX; + config.offsets_y = _offsetsY; + config.stepX = _stepX; + config.stepY = _stepY; + + config.variance = _variance; + + config.clip = _clip; + config.normalize = _bboxesNormalized; + + return make_cuda_node(preferableTarget, std::move(context->stream), config); + } +#endif + virtual Ptr initVkCom(const std::vector > &input) CV_OVERRIDE { #ifdef HAVE_VULKAN diff --git a/modules/dnn/src/layers/region_layer.cpp b/modules/dnn/src/layers/region_layer.cpp index c33c1cb..9211251 100644 --- a/modules/dnn/src/layers/region_layer.cpp +++ b/modules/dnn/src/layers/region_layer.cpp @@ -41,6 +41,7 @@ //M*/ #include "../precomp.hpp" +#include "../op_cuda.hpp" #include #include #include "../nms.inl.hpp" @@ -49,6 +50,11 @@ #include "opencl_kernels_dnn.hpp" #endif +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/region.hpp" +using namespace cv::dnn::cuda4dnn; +#endif + namespace cv { namespace dnn @@ -87,6 +93,12 @@ public: CV_Error(cv::Error::StsNotImplemented, "Yolo9000 is not implemented"); } + virtual bool supportBackend(int backendId) CV_OVERRIDE + { + return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA; + } + bool getMemoryShapes(const std::vector &inputs, const int requiredOutputs, std::vector &outputs, @@ -332,6 +344,61 @@ public: } } +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + + if (coords != 4) + CV_Error(Error::StsNotImplemented, "Only upright rectangular boxes are supported in RegionLayer."); + + std::size_t height_norm, width_norm; + if (inputs.size() == 1) + { + auto input_wrapper = inputs[0].dynamicCast(); + auto input_shape = input_wrapper->getShape(); + height_norm = input_shape[1]; + width_norm = input_shape[2]; + } + else + { + auto input_wrapper = inputs[1].dynamicCast(); + auto input_shape = input_wrapper->getShape(); + CV_Assert(input_shape.size() == 4); + height_norm = input_shape[2]; + width_norm = input_shape[3]; + } + + cuda4dnn::SquashMethod squash_method; + if(useLogistic) + squash_method = cuda4dnn::SquashMethod::SIGMOID; + else if (useSoftmax) + squash_method = cuda4dnn::SquashMethod::SOFTMAX; + + /* exactly one must be true */ + CV_Assert((useLogistic || useSoftmax) && !(useLogistic && useSoftmax)); + + cuda4dnn::RegionConfiguration config; + config.squash_method = squash_method; + config.classes = classes; + config.boxes_per_cell = anchors; + + config.height_norm = height_norm; + config.width_norm = width_norm; + + config.object_prob_cutoff = (classfix == -1) ? 0.5 : 0.0; + config.class_prob_cutoff = thresh; + + config.nms_iou_threshold = nmsThreshold; + + return make_cuda_node(preferableTarget, std::move(context->stream), blobs[0], config); + } +#endif + virtual int64 getFLOPS(const std::vector &inputs, const std::vector &outputs) const CV_OVERRIDE { diff --git a/modules/dnn/src/layers/reorg_layer.cpp b/modules/dnn/src/layers/reorg_layer.cpp index 659a795..2be87ce 100644 --- a/modules/dnn/src/layers/reorg_layer.cpp +++ b/modules/dnn/src/layers/reorg_layer.cpp @@ -41,6 +41,7 @@ //M*/ #include "../precomp.hpp" +#include "../op_cuda.hpp" #include "../op_inf_engine.hpp" #include #include @@ -49,6 +50,11 @@ #include "opencl_kernels_dnn.hpp" #endif +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/reorg.hpp" +using namespace cv::dnn::cuda4dnn; +#endif + namespace cv { namespace dnn @@ -135,7 +141,9 @@ public: virtual bool supportBackend(int backendId) CV_OVERRIDE { - return backendId == DNN_BACKEND_OPENCV || backendId == DNN_BACKEND_INFERENCE_ENGINE; + return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA || + backendId == DNN_BACKEND_INFERENCE_ENGINE; } #ifdef HAVE_OPENCL @@ -178,6 +186,18 @@ public: permute->forward(inputs, outputs, internals_arr); } +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + return make_cuda_node(preferableTarget, std::move(context->stream), reorgStride); + } +#endif + #ifdef HAVE_INF_ENGINE virtual Ptr initInfEngine(const std::vector >&) CV_OVERRIDE { diff --git a/modules/dnn/src/layers/reshape_layer.cpp b/modules/dnn/src/layers/reshape_layer.cpp index 5cbfc03..ff94f05 100644 --- a/modules/dnn/src/layers/reshape_layer.cpp +++ b/modules/dnn/src/layers/reshape_layer.cpp @@ -42,9 +42,15 @@ #include "../precomp.hpp" #include "layers_common.hpp" +#include "../op_cuda.hpp" #include "../op_inf_engine.hpp" #include +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/reshape.hpp" +using namespace cv::dnn::cuda4dnn; +#endif + namespace cv { namespace dnn @@ -179,6 +185,7 @@ public: virtual bool supportBackend(int backendId) CV_OVERRIDE { return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA || (backendId == DNN_BACKEND_INFERENCE_ENGINE && haveInfEngine()); } @@ -258,6 +265,18 @@ public: } } +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + return make_cuda_node(preferableTarget, std::move(context->stream)); + } +#endif + #ifdef HAVE_INF_ENGINE virtual Ptr initInfEngine(const std::vector >& inputs) CV_OVERRIDE { diff --git a/modules/dnn/src/layers/resize_layer.cpp b/modules/dnn/src/layers/resize_layer.cpp index 339f2b7..3846dcc 100644 --- a/modules/dnn/src/layers/resize_layer.cpp +++ b/modules/dnn/src/layers/resize_layer.cpp @@ -6,9 +6,15 @@ // Third party copyrights are property of their respective owners. #include "../precomp.hpp" #include "layers_common.hpp" +#include "../op_cuda.hpp" #include "../op_inf_engine.hpp" #include +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/resize.hpp" +using namespace cv::dnn::cuda4dnn; +#endif + namespace cv { namespace dnn { class ResizeLayerImpl : public ResizeLayer @@ -51,6 +57,9 @@ public: virtual bool supportBackend(int backendId) CV_OVERRIDE { + if (backendId == DNN_BACKEND_CUDA) + return interpolation == "nearest" || interpolation == "bilinear"; + #ifdef HAVE_INF_ENGINE if (backendId == DNN_BACKEND_INFERENCE_ENGINE) { @@ -159,6 +168,27 @@ public: CV_Error(Error::StsNotImplemented, "Unknown interpolation: " + interpolation); } +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + + cuda4dnn::InterpolationType itype; + if (interpolation == "nearest") + itype = InterpolationType::NEAREST_NEIGHBOUR; + else if (interpolation == "bilinear") + itype = InterpolationType::BILINEAR; + else + CV_Error(Error::StsNotImplemented, "Requested interpolation mode is not available in resize layer."); + + return make_cuda_node(preferableTarget, std::move(context->stream), itype, scaleHeight, scaleWidth); + } +#endif + virtual Ptr initInfEngine(const std::vector >&) CV_OVERRIDE { #ifdef HAVE_INF_ENGINE @@ -224,7 +254,9 @@ public: virtual bool supportBackend(int backendId) CV_OVERRIDE { - return backendId == DNN_BACKEND_OPENCV || backendId == DNN_BACKEND_INFERENCE_ENGINE; + return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_INFERENCE_ENGINE || + backendId == DNN_BACKEND_CUDA; } virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE diff --git a/modules/dnn/src/layers/scale_layer.cpp b/modules/dnn/src/layers/scale_layer.cpp index 4486a0f..f556adb 100644 --- a/modules/dnn/src/layers/scale_layer.cpp +++ b/modules/dnn/src/layers/scale_layer.cpp @@ -11,10 +11,16 @@ Implementation of Scale layer. #include "../precomp.hpp" #include "layers_common.hpp" +#include "../op_cuda.hpp" #include "../op_halide.hpp" #include "../op_inf_engine.hpp" #include +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/scale_shift.hpp" +using namespace cv::dnn::cuda4dnn; +#endif + namespace cv { namespace dnn @@ -50,7 +56,9 @@ public: virtual bool supportBackend(int backendId) CV_OVERRIDE { - return backendId == DNN_BACKEND_OPENCV || backendId == DNN_BACKEND_HALIDE || + return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA || + backendId == DNN_BACKEND_HALIDE || (backendId == DNN_BACKEND_INFERENCE_ENGINE && axis == 1); } @@ -138,6 +146,28 @@ public: } } +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + + CV_Assert(!blobs.empty() || inputs.size() == 2); + + cv::Mat weightsMat = hasWeights ? blobs[0] : Mat(); + + /* if the weights are provided, bias will be in blobs[1]; otherwise, it will be in blobs[0] + * in either case, it is at the end of the blobs vector => bias = blobs.back() + */ + cv::Mat biasMat = hasBias ? blobs.back() : Mat(); + + return make_cuda_node(preferableTarget, std::move(context->stream), axis, weightsMat, biasMat); + } +#endif + virtual Ptr tryAttach(const Ptr& node) CV_OVERRIDE { switch (node->backendId) diff --git a/modules/dnn/src/layers/shuffle_channel_layer.cpp b/modules/dnn/src/layers/shuffle_channel_layer.cpp index 44987f6..6db74d1 100644 --- a/modules/dnn/src/layers/shuffle_channel_layer.cpp +++ b/modules/dnn/src/layers/shuffle_channel_layer.cpp @@ -5,6 +5,12 @@ // Copyright (C) 2018, Intel Corporation, all rights reserved. // Third party copyrights are property of their respective owners. #include "../precomp.hpp" +#include "../op_cuda.hpp" + +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/shuffle_channel.hpp" +using namespace cv::dnn::cuda4dnn; +#endif namespace cv { namespace dnn { @@ -17,6 +23,12 @@ public: setParamsFrom(params); } + virtual bool supportBackend(int backendId) CV_OVERRIDE + { + return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA; + } + bool getMemoryShapes(const std::vector &inputs, const int requiredOutputs, std::vector &outputs, @@ -123,6 +135,18 @@ public: } } +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + return make_cuda_node(preferableTarget, std::move(context->stream), group); + } +#endif + private: Ptr permute; std::vector permuteInpShape, permuteOutShape; diff --git a/modules/dnn/src/layers/slice_layer.cpp b/modules/dnn/src/layers/slice_layer.cpp index 4305551..33d3a34 100644 --- a/modules/dnn/src/layers/slice_layer.cpp +++ b/modules/dnn/src/layers/slice_layer.cpp @@ -41,6 +41,7 @@ //M*/ #include "../precomp.hpp" +#include "../op_cuda.hpp" #include "../op_inf_engine.hpp" #include "layers_common.hpp" #include @@ -49,6 +50,11 @@ #include "opencl_kernels_dnn.hpp" #endif +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/slice.hpp" +using namespace cv::dnn::cuda4dnn; +#endif + namespace cv { namespace dnn @@ -112,6 +118,7 @@ public: virtual bool supportBackend(int backendId) CV_OVERRIDE { return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA || (backendId == DNN_BACKEND_INFERENCE_ENGINE && #ifdef HAVE_INF_ENGINE INF_ENGINE_VER_MAJOR_GE(INF_ENGINE_RELEASE_2019R1) && @@ -260,6 +267,28 @@ public: } } +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + + std::vector> offsets; + for (const auto& ranges : sliceRanges) + { + std::vector offsets_i; + for (const auto& range : ranges) + offsets_i.push_back(range.start); + offsets.push_back(std::move(offsets_i)); + } + + return make_cuda_node(preferableTarget, std::move(context->stream), std::move(offsets)); + } +#endif + #ifdef HAVE_INF_ENGINE #if INF_ENGINE_VER_MAJOR_GE(INF_ENGINE_RELEASE_2019R1) virtual Ptr initInfEngine(const std::vector >& inputs) CV_OVERRIDE diff --git a/modules/dnn/src/layers/softmax_layer.cpp b/modules/dnn/src/layers/softmax_layer.cpp index 59c8163..119cf0a 100644 --- a/modules/dnn/src/layers/softmax_layer.cpp +++ b/modules/dnn/src/layers/softmax_layer.cpp @@ -42,6 +42,7 @@ #include "../precomp.hpp" #include "layers_common.hpp" +#include "../op_cuda.hpp" #include "../op_halide.hpp" #include "../op_inf_engine.hpp" #include "../op_vkcom.hpp" @@ -54,6 +55,11 @@ using std::max; using namespace cv::dnn::ocl4dnn; #endif +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/softmax.hpp" +using namespace cv::dnn::cuda4dnn; +#endif + namespace cv { namespace dnn @@ -90,6 +96,7 @@ public: virtual bool supportBackend(int backendId) CV_OVERRIDE { return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA || (backendId == DNN_BACKEND_HALIDE && haveHalide() && axisRaw == 1) || (backendId == DNN_BACKEND_INFERENCE_ENGINE && haveInfEngine() && !logSoftMax) || (backendId == DNN_BACKEND_VKCOM && haveVulkan()); @@ -286,6 +293,21 @@ public: } } +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + + auto input_wrapper = inputs[0].dynamicCast(); + auto channel_axis = clamp(axisRaw, input_wrapper->getRank()); + return make_cuda_node(preferableTarget, std::move(context->cudnn_handle), channel_axis, logSoftMax); + } +#endif + virtual Ptr initVkCom(const std::vector > &inputs) CV_OVERRIDE { #ifdef HAVE_VULKAN diff --git a/modules/dnn/src/layers/split_layer.cpp b/modules/dnn/src/layers/split_layer.cpp index b0ea1ae..b025d5f 100644 --- a/modules/dnn/src/layers/split_layer.cpp +++ b/modules/dnn/src/layers/split_layer.cpp @@ -41,8 +41,14 @@ //M*/ #include "../precomp.hpp" +#include "../op_cuda.hpp" #include "layers_common.hpp" +#ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/split.hpp" +using namespace cv::dnn::cuda4dnn; +#endif + namespace cv { namespace dnn @@ -66,6 +72,12 @@ public: } } + virtual bool supportBackend(int backendId) CV_OVERRIDE + { + return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_CUDA; + } + bool getMemoryShapes(const std::vector &inputs, const int requiredOutputs, std::vector &outputs, @@ -92,6 +104,19 @@ public: inputs[0].copyTo(outputs[i]); } } + +#ifdef HAVE_CUDA + Ptr initCUDA( + void *context_, + const std::vector>& inputs, + const std::vector>& outputs + ) override + { + auto context = reinterpret_cast(context_); + return make_cuda_node(preferableTarget, std::move(context->stream)); + } +#endif + }; Ptr SplitLayer::create(const LayerParams& params) diff --git a/modules/dnn/src/op_cuda.hpp b/modules/dnn/src/op_cuda.hpp new file mode 100644 index 0000000..5a106f3 --- /dev/null +++ b/modules/dnn/src/op_cuda.hpp @@ -0,0 +1,390 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_OP_CUDA_HPP +#define OPENCV_DNN_SRC_OP_CUDA_HPP + +#ifdef HAVE_CUDA +#include "cuda4dnn/csl/stream.hpp" +#include "cuda4dnn/csl/cublas.hpp" +#include "cuda4dnn/csl/cudnn.hpp" +#include "cuda4dnn/csl/tensor.hpp" +#include "cuda4dnn/csl/memory.hpp" +#include "cuda4dnn/csl/fp16.hpp" +#include "cuda4dnn/csl/workspace.hpp" +#endif + +#include +#include + +#include +#include +#include + +namespace cv { namespace dnn { + + constexpr bool IS_DNN_CUDA_TARGET(int id) { + return id == DNN_TARGET_CUDA_FP16 || id == DNN_TARGET_CUDA; + } + + constexpr bool haveCUDA() { +#ifdef HAVE_CUDA + return true; +#else + return false; +#endif + } + +#ifdef HAVE_CUDA + namespace cuda4dnn { namespace csl { + struct CSLContext { + Stream stream; + cublas::Handle cublas_handle; + cudnn::Handle cudnn_handle; + }; + + /** @brief creates Tensor object from cv::Mat (only the header is created, i.e. no data is copied) + * + * \tparam T element type for the tensor + * \param[in] mat cv::Mat from which the shape must be inferred + * + * \return a Tensor object with the shape of \p mat + */ + template + Tensor makeTensorHeader(const Mat& mat) { + auto sizes = shape(mat); + return Tensor(std::begin(sizes), std::end(sizes)); + } + + /** @brief copies data from a cv::Mat to TensorType + * + * \tparam T the type of the elements contained in TensorType object + * + * \param[in] srcMat source matrix + * \param[out] destTensor destination tensor + * \param stream CUDA stream to use for the memory transfer + * + * The memory copy starts from begining \p srcMat. The number of elements copied is + * equal to the number of elements in \p destTensor. + * + * Pre-conditions: + * - \p srcMat must contain elements of type CV_32F + * - the size of \p srcMat must be larger than or equal to the size of \p destTensor + * + * @note best performance when \p srcMat is continuous and page-locked + * @note blocks calling thread if \p srcMat is not page-locked + */ + template + void copyMatToTensor(const Mat& srcMat, const TensorSpan destTensor, const Stream& stream); + + template <> inline + void copyMatToTensor(const Mat& srcMat, const TensorSpan destTensor, const Stream& stream) { + /* should perhaps convert cv::Mat of different type to the required type and copy */ + CV_Assert(srcMat.type() == CV_32F); + CV_Assert(srcMat.total() >= destTensor.size()); + + Mat temp; + srcMat.convertTo(temp, CV_16F); + CV_Assert(temp.isContinuous()); + + memcpy(destTensor.get(), reinterpret_cast(temp.data), destTensor.size(), stream); + } + + template <> inline + void copyMatToTensor(const Mat& srcMat, const TensorSpan destTensor, const Stream& stream) { + /* should perhaps convert cv::Mat of different type to the required type and copy */ + CV_Assert(srcMat.type() == CV_32F); + CV_Assert(srcMat.total() >= destTensor.size()); + + Mat temp = srcMat.isContinuous() ? srcMat : srcMat.clone(); + CV_Assert(temp.isContinuous()); + + memcpy(destTensor.get(), reinterpret_cast(temp.data), destTensor.size(), stream); + } + + /** @brief copies data from a TensorType to a cv::Mat + * + * \tparam T the type of the elements contained in TensorType object + * + * \param[in] srcTensor source tensor + * \param[out] destMat destination matrix + * \param stream CUDA stream to use for the memory transfer + * + * The entire memory block held by the \p srcTensor is copied to \p destMat. + * + * Pre-conditions: + * - \p destMat must contain elements of type CV_32F + * - the size of \p destMat must be larger than or equal to the size of \p srcTensor + * + * @note best performance when \p destMat is continuous and page-locked + * @note blocks calling thread if \p destMat is not page-locked + */ + template + void copyTensorToMat(TensorView srcTensor, Mat& destMat, const Stream& stream); + + template <> inline + void copyTensorToMat(TensorView srcTensor, Mat& destMat, const Stream& stream) { + CV_Assert(destMat.type() == CV_32F); + CV_Assert(destMat.total() >= srcTensor.size()); + + Mat temp(shape(destMat), CV_16F); + CV_Assert(temp.isContinuous()); + + memcpy(reinterpret_cast(temp.data), srcTensor.get(), srcTensor.size(), stream); + + temp.convertTo(destMat, CV_32F); + } + + template <> inline + void copyTensorToMat(TensorView srcTensor, Mat& destMat, const Stream& stream) { + CV_Assert(destMat.type() == CV_32F); + CV_Assert(destMat.total() >= srcTensor.size()); + + Mat temp = destMat.isContinuous() ? destMat : destMat.clone(); + CV_Assert(temp.isContinuous()); + + memcpy(reinterpret_cast(temp.data), srcTensor.get(), srcTensor.size(), stream); + + if (temp.data != destMat.data) + temp.copyTo(destMat); + } + + }} /* namespace cuda4dnn::csl */ + + /** base class for CUDA operation nodes (for all supported targets) */ + class CUDABackendNode : public BackendNode { + public: + CUDABackendNode() : BackendNode(DNN_BACKEND_CUDA) { } + virtual ~CUDABackendNode() { } + + virtual void forward( + const std::vector>& inputs, + const std::vector>& outputs, + cuda4dnn::csl::Workspace& workspace) = 0; + + virtual std::size_t get_workspace_memory_in_bytes() const noexcept { return 0; } + }; + + /** @brief utility function which creates CUDA node of correct type from `targetId` + * + * CUDA operation nodes take the type of data they operate on as a template parameter. + * For example, ConcatOp is an operation node which concats tensors of `float` type + * into a tensor of `float` type. + * + * This utility function aids the creation of nodes of different types and eliminates the + * need for CUDA target constants (`DNN_TARGET_XXX`) to appear in the operation code which + * reduces coupling between modules. + * + * Example: + * template + * class ConcatOp : public CUDABackendNode; + * + * // returns a cv::Ptr to a ConcatOp object + * auto node = make_cuda_node(DNN_TARGET_CUDA_FP16, axis); + * + * // returns a cv::Ptr to a ConcatOp object + * auto node = make_cuda_node(DNN_TARGET_CUDA, axis); + */ + template