Merge pull request #8895 from lewisjb:python-docstrings
[platform/upstream/opencv.git] / modules / dnn / src / layers / pooling_layer.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                           License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2013, OpenCV Foundation, all rights reserved.
14 // Copyright (C) 2017, Intel Corporation, all rights reserved.
15 // Third party copyrights are property of their respective owners.
16 //
17 // Redistribution and use in source and binary forms, with or without modification,
18 // are permitted provided that the following conditions are met:
19 //
20 //   * Redistribution's of source code must retain the above copyright notice,
21 //     this list of conditions and the following disclaimer.
22 //
23 //   * Redistribution's in binary form must reproduce the above copyright notice,
24 //     this list of conditions and the following disclaimer in the documentation
25 //     and/or other materials provided with the distribution.
26 //
27 //   * The name of the copyright holders may not be used to endorse or promote products
28 //     derived from this software without specific prior written permission.
29 //
30 // This software is provided by the copyright holders and contributors "as is" and
31 // any express or implied warranties, including, but not limited to, the implied
32 // warranties of merchantability and fitness for a particular purpose are disclaimed.
33 // In no event shall the Intel Corporation or contributors be liable for any direct,
34 // indirect, incidental, special, exemplary, or consequential damages
35 // (including, but not limited to, procurement of substitute goods or services;
36 // loss of use, data, or profits; or business interruption) however caused
37 // and on any theory of liability, whether in contract, strict liability,
38 // or tort (including negligence or otherwise) arising in any way out of
39 // the use of this software, even if advised of the possibility of such damage.
40 //
41 //M*/
42
43 #include "../precomp.hpp"
44 #include "layers_common.hpp"
45 #include "opencv2/core/hal/intrin.hpp"
46 #include "op_halide.hpp"
47 #include <float.h>
48 #include <algorithm>
49 using std::max;
50 using std::min;
51
52 namespace cv
53 {
54 namespace dnn
55 {
56
57 //TODO: add ceil_mode param
58 class PoolingLayerImpl : public PoolingLayer
59 {
60 public:
61     PoolingLayerImpl(const LayerParams& params)
62     {
63         type = PoolingLayer::MAX;
64         computeMaxIdx = true;
65
66         if (params.has("pool"))
67         {
68             String pool = params.get<String>("pool").toLowerCase();
69             if (pool == "max")
70                 type = PoolingLayer::MAX;
71             else if (pool == "ave")
72                 type = PoolingLayer::AVE;
73             else if (pool == "stochastic")
74                 type = PoolingLayer::STOCHASTIC;
75             else
76                 CV_Error(Error::StsBadArg, "Unknown pooling type \"" + pool + "\"");
77         }
78
79         getPoolingKernelParams(params, kernel.height, kernel.width, globalPooling,
80                                pad.height, pad.width, stride.height, stride.width, padMode);
81         setParamsFrom(params);
82     }
83
84     void finalize(const std::vector<Mat*> &inputs, std::vector<Mat> &outputs)
85     {
86         CV_Assert(inputs.size() == 1);
87
88         cv::Size inp(inputs[0]->size[3], inputs[0]->size[2]),
89                 out(outputs[0].size[3], outputs[0].size[2]);
90
91         if(globalPooling)
92         {
93             kernel = inp;
94         }
95
96         getConvPoolPaddings(inp, out, kernel, stride, padMode, pad);
97     }
98
99     virtual bool supportBackend(int backendId)
100     {
101         return backendId == DNN_BACKEND_DEFAULT ||
102                backendId == DNN_BACKEND_HALIDE && haveHalide() &&
103                (type == PoolingLayer::MAX ||
104                 type == PoolingLayer::AVE && !pad.width && !pad.height);
105     }
106
107     void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals)
108     {
109         CV_TRACE_FUNCTION();
110         CV_TRACE_ARG_VALUE(name, "name", name.c_str());
111
112         for (size_t ii = 0; ii < inputs.size(); ii++)
113         {
114             switch (type)
115             {
116                 case MAX:
117                     maxPooling(*inputs[ii], outputs[2 * ii], outputs[2 * ii + 1]);
118                     break;
119                 case AVE:
120                     avePooling(*inputs[ii], outputs[ii]);
121                     break;
122                 default:
123                     CV_Error(Error::StsNotImplemented, "Not implemented");
124                     break;
125             }
126         }
127     }
128
129     virtual Ptr<BackendNode> initHalide(const std::vector<Ptr<BackendWrapper> > &inputs)
130     {
131         if (type == PoolingLayer::MAX)
132             return initMaxPoolingHalide(inputs);
133         else if (type == PoolingLayer::AVE)
134             return initAvePoolingHalide(inputs);
135         else
136             return Ptr<BackendNode>();
137     }
138
139     class PoolingInvoker : public ParallelLoopBody
140     {
141     public:
142         const Mat* src;
143         Mat *dst, *mask;
144         Size kernel, stride, pad;
145         int nstripes;
146         bool computeMaxIdx;
147         std::vector<int> ofsbuf;
148         int poolingType;
149
150         PoolingInvoker() {}
151
152         static void run(const Mat& src, Mat& dst, Mat& mask, Size kernel,
153                         Size stride, Size pad, int poolingType,
154                         bool computeMaxIdx, int nstripes)
155         {
156             CV_Assert(src.isContinuous() && dst.isContinuous() &&
157                       src.type() == CV_32F && src.type() == dst.type() &&
158                       src.dims == 4 && dst.dims == 4 &&
159                       src.size[0] == dst.size[0] && src.size[1] == dst.size[1] &&
160                       (mask.empty() || (mask.type() == src.type() && mask.size == dst.size)));
161
162             PoolingInvoker p;
163
164             p.src = &src;
165             p.dst = &dst;
166             p.mask = &mask;
167             p.kernel = kernel;
168             p.stride = stride;
169             p.pad = pad;
170             p.nstripes = nstripes;
171             p.computeMaxIdx = computeMaxIdx;
172             p.poolingType = poolingType;
173
174             if( !computeMaxIdx )
175             {
176                 p.ofsbuf.resize(kernel.width*kernel.height);
177                 for( int i = 0; i < kernel.height; i++ )
178                     for( int j = 0; j < kernel.width; j++ )
179                         p.ofsbuf[i*kernel.width + j] = src.size[3]*i + j;
180             }
181
182             parallel_for_(Range(0, nstripes), p, nstripes);
183         }
184
185         void operator()(const Range& r) const
186         {
187             int channels = dst->size[1], width = dst->size[3], height = dst->size[2];
188             int inp_width = src->size[3], inp_height = src->size[2];
189             size_t total = dst->total();
190             size_t stripeSize = (total + nstripes - 1)/nstripes;
191             size_t stripeStart = r.start*stripeSize;
192             size_t stripeEnd = std::min(r.end*stripeSize, total);
193             int kernel_w = kernel.width, kernel_h = kernel.height;
194             int pad_w = pad.width, pad_h = pad.height;
195             int stride_w = stride.width, stride_h = stride.height;
196             bool compMaxIdx = computeMaxIdx;
197
198 #if CV_SIMD128
199             const int* ofsptr = &ofsbuf[0];
200             v_float32x4 idx00(0.f, (float)stride_w, (float)(stride_w*2), (float)(stride_w*3));
201             v_float32x4 ones = v_setall_f32(1.f);
202             v_float32x4 idx_delta = v_setall_f32((float)(inp_width - kernel_w));
203 #endif
204
205             for( size_t ofs0 = stripeStart; ofs0 < stripeEnd; )
206             {
207                 size_t ofs = ofs0;
208                 int x0 = (int)(ofs % width);
209                 ofs /= width;
210                 int y0 = (int)(ofs % height);
211                 ofs /= height;
212                 int c = (int)(ofs % channels);
213                 int n = (int)(ofs / channels);
214                 int ystart = y0 * stride_h - pad_h;
215                 int yend = min(ystart + kernel_h, inp_height + pad_h);
216                 int ydelta = yend - ystart;
217                 ystart = max(ystart, 0);
218                 yend = min(yend, inp_height);
219                 const float *srcData = src->ptr<float>(n, c);
220                 float *dstData = dst->ptr<float>(n, c, y0);
221                 float *dstMaskData = mask->data ? mask->ptr<float>(n, c, y0) : 0;
222
223                 int delta = std::min((int)(stripeEnd - ofs0), width - x0);
224                 ofs0 += delta;
225                 int x1 = x0 + delta;
226
227                 if( poolingType == PoolingLayer::MAX )
228                     for( ; x0 < x1; x0++ )
229                     {
230                         int xstart = x0 * stride_w - pad_w;
231                         int xend = min(xstart + kernel_w, inp_width);
232                         xstart = max(xstart, 0);
233
234 #if CV_SIMD128
235                         if( xstart > 0 && x0 + 7 < x1 && (x0 + 7) * stride_w - pad_w + kernel_w < inp_width )
236                         {
237                             if( compMaxIdx )
238                             {
239                                 v_float32x4 max_val0 = v_setall_f32(-FLT_MAX);
240                                 v_float32x4 max_val1 = max_val0;
241                                 v_float32x4 max_idx0 = v_setall_f32(-1.f);
242                                 v_float32x4 max_idx1 = max_idx0;
243                                 int index0 = ystart * inp_width + xstart;
244                                 v_float32x4 idx0 = idx00 + v_setall_f32((float)index0);
245                                 v_float32x4 idx1 = idx0 + v_setall_f32((float)(stride_w*4));
246
247                                 for (int y = ystart; y < yend; ++y)
248                                 {
249                                     for (int x = xstart; x < xend; ++x, idx0 += ones, idx1 += ones)
250                                     {
251                                         const int index = y * inp_width + x;
252                                         v_float32x4 v0(srcData[index], srcData[index + stride_w],
253                                                        srcData[index + stride_w*2], srcData[index + stride_w*3]);
254                                         v_float32x4 v1(srcData[index + stride_w*4], srcData[index + stride_w*5],
255                                                        srcData[index + stride_w*6], srcData[index + stride_w*7]);
256                                         max_idx0 = v_select(v0 > max_val0, idx0, max_idx0);
257                                         max_idx1 = v_select(v1 > max_val1, idx1, max_idx1);
258                                         max_val0 = v_max(max_val0, v0);
259                                         max_val1 = v_max(max_val1, v1);
260                                     }
261                                     idx0 += idx_delta;
262                                     idx1 += idx_delta;
263                                 }
264                                 v_store(dstData + x0, max_val0);
265                                 v_store(dstData + x0 + 4, max_val1);
266                                 v_store(dstMaskData + x0, max_idx0);
267                                 v_store(dstMaskData + x0 + 4, max_idx1);
268                                 x0 += 7;
269                             }
270                             else
271                             {
272                                 v_float32x4 max_val0 = v_setall_f32(-FLT_MAX);
273                                 v_float32x4 max_val1 = max_val0;
274
275                                 if( yend - ystart == kernel_h )
276                                 {
277                                     const float* srcData1 = srcData + ystart*inp_width + xstart;
278                                     if( stride_w == 1 )
279                                         for (int k = 0; k < kernel_w*kernel_h; k++)
280                                         {
281                                             int index = ofsptr[k];
282                                             v_float32x4 v0 = v_load(srcData1 + index);
283                                             v_float32x4 v1 = v_load(srcData1 + index + 4);
284                                             max_val0 = v_max(max_val0, v0);
285                                             max_val1 = v_max(max_val1, v1);
286                                         }
287 #if CV_SSE2
288                                     else if( stride_w == 2 )
289                                         for (int k = 0; k < kernel_w*kernel_h; k++)
290                                         {
291                                             int index = ofsptr[k];
292                                             v_float32x4 v00 = v_load(srcData1 + index), v01 = v_load(srcData1 + index + 4);
293                                             v_float32x4 v0(_mm_shuffle_ps(v00.val, v01.val, _MM_SHUFFLE(2, 0, 2, 0)));
294                                             v_float32x4 v10 = v_load(srcData1 + index + 8), v11 = v_load(srcData1 + index + 12);
295                                             v_float32x4 v1(_mm_shuffle_ps(v10.val, v11.val, _MM_SHUFFLE(2, 0, 2, 0)));
296                                             max_val0 = v_max(max_val0, v0);
297                                             max_val1 = v_max(max_val1, v1);
298                                         }
299 #endif
300                                     else
301                                         for (int k = 0; k < kernel_w*kernel_h; k++)
302                                         {
303                                             int index = ofsptr[k];
304                                             v_float32x4 v0(srcData1[index], srcData1[index + stride_w],
305                                                            srcData1[index + stride_w*2], srcData1[index + stride_w*3]);
306                                             v_float32x4 v1(srcData1[index + stride_w*4], srcData1[index + stride_w*5],
307                                                            srcData1[index + stride_w*6], srcData1[index + stride_w*7]);
308                                             max_val0 = v_max(max_val0, v0);
309                                             max_val1 = v_max(max_val1, v1);
310                                         }
311                                 }
312                                 else
313                                 {
314                                     for (int y = ystart; y < yend; ++y)
315                                     {
316                                         for (int x = xstart; x < xend; ++x)
317                                         {
318                                             const int index = y * inp_width + x;
319                                             v_float32x4 v0(srcData[index], srcData[index + stride_w],
320                                                            srcData[index + stride_w*2], srcData[index + stride_w*3]);
321                                             v_float32x4 v1(srcData[index + stride_w*4], srcData[index + stride_w*5],
322                                                            srcData[index + stride_w*6], srcData[index + stride_w*7]);
323                                             max_val0 = v_max(max_val0, v0);
324                                             max_val1 = v_max(max_val1, v1);
325                                         }
326                                     }
327                                 }
328                                 v_store(dstData + x0, max_val0);
329                                 v_store(dstData + x0 + 4, max_val1);
330                                 x0 += 7;
331                             }
332                         }
333                         else
334 #endif
335                         {
336                             float max_val = -FLT_MAX;
337                             if( compMaxIdx )
338                             {
339                                 int max_index = -1;
340                                 for (int y = ystart; y < yend; ++y)
341                                     for (int x = xstart; x < xend; ++x)
342                                     {
343                                         const int index = y * inp_width + x;
344                                         float val = srcData[index];
345                                         if (val > max_val)
346                                         {
347                                             max_val = val;
348                                             max_index = index;
349                                         }
350                                     }
351
352                                 dstData[x0] = max_val;
353                                 dstMaskData[x0] = max_index;
354                             }
355                             else
356                             {
357                                 for (int y = ystart; y < yend; ++y)
358                                     for (int x = xstart; x < xend; ++x)
359                                     {
360                                         const int index = y * inp_width + x;
361                                         float val = srcData[index];
362                                         max_val = std::max(max_val, val);
363                                     }
364
365                                 dstData[x0] = max_val;
366                             }
367                         }
368                     }
369                 else
370                 {
371                     for( ; x0 < x1; x0++ )
372                     {
373                         int xstart = x0 * stride_w - pad_w;
374                         int xend = min(xstart + kernel_w, inp_width + pad_w);
375                         int xdelta = xend - xstart;
376                         xstart = max(xstart, 0);
377                         xend = min(xend, inp_width);
378                         float inv_kernel_area = 1.f/(ydelta*xdelta);
379
380 #if CV_SIMD128
381                         if( xstart > 0 && x0 + 7 < x1 && (x0 + 7) * stride_w - pad_w + kernel_w < inp_width )
382                         {
383                             v_float32x4 sum_val0 = v_setzero_f32(), sum_val1 = v_setzero_f32();
384                             v_float32x4 ikarea = v_setall_f32(inv_kernel_area);
385
386                             for (int y = ystart; y < yend; ++y)
387                             {
388                                 for (int x = xstart; x < xend; ++x)
389                                 {
390                                     const int index = y * inp_width + x;
391                                     v_float32x4 v0(srcData[index], srcData[index + stride_w],
392                                                    srcData[index + stride_w*2], srcData[index + stride_w*3]);
393                                     v_float32x4 v1(srcData[index + stride_w*4], srcData[index + stride_w*5],
394                                                    srcData[index + stride_w*6], srcData[index + stride_w*7]);
395                                     sum_val0 += v0;
396                                     sum_val1 += v1;
397                                 }
398                             }
399                             v_store(dstData + x0, sum_val0*ikarea);
400                             v_store(dstData + x0 + 4, sum_val1*ikarea);
401                             x0 += 7;
402                         }
403                         else
404 #endif
405                         {
406                             float sum_val = 0.f;
407                             for (int y = ystart; y < yend; ++y)
408                                 for (int x = xstart; x < xend; ++x)
409                                 {
410                                     const int index = y * inp_width + x;
411                                     float val = srcData[index];
412                                     sum_val += val;
413                                 }
414
415                             dstData[x0] = sum_val*inv_kernel_area;
416                         }
417                     }
418                 }
419             }
420         }
421     };
422
423     void maxPooling(Mat &src, Mat &dst, Mat &mask)
424     {
425         const int nstripes = getNumThreads();
426         PoolingInvoker::run(src, dst, mask, kernel, stride, pad, type, computeMaxIdx, nstripes);
427     }
428
429     void avePooling(Mat &src, Mat &dst)
430     {
431         const int nstripes = getNumThreads();
432         Mat mask;
433         PoolingInvoker::run(src, dst, mask, kernel, stride, pad, type, computeMaxIdx, nstripes);
434     }
435
436     virtual Ptr<BackendNode> initMaxPoolingHalide(const std::vector<Ptr<BackendWrapper> > &inputs)
437     {
438 #ifdef HAVE_HALIDE
439         Halide::Buffer<float> inputBuffer = halideBuffer(inputs[0]);
440         const int inWidth = inputBuffer.width();
441         const int inHeight = inputBuffer.height();
442
443         Halide::Var x("x"), y("y"), c("c"), n("n");
444         Halide::Func top = (name.empty() ? Halide::Func() : Halide::Func(name));
445         Halide::RDom r(0, kernel.width, 0, kernel.height);
446         Halide::Expr kx, ky;
447         if (pad.width || pad.height)
448         {
449             kx = clamp(x * stride.width + r.x - pad.width, 0, inWidth - 1);
450             ky = clamp(y * stride.height + r.y - pad.height, 0, inHeight - 1);
451         }
452         else
453         {
454             kx = min(x * stride.width + r.x, inWidth - 1);
455             ky = min(y * stride.height + r.y, inHeight - 1);
456         }
457
458         // Halide::argmax returns tuple (r.x, r.y, max).
459         Halide::Tuple res = argmax(inputBuffer(kx, ky, c, n));
460
461         // Compute offset from argmax in range [0, kernel_size).
462         Halide::Expr max_index;
463         if (pad.width || pad.height)
464         {
465             max_index = clamp(y * stride.height + res[1] - pad.height,
466                               0, inHeight - 1) * inWidth +
467                         clamp(x * stride.width + res[0] - pad.width,
468                               0, inWidth - 1);
469         }
470         else
471         {
472             max_index = min(y * stride.height + res[1], inHeight - 1) * inWidth +
473                         min(x * stride.width + res[0], inWidth - 1);
474         }
475         top(x, y, c, n) = { res[2], Halide::cast<float>(max_index) };
476         return Ptr<BackendNode>(new HalideBackendNode(top));
477 #endif  // HAVE_HALIDE
478         return Ptr<BackendNode>();
479     }
480
481     virtual Ptr<BackendNode> initAvePoolingHalide(const std::vector<Ptr<BackendWrapper> > &inputs)
482     {
483 #ifdef HAVE_HALIDE
484         Halide::Buffer<float> inputBuffer = halideBuffer(inputs[0]);
485
486         const int inW = inputBuffer.width(), inH = inputBuffer.height();
487         if ((inW - kernel.width) % stride.width || (inH - kernel.height) % stride.height)
488         {
489             CV_Error(cv::Error::StsNotImplemented,
490                      "Halide backend for average pooling with partial "
491                      "kernels is not implemented");
492         }
493
494         const float norm = 1.0f / (kernel.width * kernel.height);
495
496         Halide::Var x("x"), y("y"), c("c"), n("n");
497         Halide::Func top = (name.empty() ? Halide::Func() : Halide::Func(name));
498         Halide::RDom r(0, kernel.width, 0, kernel.height);
499         top(x, y, c, n) = sum(
500             inputBuffer(x * stride.width + r.x,
501                         y * stride.height + r.y, c, n)) * norm;
502         return Ptr<BackendNode>(new HalideBackendNode(top));
503 #endif  // HAVE_HALIDE
504         return Ptr<BackendNode>();
505     }
506
507     virtual void applyHalideScheduler(Ptr<BackendNode>& node,
508                                       const std::vector<Mat*> &inputs,
509                                       const std::vector<Mat> &outputs,
510                                       int targetId) const
511     {
512 #ifdef  HAVE_HALIDE
513         if (targetId != DNN_TARGET_CPU)
514         {
515             Layer::applyHalideScheduler(node, inputs, outputs, targetId);
516             return;
517         }
518         Halide::Var x("x"), y("y"), c("c"), n("n"), tile("tile"),
519                     xi("xi"), yi("yi"), ci("ci"), xo("xo"), yo("yo"), co("co");
520         Halide::Func& top = node.dynamicCast<HalideBackendNode>()->funcs.back();
521
522         int outW, outH, outC, outN;
523         getCanonicalSize(outputs[0].size, &outW, &outH, &outC, &outN);
524
525         if (outW < 8 || outH < 8)
526         {
527             if (outC > 8)
528                 top.split(c, co, ci, 8)
529                    .fuse(x, y, tile).fuse(co, tile, tile).fuse(n, tile, tile)
530                    .parallel(tile)
531                    .vectorize(ci);
532             else
533             {
534                 top.fuse(y, c, tile).fuse(n, tile, tile)
535                    .parallel(tile);
536                 if (outW > 1)
537                     top.vectorize(x);
538             }
539         }
540         else
541         {
542             if (outC > 8)
543                 top.split(x, xo, xi, 8).split(y, yo, yi, 8).split(c, co, ci, 8)
544                    .fuse(xo, yo, tile).fuse(co, tile, tile).fuse(n, tile, tile)
545                    .parallel(tile)
546                    .vectorize(xi);
547             else
548                 top.split(x, xo, xi, 8).split(y, yo, yi, 8)
549                    .fuse(xo, yo, tile).fuse(c, tile, tile).fuse(n, tile, tile)
550                    .parallel(tile)
551                    .vectorize(xi);
552         }
553 #endif  // HAVE_HALIDE
554     }
555
556     bool getMemoryShapes(const std::vector<MatShape> &inputs,
557                          const int requiredOutputs,
558                          std::vector<MatShape> &outputs,
559                          std::vector<MatShape> &internals) const
560     {
561         CV_Assert(inputs.size() != 0);
562         Size in(inputs[0][3], inputs[0][2]), out;
563
564         if (globalPooling)
565         {
566             out.height = 1;
567             out.width = 1;
568         }
569         else if (padMode.empty())
570         {
571             //Yeah, something strange Caffe scheme-)
572             out.height = static_cast<int>(ceil(static_cast<float>(in.height + 2 * pad.height -
573                                                                   kernel.height) / stride.height)) + 1;
574             out.width = static_cast<int>(ceil(static_cast<float>(in.width + 2 * pad.width -
575                                                                  kernel.width) / stride.width)) + 1;
576
577             if (pad.height || pad.width)
578             {
579                 // If we have padding, ensure that the last pooling starts strictly
580                 // inside the image (instead of at the padding); otherwise clip the last.
581                 if ((out.height - 1) * stride.height >= in.height + pad.height)
582                     --out.height;
583                 if ((out.width - 1) * stride.width >= in.width + pad.width)
584                     --out.width;
585                 CV_Assert((out.height - 1) * stride.height < in.height + pad.height);
586                 CV_Assert((out.width - 1) * stride.width < in.width + pad.width);
587             }
588         }
589         else
590         {
591             getConvPoolOutParams(in, kernel, stride,
592                                  padMode, out);
593         }
594
595         outputs.resize(type == MAX ? 2 * inputs.size() : inputs.size());
596         for (size_t i = 0; i < inputs.size(); i++)
597         {
598             size_t index = type == MAX ? 2*i : i;
599             int dims[] = {inputs[i][0], inputs[i][1], out.height, out.width};
600             outputs[index] = shape(dims);
601
602             if (type == MAX)
603                 outputs[index + 1] = shape(dims);
604         }
605
606         return false;
607     }
608
609     virtual int64 getFLOPS(const std::vector<MatShape> &inputs,
610                            const std::vector<MatShape> &outputs) const
611     {
612         (void)inputs; // suppress unused variable warning
613         long flops = 0;
614
615         for(int i = 0; i < outputs.size(); i++)
616         {
617             if (type == MAX)
618             {
619                 if (i%2 == 0)
620                     flops += total(outputs[i])*kernel.area();
621             }
622             else
623             {
624                 flops += total(outputs[i])*(kernel.area() + 1);
625             }
626         }
627         return flops;
628     }
629 };
630
631 Ptr<PoolingLayer> PoolingLayer::create(const LayerParams& params)
632 {
633     return Ptr<PoolingLayer>(new PoolingLayerImpl(params));
634 }
635
636 }
637 }