Merge pull request #9041 from terfendail:filter_avx
[platform/upstream/opencv.git] / modules / dnn / src / layers / pooling_layer.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                           License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2013, OpenCV Foundation, all rights reserved.
14 // Copyright (C) 2017, Intel Corporation, all rights reserved.
15 // Third party copyrights are property of their respective owners.
16 //
17 // Redistribution and use in source and binary forms, with or without modification,
18 // are permitted provided that the following conditions are met:
19 //
20 //   * Redistribution's of source code must retain the above copyright notice,
21 //     this list of conditions and the following disclaimer.
22 //
23 //   * Redistribution's in binary form must reproduce the above copyright notice,
24 //     this list of conditions and the following disclaimer in the documentation
25 //     and/or other materials provided with the distribution.
26 //
27 //   * The name of the copyright holders may not be used to endorse or promote products
28 //     derived from this software without specific prior written permission.
29 //
30 // This software is provided by the copyright holders and contributors "as is" and
31 // any express or implied warranties, including, but not limited to, the implied
32 // warranties of merchantability and fitness for a particular purpose are disclaimed.
33 // In no event shall the Intel Corporation or contributors be liable for any direct,
34 // indirect, incidental, special, exemplary, or consequential damages
35 // (including, but not limited to, procurement of substitute goods or services;
36 // loss of use, data, or profits; or business interruption) however caused
37 // and on any theory of liability, whether in contract, strict liability,
38 // or tort (including negligence or otherwise) arising in any way out of
39 // the use of this software, even if advised of the possibility of such damage.
40 //
41 //M*/
42
43 #include "../precomp.hpp"
44 #include "layers_common.hpp"
45 #include "opencv2/core/hal/intrin.hpp"
46 #include "op_halide.hpp"
47 #include <float.h>
48 #include <algorithm>
49 using std::max;
50 using std::min;
51
52 namespace cv
53 {
54 namespace dnn
55 {
56
57 //TODO: add ceil_mode param
58 class PoolingLayerImpl : public PoolingLayer
59 {
60 public:
61     PoolingLayerImpl(const LayerParams& params)
62     {
63         type = PoolingLayer::MAX;
64         computeMaxIdx = true;
65
66         if (params.has("pool"))
67         {
68             String pool = params.get<String>("pool").toLowerCase();
69             if (pool == "max")
70                 type = PoolingLayer::MAX;
71             else if (pool == "ave")
72                 type = PoolingLayer::AVE;
73             else if (pool == "stochastic")
74                 type = PoolingLayer::STOCHASTIC;
75             else
76                 CV_Error(Error::StsBadArg, "Unknown pooling type \"" + pool + "\"");
77         }
78
79         getPoolingKernelParams(params, kernel.height, kernel.width, globalPooling,
80                                pad.height, pad.width, stride.height, stride.width, padMode);
81         setParamsFrom(params);
82     }
83
84     void finalize(const std::vector<Mat*> &inputs, std::vector<Mat> &outputs)
85     {
86         CV_Assert(inputs.size() == 1);
87
88         cv::Size inp(inputs[0]->size[3], inputs[0]->size[2]),
89                 out(outputs[0].size[3], outputs[0].size[2]);
90
91         if(globalPooling)
92         {
93             kernel = inp;
94         }
95
96         getConvPoolPaddings(inp, out, kernel, stride, padMode, pad);
97     }
98
99     virtual bool supportBackend(int backendId)
100     {
101         return backendId == DNN_BACKEND_DEFAULT ||
102                backendId == DNN_BACKEND_HALIDE && haveHalide() &&
103                (type == PoolingLayer::MAX ||
104                 type == PoolingLayer::AVE && !pad.width && !pad.height);
105     }
106
107     void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals)
108     {
109         CV_TRACE_FUNCTION();
110         CV_TRACE_ARG_VALUE(name, "name", name.c_str());
111
112         for (size_t ii = 0; ii < inputs.size(); ii++)
113         {
114             switch (type)
115             {
116                 case MAX:
117                     maxPooling(*inputs[ii], outputs[2 * ii], outputs[2 * ii + 1]);
118                     break;
119                 case AVE:
120                     avePooling(*inputs[ii], outputs[ii]);
121                     break;
122                 default:
123                     CV_Error(Error::StsNotImplemented, "Not implemented");
124                     break;
125             }
126         }
127     }
128
129     virtual Ptr<BackendNode> initHalide(const std::vector<Ptr<BackendWrapper> > &inputs)
130     {
131         if (type == PoolingLayer::MAX)
132             return initMaxPoolingHalide(inputs);
133         else if (type == PoolingLayer::AVE)
134             return initAvePoolingHalide(inputs);
135         else
136             return Ptr<BackendNode>();
137     }
138
139     class PoolingInvoker : public ParallelLoopBody
140     {
141     public:
142         const Mat* src;
143         Mat *dst, *mask;
144         Size kernel, stride, pad;
145         int nstripes;
146         bool computeMaxIdx;
147         std::vector<int> ofsbuf;
148         int poolingType;
149
150         PoolingInvoker() : src(0), dst(0), mask(0), nstripes(0), computeMaxIdx(0), poolingType(PoolingLayer::MAX) {}
151
152         static void run(const Mat& src, Mat& dst, Mat& mask, Size kernel,
153                         Size stride, Size pad, int poolingType,
154                         bool computeMaxIdx, int nstripes)
155         {
156             CV_Assert(src.isContinuous() && dst.isContinuous() &&
157                       src.type() == CV_32F && src.type() == dst.type() &&
158                       src.dims == 4 && dst.dims == 4 &&
159                       src.size[0] == dst.size[0] && src.size[1] == dst.size[1] &&
160                       (mask.empty() || (mask.type() == src.type() && mask.size == dst.size)));
161
162             PoolingInvoker p;
163
164             p.src = &src;
165             p.dst = &dst;
166             p.mask = &mask;
167             p.kernel = kernel;
168             p.stride = stride;
169             p.pad = pad;
170             p.nstripes = nstripes;
171             p.computeMaxIdx = computeMaxIdx;
172             p.poolingType = poolingType;
173
174             if( !computeMaxIdx )
175             {
176                 p.ofsbuf.resize(kernel.width*kernel.height);
177                 for( int i = 0; i < kernel.height; i++ )
178                     for( int j = 0; j < kernel.width; j++ )
179                         p.ofsbuf[i*kernel.width + j] = src.size[3]*i + j;
180             }
181
182             parallel_for_(Range(0, nstripes), p, nstripes);
183         }
184
185         void operator()(const Range& r) const
186         {
187             int channels = dst->size[1], width = dst->size[3], height = dst->size[2];
188             int inp_width = src->size[3], inp_height = src->size[2];
189             size_t total = dst->total();
190             size_t stripeSize = (total + nstripes - 1)/nstripes;
191             size_t stripeStart = r.start*stripeSize;
192             size_t stripeEnd = std::min(r.end*stripeSize, total);
193             int kernel_w = kernel.width, kernel_h = kernel.height;
194             int pad_w = pad.width, pad_h = pad.height;
195             int stride_w = stride.width, stride_h = stride.height;
196             bool compMaxIdx = computeMaxIdx;
197
198 #if CV_SIMD128
199             const int* ofsptr = &ofsbuf[0];
200             v_float32x4 idx00(0.f, (float)stride_w, (float)(stride_w*2), (float)(stride_w*3));
201             v_float32x4 ones = v_setall_f32(1.f);
202             v_float32x4 idx_delta = v_setall_f32((float)(inp_width - kernel_w));
203 #endif
204
205             for( size_t ofs0 = stripeStart; ofs0 < stripeEnd; )
206             {
207                 size_t ofs = ofs0;
208                 int x0 = (int)(ofs % width);
209                 ofs /= width;
210                 int y0 = (int)(ofs % height);
211                 ofs /= height;
212                 int c = (int)(ofs % channels);
213                 int n = (int)(ofs / channels);
214                 int ystart = y0 * stride_h - pad_h;
215                 int yend = min(ystart + kernel_h, inp_height + pad_h);
216                 int ydelta = yend - ystart;
217                 ystart = max(ystart, 0);
218                 yend = min(yend, inp_height);
219                 const float *srcData = src->ptr<float>(n, c);
220                 float *dstData = dst->ptr<float>(n, c, y0);
221                 float *dstMaskData = mask->data ? mask->ptr<float>(n, c, y0) : 0;
222
223                 int delta = std::min((int)(stripeEnd - ofs0), width - x0);
224                 ofs0 += delta;
225                 int x1 = x0 + delta;
226
227                 if( poolingType == PoolingLayer::MAX )
228                     for( ; x0 < x1; x0++ )
229                     {
230                         int xstart = x0 * stride_w - pad_w;
231                         int xend = min(xstart + kernel_w, inp_width);
232                         xstart = max(xstart, 0);
233
234 #if CV_SIMD128
235                         if( xstart > 0 && x0 + 7 < x1 && (x0 + 7) * stride_w - pad_w + kernel_w < inp_width )
236                         {
237                             if( compMaxIdx )
238                             {
239                                 v_float32x4 max_val0 = v_setall_f32(-FLT_MAX);
240                                 v_float32x4 max_val1 = max_val0;
241                                 v_float32x4 max_idx0 = v_setall_f32(-1.f);
242                                 v_float32x4 max_idx1 = max_idx0;
243                                 int index0 = ystart * inp_width + xstart;
244                                 v_float32x4 idx0 = idx00 + v_setall_f32((float)index0);
245                                 v_float32x4 idx1 = idx0 + v_setall_f32((float)(stride_w*4));
246
247                                 for (int y = ystart; y < yend; ++y)
248                                 {
249                                     for (int x = xstart; x < xend; ++x, idx0 += ones, idx1 += ones)
250                                     {
251                                         const int index = y * inp_width + x;
252                                         v_float32x4 v0(srcData[index], srcData[index + stride_w],
253                                                        srcData[index + stride_w*2], srcData[index + stride_w*3]);
254                                         v_float32x4 v1(srcData[index + stride_w*4], srcData[index + stride_w*5],
255                                                        srcData[index + stride_w*6], srcData[index + stride_w*7]);
256                                         max_idx0 = v_select(v0 > max_val0, idx0, max_idx0);
257                                         max_idx1 = v_select(v1 > max_val1, idx1, max_idx1);
258                                         max_val0 = v_max(max_val0, v0);
259                                         max_val1 = v_max(max_val1, v1);
260                                     }
261                                     idx0 += idx_delta;
262                                     idx1 += idx_delta;
263                                 }
264                                 v_store(dstData + x0, max_val0);
265                                 v_store(dstData + x0 + 4, max_val1);
266                                 if (dstMaskData)
267                                 {
268                                     v_store(dstMaskData + x0, max_idx0);
269                                     v_store(dstMaskData + x0 + 4, max_idx1);
270                                 }
271                                 x0 += 7;
272                             }
273                             else
274                             {
275                                 v_float32x4 max_val0 = v_setall_f32(-FLT_MAX);
276                                 v_float32x4 max_val1 = max_val0;
277
278                                 if( yend - ystart == kernel_h )
279                                 {
280                                     const float* srcData1 = srcData + ystart*inp_width + xstart;
281                                     if( stride_w == 1 )
282                                         for (int k = 0; k < kernel_w*kernel_h; k++)
283                                         {
284                                             int index = ofsptr[k];
285                                             v_float32x4 v0 = v_load(srcData1 + index);
286                                             v_float32x4 v1 = v_load(srcData1 + index + 4);
287                                             max_val0 = v_max(max_val0, v0);
288                                             max_val1 = v_max(max_val1, v1);
289                                         }
290 #if CV_SSE2
291                                     else if( stride_w == 2 )
292                                         for (int k = 0; k < kernel_w*kernel_h; k++)
293                                         {
294                                             int index = ofsptr[k];
295                                             v_float32x4 v00 = v_load(srcData1 + index), v01 = v_load(srcData1 + index + 4);
296                                             v_float32x4 v0(_mm_shuffle_ps(v00.val, v01.val, _MM_SHUFFLE(2, 0, 2, 0)));
297                                             v_float32x4 v10 = v_load(srcData1 + index + 8), v11 = v_load(srcData1 + index + 12);
298                                             v_float32x4 v1(_mm_shuffle_ps(v10.val, v11.val, _MM_SHUFFLE(2, 0, 2, 0)));
299                                             max_val0 = v_max(max_val0, v0);
300                                             max_val1 = v_max(max_val1, v1);
301                                         }
302 #endif
303                                     else
304                                         for (int k = 0; k < kernel_w*kernel_h; k++)
305                                         {
306                                             int index = ofsptr[k];
307                                             v_float32x4 v0(srcData1[index], srcData1[index + stride_w],
308                                                            srcData1[index + stride_w*2], srcData1[index + stride_w*3]);
309                                             v_float32x4 v1(srcData1[index + stride_w*4], srcData1[index + stride_w*5],
310                                                            srcData1[index + stride_w*6], srcData1[index + stride_w*7]);
311                                             max_val0 = v_max(max_val0, v0);
312                                             max_val1 = v_max(max_val1, v1);
313                                         }
314                                 }
315                                 else
316                                 {
317                                     for (int y = ystart; y < yend; ++y)
318                                     {
319                                         for (int x = xstart; x < xend; ++x)
320                                         {
321                                             const int index = y * inp_width + x;
322                                             v_float32x4 v0(srcData[index], srcData[index + stride_w],
323                                                            srcData[index + stride_w*2], srcData[index + stride_w*3]);
324                                             v_float32x4 v1(srcData[index + stride_w*4], srcData[index + stride_w*5],
325                                                            srcData[index + stride_w*6], srcData[index + stride_w*7]);
326                                             max_val0 = v_max(max_val0, v0);
327                                             max_val1 = v_max(max_val1, v1);
328                                         }
329                                     }
330                                 }
331                                 v_store(dstData + x0, max_val0);
332                                 v_store(dstData + x0 + 4, max_val1);
333                                 x0 += 7;
334                             }
335                         }
336                         else
337 #endif
338                         {
339                             float max_val = -FLT_MAX;
340                             if( compMaxIdx )
341                             {
342                                 int max_index = -1;
343                                 for (int y = ystart; y < yend; ++y)
344                                     for (int x = xstart; x < xend; ++x)
345                                     {
346                                         const int index = y * inp_width + x;
347                                         float val = srcData[index];
348                                         if (val > max_val)
349                                         {
350                                             max_val = val;
351                                             max_index = index;
352                                         }
353                                     }
354
355                                 dstData[x0] = max_val;
356                                 if (dstMaskData)
357                                     dstMaskData[x0] = max_index;
358                             }
359                             else
360                             {
361                                 for (int y = ystart; y < yend; ++y)
362                                     for (int x = xstart; x < xend; ++x)
363                                     {
364                                         const int index = y * inp_width + x;
365                                         float val = srcData[index];
366                                         max_val = std::max(max_val, val);
367                                     }
368
369                                 dstData[x0] = max_val;
370                             }
371                         }
372                     }
373                 else
374                 {
375                     for( ; x0 < x1; x0++ )
376                     {
377                         int xstart = x0 * stride_w - pad_w;
378                         int xend = min(xstart + kernel_w, inp_width + pad_w);
379                         int xdelta = xend - xstart;
380                         xstart = max(xstart, 0);
381                         xend = min(xend, inp_width);
382                         float inv_kernel_area = 1.f/(ydelta*xdelta);
383
384 #if CV_SIMD128
385                         if( xstart > 0 && x0 + 7 < x1 && (x0 + 7) * stride_w - pad_w + kernel_w < inp_width )
386                         {
387                             v_float32x4 sum_val0 = v_setzero_f32(), sum_val1 = v_setzero_f32();
388                             v_float32x4 ikarea = v_setall_f32(inv_kernel_area);
389
390                             for (int y = ystart; y < yend; ++y)
391                             {
392                                 for (int x = xstart; x < xend; ++x)
393                                 {
394                                     const int index = y * inp_width + x;
395                                     v_float32x4 v0(srcData[index], srcData[index + stride_w],
396                                                    srcData[index + stride_w*2], srcData[index + stride_w*3]);
397                                     v_float32x4 v1(srcData[index + stride_w*4], srcData[index + stride_w*5],
398                                                    srcData[index + stride_w*6], srcData[index + stride_w*7]);
399                                     sum_val0 += v0;
400                                     sum_val1 += v1;
401                                 }
402                             }
403                             v_store(dstData + x0, sum_val0*ikarea);
404                             v_store(dstData + x0 + 4, sum_val1*ikarea);
405                             x0 += 7;
406                         }
407                         else
408 #endif
409                         {
410                             float sum_val = 0.f;
411                             for (int y = ystart; y < yend; ++y)
412                                 for (int x = xstart; x < xend; ++x)
413                                 {
414                                     const int index = y * inp_width + x;
415                                     float val = srcData[index];
416                                     sum_val += val;
417                                 }
418
419                             dstData[x0] = sum_val*inv_kernel_area;
420                         }
421                     }
422                 }
423             }
424         }
425     };
426
427     void maxPooling(Mat &src, Mat &dst, Mat &mask)
428     {
429         const int nstripes = getNumThreads();
430         PoolingInvoker::run(src, dst, mask, kernel, stride, pad, type, computeMaxIdx, nstripes);
431     }
432
433     void avePooling(Mat &src, Mat &dst)
434     {
435         const int nstripes = getNumThreads();
436         Mat mask;
437         PoolingInvoker::run(src, dst, mask, kernel, stride, pad, type, computeMaxIdx, nstripes);
438     }
439
440     virtual Ptr<BackendNode> initMaxPoolingHalide(const std::vector<Ptr<BackendWrapper> > &inputs)
441     {
442 #ifdef HAVE_HALIDE
443         Halide::Buffer<float> inputBuffer = halideBuffer(inputs[0]);
444         const int inWidth = inputBuffer.width();
445         const int inHeight = inputBuffer.height();
446
447         Halide::Var x("x"), y("y"), c("c"), n("n");
448         Halide::Func top = (name.empty() ? Halide::Func() : Halide::Func(name));
449         Halide::RDom r(0, kernel.width, 0, kernel.height);
450         Halide::Expr kx, ky;
451         if (pad.width || pad.height)
452         {
453             kx = clamp(x * stride.width + r.x - pad.width, 0, inWidth - 1);
454             ky = clamp(y * stride.height + r.y - pad.height, 0, inHeight - 1);
455         }
456         else
457         {
458             kx = min(x * stride.width + r.x, inWidth - 1);
459             ky = min(y * stride.height + r.y, inHeight - 1);
460         }
461
462         // Halide::argmax returns tuple (r.x, r.y, max).
463         Halide::Tuple res = argmax(inputBuffer(kx, ky, c, n));
464
465         // Compute offset from argmax in range [0, kernel_size).
466         Halide::Expr max_index;
467         if (pad.width || pad.height)
468         {
469             max_index = clamp(y * stride.height + res[1] - pad.height,
470                               0, inHeight - 1) * inWidth +
471                         clamp(x * stride.width + res[0] - pad.width,
472                               0, inWidth - 1);
473         }
474         else
475         {
476             max_index = min(y * stride.height + res[1], inHeight - 1) * inWidth +
477                         min(x * stride.width + res[0], inWidth - 1);
478         }
479         top(x, y, c, n) = { res[2], Halide::cast<float>(max_index) };
480         return Ptr<BackendNode>(new HalideBackendNode(top));
481 #endif  // HAVE_HALIDE
482         return Ptr<BackendNode>();
483     }
484
485     virtual Ptr<BackendNode> initAvePoolingHalide(const std::vector<Ptr<BackendWrapper> > &inputs)
486     {
487 #ifdef HAVE_HALIDE
488         Halide::Buffer<float> inputBuffer = halideBuffer(inputs[0]);
489
490         const int inW = inputBuffer.width(), inH = inputBuffer.height();
491         if ((inW - kernel.width) % stride.width || (inH - kernel.height) % stride.height)
492         {
493             CV_Error(cv::Error::StsNotImplemented,
494                      "Halide backend for average pooling with partial "
495                      "kernels is not implemented");
496         }
497
498         const float norm = 1.0f / (kernel.width * kernel.height);
499
500         Halide::Var x("x"), y("y"), c("c"), n("n");
501         Halide::Func top = (name.empty() ? Halide::Func() : Halide::Func(name));
502         Halide::RDom r(0, kernel.width, 0, kernel.height);
503         top(x, y, c, n) = sum(
504             inputBuffer(x * stride.width + r.x,
505                         y * stride.height + r.y, c, n)) * norm;
506         return Ptr<BackendNode>(new HalideBackendNode(top));
507 #endif  // HAVE_HALIDE
508         return Ptr<BackendNode>();
509     }
510
511     virtual void applyHalideScheduler(Ptr<BackendNode>& node,
512                                       const std::vector<Mat*> &inputs,
513                                       const std::vector<Mat> &outputs,
514                                       int targetId) const
515     {
516 #ifdef  HAVE_HALIDE
517         if (targetId != DNN_TARGET_CPU)
518         {
519             Layer::applyHalideScheduler(node, inputs, outputs, targetId);
520             return;
521         }
522         Halide::Var x("x"), y("y"), c("c"), n("n"), tile("tile"),
523                     xi("xi"), yi("yi"), ci("ci"), xo("xo"), yo("yo"), co("co");
524         Halide::Func& top = node.dynamicCast<HalideBackendNode>()->funcs.back();
525
526         int outW, outH, outC, outN;
527         getCanonicalSize(outputs[0].size, &outW, &outH, &outC, &outN);
528
529         if (outW < 8 || outH < 8)
530         {
531             if (outC > 8)
532                 top.split(c, co, ci, 8)
533                    .fuse(x, y, tile).fuse(co, tile, tile).fuse(n, tile, tile)
534                    .parallel(tile)
535                    .vectorize(ci);
536             else
537             {
538                 top.fuse(y, c, tile).fuse(n, tile, tile)
539                    .parallel(tile);
540                 if (outW > 1)
541                     top.vectorize(x);
542             }
543         }
544         else
545         {
546             if (outC > 8)
547                 top.split(x, xo, xi, 8).split(y, yo, yi, 8).split(c, co, ci, 8)
548                    .fuse(xo, yo, tile).fuse(co, tile, tile).fuse(n, tile, tile)
549                    .parallel(tile)
550                    .vectorize(xi);
551             else
552                 top.split(x, xo, xi, 8).split(y, yo, yi, 8)
553                    .fuse(xo, yo, tile).fuse(c, tile, tile).fuse(n, tile, tile)
554                    .parallel(tile)
555                    .vectorize(xi);
556         }
557 #endif  // HAVE_HALIDE
558     }
559
560     bool getMemoryShapes(const std::vector<MatShape> &inputs,
561                          const int requiredOutputs,
562                          std::vector<MatShape> &outputs,
563                          std::vector<MatShape> &internals) const
564     {
565         CV_Assert(inputs.size() != 0);
566         Size in(inputs[0][3], inputs[0][2]), out;
567
568         if (globalPooling)
569         {
570             out.height = 1;
571             out.width = 1;
572         }
573         else if (padMode.empty())
574         {
575             //Yeah, something strange Caffe scheme-)
576             out.height = static_cast<int>(ceil(static_cast<float>(in.height + 2 * pad.height -
577                                                                   kernel.height) / stride.height)) + 1;
578             out.width = static_cast<int>(ceil(static_cast<float>(in.width + 2 * pad.width -
579                                                                  kernel.width) / stride.width)) + 1;
580
581             if (pad.height || pad.width)
582             {
583                 // If we have padding, ensure that the last pooling starts strictly
584                 // inside the image (instead of at the padding); otherwise clip the last.
585                 if ((out.height - 1) * stride.height >= in.height + pad.height)
586                     --out.height;
587                 if ((out.width - 1) * stride.width >= in.width + pad.width)
588                     --out.width;
589                 CV_Assert((out.height - 1) * stride.height < in.height + pad.height);
590                 CV_Assert((out.width - 1) * stride.width < in.width + pad.width);
591             }
592         }
593         else
594         {
595             getConvPoolOutParams(in, kernel, stride,
596                                  padMode, out);
597         }
598
599         outputs.resize(type == MAX ? 2 * inputs.size() : inputs.size());
600         for (size_t i = 0; i < inputs.size(); i++)
601         {
602             size_t index = type == MAX ? 2*i : i;
603             int dims[] = {inputs[i][0], inputs[i][1], out.height, out.width};
604             outputs[index] = shape(dims);
605
606             if (type == MAX)
607                 outputs[index + 1] = shape(dims);
608         }
609
610         return false;
611     }
612
613     virtual int64 getFLOPS(const std::vector<MatShape> &inputs,
614                            const std::vector<MatShape> &outputs) const
615     {
616         (void)inputs; // suppress unused variable warning
617         long flops = 0;
618
619         for(int i = 0; i < outputs.size(); i++)
620         {
621             if (type == MAX)
622             {
623                 if (i%2 == 0)
624                     flops += total(outputs[i])*kernel.area();
625             }
626             else
627             {
628                 flops += total(outputs[i])*(kernel.area() + 1);
629             }
630         }
631         return flops;
632     }
633 };
634
635 Ptr<PoolingLayer> PoolingLayer::create(const LayerParams& params)
636 {
637     return Ptr<PoolingLayer>(new PoolingLayerImpl(params));
638 }
639
640 }
641 }