44987f63906821c012d21116772d6d6b3f4b9cc9
[platform/upstream/opencv.git] / modules / dnn / src / layers / shuffle_channel_layer.cpp
1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4
5 // Copyright (C) 2018, Intel Corporation, all rights reserved.
6 // Third party copyrights are property of their respective owners.
7 #include "../precomp.hpp"
8
9 namespace cv { namespace dnn {
10
11 class ShuffleChannelLayerImpl CV_FINAL : public ShuffleChannelLayer
12 {
13 public:
14     ShuffleChannelLayerImpl(const LayerParams& params)
15     {
16         group = params.get<int>("group", 1);
17         setParamsFrom(params);
18     }
19
20     bool getMemoryShapes(const std::vector<MatShape> &inputs,
21                          const int requiredOutputs,
22                          std::vector<MatShape> &outputs,
23                          std::vector<MatShape> &internals) const CV_OVERRIDE
24     {
25         CV_Assert(inputs.size() == 1 && inputs[0].size() == 4);
26         CV_Assert(inputs[0][1] % group == 0);
27         Layer::getMemoryShapes(inputs, requiredOutputs, outputs, internals);
28         return group == 1;
29     }
30
31     virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE
32     {
33         if (group != 1)
34         {
35             std::vector<Mat> inputs, outputs;
36             inputs_arr.getMatVector(inputs);
37             outputs_arr.getMatVector(outputs);
38
39             LayerParams lp;
40             float order[] = {0, 2, 1, 3};
41             lp.set("order", DictValue::arrayInt(&order[0], 4));
42             permute = PermuteLayer::create(lp);
43
44             const Mat& inp = inputs[0];
45             const Mat& out = outputs[0];
46
47             permuteInpShape.resize(4);
48             permuteInpShape[0] = inp.size[0];
49             permuteInpShape[1] = group;
50             permuteInpShape[2] = inp.size[1] / group;
51             permuteInpShape[3] = inp.size[2]*inp.size[3];
52
53             permuteOutShape.resize(4);
54             permuteOutShape[0] = permuteInpShape[0];
55             permuteOutShape[1] = permuteInpShape[2];
56             permuteOutShape[2] = permuteInpShape[1];
57             permuteOutShape[3] = permuteInpShape[3];
58
59             std::vector<Mat> permuteInputs(1, inp.reshape(1, permuteInpShape));
60             std::vector<Mat> permuteOutputs(1, out.reshape(1, permuteOutShape));
61             permute->finalize(permuteInputs, permuteOutputs);
62         }
63     }
64
65 #ifdef HAVE_OPENCL
66     bool forward_ocl(InputArrayOfArrays inps, OutputArrayOfArrays outs, OutputArrayOfArrays internals)
67     {
68         std::vector<UMat> inputs;
69         std::vector<UMat> outputs;
70
71         inps.getUMatVector(inputs);
72         outs.getUMatVector(outputs);
73
74         if (inputs[0].u != outputs[0].u)
75         {
76             if (!permute.empty())
77             {
78                 inputs[0] = inputs[0].reshape(1, permuteInpShape.size(), &permuteInpShape[0]);
79                 outputs[0] = outputs[0].reshape(1, permuteOutShape.size(), &permuteOutShape[0]);
80                 permute->preferableTarget = preferableTarget;
81                 permute->forward(inputs, outputs, internals);
82             }
83             else
84                 inputs[0].copyTo(outputs[0]);
85         }
86         return true;
87     }
88 #endif
89
90     void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
91     {
92         CV_TRACE_FUNCTION();
93         CV_TRACE_ARG_VALUE(name, "name", name.c_str());
94
95         CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget),
96                    forward_ocl(inputs_arr, outputs_arr, internals_arr))
97
98         if (inputs_arr.depth() == CV_16S)
99         {
100             forward_fallback(inputs_arr, outputs_arr, internals_arr);
101             return;
102         }
103
104         std::vector<Mat> inputs, outputs, internals;
105         inputs_arr.getMatVector(inputs);
106         outputs_arr.getMatVector(outputs);
107         internals_arr.getMatVector(internals);
108
109         Mat inp = inputs[0];
110         Mat out = outputs[0];
111         if (inp.data != out.data)
112         {
113             if (!permute.empty())
114             {
115                 inp = inp.reshape(1, permuteInpShape);
116                 out = out.reshape(1, permuteOutShape);
117                 std::vector<Mat> permuteInputs(1, inp);
118                 std::vector<Mat> permuteOutputs(1, out);
119                 permute->forward(permuteInputs, permuteOutputs, internals);
120             }
121             else
122                 inp.copyTo(out);
123         }
124     }
125
126 private:
127     Ptr<PermuteLayer> permute;
128     std::vector<int> permuteInpShape, permuteOutShape;
129 };
130
131 Ptr<Layer> ShuffleChannelLayer::create(const LayerParams& params)
132 {
133     return Ptr<Layer>(new ShuffleChannelLayerImpl(params));
134 }
135
136 }  // namespace dnn
137 }  // namespace cv