Merge pull request #10255 from dkurt:dnn_roi_pooling
[platform/upstream/opencv.git] / modules / dnn / src / layers / pooling_layer.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                           License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2013, OpenCV Foundation, all rights reserved.
14 // Copyright (C) 2017, Intel Corporation, all rights reserved.
15 // Third party copyrights are property of their respective owners.
16 //
17 // Redistribution and use in source and binary forms, with or without modification,
18 // are permitted provided that the following conditions are met:
19 //
20 //   * Redistribution's of source code must retain the above copyright notice,
21 //     this list of conditions and the following disclaimer.
22 //
23 //   * Redistribution's in binary form must reproduce the above copyright notice,
24 //     this list of conditions and the following disclaimer in the documentation
25 //     and/or other materials provided with the distribution.
26 //
27 //   * The name of the copyright holders may not be used to endorse or promote products
28 //     derived from this software without specific prior written permission.
29 //
30 // This software is provided by the copyright holders and contributors "as is" and
31 // any express or implied warranties, including, but not limited to, the implied
32 // warranties of merchantability and fitness for a particular purpose are disclaimed.
33 // In no event shall the Intel Corporation or contributors be liable for any direct,
34 // indirect, incidental, special, exemplary, or consequential damages
35 // (including, but not limited to, procurement of substitute goods or services;
36 // loss of use, data, or profits; or business interruption) however caused
37 // and on any theory of liability, whether in contract, strict liability,
38 // or tort (including negligence or otherwise) arising in any way out of
39 // the use of this software, even if advised of the possibility of such damage.
40 //
41 //M*/
42
43 #include "../precomp.hpp"
44 #include "layers_common.hpp"
45 #include "opencv2/core/hal/intrin.hpp"
46 #include "op_halide.hpp"
47 #include "opencl_kernels_dnn.hpp"
48 #include <float.h>
49 #include <algorithm>
50 using std::max;
51 using std::min;
52 #ifdef HAVE_OPENCL
53 using namespace cv::dnn::ocl4dnn;
54 #endif
55
56 namespace cv
57 {
58 namespace dnn
59 {
60
61 class PoolingLayerImpl : public PoolingLayer
62 {
63 public:
64     PoolingLayerImpl(const LayerParams& params)
65     {
66         type = PoolingLayer::MAX;
67         computeMaxIdx = true;
68         globalPooling = false;
69
70         if (params.has("pool"))
71         {
72             String pool = params.get<String>("pool").toLowerCase();
73             if (pool == "max")
74                 type = PoolingLayer::MAX;
75             else if (pool == "ave")
76                 type = PoolingLayer::AVE;
77             else if (pool == "stochastic")
78                 type = PoolingLayer::STOCHASTIC;
79             else
80                 CV_Error(Error::StsBadArg, "Unknown pooling type \"" + pool + "\"");
81             getPoolingKernelParams(params, kernel.height, kernel.width, globalPooling,
82                                    pad.height, pad.width, stride.height, stride.width, padMode);
83         }
84         else if (params.has("pooled_w") || params.has("pooled_h") || params.has("spatial_scale"))
85         {
86             type = PoolingLayer::ROI;
87         }
88         setParamsFrom(params);
89         ceilMode = params.get<bool>("ceil_mode", true);
90         pooledSize.width = params.get<uint32_t>("pooled_w", 1);
91         pooledSize.height = params.get<uint32_t>("pooled_h", 1);
92         spatialScale = params.get<float>("spatial_scale", 1);
93     }
94
95 #ifdef HAVE_OPENCL
96     Ptr<OCL4DNNPool<float> > poolOp;
97 #endif
98
99     void finalize(const std::vector<Mat*> &inputs, std::vector<Mat> &outputs)
100     {
101         CV_Assert(!inputs.empty());
102
103         cv::Size inp(inputs[0]->size[3], inputs[0]->size[2]),
104                 out(outputs[0].size[3], outputs[0].size[2]);
105
106         if(globalPooling)
107         {
108             kernel = inp;
109         }
110
111         getConvPoolPaddings(inp, out, kernel, stride, padMode, Size(1, 1), pad);
112     }
113
114     virtual bool supportBackend(int backendId)
115     {
116         return backendId == DNN_BACKEND_DEFAULT ||
117                backendId == DNN_BACKEND_HALIDE && haveHalide() &&
118                (type == PoolingLayer::MAX ||
119                 type == PoolingLayer::AVE && !pad.width && !pad.height);
120     }
121
122 #ifdef HAVE_OPENCL
123     bool forward_ocl(InputArrayOfArrays inps, OutputArrayOfArrays outs, InputArrayOfArrays internals)
124     {
125         std::vector<UMat> inputs;
126         std::vector<UMat> outputs;
127
128         inps.getUMatVector(inputs);
129         outs.getUMatVector(outputs);
130
131         if (poolOp.empty())
132         {
133             OCL4DNNPoolConfig config;
134
135             config.in_shape = shape(inputs[0]);
136             config.out_shape = shape(outputs[0]);
137             config.kernel = kernel;
138             config.pad = pad;
139             config.stride = stride;
140             config.channels = inputs[0].size[1];
141             config.pool_method = type == MAX ? LIBDNN_POOLING_METHOD_MAX :
142                                 (type == AVE ? LIBDNN_POOLING_METHOD_AVE :
143                                                LIBDNN_POOLING_METHOD_STO);
144             poolOp = Ptr<OCL4DNNPool<float> >(new OCL4DNNPool<float>(config));
145         }
146
147         for (size_t ii = 0; ii < inputs.size(); ii++)
148         {
149             UMat& inpMat = inputs[ii];
150             int out_index = (type == MAX) ? 2 : 1;
151             UMat& outMat = outputs[out_index * ii];
152             UMat maskMat = (type == MAX) ? outputs[2 * ii + 1] : UMat();
153
154             CV_Assert(inpMat.offset == 0 && outMat.offset == 0);
155
156             if (!poolOp->Forward(inpMat, outMat, maskMat))
157                 return false;
158         }
159
160         return true;
161     }
162 #endif
163
164     void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr)
165     {
166         CV_TRACE_FUNCTION();
167         CV_TRACE_ARG_VALUE(name, "name", name.c_str());
168
169         CV_OCL_RUN((preferableTarget == DNN_TARGET_OPENCL) &&
170                    OCL_PERFORMANCE_CHECK(ocl::Device::getDefault().isIntel()),
171                    forward_ocl(inputs_arr, outputs_arr, internals_arr))
172
173         Layer::forward_fallback(inputs_arr, outputs_arr, internals_arr);
174     }
175
176     void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals)
177     {
178         CV_TRACE_FUNCTION();
179         CV_TRACE_ARG_VALUE(name, "name", name.c_str());
180
181         switch (type)
182         {
183             case MAX:
184                 CV_Assert(inputs.size() == 1, outputs.size() == 2);
185                 maxPooling(*inputs[0], outputs[0], outputs[1]);
186                 break;
187             case AVE:
188                 CV_Assert(inputs.size() == 1, outputs.size() == 1);
189                 avePooling(*inputs[0], outputs[0]);
190                 break;
191             case ROI:
192                 CV_Assert(inputs.size() == 2, outputs.size() == 1);
193                 roiPooling(*inputs[0], *inputs[1], outputs[0]);
194                 break;
195             default:
196                 CV_Error(Error::StsNotImplemented, "Not implemented");
197                 break;
198         }
199     }
200
201     virtual Ptr<BackendNode> initHalide(const std::vector<Ptr<BackendWrapper> > &inputs)
202     {
203         if (type == PoolingLayer::MAX)
204             return initMaxPoolingHalide(inputs);
205         else if (type == PoolingLayer::AVE)
206             return initAvePoolingHalide(inputs);
207         else
208             return Ptr<BackendNode>();
209     }
210
211     class PoolingInvoker : public ParallelLoopBody
212     {
213     public:
214         const Mat* src, *rois;
215         Mat *dst, *mask;
216         Size kernel, stride, pad;
217         int nstripes;
218         bool computeMaxIdx;
219         std::vector<int> ofsbuf;
220         int poolingType;
221         float spatialScale;
222
223         PoolingInvoker() : src(0), rois(0), dst(0), mask(0), nstripes(0),
224                            computeMaxIdx(0), poolingType(PoolingLayer::MAX), spatialScale(0) {}
225
226         static void run(const Mat& src, const Mat& rois, Mat& dst, Mat& mask, Size kernel,
227                         Size stride, Size pad, int poolingType, float spatialScale,
228                         bool computeMaxIdx, int nstripes)
229         {
230             CV_Assert(src.isContinuous() && dst.isContinuous() &&
231                       src.type() == CV_32F && src.type() == dst.type() &&
232                       src.dims == 4 && dst.dims == 4 &&
233                       (poolingType == ROI && dst.size[0] == rois.size[0] ||
234                        src.size[0] == dst.size[0]) && src.size[1] == dst.size[1] &&
235                       (mask.empty() || (mask.type() == src.type() && mask.size == dst.size)));
236
237             PoolingInvoker p;
238
239             p.src = &src;
240             p.rois = &rois;
241             p.dst = &dst;
242             p.mask = &mask;
243             p.kernel = kernel;
244             p.stride = stride;
245             p.pad = pad;
246             p.nstripes = nstripes;
247             p.computeMaxIdx = computeMaxIdx;
248             p.poolingType = poolingType;
249             p.spatialScale = spatialScale;
250
251             if( !computeMaxIdx )
252             {
253                 p.ofsbuf.resize(kernel.width*kernel.height);
254                 for( int i = 0; i < kernel.height; i++ )
255                     for( int j = 0; j < kernel.width; j++ )
256                         p.ofsbuf[i*kernel.width + j] = src.size[3]*i + j;
257             }
258
259             parallel_for_(Range(0, nstripes), p, nstripes);
260         }
261
262         void operator()(const Range& r) const
263         {
264             int channels = dst->size[1], width = dst->size[3], height = dst->size[2];
265             int inp_width = src->size[3], inp_height = src->size[2];
266             size_t total = dst->total();
267             size_t stripeSize = (total + nstripes - 1)/nstripes;
268             size_t stripeStart = r.start*stripeSize;
269             size_t stripeEnd = std::min(r.end*stripeSize, total);
270             int kernel_w = kernel.width, kernel_h = kernel.height;
271             int pad_w = pad.width, pad_h = pad.height;
272             int stride_w = stride.width, stride_h = stride.height;
273             bool compMaxIdx = computeMaxIdx;
274
275 #if CV_SIMD128
276             const int* ofsptr = &ofsbuf[0];
277             v_float32x4 idx00(0.f, (float)stride_w, (float)(stride_w*2), (float)(stride_w*3));
278             v_float32x4 ones = v_setall_f32(1.f);
279             v_float32x4 idx_delta = v_setall_f32((float)(inp_width - kernel_w));
280 #endif
281
282             for( size_t ofs0 = stripeStart; ofs0 < stripeEnd; )
283             {
284                 size_t ofs = ofs0;
285                 int x0 = (int)(ofs % width);
286                 ofs /= width;
287                 int y0 = (int)(ofs % height);
288                 ofs /= height;
289                 int c = (int)(ofs % channels);
290                 int n = (int)(ofs / channels);
291                 int ystart, yend;
292
293                 const float *srcData;
294                 int xstartROI = 0;
295                 float roiRatio = 0;
296                 if (poolingType == ROI)
297                 {
298                     const float *roisData = rois->ptr<float>(n);
299                     int ystartROI = round(roisData[2] * spatialScale);
300                     int yendROI = round(roisData[4] * spatialScale);
301                     int roiHeight = std::max(yendROI - ystartROI + 1, 1);
302                     roiRatio = (float)roiHeight / height;
303
304                     ystart = ystartROI + y0 * roiRatio;
305                     yend = ystartROI + std::ceil((y0 + 1) * roiRatio);
306
307                     xstartROI = round(roisData[1] * spatialScale);
308                     int xendROI = round(roisData[3] * spatialScale);
309                     int roiWidth = std::max(xendROI - xstartROI + 1, 1);
310                     roiRatio = (float)roiWidth / width;
311
312                     CV_Assert(roisData[0] < src->size[0]);
313                     srcData = src->ptr<float>(roisData[0], c);
314                 }
315                 else
316                 {
317                     ystart = y0 * stride_h - pad_h;
318                     yend = min(ystart + kernel_h, inp_height + pad_h);
319                     srcData = src->ptr<float>(n, c);
320                 }
321                 int ydelta = yend - ystart;
322                 ystart = max(ystart, 0);
323                 yend = min(yend, inp_height);
324                 float *dstData = dst->ptr<float>(n, c, y0);
325                 float *dstMaskData = mask->data ? mask->ptr<float>(n, c, y0) : 0;
326
327                 int delta = std::min((int)(stripeEnd - ofs0), width - x0);
328                 ofs0 += delta;
329                 int x1 = x0 + delta;
330
331                 if( poolingType == MAX || poolingType == ROI)
332                     for( ; x0 < x1; x0++ )
333                     {
334                         int xstart, xend;
335                         if (poolingType == ROI)
336                         {
337                             xstart = xstartROI + x0 * roiRatio;
338                             xend = xstartROI + std::ceil((x0 + 1) * roiRatio);
339                         }
340                         else
341                         {
342                             xstart = x0 * stride_w - pad_w;
343                             xend = xstart + kernel_w;
344                         }
345                         xstart = max(xstart, 0);
346                         xend = min(xend, inp_width);
347                         if (xstart >= xend || ystart >= yend)
348                         {
349                             dstData[x0] = 0;
350                             if (compMaxIdx && dstMaskData)
351                                 dstMaskData[x0] = -1;
352                             continue;
353                         }
354 #if CV_SIMD128
355                         if( xstart > 0 && x0 + 7 < x1 && (x0 + 7) * stride_w - pad_w + kernel_w < inp_width )
356                         {
357                             if( compMaxIdx )
358                             {
359                                 v_float32x4 max_val0 = v_setall_f32(-FLT_MAX);
360                                 v_float32x4 max_val1 = max_val0;
361                                 v_float32x4 max_idx0 = v_setall_f32(-1.f);
362                                 v_float32x4 max_idx1 = max_idx0;
363                                 int index0 = ystart * inp_width + xstart;
364                                 v_float32x4 idx0 = idx00 + v_setall_f32((float)index0);
365                                 v_float32x4 idx1 = idx0 + v_setall_f32((float)(stride_w*4));
366
367                                 for (int y = ystart; y < yend; ++y)
368                                 {
369                                     for (int x = xstart; x < xend; ++x, idx0 += ones, idx1 += ones)
370                                     {
371                                         const int index = y * inp_width + x;
372                                         v_float32x4 v0(srcData[index], srcData[index + stride_w],
373                                                        srcData[index + stride_w*2], srcData[index + stride_w*3]);
374                                         v_float32x4 v1(srcData[index + stride_w*4], srcData[index + stride_w*5],
375                                                        srcData[index + stride_w*6], srcData[index + stride_w*7]);
376                                         max_idx0 = v_select(v0 > max_val0, idx0, max_idx0);
377                                         max_idx1 = v_select(v1 > max_val1, idx1, max_idx1);
378                                         max_val0 = v_max(max_val0, v0);
379                                         max_val1 = v_max(max_val1, v1);
380                                     }
381                                     idx0 += idx_delta;
382                                     idx1 += idx_delta;
383                                 }
384                                 v_store(dstData + x0, max_val0);
385                                 v_store(dstData + x0 + 4, max_val1);
386                                 if (dstMaskData)
387                                 {
388                                     v_store(dstMaskData + x0, max_idx0);
389                                     v_store(dstMaskData + x0 + 4, max_idx1);
390                                 }
391                                 x0 += 7;
392                             }
393                             else
394                             {
395                                 v_float32x4 max_val0 = v_setall_f32(-FLT_MAX);
396                                 v_float32x4 max_val1 = max_val0;
397
398                                 if( yend - ystart == kernel_h )
399                                 {
400                                     const float* srcData1 = srcData + ystart*inp_width + xstart;
401                                     if( stride_w == 1 )
402                                         for (int k = 0; k < kernel_w*kernel_h; k++)
403                                         {
404                                             int index = ofsptr[k];
405                                             v_float32x4 v0 = v_load(srcData1 + index);
406                                             v_float32x4 v1 = v_load(srcData1 + index + 4);
407                                             max_val0 = v_max(max_val0, v0);
408                                             max_val1 = v_max(max_val1, v1);
409                                         }
410 #if CV_SSE2
411                                     else if( stride_w == 2 )
412                                         for (int k = 0; k < kernel_w*kernel_h; k++)
413                                         {
414                                             int index = ofsptr[k];
415                                             v_float32x4 v00 = v_load(srcData1 + index), v01 = v_load(srcData1 + index + 4);
416                                             v_float32x4 v0(_mm_shuffle_ps(v00.val, v01.val, _MM_SHUFFLE(2, 0, 2, 0)));
417                                             v_float32x4 v10 = v_load(srcData1 + index + 8), v11 = v_load(srcData1 + index + 12);
418                                             v_float32x4 v1(_mm_shuffle_ps(v10.val, v11.val, _MM_SHUFFLE(2, 0, 2, 0)));
419                                             max_val0 = v_max(max_val0, v0);
420                                             max_val1 = v_max(max_val1, v1);
421                                         }
422 #endif
423                                     else
424                                         for (int k = 0; k < kernel_w*kernel_h; k++)
425                                         {
426                                             int index = ofsptr[k];
427                                             v_float32x4 v0(srcData1[index], srcData1[index + stride_w],
428                                                            srcData1[index + stride_w*2], srcData1[index + stride_w*3]);
429                                             v_float32x4 v1(srcData1[index + stride_w*4], srcData1[index + stride_w*5],
430                                                            srcData1[index + stride_w*6], srcData1[index + stride_w*7]);
431                                             max_val0 = v_max(max_val0, v0);
432                                             max_val1 = v_max(max_val1, v1);
433                                         }
434                                 }
435                                 else
436                                 {
437                                     for (int y = ystart; y < yend; ++y)
438                                     {
439                                         for (int x = xstart; x < xend; ++x)
440                                         {
441                                             const int index = y * inp_width + x;
442                                             v_float32x4 v0(srcData[index], srcData[index + stride_w],
443                                                            srcData[index + stride_w*2], srcData[index + stride_w*3]);
444                                             v_float32x4 v1(srcData[index + stride_w*4], srcData[index + stride_w*5],
445                                                            srcData[index + stride_w*6], srcData[index + stride_w*7]);
446                                             max_val0 = v_max(max_val0, v0);
447                                             max_val1 = v_max(max_val1, v1);
448                                         }
449                                     }
450                                 }
451                                 v_store(dstData + x0, max_val0);
452                                 v_store(dstData + x0 + 4, max_val1);
453                                 x0 += 7;
454                             }
455                         }
456                         else
457 #endif
458                         {
459                             float max_val = -FLT_MAX;
460                             if( compMaxIdx )
461                             {
462                                 int max_index = -1;
463                                 for (int y = ystart; y < yend; ++y)
464                                     for (int x = xstart; x < xend; ++x)
465                                     {
466                                         const int index = y * inp_width + x;
467                                         float val = srcData[index];
468                                         if (val > max_val)
469                                         {
470                                             max_val = val;
471                                             max_index = index;
472                                         }
473                                     }
474
475                                 dstData[x0] = max_val;
476                                 if (dstMaskData)
477                                     dstMaskData[x0] = max_index;
478                             }
479                             else
480                             {
481                                 for (int y = ystart; y < yend; ++y)
482                                     for (int x = xstart; x < xend; ++x)
483                                     {
484                                         const int index = y * inp_width + x;
485                                         float val = srcData[index];
486                                         max_val = std::max(max_val, val);
487                                     }
488
489                                 dstData[x0] = max_val;
490                             }
491                         }
492                     }
493                 else
494                 {
495                     for( ; x0 < x1; x0++ )
496                     {
497                         int xstart = x0 * stride_w - pad_w;
498                         int xend = min(xstart + kernel_w, inp_width + pad_w);
499                         int xdelta = xend - xstart;
500                         xstart = max(xstart, 0);
501                         xend = min(xend, inp_width);
502                         float inv_kernel_area = 1.f/(ydelta*xdelta);
503
504 #if CV_SIMD128
505                         if( xstart > 0 && x0 + 7 < x1 && (x0 + 7) * stride_w - pad_w + kernel_w < inp_width )
506                         {
507                             v_float32x4 sum_val0 = v_setzero_f32(), sum_val1 = v_setzero_f32();
508                             v_float32x4 ikarea = v_setall_f32(inv_kernel_area);
509
510                             for (int y = ystart; y < yend; ++y)
511                             {
512                                 for (int x = xstart; x < xend; ++x)
513                                 {
514                                     const int index = y * inp_width + x;
515                                     v_float32x4 v0(srcData[index], srcData[index + stride_w],
516                                                    srcData[index + stride_w*2], srcData[index + stride_w*3]);
517                                     v_float32x4 v1(srcData[index + stride_w*4], srcData[index + stride_w*5],
518                                                    srcData[index + stride_w*6], srcData[index + stride_w*7]);
519                                     sum_val0 += v0;
520                                     sum_val1 += v1;
521                                 }
522                             }
523                             v_store(dstData + x0, sum_val0*ikarea);
524                             v_store(dstData + x0 + 4, sum_val1*ikarea);
525                             x0 += 7;
526                         }
527                         else
528 #endif
529                         {
530                             float sum_val = 0.f;
531                             for (int y = ystart; y < yend; ++y)
532                                 for (int x = xstart; x < xend; ++x)
533                                 {
534                                     const int index = y * inp_width + x;
535                                     float val = srcData[index];
536                                     sum_val += val;
537                                 }
538
539                             dstData[x0] = sum_val*inv_kernel_area;
540                         }
541                     }
542                 }
543             }
544         }
545     };
546
547     void maxPooling(Mat &src, Mat &dst, Mat &mask)
548     {
549         const int nstripes = getNumThreads();
550         Mat rois;
551         PoolingInvoker::run(src, rois, dst, mask, kernel, stride, pad, type, spatialScale, computeMaxIdx, nstripes);
552     }
553
554     void avePooling(Mat &src, Mat &dst)
555     {
556         const int nstripes = getNumThreads();
557         Mat rois, mask;
558         PoolingInvoker::run(src, rois, dst, mask, kernel, stride, pad, type, spatialScale, computeMaxIdx, nstripes);
559     }
560
561     void roiPooling(const Mat &src, const Mat &rois, Mat &dst)
562     {
563         const int nstripes = getNumThreads();
564         Mat mask;
565         PoolingInvoker::run(src, rois, dst, mask, kernel, stride, pad, type, spatialScale, computeMaxIdx, nstripes);
566     }
567
568     virtual Ptr<BackendNode> initMaxPoolingHalide(const std::vector<Ptr<BackendWrapper> > &inputs)
569     {
570 #ifdef HAVE_HALIDE
571         Halide::Buffer<float> inputBuffer = halideBuffer(inputs[0]);
572         const int inWidth = inputBuffer.width();
573         const int inHeight = inputBuffer.height();
574
575         Halide::Var x("x"), y("y"), c("c"), n("n");
576         Halide::Func top = (name.empty() ? Halide::Func() : Halide::Func(name));
577         Halide::RDom r(0, kernel.width, 0, kernel.height);
578         Halide::Expr kx, ky;
579         if (pad.width || pad.height)
580         {
581             kx = clamp(x * stride.width + r.x - pad.width, 0, inWidth - 1);
582             ky = clamp(y * stride.height + r.y - pad.height, 0, inHeight - 1);
583         }
584         else
585         {
586             kx = min(x * stride.width + r.x, inWidth - 1);
587             ky = min(y * stride.height + r.y, inHeight - 1);
588         }
589
590         // Halide::argmax returns tuple (r.x, r.y, max).
591         Halide::Tuple res = argmax(inputBuffer(kx, ky, c, n));
592
593         // Compute offset from argmax in range [0, kernel_size).
594         Halide::Expr max_index;
595         if (pad.width || pad.height)
596         {
597             max_index = clamp(y * stride.height + res[1] - pad.height,
598                               0, inHeight - 1) * inWidth +
599                         clamp(x * stride.width + res[0] - pad.width,
600                               0, inWidth - 1);
601         }
602         else
603         {
604             max_index = min(y * stride.height + res[1], inHeight - 1) * inWidth +
605                         min(x * stride.width + res[0], inWidth - 1);
606         }
607         top(x, y, c, n) = { res[2], Halide::cast<float>(max_index) };
608         return Ptr<BackendNode>(new HalideBackendNode(top));
609 #endif  // HAVE_HALIDE
610         return Ptr<BackendNode>();
611     }
612
613     virtual Ptr<BackendNode> initAvePoolingHalide(const std::vector<Ptr<BackendWrapper> > &inputs)
614     {
615 #ifdef HAVE_HALIDE
616         Halide::Buffer<float> inputBuffer = halideBuffer(inputs[0]);
617
618         const int inW = inputBuffer.width(), inH = inputBuffer.height();
619         if ((inW - kernel.width) % stride.width || (inH - kernel.height) % stride.height)
620         {
621             CV_Error(cv::Error::StsNotImplemented,
622                      "Halide backend for average pooling with partial "
623                      "kernels is not implemented");
624         }
625
626         const float norm = 1.0f / (kernel.width * kernel.height);
627
628         Halide::Var x("x"), y("y"), c("c"), n("n");
629         Halide::Func top = (name.empty() ? Halide::Func() : Halide::Func(name));
630         Halide::RDom r(0, kernel.width, 0, kernel.height);
631         top(x, y, c, n) = sum(
632             inputBuffer(x * stride.width + r.x,
633                         y * stride.height + r.y, c, n)) * norm;
634         return Ptr<BackendNode>(new HalideBackendNode(top));
635 #endif  // HAVE_HALIDE
636         return Ptr<BackendNode>();
637     }
638
639     virtual void applyHalideScheduler(Ptr<BackendNode>& node,
640                                       const std::vector<Mat*> &inputs,
641                                       const std::vector<Mat> &outputs,
642                                       int targetId) const
643     {
644 #ifdef  HAVE_HALIDE
645         if (targetId != DNN_TARGET_CPU)
646         {
647             Layer::applyHalideScheduler(node, inputs, outputs, targetId);
648             return;
649         }
650         Halide::Var x("x"), y("y"), c("c"), n("n"), tile("tile"),
651                     xi("xi"), yi("yi"), ci("ci"), xo("xo"), yo("yo"), co("co");
652         Halide::Func& top = node.dynamicCast<HalideBackendNode>()->funcs.back();
653
654         int outW, outH, outC, outN;
655         getCanonicalSize(outputs[0].size, &outW, &outH, &outC, &outN);
656
657         if (outW < 8 || outH < 8)
658         {
659             if (outC > 8)
660                 top.split(c, co, ci, 8)
661                    .fuse(x, y, tile).fuse(co, tile, tile).fuse(n, tile, tile)
662                    .parallel(tile)
663                    .vectorize(ci);
664             else
665             {
666                 top.fuse(y, c, tile).fuse(n, tile, tile)
667                    .parallel(tile);
668                 if (outW > 1)
669                     top.vectorize(x);
670             }
671         }
672         else
673         {
674             if (outC > 8)
675                 top.split(x, xo, xi, 8).split(y, yo, yi, 8).split(c, co, ci, 8)
676                    .fuse(xo, yo, tile).fuse(co, tile, tile).fuse(n, tile, tile)
677                    .parallel(tile)
678                    .vectorize(xi);
679             else
680                 top.split(x, xo, xi, 8).split(y, yo, yi, 8)
681                    .fuse(xo, yo, tile).fuse(c, tile, tile).fuse(n, tile, tile)
682                    .parallel(tile)
683                    .vectorize(xi);
684         }
685 #endif  // HAVE_HALIDE
686     }
687
688     bool getMemoryShapes(const std::vector<MatShape> &inputs,
689                          const int requiredOutputs,
690                          std::vector<MatShape> &outputs,
691                          std::vector<MatShape> &internals) const
692     {
693         CV_Assert(inputs.size() != 0);
694         Size in(inputs[0][3], inputs[0][2]), out;
695
696         if (globalPooling)
697         {
698             out.height = 1;
699             out.width = 1;
700         }
701         else if (type == PoolingLayer::ROI)
702         {
703             out.height = pooledSize.height;
704             out.width = pooledSize.width;
705         }
706         else if (padMode.empty())
707         {
708             float height = (float)(in.height + 2 * pad.height - kernel.height) / stride.height;
709             float width = (float)(in.width + 2 * pad.width - kernel.width) / stride.width;
710             out.height = 1 + (ceilMode ? ceil(height) : floor(height));
711             out.width = 1 + (ceilMode ? ceil(width) : floor(width));
712
713             if (pad.height || pad.width)
714             {
715                 // If we have padding, ensure that the last pooling starts strictly
716                 // inside the image (instead of at the padding); otherwise clip the last.
717                 if ((out.height - 1) * stride.height >= in.height + pad.height)
718                     --out.height;
719                 if ((out.width - 1) * stride.width >= in.width + pad.width)
720                     --out.width;
721                 CV_Assert((out.height - 1) * stride.height < in.height + pad.height);
722                 CV_Assert((out.width - 1) * stride.width < in.width + pad.width);
723             }
724         }
725         else
726         {
727             getConvPoolOutParams(in, kernel, stride, padMode, Size(1, 1), out);
728         }
729
730         int dims[] = {inputs[0][0], inputs[0][1], out.height, out.width};
731         if (type == ROI)
732         {
733             CV_Assert(inputs.size() == 2);
734             dims[0] = inputs[1][0];  // Number of proposals;
735         }
736         outputs.assign(type == MAX ? 2 : 1, shape(dims));
737         return false;
738     }
739
740     virtual int64 getFLOPS(const std::vector<MatShape> &inputs,
741                            const std::vector<MatShape> &outputs) const
742     {
743         (void)inputs; // suppress unused variable warning
744         long flops = 0;
745
746         for(int i = 0; i < outputs.size(); i++)
747         {
748             if (type == MAX)
749             {
750                 if (i%2 == 0)
751                     flops += total(outputs[i])*kernel.area();
752             }
753             else
754             {
755                 flops += total(outputs[i])*(kernel.area() + 1);
756             }
757         }
758         return flops;
759     }
760 };
761
762 Ptr<PoolingLayer> PoolingLayer::create(const LayerParams& params)
763 {
764     return Ptr<PoolingLayer>(new PoolingLayerImpl(params));
765 }
766
767 }
768 }