--- /dev/null
+// This file is part of OpenCV project.
+// It is subject to the license terms in the LICENSE file found in the top-level directory
+// of this distribution and at http://opencv.org/license.html.
+
+// Copyright (C) 2018, Intel Corporation, all rights reserved.
+// Third party copyrights are property of their respective owners.
+#include "../precomp.hpp"
+
+namespace cv { namespace dnn {
+
+class ShuffleChannelLayerImpl CV_FINAL : public ShuffleChannelLayer
+{
+public:
+ ShuffleChannelLayerImpl(const LayerParams& params)
+ {
+ group = params.get<int>("group", 1);
+ }
+
+ bool getMemoryShapes(const std::vector<MatShape> &inputs,
+ const int requiredOutputs,
+ std::vector<MatShape> &outputs,
+ std::vector<MatShape> &internals) const CV_OVERRIDE
+ {
+ CV_Assert(inputs.size() == 1 && inputs[0].size() == 4);
+ CV_Assert(inputs[0][1] % group == 0);
+ Layer::getMemoryShapes(inputs, requiredOutputs, outputs, internals);
+ return group == 1;
+ }
+
+ virtual void finalize(const std::vector<Mat*>& inputs, std::vector<Mat> &outputs) CV_OVERRIDE
+ {
+ if (group != 1)
+ {
+ LayerParams lp;
+ float order[] = {0, 2, 1, 3};
+ lp.set("order", DictValue::arrayInt(&order[0], 4));
+ permute = PermuteLayer::create(lp);
+
+ Mat inp = *inputs[0];
+ Mat out = outputs[0];
+
+ permuteInpShape.resize(4);
+ permuteInpShape[0] = inp.size[0];
+ permuteInpShape[1] = group;
+ permuteInpShape[2] = inp.size[1] / group;
+ permuteInpShape[3] = inp.size[2]*inp.size[3];
+
+ permuteOutShape.resize(4);
+ permuteOutShape[0] = permuteInpShape[0];
+ permuteOutShape[1] = permuteInpShape[2];
+ permuteOutShape[2] = permuteInpShape[1];
+ permuteOutShape[3] = permuteInpShape[3];
+
+ inp = inp.reshape(1, permuteInpShape);
+ out = out.reshape(1, permuteOutShape);
+
+ std::vector<Mat*> permuteInputs(1, &inp);
+ std::vector<Mat> permuteOutputs(1, out);
+ permute->finalize(permuteInputs, permuteOutputs);
+ }
+ }
+
+ void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
+ {
+ CV_TRACE_FUNCTION();
+ CV_TRACE_ARG_VALUE(name, "name", name.c_str());
+
+ Layer::forward_fallback(inputs_arr, outputs_arr, internals_arr);
+ }
+
+ void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals) CV_OVERRIDE
+ {
+ CV_TRACE_FUNCTION();
+ CV_TRACE_ARG_VALUE(name, "name", name.c_str());
+
+ Mat inp = *inputs[0];
+ Mat out = outputs[0];
+ if (inp.data != out.data)
+ {
+ if (!permute.empty())
+ {
+ inp = inp.reshape(1, permuteInpShape);
+ out = out.reshape(1, permuteOutShape);
+ std::vector<Mat*> permuteInputs(1, &inp);
+ std::vector<Mat> permuteOutputs(1, out);
+ permute->forward(permuteInputs, permuteOutputs, internals);
+ }
+ else
+ inp.copyTo(out);
+ }
+ }
+
+private:
+ Ptr<PermuteLayer> permute;
+ std::vector<int> permuteInpShape, permuteOutShape;
+};
+
+Ptr<Layer> ShuffleChannelLayer::create(const LayerParams& params)
+{
+ return Ptr<Layer>(new ShuffleChannelLayerImpl(params));
+}
+
+} // namespace dnn
+} // namespace cv
normAssert(indices, outputs[1].reshape(1, 5));
}
+typedef testing::TestWithParam<tuple<Vec4i, int> > Layer_Test_ShuffleChannel;
+TEST_P(Layer_Test_ShuffleChannel, Accuracy)
+{
+ Vec4i inpShapeVec = get<0>(GetParam());
+ int group = get<1>(GetParam());
+ ASSERT_EQ(inpShapeVec[1] % group, 0);
+ const int groupSize = inpShapeVec[1] / group;
+
+ Net net;
+ LayerParams lp;
+ lp.set("group", group);
+ lp.type = "ShuffleChannel";
+ lp.name = "testLayer";
+ net.addLayerToPrev(lp.name, lp.type, lp);
+
+ const int inpShape[] = {inpShapeVec[0], inpShapeVec[1], inpShapeVec[2], inpShapeVec[3]};
+ Mat inp(4, inpShape, CV_32F);
+ randu(inp, 0, 255);
+
+ net.setInput(inp);
+ Mat out = net.forward();
+
+ for (int n = 0; n < inpShapeVec[0]; ++n)
+ {
+ for (int c = 0; c < inpShapeVec[1]; ++c)
+ {
+ Mat outChannel = getPlane(out, n, c);
+ Mat inpChannel = getPlane(inp, n, groupSize * (c % group) + c / group);
+ normAssert(outChannel, inpChannel);
+ }
+ }
+}
+INSTANTIATE_TEST_CASE_P(/**/, Layer_Test_ShuffleChannel, Combine(
+/*input shape*/ Values(Vec4i(1, 6, 5, 7), Vec4i(3, 12, 1, 4)),
+/*group*/ Values(1, 2, 3, 6)
+));
+
}} // namespace