7fc21154cf5b47d0bc441c74b5a4bfc1a0986a3c
[platform/upstream/opencv.git] / modules / dnn / src / layers / resize_layer.cpp
1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4
5 // Copyright (C) 2017, Intel Corporation, all rights reserved.
6 // Third party copyrights are property of their respective owners.
7 #include "../precomp.hpp"
8 #include "layers_common.hpp"
9 #include "../op_inf_engine.hpp"
10 #include <opencv2/imgproc.hpp>
11
12 #ifdef HAVE_DNN_NGRAPH
13 #include "../ie_ngraph.hpp"
14 #if INF_ENGINE_VER_MAJOR_GT(INF_ENGINE_RELEASE_2020_4)
15 #include <ngraph/op/interpolate.hpp>
16 #else
17 #include <ngraph/op/experimental/layers/interpolate.hpp>
18 #endif
19 #endif
20
21 namespace cv { namespace dnn {
22
23 class ResizeLayerImpl : public ResizeLayer
24 {
25 public:
26     ResizeLayerImpl(const LayerParams& params) : zoomFactorWidth(params.get<float>("zoom_factor_x", params.get<float>("zoom_factor", 0))),
27                                                  zoomFactorHeight(params.get<float>("zoom_factor_y", params.get<float>("zoom_factor", 0))),
28                                                  scaleWidth(0), scaleHeight(0)
29     {
30         setParamsFrom(params);
31         outWidth = params.get<float>("width", 0);
32         outHeight = params.get<float>("height", 0);
33         if (params.has("zoom_factor"))
34         {
35             CV_Assert(!params.has("zoom_factor_x") && !params.has("zoom_factor_y"));
36         }
37         else if (params.has("zoom_factor_x") || params.has("zoom_factor_y"))
38         {
39             CV_Assert(params.has("zoom_factor_x") && params.has("zoom_factor_y"));
40         }
41         interpolation = params.get<String>("interpolation");
42         CV_Assert(interpolation == "nearest" || interpolation == "opencv_linear" || interpolation == "bilinear");
43
44         alignCorners = params.get<bool>("align_corners", false);
45     }
46
47     bool getMemoryShapes(const std::vector<MatShape> &inputs,
48                          const int requiredOutputs,
49                          std::vector<MatShape> &outputs,
50                          std::vector<MatShape> &internals) const CV_OVERRIDE
51     {
52         CV_Assert_N(inputs.size() == 1 || inputs.size() == 2, inputs[0].size() == 4);
53         outputs.resize(1, inputs[0]);
54         if (inputs.size() == 1) {
55             outputs[0][2] = zoomFactorHeight > 0 ? (outputs[0][2] * zoomFactorHeight) : outHeight;
56             outputs[0][3] = zoomFactorWidth > 0 ? (outputs[0][3] * zoomFactorWidth) : outWidth;
57         } else {
58             outputs[0][2] = inputs[1][2];
59             outputs[0][3] = inputs[1][3];
60         }
61         // We can work in-place (do nothing) if input shape == output shape.
62         return (outputs[0][2] == inputs[0][2]) && (outputs[0][3] == inputs[0][3]);
63     }
64
65     virtual bool supportBackend(int backendId) CV_OVERRIDE
66     {
67 #ifdef HAVE_INF_ENGINE
68         if (backendId == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 || backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
69         {
70             return (interpolation == "nearest" && scaleWidth == scaleHeight) ||
71                    (interpolation == "bilinear");
72         }
73 #endif
74         return backendId == DNN_BACKEND_OPENCV;
75     }
76
77     virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE
78     {
79         std::vector<Mat> inputs, outputs;
80         inputs_arr.getMatVector(inputs);
81         outputs_arr.getMatVector(outputs);
82
83         outHeight = outputs[0].size[2];
84         outWidth = outputs[0].size[3];
85         if (alignCorners && outHeight > 1)
86             scaleHeight = static_cast<float>(inputs[0].size[2] - 1) / (outHeight - 1);
87         else
88             scaleHeight = static_cast<float>(inputs[0].size[2]) / outHeight;
89
90         if (alignCorners && outWidth > 1)
91             scaleWidth = static_cast<float>(inputs[0].size[3] - 1) / (outWidth - 1);
92         else
93             scaleWidth = static_cast<float>(inputs[0].size[3]) / outWidth;
94     }
95
96     void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
97     {
98         CV_TRACE_FUNCTION();
99         CV_TRACE_ARG_VALUE(name, "name", name.c_str());
100
101         if (inputs_arr.depth() == CV_16S)
102         {
103             forward_fallback(inputs_arr, outputs_arr, internals_arr);
104             return;
105         }
106
107         std::vector<Mat> inputs, outputs, internals;
108         inputs_arr.getMatVector(inputs);
109         outputs_arr.getMatVector(outputs);
110         internals_arr.getMatVector(internals);
111
112         if (outHeight == inputs[0].size[2] && outWidth == inputs[0].size[3])
113             return;
114
115         Mat& inp = inputs[0];
116         Mat& out = outputs[0];
117         if (interpolation == "nearest" || interpolation == "opencv_linear")
118         {
119             InterpolationFlags mode = interpolation == "nearest" ? INTER_NEAREST : INTER_LINEAR;
120             for (size_t n = 0; n < inputs[0].size[0]; ++n)
121             {
122                 for (size_t ch = 0; ch < inputs[0].size[1]; ++ch)
123                 {
124                     resize(getPlane(inp, n, ch), getPlane(out, n, ch),
125                            Size(outWidth, outHeight), 0, 0, mode);
126                 }
127             }
128         }
129         else if (interpolation == "bilinear")
130         {
131             const int inpHeight = inp.size[2];
132             const int inpWidth = inp.size[3];
133             const int inpSpatialSize = inpHeight * inpWidth;
134             const int outSpatialSize = outHeight * outWidth;
135             const int numPlanes = inp.size[0] * inp.size[1];
136             CV_Assert_N(inp.isContinuous(), out.isContinuous());
137
138             Mat inpPlanes = inp.reshape(1, numPlanes * inpHeight);
139             Mat outPlanes = out.reshape(1, numPlanes * outHeight);
140             for (int y = 0; y < outHeight; ++y)
141             {
142                 float input_y = y * scaleHeight;
143                 int y0 = static_cast<int>(input_y);
144                 const float* inpData_row0 = inpPlanes.ptr<float>(y0);
145                 const float* inpData_row1 = inpPlanes.ptr<float>(std::min(y0 + 1, inpHeight - 1));
146                 for (int x = 0; x < outWidth; ++x)
147                 {
148                     float input_x = x * scaleWidth;
149                     int x0 = static_cast<int>(input_x);
150                     int x1 = std::min(x0 + 1, inpWidth - 1);
151
152                     float* outData = outPlanes.ptr<float>(y, x);
153                     const float* inpData_row0_c = inpData_row0;
154                     const float* inpData_row1_c = inpData_row1;
155                     for (int c = 0; c < numPlanes; ++c)
156                     {
157                         *outData = inpData_row0_c[x0] +
158                             (input_y - y0) * (inpData_row1_c[x0] - inpData_row0_c[x0]) +
159                             (input_x - x0) * (inpData_row0_c[x1] - inpData_row0_c[x0] +
160                             (input_y - y0) * (inpData_row1_c[x1] - inpData_row0_c[x1] - inpData_row1_c[x0] + inpData_row0_c[x0]));
161
162                         inpData_row0_c += inpSpatialSize;
163                         inpData_row1_c += inpSpatialSize;
164                         outData += outSpatialSize;
165                     }
166                 }
167             }
168         }
169         else
170             CV_Error(Error::StsNotImplemented, "Unknown interpolation: " + interpolation);
171     }
172
173 #ifdef HAVE_DNN_IE_NN_BUILDER_2019
174     virtual Ptr<BackendNode> initInfEngine(const std::vector<Ptr<BackendWrapper> >&) CV_OVERRIDE
175     {
176         InferenceEngine::Builder::Layer ieLayer(name);
177         ieLayer.setName(name);
178         if (interpolation == "nearest")
179         {
180             ieLayer.setType("Resample");
181             ieLayer.getParameters()["type"] = std::string("caffe.ResampleParameter.NEAREST");
182             ieLayer.getParameters()["antialias"] = false;
183             if (scaleWidth != scaleHeight)
184                 CV_Error(Error::StsNotImplemented, "resample with sw != sh");
185             ieLayer.getParameters()["factor"] = 1.0f / scaleWidth;
186         }
187         else if (interpolation == "bilinear")
188         {
189             ieLayer.setType("Interp");
190             ieLayer.getParameters()["pad_beg"] = 0;
191             ieLayer.getParameters()["pad_end"] = 0;
192             ieLayer.getParameters()["align_corners"] = alignCorners;
193         }
194         else
195             CV_Error(Error::StsNotImplemented, "Unsupported interpolation: " + interpolation);
196         ieLayer.getParameters()["width"] = outWidth;
197         ieLayer.getParameters()["height"] = outHeight;
198         ieLayer.setInputPorts(std::vector<InferenceEngine::Port>(1));
199         ieLayer.setOutputPorts(std::vector<InferenceEngine::Port>(1));
200         return Ptr<BackendNode>(new InfEngineBackendNode(ieLayer));
201     }
202 #endif  // HAVE_DNN_IE_NN_BUILDER_2019
203
204
205 #ifdef HAVE_DNN_NGRAPH
206     virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inputs,
207                                         const std::vector<Ptr<BackendNode> >& nodes) CV_OVERRIDE
208     {
209         auto& ieInpNode = nodes[0].dynamicCast<InfEngineNgraphNode>()->node;
210
211         ngraph::op::InterpolateAttrs attrs;
212         attrs.pads_begin.push_back(0);
213         attrs.pads_end.push_back(0);
214         attrs.axes = ngraph::AxisSet{2, 3};
215         attrs.align_corners = alignCorners;
216
217         if (interpolation == "nearest") {
218             attrs.mode = "nearest";
219             attrs.antialias = false;
220         } else if (interpolation == "bilinear") {
221             attrs.mode = "linear";
222         } else {
223             CV_Error(Error::StsNotImplemented, "Unsupported interpolation: " + interpolation);
224         }
225
226         std::vector<int64_t> shape = {outHeight, outWidth};
227         auto out_shape = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{2}, shape.data());
228         auto interp = std::make_shared<ngraph::op::Interpolate>(ieInpNode, out_shape, attrs);
229         return Ptr<BackendNode>(new InfEngineNgraphNode(interp));
230     }
231 #endif  // HAVE_DNN_NGRAPH
232
233 protected:
234     int outWidth, outHeight;
235     const float zoomFactorWidth, zoomFactorHeight;
236     String interpolation;
237     float scaleWidth, scaleHeight;
238     bool alignCorners;
239 };
240
241
242 Ptr<ResizeLayer> ResizeLayer::create(const LayerParams& params)
243 {
244     return Ptr<ResizeLayer>(new ResizeLayerImpl(params));
245 }
246
247 class InterpLayerImpl CV_FINAL : public ResizeLayerImpl
248 {
249 public:
250     InterpLayerImpl(const LayerParams& params) : ResizeLayerImpl(params) {}
251
252     bool getMemoryShapes(const std::vector<MatShape> &inputs,
253                          const int requiredOutputs,
254                          std::vector<MatShape> &outputs,
255                          std::vector<MatShape> &internals) const CV_OVERRIDE
256     {
257         CV_Assert_N(inputs.size() == 1, inputs[0].size() == 4);
258         outputs.resize(1, inputs[0]);
259         outputs[0][2] = zoomFactorHeight > 0 ? (1 + zoomFactorHeight * (outputs[0][2] - 1)) : outHeight;
260         outputs[0][3] = zoomFactorWidth > 0 ? (1 + zoomFactorWidth * (outputs[0][3] - 1)) : outWidth;
261         // We can work in-place (do nothing) if input shape == output shape.
262         return (outputs[0][2] == inputs[0][2]) && (outputs[0][3] == inputs[0][3]);
263     }
264 };
265
266 Ptr<Layer> InterpLayer::create(const LayerParams& params)
267 {
268     LayerParams lp(params);
269     lp.set("interpolation", "bilinear");
270     lp.set("align_corners", true);
271     return Ptr<Layer>(new InterpLayerImpl(lp));
272 }
273
274 }  // namespace dnn
275 }  // namespace cv