1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
5 #ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_CONCAT_HPP
6 #define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_CONCAT_HPP
8 #include "../../op_cuda.hpp"
10 #include "../csl/stream.hpp"
11 #include "../csl/pointer.hpp"
13 #include "../kernels/fill.hpp"
14 #include "../kernels/concat.hpp"
16 #include <opencv2/core.hpp>
22 namespace cv { namespace dnn { namespace cuda4dnn {
25 class ConcatOp final : public CUDABackendNode {
27 using wrapper_type = GetCUDABackendWrapperType<T>;
29 ConcatOp(csl::Stream stream_, std::size_t concat_axis, bool zero_padding)
30 : stream(std::move(stream_)), concat_axis{ concat_axis }, zero_padding{ zero_padding }
35 const std::vector<cv::Ptr<BackendWrapper>>& inputs,
36 const std::vector<cv::Ptr<BackendWrapper>>& outputs,
37 csl::Workspace& workspace) override
39 CV_Assert(outputs.size() == 1);
41 auto output_wrapper = outputs[0].dynamicCast<wrapper_type>();
42 auto output = output_wrapper->getSpan();
46 auto output_shape = output_wrapper->getShape();
48 kernels::fill<T>(stream, output, 0.0);
50 std::size_t output_concat_axis_offset = 0;
51 for (int i = 0; i < inputs.size(); i++)
53 auto input_wrapper = inputs[i].dynamicCast<wrapper_type>();
54 auto input = input_wrapper->getView();
55 auto input_shape = input_wrapper->getShape();
57 std::vector<std::size_t> offsets(input_shape.size());
58 for (int j = 0; j < offsets.size(); j++)
59 offsets[j] = (output_shape[j] - input_shape[j]) / 2;
60 offsets[concat_axis] = output_concat_axis_offset;
62 kernels::concat_with_offsets(stream, output, input, offsets);
64 output_concat_axis_offset += input.get_axis_size(concat_axis);
69 std::size_t output_axis_offset = 0;
70 for (int i = 0; i < inputs.size(); i++)
72 auto input_wrapper = inputs[i].dynamicCast<wrapper_type>();
73 auto input = input_wrapper->getView();
75 kernels::concat(stream, output, output_axis_offset, input, concat_axis);
77 output_axis_offset += input.get_axis_size(concat_axis);
84 std::size_t concat_axis;
88 }}} /* namespace cv::dnn::cuda4dnn */
90 #endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_CONCAT_HPP */