From bb64db98d89ae9806eb792cca42dd43c06ebd5e8 Mon Sep 17 00:00:00 2001 From: Zihao Mu Date: Fri, 26 Aug 2022 17:57:25 +0800 Subject: [PATCH] Further optimization of Conv2D, fused Conv_Add_Activation, bring latest code from ficus OpConv.fx. (#22401) --- modules/dnn/include/opencv2/dnn/all_layers.hpp | 3 + modules/dnn/src/dnn_common.hpp | 1 + modules/dnn/src/layers/convolution_layer.cpp | 42 +- .../fast_convolution/fast_convolution.avx2.cpp | 91 ++- .../layers/fast_convolution/fast_convolution.cpp | 843 +++++++++------------ .../layers/fast_convolution/fast_convolution.hpp | 39 +- .../fast_convolution/fast_convolution.simd.hpp | 508 ++++++------- .../layers/fast_convolution/winograd_3x3s1_f63.cpp | 406 ++++++++-- modules/dnn/src/net_impl_fuse.cpp | 174 ++++- modules/dnn/test/test_caffe_importer.cpp | 2 +- modules/dnn/test/test_int8_layers.cpp | 2 +- modules/dnn/test/test_torch_importer.cpp | 4 +- 12 files changed, 1202 insertions(+), 913 deletions(-) diff --git a/modules/dnn/include/opencv2/dnn/all_layers.hpp b/modules/dnn/include/opencv2/dnn/all_layers.hpp index 263e48c..66ba087 100644 --- a/modules/dnn/include/opencv2/dnn/all_layers.hpp +++ b/modules/dnn/include/opencv2/dnn/all_layers.hpp @@ -256,6 +256,9 @@ CV__DNN_INLINE_NS_BEGIN { public: static Ptr create(const LayerParams& params); + bool fusedActivation = false; + bool fusedAdd = false; + bool isConv2D = false; // Should be deleted after fastconv branch support Conv1D and Conv3D. }; class CV_EXPORTS ConvolutionLayerInt8 : public BaseConvolutionLayer diff --git a/modules/dnn/src/dnn_common.hpp b/modules/dnn/src/dnn_common.hpp index ae4d9c2..b580b9f 100644 --- a/modules/dnn/src/dnn_common.hpp +++ b/modules/dnn/src/dnn_common.hpp @@ -13,6 +13,7 @@ namespace cv { namespace dnn { CV__DNN_INLINE_NS_BEGIN #define IS_DNN_OPENCL_TARGET(id) (id == DNN_TARGET_OPENCL || id == DNN_TARGET_OPENCL_FP16) +#define IS_DNN_CPU_TARGET(id) (id == DNN_TARGET_CPU) // TODO: add DNN_TARGET_CPU_FP16 Mutex& getInitializationMutex(); void initializeLayerFactory(); diff --git a/modules/dnn/src/layers/convolution_layer.cpp b/modules/dnn/src/layers/convolution_layer.cpp index c2960d5..e6ef9f1 100644 --- a/modules/dnn/src/layers/convolution_layer.cpp +++ b/modules/dnn/src/layers/convolution_layer.cpp @@ -118,6 +118,9 @@ public: fusedWeights = false; fusedBias = false; + + if (kernel_size.size() == 2) + isConv2D = true; } virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE @@ -188,6 +191,9 @@ public: virtual bool tryFuse(Ptr& top) CV_OVERRIDE { + if (fusedAdd) // If the Conv layer has fused Add layer, it cannot fuse other layers. + return false; + Ptr blank_layer = top.dynamicCast(); if (blank_layer) return true; @@ -260,7 +266,6 @@ public: std::vector reluslope; Ptr activ; - Mat fastWeights; // Used to store weight params. It will be used for layer fusion and without memory alignment. Ptr fastConv2dImpl; #ifdef HAVE_OPENCL @@ -438,7 +443,6 @@ public: wm.copyTo(wm_aligned); wm = wm_aligned; } - fastWeights = blobs[0].reshape(1, numOutput); weightsMat = wm; } else @@ -584,11 +588,15 @@ public: } } #endif - return !activ.empty(); + fusedActivation = !activ.empty(); + return fusedActivation; } virtual bool tryFuse(Ptr& top) CV_OVERRIDE { + if (fusedAdd) // If the Conv layer has fused Add layer, it cannot fuse other layers. + return false; + #ifdef HAVE_CUDA if(IS_DNN_CUDA_TARGET(preferableTarget)) { @@ -634,26 +642,14 @@ public: if (weightsMat.data == blobs[0].data) weightsMat = weightsMat.clone(); - // If fastWeights is the same as weightsMat, we don't need to allocate more space for fastWeights. - bool sameFastWeights = false; - if (fastWeights.step1() == weightsMat.step1()) // If weightsMat is realigned, it is not the same as fastWeights. - sameFastWeights = true; - - if (!sameFastWeights && fastWeights.data == blobs[0].data) - fastWeights = fastWeights.clone(); - Mat originWeights = blobs[0].reshape(1, outCn); for (int i = 0; i < outCn; ++i) { double wi = w.at(i); weightsMultipliers[i] *= wi; cv::multiply(originWeights.row(i), weightsMultipliers[i], weightsMat.row(i)); - if (!sameFastWeights) - cv::multiply(originWeights.row(i), weightsMultipliers[i], fastWeights.row(i)); biasvec[i] *= wi; } - if (sameFastWeights) - fastWeights = weightsMat; } if (!b.empty()) @@ -1970,9 +1966,6 @@ public: if (blobs.empty()) { variableWeight = true; - if (fastWeights.data != inputs[1].data) - fastWeights = inputs[1].clone(); - Mat wm = inputs[1].reshape(1, outCn); if (wm.data != weightsMat.data) { @@ -2089,7 +2082,7 @@ public: { int nstripes = std::max(getNumThreads(), 1); - // Initialization of FastCovn2d + // Initialization of FastCovn2d, pack weight. if ((!fastConv2dImpl || variableWeight) && inputs[0].dims == 4) { int K = outputs[0].size[1]; @@ -2103,23 +2096,22 @@ public: int dilation_h = dilations[dilations.size() - 2]; int dilation_w = dilations.back(); - float* weightsPtr = fastWeights.ptr(); - CV_Assert(weightsPtr); - fastConv2dImpl = initFastConv2d(ngroups, K, C, Hk, Wk, stride_w, stride_h, - dilation_w, dilation_h, pads_begin, pads_end, weightsPtr, &biasvec[0]); + fastConv2dImpl = initFastConv2d(ngroups, K, C, Hk, Wk, stride_w, stride_h, dilation_w, + dilation_h, pads_begin, pads_end, weightsMat, &biasvec[0]); } if (fastConv2dImpl) { - runFastConv2d(inputs[0], outputs[0], fastConv2dImpl, nstripes, activ); + runFastConv2d(inputs[0], outputs[0], fastConv2dImpl, nstripes, activ, fusedAdd); return; } + //TODO: Add support of Conv1D and Conv3D to fastConv, and remove the old Conv branch. // Use only for Conv1D and Conv3D. + CV_Assert(!fusedAdd); ParallelConv::run(inputs[0], outputs[0], weightsMat, biasvec, reluslope, kernel_size, strides, pads_begin, pads_end, dilations, activ.get(), ngroups, nstripes); - } } diff --git a/modules/dnn/src/layers/fast_convolution/fast_convolution.avx2.cpp b/modules/dnn/src/layers/fast_convolution/fast_convolution.avx2.cpp index 22580c5..de1a2ef 100644 --- a/modules/dnn/src/layers/fast_convolution/fast_convolution.avx2.cpp +++ b/modules/dnn/src/layers/fast_convolution/fast_convolution.avx2.cpp @@ -9,67 +9,67 @@ namespace cv { namespace opt_AVX2 { #if CV_TRY_AVX2 -void convBlock_AVX2(int k, const float *a, const float *b, - float *c, int ldc, const float *bias, - float minval, float maxval, bool ifActiv) +void convBlock_AVX2(int np, const float* a, const float* b, float* c, int ldc, bool init_c) { -#if FAST_CONV_MR == 4 && FAST_CONV_NR == 24 - __m256 vminval = _mm256_set1_ps(minval), vmaxval = _mm256_set1_ps(maxval); - __m256 c0 = _mm256_set1_ps(bias[0]), c1 = c0, c2 = c0; - __m256 c3 = _mm256_set1_ps(bias[1]), c4 = c3, c5 = c3; - __m256 c6 = _mm256_set1_ps(bias[2]), c7 = c6, c8 = c6; - __m256 c9 = _mm256_set1_ps(bias[3]), c10 = c9, c11 = c9; +#if CONV_MR == 4 && CONV_NR == 24 + __m256 c00 = _mm256_set1_ps(0.f), c01 = c00, c02 = c00; + __m256 c10 = c00, c11 = c00, c12 = c00; + __m256 c20 = c00, c21 = c00, c22 = c00; + __m256 c30 = c00, c31 = c00, c32 = c00; __m256 a0 = _mm256_setzero_ps(), a1 = _mm256_setzero_ps(); __m256 b0 = _mm256_setzero_ps(), b1 = _mm256_setzero_ps(), b2 = _mm256_setzero_ps(); - for (int p = 0; p < k; p++, a += FAST_CONV_MR, b += FAST_CONV_NR) + for (int p = 0; p < np; p++, a += CONV_MR, b += CONV_NR) { a0 = _mm256_set1_ps(a[0]), a1 = _mm256_set1_ps(a[1]); b0 = _mm256_load_ps(b), b1 = _mm256_load_ps(b + 8), b2 = _mm256_load_ps(b + 16); - c0 = _mm256_fmadd_ps(b0, a0, c0); - c1 = _mm256_fmadd_ps(b1, a0, c1); - c2 = _mm256_fmadd_ps(b2, a0, c2); + c00 = _mm256_fmadd_ps(b0, a0, c00); + c01 = _mm256_fmadd_ps(b1, a0, c01); + c02 = _mm256_fmadd_ps(b2, a0, c02); - c3 = _mm256_fmadd_ps(b0, a1, c3); - a0 = _mm256_set1_ps(a[2]); - c4 = _mm256_fmadd_ps(b1, a1, c4); - c5 = _mm256_fmadd_ps(b2, a1, c5); + c10 = _mm256_fmadd_ps(b0, a1, c10); + c11 = _mm256_fmadd_ps(b1, a1, c11); + c12 = _mm256_fmadd_ps(b2, a1, c12); - c6 = _mm256_fmadd_ps(b0, a0, c6); - a1 = _mm256_set1_ps(a[3]); - c7 = _mm256_fmadd_ps(b1, a0, c7); - c8 = _mm256_fmadd_ps(b2, a0, c8); + a0 = _mm256_set1_ps(a[2]), a1 = _mm256_set1_ps(a[3]); - c9 = _mm256_fmadd_ps(b0, a1, c9); - c10 = _mm256_fmadd_ps(b1, a1, c10); - c11 = _mm256_fmadd_ps(b2, a1, c11); + c20 = _mm256_fmadd_ps(b0, a0, c20); + c21 = _mm256_fmadd_ps(b1, a0, c21); + c22 = _mm256_fmadd_ps(b2, a0, c22); + + c30 = _mm256_fmadd_ps(b0, a1, c30); + c31 = _mm256_fmadd_ps(b1, a1, c31); + c32 = _mm256_fmadd_ps(b2, a1, c32); } - if (ifActiv) + if (!init_c) { - c0 = _mm256_min_ps(_mm256_max_ps(c0, vminval), vmaxval); - c1 = _mm256_min_ps(_mm256_max_ps(c1, vminval), vmaxval); - c2 = _mm256_min_ps(_mm256_max_ps(c2, vminval), vmaxval); - c3 = _mm256_min_ps(_mm256_max_ps(c3, vminval), vmaxval); - c4 = _mm256_min_ps(_mm256_max_ps(c4, vminval), vmaxval); - c5 = _mm256_min_ps(_mm256_max_ps(c5, vminval), vmaxval); - c6 = _mm256_min_ps(_mm256_max_ps(c6, vminval), vmaxval); - c7 = _mm256_min_ps(_mm256_max_ps(c7, vminval), vmaxval); - c8 = _mm256_min_ps(_mm256_max_ps(c8, vminval), vmaxval); - c9 = _mm256_min_ps(_mm256_max_ps(c9, vminval), vmaxval); - c10 = _mm256_min_ps(_mm256_max_ps(c10, vminval), vmaxval); - c11 = _mm256_min_ps(_mm256_max_ps(c11, vminval), vmaxval); + c00 = _mm256_add_ps(c00, _mm256_load_ps(c)); + c01 = _mm256_add_ps(c01, _mm256_load_ps(c + 8)); + c02 = _mm256_add_ps(c02, _mm256_load_ps(c + 16)); + + c10 = _mm256_add_ps(c10, _mm256_load_ps(c + ldc)); + c11 = _mm256_add_ps(c11, _mm256_load_ps(c + ldc + 8)); + c12 = _mm256_add_ps(c12, _mm256_load_ps(c + ldc + 16)); + + c20 = _mm256_add_ps(c20, _mm256_load_ps(c + ldc*2)); + c21 = _mm256_add_ps(c21, _mm256_load_ps(c + ldc*2 + 8)); + c22 = _mm256_add_ps(c22, _mm256_load_ps(c + ldc*2 + 16)); + + c30 = _mm256_add_ps(c30, _mm256_load_ps(c + ldc*3)); + c31 = _mm256_add_ps(c31, _mm256_load_ps(c + ldc*3 + 8)); + c32 = _mm256_add_ps(c32, _mm256_load_ps(c + ldc*3 + 16)); } - _mm256_storeu_ps(c, c0); _mm256_storeu_ps(c+8, c1); _mm256_storeu_ps(c+16, c2); - _mm256_storeu_ps(c + ldc, c3); _mm256_storeu_ps(c + ldc + 8, c4); _mm256_storeu_ps(c + ldc + 16, c5); - _mm256_storeu_ps(c + ldc*2, c6); _mm256_storeu_ps(c + ldc*2 + 8, c7); _mm256_storeu_ps(c + ldc*2 + 16, c8); - _mm256_storeu_ps(c + ldc*3, c9); _mm256_storeu_ps(c + ldc*3 + 8, c10); _mm256_storeu_ps(c + ldc*3 + 16, c11); + _mm256_storeu_ps(c, c00), _mm256_storeu_ps(c+8, c01), _mm256_storeu_ps(c+16, c02); + _mm256_storeu_ps(c + ldc, c10), _mm256_storeu_ps(c + ldc + 8, c11), _mm256_storeu_ps(c + ldc + 16, c12); + _mm256_storeu_ps(c + ldc*2, c20), _mm256_storeu_ps(c + ldc*2 + 8, c21), _mm256_storeu_ps(c + ldc*2 + 16, c22); + _mm256_storeu_ps(c + ldc*3, c30), _mm256_storeu_ps(c + ldc*3 + 8, c31), _mm256_storeu_ps(c + ldc*3 + 16, c32); _mm256_zeroupper(); #else -#error "unsupported FAST_CONV_MR and/or FAST_CONV_NR in convBlock_AVX2." +#error "unsupported CONV_MR and/or CONV_NR in convBlock_AVX2." #endif } @@ -78,7 +78,6 @@ void depthWiseBlock_AVX2(const float *inptr, float *outptr, const float *weights int dilation_y, int stride_x, int stride_y, int inner_xleft, int inner_xright, int inner_ytop, int inner_ybottom, bool ifMinMaxAct, bool useSIMD, bool is3x3) { - const int VECSZ = 8; __m256 vminval = _mm256_set1_ps(minval); __m256 vmaxval = _mm256_set1_ps(maxval); @@ -175,7 +174,7 @@ void depthWiseBlock_AVX2(const float *inptr, float *outptr, const float *weights { if (dy0 == 3) { - for (; x0 <= x1 - VECSZ; x0 += VECSZ) + for (; x0 <= x1 - FAST_VEC_NLANES; x0 += FAST_VEC_NLANES) { int xi_ = x0 * stride_x - pad_left; const float *inptr_xi = inptr + Wi * yi_ + xi_; @@ -251,7 +250,7 @@ void depthWiseBlock_AVX2(const float *inptr, float *outptr, const float *weights } else { - for (; x0 <= x1 - VECSZ; x0 += VECSZ) + for (; x0 <= x1 - FAST_VEC_NLANES; x0 += FAST_VEC_NLANES) { int xi_ = x0 * stride_x - pad_left; const float *inptr_xi = inptr + Wi * yi_ + xi_; @@ -277,7 +276,7 @@ void depthWiseBlock_AVX2(const float *inptr, float *outptr, const float *weights } else { - for (; x0 <= x1 - VECSZ; x0 += VECSZ) + for (; x0 <= x1 - FAST_VEC_NLANES; x0 += FAST_VEC_NLANES) { int xi_ = x0 * stride_x - pad_left, k = 0; const float *inptr_xi = inptr + Wi * yi_ + xi_; diff --git a/modules/dnn/src/layers/fast_convolution/fast_convolution.cpp b/modules/dnn/src/layers/fast_convolution/fast_convolution.cpp index 139ea7f..2af8363 100644 --- a/modules/dnn/src/layers/fast_convolution/fast_convolution.cpp +++ b/modules/dnn/src/layers/fast_convolution/fast_convolution.cpp @@ -22,7 +22,7 @@ Ptr initFastConv2d( int dilation_x, int dilation_y, const std::vector& pads_begin, const std::vector& pads_end, - float* srcWeights, + InputArray _weightsMat, float* srcBias) { Ptr conv = makePtr(); @@ -43,33 +43,27 @@ Ptr initFastConv2d( conv->pad_bottom = pads_end[0]; conv->pad_left = pads_begin[1]; conv->pad_right = pads_end[1]; - - // store bias; append some zero's to make sure that - // we can always read FAST_CONV_MR elements starting from any valid index - { - int k = 0, nbias = K + FAST_CONV_MR-1; - conv->biasBuf.reserve(nbias); - float* biasBufPtr = conv->biasBuf.data(); - for(; k < K; k++) - biasBufPtr[k] = srcBias ? srcBias[k] : 0.f; - for(; k < nbias; k++) - biasBufPtr[k] = 0.f; - } + Mat weightsMat = _weightsMat.getMat(); + auto wShape = shape(weightsMat); + const size_t wstep = weightsMat.step1(); #if CV_NEON // For now, winograd is ARM platform only. - if (ngroups == 1 && Hk ==3 && Wk == 3 && stride_x == 1 && stride_y == 1 && dilation_x == 1 && dilation_y ==1 - && K >= 16 && C >= 16 ) + if (ngroups == 1 && Hk ==3 && Wk == 3 && stride_x == 1 && stride_y == 1 && + dilation_x == 1 && dilation_y ==1 && K >= 16 && C >= 16) conv->ifWinograd63 = true; #else conv->ifWinograd63 = false; #endif + float *srcWeights = (float *)weightsMat.data; if (ngroups > 1 && ngroups == K && ngroups == C) { // for depth-wise convolutions on NCHW data we just preserve the weights in KCHW layout, // but add some padding to make the weights array layout more SIMD-friendly int ksize = Hk*Wk; - int padded_ksize = ((ksize + FAST_VEC_NLANES-1)/FAST_VEC_NLANES)*FAST_VEC_NLANES; // this code aims to let memory fit with vector size. + + // this code aims to let memory fit with vector size. + int padded_ksize = ((ksize + FAST_VEC_NLANES-1) / FAST_VEC_NLANES) * FAST_VEC_NLANES; int nweights = C*padded_ksize; conv->weightsBuf.reserve(nweights); float* weightsBufPtr = conv->weightsBuf.data(); @@ -77,340 +71,80 @@ Ptr initFastConv2d( for(int c = 0; c < C; c++) { for (int k = 0; k < ksize; k++) - weightsBufPtr[c*padded_ksize + k] = srcWeights[c*ksize + k]; + weightsBufPtr[c*padded_ksize + k] = srcWeights[c*wstep + k]; } } else { // The weights are packed as - // ngroups x (ceil((K/ngroups)/FAST_CONV_MR)*FAST_CONV_MR) x (Cg*Hk*Wk) x FAST_CONV_MR tensor + // ngroups x (ceil((K/ngroups)/CONV_MR)*CONV_MR) x (Cg*Hk*Wk) x CONV_MR tensor int Kg = K/ngroups, Cg = max(C/ngroups, 1); - int Kg_aligned = ((Kg + FAST_CONV_MR - 1)/FAST_CONV_MR)*FAST_CONV_MR; - size_t nweights = ngroups*Kg_aligned*Cg*Hk*Wk; + int numStripsMR = (Kg + CONV_MR - 1) / CONV_MR; + int Kg_aligned = numStripsMR * CONV_MR; + int HkWkCg = Hk*Wk*Cg; + size_t nweights = ngroups*Kg_aligned*HkWkCg; conv->weightsBuf.reserve(nweights); float* weightsBufPtr = conv->weightsBuf.data(); memset(weightsBufPtr, 0, nweights*sizeof(weightsBufPtr[0])); - float* packed_wptr = weightsBufPtr; - // pack the weight. - for(int g = 0; g < ngroups; g++) + // Pack the weight. + parallel_for_(Range(0, ngroups * numStripsMR), [&](const Range& r0){ + for (int gsi = r0.start; gsi < r0.end; gsi++) { - for(int k0 = 0; k0 < Kg_aligned; k0 += FAST_CONV_MR) - { - int dk = Kg - k0 < FAST_CONV_MR ? Kg - k0 : FAST_CONV_MR; - for(int c = 0; c < Cg; c++) - { - for(int yx = 0; yx < Hk*Wk; yx++, packed_wptr += FAST_CONV_MR) - { - const float* wptr = srcWeights + ((g*Kg + k0)*Cg + c)*Hk*Wk + yx; - int k = 0; - for(; k < dk; k++, wptr += Cg*Hk*Wk) - packed_wptr[k] = *wptr; - for(; k < FAST_CONV_MR; k++) - packed_wptr[k] = 0.f; - } - } - } - } - - // Prepare Weight for Winograd F(6x6, 3x3) - if (conv->ifWinograd63) - { - initWinograd63(conv, srcWeights, K, C); - } - } - return conv; -} - -static void packInput(float* inpbuf, const float* inptr, int* yxtab, int ksize, int Cg, int Hi, int Wi, int W0, - int pad_top, int pad_left, int stride_x, int stride_y, int yx0, int slice_len, - bool fast_1x1, bool partial0, bool s1d1p0, bool s1d1) -{ - const size_t inp_planesize = (size_t)Hi*Wi; + int g = gsi / numStripsMR; + int si = gsi - g * numStripsMR; - if (fast_1x1) - { - /* - super-fast branch for 1x1 convolutions with sy=sx=1. - in this case each feature plane can be safely treated - as 1D array and we just extract next portion - of FAST_CONV_NR elements from each feature plane and - put it together. - */ - inptr += yx0; - if (!partial0) - { - // Make special branch where memcpy() is called with a constant buffer size. - // Compilers will likely unroll this loop properly. - for (int c = 0; c < Cg; c++, inptr += inp_planesize, inpbuf += FAST_CONV_NR) - memcpy(inpbuf, inptr, FAST_CONV_NR * sizeof(inpbuf[0])); - } - else - { - for (int c = 0; c < Cg; c++, inptr += inp_planesize, inpbuf += FAST_CONV_NR) - { - memcpy(inpbuf, inptr, slice_len * sizeof(inpbuf[0])); - memset(inpbuf + slice_len, 0, (FAST_CONV_NR - slice_len) * sizeof(inpbuf[0])); - } - } - } - else if (s1d1p0) - { - /* - slower, but still fast branch for sy=sx=1, dy=dx=1 and without padding, - in this case we copy data from input tensors by chunks. - */ - for (int c = 0; c < Cg; c++) - { - float *inpbuf_c = inpbuf + c * (FAST_CONV_NR * ksize); - const float *inptr_c = inptr + c * inp_planesize; + int startK = si * CONV_MR; + CV_Assert(startK < Kg_aligned); - for (int k = 0; k < ksize; k++) - { - int y0 = yx0 / W0, x0 = yx0 % W0; - int yi = y0 + yxtab[k * 2], xi = x0 + yxtab[k * 2 + 1]; - float *inpbuf_k = inpbuf_c + k * FAST_CONV_NR; - int xi_0 = yxtab[k * 2 + 1]; + float* packed_wptr = weightsBufPtr + HkWkCg * (startK + g * Kg_aligned); + int dk = Kg - startK < CONV_MR ? Kg - startK : CONV_MR; // check if we need zero padding. - int i = 0; - for (; i < slice_len;) + int k_idx = g*Kg + startK; + for(int yx = 0; yx < Hk*Wk; yx++) { + for(int c = 0; c < Cg; c++, packed_wptr += CONV_MR) { - const float *inptr_k = inptr_c + yi * Wi + xi; - int copy_len = std::min(slice_len - i, W0 - x0); - int di_z = (slice_len == i + copy_len) ? FAST_CONV_NR - slice_len : 0; - - memcpy(inpbuf_k + i, - inptr_k, - copy_len * sizeof(inpbuf_k[0])); - - memset(inpbuf_k + i + copy_len, - 0, di_z * sizeof(inpbuf_k[0])); - - i += copy_len; - x0 = 0; - xi = xi_0; - yi++; + const float* wptr = srcWeights + wstep * k_idx + c*Hk*Wk + yx; + int k = 0; + for(; k < dk; k++, wptr += wstep) + packed_wptr[k] = *wptr; + for(; k < CONV_MR; k++) + packed_wptr[k] = 0.f; } } - } - } - else if (s1d1) - { - /* - slower, but still fast branch for sy=sx=1, dy=dx=1. - in this case we copy data from input tensors by chunks and - interleave the data in inpbuf with 0's - (that correspond to the padding elements) when necessary - */ - int y0 = yx0 / W0, x0 = yx0 % W0; - for (int c = 0; c < Cg; c++) - { - float *inpbuf_c = inpbuf + c * (FAST_CONV_NR * ksize); - const float *inptr_c = inptr + c * inp_planesize; - - for (int k = 0; k < ksize; k++) - { - int x0_tmp = x0; - - int xi_0 = yxtab[k * 2 + 1] - pad_left; - - int yi = y0 + yxtab[k * 2] - pad_top, xi = x0_tmp + xi_0; - float *inpbuf_k = inpbuf_c + k * FAST_CONV_NR; + }}); - int i = 0; - for (; i < slice_len;) { - int copyLen = std::min(slice_len - i, W0 - x0_tmp); - - int di_z = (i + copyLen == slice_len) ? FAST_CONV_NR - slice_len - : 0; // The final padding. - // pad_top or pad bottom - if (yi < 0 || yi > Hi - 1) - { - memset(inpbuf_k + i, - 0, (copyLen + di_z) * sizeof(inpbuf_k[0])); - i += copyLen + di_z; - } - else - { - int x_pad_left = 0, x_pad_right = 0; - - // pad_left - if (xi < 0) - { - x_pad_left = std::min(-xi, copyLen); - xi = 0; - copyLen -= x_pad_left; - } - - memset(inpbuf_k + i, - 0, x_pad_left * sizeof(inpbuf_k[0])); - i += x_pad_left; - - // pad right - if (xi + copyLen > Wi) - { - if (xi > Wi) - { - x_pad_right = copyLen; - copyLen = 0; - } - else - { - x_pad_right = std::min(xi + copyLen - Wi, copyLen); - copyLen -= x_pad_right; - } - } - - CV_Assert(copyLen >= 0); - - const float *inptr_k = inptr_c + yi * Wi + xi; - memcpy(inpbuf_k + i, - inptr_k, - copyLen * sizeof(inpbuf_k[0])); - - i += copyLen; - - // pad_right and the final padding. - memset(inpbuf_k + i, - 0, (di_z + x_pad_right) * sizeof(inpbuf_k[0])); - i += x_pad_right + di_z; - } - - x0_tmp = 0; - xi = xi_0; - yi++; - } - } - } - } - else - { - int y0_ = yx0 / W0, x0_ = yx0 - y0_ * W0; - for (int k = 0; k < ksize; k++) + // Prepare Weight for Winograd F(6x6, 3x3) + if (conv->ifWinograd63) { - int dy = yxtab[k * 2], dx = yxtab[k * 2 + 1]; - int i = 0, y0 = y0_, x0 = x0_; - for (; i < FAST_CONV_NR;) - { - float *inpbuf_ki = inpbuf + k * FAST_CONV_NR + i; - int yi = y0 * stride_y + dy - pad_top; - int xi = x0 * stride_x + dx - pad_left; - - if ((unsigned) yi < (unsigned) Hi && - (unsigned) xi < (unsigned) Wi) - { - const float *inptr_ki = inptr + yi * Wi + xi; - if (i + 4 <= FAST_CONV_NR && x0 + 4 <= W0 && xi + stride_x * 4 <= Wi) - { - if (stride_x == 2) { - for (int c = 0; c < Cg; c++, inpbuf_ki += FAST_CONV_NR * - ksize, inptr_ki += inp_planesize) - { - float t0 = inptr_ki[0], t1 = inptr_ki[2]; - float t2 = inptr_ki[4], t3 = inptr_ki[6]; - inpbuf_ki[0] = t0; - inpbuf_ki[1] = t1; - inpbuf_ki[2] = t2; - inpbuf_ki[3] = t3; - } - } - else - { - for (int c = 0; c < Cg; c++, inpbuf_ki += FAST_CONV_NR * - ksize, inptr_ki += inp_planesize) - { - float t0 = inptr_ki[0], t1 = inptr_ki[stride_x]; - float t2 = inptr_ki[stride_x * 2], t3 = inptr_ki[stride_x * 3]; - inpbuf_ki[0] = t0; - inpbuf_ki[1] = t1; - inpbuf_ki[2] = t2; - inpbuf_ki[3] = t3; - } - } - i += 4; - x0 += 4; - } - else - { - for (int c = 0; c < Cg; c++, inpbuf_ki += FAST_CONV_NR * - ksize, inptr_ki += inp_planesize) - *inpbuf_ki = *inptr_ki; - i++; - x0++; - } - } - else - { - for (int c = 0; c < Cg; c++, inpbuf_ki += FAST_CONV_NR * ksize) - inpbuf_ki[0] = 0.f; - i++; - x0++; - } - int mask = x0 >= W0; - y0 += mask; - x0 &= mask - 1; - } + initWinograd63(conv, weightsMat, K, C); } } -} - -static void matMulCompute(float* outptr0, float* inpbuf_task, float* cbuf, const Ptr& conv, int HkWkCg, - int k0, int k1, int yx0, int yx1, size_t out_planesize, int g, int Kg, int Kg_aligned, - bool partial0, ActivationLayer*& activ, float minval, float maxval, bool ifMinMaxAct) -{ - int outstep0 = out_planesize; - for (int k = k0; k < k1; k += FAST_CONV_MR, outptr0 += outstep0 * FAST_CONV_MR) + // store bias; append some zero's to make sure that + // we can always read MR elements starting from any valid index { - int dk = Kg - k < FAST_CONV_MR ? Kg - k : FAST_CONV_MR; - bool partial = partial0 || dk < FAST_CONV_MR; - float *outptr = outptr0; - - int outstep = outstep0; - if (partial) - { - outptr = cbuf; - outstep = FAST_CONV_NR; - } - - -#if CV_TRY_AVX2 - if (conv->useAVX2) - opt_AVX2::convBlock_AVX2( HkWkCg, conv->weightsBuf.data() + (g * Kg_aligned + k) * HkWkCg, - inpbuf_task, outptr, outstep, conv->biasBuf.data() + Kg * g + k, - minval, maxval, ifMinMaxAct); - else -#endif -#if CV_TRY_NEON - if (conv->useNEON) - opt_NEON::convBlock_NEON(HkWkCg, conv->weightsBuf.data() + (g * Kg_aligned + k) * HkWkCg, - inpbuf_task, outptr, outstep, conv->biasBuf.data() + Kg * g + k, - minval, maxval, ifMinMaxAct); - else -#endif - convBlock(HkWkCg, conv->weightsBuf.data() + (g * Kg_aligned + k) * HkWkCg, - inpbuf_task, outptr, outstep, conv->biasBuf.data() + Kg * g + k, - minval, maxval, ifMinMaxAct); - - // activation - if (activ) - activ->forwardSlice(outptr, outptr, yx1 - yx0, outstep, Kg * g + k, - Kg * g + k + dk); - - if (partial) - { - for (int i = 0; i < dk; i++) - memcpy(outptr0 + i * outstep0, cbuf + i * FAST_CONV_NR, - (yx1 - yx0) * sizeof(cbuf[0])); - } + int k = 0, nbias = K + CONV_MR - 1; + conv->biasBuf.reserve(nbias); + float* biasBufPtr = conv->biasBuf.data(); + for(; k < K; k++) + biasBufPtr[k] = srcBias ? srcBias[k] : 0.f; + for(; k < nbias; k++) + biasBufPtr[k] = 0.f; } + return conv; } -void runFastConv2d(InputArray _input, OutputArray _output, - const Ptr& conv, int ntasks, const Ptr& actLayer) +void runFastConv2d(InputArray _input, OutputArray _output, const Ptr& conv, int ntasks, + const Ptr& actLayer, bool fusedAdd) { Mat input = _input.getMat(); Mat output = _output.getMat(); + + Mat fusedAddMat; + if (fusedAdd) + fusedAddMat = _output.getMat(); + MatShape inputShape = shape(input); MatShape outputShape = shape(output); CV_Assert(inputShape.size() == 4 && outputShape.size() == 4); @@ -452,93 +186,69 @@ void runFastConv2d(InputArray _input, OutputArray _output, if (conv->ngroups > 1 && conv->ngroups == conv->K && conv->ngroups == conv->C) { + CV_Assert(fusedAddMat.empty()); // Depthwise-Convolution layer should not be followed by Add layer. return runDepthwise(input, output, conv, minval, maxval, activ, ifMinMaxAct); } #if CV_NEON - if ( conv->ifWinograd63 + if (conv->ifWinograd63 && inputShape[2] > 12 && inputShape[3] > 12 - && inputShape[2] < 120 && inputShape[3] < 120 ) + && inputShape[2] < 120 && inputShape[3] < 120 + ) { - // In general, for winograd branch, more cores will give better performance. - int maxNumThread = std::max(getNumThreads(), 1); - if (runWinograd63(input, output, conv, maxNumThread, minval, maxval, activ, ifMinMaxAct)) + if (runWinograd63(input, fusedAddMat, output, conv, ntasks, minval, maxval, activ, ifMinMaxAct)) return; } #endif - float* inp = input.ptr(); - float* out = output.ptr(); - int N = inputShape[0], C = inputShape[1], Hi = inputShape[2], Wi = inputShape[3]; // [N, C, H, W] int K = conv->K, Hk = conv->Hk, Wk = conv->Wk; - int H0 = outputShape[2], W0 = outputShape[3], ngroups = conv->ngroups; // ngroups + int H0 = outputShape[2], W0 = outputShape[3], ngroups = conv->ngroups; int Cg = C/ngroups, Kg = K/ngroups; - int Kg_nblocks = (Kg + FAST_CONV_MR-1)/FAST_CONV_MR, Kg_aligned = Kg_nblocks*FAST_CONV_MR; // align to MR const size_t inp_planesize = (size_t)Hi*Wi; const size_t out_planesize = (size_t)H0*W0; - int pad_top = conv->pad_top, pad_bottom = conv->pad_bottom; + int pad_top = conv->pad_top; int pad_left = conv->pad_left; - int pad_right = conv->pad_right; int stride_y = conv->stride_y, stride_x = conv->stride_x; int dilation_y = conv->dilation_y, dilation_x = conv->dilation_x; int ksize = Hk * Wk; - bool s1d1 = stride_x == 1 && stride_y == 1 && dilation_x == 1 && dilation_y == 1; - bool s1d1p0 = s1d1 && pad_top == 0 && pad_left ==0 && pad_bottom == 0 && pad_right == 0; bool fast_1x1 = stride_x == 1 && stride_y == 1 && ksize == 1; int HkWkCg = Hk*Wk*Cg; - enum { VEC_ALIGN = 8, DFT_TYPE = CV_32F }; - size_t taskbufsize = FAST_CONV_NR*HkWkCg; // input buffer - size_t taskbufsizeOutput = FAST_CONV_NR * FAST_CONV_MR; - size_t inputbufsize = 0; - size_t outbufsize = ntasks * taskbufsizeOutput; + enum { VEC_ALIGN = 8, DFT_TYPE = CV_32F }; // Memory alignment. + int MAX_STRIPES = 2; // (56 + CONV_NR - 1)/CONV_NR; + + // Friendly to L1 cache + const int K_BLOCK_SIZE = 32; + const int C_BLOCK_SIZE = 256; - int stripes_per_sample = (out_planesize + FAST_CONV_NR - 1)/FAST_CONV_NR; // align to NR - size_t hw_task = stripes_per_sample; - size_t hw_aligned = stripes_per_sample * FAST_CONV_NR; + int Kg_nblocks = (Kg + CONV_MR-1)/CONV_MR, Kg_aligned = Kg_nblocks * CONV_MR; - bool separatedLoop = false; + int stripes_per_sample = (out_planesize + CONV_NR - 1) / CONV_NR; - if (stripes_per_sample < 4 * ntasks) + if (stripes_per_sample < ntasks * 4) { - // If stripes_per_sample is small, we parallelize on K (output channel). + MAX_STRIPES = 1; stripes_per_sample = 1; - - // Separated Parallelloop could save much time in packing input data. But it may cost more memory, we use it when batch size is 1. - if (N == 1) - { - separatedLoop = true; - inputbufsize = ngroups * hw_aligned * HkWkCg; - } - - if (!separatedLoop) - { - inputbufsize = taskbufsize * ntasks; - } } else - { - // If stripes_per_sample is big, we parallelize on H0*W0. Kg_nblocks = 1; - inputbufsize = taskbufsize * ntasks; - } int Kstripes = Kg_nblocks*stripes_per_sample; int nsubtasks = N*ngroups*Kstripes; - AutoBuffer inpbuf_all_, outputbuf_; - inputbufsize = alignSize(inputbufsize, VEC_ALIGN); - inpbuf_all_.allocate(inputbufsize + VEC_ALIGN); - float* inpbuf_all = alignPtr(inpbuf_all_.data(), (int)(VEC_ALIGN*sizeof(float))); + size_t stripesize = CONV_NR * ksize * Cg; + size_t taskbufsize = (stripesize + CONV_NR * K_BLOCK_SIZE) * MAX_STRIPES; + size_t totalbufsize = taskbufsize * ntasks; - outbufsize = alignSize(outbufsize, VEC_ALIGN); - outputbuf_.allocate(outbufsize + VEC_ALIGN); - float* output_buf = alignPtr(outputbuf_.data(), (int)(VEC_ALIGN*sizeof(float))); + AutoBuffer inpbuf_all_; + totalbufsize = alignSize(totalbufsize, VEC_ALIGN); + inpbuf_all_.allocate(totalbufsize + VEC_ALIGN); + float* inpbuf_all = alignPtr(inpbuf_all_.data(), (int)(VEC_ALIGN*sizeof(inpbuf_all_[0]))); std::vector ofstab_(Hk*Wk*3, 0); int* ofstab = ofstab_.data(); @@ -554,141 +264,306 @@ void runFastConv2d(InputArray _input, OutputArray _output, ofstab[k] = dy*Wi + dx; } - if (ksize == 1) - { - CV_Assert(pad_left == 0 && pad_right == 0 && pad_top == 0 && pad_bottom == 0); - CV_Assert(stride_x != 1 || stride_y != 1 || (H0 == Hi && W0 == Wi)); - } + float* inp = input.ptr(); + float* out = output.ptr(); + float* fusedAddPtr0 = fusedAddMat.empty() ? 0 : fusedAddMat.ptr(); - if (separatedLoop) + parallel_for_(Range(0, ntasks), [&](const Range& r0) { + for (int task_id = r0.start; task_id < r0.end; task_id++) { - // For now this branch only handles batch size = 1. Maybe we could support batch size < 10 in the future. - // Pack Input data - parallel_for_(Range(0, ngroups * hw_task), [&](const Range& r0) + float* inpbuf_task = &inpbuf_all[taskbufsize * task_id]; + float* cbuf_task = inpbuf_task + stripesize * MAX_STRIPES; + + int ngs0 = (int)((size_t)nsubtasks * task_id / ntasks); + int ngs1 = (int)((size_t)nsubtasks * (task_id+1) / ntasks); + for (int subtask = ngs0; subtask < ngs1; ) { - for (int nhwi = r0.start; nhwi < r0.end; nhwi++) + int ng = subtask / Kstripes; + int kyx0 = subtask - ng * Kstripes; + int kyx1 = kyx0 + (ngs1 - subtask); + int n = ng / ngroups, g = ng % ngroups; // ng - n * ngroups; + size_t inp_plane_ofs = (size_t)(n * ngroups + g) * Cg * inp_planesize; + kyx1 = kyx1 <= Kstripes ? kyx1 : Kstripes; + subtask += kyx1 - kyx0; + int k0, k1; + int yx0, yx_limit, yx_block_limit = 0; + + if (stripes_per_sample == 1) { - int g = nhwi/hw_task; - int hw_i = nhwi % hw_task; - int hw0 = hw_i * FAST_CONV_NR; - float* inpbuf = inpbuf_all + g * hw_aligned * HkWkCg + hw0 * HkWkCg; - const float* inptr = inp + g * Cg * inp_planesize; - bool partial0 = hw0 + FAST_CONV_NR > out_planesize? true: false; - int slice_len = FAST_CONV_NR; - - if (partial0) - slice_len = out_planesize - hw0; - - packInput(inpbuf, inptr, yxtab, ksize, Cg, Hi, Wi, W0, pad_top, pad_left, stride_x, stride_y, - hw0, slice_len, fast_1x1, partial0, s1d1p0, s1d1); + k0 = kyx0 * CONV_MR; + k1 = kyx1 * CONV_MR; + k1 = k1 <= Kg ? k1 : Kg; + yx0 = 0; + yx_limit = out_planesize; } - }); - - // Compute - parallel_for_(Range(0, ntasks), [&](const Range& r0) - { - for (int task_id = r0.start; task_id < r0.end; task_id++) + else { - float *cbuf = output_buf + task_id * taskbufsizeOutput; - int ngs0 = (int) ((size_t) nsubtasks * task_id / ntasks); - int ngs1 = (int) ((size_t) nsubtasks * (task_id + 1) / ntasks); - for (int subtask = ngs0; subtask < ngs1;) - { - int ng = subtask / Kstripes; - int kyx0 = subtask - ng * Kstripes; - int kyx1 = kyx0 + (ngs1 - subtask); - int n = ng / ngroups, g = ng - n * ngroups; - - CV_Assert(n <= 1); + k0 = 0; + k1 = Kg; + yx0 = kyx0 * CONV_NR; + yx_limit = kyx1 * CONV_NR; + yx_limit = yx_limit < out_planesize ? yx_limit : out_planesize; + } - kyx1 = kyx1 <= Kstripes ? kyx1 : Kstripes; // Guarantee that maximum kyx1 is Kstripes. - subtask += kyx1 - kyx0; + for (; yx0 < yx_limit; yx0 = yx_block_limit) + { + // step 1. extract part of input tensor and represent it in zigzag form + yx_block_limit = yx0 + CONV_NR * MAX_STRIPES; + yx_block_limit = yx_block_limit < yx_limit ? yx_block_limit : yx_limit; - int k0 = kyx0 * FAST_CONV_MR; - int k1 = kyx1 * FAST_CONV_MR; - k1 = k1 <= Kg ? k1 : Kg; + int nstripes = (yx_block_limit - yx0 + CONV_NR - 1) / CONV_NR; + int yx0_saved = yx0; + CV_Assert(nstripes <= MAX_STRIPES); - for (int yx0 = 0; yx0 < out_planesize; yx0 += FAST_CONV_NR) - { - float* inpbuf_task = inpbuf_all + g * hw_aligned * HkWkCg + yx0 * HkWkCg; - int yx1 = yx0 + FAST_CONV_NR; - yx1 = yx1 <= out_planesize ? yx1 : out_planesize; - int slice_len = yx1 - yx0; - bool partial0 = slice_len < FAST_CONV_NR; - - int outstep0 = out_planesize; - size_t outofs = ((n * ngroups + g) * Kg + k0) * outstep0 + yx0; - float *outptr0 = out + outofs; - - matMulCompute(outptr0, inpbuf_task, cbuf, conv, HkWkCg, k0, k1, yx0, yx1, out_planesize, g, - Kg, Kg_aligned, partial0, activ, minval, maxval, ifMinMaxAct); - } - } - } - }); - } - else - { - parallel_for_(Range(0, ntasks), [&](const Range &r0) { - for (int task_id = r0.start; task_id < r0.end; task_id++) { - float *inpbuf_task = &inpbuf_all[taskbufsize * task_id]; - float *cbuf = output_buf + task_id * taskbufsizeOutput; - int ngs0 = (int) ((size_t) nsubtasks * task_id / ntasks); - int ngs1 = (int) ((size_t) nsubtasks * (task_id + 1) / ntasks); - - for (int subtask = ngs0; subtask < ngs1;) + for (int stripe = 0; yx0 < yx_block_limit; stripe++, yx0 += CONV_NR) { - int ng = subtask / Kstripes; - int kyx0 = subtask - ng * Kstripes; - int kyx1 = kyx0 + (ngs1 - subtask); - int n = ng / ngroups, g = ng - n * ngroups; - size_t inp_plane_ofs = (size_t) (n * ngroups + g) * Cg * inp_planesize; - kyx1 = kyx1 <= Kstripes ? kyx1 : Kstripes; // Guarantee that maximum kyx1 is Kstripes. - subtask += kyx1 - kyx0; - int k0, k1; - int yx0, yx_limit; - - if (stripes_per_sample == 1) + float* inpbuf = inpbuf_task + stripe * stripesize; + float* inptr = inp + inp_plane_ofs; + + /* + 1. pack the data. Copy the HkxWk CONV_NR-wide slices from + each feature plane of the input tensor to the input buffer. + */ + if (fast_1x1) { - k0 = kyx0 * FAST_CONV_MR; - k1 = kyx1 * FAST_CONV_MR; - k1 = k1 <= Kg ? k1 : Kg; - yx0 = 0; - yx_limit = out_planesize; + int slice_len = yx_block_limit - yx0; + bool partial = slice_len < CONV_NR; + // Superfast branch for 1x1 convolutions with sy=sx=1. + // in this case each feature plane can be safely treated + // as 1D array, and we just extract next portion + // of CONV_NR elements from each feature plane and + // put it together. + inptr += yx0; + if (!partial) + { + // Make special branch where memcpy() is called with a constant buffer size. + // Compilers will likely unroll this loop properly. + for (int c = 0; c < Cg; c++, inptr += inp_planesize, inpbuf += CONV_NR) + memcpy(inpbuf, inptr, CONV_NR*sizeof(inpbuf[0])); + } + else + { + for (int c = 0; c < Cg; c++, inptr += inp_planesize, inpbuf += CONV_NR) + { + memcpy(inpbuf, inptr, slice_len * sizeof(inpbuf[0])); + memset(inpbuf + slice_len, 0, (CONV_NR - slice_len) * sizeof(inpbuf[0])); + } + } } else { - k0 = 0; - k1 = Kg; - yx0 = kyx0 * FAST_CONV_NR; - yx_limit = kyx1 * FAST_CONV_NR; - yx_limit = yx_limit < out_planesize ? yx_limit : out_planesize; + int y0_ = yx0 / W0, x0_ = yx0 - y0_ * W0; + for (int k = 0; k < ksize; k++) + { + int dy = yxtab[k * 2], dx = yxtab[k * 2 + 1]; + int i = 0, y0 = y0_, x0 = x0_; + for (; i < CONV_NR;) + { + float *inpbuf_ki = inpbuf + k * CONV_NR * Cg + i; + int yi = y0 * stride_y + dy - pad_top; + int xi = x0 * stride_x + dx - pad_left; + + if ((unsigned) yi < (unsigned) Hi && (unsigned) xi < (unsigned) Wi) + { + const float *inptr_ki = inptr + yi * Wi + xi; + if (i + 8 <= CONV_NR && x0 + 8 <= W0 && xi + stride_x * 8 <= Wi) + { + if (stride_x == 1) + { + for (int c = 0; c < Cg; c++, inpbuf_ki += CONV_NR, inptr_ki += inp_planesize) + { + float t0 = inptr_ki[0], t1 = inptr_ki[1]; + float t2 = inptr_ki[2], t3 = inptr_ki[3]; + float t4 = inptr_ki[4], t5 = inptr_ki[5]; + float t6 = inptr_ki[6], t7 = inptr_ki[7]; + inpbuf_ki[0] = t0; inpbuf_ki[1] = t1; + inpbuf_ki[2] = t2; inpbuf_ki[3] = t3; + inpbuf_ki[4] = t4; inpbuf_ki[5] = t5; + inpbuf_ki[6] = t6; inpbuf_ki[7] = t7; + } + } + else + { + for (int c = 0; c < Cg; c++, inpbuf_ki += CONV_NR, inptr_ki += inp_planesize) + { + float t0 = inptr_ki[0], t1 = inptr_ki[stride_x]; + float t2 = inptr_ki[stride_x*2], t3 = inptr_ki[stride_x*3]; + float t4 = inptr_ki[stride_x*4], t5 = inptr_ki[stride_x*5]; + float t6 = inptr_ki[stride_x*6], t7 = inptr_ki[stride_x*7]; + inpbuf_ki[0] = t0; inpbuf_ki[1] = t1; + inpbuf_ki[2] = t2; inpbuf_ki[3] = t3; + inpbuf_ki[4] = t4; inpbuf_ki[5] = t5; + inpbuf_ki[6] = t6; inpbuf_ki[7] = t7; + } + } + i += 8; + x0 += 8; + } + else if (i + 4 <= CONV_NR && x0 + 4 <= W0 && xi + stride_x * 4 <= Wi) + { + if (stride_x == 1) + { + for (int c = 0; c < Cg; c++, inpbuf_ki += CONV_NR, inptr_ki += inp_planesize) + { + float t0 = inptr_ki[0], t1 = inptr_ki[1]; + float t2 = inptr_ki[2], t3 = inptr_ki[3]; + inpbuf_ki[0] = t0; inpbuf_ki[1] = t1; + inpbuf_ki[2] = t2; inpbuf_ki[3] = t3; + } + } + else + { + for (int c = 0; c < Cg; c++, inpbuf_ki += CONV_NR, inptr_ki += inp_planesize) + { + float t0 = inptr_ki[0], t1 = inptr_ki[stride_x]; + float t2 = inptr_ki[stride_x*2], t3 = inptr_ki[stride_x*3]; + inpbuf_ki[0] = t0; inpbuf_ki[1] = t1; + inpbuf_ki[2] = t2; inpbuf_ki[3] = t3; + } + } + i += 4; + x0 += 4; + } + else + { + for (int c = 0; c < Cg; c++, inpbuf_ki += CONV_NR, inptr_ki += inp_planesize) + *inpbuf_ki = *inptr_ki; + i++; + x0++; + } + } + else + { + for (int c = 0; c < Cg; c++, inpbuf_ki += CONV_NR) + inpbuf_ki[0] = 0.f; + i++; + x0++; + } + int mask = x0 >= W0; + y0 += mask; + x0 &= mask - 1; + } + } } + } - for (; yx0 < yx_limit; yx0 += FAST_CONV_NR) + yx0 = yx0_saved; + float* weights = conv->weightsBuf.data() + g * Kg_aligned * HkWkCg; + const float* biasptr = conv->biasBuf.data() + Kg * g; + int ldc = nstripes * CONV_NR; + + // 2. do convolution, compute Kg x (yx_block_limit - yx0) part of the output tensor + for (int k0_block = k0; k0_block < k1; k0_block += K_BLOCK_SIZE) + { + int k1_block = k0_block + K_BLOCK_SIZE < k1 ? k0_block + K_BLOCK_SIZE : k1; + for (int c0 = 0; c0 < HkWkCg; c0 += C_BLOCK_SIZE) { - float *inpbuf = inpbuf_task; - const float *inptr = inp + inp_plane_ofs; - int yx1 = yx0 + FAST_CONV_NR; - yx1 = yx1 <= yx_limit ? yx1 : yx_limit; - int slice_len = yx1 - yx0; - bool partial0 = slice_len < FAST_CONV_NR; - packInput(inpbuf, inptr, yxtab, ksize, Cg, Hi, Wi, W0, pad_top, pad_left, stride_x, stride_y, - yx0, slice_len, fast_1x1, partial0, s1d1p0, s1d1); - - // 2. do convolution, compute Kg x (yx1 - yx0) part of the output tensor - int outstep0 = out_planesize; - size_t outofs = ((n * ngroups + g) * Kg + k0) * outstep0 + yx0; - float *outptr0 = out + outofs; - - matMulCompute(outptr0, inpbuf_task, cbuf, conv, HkWkCg, k0, k1, yx0, yx1, out_planesize, g, - Kg, Kg_aligned, partial0, activ, minval, maxval, ifMinMaxAct); + int c1 = c0 + C_BLOCK_SIZE < HkWkCg ? c0 + C_BLOCK_SIZE : HkWkCg; + for (int stripe = 0; stripe < nstripes; stripe++) + { + float* wptr = weights + k0_block*HkWkCg + c0*CONV_MR; + const float* inptr = inpbuf_task + stripe*stripesize + c0 * CONV_NR; + float* cptr = cbuf_task + stripe * CONV_NR; + for (int k = k0_block; k < k1_block; k += CONV_MR, + wptr += HkWkCg * CONV_MR, cptr += CONV_MR * ldc) + { +#if CV_TRY_AVX2 + if (conv->useAVX2) + opt_AVX2::convBlock_AVX2(c1 - c0, wptr, inptr, cptr, ldc, c0 == 0); + else +#endif +#if CV_TRY_NEON + if (conv->useNEON) + opt_NEON::convBlock_NEON(c1 - c0, wptr, inptr, cptr, ldc, c0 == 0); + else +#endif + convBlock(c1 - c0, wptr, inptr, cptr, ldc, c0 == 0); + } + } + } + + size_t outofs = ((n*ngroups + g) * Kg + k0_block) * out_planesize + yx0; + int out_width = yx_block_limit - yx0; + const float* cptr = cbuf_task; + + float* outptr = out + outofs; + const float* pbptr = fusedAddPtr0 ? fusedAddPtr0 + outofs : 0; + + for (int k = k0_block; k < k1_block; k++, + cptr += ldc, outptr += out_planesize, + pbptr += (pbptr ? out_planesize : 0)) + { + float biasval = biasptr[k]; + int j = 0; +#if CV_SIMD128 + v_float32x4 vbias = v_setall_f32(biasval), vmax = v_setall_f32(maxval), vmin = v_setall_f32(minval); + if (pbptr) + { + for (; j + 7 < out_width; j += 8) + { + v_float32x4 v0 = v_add(v_load(cptr + j), vbias); + v_float32x4 v1 = v_add(v_load(cptr + j + 4), vbias); + v0 = v_add(v0, v_load(pbptr + j)); + v1 = v_add(v1, v_load(pbptr + j + 4)); + + if (ifMinMaxAct) + { + v0 = v_min(v_max(v0, vmin), vmax); + v1 = v_min(v_max(v1, vmin), vmax); + } + + v_store(outptr + j, v0); + v_store(outptr + j + 4, v1); + } + } + else + { + for (; j + 7 < out_width; j += 8) + { + v_float32x4 v0 = v_add(v_load(cptr + j), vbias); + v_float32x4 v1 = v_add(v_load(cptr + j + 4), vbias); + + if (ifMinMaxAct) + { + v0 = v_min(v_max(v0, vmin), vmax); + v1 = v_min(v_max(v1, vmin), vmax); + } + + v_store(outptr + j, v0); + v_store(outptr + j + 4, v1); + } + } +#endif + if (pbptr) { + for (; j < out_width; j++) + { + float v = cptr[j] + biasval; + v += pbptr[j]; + if (ifMinMaxAct) + v = std::min(std::max(v, minval), maxval); + outptr[j] = v; + } + } + else + { + for (; j < out_width; j++) + { + float v = cptr[j] + biasval; + + if (ifMinMaxAct) + v = std::min(std::max(v, minval), maxval); + outptr[j] = v; + } + } + + if (activ) + activ->forwardSlice(outptr, outptr, out_width, out_planesize, Kg * g + k, Kg * g + k + 1); } } } - }); + } } + }); } - -}} // namespace cv::dnn \ No newline at end of file +}} // namespace cv::dnn diff --git a/modules/dnn/src/layers/fast_convolution/fast_convolution.hpp b/modules/dnn/src/layers/fast_convolution/fast_convolution.hpp index c993781..eda5c7c 100644 --- a/modules/dnn/src/layers/fast_convolution/fast_convolution.hpp +++ b/modules/dnn/src/layers/fast_convolution/fast_convolution.hpp @@ -7,20 +7,26 @@ #include "opencv2/core/hal/intrin.hpp" -#ifndef FAST_CONV_PRAM -#define FAST_CONV_PRAM +#ifndef CONV_PRAM +#define CONV_PRAM #if CV_NEON && CV_NEON_AARCH64 // 32 registers. -#define FAST_CONV_MR 4 -#define FAST_CONV_NR 28 +#define CONV_MR 4 +#define CONV_NR 28 enum { FAST_VEC_NLANES=4 }; #elif CV_NEON // 16 registers. -#define FAST_CONV_MR 4 -#define FAST_CONV_NR 12 +#define CONV_MR 4 +#define CONV_NR 12 enum { FAST_VEC_NLANES=4 }; #else // SIMD 128, AVX or AVX2 -#define FAST_CONV_MR 4 -#define FAST_CONV_NR 24 -enum { FAST_VEC_NLANES=4 }; +#define CONV_MR 4 +#define CONV_NR 24 + +#ifdef CV_AVX2 +enum { FAST_VEC_NLANES=8 }; // AVX2 +#else +enum { FAST_VEC_NLANES=4 }; // SIMD 128 +#endif + #endif #endif @@ -37,7 +43,6 @@ struct FastConv2d std::vector weightsBuf; // For generic Conv 2D std::vector weightsWino63Buf; // For Winograd F(6x6, 3x3). - std::vector biasBuf; bool ifWinograd63 = false; bool useAVX2 = checkHardwareSupport(CPU_AVX2); @@ -52,20 +57,20 @@ Ptr initFastConv2d( int dilation_x, int dilation_y, const std::vector& pads_begin, const std::vector& pads_end, - float* srcWeights, + InputArray weightsMat, float* srcBias); // It contains different computing branches, like winograd, 1x1 conv. -void runFastConv2d(InputArray _input, OutputArray _output, - const Ptr& conv, int ntasks, const Ptr& actLayer); +void runFastConv2d(InputArray _input, OutputArray _output, const Ptr& conv, int ntasks, + const Ptr& actLayer, bool fusedAdd); void runDepthwise(InputArray _input, OutputArray _output, const Ptr& conv, float minval, float maxval, ActivationLayer* activ, bool ifMinMaxAct); // winograd init -void initWinograd63(Ptr& conv, float* src_weight, int K, int C); +void initWinograd63(Ptr& conv, InputArray weightsMat, int K, int C); -int runWinograd63(InputArray _input, OutputArray _output, const Ptr& conv, int ntasks, +int runWinograd63(InputArray _input, InputArray _fusedAddMat, OutputArray _output, const Ptr& conv, int ntasks, float minval, float maxval, ActivationLayer* activ, bool ifMinMaxAct); } // namespace dnn @@ -73,9 +78,7 @@ int runWinograd63(InputArray _input, OutputArray _output, const Ptr& namespace opt_AVX2 { #if CV_TRY_AVX2 -void convBlock_AVX2(int k, const float *a, const float *b, - float *c, int ldc, const float *bias, - float minval, float maxval, bool ifActiv); +void convBlock_AVX2(int np, const float* a, const float* b, float* c, int ldc, bool init_c); void depthWiseBlock_AVX2(const float *inptr, float *outptr, const float *weights, float biasval, int *ofstab, int *yxtab, float minval, float maxval, int Hi, int Wi, int H0, int W0, int ksize, int pad_top, int pad_left, diff --git a/modules/dnn/src/layers/fast_convolution/fast_convolution.simd.hpp b/modules/dnn/src/layers/fast_convolution/fast_convolution.simd.hpp index 94b0814..2e088a6 100644 --- a/modules/dnn/src/layers/fast_convolution/fast_convolution.simd.hpp +++ b/modules/dnn/src/layers/fast_convolution/fast_convolution.simd.hpp @@ -11,140 +11,131 @@ namespace cv { namespace dnn { -void convBlock(int k, const float *a, const float *b, - float *c, int ldc, const float *bias, - float minval, float maxval, bool ifActiv) +void convBlock(int np, const float* a, const float* b, float* c, int ldc, bool init_c) { -#if CV_SIMD128 -#if FAST_CONV_MR == 4 && FAST_CONV_NR == 24 - { - v_float32x4 c0 = v_setall_f32(bias[0]), c1 = c0, c2 = c0, c3 = c0, c4 = c0, c5 = c0; - v_float32x4 c6 = v_setall_f32(bias[1]), c7 = c6, c8 = c6, c9 = c6, c10 = c6, c11 = c6; - v_float32x4 c12 = v_setall_f32(bias[2]), c13 = c12, c14 = c12, c15 = c12, c16 = c12, c17 = c12; - v_float32x4 c18 = v_setall_f32(bias[3]), c19 = c18, c20 = c18, c21 = c18, c22 = c18, c23 = c18; +#if 0 // CV_SIMD128 && CONV_MR == 4 && CONV_NR == 24 + v_float32x4 c0 = v_setzero_f32(), c1 = c0, c2 = c0, c3 = c0, c4 = c0, c5 = c0; + v_float32x4 c6 = v_setzero_f32(), c7 = c6, c8 = c6, c9 = c6, c10 = c6, c11 = c6; + v_float32x4 c12 = v_setzero_f32(), c13 = c12, c14 = c12, c15 = c12, c16 = c12, c17 = c12; + v_float32x4 c18 = v_setzero_f32(), c19 = c18, c20 = c18, c21 = c18, c22 = c18, c23 = c18; - for (int p = 0; p < k; p++, a += FAST_CONV_MR, b += FAST_CONV_NR) - { - v_float32x4 a0 = v_setall_f32(a[0]); - v_float32x4 b0 = v_load(b), b1 = v_load(b + 4), b2 = v_load(b + 8); - v_float32x4 b3 = v_load(b + 12), b4 = v_load(b + 16), b5 = v_load(b + 20); - - c0 = v_fma(b0, a0, c0); - c1 = v_fma(b1, a0, c1); - c2 = v_fma(b2, a0, c2); - c3 = v_fma(b3, a0, c3); - c4 = v_fma(b4, a0, c4); - c5 = v_fma(b5, a0, c5); - - a0 = v_setall_f32(a[1]); - c6 = v_fma(b0, a0, c6); - c7 = v_fma(b1, a0, c7); - c8 = v_fma(b2, a0, c8); - c9 = v_fma(b3, a0, c9); - c10 = v_fma(b4, a0, c10); - c11 = v_fma(b5, a0, c11); - - a0 = v_setall_f32(a[2]); - c12 = v_fma(b0, a0, c12); - c13 = v_fma(b1, a0, c13); - c14 = v_fma(b2, a0, c14); - c15 = v_fma(b3, a0, c15); - c16 = v_fma(b4, a0, c16); - c17 = v_fma(b5, a0, c17); - - a0 = v_setall_f32(a[3]); - c18 = v_fma(b0, a0, c18); - c19 = v_fma(b1, a0, c19); - c20 = v_fma(b2, a0, c20); - c21 = v_fma(b3, a0, c21); - c22 = v_fma(b4, a0, c22); - c23 = v_fma(b5, a0, c23); - } - - if (ifActiv) { - v_float32x4 vmin = v_setall_f32(minval), vmax = v_setall_f32(maxval); - c0 = v_min(v_max(c0, vmin), vmax); - c1 = v_min(v_max(c1, vmin), vmax); - c2 = v_min(v_max(c2, vmin), vmax); - c3 = v_min(v_max(c3, vmin), vmax); - c4 = v_min(v_max(c4, vmin), vmax); - c5 = v_min(v_max(c5, vmin), vmax); - c6 = v_min(v_max(c6, vmin), vmax); - c7 = v_min(v_max(c7, vmin), vmax); - c8 = v_min(v_max(c8, vmin), vmax); - c9 = v_min(v_max(c9, vmin), vmax); - c10 = v_min(v_max(c10, vmin), vmax); - c11 = v_min(v_max(c11, vmin), vmax); - c12 = v_min(v_max(c12, vmin), vmax); - c13 = v_min(v_max(c13, vmin), vmax); - c14 = v_min(v_max(c14, vmin), vmax); - c15 = v_min(v_max(c15, vmin), vmax); - c16 = v_min(v_max(c16, vmin), vmax); - c17 = v_min(v_max(c17, vmin), vmax); - c18 = v_min(v_max(c18, vmin), vmax); - c19 = v_min(v_max(c19, vmin), vmax); - c20 = v_min(v_max(c20, vmin), vmax); - c21 = v_min(v_max(c21, vmin), vmax); - c22 = v_min(v_max(c22, vmin), vmax); - c23 = v_min(v_max(c23, vmin), vmax); - } - v_store(c, c0); - v_store(c + 4, c1); - v_store(c + 8, c2); - v_store(c + 12, c3); - v_store(c + 16, c4); - v_store(c + 20, c5); - - v_store(c + ldc, c6); - v_store(c + ldc + 4, c7); - v_store(c + ldc + 8, c8); - v_store(c + ldc + 12, c9); - v_store(c + ldc + 16, c10); - v_store(c + ldc + 20, c11); - - v_store(c + ldc * 2, c12); - v_store(c + ldc * 2 + 4, c13); - v_store(c + ldc * 2 + 8, c14); - v_store(c + ldc * 2 + 12, c15); - v_store(c + ldc * 2 + 16, c16); - v_store(c + ldc * 2 + 20, c17); - - v_store(c + ldc * 3, c18); - v_store(c + ldc * 3 + 4, c19); - v_store(c + ldc * 3 + 8, c20); - v_store(c + ldc * 3 + 12, c21); - v_store(c + ldc * 3 + 16, c22); - v_store(c + ldc * 3 + 20, c23); + for (int p = 0; p < np; p++, a += CONV_MR, b += CONV_NR) + { + v_float32x4 a0 = v_setall_f32(a[0]); + v_float32x4 b0 = v_load(b), b1 = v_load(b + 4), b2 = v_load(b + 8); + v_float32x4 b3 = v_load(b + 12), b4 = v_load(b + 16), b5 = v_load(b + 20); + + c0 = v_fma(b0, a0, c0); + c1 = v_fma(b1, a0, c1); + c2 = v_fma(b2, a0, c2); + c3 = v_fma(b3, a0, c3); + c4 = v_fma(b4, a0, c4); + c5 = v_fma(b5, a0, c5); + + a0 = v_setall_f32(a[1]); + c6 = v_fma(b0, a0, c6); + c7 = v_fma(b1, a0, c7); + c8 = v_fma(b2, a0, c8); + c9 = v_fma(b3, a0, c9); + c10 = v_fma(b4, a0, c10); + c11 = v_fma(b5, a0, c11); + + a0 = v_setall_f32(a[2]); + c12 = v_fma(b0, a0, c12); + c13 = v_fma(b1, a0, c13); + c14 = v_fma(b2, a0, c14); + c15 = v_fma(b3, a0, c15); + c16 = v_fma(b4, a0, c16); + c17 = v_fma(b5, a0, c17); + + a0 = v_setall_f32(a[3]); + c18 = v_fma(b0, a0, c18); + c19 = v_fma(b1, a0, c19); + c20 = v_fma(b2, a0, c20); + c21 = v_fma(b3, a0, c21); + c22 = v_fma(b4, a0, c22); + c23 = v_fma(b5, a0, c23); } -#endif -#else - for (int i = 0; i < FAST_CONV_MR; i++) + + if (!init_c) { - float beta = bias[i]; - for (int j = 0; j < FAST_CONV_NR; j++) - c[i*ldc + j] = beta; + c0 = v_add(c0, v_load(c)); + c1 = v_add(c1, v_load(c + 4)); + c2 = v_add(c2, v_load(c + 8)); + c3 = v_add(c3, v_load(c + 12)); + c4 = v_add(c4, v_load(c + 16)); + c5 = v_add(c5, v_load(c + 20)); + + c6 = v_add(c6 , v_load(c + ldc)); + c7 = v_add(c7 , v_load(c + ldc + 4)); + c8 = v_add(c8 , v_load(c + ldc + 8)); + c9 = v_add(c9 , v_load(c + ldc + 12)); + c10 = v_add(c10, v_load(c + ldc + 16)); + c11 = v_add(c11, v_load(c + ldc + 20)); + + c12 = v_add(c12, v_load(c + ldc*2)); + c13 = v_add(c13, v_load(c + ldc*2 + 4)); + c14 = v_add(c14, v_load(c + ldc*2 + 8)); + c15 = v_add(c15, v_load(c + ldc*2 + 12)); + c16 = v_add(c16, v_load(c + ldc*2 + 16)); + c17 = v_add(c17, v_load(c + ldc*2 + 20)); + + c18 = v_add(c18, v_load(c + ldc*3)); + c19 = v_add(c19, v_load(c + ldc*3 + 4)); + c20 = v_add(c20, v_load(c + ldc*3 + 8)); + c21 = v_add(c21, v_load(c + ldc*3 + 12)); + c22 = v_add(c22, v_load(c + ldc*3 + 16)); + c23 = v_add(c23, v_load(c + ldc*3 + 20)); } - for (int p = 0; p < k; p++) + + v_store(c, c0); + v_store(c + 4, c1); + v_store(c + 8, c2); + v_store(c + 12, c3); + v_store(c + 16, c4); + v_store(c + 20, c5); + + v_store(c + ldc, c6); + v_store(c + ldc + 4, c7); + v_store(c + ldc + 8, c8); + v_store(c + ldc + 12, c9); + v_store(c + ldc + 16, c10); + v_store(c + ldc + 20, c11); + + v_store(c + ldc * 2, c12); + v_store(c + ldc * 2 + 4, c13); + v_store(c + ldc * 2 + 8, c14); + v_store(c + ldc * 2 + 12, c15); + v_store(c + ldc * 2 + 16, c16); + v_store(c + ldc * 2 + 20, c17); + + v_store(c + ldc * 3, c18); + v_store(c + ldc * 3 + 4, c19); + v_store(c + ldc * 3 + 8, c20); + v_store(c + ldc * 3 + 12, c21); + v_store(c + ldc * 3 + 16, c22); + v_store(c + ldc * 3 + 20, c23); +#else + float cbuf[CONV_MR * CONV_NR]; + memset(cbuf, 0, sizeof(cbuf)); + for( int p = 0; p < np; p++ ) { - for (int i = 0; i < FAST_CONV_MR; i++) + for( int i = 0; i < CONV_MR; i++ ) { - float alpha = a[FAST_CONV_MR*p + i]; - for (int j = 0; j < FAST_CONV_NR; j++) - { - c[i*ldc+j] += b[FAST_CONV_NR*p + j]*alpha; - } + float ai = a[CONV_MR*p + i]; + for( int j = 0; j < CONV_NR; j++ ) + cbuf[i * CONV_NR+j] += b[CONV_NR*p + j] * ai; } } - if (ifActiv) - { - for (int i = 0; i < FAST_CONV_MR; i++) - { - for (int j = 0; j < FAST_CONV_NR; j++) - { - float v = c[i*ldc + j]; - v = std::min(std::max(v, minval), maxval); - c[i*ldc + j] = v; - } + if (!init_c) { + for(int i = 0; i < CONV_MR; i++) { + for(int j = 0; j < CONV_NR; j++) + c[i*ldc + j] += cbuf[i*CONV_NR + j]; + } + } else { + for(int i = 0; i < CONV_MR; i++) { + for(int j = 0; j < CONV_NR; j++) + c[i*ldc + j] = cbuf[i*CONV_NR + j]; } } #endif @@ -154,142 +145,122 @@ void convBlock(int k, const float *a, const float *b, namespace opt_NEON { #if CV_TRY_NEON -void convBlock_NEON(int k, const float *a, const float *b, - float *c, int ldc, const float *bias, - float minval, float maxval, bool ifActiv) +void convBlock_NEON(int np, const float* a, const float* b, float* c, int ldc, bool init_c) { -#if CV_NEON_AARCH64 && FAST_CONV_MR == 4 && FAST_CONV_NR == 28 // AARCH64 +#if CONV_MR == 4 && CONV_NR == 28 // AARCH64 { - float32x4_t c0 = vdupq_n_f32(bias[0]), c1 = c0, c2 = c0, c3 = c0, c4 = c0, c5 = c0, c24 = c0; - float32x4_t c6 = vdupq_n_f32(bias[1]), c7 = c6, c8 = c6, c9 = c6, c10 = c6, c11 = c6, c25 = c6; - float32x4_t c12 = vdupq_n_f32(bias[2]), c13 = c12, c14 = c12, c15 = c12, c16 = c12, c17 = c12, c26 = c12; - float32x4_t c18 = vdupq_n_f32(bias[3]), c19 = c18, c20 = c18, c21 = c18, c22 = c18, c23 = c18, c27 = c18; - - float32x4_t a0 = vdupq_n_f32(0.0f); - float32x4_t b0 = vdupq_n_f32(0.0f), b1 = vdupq_n_f32(0.0f), b2 = vdupq_n_f32(0.0f); + float32x4_t c00 = vdupq_n_f32(0.f), c01 = c00, c02 = c00, c03 = c00, c04 = c00, c05 = c00, c06 = c00; + float32x4_t c10 = vdupq_n_f32(0.f), c11 = c10, c12 = c10, c13 = c10, c14 = c10, c15 = c10, c16 = c10; + float32x4_t c20 = vdupq_n_f32(0.f), c21 = c20, c22 = c20, c23 = c20, c24 = c20, c25 = c20, c26 = c20; + float32x4_t c30 = vdupq_n_f32(0.f), c31 = c30, c32 = c30, c33 = c30, c34 = c30, c35 = c30, c36 = c30; - for (int p = 0; p < k; p++, a += FAST_CONV_MR) + for( int p = 0; p < np; p++, a += CONV_MR, b += CONV_NR ) { - a0 = vld1q_f32(a); - b0 = vld1q_f32(b), b1 = vld1q_f32(b + 4), b2 = vld1q_f32(b + 8); - b += 12; - - c0 = vfmaq_laneq_f32(c0, b0, a0, 0); - c1 = vfmaq_laneq_f32(c1, b1, a0, 0); - c2 = vfmaq_laneq_f32(c2, b2, a0, 0); - c6 = vfmaq_laneq_f32(c6, b0, a0, 1); - c7 = vfmaq_laneq_f32(c7, b1, a0, 1); - c8 = vfmaq_laneq_f32(c8, b2, a0, 1); - c12 = vfmaq_laneq_f32(c12, b0, a0, 2); - c13 = vfmaq_laneq_f32(c13, b1, a0, 2); - c14 = vfmaq_laneq_f32(c14, b2, a0, 2); - c18 = vfmaq_laneq_f32(c18, b0, a0, 3); - c19 = vfmaq_laneq_f32(c19, b1, a0, 3); - c20 = vfmaq_laneq_f32(c20, b2, a0, 3); - - b0 = vld1q_f32(b), b1 = vld1q_f32(b + 4), b2 = vld1q_f32(b + 8); - b += 12; - - c3 = vfmaq_laneq_f32(c3, b0, a0, 0); - c4 = vfmaq_laneq_f32(c4, b1, a0, 0); - c5 = vfmaq_laneq_f32(c5, b2, a0, 0); - - c9 = vfmaq_laneq_f32(c9, b0, a0, 1); - c10 = vfmaq_laneq_f32(c10, b1, a0, 1); - c11 = vfmaq_laneq_f32(c11, b2, a0, 1); - - c15 = vfmaq_laneq_f32(c15, b0, a0, 2); - c16 = vfmaq_laneq_f32(c16, b1, a0, 2); - c17 = vfmaq_laneq_f32(c17, b2, a0, 2); - - c21 = vfmaq_laneq_f32(c21, b0, a0, 3); - - b0 = vld1q_f32(b); - b += 4; - - c22 = vfmaq_laneq_f32(c22, b1, a0, 3); - c23 = vfmaq_laneq_f32(c23, b2, a0, 3); - - c24 = vfmaq_laneq_f32(c24, b0, a0, 0); - c25 = vfmaq_laneq_f32(c25, b0, a0, 1); + float32x4_t a0 = vld1q_f32(a), b0, b1, b2; + b0 = vld1q_f32(b); b1 = vld1q_f32(b + 4); b2 = vld1q_f32(b + 8); + + c00 = vfmaq_laneq_f32(c00, b0, a0, 0); + c01 = vfmaq_laneq_f32(c01, b1, a0, 0); + c02 = vfmaq_laneq_f32(c02, b2, a0, 0); + c10 = vfmaq_laneq_f32(c10, b0, a0, 1); + c11 = vfmaq_laneq_f32(c11, b1, a0, 1); + c12 = vfmaq_laneq_f32(c12, b2, a0, 1); + c20 = vfmaq_laneq_f32(c20, b0, a0, 2); + c21 = vfmaq_laneq_f32(c21, b1, a0, 2); + c22 = vfmaq_laneq_f32(c22, b2, a0, 2); + c30 = vfmaq_laneq_f32(c30, b0, a0, 3); + c31 = vfmaq_laneq_f32(c31, b1, a0, 3); + c32 = vfmaq_laneq_f32(c32, b2, a0, 3); + + b0 = vld1q_f32(b + 12); b1 = vld1q_f32(b + 16); b2 = vld1q_f32(b + 20); + + c03 = vfmaq_laneq_f32(c03, b0, a0, 0); + c04 = vfmaq_laneq_f32(c04, b1, a0, 0); + c05 = vfmaq_laneq_f32(c05, b2, a0, 0); + c13 = vfmaq_laneq_f32(c13, b0, a0, 1); + c14 = vfmaq_laneq_f32(c14, b1, a0, 1); + c15 = vfmaq_laneq_f32(c15, b2, a0, 1); + c23 = vfmaq_laneq_f32(c23, b0, a0, 2); + c24 = vfmaq_laneq_f32(c24, b1, a0, 2); + c25 = vfmaq_laneq_f32(c25, b2, a0, 2); + c33 = vfmaq_laneq_f32(c33, b0, a0, 3); + c34 = vfmaq_laneq_f32(c34, b1, a0, 3); + c35 = vfmaq_laneq_f32(c35, b2, a0, 3); + + b0 = vld1q_f32(b + 24); + c06 = vfmaq_laneq_f32(c06, b0, a0, 0); + c16 = vfmaq_laneq_f32(c16, b0, a0, 1); c26 = vfmaq_laneq_f32(c26, b0, a0, 2); - c27 = vfmaq_laneq_f32(c27, b0, a0, 3); + c36 = vfmaq_laneq_f32(c36, b0, a0, 3); } - if (ifActiv) { - b0 = vdupq_n_f32(minval), b1 = vdupq_n_f32(maxval); - c0 = vminq_f32(vmaxq_f32(c0, b0), b1); - c1 = vminq_f32(vmaxq_f32(c1, b0), b1); - c2 = vminq_f32(vmaxq_f32(c2, b0), b1); - c3 = vminq_f32(vmaxq_f32(c3, b0), b1); - c4 = vminq_f32(vmaxq_f32(c4, b0), b1); - c5 = vminq_f32(vmaxq_f32(c5, b0), b1); - c6 = vminq_f32(vmaxq_f32(c6, b0), b1); - c7 = vminq_f32(vmaxq_f32(c7, b0), b1); - c8 = vminq_f32(vmaxq_f32(c8, b0), b1); - c9 = vminq_f32(vmaxq_f32(c9, b0), b1); - c10 = vminq_f32(vmaxq_f32(c10, b0), b1); - c11 = vminq_f32(vmaxq_f32(c11, b0), b1); - c12 = vminq_f32(vmaxq_f32(c12, b0), b1); - c13 = vminq_f32(vmaxq_f32(c13, b0), b1); - c14 = vminq_f32(vmaxq_f32(c14, b0), b1); - c15 = vminq_f32(vmaxq_f32(c15, b0), b1); - c16 = vminq_f32(vmaxq_f32(c16, b0), b1); - c17 = vminq_f32(vmaxq_f32(c17, b0), b1); - c18 = vminq_f32(vmaxq_f32(c18, b0), b1); - c19 = vminq_f32(vmaxq_f32(c19, b0), b1); - c20 = vminq_f32(vmaxq_f32(c20, b0), b1); - c21 = vminq_f32(vmaxq_f32(c21, b0), b1); - c22 = vminq_f32(vmaxq_f32(c22, b0), b1); - c23 = vminq_f32(vmaxq_f32(c23, b0), b1); - c24 = vminq_f32(vmaxq_f32(c24, b0), b1); - c25 = vminq_f32(vmaxq_f32(c25, b0), b1); - c26 = vminq_f32(vmaxq_f32(c26, b0), b1); - c27 = vminq_f32(vmaxq_f32(c27, b0), b1); + if (!init_c) + { + c00 = vaddq_f32(c00, vld1q_f32(c)); + c01 = vaddq_f32(c01, vld1q_f32(c + 4)); + c02 = vaddq_f32(c02, vld1q_f32(c + 8)); + c03 = vaddq_f32(c03, vld1q_f32(c + 12)); + c04 = vaddq_f32(c04, vld1q_f32(c + 16)); + c05 = vaddq_f32(c05, vld1q_f32(c + 20)); + c06 = vaddq_f32(c06, vld1q_f32(c + 24)); + + c10 = vaddq_f32(c10, vld1q_f32(c + ldc)); + c11 = vaddq_f32(c11, vld1q_f32(c + ldc + 4)); + c12 = vaddq_f32(c12, vld1q_f32(c + ldc + 8)); + c13 = vaddq_f32(c13, vld1q_f32(c + ldc + 12)); + c14 = vaddq_f32(c14, vld1q_f32(c + ldc + 16)); + c15 = vaddq_f32(c15, vld1q_f32(c + ldc + 20)); + c16 = vaddq_f32(c16, vld1q_f32(c + ldc + 24)); + + c20 = vaddq_f32(c20, vld1q_f32(c + ldc*2)); + c21 = vaddq_f32(c21, vld1q_f32(c + ldc*2 + 4)); + c22 = vaddq_f32(c22, vld1q_f32(c + ldc*2 + 8)); + c23 = vaddq_f32(c23, vld1q_f32(c + ldc*2 + 12)); + c24 = vaddq_f32(c24, vld1q_f32(c + ldc*2 + 16)); + c25 = vaddq_f32(c25, vld1q_f32(c + ldc*2 + 20)); + c26 = vaddq_f32(c26, vld1q_f32(c + ldc*2 + 24)); + + c30 = vaddq_f32(c30, vld1q_f32(c + ldc*3)); + c31 = vaddq_f32(c31, vld1q_f32(c + ldc*3 + 4)); + c32 = vaddq_f32(c32, vld1q_f32(c + ldc*3 + 8)); + c33 = vaddq_f32(c33, vld1q_f32(c + ldc*3 + 12)); + c34 = vaddq_f32(c34, vld1q_f32(c + ldc*3 + 16)); + c35 = vaddq_f32(c35, vld1q_f32(c + ldc*3 + 20)); + c36 = vaddq_f32(c36, vld1q_f32(c + ldc*3 + 24)); } - vst1q_f32(c, c0); - vst1q_f32(c + 4, c1); - vst1q_f32(c + 8, c2); - vst1q_f32(c + 12, c3); - vst1q_f32(c + 16, c4); - vst1q_f32(c + 20, c5); - vst1q_f32(c + 24, c24); - - vst1q_f32(c + ldc, c6); - vst1q_f32(c + ldc + 4, c7); - vst1q_f32(c + ldc + 8, c8); - vst1q_f32(c + ldc + 12, c9); - vst1q_f32(c + ldc + 16, c10); - vst1q_f32(c + ldc + 20, c11); - vst1q_f32(c + ldc + 24, c25); - - vst1q_f32(c + ldc * 2, c12); - vst1q_f32(c + ldc * 2 + 4, c13); - vst1q_f32(c + ldc * 2 + 8, c14); - vst1q_f32(c + ldc * 2 + 12, c15); - vst1q_f32(c + ldc * 2 + 16, c16); - vst1q_f32(c + ldc * 2 + 20, c17); - vst1q_f32(c + ldc * 2 + 24, c26); - - vst1q_f32(c + ldc * 3, c18); - vst1q_f32(c + ldc * 3 + 4, c19); - vst1q_f32(c + ldc * 3 + 8, c20); - vst1q_f32(c + ldc * 3 + 12, c21); - vst1q_f32(c + ldc * 3 + 16, c22); - vst1q_f32(c + ldc * 3 + 20, c23); - vst1q_f32(c + ldc * 3 + 24, c27); + + vst1q_f32(c, c00); vst1q_f32(c+4, c01); + vst1q_f32(c+8, c02); vst1q_f32(c+12, c03); + vst1q_f32(c+16, c04); vst1q_f32(c+20, c05); + vst1q_f32(c+24, c06); + + vst1q_f32(c+ldc, c10); vst1q_f32(c+ldc+4, c11); + vst1q_f32(c+ldc+8, c12); vst1q_f32(c+ldc+12, c13); + vst1q_f32(c+ldc+16, c14); vst1q_f32(c+ldc+20, c15); + vst1q_f32(c+ldc+24, c16); + + vst1q_f32(c+ldc*2, c20); vst1q_f32(c+ldc*2+4, c21); + vst1q_f32(c+ldc*2+8, c22); vst1q_f32(c+ldc*2+12, c23); + vst1q_f32(c+ldc*2+16, c24); vst1q_f32(c+ldc*2+20, c25); + vst1q_f32(c+ldc*2+24, c26); + + vst1q_f32(c+ldc*3, c30); vst1q_f32(c+ldc*3+4, c31); + vst1q_f32(c+ldc*3+8, c32); vst1q_f32(c+ldc*3+12, c33); + vst1q_f32(c+ldc*3+16, c34); vst1q_f32(c+ldc*3+20, c35); + vst1q_f32(c+ldc*3+24, c36); } -#elif (!defined(CV_NEON_AARCH64) || !CV_NEON_AARCH64) && FAST_CONV_MR == 4 && FAST_CONV_NR == 12 // ARMv7 +#elif CONV_MR == 4 && CONV_NR == 12 // ARMv7 { - float32x4_t c0 = vdupq_n_f32(bias[0]), c1 = c0, c2 = c0; - float32x4_t c3 = vdupq_n_f32(bias[1]), c4 = c3, c5 = c3; - float32x4_t c6 = vdupq_n_f32(bias[2]), c7 = c6, c8 = c6; - float32x4_t c9 = vdupq_n_f32(bias[3]), c10 = c9, c11 = c9; + float32x4_t c0 = vdupq_n_f32(0.f), c1 = c0, c2 = c0; + float32x4_t c3 = vdupq_n_f32(0.f), c4 = c3, c5 = c3; + float32x4_t c6 = vdupq_n_f32(0.f), c7 = c6, c8 = c6; + float32x4_t c9 = vdupq_n_f32(0.f), c10 = c9, c11 = c9; + float32x2_t a0 = vdup_n_f32(0.0f), a1 = a0; float32x4_t b0 = vdupq_n_f32(0.0f), b1 = vdupq_n_f32(0.0f), b2 = vdupq_n_f32(0.0f); - for (int p = 0; p < k; p++, a += FAST_CONV_MR, b += FAST_CONV_NR) + for (int p = 0; p < np; p++, a += CONV_MR, b += CONV_NR) { a0 = vld1_f32(a), a1 = vld1_f32(a+2); b0 = vld1q_f32(b), b1 = vld1q_f32(b + 4), b2 = vld1q_f32(b + 8); @@ -311,29 +282,32 @@ void convBlock_NEON(int k, const float *a, const float *b, c11 = vmlaq_lane_f32(c11, b2, a1, 1); } - if (ifActiv) + if (!init_c) { - b0 = vdupq_n_f32(minval), b1 = vdupq_n_f32(maxval); - c0 = vminq_f32(vmaxq_f32(c0, b0), b1); - c1 = vminq_f32(vmaxq_f32(c1, b0), b1); - c2 = vminq_f32(vmaxq_f32(c2, b0), b1); - c3 = vminq_f32(vmaxq_f32(c3, b0), b1); - c4 = vminq_f32(vmaxq_f32(c4, b0), b1); - c5 = vminq_f32(vmaxq_f32(c5, b0), b1); - c6 = vminq_f32(vmaxq_f32(c6, b0), b1); - c7 = vminq_f32(vmaxq_f32(c7, b0), b1); - c8 = vminq_f32(vmaxq_f32(c8, b0), b1); - c9 = vminq_f32(vmaxq_f32(c9, b0), b1); - c10 = vminq_f32(vmaxq_f32(c10, b0), b1); - c11 = vminq_f32(vmaxq_f32(c11, b0), b1); + c0 = vaddq_f32(c0, vld1q_f32(c)); + c1 = vaddq_f32(c1, vld1q_f32(c + 4)); + c2 = vaddq_f32(c2, vld1q_f32(c + 8)); + + c3 = vaddq_f32(c3, vld1q_f32(c + ldc)); + c4 = vaddq_f32(c4, vld1q_f32(c + ldc + 4)); + c5 = vaddq_f32(c5, vld1q_f32(c + ldc + 8)); + + c6 = vaddq_f32(c6, vld1q_f32(c + ldc * 2)); + c7 = vaddq_f32(c7, vld1q_f32(c + ldc * 2 + 4)); + c8 = vaddq_f32(c8, vld1q_f32(c + ldc * 2 + 8)); + + c9 = vaddq_f32(c9 , vld1q_f32(c + ldc * 3)); + c10 = vaddq_f32(c10, vld1q_f32(c + ldc * 3 + 4)); + c11 = vaddq_f32(c11, vld1q_f32(c + ldc * 3 + 8)); } - vst1q_f32(c, c0); vst1q_f32(c+4, c1); vst1q_f32(c+8, c2); - vst1q_f32(c + ldc, c3); vst1q_f32(c + ldc + 4, c4); vst1q_f32(c + ldc + 8, c5); - vst1q_f32(c + ldc*2, c6); vst1q_f32(c + ldc*2 + 4, c7); vst1q_f32(c + ldc*2 + 8, c8); - vst1q_f32(c + ldc*3, c9); vst1q_f32(c + ldc*3 + 4, c10); vst1q_f32(c + ldc*3 + 8, c11); + + vst1q_f32(c, c0), vst1q_f32(c+4, c1), vst1q_f32(c+8, c2); + vst1q_f32(c + ldc, c3), vst1q_f32(c + ldc + 4, c4), vst1q_f32(c + ldc + 8, c5); + vst1q_f32(c + ldc*2, c6), vst1q_f32(c + ldc*2 + 4, c7), vst1q_f32(c + ldc*2 + 8, c8); + vst1q_f32(c + ldc*3, c9), vst1q_f32(c + ldc*3 + 4, c10), vst1q_f32(c + ldc*3 + 8, c11); } -#else -#error "unsupported FAST_CONV_MR and/or FAST_CONV_NR in convBlock_NEON." +//#else +//#error "unsupported CONV_MR and/or CONV_NR in convBlock_NEON." #endif } #endif diff --git a/modules/dnn/src/layers/fast_convolution/winograd_3x3s1_f63.cpp b/modules/dnn/src/layers/fast_convolution/winograd_3x3s1_f63.cpp index e841889..fb668d6 100644 --- a/modules/dnn/src/layers/fast_convolution/winograd_3x3s1_f63.cpp +++ b/modules/dnn/src/layers/fast_convolution/winograd_3x3s1_f63.cpp @@ -37,6 +37,47 @@ enum }; #if CV_NEON + +#undef _FAST_CONV_T4x4 +#define _FAST_CONV_T4x4(a, b, c, d, tr0, tr1) \ + tr0 = vtrnq_f32(a, b); \ + tr1 = vtrnq_f32(c, d); \ + a = vcombine_f32(vget_low_f32(tr0.val[0]), vget_low_f32(tr1.val[0])); \ + b = vcombine_f32(vget_low_f32(tr0.val[1]), vget_low_f32(tr1.val[1])); \ + c = vcombine_f32(vget_high_f32(tr0.val[0]), vget_high_f32(tr1.val[0])); \ + d = vcombine_f32(vget_high_f32(tr0.val[1]), vget_high_f32(tr1.val[1])) + +// The input is the pack4 data, and the output is unpack4 data. +static void transpose12x4(float* src, float* dst, const int cn) +{ + float32x4_t r00, r01, r02, r03, r04, r05, r06, r07, r08, r09, r10, r11; + float32x4x2_t tr0, tr1; + for (int i = 0; i < cn; i++, src += 48, dst += 48) + { + r00 = vld1q_f32(src); + r01 = vld1q_f32(src + 4); + r02 = vld1q_f32(src + 8); + r03 = vld1q_f32(src + 12); + r04 = vld1q_f32(src + 16); + r05 = vld1q_f32(src + 20); + r06 = vld1q_f32(src + 24); + r07 = vld1q_f32(src + 28); + r08 = vld1q_f32(src + 32); + r09 = vld1q_f32(src + 36); + r10 = vld1q_f32(src + 40); + r11 = vld1q_f32(src + 44); + + _FAST_CONV_T4x4(r00, r01, r02, r03, tr0, tr1); + _FAST_CONV_T4x4(r04, r05, r06, r07, tr0, tr1); + _FAST_CONV_T4x4(r08, r09, r10, r11, tr0, tr1); + + vst1q_f32(dst, r00), vst1q_f32(dst + 4, r04), vst1q_f32(dst + 8, r08); + vst1q_f32(dst + 12, r01), vst1q_f32(dst + 16, r05), vst1q_f32(dst + 20, r09); + vst1q_f32(dst + 24, r02), vst1q_f32(dst + 28, r06), vst1q_f32(dst + 32, r10); + vst1q_f32(dst + 36, r03), vst1q_f32(dst + 40, r07), vst1q_f32(dst + 44, r11); + } +} + static void winograd_trans_input_F63(float* src, float* dst, int Channle_div4, const int tiles, const int big_step, const int line_step, const int* ofstab0) { // const float itm[8][8] = { @@ -192,15 +233,14 @@ static void winograd_trans_input_F63(float* src, float* dst, int Channle_div4, c float* input0 = input_buf0 + 4 * tiles * r; // TODO! support tiles > 12 -//#if CV_NEON_AARCH64 -// for (; ti + 11 < tiles; ti += 12) -// { -// float* out1 = out0 + line_step * ofstab0[ti * 2] + Channle_div4 * ofstab0[ti * 2 + 1] * 4; -//// std::cout<<"ofstab0[ti * 2] = "< convLayer = ld.layerInstance.dynamicCast(); + + // Only Conv2D without fusion Activation supports this fusion, other-wise, we skip. + if (!convLayer->isConv2D || convLayer->fusedActivation) + break; + + // For now, there are currently two layers in OpenCV that run the Add operator. + Ptr nextNaryEltwiseLayer = nextData->layerInstance.dynamicCast(); + Ptr nextEltwiseLayer = nextData->layerInstance.dynamicCast(); + if (nextNaryEltwiseLayer.empty() && nextEltwiseLayer.empty()) + break; + + if (nextData->inputBlobsId.size() != 2) + break; + + if (!nextData->params.has("operation") || toLowerCase(nextData->params.get("operation")) != "add") + { + CV_LOG_DEBUG(NULL, "DNN/CPU: fusion with NaryEltwise or Eltwise Layer operation is not supported: " + << nextData->params.get("operation")); + break; + } + + // This optimization is for cases like + // some_layer conv + // | | + // +-- eltwise or (naryEltwise) --+ + // | + // activ + // This way all the element-wise computations + // (i.e. some_layer+conv) would be done at [conv] layer. + // So we need to replace [conv]'s output blob to [eltwise]'s one + // considering that [activ] is an in-place layer. + // Also we need to move all the consumers' references. + // To prevent memory collisions (i.e. when input of + // [conv] and output of [eltwise or naryEltwise] is the same blob) + // we allocate a new blob. + { + LayerData *naryOrEltwiseData = nextData; + + // Eltwise or NaryEltwise layer has two inputs. We need to determine which + // is a base convolution layer and which could be used as it's bias. + LayerData* biasLayerData = 0; + for (int i = 0; i < 2; ++i) + { + LayerData *downLayerData = &layers[naryOrEltwiseData->inputBlobsId[i].lid]; + CV_Assert(downLayerData); + // If the current downLayerData is skip, it means it is fused into the parent node. + while (downLayerData->skip) + { + if (downLayerData->inputBlobsId.size() == 1) + downLayerData = &layers[downLayerData->inputBlobsId[0].lid]; + else + { + downLayerData = 0; + break; + } + } + + if (downLayerData && ld.id == downLayerData->id) + { + biasLayerData = &layers[naryOrEltwiseData->inputBlobsId[1 - i].lid]; + break; + } + } + + // We check if biasLayerData is expected layer. + if (!biasLayerData) + break; + + // We check if the bias output shape and the ld output shape are the same. + MatShape biasOutShape = shape(biasLayerData->outputBlobs[0]); + MatShape ldOutShape = shape(ld.outputBlobs[0]); + if (biasOutShape != ldOutShape) + break; + + CV_Assert(biasLayerData); + { + // fuse naryEltwise layer + // bias must already be computed to fuse => bias layer must appear before convolution + if (biasLayerData->id < ld.id) + { + // conv + naryEltwise. + CV_Assert_N(biasLayerData->outputBlobs.size() == 1, ld.inputBlobs.size() == 1); + CV_Assert_N(biasLayerData->outputBlobsWrappers.size() == 1, ld.inputBlobsWrappers.size() == 1); + + printf_(("\tfused with %s\n", nextNaryEltwiseLayer->name.c_str())); + naryOrEltwiseData->skip = true; + + + CV_Assert_N(ld.outputBlobs.size() == 1, ld.outputBlobsWrappers.size() == 1); + // Note: Here's a trick. We set the output of conv as the output of biasLayer. + ld.outputBlobs[0] = ld.outputBlobs[0].clone(); + ld.outputBlobsWrappers[0] = wrap(ld.outputBlobs[0]); + + // Recursively modifies the output data of biasLayerData and its parent. + std::vector skipDataList; + skipDataList.push_back(biasLayerData); + + while (!skipDataList.empty()) + { + LayerData* skipData = skipDataList.back(); + skipDataList.pop_back(); + + CV_Assert(skipData->outputBlobs.size() == 1); + skipData->outputBlobs[0] = ld.outputBlobs[0]; + skipData->outputBlobsWrappers[0] = ld.outputBlobsWrappers[0]; + if (skipData->skip) + { + for (auto& inputLayerId : skipData->inputLayersId) + { + LayerData* inputld = &layers[inputLayerId]; + + if (inputld && inputld->outputBlobs.size() == 1) + skipDataList.push_back(inputld); + } + } + } + + naryOrEltwiseData->outputBlobs = ld.outputBlobs; + naryOrEltwiseData->outputBlobsWrappers = ld.outputBlobsWrappers; + + // set the fusedAdd flag in [Conv]; + convLayer->fusedAdd = true; + LayerData* finalData = naryOrEltwiseData; + /* After fused Conv + naryEltwise or eltwise, we can fuse activation if: + * => activation layer that follows is the only consumer of eltwise output + * => activation layer does not process multiple inputs + * => we do not require to keep the output of eltwise + */ + if (naryOrEltwiseData->consumers.size() == 1) + { + Ptr nextFusabeleActivLayer; + LayerData* nextAct = &layers[naryOrEltwiseData->consumers[0].lid]; + + if (nextData->outputBlobs.size() == 1) + nextFusabeleActivLayer = nextAct->layerInstance.dynamicCast(); + + if (!nextFusabeleActivLayer.empty()) + { + convLayer->setActivation(nextFusabeleActivLayer); + nextAct->skip = true; + + nextAct->outputBlobs = ld.outputBlobs; + nextAct->outputBlobsWrappers = ld.outputBlobsWrappers; + } + } + + // Move references of finalData (eltwise or activation) layer consumers to the newly allocated blob. + for (int i = 0; i < finalData->consumers.size(); ++i) + { + LayerData& consumer = layers[finalData->consumers[i].lid]; + for (int j = 0; j < consumer.inputBlobsId.size(); ++j) + { + if (consumer.inputBlobsId[j].lid == finalData->id) + { + consumer.inputBlobs[j] = &ld.outputBlobs[0]; + consumer.inputBlobsWrappers[j] = ld.outputBlobsWrappers[0]; + break; + } + } + } + } + } + } + break; + } + // OpenCL: fuse convolution layer followed by eltwise + relu // CUDA: fuse convolution layer followed by eltwise (and optional activation) while (nextData && @@ -398,7 +570,7 @@ void Net::Impl::fuseLayers(const std::vector& blobsToKeep_) // (i.e. some_layer+conv or some_layer*conv) // would be done at [conv] layer. So we need to // replace [conv]'s output blob to [eltwise]'s one. - // Also we need to move all the consumers' references. + // Also, we need to move all the consumers' references. // To prevent memory collisions (i.e. when input of // [conv] and output of [eltwise] is the same blob) // we allocate a new blob. diff --git a/modules/dnn/test/test_caffe_importer.cpp b/modules/dnn/test/test_caffe_importer.cpp index ddaa51d..3ea8d17 100644 --- a/modules/dnn/test/test_caffe_importer.cpp +++ b/modules/dnn/test/test_caffe_importer.cpp @@ -284,7 +284,7 @@ TEST(Reproducibility_SSD, Accuracy) Mat out = net.forward("detection_out"); Mat ref = blobFromNPY(_tf("ssd_out.npy")); - normAssertDetections(ref, out, "", FLT_MIN); + normAssertDetections(ref, out, "", 0.06); } typedef testing::TestWithParam > Reproducibility_MobileNet_SSD; diff --git a/modules/dnn/test/test_int8_layers.cpp b/modules/dnn/test/test_int8_layers.cpp index 54db5a3..1cafa22 100644 --- a/modules/dnn/test/test_int8_layers.cpp +++ b/modules/dnn/test/test_int8_layers.cpp @@ -1029,7 +1029,7 @@ TEST_P(Test_Int8_nets, FasterRCNN_resnet50) Mat blob = blobFromImage(inp, 1.0, Size(800, 600), Scalar(), true, false); Mat ref = blobFromNPY(_tf("tensorflow/faster_rcnn_resnet50_coco_2018_01_28.detection_out.npy")); - float confThreshold = 0.5, scoreDiff = 0.05, iouDiff = 0.15; + float confThreshold = 0.8, scoreDiff = 0.05, iouDiff = 0.15; testDetectionNet(net, blob, ref, confThreshold, scoreDiff, iouDiff); } diff --git a/modules/dnn/test/test_torch_importer.cpp b/modules/dnn/test/test_torch_importer.cpp index 5208874..bd95727 100644 --- a/modules/dnn/test/test_torch_importer.cpp +++ b/modules/dnn/test/test_torch_importer.cpp @@ -488,7 +488,7 @@ TEST_P(Test_Torch_nets, ENet_accuracy) // Due to numerical instability in Pooling-Unpooling layers (indexes jittering) // thresholds for ENet must be changed. Accuracy of results was checked on // Cityscapes dataset and difference in mIOU with Torch is 10E-4% - normAssert(ref, out, "", 0.00044, /*target == DNN_TARGET_CPU ? 0.453 : */0.552); + normAssert(ref, out, "", 0.0005, /*target == DNN_TARGET_CPU ? 0.453 : */0.552); normAssertSegmentation(ref, out); const int N = 3; @@ -496,7 +496,7 @@ TEST_P(Test_Torch_nets, ENet_accuracy) { net.setInput(inputBlob, ""); Mat out = net.forward(); - normAssert(ref, out, "", 0.00044, /*target == DNN_TARGET_CPU ? 0.453 : */0.552); + normAssert(ref, out, "", 0.0005, /*target == DNN_TARGET_CPU ? 0.453 : */0.552); normAssertSegmentation(ref, out); } } -- 2.7.4