1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
5 // Copyright (C) 2018, Intel Corporation, all rights reserved.
6 // Third party copyrights are property of their respective owners.
7 #include "../precomp.hpp"
9 namespace cv { namespace dnn {
11 class ShuffleChannelLayerImpl CV_FINAL : public ShuffleChannelLayer
14 ShuffleChannelLayerImpl(const LayerParams& params)
16 group = params.get<int>("group", 1);
17 setParamsFrom(params);
20 bool getMemoryShapes(const std::vector<MatShape> &inputs,
21 const int requiredOutputs,
22 std::vector<MatShape> &outputs,
23 std::vector<MatShape> &internals) const CV_OVERRIDE
25 CV_Assert(inputs.size() == 1 && inputs[0].size() == 4);
26 CV_Assert(inputs[0][1] % group == 0);
27 Layer::getMemoryShapes(inputs, requiredOutputs, outputs, internals);
31 virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE
35 std::vector<Mat> inputs, outputs;
36 inputs_arr.getMatVector(inputs);
37 outputs_arr.getMatVector(outputs);
40 float order[] = {0, 2, 1, 3};
41 lp.set("order", DictValue::arrayInt(&order[0], 4));
42 permute = PermuteLayer::create(lp);
44 const Mat& inp = inputs[0];
45 const Mat& out = outputs[0];
47 permuteInpShape.resize(4);
48 permuteInpShape[0] = inp.size[0];
49 permuteInpShape[1] = group;
50 permuteInpShape[2] = inp.size[1] / group;
51 permuteInpShape[3] = inp.size[2]*inp.size[3];
53 permuteOutShape.resize(4);
54 permuteOutShape[0] = permuteInpShape[0];
55 permuteOutShape[1] = permuteInpShape[2];
56 permuteOutShape[2] = permuteInpShape[1];
57 permuteOutShape[3] = permuteInpShape[3];
59 std::vector<Mat> permuteInputs(1, inp.reshape(1, permuteInpShape));
60 std::vector<Mat> permuteOutputs(1, out.reshape(1, permuteOutShape));
61 permute->finalize(permuteInputs, permuteOutputs);
66 bool forward_ocl(InputArrayOfArrays inps, OutputArrayOfArrays outs, OutputArrayOfArrays internals)
68 std::vector<UMat> inputs;
69 std::vector<UMat> outputs;
71 inps.getUMatVector(inputs);
72 outs.getUMatVector(outputs);
74 if (inputs[0].u != outputs[0].u)
78 inputs[0] = inputs[0].reshape(1, permuteInpShape.size(), &permuteInpShape[0]);
79 outputs[0] = outputs[0].reshape(1, permuteOutShape.size(), &permuteOutShape[0]);
80 permute->preferableTarget = preferableTarget;
81 permute->forward(inputs, outputs, internals);
84 inputs[0].copyTo(outputs[0]);
90 void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
93 CV_TRACE_ARG_VALUE(name, "name", name.c_str());
95 CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget),
96 forward_ocl(inputs_arr, outputs_arr, internals_arr))
98 if (inputs_arr.depth() == CV_16S)
100 forward_fallback(inputs_arr, outputs_arr, internals_arr);
104 std::vector<Mat> inputs, outputs, internals;
105 inputs_arr.getMatVector(inputs);
106 outputs_arr.getMatVector(outputs);
107 internals_arr.getMatVector(internals);
110 Mat out = outputs[0];
111 if (inp.data != out.data)
113 if (!permute.empty())
115 inp = inp.reshape(1, permuteInpShape);
116 out = out.reshape(1, permuteOutShape);
117 std::vector<Mat> permuteInputs(1, inp);
118 std::vector<Mat> permuteOutputs(1, out);
119 permute->forward(permuteInputs, permuteOutputs, internals);
127 Ptr<PermuteLayer> permute;
128 std::vector<int> permuteInpShape, permuteOutShape;
131 Ptr<Layer> ShuffleChannelLayer::create(const LayerParams& params)
133 return Ptr<Layer>(new ShuffleChannelLayerImpl(params));