1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
5 // Copyright (C) 2018, Intel Corporation, all rights reserved.
6 // Third party copyrights are property of their respective owners.
7 #include "../precomp.hpp"
8 #include "../op_cuda.hpp"
11 #include "../cuda4dnn/primitives/shuffle_channel.hpp"
12 using namespace cv::dnn::cuda4dnn;
15 namespace cv { namespace dnn {
17 class ShuffleChannelLayerImpl CV_FINAL : public ShuffleChannelLayer
20 ShuffleChannelLayerImpl(const LayerParams& params)
22 group = params.get<int>("group", 1);
23 setParamsFrom(params);
26 virtual bool supportBackend(int backendId) CV_OVERRIDE
28 return backendId == DNN_BACKEND_OPENCV ||
29 backendId == DNN_BACKEND_CUDA;
32 bool getMemoryShapes(const std::vector<MatShape> &inputs,
33 const int requiredOutputs,
34 std::vector<MatShape> &outputs,
35 std::vector<MatShape> &internals) const CV_OVERRIDE
37 CV_Assert(inputs.size() == 1 && inputs[0].size() == 4);
38 CV_Assert(inputs[0][1] % group == 0);
39 Layer::getMemoryShapes(inputs, requiredOutputs, outputs, internals);
43 virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE
47 std::vector<Mat> inputs, outputs;
48 inputs_arr.getMatVector(inputs);
49 outputs_arr.getMatVector(outputs);
52 float order[] = {0, 2, 1, 3};
53 lp.set("order", DictValue::arrayInt(&order[0], 4));
54 permute = PermuteLayer::create(lp);
56 const Mat& inp = inputs[0];
57 const Mat& out = outputs[0];
59 permuteInpShape.resize(4);
60 permuteInpShape[0] = inp.size[0];
61 permuteInpShape[1] = group;
62 permuteInpShape[2] = inp.size[1] / group;
63 permuteInpShape[3] = inp.size[2]*inp.size[3];
65 permuteOutShape.resize(4);
66 permuteOutShape[0] = permuteInpShape[0];
67 permuteOutShape[1] = permuteInpShape[2];
68 permuteOutShape[2] = permuteInpShape[1];
69 permuteOutShape[3] = permuteInpShape[3];
71 std::vector<Mat> permuteInputs(1, inp.reshape(1, permuteInpShape));
72 std::vector<Mat> permuteOutputs(1, out.reshape(1, permuteOutShape));
73 permute->finalize(permuteInputs, permuteOutputs);
78 bool forward_ocl(InputArrayOfArrays inps, OutputArrayOfArrays outs, OutputArrayOfArrays internals)
80 std::vector<UMat> inputs;
81 std::vector<UMat> outputs;
83 inps.getUMatVector(inputs);
84 outs.getUMatVector(outputs);
86 if (inputs[0].u != outputs[0].u)
90 inputs[0] = inputs[0].reshape(1, permuteInpShape.size(), &permuteInpShape[0]);
91 outputs[0] = outputs[0].reshape(1, permuteOutShape.size(), &permuteOutShape[0]);
92 permute->preferableTarget = preferableTarget;
93 permute->forward(inputs, outputs, internals);
96 inputs[0].copyTo(outputs[0]);
102 void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
105 CV_TRACE_ARG_VALUE(name, "name", name.c_str());
107 CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget),
108 forward_ocl(inputs_arr, outputs_arr, internals_arr))
110 if (inputs_arr.depth() == CV_16S)
112 forward_fallback(inputs_arr, outputs_arr, internals_arr);
116 std::vector<Mat> inputs, outputs, internals;
117 inputs_arr.getMatVector(inputs);
118 outputs_arr.getMatVector(outputs);
119 internals_arr.getMatVector(internals);
122 Mat out = outputs[0];
123 if (inp.data != out.data)
125 if (!permute.empty())
127 inp = inp.reshape(1, permuteInpShape);
128 out = out.reshape(1, permuteOutShape);
129 std::vector<Mat> permuteInputs(1, inp);
130 std::vector<Mat> permuteOutputs(1, out);
131 permute->forward(permuteInputs, permuteOutputs, internals);
139 Ptr<BackendNode> initCUDA(
141 const std::vector<Ptr<BackendWrapper>>& inputs,
142 const std::vector<Ptr<BackendWrapper>>& outputs
145 auto context = reinterpret_cast<csl::CSLContext*>(context_);
146 return make_cuda_node<cuda4dnn::ShuffleChannelOp>(preferableTarget, std::move(context->stream), group);
151 Ptr<PermuteLayer> permute;
152 std::vector<int> permuteInpShape, permuteOutShape;
155 Ptr<Layer> ShuffleChannelLayer::create(const LayerParams& params)
157 return Ptr<Layer>(new ShuffleChannelLayerImpl(params));