Further optimization of Conv2D, fused Conv_Add_Activation, bring latest code from...
authorZihao Mu <zihaomu@outlook.com>
Fri, 26 Aug 2022 09:57:25 +0000 (17:57 +0800)
committerGitHub <noreply@github.com>
Fri, 26 Aug 2022 09:57:25 +0000 (12:57 +0300)
12 files changed:
modules/dnn/include/opencv2/dnn/all_layers.hpp
modules/dnn/src/dnn_common.hpp
modules/dnn/src/layers/convolution_layer.cpp
modules/dnn/src/layers/fast_convolution/fast_convolution.avx2.cpp
modules/dnn/src/layers/fast_convolution/fast_convolution.cpp
modules/dnn/src/layers/fast_convolution/fast_convolution.hpp
modules/dnn/src/layers/fast_convolution/fast_convolution.simd.hpp
modules/dnn/src/layers/fast_convolution/winograd_3x3s1_f63.cpp
modules/dnn/src/net_impl_fuse.cpp
modules/dnn/test/test_caffe_importer.cpp
modules/dnn/test/test_int8_layers.cpp
modules/dnn/test/test_torch_importer.cpp

index 263e48c..66ba087 100644 (file)
@@ -256,6 +256,9 @@ CV__DNN_INLINE_NS_BEGIN
     {
     public:
         static Ptr<BaseConvolutionLayer> create(const LayerParams& params);
+        bool fusedActivation = false;
+        bool fusedAdd = false;
+        bool isConv2D = false; // Should be deleted after fastconv branch support Conv1D and Conv3D.
     };
 
     class CV_EXPORTS ConvolutionLayerInt8 : public BaseConvolutionLayer
index ae4d9c2..b580b9f 100644 (file)
@@ -13,6 +13,7 @@
 namespace cv { namespace dnn {
 CV__DNN_INLINE_NS_BEGIN
 #define IS_DNN_OPENCL_TARGET(id) (id == DNN_TARGET_OPENCL || id == DNN_TARGET_OPENCL_FP16)
+#define IS_DNN_CPU_TARGET(id) (id == DNN_TARGET_CPU) // TODO: add DNN_TARGET_CPU_FP16
 Mutex& getInitializationMutex();
 void initializeLayerFactory();
 
index c2960d5..e6ef9f1 100644 (file)
@@ -118,6 +118,9 @@ public:
 
         fusedWeights = false;
         fusedBias = false;
+
+        if (kernel_size.size() == 2)
+            isConv2D = true;
     }
 
     virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE
@@ -188,6 +191,9 @@ public:
 
     virtual bool tryFuse(Ptr<Layer>& top) CV_OVERRIDE
     {
+        if (fusedAdd)   // If the Conv layer has fused Add layer, it cannot fuse other layers.
+            return false;
+
         Ptr<BlankLayer> blank_layer = top.dynamicCast<BlankLayer>();
         if (blank_layer)
             return true;
@@ -260,7 +266,6 @@ public:
     std::vector<float> reluslope;
     Ptr<ActivationLayer> activ;
 
-    Mat fastWeights; // Used to store weight params. It will be used for layer fusion and without memory alignment.
     Ptr<FastConv2d> fastConv2dImpl;
 
 #ifdef HAVE_OPENCL
@@ -438,7 +443,6 @@ public:
                 wm.copyTo(wm_aligned);
                 wm = wm_aligned;
             }
-            fastWeights = blobs[0].reshape(1, numOutput);
             weightsMat = wm;
         }
         else
@@ -584,11 +588,15 @@ public:
             }
         }
 #endif
-        return !activ.empty();
+        fusedActivation = !activ.empty();
+        return fusedActivation;
     }
 
     virtual bool tryFuse(Ptr<Layer>& top) CV_OVERRIDE
     {
+        if (fusedAdd)   // If the Conv layer has fused Add layer, it cannot fuse other layers.
+            return false;
+
 #ifdef HAVE_CUDA
         if(IS_DNN_CUDA_TARGET(preferableTarget))
         {
@@ -634,26 +642,14 @@ public:
             if (weightsMat.data == blobs[0].data)
                 weightsMat = weightsMat.clone();
 
-            // If fastWeights is the same as weightsMat, we don't need to allocate more space for fastWeights.
-            bool sameFastWeights = false;
-            if (fastWeights.step1() == weightsMat.step1()) // If weightsMat is realigned, it is not the same as fastWeights.
-                sameFastWeights = true;
-
-            if (!sameFastWeights && fastWeights.data == blobs[0].data)
-                fastWeights = fastWeights.clone();
-
             Mat originWeights = blobs[0].reshape(1, outCn);
             for (int i = 0; i < outCn; ++i)
             {
                 double wi = w.at<float>(i);
                 weightsMultipliers[i] *= wi;
                 cv::multiply(originWeights.row(i), weightsMultipliers[i], weightsMat.row(i));
-                if (!sameFastWeights)
-                    cv::multiply(originWeights.row(i), weightsMultipliers[i], fastWeights.row(i));
                 biasvec[i] *= wi;
             }
-            if (sameFastWeights)
-                fastWeights = weightsMat;
         }
 
         if (!b.empty())
@@ -1970,9 +1966,6 @@ public:
         if (blobs.empty())
         {
             variableWeight = true;
-            if (fastWeights.data != inputs[1].data)
-                fastWeights = inputs[1].clone();
-
             Mat wm = inputs[1].reshape(1, outCn);
             if (wm.data != weightsMat.data)
             {
@@ -2089,7 +2082,7 @@ public:
         {
             int nstripes = std::max(getNumThreads(), 1);
 
-            // Initialization of FastCovn2d
+            // Initialization of FastCovn2d, pack weight.
             if ((!fastConv2dImpl || variableWeight) && inputs[0].dims == 4)
             {
                 int K = outputs[0].size[1];
@@ -2103,23 +2096,22 @@ public:
 
                 int dilation_h = dilations[dilations.size() - 2];
                 int dilation_w = dilations.back();
-                float* weightsPtr = fastWeights.ptr<float>();
-                CV_Assert(weightsPtr);
 
-                fastConv2dImpl = initFastConv2d(ngroups, K, C, Hk, Wk, stride_w, stride_h,
-                                              dilation_w, dilation_h, pads_begin, pads_end, weightsPtr, &biasvec[0]);
+                fastConv2dImpl = initFastConv2d(ngroups, K, C, Hk, Wk, stride_w, stride_h, dilation_w,
+                                                dilation_h, pads_begin, pads_end, weightsMat, &biasvec[0]);
             }
 
             if (fastConv2dImpl)
             {
-                runFastConv2d(inputs[0], outputs[0], fastConv2dImpl, nstripes, activ);
+                runFastConv2d(inputs[0], outputs[0], fastConv2dImpl, nstripes, activ, fusedAdd);
                 return;
             }
 
+            //TODO: Add support of Conv1D and Conv3D to fastConv, and remove the old Conv branch.
             // Use only for Conv1D and Conv3D.
+            CV_Assert(!fusedAdd);
             ParallelConv::run(inputs[0], outputs[0], weightsMat, biasvec, reluslope,
                             kernel_size, strides, pads_begin, pads_end, dilations, activ.get(), ngroups, nstripes);
-
         }
     }
 
index 22580c5..de1a2ef 100644 (file)
@@ -9,67 +9,67 @@ namespace cv {
 namespace opt_AVX2
 {
 #if CV_TRY_AVX2
-void convBlock_AVX2(int k, const float *a, const float *b,
-                float *c, int ldc, const float *bias,
-                float minval, float maxval, bool ifActiv)
+void convBlock_AVX2(int np, const float* a, const float* b, float* c, int ldc, bool init_c)
 {
-#if FAST_CONV_MR == 4 && FAST_CONV_NR == 24
-    __m256 vminval = _mm256_set1_ps(minval), vmaxval = _mm256_set1_ps(maxval);
-    __m256 c0 = _mm256_set1_ps(bias[0]), c1 = c0, c2 = c0;
-    __m256 c3 = _mm256_set1_ps(bias[1]), c4 = c3, c5 = c3;
-    __m256 c6 = _mm256_set1_ps(bias[2]), c7 = c6, c8 = c6;
-    __m256 c9 = _mm256_set1_ps(bias[3]), c10 = c9, c11 = c9;
+#if CONV_MR == 4 && CONV_NR == 24
+    __m256 c00 = _mm256_set1_ps(0.f), c01 = c00, c02 = c00;
+    __m256 c10 = c00, c11 = c00, c12 = c00;
+    __m256 c20 = c00, c21 = c00, c22 = c00;
+    __m256 c30 = c00, c31 = c00, c32 = c00;
 
     __m256 a0 = _mm256_setzero_ps(), a1 = _mm256_setzero_ps();
     __m256 b0 = _mm256_setzero_ps(), b1 = _mm256_setzero_ps(), b2 = _mm256_setzero_ps();
 
-    for (int p = 0; p < k; p++, a += FAST_CONV_MR, b += FAST_CONV_NR)
+    for (int p = 0; p < np; p++, a += CONV_MR, b += CONV_NR)
     {
         a0 = _mm256_set1_ps(a[0]), a1 = _mm256_set1_ps(a[1]);
         b0 = _mm256_load_ps(b), b1 = _mm256_load_ps(b + 8), b2 = _mm256_load_ps(b + 16);
 
-        c0 = _mm256_fmadd_ps(b0, a0, c0);
-        c1 = _mm256_fmadd_ps(b1, a0, c1);
-        c2 = _mm256_fmadd_ps(b2, a0, c2);
+        c00 = _mm256_fmadd_ps(b0, a0, c00);
+        c01 = _mm256_fmadd_ps(b1, a0, c01);
+        c02 = _mm256_fmadd_ps(b2, a0, c02);
 
-        c3 = _mm256_fmadd_ps(b0, a1, c3);
-        a0 = _mm256_set1_ps(a[2]);
-        c4 = _mm256_fmadd_ps(b1, a1, c4);
-        c5 = _mm256_fmadd_ps(b2, a1, c5);
+        c10 = _mm256_fmadd_ps(b0, a1, c10);
+        c11 = _mm256_fmadd_ps(b1, a1, c11);
+        c12 = _mm256_fmadd_ps(b2, a1, c12);
 
-        c6 = _mm256_fmadd_ps(b0, a0, c6);
-        a1 = _mm256_set1_ps(a[3]);
-        c7 = _mm256_fmadd_ps(b1, a0, c7);
-        c8 = _mm256_fmadd_ps(b2, a0, c8);
+        a0 = _mm256_set1_ps(a[2]), a1 = _mm256_set1_ps(a[3]);
 
-        c9 = _mm256_fmadd_ps(b0, a1, c9);
-        c10 = _mm256_fmadd_ps(b1, a1, c10);
-        c11 = _mm256_fmadd_ps(b2, a1, c11);
+        c20 = _mm256_fmadd_ps(b0, a0, c20);
+        c21 = _mm256_fmadd_ps(b1, a0, c21);
+        c22 = _mm256_fmadd_ps(b2, a0, c22);
+
+        c30 = _mm256_fmadd_ps(b0, a1, c30);
+        c31 = _mm256_fmadd_ps(b1, a1, c31);
+        c32 = _mm256_fmadd_ps(b2, a1, c32);
     }
 
-    if (ifActiv)
+    if (!init_c)
     {
-        c0 = _mm256_min_ps(_mm256_max_ps(c0, vminval), vmaxval);
-        c1 = _mm256_min_ps(_mm256_max_ps(c1, vminval), vmaxval);
-        c2 = _mm256_min_ps(_mm256_max_ps(c2, vminval), vmaxval);
-        c3 = _mm256_min_ps(_mm256_max_ps(c3, vminval), vmaxval);
-        c4 = _mm256_min_ps(_mm256_max_ps(c4, vminval), vmaxval);
-        c5 = _mm256_min_ps(_mm256_max_ps(c5, vminval), vmaxval);
-        c6 = _mm256_min_ps(_mm256_max_ps(c6, vminval), vmaxval);
-        c7 = _mm256_min_ps(_mm256_max_ps(c7, vminval), vmaxval);
-        c8 = _mm256_min_ps(_mm256_max_ps(c8, vminval), vmaxval);
-        c9 = _mm256_min_ps(_mm256_max_ps(c9, vminval), vmaxval);
-        c10 = _mm256_min_ps(_mm256_max_ps(c10, vminval), vmaxval);
-        c11 = _mm256_min_ps(_mm256_max_ps(c11, vminval), vmaxval);
+        c00 = _mm256_add_ps(c00, _mm256_load_ps(c));
+        c01 = _mm256_add_ps(c01, _mm256_load_ps(c + 8));
+        c02 = _mm256_add_ps(c02, _mm256_load_ps(c + 16));
+
+        c10 = _mm256_add_ps(c10, _mm256_load_ps(c + ldc));
+        c11 = _mm256_add_ps(c11, _mm256_load_ps(c + ldc + 8));
+        c12 = _mm256_add_ps(c12, _mm256_load_ps(c + ldc + 16));
+
+        c20 = _mm256_add_ps(c20, _mm256_load_ps(c + ldc*2));
+        c21 = _mm256_add_ps(c21, _mm256_load_ps(c + ldc*2 + 8));
+        c22 = _mm256_add_ps(c22, _mm256_load_ps(c + ldc*2 + 16));
+
+        c30 = _mm256_add_ps(c30, _mm256_load_ps(c + ldc*3));
+        c31 = _mm256_add_ps(c31, _mm256_load_ps(c + ldc*3 + 8));
+        c32 = _mm256_add_ps(c32, _mm256_load_ps(c + ldc*3 + 16));
     }
 
-    _mm256_storeu_ps(c, c0); _mm256_storeu_ps(c+8, c1); _mm256_storeu_ps(c+16, c2);
-    _mm256_storeu_ps(c + ldc, c3); _mm256_storeu_ps(c + ldc + 8, c4); _mm256_storeu_ps(c + ldc + 16, c5);
-    _mm256_storeu_ps(c + ldc*2, c6); _mm256_storeu_ps(c + ldc*2 + 8, c7); _mm256_storeu_ps(c + ldc*2 + 16, c8);
-    _mm256_storeu_ps(c + ldc*3, c9); _mm256_storeu_ps(c + ldc*3 + 8, c10); _mm256_storeu_ps(c + ldc*3 + 16, c11);
+    _mm256_storeu_ps(c, c00), _mm256_storeu_ps(c+8, c01), _mm256_storeu_ps(c+16, c02);
+    _mm256_storeu_ps(c + ldc, c10), _mm256_storeu_ps(c + ldc + 8, c11), _mm256_storeu_ps(c + ldc + 16, c12);
+    _mm256_storeu_ps(c + ldc*2, c20), _mm256_storeu_ps(c + ldc*2 + 8, c21), _mm256_storeu_ps(c + ldc*2 + 16, c22);
+    _mm256_storeu_ps(c + ldc*3, c30), _mm256_storeu_ps(c + ldc*3 + 8, c31), _mm256_storeu_ps(c + ldc*3 + 16, c32);
     _mm256_zeroupper();
 #else
-#error "unsupported FAST_CONV_MR and/or FAST_CONV_NR in convBlock_AVX2."
+#error "unsupported CONV_MR and/or CONV_NR in convBlock_AVX2."
 #endif
 }
 
@@ -78,7 +78,6 @@ void depthWiseBlock_AVX2(const float *inptr, float *outptr, const float *weights
                     int dilation_y, int stride_x, int stride_y, int inner_xleft, int inner_xright, int inner_ytop,
                     int inner_ybottom, bool ifMinMaxAct, bool useSIMD, bool is3x3)
 {
-    const int VECSZ = 8;
     __m256 vminval = _mm256_set1_ps(minval);
     __m256 vmaxval = _mm256_set1_ps(maxval);
 
@@ -175,7 +174,7 @@ void depthWiseBlock_AVX2(const float *inptr, float *outptr, const float *weights
                 {
                     if (dy0 == 3)
                     {
-                        for (; x0 <= x1 - VECSZ; x0 += VECSZ)
+                        for (; x0 <= x1 - FAST_VEC_NLANES; x0 += FAST_VEC_NLANES)
                         {
                             int xi_ = x0 * stride_x - pad_left;
                             const float *inptr_xi = inptr + Wi * yi_ + xi_;
@@ -251,7 +250,7 @@ void depthWiseBlock_AVX2(const float *inptr, float *outptr, const float *weights
                     }
                     else
                     {
-                        for (; x0 <= x1 - VECSZ; x0 += VECSZ)
+                        for (; x0 <= x1 - FAST_VEC_NLANES; x0 += FAST_VEC_NLANES)
                         {
                             int xi_ = x0 * stride_x - pad_left;
                             const float *inptr_xi = inptr + Wi * yi_ + xi_;
@@ -277,7 +276,7 @@ void depthWiseBlock_AVX2(const float *inptr, float *outptr, const float *weights
                 }
                 else
                 {
-                    for (; x0 <= x1 - VECSZ; x0 += VECSZ)
+                    for (; x0 <= x1 - FAST_VEC_NLANES; x0 += FAST_VEC_NLANES)
                     {
                         int xi_ = x0 * stride_x - pad_left, k = 0;
                         const float *inptr_xi = inptr + Wi * yi_ + xi_;
index 139ea7f..2af8363 100644 (file)
@@ -22,7 +22,7 @@ Ptr<FastConv2d> initFastConv2d(
         int dilation_x, int dilation_y,
         const std::vector<size_t>& pads_begin,
         const std::vector<size_t>& pads_end,
-        float* srcWeights,
+        InputArray _weightsMat,
         float* srcBias)
 {
     Ptr<FastConv2d> conv = makePtr<FastConv2d>();
@@ -43,33 +43,27 @@ Ptr<FastConv2d> initFastConv2d(
     conv->pad_bottom = pads_end[0];
     conv->pad_left = pads_begin[1];
     conv->pad_right = pads_end[1];
-
-    // store bias; append some zero's to make sure that
-    // we can always read FAST_CONV_MR elements starting from any valid index
-    {
-        int k = 0, nbias = K + FAST_CONV_MR-1;
-        conv->biasBuf.reserve(nbias);
-        float* biasBufPtr = conv->biasBuf.data();
-        for(; k < K; k++)
-            biasBufPtr[k] = srcBias ? srcBias[k] : 0.f;
-        for(; k < nbias; k++)
-            biasBufPtr[k] = 0.f;
-    }
+    Mat weightsMat = _weightsMat.getMat();
+    auto wShape = shape(weightsMat);
+    const size_t wstep = weightsMat.step1();
 
 #if CV_NEON // For now, winograd is ARM platform only.
-    if (ngroups == 1 && Hk ==3 && Wk == 3 && stride_x == 1 && stride_y == 1 && dilation_x == 1 && dilation_y ==1
-        && K >= 16 && C >= 16 )
+    if (ngroups == 1 && Hk ==3 && Wk == 3 && stride_x == 1 && stride_y == 1 &&
+        dilation_x == 1 && dilation_y ==1 && K >= 16 && C >= 16)
         conv->ifWinograd63 = true;
 #else
     conv->ifWinograd63 = false;
 #endif
 
+    float *srcWeights = (float *)weightsMat.data;
     if (ngroups > 1 && ngroups == K && ngroups == C)
     {
         // for depth-wise convolutions on NCHW data we just preserve the weights in KCHW layout,
         // but add some padding to make the weights array layout more SIMD-friendly
         int ksize = Hk*Wk;
-        int padded_ksize = ((ksize + FAST_VEC_NLANES-1)/FAST_VEC_NLANES)*FAST_VEC_NLANES;  // this code aims to let memory fit with vector size.
+
+        // this code aims to let memory fit with vector size.
+        int padded_ksize = ((ksize + FAST_VEC_NLANES-1) / FAST_VEC_NLANES) * FAST_VEC_NLANES;
         int nweights = C*padded_ksize;
         conv->weightsBuf.reserve(nweights);
         float* weightsBufPtr = conv->weightsBuf.data();
@@ -77,340 +71,80 @@ Ptr<FastConv2d> initFastConv2d(
         for(int c = 0; c < C; c++)
         {
             for (int k = 0; k < ksize; k++)
-                weightsBufPtr[c*padded_ksize + k] = srcWeights[c*ksize + k];
+                weightsBufPtr[c*padded_ksize + k] = srcWeights[c*wstep + k];
         }
     }
     else
     {
         // The weights are packed as
-        // ngroups x (ceil((K/ngroups)/FAST_CONV_MR)*FAST_CONV_MR) x (Cg*Hk*Wk) x FAST_CONV_MR tensor
+        // ngroups x (ceil((K/ngroups)/CONV_MR)*CONV_MR) x (Cg*Hk*Wk) x CONV_MR tensor
         int Kg = K/ngroups, Cg = max(C/ngroups, 1);
-        int Kg_aligned = ((Kg + FAST_CONV_MR - 1)/FAST_CONV_MR)*FAST_CONV_MR;
-        size_t nweights = ngroups*Kg_aligned*Cg*Hk*Wk;
+        int numStripsMR = (Kg + CONV_MR - 1) / CONV_MR;
+        int Kg_aligned = numStripsMR * CONV_MR;
+        int HkWkCg = Hk*Wk*Cg;
+        size_t nweights = ngroups*Kg_aligned*HkWkCg;
         conv->weightsBuf.reserve(nweights);
         float* weightsBufPtr = conv->weightsBuf.data();
         memset(weightsBufPtr, 0, nweights*sizeof(weightsBufPtr[0]));
-        float* packed_wptr = weightsBufPtr;
 
-        // pack the weight.
-        for(int g = 0; g < ngroups; g++)
+        // Pack the weight.
+        parallel_for_(Range(0, ngroups * numStripsMR), [&](const Range& r0){
+        for (int gsi = r0.start; gsi < r0.end; gsi++)
         {
-            for(int k0 = 0; k0 < Kg_aligned; k0 += FAST_CONV_MR)
-            {
-                int dk = Kg - k0 < FAST_CONV_MR ? Kg - k0 : FAST_CONV_MR;
-                for(int c = 0; c < Cg; c++)
-                {
-                    for(int yx = 0; yx < Hk*Wk; yx++, packed_wptr += FAST_CONV_MR)
-                    {
-                        const float* wptr = srcWeights + ((g*Kg + k0)*Cg + c)*Hk*Wk + yx;
-                        int k = 0;
-                        for(; k < dk; k++, wptr += Cg*Hk*Wk)
-                            packed_wptr[k] = *wptr;
-                        for(; k < FAST_CONV_MR; k++)
-                            packed_wptr[k] = 0.f;
-                    }
-                }
-            }
-        }
-
-        // Prepare Weight for Winograd F(6x6, 3x3)
-        if (conv->ifWinograd63)
-        {
-            initWinograd63(conv, srcWeights, K, C);
-        }
-    }
-    return conv;
-}
-
-static void packInput(float* inpbuf, const float* inptr, int* yxtab, int ksize, int Cg, int Hi, int Wi, int W0,
-                         int pad_top, int pad_left, int stride_x, int stride_y, int yx0, int slice_len,
-                         bool fast_1x1, bool partial0, bool s1d1p0, bool s1d1)
-{
-    const size_t inp_planesize = (size_t)Hi*Wi;
+            int g = gsi / numStripsMR;
+            int si = gsi - g * numStripsMR;
 
-    if (fast_1x1)
-    {
-        /*
-           super-fast branch for 1x1 convolutions with sy=sx=1.
-           in this case each feature plane can be safely treated
-           as 1D array and we just extract next portion
-           of FAST_CONV_NR elements from each feature plane and
-           put it together.
-        */
-        inptr += yx0;
-        if (!partial0)
-        {
-            // Make special branch where memcpy() is called with a constant buffer size.
-            // Compilers will likely unroll this loop properly.
-            for (int c = 0; c < Cg; c++, inptr += inp_planesize, inpbuf += FAST_CONV_NR)
-                memcpy(inpbuf, inptr, FAST_CONV_NR * sizeof(inpbuf[0]));
-        }
-        else
-        {
-            for (int c = 0; c < Cg; c++, inptr += inp_planesize, inpbuf += FAST_CONV_NR)
-            {
-                memcpy(inpbuf, inptr, slice_len * sizeof(inpbuf[0]));
-                memset(inpbuf + slice_len, 0, (FAST_CONV_NR - slice_len) * sizeof(inpbuf[0]));
-            }
-        }
-    }
-    else if (s1d1p0)
-    {
-        /*
-         slower, but still fast branch for sy=sx=1, dy=dx=1 and without padding,
-         in this case we copy data from input tensors by chunks.
-         */
-        for (int c = 0; c < Cg; c++)
-        {
-            float *inpbuf_c = inpbuf + c * (FAST_CONV_NR * ksize);
-            const float *inptr_c = inptr + c * inp_planesize;
+            int startK = si * CONV_MR;
+            CV_Assert(startK < Kg_aligned);
 
-            for (int k = 0; k < ksize; k++)
-            {
-                int y0 = yx0 / W0, x0 = yx0 % W0;
-                int yi = y0 + yxtab[k * 2], xi = x0 + yxtab[k * 2 + 1];
-                float *inpbuf_k = inpbuf_c + k * FAST_CONV_NR;
-                int xi_0 = yxtab[k * 2 + 1];
+            float* packed_wptr = weightsBufPtr + HkWkCg * (startK + g * Kg_aligned);
+            int dk = Kg - startK < CONV_MR ? Kg - startK : CONV_MR; // check if we need zero padding.
 
-                int i = 0;
-                for (; i < slice_len;)
+            int k_idx = g*Kg + startK;
+            for(int yx = 0; yx < Hk*Wk; yx++) {
+                for(int c = 0; c < Cg; c++, packed_wptr += CONV_MR)
                 {
-                    const float *inptr_k = inptr_c + yi * Wi + xi;
-                    int copy_len = std::min(slice_len - i, W0 - x0);
-                    int di_z = (slice_len == i + copy_len) ? FAST_CONV_NR - slice_len : 0;
-
-                    memcpy(inpbuf_k + i,
-                           inptr_k,
-                           copy_len * sizeof(inpbuf_k[0]));
-
-                    memset(inpbuf_k + i + copy_len,
-                           0, di_z * sizeof(inpbuf_k[0]));
-
-                    i += copy_len;
-                    x0 = 0;
-                    xi = xi_0;
-                    yi++;
+                    const float* wptr = srcWeights + wstep * k_idx + c*Hk*Wk + yx;
+                    int k = 0;
+                    for(; k < dk; k++, wptr += wstep)
+                        packed_wptr[k] = *wptr;
+                    for(; k < CONV_MR; k++)
+                        packed_wptr[k] = 0.f;
                 }
             }
-        }
-    }
-    else if (s1d1)
-    {
-        /*
-         slower, but still fast branch for sy=sx=1, dy=dx=1.
-         in this case we copy data from input tensors by chunks and
-         interleave the data in inpbuf with 0's
-         (that correspond to the padding elements) when necessary
-         */
-        int y0 = yx0 / W0, x0 = yx0 % W0;
-        for (int c = 0; c < Cg; c++)
-        {
-            float *inpbuf_c = inpbuf + c * (FAST_CONV_NR * ksize);
-            const float *inptr_c = inptr + c * inp_planesize;
-
-            for (int k = 0; k < ksize; k++)
-            {
-                int x0_tmp = x0;
-
-                int xi_0 = yxtab[k * 2 + 1] - pad_left;
-
-                int yi = y0 + yxtab[k * 2] - pad_top, xi = x0_tmp + xi_0;
-                float *inpbuf_k = inpbuf_c + k * FAST_CONV_NR;
+        }});
 
-                int i = 0;
-                for (; i < slice_len;) {
-                    int copyLen = std::min(slice_len - i, W0 - x0_tmp);
-
-                    int di_z = (i + copyLen == slice_len) ? FAST_CONV_NR - slice_len
-                                                          : 0; // The final padding.
-                    // pad_top or pad bottom
-                    if (yi < 0 || yi > Hi - 1)
-                    {
-                        memset(inpbuf_k + i,
-                               0, (copyLen + di_z) * sizeof(inpbuf_k[0]));
-                        i += copyLen + di_z;
-                    }
-                    else
-                    {
-                        int x_pad_left = 0, x_pad_right = 0;
-
-                        // pad_left
-                        if (xi < 0)
-                        {
-                            x_pad_left = std::min(-xi, copyLen);
-                            xi = 0;
-                            copyLen -= x_pad_left;
-                        }
-
-                        memset(inpbuf_k + i,
-                               0, x_pad_left * sizeof(inpbuf_k[0]));
-                        i += x_pad_left;
-
-                        // pad right
-                        if (xi + copyLen > Wi)
-                        {
-                            if (xi > Wi)
-                            {
-                                x_pad_right = copyLen;
-                                copyLen = 0;
-                            }
-                            else
-                            {
-                                x_pad_right = std::min(xi + copyLen - Wi, copyLen);
-                                copyLen -= x_pad_right;
-                            }
-                        }
-
-                        CV_Assert(copyLen >= 0);
-
-                        const float *inptr_k = inptr_c + yi * Wi + xi;
-                        memcpy(inpbuf_k + i,
-                               inptr_k,
-                               copyLen * sizeof(inpbuf_k[0]));
-
-                        i += copyLen;
-
-                        // pad_right and the final padding.
-                        memset(inpbuf_k + i,
-                               0, (di_z + x_pad_right) * sizeof(inpbuf_k[0]));
-                        i += x_pad_right + di_z;
-                    }
-
-                    x0_tmp = 0;
-                    xi = xi_0;
-                    yi++;
-                }
-            }
-        }
-    }
-    else
-    {
-        int y0_ = yx0 / W0, x0_ = yx0 - y0_ * W0;
-        for (int k = 0; k < ksize; k++)
+        // Prepare Weight for Winograd F(6x6, 3x3)
+        if (conv->ifWinograd63)
         {
-            int dy = yxtab[k * 2], dx = yxtab[k * 2 + 1];
-            int i = 0, y0 = y0_, x0 = x0_;
-            for (; i < FAST_CONV_NR;)
-            {
-                float *inpbuf_ki = inpbuf + k * FAST_CONV_NR + i;
-                int yi = y0 * stride_y + dy - pad_top;
-                int xi = x0 * stride_x + dx - pad_left;
-
-                if ((unsigned) yi < (unsigned) Hi &&
-                    (unsigned) xi < (unsigned) Wi)
-                {
-                    const float *inptr_ki = inptr + yi * Wi + xi;
-                    if (i + 4 <= FAST_CONV_NR && x0 + 4 <= W0 && xi + stride_x * 4 <= Wi)
-                    {
-                        if (stride_x == 2) {
-                            for (int c = 0; c < Cg; c++, inpbuf_ki += FAST_CONV_NR *
-                                                                      ksize, inptr_ki += inp_planesize)
-                            {
-                                float t0 = inptr_ki[0], t1 = inptr_ki[2];
-                                float t2 = inptr_ki[4], t3 = inptr_ki[6];
-                                inpbuf_ki[0] = t0;
-                                inpbuf_ki[1] = t1;
-                                inpbuf_ki[2] = t2;
-                                inpbuf_ki[3] = t3;
-                            }
-                        }
-                        else
-                        {
-                            for (int c = 0; c < Cg; c++, inpbuf_ki += FAST_CONV_NR *
-                                                                      ksize, inptr_ki += inp_planesize)
-                            {
-                                float t0 = inptr_ki[0], t1 = inptr_ki[stride_x];
-                                float t2 = inptr_ki[stride_x * 2], t3 = inptr_ki[stride_x * 3];
-                                inpbuf_ki[0] = t0;
-                                inpbuf_ki[1] = t1;
-                                inpbuf_ki[2] = t2;
-                                inpbuf_ki[3] = t3;
-                            }
-                        }
-                        i += 4;
-                        x0 += 4;
-                    }
-                    else
-                    {
-                        for (int c = 0; c < Cg; c++, inpbuf_ki += FAST_CONV_NR *
-                                                                  ksize, inptr_ki += inp_planesize)
-                            *inpbuf_ki = *inptr_ki;
-                        i++;
-                        x0++;
-                    }
-                }
-                else
-                {
-                    for (int c = 0; c < Cg; c++, inpbuf_ki += FAST_CONV_NR * ksize)
-                        inpbuf_ki[0] = 0.f;
-                    i++;
-                    x0++;
-                }
-                int mask = x0 >= W0;
-                y0 += mask;
-                x0 &= mask - 1;
-            }
+            initWinograd63(conv, weightsMat, K, C);
         }
     }
-}
-
-static void matMulCompute(float* outptr0, float* inpbuf_task, float* cbuf, const Ptr<FastConv2d>& conv, int HkWkCg,
-                          int k0, int k1, int yx0, int yx1, size_t out_planesize, int g, int Kg, int Kg_aligned,
-                          bool partial0, ActivationLayer*& activ, float minval, float maxval, bool ifMinMaxAct)
-{
-    int outstep0 = out_planesize;
 
-    for (int k = k0; k < k1; k += FAST_CONV_MR, outptr0 += outstep0 * FAST_CONV_MR)
+    // store bias; append some zero's to make sure that
+    // we can always read MR elements starting from any valid index
     {
-        int dk = Kg - k < FAST_CONV_MR ? Kg - k : FAST_CONV_MR;
-        bool partial = partial0 || dk < FAST_CONV_MR;
-        float *outptr = outptr0;
-
-        int outstep = outstep0;
-        if (partial)
-        {
-            outptr = cbuf;
-            outstep = FAST_CONV_NR;
-        }
-
-
-#if CV_TRY_AVX2
-        if (conv->useAVX2)
-            opt_AVX2::convBlock_AVX2( HkWkCg, conv->weightsBuf.data() + (g * Kg_aligned + k) * HkWkCg,
-                                  inpbuf_task, outptr, outstep, conv->biasBuf.data() + Kg * g + k,
-                                  minval, maxval, ifMinMaxAct);
-        else
-#endif
-#if CV_TRY_NEON
-        if (conv->useNEON)
-            opt_NEON::convBlock_NEON(HkWkCg, conv->weightsBuf.data() + (g * Kg_aligned + k) * HkWkCg,
-                                 inpbuf_task, outptr, outstep, conv->biasBuf.data() + Kg * g + k,
-                                 minval, maxval, ifMinMaxAct);
-        else
-#endif
-            convBlock(HkWkCg, conv->weightsBuf.data() + (g * Kg_aligned + k) * HkWkCg,
-                            inpbuf_task, outptr, outstep, conv->biasBuf.data() + Kg * g + k,
-                            minval, maxval, ifMinMaxAct);
-
-        // activation
-        if (activ)
-            activ->forwardSlice(outptr, outptr, yx1 - yx0, outstep, Kg * g + k,
-                                Kg * g + k + dk);
-
-        if (partial)
-        {
-            for (int i = 0; i < dk; i++)
-                memcpy(outptr0 + i * outstep0, cbuf + i * FAST_CONV_NR,
-                       (yx1 - yx0) * sizeof(cbuf[0]));
-        }
+        int k = 0, nbias = K + CONV_MR - 1;
+        conv->biasBuf.reserve(nbias);
+        float* biasBufPtr = conv->biasBuf.data();
+        for(; k < K; k++)
+            biasBufPtr[k] = srcBias ? srcBias[k] : 0.f;
+        for(; k < nbias; k++)
+            biasBufPtr[k] = 0.f;
     }
+    return conv;
 }
 
-void runFastConv2d(InputArray _input, OutputArray _output,
-                   const Ptr<FastConv2d>& conv, int ntasks, const Ptr<ActivationLayer>& actLayer)
+void runFastConv2d(InputArray _input, OutputArray _output, const Ptr<FastConv2d>& conv, int ntasks,
+                   const Ptr<ActivationLayer>& actLayer, bool fusedAdd)
 {
     Mat input = _input.getMat();
     Mat output = _output.getMat();
+
+    Mat fusedAddMat;
+    if (fusedAdd)
+        fusedAddMat = _output.getMat();
+
     MatShape inputShape = shape(input);
     MatShape outputShape = shape(output);
     CV_Assert(inputShape.size() == 4 && outputShape.size() == 4);
@@ -452,93 +186,69 @@ void runFastConv2d(InputArray _input, OutputArray _output,
 
     if (conv->ngroups  > 1 && conv->ngroups == conv->K && conv->ngroups == conv->C)
     {
+        CV_Assert(fusedAddMat.empty()); // Depthwise-Convolution layer should not be followed by Add layer.
         return runDepthwise(input, output, conv, minval, maxval, activ, ifMinMaxAct);
     }
 
 #if CV_NEON
-    if ( conv->ifWinograd63
+    if (conv->ifWinograd63
          && inputShape[2] > 12 && inputShape[3] > 12
-         && inputShape[2] < 120 && inputShape[3] < 120 )
+         && inputShape[2] < 120 && inputShape[3] < 120
+         )
     {
-        // In general, for winograd branch, more cores will give better performance.
-        int maxNumThread = std::max(getNumThreads(), 1);
-        if (runWinograd63(input, output, conv, maxNumThread, minval, maxval, activ, ifMinMaxAct))
+        if (runWinograd63(input, fusedAddMat, output, conv, ntasks, minval, maxval, activ, ifMinMaxAct))
             return;
     }
 #endif
 
-    float* inp = input.ptr<float>();
-    float* out = output.ptr<float>();
-
     int N = inputShape[0], C = inputShape[1], Hi = inputShape[2], Wi = inputShape[3];  // [N, C, H, W]
     int K = conv->K, Hk = conv->Hk, Wk = conv->Wk;
-    int H0 = outputShape[2], W0 = outputShape[3], ngroups = conv->ngroups;         // ngroups
+    int H0 = outputShape[2], W0 = outputShape[3], ngroups = conv->ngroups;
     int Cg = C/ngroups, Kg = K/ngroups;
-    int Kg_nblocks = (Kg + FAST_CONV_MR-1)/FAST_CONV_MR, Kg_aligned = Kg_nblocks*FAST_CONV_MR; // align to MR
 
     const size_t inp_planesize = (size_t)Hi*Wi;
     const size_t out_planesize = (size_t)H0*W0;
 
-    int pad_top = conv->pad_top, pad_bottom = conv->pad_bottom;
+    int pad_top = conv->pad_top;
     int pad_left = conv->pad_left;
-    int pad_right = conv->pad_right;
 
     int stride_y = conv->stride_y, stride_x = conv->stride_x;
     int dilation_y = conv->dilation_y, dilation_x = conv->dilation_x;
 
     int ksize = Hk * Wk;
-    bool s1d1 = stride_x == 1 && stride_y == 1 && dilation_x == 1 && dilation_y == 1;
-    bool s1d1p0 = s1d1 && pad_top == 0 && pad_left ==0 && pad_bottom == 0 && pad_right == 0;
     bool fast_1x1 = stride_x == 1 && stride_y == 1 && ksize == 1;
     int HkWkCg = Hk*Wk*Cg;
 
-    enum { VEC_ALIGN = 8, DFT_TYPE = CV_32F };
-    size_t taskbufsize = FAST_CONV_NR*HkWkCg; // input buffer
-    size_t taskbufsizeOutput = FAST_CONV_NR * FAST_CONV_MR;
-    size_t inputbufsize = 0;
-    size_t outbufsize = ntasks * taskbufsizeOutput;
+    enum { VEC_ALIGN = 8, DFT_TYPE = CV_32F }; // Memory alignment.
+    int MAX_STRIPES = 2; // (56 + CONV_NR - 1)/CONV_NR;
+
+    // Friendly to L1 cache
+    const int K_BLOCK_SIZE = 32;
+    const int C_BLOCK_SIZE = 256;
 
-    int stripes_per_sample = (out_planesize + FAST_CONV_NR - 1)/FAST_CONV_NR; // align to NR
-    size_t hw_task = stripes_per_sample;
-    size_t hw_aligned = stripes_per_sample * FAST_CONV_NR;
+    int Kg_nblocks = (Kg + CONV_MR-1)/CONV_MR, Kg_aligned = Kg_nblocks * CONV_MR;
 
-    bool separatedLoop = false;
+    int stripes_per_sample = (out_planesize + CONV_NR - 1) / CONV_NR;
 
-    if (stripes_per_sample < 4 * ntasks)
+    if (stripes_per_sample < ntasks * 4)
     {
-        // If stripes_per_sample is small, we parallelize on K (output channel).
+        MAX_STRIPES = 1;
         stripes_per_sample = 1;
-
-        // Separated Parallelloop could save much time in packing input data. But it may cost more memory, we use it when batch size is 1.
-        if (N == 1)
-        {
-            separatedLoop = true;
-            inputbufsize = ngroups * hw_aligned * HkWkCg;
-        }
-
-        if (!separatedLoop)
-        {
-            inputbufsize = taskbufsize * ntasks;
-        }
     }
     else
-    {
-        // If stripes_per_sample is big, we parallelize on H0*W0.
         Kg_nblocks = 1;
-        inputbufsize = taskbufsize * ntasks;
-    }
 
     int Kstripes = Kg_nblocks*stripes_per_sample;
     int nsubtasks = N*ngroups*Kstripes;
 
-    AutoBuffer<float> inpbuf_all_, outputbuf_;
-    inputbufsize = alignSize(inputbufsize, VEC_ALIGN);
-    inpbuf_all_.allocate(inputbufsize + VEC_ALIGN);
-    float* inpbuf_all = alignPtr(inpbuf_all_.data(), (int)(VEC_ALIGN*sizeof(float)));
+    size_t stripesize = CONV_NR * ksize * Cg;
+    size_t taskbufsize = (stripesize + CONV_NR * K_BLOCK_SIZE) * MAX_STRIPES;
+    size_t totalbufsize = taskbufsize * ntasks;
 
-    outbufsize = alignSize(outbufsize, VEC_ALIGN);
-    outputbuf_.allocate(outbufsize + VEC_ALIGN);
-    float* output_buf = alignPtr(outputbuf_.data(), (int)(VEC_ALIGN*sizeof(float)));
+    AutoBuffer<float> inpbuf_all_;
+    totalbufsize = alignSize(totalbufsize, VEC_ALIGN);
+    inpbuf_all_.allocate(totalbufsize + VEC_ALIGN);
+    float* inpbuf_all = alignPtr(inpbuf_all_.data(), (int)(VEC_ALIGN*sizeof(inpbuf_all_[0])));
 
     std::vector<int> ofstab_(Hk*Wk*3, 0);
     int* ofstab = ofstab_.data();
@@ -554,141 +264,306 @@ void runFastConv2d(InputArray _input, OutputArray _output,
             ofstab[k] = dy*Wi + dx;
         }
 
-    if (ksize == 1)
-    {
-        CV_Assert(pad_left == 0 && pad_right == 0 && pad_top == 0 && pad_bottom == 0);
-        CV_Assert(stride_x != 1 || stride_y != 1 || (H0 == Hi && W0 == Wi));
-    }
+    float* inp = input.ptr<float>();
+    float* out = output.ptr<float>();
+    float* fusedAddPtr0 = fusedAddMat.empty() ? 0 : fusedAddMat.ptr<float>();
 
-    if (separatedLoop)
+    parallel_for_(Range(0, ntasks), [&](const Range& r0) {
+    for (int task_id = r0.start; task_id < r0.end; task_id++)
     {
-        // For now this branch only handles batch size = 1. Maybe we could support batch size < 10 in the future.
-        // Pack Input data
-        parallel_for_(Range(0, ngroups * hw_task), [&](const Range& r0)
+        float* inpbuf_task = &inpbuf_all[taskbufsize * task_id];
+        float* cbuf_task = inpbuf_task + stripesize * MAX_STRIPES;
+
+        int ngs0 = (int)((size_t)nsubtasks * task_id / ntasks);
+        int ngs1 = (int)((size_t)nsubtasks * (task_id+1) / ntasks);
+        for (int subtask = ngs0; subtask < ngs1; )
         {
-            for (int nhwi = r0.start; nhwi < r0.end; nhwi++)
+            int ng = subtask / Kstripes;
+            int kyx0 = subtask - ng * Kstripes;
+            int kyx1 = kyx0 + (ngs1 - subtask);
+            int n = ng / ngroups, g = ng % ngroups; // ng - n * ngroups;
+            size_t inp_plane_ofs = (size_t)(n * ngroups + g) * Cg * inp_planesize;
+            kyx1 = kyx1 <= Kstripes ? kyx1 : Kstripes;
+            subtask += kyx1 - kyx0;
+            int k0, k1;
+            int yx0, yx_limit, yx_block_limit = 0;
+
+            if (stripes_per_sample == 1)
             {
-                int g = nhwi/hw_task;
-                int hw_i = nhwi % hw_task;
-                int hw0 = hw_i * FAST_CONV_NR;
-                float* inpbuf = inpbuf_all + g * hw_aligned * HkWkCg + hw0 * HkWkCg;
-                const float* inptr = inp + g * Cg * inp_planesize;
-                bool partial0 = hw0 + FAST_CONV_NR > out_planesize? true: false;
-                int slice_len = FAST_CONV_NR;
-
-                if (partial0)
-                    slice_len = out_planesize - hw0;
-
-                packInput(inpbuf, inptr, yxtab, ksize, Cg, Hi, Wi, W0, pad_top, pad_left, stride_x, stride_y,
-                          hw0, slice_len, fast_1x1, partial0, s1d1p0, s1d1);
+                k0 = kyx0 * CONV_MR;
+                k1 = kyx1 * CONV_MR;
+                k1 = k1 <= Kg ? k1 : Kg;
+                yx0 = 0;
+                yx_limit = out_planesize;
             }
-        });
-
-        // Compute
-        parallel_for_(Range(0, ntasks), [&](const Range& r0)
-        {
-            for (int task_id = r0.start; task_id < r0.end; task_id++)
+            else
             {
-                float *cbuf = output_buf + task_id * taskbufsizeOutput;
-                int ngs0 = (int) ((size_t) nsubtasks * task_id / ntasks);
-                int ngs1 = (int) ((size_t) nsubtasks * (task_id + 1) / ntasks);
-                for (int subtask = ngs0; subtask < ngs1;)
-                {
-                    int ng = subtask / Kstripes;
-                    int kyx0 = subtask - ng * Kstripes;
-                    int kyx1 = kyx0 + (ngs1 - subtask);
-                    int n = ng / ngroups, g = ng - n * ngroups;
-
-                    CV_Assert(n <= 1);
+                k0 = 0;
+                k1 = Kg;
+                yx0 = kyx0 * CONV_NR;
+                yx_limit = kyx1 * CONV_NR;
+                yx_limit = yx_limit < out_planesize ? yx_limit : out_planesize;
+            }
 
-                    kyx1 = kyx1 <= Kstripes ? kyx1 : Kstripes; // Guarantee that maximum kyx1 is Kstripes.
-                    subtask += kyx1 - kyx0;
+            for (; yx0 < yx_limit; yx0 = yx_block_limit)
+            {
+                // step 1. extract part of input tensor and represent it in zigzag form
+                yx_block_limit = yx0 + CONV_NR * MAX_STRIPES;
+                yx_block_limit = yx_block_limit < yx_limit ? yx_block_limit : yx_limit;
 
-                    int k0 = kyx0 * FAST_CONV_MR;
-                    int k1 = kyx1 * FAST_CONV_MR;
-                    k1 = k1 <= Kg ? k1 : Kg;
+                int nstripes = (yx_block_limit - yx0 + CONV_NR - 1) / CONV_NR;
+                int yx0_saved = yx0;
 
+                CV_Assert(nstripes <= MAX_STRIPES);
 
-                    for (int yx0 = 0; yx0 < out_planesize; yx0 += FAST_CONV_NR)
-                    {
-                        float* inpbuf_task = inpbuf_all + g * hw_aligned * HkWkCg + yx0 * HkWkCg;
-                        int yx1 = yx0 + FAST_CONV_NR;
-                        yx1 = yx1 <= out_planesize ? yx1 : out_planesize;
-                        int slice_len = yx1 - yx0;
-                        bool partial0 = slice_len < FAST_CONV_NR;
-
-                        int outstep0 = out_planesize;
-                        size_t outofs = ((n * ngroups + g) * Kg + k0) * outstep0 + yx0;
-                        float *outptr0 = out + outofs;
-
-                        matMulCompute(outptr0, inpbuf_task, cbuf, conv, HkWkCg, k0, k1, yx0, yx1, out_planesize, g,
-                                      Kg, Kg_aligned, partial0, activ, minval, maxval, ifMinMaxAct);
-                    }
-                }
-            }
-        });
-    }
-    else
-    {
-        parallel_for_(Range(0, ntasks), [&](const Range &r0) {
-            for (int task_id = r0.start; task_id < r0.end; task_id++) {
-                float *inpbuf_task = &inpbuf_all[taskbufsize * task_id];
-                float *cbuf = output_buf + task_id * taskbufsizeOutput;
-                int ngs0 = (int) ((size_t) nsubtasks * task_id / ntasks);
-                int ngs1 = (int) ((size_t) nsubtasks * (task_id + 1) / ntasks);
-
-                for (int subtask = ngs0; subtask < ngs1;)
+                for (int stripe = 0; yx0 < yx_block_limit; stripe++, yx0 += CONV_NR)
                 {
-                    int ng = subtask / Kstripes;
-                    int kyx0 = subtask - ng * Kstripes;
-                    int kyx1 = kyx0 + (ngs1 - subtask);
-                    int n = ng / ngroups, g = ng - n * ngroups;
-                    size_t inp_plane_ofs = (size_t) (n * ngroups + g) * Cg * inp_planesize;
-                    kyx1 = kyx1 <= Kstripes ? kyx1 : Kstripes; // Guarantee that maximum kyx1 is Kstripes.
-                    subtask += kyx1 - kyx0;
-                    int k0, k1;
-                    int yx0, yx_limit;
-
-                    if (stripes_per_sample == 1)
+                    float* inpbuf = inpbuf_task + stripe * stripesize;
+                    float* inptr = inp + inp_plane_ofs;
+
+                    /*
+                        1. pack the data. Copy the HkxWk CONV_NR-wide slices from
+                           each feature plane of the input tensor to the input buffer.
+                    */
+                    if (fast_1x1)
                     {
-                        k0 = kyx0 * FAST_CONV_MR;
-                        k1 = kyx1 * FAST_CONV_MR;
-                        k1 = k1 <= Kg ? k1 : Kg;
-                        yx0 = 0;
-                        yx_limit = out_planesize;
+                        int slice_len = yx_block_limit - yx0;
+                        bool partial = slice_len < CONV_NR;
+                        // Superfast branch for 1x1 convolutions with sy=sx=1.
+                        // in this case each feature plane can be safely treated
+                        // as 1D array, and we just extract next portion
+                        // of CONV_NR elements from each feature plane and
+                        // put it together.
+                        inptr += yx0;
+                        if (!partial)
+                        {
+                            // Make special branch where memcpy() is called with a constant buffer size.
+                            // Compilers will likely unroll this loop properly.
+                            for (int c = 0; c < Cg; c++, inptr += inp_planesize, inpbuf += CONV_NR)
+                                memcpy(inpbuf, inptr, CONV_NR*sizeof(inpbuf[0]));
+                        }
+                        else
+                        {
+                            for (int c = 0; c < Cg; c++, inptr += inp_planesize, inpbuf += CONV_NR)
+                            {
+                                memcpy(inpbuf, inptr, slice_len * sizeof(inpbuf[0]));
+                                memset(inpbuf + slice_len, 0, (CONV_NR - slice_len) * sizeof(inpbuf[0]));
+                            }
+                        }
                     }
                     else
                     {
-                        k0 = 0;
-                        k1 = Kg;
-                        yx0 = kyx0 * FAST_CONV_NR;
-                        yx_limit = kyx1 * FAST_CONV_NR;
-                        yx_limit = yx_limit < out_planesize ? yx_limit : out_planesize;
+                        int y0_ = yx0 / W0, x0_ = yx0 - y0_ * W0;
+                        for (int k = 0; k < ksize; k++)
+                        {
+                            int dy = yxtab[k * 2], dx = yxtab[k * 2 + 1];
+                            int i = 0, y0 = y0_, x0 = x0_;
+                            for (; i < CONV_NR;)
+                            {
+                                float *inpbuf_ki = inpbuf + k * CONV_NR * Cg + i;
+                                int yi = y0 * stride_y + dy - pad_top;
+                                int xi = x0 * stride_x + dx - pad_left;
+
+                                if ((unsigned) yi < (unsigned) Hi && (unsigned) xi < (unsigned) Wi)
+                                {
+                                    const float *inptr_ki = inptr + yi * Wi + xi;
+                                    if (i + 8 <= CONV_NR && x0 + 8 <= W0 && xi + stride_x * 8 <= Wi)
+                                    {
+                                        if (stride_x == 1)
+                                        {
+                                            for (int c = 0; c < Cg; c++, inpbuf_ki += CONV_NR, inptr_ki += inp_planesize)
+                                            {
+                                                float t0 = inptr_ki[0], t1 = inptr_ki[1];
+                                                float t2 = inptr_ki[2], t3 = inptr_ki[3];
+                                                float t4 = inptr_ki[4], t5 = inptr_ki[5];
+                                                float t6 = inptr_ki[6], t7 = inptr_ki[7];
+                                                inpbuf_ki[0] = t0; inpbuf_ki[1] = t1;
+                                                inpbuf_ki[2] = t2; inpbuf_ki[3] = t3;
+                                                inpbuf_ki[4] = t4; inpbuf_ki[5] = t5;
+                                                inpbuf_ki[6] = t6; inpbuf_ki[7] = t7;
+                                            }
+                                        }
+                                        else
+                                        {
+                                            for (int c = 0; c < Cg; c++, inpbuf_ki += CONV_NR, inptr_ki += inp_planesize)
+                                            {
+                                                float t0 = inptr_ki[0], t1 = inptr_ki[stride_x];
+                                                float t2 = inptr_ki[stride_x*2], t3 = inptr_ki[stride_x*3];
+                                                float t4 = inptr_ki[stride_x*4], t5 = inptr_ki[stride_x*5];
+                                                float t6 = inptr_ki[stride_x*6], t7 = inptr_ki[stride_x*7];
+                                                inpbuf_ki[0] = t0; inpbuf_ki[1] = t1;
+                                                inpbuf_ki[2] = t2; inpbuf_ki[3] = t3;
+                                                inpbuf_ki[4] = t4; inpbuf_ki[5] = t5;
+                                                inpbuf_ki[6] = t6; inpbuf_ki[7] = t7;
+                                            }
+                                        }
+                                        i += 8;
+                                        x0 += 8;
+                                    }
+                                    else if (i + 4 <= CONV_NR && x0 + 4 <= W0 && xi + stride_x * 4 <= Wi)
+                                    {
+                                        if (stride_x == 1)
+                                        {
+                                            for (int c = 0; c < Cg; c++, inpbuf_ki += CONV_NR, inptr_ki += inp_planesize)
+                                            {
+                                                float t0 = inptr_ki[0], t1 = inptr_ki[1];
+                                                float t2 = inptr_ki[2], t3 = inptr_ki[3];
+                                                inpbuf_ki[0] = t0; inpbuf_ki[1] = t1;
+                                                inpbuf_ki[2] = t2; inpbuf_ki[3] = t3;
+                                            }
+                                        }
+                                        else
+                                        {
+                                            for (int c = 0; c < Cg; c++, inpbuf_ki += CONV_NR, inptr_ki += inp_planesize)
+                                            {
+                                                float t0 = inptr_ki[0], t1 = inptr_ki[stride_x];
+                                                float t2 = inptr_ki[stride_x*2], t3 = inptr_ki[stride_x*3];
+                                                inpbuf_ki[0] = t0; inpbuf_ki[1] = t1;
+                                                inpbuf_ki[2] = t2; inpbuf_ki[3] = t3;
+                                            }
+                                        }
+                                        i += 4;
+                                        x0 += 4;
+                                    }
+                                    else
+                                    {
+                                        for (int c = 0; c < Cg; c++, inpbuf_ki += CONV_NR, inptr_ki += inp_planesize)
+                                            *inpbuf_ki = *inptr_ki;
+                                        i++;
+                                        x0++;
+                                    }
+                                }
+                                else
+                                {
+                                    for (int c = 0; c < Cg; c++, inpbuf_ki += CONV_NR)
+                                        inpbuf_ki[0] = 0.f;
+                                    i++;
+                                    x0++;
+                                }
+                                int mask = x0 >= W0;
+                                y0 += mask;
+                                x0 &= mask - 1;
+                            }
+                        }
                     }
+                }
 
-                    for (; yx0 < yx_limit; yx0 += FAST_CONV_NR)
+                yx0 = yx0_saved;
+                float* weights = conv->weightsBuf.data() + g * Kg_aligned * HkWkCg;
+                const float* biasptr = conv->biasBuf.data() + Kg * g;
+                int ldc = nstripes * CONV_NR;
+
+                // 2. do convolution, compute Kg x (yx_block_limit - yx0) part of the output tensor
+                for (int k0_block = k0; k0_block < k1; k0_block += K_BLOCK_SIZE)
+                {
+                    int k1_block = k0_block + K_BLOCK_SIZE < k1 ? k0_block + K_BLOCK_SIZE : k1;
+                    for (int c0 = 0; c0 < HkWkCg; c0 += C_BLOCK_SIZE)
                     {
-                        float *inpbuf = inpbuf_task;
-                        const float *inptr = inp + inp_plane_ofs;
-                        int yx1 = yx0 + FAST_CONV_NR;
-                        yx1 = yx1 <= yx_limit ? yx1 : yx_limit;
-                        int slice_len = yx1 - yx0;
-                        bool partial0 = slice_len < FAST_CONV_NR;
-                        packInput(inpbuf, inptr, yxtab, ksize, Cg, Hi, Wi, W0, pad_top, pad_left, stride_x, stride_y,
-                                     yx0, slice_len, fast_1x1, partial0, s1d1p0, s1d1);
-
-                        // 2. do convolution, compute Kg x (yx1 - yx0) part of the output tensor
-                        int outstep0 = out_planesize;
-                        size_t outofs = ((n * ngroups + g) * Kg + k0) * outstep0 + yx0;
-                        float *outptr0 = out + outofs;
-
-                        matMulCompute(outptr0, inpbuf_task, cbuf, conv, HkWkCg, k0, k1, yx0, yx1, out_planesize, g,
-                                      Kg, Kg_aligned, partial0, activ, minval, maxval, ifMinMaxAct);
+                        int c1 = c0 + C_BLOCK_SIZE < HkWkCg ? c0 + C_BLOCK_SIZE : HkWkCg;
+                        for (int stripe = 0; stripe < nstripes; stripe++)
+                        {
+                            float* wptr = weights + k0_block*HkWkCg + c0*CONV_MR;
+                            const float* inptr = inpbuf_task + stripe*stripesize + c0 * CONV_NR;
+                            float* cptr = cbuf_task + stripe * CONV_NR;
+                            for (int k = k0_block; k < k1_block; k += CONV_MR,
+                                    wptr += HkWkCg * CONV_MR, cptr += CONV_MR * ldc)
+                            {
+#if CV_TRY_AVX2
+                                if (conv->useAVX2)
+                                    opt_AVX2::convBlock_AVX2(c1 - c0, wptr, inptr, cptr, ldc, c0 == 0);
+                                else
+#endif
+#if CV_TRY_NEON
+                                if (conv->useNEON)
+                                    opt_NEON::convBlock_NEON(c1 - c0, wptr, inptr, cptr, ldc, c0 == 0);
+                                else
+#endif
+                                    convBlock(c1 - c0, wptr, inptr, cptr, ldc, c0 == 0);
+                            }
+                        }
+                    }
+
+                    size_t outofs = ((n*ngroups + g) * Kg + k0_block) * out_planesize + yx0;
+                    int out_width = yx_block_limit - yx0;
+                    const float* cptr = cbuf_task;
+
+                    float* outptr = out + outofs;
+                    const float* pbptr = fusedAddPtr0 ? fusedAddPtr0 + outofs : 0;
+
+                    for (int k = k0_block; k < k1_block; k++,
+                            cptr += ldc, outptr += out_planesize,
+                            pbptr += (pbptr ? out_planesize : 0))
+                    {
+                        float biasval = biasptr[k];
+                        int j = 0;
+#if CV_SIMD128
+                        v_float32x4 vbias = v_setall_f32(biasval), vmax = v_setall_f32(maxval), vmin = v_setall_f32(minval);
+                        if (pbptr)
+                        {
+                            for (; j + 7 < out_width; j += 8)
+                            {
+                                v_float32x4 v0 = v_add(v_load(cptr + j), vbias);
+                                v_float32x4 v1 = v_add(v_load(cptr + j + 4), vbias);
+                                v0 = v_add(v0, v_load(pbptr + j));
+                                v1 = v_add(v1, v_load(pbptr + j + 4));
+
+                                if (ifMinMaxAct)
+                                {
+                                    v0 = v_min(v_max(v0, vmin), vmax);
+                                    v1 = v_min(v_max(v1, vmin), vmax);
+                                }
+
+                                v_store(outptr + j, v0);
+                                v_store(outptr + j + 4, v1);
+                            }
+                        }
+                        else
+                        {
+                            for (; j + 7 < out_width; j += 8)
+                            {
+                                v_float32x4 v0 = v_add(v_load(cptr + j), vbias);
+                                v_float32x4 v1 = v_add(v_load(cptr + j + 4), vbias);
+
+                                if (ifMinMaxAct)
+                                {
+                                    v0 = v_min(v_max(v0, vmin), vmax);
+                                    v1 = v_min(v_max(v1, vmin), vmax);
+                                }
+
+                                v_store(outptr + j, v0);
+                                v_store(outptr + j + 4, v1);
+                            }
+                        }
+#endif
+                        if (pbptr) {
+                            for (; j < out_width; j++)
+                            {
+                                float v = cptr[j] + biasval;
+                                v += pbptr[j];
+                                if (ifMinMaxAct)
+                                    v = std::min(std::max(v, minval), maxval);
+                                outptr[j] = v;
+                            }
+                        }
+                        else
+                        {
+                            for (; j < out_width; j++)
+                            {
+                                float v = cptr[j] + biasval;
+
+                                if (ifMinMaxAct)
+                                    v = std::min(std::max(v, minval), maxval);
+                                outptr[j] = v;
+                            }
+                        }
+
+                        if (activ)
+                            activ->forwardSlice(outptr, outptr, out_width, out_planesize, Kg * g + k, Kg * g + k + 1);
                     }
                 }
             }
-        });
+        }
     }
+    });
 }
-
-}} // namespace cv::dnn
\ No newline at end of file
+}} // namespace cv::dnn
index c993781..eda5c7c 100644 (file)
@@ -7,20 +7,26 @@
 
 #include "opencv2/core/hal/intrin.hpp"
 
-#ifndef FAST_CONV_PRAM
-#define FAST_CONV_PRAM
+#ifndef CONV_PRAM
+#define CONV_PRAM
 #if CV_NEON && CV_NEON_AARCH64  // 32 registers.
-#define FAST_CONV_MR 4
-#define FAST_CONV_NR 28
+#define CONV_MR 4
+#define CONV_NR 28
 enum { FAST_VEC_NLANES=4 };
 #elif CV_NEON              // 16 registers.
-#define FAST_CONV_MR 4
-#define FAST_CONV_NR 12
+#define CONV_MR 4
+#define CONV_NR 12
 enum { FAST_VEC_NLANES=4 };
 #else // SIMD 128, AVX or AVX2
-#define FAST_CONV_MR 4
-#define FAST_CONV_NR 24
-enum { FAST_VEC_NLANES=4 };
+#define CONV_MR 4
+#define CONV_NR 24
+
+#ifdef CV_AVX2
+enum { FAST_VEC_NLANES=8 }; // AVX2
+#else
+enum { FAST_VEC_NLANES=4 }; // SIMD 128
+#endif
+
 #endif
 #endif
 
@@ -37,7 +43,6 @@ struct FastConv2d
 
     std::vector<float> weightsBuf;        // For generic Conv 2D
     std::vector<float> weightsWino63Buf;  // For Winograd F(6x6, 3x3).
-
     std::vector<float> biasBuf;
     bool ifWinograd63 = false;
     bool useAVX2 = checkHardwareSupport(CPU_AVX2);
@@ -52,20 +57,20 @@ Ptr<FastConv2d> initFastConv2d(
         int dilation_x, int dilation_y,
         const std::vector<size_t>& pads_begin,
         const std::vector<size_t>& pads_end,
-        float* srcWeights,
+        InputArray weightsMat,
         float* srcBias);
 
 // It contains different computing branches, like winograd, 1x1 conv.
-void runFastConv2d(InputArray _input, OutputArray _output,
-                 const Ptr<FastConv2d>& conv, int ntasks, const Ptr<ActivationLayer>& actLayer);
+void runFastConv2d(InputArray _input, OutputArray _output, const Ptr<FastConv2d>& conv, int ntasks,
+                   const Ptr<ActivationLayer>& actLayer, bool fusedAdd);
 
 void runDepthwise(InputArray _input, OutputArray _output, const Ptr<FastConv2d>& conv, float minval, float maxval,
         ActivationLayer* activ, bool ifMinMaxAct);
 
 // winograd init
-void initWinograd63(Ptr<FastConv2d>& conv, float* src_weight, int K, int C);
+void initWinograd63(Ptr<FastConv2d>& conv, InputArray weightsMat, int K, int C);
 
-int runWinograd63(InputArray _input, OutputArray _output, const Ptr<FastConv2d>& conv, int ntasks,
+int runWinograd63(InputArray _input, InputArray _fusedAddMat, OutputArray _output, const Ptr<FastConv2d>& conv, int ntasks,
                   float minval, float maxval, ActivationLayer* activ, bool ifMinMaxAct);
 
 } // namespace dnn
@@ -73,9 +78,7 @@ int runWinograd63(InputArray _input, OutputArray _output, const Ptr<FastConv2d>&
 namespace opt_AVX2
 {
 #if CV_TRY_AVX2
-void convBlock_AVX2(int k, const float *a, const float *b,
-        float *c, int ldc, const float *bias,
-        float minval, float maxval, bool ifActiv);
+void convBlock_AVX2(int np, const float* a, const float* b, float* c, int ldc, bool init_c);
 
 void depthWiseBlock_AVX2(const float *inptr, float *outptr, const float *weights, float biasval, int *ofstab, int *yxtab,
                 float minval, float maxval, int Hi, int Wi, int H0, int W0, int ksize, int pad_top, int pad_left,
index 94b0814..2e088a6 100644 (file)
 namespace cv {
 namespace dnn {
 
-void convBlock(int k, const float *a, const float *b,
-        float *c, int ldc, const float *bias,
-        float minval, float maxval, bool ifActiv)
+void convBlock(int np, const float* a, const float* b, float* c, int ldc, bool init_c)
 {
-#if CV_SIMD128
-#if FAST_CONV_MR == 4 && FAST_CONV_NR == 24
-    {
-        v_float32x4 c0 = v_setall_f32(bias[0]), c1 = c0, c2 = c0, c3 = c0, c4 = c0, c5 = c0;
-        v_float32x4 c6 = v_setall_f32(bias[1]), c7 = c6, c8 = c6, c9 = c6, c10 = c6, c11 = c6;
-        v_float32x4 c12 = v_setall_f32(bias[2]), c13 = c12, c14 = c12, c15 = c12, c16 = c12, c17 = c12;
-        v_float32x4 c18 = v_setall_f32(bias[3]), c19 = c18, c20 = c18, c21 = c18, c22 = c18, c23 = c18;
+#if 0 // CV_SIMD128 && CONV_MR == 4 && CONV_NR == 24
+    v_float32x4 c0  = v_setzero_f32(), c1 = c0, c2 = c0, c3 = c0, c4 = c0, c5 = c0;
+    v_float32x4 c6  = v_setzero_f32(), c7 = c6, c8 = c6, c9 = c6, c10 = c6, c11 = c6;
+    v_float32x4 c12 = v_setzero_f32(), c13 = c12, c14 = c12, c15 = c12, c16 = c12, c17 = c12;
+    v_float32x4 c18 = v_setzero_f32(), c19 = c18, c20 = c18, c21 = c18, c22 = c18, c23 = c18;
 
-        for (int p = 0; p < k; p++, a += FAST_CONV_MR, b += FAST_CONV_NR)
-        {
-            v_float32x4 a0 = v_setall_f32(a[0]);
-            v_float32x4 b0 = v_load(b), b1 = v_load(b + 4), b2 = v_load(b + 8);
-            v_float32x4 b3 = v_load(b + 12), b4 = v_load(b + 16), b5 = v_load(b + 20);
-
-            c0 = v_fma(b0, a0, c0);
-            c1 = v_fma(b1, a0, c1);
-            c2 = v_fma(b2, a0, c2);
-            c3 = v_fma(b3, a0, c3);
-            c4 = v_fma(b4, a0, c4);
-            c5 = v_fma(b5, a0, c5);
-
-            a0  = v_setall_f32(a[1]);
-            c6  = v_fma(b0, a0, c6);
-            c7  = v_fma(b1, a0, c7);
-            c8  = v_fma(b2, a0, c8);
-            c9  = v_fma(b3, a0, c9);
-            c10 = v_fma(b4, a0, c10);
-            c11 = v_fma(b5, a0, c11);
-
-            a0 = v_setall_f32(a[2]);
-            c12 = v_fma(b0, a0, c12);
-            c13 = v_fma(b1, a0, c13);
-            c14 = v_fma(b2, a0, c14);
-            c15 = v_fma(b3, a0, c15);
-            c16 = v_fma(b4, a0, c16);
-            c17 = v_fma(b5, a0, c17);
-
-            a0 = v_setall_f32(a[3]);
-            c18 = v_fma(b0, a0, c18);
-            c19 = v_fma(b1, a0, c19);
-            c20 = v_fma(b2, a0, c20);
-            c21 = v_fma(b3, a0, c21);
-            c22 = v_fma(b4, a0, c22);
-            c23 = v_fma(b5, a0, c23);
-        }
-
-        if (ifActiv) {
-            v_float32x4 vmin = v_setall_f32(minval), vmax = v_setall_f32(maxval);
-            c0 = v_min(v_max(c0, vmin), vmax);
-            c1 = v_min(v_max(c1, vmin), vmax);
-            c2 = v_min(v_max(c2, vmin), vmax);
-            c3 = v_min(v_max(c3, vmin), vmax);
-            c4 = v_min(v_max(c4, vmin), vmax);
-            c5 = v_min(v_max(c5, vmin), vmax);
-            c6 = v_min(v_max(c6, vmin), vmax);
-            c7 = v_min(v_max(c7, vmin), vmax);
-            c8 = v_min(v_max(c8, vmin), vmax);
-            c9 = v_min(v_max(c9, vmin), vmax);
-            c10 = v_min(v_max(c10, vmin), vmax);
-            c11 = v_min(v_max(c11, vmin), vmax);
-            c12 = v_min(v_max(c12, vmin), vmax);
-            c13 = v_min(v_max(c13, vmin), vmax);
-            c14 = v_min(v_max(c14, vmin), vmax);
-            c15 = v_min(v_max(c15, vmin), vmax);
-            c16 = v_min(v_max(c16, vmin), vmax);
-            c17 = v_min(v_max(c17, vmin), vmax);
-            c18 = v_min(v_max(c18, vmin), vmax);
-            c19 = v_min(v_max(c19, vmin), vmax);
-            c20 = v_min(v_max(c20, vmin), vmax);
-            c21 = v_min(v_max(c21, vmin), vmax);
-            c22 = v_min(v_max(c22, vmin), vmax);
-            c23 = v_min(v_max(c23, vmin), vmax);
-        }
-        v_store(c, c0);
-        v_store(c + 4, c1);
-        v_store(c + 8, c2);
-        v_store(c + 12, c3);
-        v_store(c + 16, c4);
-        v_store(c + 20, c5);
-
-        v_store(c + ldc, c6);
-        v_store(c + ldc + 4, c7);
-        v_store(c + ldc + 8, c8);
-        v_store(c + ldc + 12, c9);
-        v_store(c + ldc + 16, c10);
-        v_store(c + ldc + 20, c11);
-
-        v_store(c + ldc * 2, c12);
-        v_store(c + ldc * 2 + 4, c13);
-        v_store(c + ldc * 2 + 8, c14);
-        v_store(c + ldc * 2 + 12, c15);
-        v_store(c + ldc * 2 + 16, c16);
-        v_store(c + ldc * 2 + 20, c17);
-
-        v_store(c + ldc * 3, c18);
-        v_store(c + ldc * 3 + 4, c19);
-        v_store(c + ldc * 3 + 8, c20);
-        v_store(c + ldc * 3 + 12, c21);
-        v_store(c + ldc * 3 + 16, c22);
-        v_store(c + ldc * 3 + 20, c23);
+    for (int p = 0; p < np; p++, a += CONV_MR, b += CONV_NR)
+    {
+        v_float32x4 a0 = v_setall_f32(a[0]);
+        v_float32x4 b0 = v_load(b), b1 = v_load(b + 4), b2 = v_load(b + 8);
+        v_float32x4 b3 = v_load(b + 12), b4 = v_load(b + 16), b5 = v_load(b + 20);
+
+        c0 = v_fma(b0, a0, c0);
+        c1 = v_fma(b1, a0, c1);
+        c2 = v_fma(b2, a0, c2);
+        c3 = v_fma(b3, a0, c3);
+        c4 = v_fma(b4, a0, c4);
+        c5 = v_fma(b5, a0, c5);
+
+        a0  = v_setall_f32(a[1]);
+        c6  = v_fma(b0, a0, c6);
+        c7  = v_fma(b1, a0, c7);
+        c8  = v_fma(b2, a0, c8);
+        c9  = v_fma(b3, a0, c9);
+        c10 = v_fma(b4, a0, c10);
+        c11 = v_fma(b5, a0, c11);
+
+        a0 = v_setall_f32(a[2]);
+        c12 = v_fma(b0, a0, c12);
+        c13 = v_fma(b1, a0, c13);
+        c14 = v_fma(b2, a0, c14);
+        c15 = v_fma(b3, a0, c15);
+        c16 = v_fma(b4, a0, c16);
+        c17 = v_fma(b5, a0, c17);
+
+        a0 = v_setall_f32(a[3]);
+        c18 = v_fma(b0, a0, c18);
+        c19 = v_fma(b1, a0, c19);
+        c20 = v_fma(b2, a0, c20);
+        c21 = v_fma(b3, a0, c21);
+        c22 = v_fma(b4, a0, c22);
+        c23 = v_fma(b5, a0, c23);
     }
-#endif
-#else
-    for (int i = 0; i < FAST_CONV_MR; i++)
+
+    if (!init_c)
     {
-        float beta = bias[i];
-        for (int j = 0; j < FAST_CONV_NR; j++)
-            c[i*ldc + j] = beta;
+        c0 = v_add(c0, v_load(c));
+        c1 = v_add(c1, v_load(c + 4));
+        c2 = v_add(c2, v_load(c + 8));
+        c3 = v_add(c3, v_load(c + 12));
+        c4 = v_add(c4, v_load(c + 16));
+        c5 = v_add(c5, v_load(c + 20));
+
+        c6  = v_add(c6 , v_load(c + ldc));
+        c7  = v_add(c7 , v_load(c + ldc + 4));
+        c8  = v_add(c8 , v_load(c + ldc + 8));
+        c9  = v_add(c9 , v_load(c + ldc + 12));
+        c10 = v_add(c10, v_load(c + ldc + 16));
+        c11 = v_add(c11, v_load(c + ldc + 20));
+
+        c12 = v_add(c12, v_load(c + ldc*2));
+        c13 = v_add(c13, v_load(c + ldc*2 + 4));
+        c14 = v_add(c14, v_load(c + ldc*2 + 8));
+        c15 = v_add(c15, v_load(c + ldc*2 + 12));
+        c16 = v_add(c16, v_load(c + ldc*2 + 16));
+        c17 = v_add(c17, v_load(c + ldc*2 + 20));
+
+        c18 = v_add(c18, v_load(c + ldc*3));
+        c19 = v_add(c19, v_load(c + ldc*3 + 4));
+        c20 = v_add(c20, v_load(c + ldc*3 + 8));
+        c21 = v_add(c21, v_load(c + ldc*3 + 12));
+        c22 = v_add(c22, v_load(c + ldc*3 + 16));
+        c23 = v_add(c23, v_load(c + ldc*3 + 20));
     }
-    for (int p = 0; p < k; p++)
+
+    v_store(c, c0);
+    v_store(c + 4, c1);
+    v_store(c + 8, c2);
+    v_store(c + 12, c3);
+    v_store(c + 16, c4);
+    v_store(c + 20, c5);
+
+    v_store(c + ldc, c6);
+    v_store(c + ldc + 4, c7);
+    v_store(c + ldc + 8, c8);
+    v_store(c + ldc + 12, c9);
+    v_store(c + ldc + 16, c10);
+    v_store(c + ldc + 20, c11);
+
+    v_store(c + ldc * 2, c12);
+    v_store(c + ldc * 2 + 4, c13);
+    v_store(c + ldc * 2 + 8, c14);
+    v_store(c + ldc * 2 + 12, c15);
+    v_store(c + ldc * 2 + 16, c16);
+    v_store(c + ldc * 2 + 20, c17);
+
+    v_store(c + ldc * 3, c18);
+    v_store(c + ldc * 3 + 4, c19);
+    v_store(c + ldc * 3 + 8, c20);
+    v_store(c + ldc * 3 + 12, c21);
+    v_store(c + ldc * 3 + 16, c22);
+    v_store(c + ldc * 3 + 20, c23);
+#else
+    float cbuf[CONV_MR * CONV_NR];
+    memset(cbuf, 0, sizeof(cbuf));
+    for( int p = 0; p < np; p++ )
     {
-        for (int i = 0; i < FAST_CONV_MR; i++)
+        for( int i = 0; i < CONV_MR; i++ )
         {
-            float alpha = a[FAST_CONV_MR*p + i];
-            for (int j = 0; j < FAST_CONV_NR; j++)
-            {
-                c[i*ldc+j] += b[FAST_CONV_NR*p + j]*alpha;
-            }
+            float ai = a[CONV_MR*p + i];
+            for( int j = 0; j < CONV_NR; j++ )
+                cbuf[i * CONV_NR+j] += b[CONV_NR*p + j] * ai;
         }
     }
-    if (ifActiv)
-    {
-        for (int i = 0; i < FAST_CONV_MR; i++)
-        {
-            for (int j = 0; j < FAST_CONV_NR; j++)
-            {
-                float v = c[i*ldc + j];
-                v = std::min(std::max(v, minval), maxval);
-                c[i*ldc + j] = v;
-            }
+    if (!init_c) {
+        for(int i = 0; i < CONV_MR; i++) {
+            for(int j = 0; j < CONV_NR; j++)
+                c[i*ldc + j] += cbuf[i*CONV_NR + j];
+        }
+    } else {
+        for(int i = 0; i < CONV_MR; i++) {
+            for(int j = 0; j < CONV_NR; j++)
+                c[i*ldc + j] = cbuf[i*CONV_NR + j];
         }
     }
 #endif
@@ -154,142 +145,122 @@ void convBlock(int k, const float *a, const float *b,
 namespace opt_NEON
 {
 #if CV_TRY_NEON
-void convBlock_NEON(int k, const float *a, const float *b,
-                float *c, int ldc, const float *bias,
-                float minval, float maxval, bool ifActiv)
+void convBlock_NEON(int np, const float* a, const float* b, float* c, int ldc, bool init_c)
 {
-#if CV_NEON_AARCH64 && FAST_CONV_MR == 4 && FAST_CONV_NR == 28  // AARCH64
+#if CONV_MR == 4 && CONV_NR == 28  // AARCH64
     {
-        float32x4_t c0 = vdupq_n_f32(bias[0]), c1 = c0, c2 = c0, c3 = c0, c4 = c0, c5 = c0, c24 = c0;
-        float32x4_t c6 = vdupq_n_f32(bias[1]), c7 = c6, c8 = c6, c9 = c6, c10 = c6, c11 = c6, c25 = c6;
-        float32x4_t c12 = vdupq_n_f32(bias[2]), c13 = c12, c14 = c12, c15 = c12, c16 = c12, c17 = c12, c26 = c12;
-        float32x4_t c18 = vdupq_n_f32(bias[3]), c19 = c18, c20 = c18, c21 = c18, c22 = c18, c23 = c18, c27 = c18;
-
-        float32x4_t a0 = vdupq_n_f32(0.0f);
-        float32x4_t b0 = vdupq_n_f32(0.0f), b1 = vdupq_n_f32(0.0f), b2 = vdupq_n_f32(0.0f);
+        float32x4_t c00 = vdupq_n_f32(0.f), c01 = c00, c02 = c00, c03 = c00, c04 = c00, c05 = c00, c06 = c00;
+        float32x4_t c10 = vdupq_n_f32(0.f), c11 = c10, c12 = c10, c13 = c10, c14 = c10, c15 = c10, c16 = c10;
+        float32x4_t c20 = vdupq_n_f32(0.f), c21 = c20, c22 = c20, c23 = c20, c24 = c20, c25 = c20, c26 = c20;
+        float32x4_t c30 = vdupq_n_f32(0.f), c31 = c30, c32 = c30, c33 = c30, c34 = c30, c35 = c30, c36 = c30;
 
-        for (int p = 0; p < k; p++, a += FAST_CONV_MR)
+        for( int p = 0; p < np; p++, a += CONV_MR, b += CONV_NR )
         {
-            a0 = vld1q_f32(a);
-            b0 = vld1q_f32(b), b1 = vld1q_f32(b + 4), b2 = vld1q_f32(b + 8);
-            b += 12;
-
-            c0 = vfmaq_laneq_f32(c0, b0, a0, 0);
-            c1 = vfmaq_laneq_f32(c1, b1, a0, 0);
-            c2 = vfmaq_laneq_f32(c2, b2, a0, 0);
-            c6 = vfmaq_laneq_f32(c6, b0, a0, 1);
-            c7 = vfmaq_laneq_f32(c7, b1, a0, 1);
-            c8 = vfmaq_laneq_f32(c8, b2, a0, 1);
-            c12 = vfmaq_laneq_f32(c12, b0, a0, 2);
-            c13 = vfmaq_laneq_f32(c13, b1, a0, 2);
-            c14 = vfmaq_laneq_f32(c14, b2, a0, 2);
-            c18 = vfmaq_laneq_f32(c18, b0, a0, 3);
-            c19 = vfmaq_laneq_f32(c19, b1, a0, 3);
-            c20 = vfmaq_laneq_f32(c20, b2, a0, 3);
-
-            b0 = vld1q_f32(b), b1 = vld1q_f32(b + 4), b2 = vld1q_f32(b + 8);
-            b += 12;
-
-            c3 = vfmaq_laneq_f32(c3, b0, a0, 0);
-            c4 = vfmaq_laneq_f32(c4, b1, a0, 0);
-            c5 = vfmaq_laneq_f32(c5, b2, a0, 0);
-
-            c9 = vfmaq_laneq_f32(c9, b0, a0, 1);
-            c10 = vfmaq_laneq_f32(c10, b1, a0, 1);
-            c11 = vfmaq_laneq_f32(c11, b2, a0, 1);
-
-            c15 = vfmaq_laneq_f32(c15, b0, a0, 2);
-            c16 = vfmaq_laneq_f32(c16, b1, a0, 2);
-            c17 = vfmaq_laneq_f32(c17, b2, a0, 2);
-
-            c21 = vfmaq_laneq_f32(c21, b0, a0, 3);
-
-            b0 = vld1q_f32(b);
-            b += 4;
-
-            c22 = vfmaq_laneq_f32(c22, b1, a0, 3);
-            c23 = vfmaq_laneq_f32(c23, b2, a0, 3);
-
-            c24 = vfmaq_laneq_f32(c24, b0, a0, 0);
-            c25 = vfmaq_laneq_f32(c25, b0, a0, 1);
+            float32x4_t a0 = vld1q_f32(a), b0, b1, b2;
+            b0 = vld1q_f32(b); b1 = vld1q_f32(b + 4); b2 = vld1q_f32(b + 8);
+
+            c00 = vfmaq_laneq_f32(c00, b0, a0, 0);
+            c01 = vfmaq_laneq_f32(c01, b1, a0, 0);
+            c02 = vfmaq_laneq_f32(c02, b2, a0, 0);
+            c10 = vfmaq_laneq_f32(c10, b0, a0, 1);
+            c11 = vfmaq_laneq_f32(c11, b1, a0, 1);
+            c12 = vfmaq_laneq_f32(c12, b2, a0, 1);
+            c20 = vfmaq_laneq_f32(c20, b0, a0, 2);
+            c21 = vfmaq_laneq_f32(c21, b1, a0, 2);
+            c22 = vfmaq_laneq_f32(c22, b2, a0, 2);
+            c30 = vfmaq_laneq_f32(c30, b0, a0, 3);
+            c31 = vfmaq_laneq_f32(c31, b1, a0, 3);
+            c32 = vfmaq_laneq_f32(c32, b2, a0, 3);
+
+            b0 = vld1q_f32(b + 12); b1 = vld1q_f32(b + 16); b2 = vld1q_f32(b + 20);
+
+            c03 = vfmaq_laneq_f32(c03, b0, a0, 0);
+            c04 = vfmaq_laneq_f32(c04, b1, a0, 0);
+            c05 = vfmaq_laneq_f32(c05, b2, a0, 0);
+            c13 = vfmaq_laneq_f32(c13, b0, a0, 1);
+            c14 = vfmaq_laneq_f32(c14, b1, a0, 1);
+            c15 = vfmaq_laneq_f32(c15, b2, a0, 1);
+            c23 = vfmaq_laneq_f32(c23, b0, a0, 2);
+            c24 = vfmaq_laneq_f32(c24, b1, a0, 2);
+            c25 = vfmaq_laneq_f32(c25, b2, a0, 2);
+            c33 = vfmaq_laneq_f32(c33, b0, a0, 3);
+            c34 = vfmaq_laneq_f32(c34, b1, a0, 3);
+            c35 = vfmaq_laneq_f32(c35, b2, a0, 3);
+
+            b0 = vld1q_f32(b + 24);
+            c06 = vfmaq_laneq_f32(c06, b0, a0, 0);
+            c16 = vfmaq_laneq_f32(c16, b0, a0, 1);
             c26 = vfmaq_laneq_f32(c26, b0, a0, 2);
-            c27 = vfmaq_laneq_f32(c27, b0, a0, 3);
+            c36 = vfmaq_laneq_f32(c36, b0, a0, 3);
         }
 
-        if (ifActiv) {
-            b0 = vdupq_n_f32(minval), b1 = vdupq_n_f32(maxval);
-            c0 = vminq_f32(vmaxq_f32(c0, b0), b1);
-            c1 = vminq_f32(vmaxq_f32(c1, b0), b1);
-            c2 = vminq_f32(vmaxq_f32(c2, b0), b1);
-            c3 = vminq_f32(vmaxq_f32(c3, b0), b1);
-            c4 = vminq_f32(vmaxq_f32(c4, b0), b1);
-            c5 = vminq_f32(vmaxq_f32(c5, b0), b1);
-            c6 = vminq_f32(vmaxq_f32(c6, b0), b1);
-            c7 = vminq_f32(vmaxq_f32(c7, b0), b1);
-            c8 = vminq_f32(vmaxq_f32(c8, b0), b1);
-            c9 = vminq_f32(vmaxq_f32(c9, b0), b1);
-            c10 = vminq_f32(vmaxq_f32(c10, b0), b1);
-            c11 = vminq_f32(vmaxq_f32(c11, b0), b1);
-            c12 = vminq_f32(vmaxq_f32(c12, b0), b1);
-            c13 = vminq_f32(vmaxq_f32(c13, b0), b1);
-            c14 = vminq_f32(vmaxq_f32(c14, b0), b1);
-            c15 = vminq_f32(vmaxq_f32(c15, b0), b1);
-            c16 = vminq_f32(vmaxq_f32(c16, b0), b1);
-            c17 = vminq_f32(vmaxq_f32(c17, b0), b1);
-            c18 = vminq_f32(vmaxq_f32(c18, b0), b1);
-            c19 = vminq_f32(vmaxq_f32(c19, b0), b1);
-            c20 = vminq_f32(vmaxq_f32(c20, b0), b1);
-            c21 = vminq_f32(vmaxq_f32(c21, b0), b1);
-            c22 = vminq_f32(vmaxq_f32(c22, b0), b1);
-            c23 = vminq_f32(vmaxq_f32(c23, b0), b1);
-            c24 = vminq_f32(vmaxq_f32(c24, b0), b1);
-            c25 = vminq_f32(vmaxq_f32(c25, b0), b1);
-            c26 = vminq_f32(vmaxq_f32(c26, b0), b1);
-            c27 = vminq_f32(vmaxq_f32(c27, b0), b1);
+        if (!init_c)
+        {
+            c00 = vaddq_f32(c00, vld1q_f32(c));
+            c01 = vaddq_f32(c01, vld1q_f32(c + 4));
+            c02 = vaddq_f32(c02, vld1q_f32(c + 8));
+            c03 = vaddq_f32(c03, vld1q_f32(c + 12));
+            c04 = vaddq_f32(c04, vld1q_f32(c + 16));
+            c05 = vaddq_f32(c05, vld1q_f32(c + 20));
+            c06 = vaddq_f32(c06, vld1q_f32(c + 24));
+
+            c10 = vaddq_f32(c10, vld1q_f32(c + ldc));
+            c11 = vaddq_f32(c11, vld1q_f32(c + ldc + 4));
+            c12 = vaddq_f32(c12, vld1q_f32(c + ldc + 8));
+            c13 = vaddq_f32(c13, vld1q_f32(c + ldc + 12));
+            c14 = vaddq_f32(c14, vld1q_f32(c + ldc + 16));
+            c15 = vaddq_f32(c15, vld1q_f32(c + ldc + 20));
+            c16 = vaddq_f32(c16, vld1q_f32(c + ldc + 24));
+
+            c20 = vaddq_f32(c20, vld1q_f32(c + ldc*2));
+            c21 = vaddq_f32(c21, vld1q_f32(c + ldc*2 + 4));
+            c22 = vaddq_f32(c22, vld1q_f32(c + ldc*2 + 8));
+            c23 = vaddq_f32(c23, vld1q_f32(c + ldc*2 + 12));
+            c24 = vaddq_f32(c24, vld1q_f32(c + ldc*2 + 16));
+            c25 = vaddq_f32(c25, vld1q_f32(c + ldc*2 + 20));
+            c26 = vaddq_f32(c26, vld1q_f32(c + ldc*2 + 24));
+
+            c30 = vaddq_f32(c30, vld1q_f32(c + ldc*3));
+            c31 = vaddq_f32(c31, vld1q_f32(c + ldc*3 + 4));
+            c32 = vaddq_f32(c32, vld1q_f32(c + ldc*3 + 8));
+            c33 = vaddq_f32(c33, vld1q_f32(c + ldc*3 + 12));
+            c34 = vaddq_f32(c34, vld1q_f32(c + ldc*3 + 16));
+            c35 = vaddq_f32(c35, vld1q_f32(c + ldc*3 + 20));
+            c36 = vaddq_f32(c36, vld1q_f32(c + ldc*3 + 24));
         }
-        vst1q_f32(c, c0);
-        vst1q_f32(c + 4, c1);
-        vst1q_f32(c + 8, c2);
-        vst1q_f32(c + 12, c3);
-        vst1q_f32(c + 16, c4);
-        vst1q_f32(c + 20, c5);
-        vst1q_f32(c + 24, c24);
-
-        vst1q_f32(c + ldc, c6);
-        vst1q_f32(c + ldc + 4, c7);
-        vst1q_f32(c + ldc + 8, c8);
-        vst1q_f32(c + ldc + 12, c9);
-        vst1q_f32(c + ldc + 16, c10);
-        vst1q_f32(c + ldc + 20, c11);
-        vst1q_f32(c + ldc + 24, c25);
-
-        vst1q_f32(c + ldc * 2, c12);
-        vst1q_f32(c + ldc * 2 + 4, c13);
-        vst1q_f32(c + ldc * 2 + 8, c14);
-        vst1q_f32(c + ldc * 2 + 12, c15);
-        vst1q_f32(c + ldc * 2 + 16, c16);
-        vst1q_f32(c + ldc * 2 + 20, c17);
-        vst1q_f32(c + ldc * 2 + 24, c26);
-
-        vst1q_f32(c + ldc * 3, c18);
-        vst1q_f32(c + ldc * 3 + 4, c19);
-        vst1q_f32(c + ldc * 3 + 8, c20);
-        vst1q_f32(c + ldc * 3 + 12, c21);
-        vst1q_f32(c + ldc * 3 + 16, c22);
-        vst1q_f32(c + ldc * 3 + 20, c23);
-        vst1q_f32(c + ldc * 3 + 24, c27);
+
+        vst1q_f32(c, c00); vst1q_f32(c+4, c01);
+        vst1q_f32(c+8, c02); vst1q_f32(c+12, c03);
+        vst1q_f32(c+16, c04); vst1q_f32(c+20, c05);
+        vst1q_f32(c+24, c06);
+
+        vst1q_f32(c+ldc, c10); vst1q_f32(c+ldc+4, c11);
+        vst1q_f32(c+ldc+8, c12); vst1q_f32(c+ldc+12, c13);
+        vst1q_f32(c+ldc+16, c14); vst1q_f32(c+ldc+20, c15);
+        vst1q_f32(c+ldc+24, c16);
+
+        vst1q_f32(c+ldc*2, c20); vst1q_f32(c+ldc*2+4, c21);
+        vst1q_f32(c+ldc*2+8, c22); vst1q_f32(c+ldc*2+12, c23);
+        vst1q_f32(c+ldc*2+16, c24); vst1q_f32(c+ldc*2+20, c25);
+        vst1q_f32(c+ldc*2+24, c26);
+
+        vst1q_f32(c+ldc*3, c30); vst1q_f32(c+ldc*3+4, c31);
+        vst1q_f32(c+ldc*3+8, c32); vst1q_f32(c+ldc*3+12, c33);
+        vst1q_f32(c+ldc*3+16, c34); vst1q_f32(c+ldc*3+20, c35);
+        vst1q_f32(c+ldc*3+24, c36);
     }
-#elif (!defined(CV_NEON_AARCH64) || !CV_NEON_AARCH64) && FAST_CONV_MR == 4 && FAST_CONV_NR == 12 // ARMv7
+#elif CONV_MR == 4 && CONV_NR == 12 // ARMv7
     {
-        float32x4_t c0 = vdupq_n_f32(bias[0]), c1 = c0, c2 = c0;
-        float32x4_t c3 = vdupq_n_f32(bias[1]), c4 = c3, c5 = c3;
-        float32x4_t c6 = vdupq_n_f32(bias[2]), c7 = c6, c8 = c6;
-        float32x4_t c9 = vdupq_n_f32(bias[3]), c10 = c9, c11 = c9;
+        float32x4_t c0 = vdupq_n_f32(0.f), c1 = c0, c2 = c0;
+        float32x4_t c3 = vdupq_n_f32(0.f), c4 = c3, c5 = c3;
+        float32x4_t c6 = vdupq_n_f32(0.f), c7 = c6, c8 = c6;
+        float32x4_t c9 = vdupq_n_f32(0.f), c10 = c9, c11 = c9;
+
 
         float32x2_t a0 = vdup_n_f32(0.0f), a1 = a0;
         float32x4_t b0 = vdupq_n_f32(0.0f), b1 = vdupq_n_f32(0.0f), b2 = vdupq_n_f32(0.0f);
 
-        for (int p = 0; p < k; p++, a += FAST_CONV_MR, b += FAST_CONV_NR)
+        for (int p = 0; p < np; p++, a += CONV_MR, b += CONV_NR)
         {
             a0 = vld1_f32(a), a1 = vld1_f32(a+2);
             b0 = vld1q_f32(b), b1 = vld1q_f32(b + 4), b2 = vld1q_f32(b + 8);
@@ -311,29 +282,32 @@ void convBlock_NEON(int k, const float *a, const float *b,
             c11 = vmlaq_lane_f32(c11, b2, a1, 1);
         }
 
-        if (ifActiv)
+        if (!init_c)
         {
-            b0 = vdupq_n_f32(minval), b1 = vdupq_n_f32(maxval);
-            c0 = vminq_f32(vmaxq_f32(c0, b0), b1);
-            c1 = vminq_f32(vmaxq_f32(c1, b0), b1);
-            c2 = vminq_f32(vmaxq_f32(c2, b0), b1);
-            c3 = vminq_f32(vmaxq_f32(c3, b0), b1);
-            c4 = vminq_f32(vmaxq_f32(c4, b0), b1);
-            c5 = vminq_f32(vmaxq_f32(c5, b0), b1);
-            c6 = vminq_f32(vmaxq_f32(c6, b0), b1);
-            c7 = vminq_f32(vmaxq_f32(c7, b0), b1);
-            c8 = vminq_f32(vmaxq_f32(c8, b0), b1);
-            c9 = vminq_f32(vmaxq_f32(c9, b0), b1);
-            c10 = vminq_f32(vmaxq_f32(c10, b0), b1);
-            c11 = vminq_f32(vmaxq_f32(c11, b0), b1);
+            c0 = vaddq_f32(c0, vld1q_f32(c));
+            c1 = vaddq_f32(c1, vld1q_f32(c + 4));
+            c2 = vaddq_f32(c2, vld1q_f32(c + 8));
+
+            c3 = vaddq_f32(c3, vld1q_f32(c + ldc));
+            c4 = vaddq_f32(c4, vld1q_f32(c + ldc + 4));
+            c5 = vaddq_f32(c5, vld1q_f32(c + ldc + 8));
+
+            c6 = vaddq_f32(c6, vld1q_f32(c + ldc * 2));
+            c7 = vaddq_f32(c7, vld1q_f32(c + ldc * 2 + 4));
+            c8 = vaddq_f32(c8, vld1q_f32(c + ldc * 2 + 8));
+
+            c9  = vaddq_f32(c9 , vld1q_f32(c + ldc * 3));
+            c10 = vaddq_f32(c10, vld1q_f32(c + ldc * 3 + 4));
+            c11 = vaddq_f32(c11, vld1q_f32(c + ldc * 3 + 8));
         }
-        vst1q_f32(c, c0); vst1q_f32(c+4, c1); vst1q_f32(c+8, c2);
-        vst1q_f32(c + ldc, c3); vst1q_f32(c + ldc + 4, c4); vst1q_f32(c + ldc + 8, c5);
-        vst1q_f32(c + ldc*2, c6); vst1q_f32(c + ldc*2 + 4, c7); vst1q_f32(c + ldc*2 + 8, c8);
-        vst1q_f32(c + ldc*3, c9); vst1q_f32(c + ldc*3 + 4, c10); vst1q_f32(c + ldc*3 + 8, c11);
+
+        vst1q_f32(c, c0), vst1q_f32(c+4, c1), vst1q_f32(c+8, c2);
+        vst1q_f32(c + ldc, c3), vst1q_f32(c + ldc + 4, c4), vst1q_f32(c + ldc + 8, c5);
+        vst1q_f32(c + ldc*2, c6), vst1q_f32(c + ldc*2 + 4, c7), vst1q_f32(c + ldc*2 + 8, c8);
+        vst1q_f32(c + ldc*3, c9), vst1q_f32(c + ldc*3 + 4, c10), vst1q_f32(c + ldc*3 + 8, c11);
     }
-#else
-#error "unsupported FAST_CONV_MR and/or FAST_CONV_NR in convBlock_NEON."
+//#else
+//#error "unsupported CONV_MR and/or CONV_NR in convBlock_NEON."
 #endif
 }
 #endif
index e841889..fb668d6 100644 (file)
@@ -37,6 +37,47 @@ enum
 };
 
 #if CV_NEON
+
+#undef _FAST_CONV_T4x4
+#define _FAST_CONV_T4x4(a, b, c, d, tr0, tr1) \
+    tr0 = vtrnq_f32(a, b); \
+    tr1 = vtrnq_f32(c, d); \
+    a = vcombine_f32(vget_low_f32(tr0.val[0]), vget_low_f32(tr1.val[0])); \
+    b = vcombine_f32(vget_low_f32(tr0.val[1]), vget_low_f32(tr1.val[1])); \
+    c = vcombine_f32(vget_high_f32(tr0.val[0]), vget_high_f32(tr1.val[0])); \
+    d = vcombine_f32(vget_high_f32(tr0.val[1]), vget_high_f32(tr1.val[1]))
+
+// The input is the pack4 data, and the output is unpack4 data.
+static void transpose12x4(float* src, float* dst, const int cn)
+{
+    float32x4_t r00, r01, r02, r03, r04, r05, r06, r07, r08, r09, r10, r11;
+    float32x4x2_t tr0, tr1;
+    for (int i = 0; i < cn; i++, src += 48, dst += 48)
+    {
+        r00 = vld1q_f32(src);
+        r01 = vld1q_f32(src + 4);
+        r02 = vld1q_f32(src + 8);
+        r03 = vld1q_f32(src + 12);
+        r04 = vld1q_f32(src + 16);
+        r05 = vld1q_f32(src + 20);
+        r06 = vld1q_f32(src + 24);
+        r07 = vld1q_f32(src + 28);
+        r08 = vld1q_f32(src + 32);
+        r09 = vld1q_f32(src + 36);
+        r10 = vld1q_f32(src + 40);
+        r11 = vld1q_f32(src + 44);
+
+        _FAST_CONV_T4x4(r00, r01, r02, r03, tr0, tr1);
+        _FAST_CONV_T4x4(r04, r05, r06, r07, tr0, tr1);
+        _FAST_CONV_T4x4(r08, r09, r10, r11, tr0, tr1);
+
+        vst1q_f32(dst, r00), vst1q_f32(dst + 4, r04), vst1q_f32(dst + 8, r08);
+        vst1q_f32(dst + 12, r01), vst1q_f32(dst + 16, r05), vst1q_f32(dst + 20, r09);
+        vst1q_f32(dst + 24, r02), vst1q_f32(dst + 28, r06), vst1q_f32(dst + 32, r10);
+        vst1q_f32(dst + 36, r03), vst1q_f32(dst + 40, r07), vst1q_f32(dst + 44, r11);
+    }
+}
+
 static void winograd_trans_input_F63(float* src, float* dst, int Channle_div4, const int tiles, const int big_step, const int line_step, const int* ofstab0)
 {
     // const float itm[8][8] = {
@@ -192,15 +233,14 @@ static void winograd_trans_input_F63(float* src, float* dst, int Channle_div4, c
         float* input0 = input_buf0 + 4 * tiles * r;
 
         // TODO! support tiles > 12
-//#if CV_NEON_AARCH64
-//        for (; ti + 11 < tiles; ti += 12)
-//        {
-//            float* out1 = out0 + line_step * ofstab0[ti * 2] + Channle_div4 * ofstab0[ti * 2 + 1] * 4;
-////            std::cout<<"ofstab0[ti * 2] = "<<ofstab0[ti * 2]<<", ofstab0[ti * 2 + 1] = "<<ofstab0[ti * 2 + 1]<<std::endl;
-//            float* input1 = input0 + ti * 4;
-//            memcpy(out1, input1, 12 * 4 * sizeof(float ));
-//        }
-//#endif
+#if CV_NEON_AARCH64
+        for (; ti + 11 < tiles; ti += 12)
+        {
+            float* out1 = out0 + line_step * ofstab0[ti * 2] + Channle_div4 * ofstab0[ti * 2 + 1] * 4;
+            float* input1 = input0 + ti * 4;
+            memcpy(out1, input1, 12 * 4 * sizeof(float ));
+        }
+#endif
         for (; ti + 7 < tiles; ti += 8)
         {
             float* out1 = out0 + line_step * ofstab0[ti * 2] + Channle_div4 * ofstab0[ti * 2 + 1] * 4;
@@ -231,7 +271,7 @@ static void winograd_trans_input_F63(float* src, float* dst, int Channle_div4, c
     }
 }
 
-static void winograd_trans_output_F63(float* src_, float* bias_, float minval, float maxval, bool ifMinMaxAct)
+static void winograd_trans_output_F63(float* src_, float* bias_, float* fAbuf0, float minval, float maxval, bool ifMinMaxAct)
 {
     // const float otm[6][8] = {
     //     {1.0f,  1.0f,   1.0f,   1.0f,   1.0f,  32.0f, 32.0f, 0.0f},
@@ -292,6 +332,7 @@ static void winograd_trans_output_F63(float* src_, float* bias_, float minval, f
     for (int m = 0; m < 6; m++)
     {
         float* output0 = src_ + 6 * m * FAST_VEC_NLANES;
+        float* fAbuf = fAbuf0 ? fAbuf0 + 6 * m * FAST_VEC_NLANES : 0;
 
         float32x4_t _tmp00 = vld1q_f32(tmp[m][0]);
         float32x4_t _tmp01 = vld1q_f32(tmp[m][1]);
@@ -319,6 +360,16 @@ static void winograd_trans_output_F63(float* src_, float* bias_, float minval, f
         float32x4_t _out03 = vaddq_f32(bias0, vmlaq_n_f32(vmlaq_n_f32(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f));
         float32x4_t _out05 = vaddq_f32(bias0, vaddq_f32(vaddq_f32(_tmp07, _tmp135a), vmlaq_n_f32(_tmp135c, _tmp135b, 32.f)));
 
+        if (fAbuf)
+        {
+            _out00 = vaddq_f32(_out00, vld1q_f32(fAbuf));
+            _out01 = vaddq_f32(_out01, vld1q_f32(fAbuf + 4));
+            _out02 = vaddq_f32(_out02, vld1q_f32(fAbuf + 8));
+            _out03 = vaddq_f32(_out03, vld1q_f32(fAbuf + 12));
+            _out04 = vaddq_f32(_out04, vld1q_f32(fAbuf + 16));
+            _out05 = vaddq_f32(_out05, vld1q_f32(fAbuf + 20));
+        }
+
         if (ifMinMaxAct)
         {
             float32x4_t vmin = vdupq_n_f32(minval), vmax = vdupq_n_f32(maxval);
@@ -339,7 +390,7 @@ static void winograd_trans_output_F63(float* src_, float* bias_, float minval, f
     }
 }
 
-void initWinograd63(Ptr<FastConv2d>& conv, float* srcWeight, int K, int C)
+void initWinograd63(Ptr<FastConv2d>& conv, InputArray _weightsMat, int K, int C)
 {
     static const float ktm[8][3] = {
             {1.0f,      0.0f,      0.0f},
@@ -352,11 +403,14 @@ void initWinograd63(Ptr<FastConv2d>& conv, float* srcWeight, int K, int C)
             {0.0f, 0.0f, 1.0f}
     };
 
+    Mat weightsMat = _weightsMat.getMat();
+    float* srcWeight = weightsMat.ptr<float>();
+    size_t wstep = weightsMat.step1();
+
     int K_aligned = ((K + FAST_VEC_NLANES - 1)/FAST_VEC_NLANES) * FAST_VEC_NLANES;
     int C_aligned = ((C + FAST_VEC_NLANES - 1)/FAST_VEC_NLANES) * FAST_VEC_NLANES;
     const int winoSize = C * WINO_AREA;
     const int kArea = WINO_KSIZE * WINO_KSIZE;
-    const int kSize = C * kArea;
 
     // Allocate memory for winograd.
     int nweights = K_aligned * C_aligned * WINO_AREA;
@@ -379,7 +433,7 @@ void initWinograd63(Ptr<FastConv2d>& conv, float* srcWeight, int K, int C)
             for (int inc = 0; inc < C; inc++)
             {
                 float *kernel_tm0 = kernelTm + outc * winoSize + inc * WINO_AREA;
-                const float *kernel0 = srcWeight + outc * kSize + inc * kArea;
+                const float *kernel0 = srcWeight + outc * wstep + inc * kArea;
 
                 // transform kernel, transposed
                 const float *k0 = kernel0;
@@ -472,16 +526,16 @@ void initWinograd63(Ptr<FastConv2d>& conv, float* srcWeight, int K, int C)
                     out1[inc * 4] = tmp1[inc * 64];
                 }
             }
-
         }
     }
 }
 
-int runWinograd63(InputArray _input, OutputArray _output, const Ptr<FastConv2d>& conv, int ntasks, float minval,
+int runWinograd63(InputArray _input, InputArray _fusedAddMat, OutputArray _output, const Ptr<FastConv2d>& conv, int ntasks, float minval,
         float maxval, ActivationLayer* activ, bool ifMinMaxAct)
 {
     Mat input = _input.getMat();
     Mat output = _output.getMat();
+    Mat fusedAddMat = _fusedAddMat.getMat();
 
     MatShape inputShape = shape(input);
     MatShape outputShape = shape(output);
@@ -517,15 +571,14 @@ int runWinograd63(InputArray _input, OutputArray _output, const Ptr<FastConv2d>&
     int inpPack = 0;
     int lineNum =0;
 
-    // TODO! tiles > 12
-//#if CV_NEON_AARCH64
-//    if (tiles >= 12)
-//    {
-//        inpPack = 12;
-//        lineNum = tiles / 12 + (tiles % 12) / 8 + (tiles % 12 % 8) / 4 + (tiles % 12 % 4) / 2 + tiles % 12 % 2;
-//    }
-//    else
-//#endif
+#if CV_NEON_AARCH64
+    if (tiles >= 12)
+    {
+        inpPack = 12;
+        lineNum = tiles / 12 + (tiles % 12) / 8 + (tiles % 12 % 8) / 4 + (tiles % 12 % 4) / 2 + tiles % 12 % 2;
+    }
+    else
+#endif
     if (tiles >= 8)
     {
         inpPack = 8;
@@ -586,6 +639,7 @@ int runWinograd63(InputArray _input, OutputArray _output, const Ptr<FastConv2d>&
         }
     }
 
+    const size_t inp_planesize = (size_t)Hi*Wi;
     const size_t out_planesize = (size_t)H0*W0;
 
     size_t inputbuf_size = inpPack * C_aligned * lineNum * 64;
@@ -594,36 +648,33 @@ int runWinograd63(InputArray _input, OutputArray _output, const Ptr<FastConv2d>&
     size_t outputbuf_size = tiles * K_aligned * 8 * 8;
     size_t outputCnbuf_size = ntasks * 8 * 8 * 4;
 
-    AutoBuffer<float> inputbuf0_, inputCnbuf0_, outputbuf0_, outputCnbuf0_;
+    size_t part0_size = std::max(inputbuf_size, outputCnbuf_size);
+    size_t allbuf_size = part0_size + std::max(inputbufCn_size, outputbuf_size);
 
-    inputbuf0_.allocate(inputbuf_size);
-    float* inputbuf0 = alignPtr(inputbuf0_.data(), (int)(sizeof(float)));
-    memset(inputbuf0, 0, inputbuf_size * sizeof(float ));
-
-    inputCnbuf0_.allocate(inputbufCn_size);
-    float* inputCnbuf0 = inputCnbuf0_.data();
-
-    outputbuf0_.allocate(outputbuf_size);
-    float* outputbuf0 = outputbuf0_.data();
-
-    outputCnbuf0_.allocate(outputCnbuf_size);
-    float* outputCnbuf0 = outputCnbuf0_.data();
+    AutoBuffer<float> allbuf_;
+    allbuf_.allocate(allbuf_size);
+    float* inputbuf0 = alignPtr(allbuf_.data(), (int)(sizeof(float)));
+    float* inputCnbuf0 = inputbuf0 + inputbuf_size;
+    float* outputbuf0 = inputCnbuf0;
+    float* outputCnbuf0 = inputbuf0;
 
     // Input Parallel For
     float* weight_ptr0 = conv->weightsWino63Buf.data();
+
     for (int bn = 0; bn < N; bn++)
     {
-        float* input_ptr0 = input.ptr<float>() + bn * Hi * Wi * C;
+        float* input_ptr0 = input.ptr<float>() + bn * inp_planesize * C;
         float* output_ptr0 = output.ptr<float>() + bn * out_planesize * K;
+        float* fusedAddPtr0 = fusedAddMat.empty() ? 0 : fusedAddMat.ptr<float>() + bn * out_planesize * K;
 
         // Transform Input
         int C_aligned_div4 = C_aligned/4;
+        const int tiStep = 8 * 8 * FAST_VEC_NLANES;
 
-        parallel_for_(Range(0, ntasks), [&](const Range& range)
-        {
-            for (int task_i = range.start; task_i < range.end; task_i++)
+        parallel_for_(Range(0, ntasks), [&](const Range& range){
+        for (int task_i = range.start; task_i < range.end; task_i++)
             {
-                float *inpCnbuf = inputCnbuf0 + tiles * 256 * task_i;
+                float *inpCnbuf = inputCnbuf0 + tiles * tiStep * task_i;
                 for (int inc4 = task_i; inc4 < C_aligned_div4; inc4 += ntasks)
                 {
                     for (int cn = 0; cn < 4; cn++)
@@ -699,31 +750,225 @@ int runWinograd63(InputArray _input, OutputArray _output, const Ptr<FastConv2d>&
                         }
                     }
 
-                    // Transfor Compute BdB^T
+                    // Transform Compute BdB^T
                     winograd_trans_input_F63(inpCnbuf, inputbuf0, inc4, tiles, big_step, line_step, ofstab0);
                 }
             }
         });
-
         // Matrix multiplication 8 channel
         int K_div8 = 0;
-
 #if CV_NEON_AARCH64
         K_div8 = K_aligned/8;
-
-        parallel_for_(Range(0, K_div8), [&](const Range &range){
-        for (int outcn = range.start; outcn < range.end; outcn ++)
+        // Transpose 12
+        if (inpPack == 12)
         {
-            float* output_tmp = outputbuf0 + tiles * outcn * 8;
-            float* kernel_tmp = weight_ptr0 + outcn * 8 * C_aligned;
-            for (int r = 0; r < 64; r++)
+            int C_div4 = C_aligned/4;
+            parallel_for_(Range(0, 64), [&](const Range &range){
+            for (int r = range.start; r < range.end; r++)
             {
                 float* input_tm = inputbuf0 + r * big_step;
-                float* output0_tm = output_tmp + tiles * K_aligned * r;
+
+                for (int ti = 0; ti + 11 < tiles; ti += 12)
+                {
+                    float* r0 = input_tm + ofstab0[ti * 2] * line_step;
+                    transpose12x4(r0, r0, C_div4);
+                }
+            }
+            });
+        }
+
+        parallel_for_(Range(0, 64), [&](const Range &range){
+        for (int r = range.start; r < range.end; r++)
+        {
+            float* input_tm = inputbuf0 + r * big_step;
+            float* output_tmp = outputbuf0 + tiles * K_aligned * r;
+            float* kernel_tmp = weight_ptr0 + r * C_aligned * K_aligned;
+
+            for (int out_div8 = 0; out_div8 < K_div8; out_div8 ++)
+            {
+                float* output0_tm = output_tmp + tiles * out_div8 * 8;
                 float* output1_tm = output0_tm + tiles * 4;
-                float* kernel_tm_i = kernel_tmp + r * C_aligned * K_aligned;
+                float* kernel_tm_i = kernel_tmp + out_div8 * 8 * C_aligned;
 
                 int ti = 0;
+                for (; ti + 11 < tiles; ti += 12)
+                {
+                    float* r0 = input_tm + ofstab0[ti * 2] * line_step;
+                    const float* k01 = kernel_tm_i;
+
+                    int nn = C_aligned/4;
+                    r0 = input_tm + ofstab0[ti * 2] * line_step;
+
+                    // init 32 registers. FMA/load ratio = 96/20
+                    float32x4_t r00 = vdupq_n_f32(0.0f), r01 = r00, r02 = r00, r03 = r00;
+                    float32x4_t r04 = r00, r05 = r00, r06 = r00, r07 = r00;
+                    float32x4_t r08 = r00, r09 = r00, r10 = r00, r11 = r00;
+                    float32x4_t r12 = r00, r13 = r00, r14 = r00, r15 = r00;
+                    float32x4_t r16 = r00, r17 = r00, r18 = r00, r19 = r00;
+                    float32x4_t r20 = r00, r21 = r00, r22 = r00, r23 = r00;
+                    float32x4_t r24 = r00, r25 = r00, r26 = r00, r27 = r00;
+                    float32x4_t r28 = r00, r29 = r00, r30 = r00, r31 = r00;
+
+                    for(;nn > 0; nn--)
+                    {
+                        r00 = vld1q_f32(r0), r01 = vld1q_f32(r0+4), r02 = vld1q_f32(r0+8), r03 = vld1q_f32(r0+12);
+                        r04 = vld1q_f32(k01), r05 = vld1q_f32(k01+4), r06 = vld1q_f32(k01+8), r07 = vld1q_f32(k01+12);
+                        r0 += 16, k01 += 16;
+
+                        // Cn0
+                        // 8 ~ 19
+                        r08 = vfmaq_laneq_f32(r08, r04, r00, 0);
+                        r09 = vfmaq_laneq_f32(r09, r04, r00, 1);
+                        r10 = vfmaq_laneq_f32(r10, r04, r00, 2);
+                        r11 = vfmaq_laneq_f32(r11, r04, r00, 3);
+
+                        r12 = vfmaq_laneq_f32(r12, r04, r01, 0);
+                        r13 = vfmaq_laneq_f32(r13, r04, r01, 1);
+                        r14 = vfmaq_laneq_f32(r14, r04, r01, 2);
+                        r15 = vfmaq_laneq_f32(r15, r04, r01, 3);
+
+                        r16 = vfmaq_laneq_f32(r16, r04, r02, 0);
+                        r17 = vfmaq_laneq_f32(r17, r04, r02, 1);
+                        r18 = vfmaq_laneq_f32(r18, r04, r02, 2);
+                        r19 = vfmaq_laneq_f32(r19, r04, r02, 3);
+
+                        // 20 ~ 31
+                        r20 = vfmaq_laneq_f32(r20, r05, r00, 0);
+                        r21 = vfmaq_laneq_f32(r21, r05, r00, 1);
+                        r22 = vfmaq_laneq_f32(r22, r05, r00, 2);
+                        r23 = vfmaq_laneq_f32(r23, r05, r00, 3);
+
+                        r24 = vfmaq_laneq_f32(r24, r05, r01, 0);
+                        r25 = vfmaq_laneq_f32(r25, r05, r01, 1);
+                        r26 = vfmaq_laneq_f32(r26, r05, r01, 2);
+                        r27 = vfmaq_laneq_f32(r27, r05, r01, 3);
+
+                        r28 = vfmaq_laneq_f32(r28, r05, r02, 0);
+                        r29 = vfmaq_laneq_f32(r29, r05, r02, 1);
+                        r30 = vfmaq_laneq_f32(r30, r05, r02, 2);
+                        r31 = vfmaq_laneq_f32(r31, r05, r02, 3);
+
+                        // Cn1
+                        r08 = vfmaq_laneq_f32(r08, r06, r03, 0);
+                        r09 = vfmaq_laneq_f32(r09, r06, r03, 1);
+                        r10 = vfmaq_laneq_f32(r10, r06, r03, 2);
+                        r11 = vfmaq_laneq_f32(r11, r06, r03, 3);
+
+                        r20 = vfmaq_laneq_f32(r20, r07, r03, 0);
+                        r21 = vfmaq_laneq_f32(r21, r07, r03, 1);
+                        r22 = vfmaq_laneq_f32(r22, r07, r03, 2);
+                        r23 = vfmaq_laneq_f32(r23, r07, r03, 3);
+
+                        r00 = vld1q_f32(r0), r01 = vld1q_f32(r0+4), r02 = vld1q_f32(r0+8), r03 = vld1q_f32(r0+12);
+                        r0 += 16;
+
+                        r12 = vfmaq_laneq_f32(r12, r06, r00, 0);
+                        r13 = vfmaq_laneq_f32(r13, r06, r00, 1);
+                        r14 = vfmaq_laneq_f32(r14, r06, r00, 2);
+                        r15 = vfmaq_laneq_f32(r15, r06, r00, 3);
+
+                        r16 = vfmaq_laneq_f32(r16, r06, r01, 0);
+                        r17 = vfmaq_laneq_f32(r17, r06, r01, 1);
+                        r18 = vfmaq_laneq_f32(r18, r06, r01, 2);
+                        r19 = vfmaq_laneq_f32(r19, r06, r01, 3);
+
+                        r24 = vfmaq_laneq_f32(r24, r07, r00, 0);
+                        r25 = vfmaq_laneq_f32(r25, r07, r00, 1);
+                        r26 = vfmaq_laneq_f32(r26, r07, r00, 2);
+                        r27 = vfmaq_laneq_f32(r27, r07, r00, 3);
+
+                        r28 = vfmaq_laneq_f32(r28, r07, r01, 0);
+                        r29 = vfmaq_laneq_f32(r29, r07, r01, 1);
+                        r30 = vfmaq_laneq_f32(r30, r07, r01, 2);
+                        r31 = vfmaq_laneq_f32(r31, r07, r01, 3);
+
+                        r04 = vld1q_f32(k01), r05 = vld1q_f32(k01+4), r06 = vld1q_f32(k01+8), r07 = vld1q_f32(k01+12);
+                        k01 += 16;
+
+                        // Cn2
+                        r08 = vfmaq_laneq_f32(r08, r04, r02, 0);
+                        r09 = vfmaq_laneq_f32(r09, r04, r02, 1);
+                        r10 = vfmaq_laneq_f32(r10, r04, r02, 2);
+                        r11 = vfmaq_laneq_f32(r11, r04, r02, 3);
+
+                        r12 = vfmaq_laneq_f32(r12, r04, r03, 0);
+                        r13 = vfmaq_laneq_f32(r13, r04, r03, 1);
+                        r14 = vfmaq_laneq_f32(r14, r04, r03, 2);
+                        r15 = vfmaq_laneq_f32(r15, r04, r03, 3);
+
+                        r20 = vfmaq_laneq_f32(r20, r05, r02, 0);
+                        r21 = vfmaq_laneq_f32(r21, r05, r02, 1);
+                        r22 = vfmaq_laneq_f32(r22, r05, r02, 2);
+                        r23 = vfmaq_laneq_f32(r23, r05, r02, 3);
+
+                        r24 = vfmaq_laneq_f32(r24, r05, r03, 0);
+                        r25 = vfmaq_laneq_f32(r25, r05, r03, 1);
+                        r26 = vfmaq_laneq_f32(r26, r05, r03, 2);
+                        r27 = vfmaq_laneq_f32(r27, r05, r03, 3);
+
+                        r00 = vld1q_f32(r0), r01 = vld1q_f32(r0+4), r02 = vld1q_f32(r0+8), r03 = vld1q_f32(r0+12);
+                        r0 += 16;
+
+                        r16 = vfmaq_laneq_f32(r16, r04, r00, 0);
+                        r17 = vfmaq_laneq_f32(r17, r04, r00, 1);
+                        r18 = vfmaq_laneq_f32(r18, r04, r00, 2);
+                        r19 = vfmaq_laneq_f32(r19, r04, r00, 3);
+
+                        r28 = vfmaq_laneq_f32(r28, r05, r00, 0);
+                        r29 = vfmaq_laneq_f32(r29, r05, r00, 1);
+                        r30 = vfmaq_laneq_f32(r30, r05, r00, 2);
+                        r31 = vfmaq_laneq_f32(r31, r05, r00, 3);
+
+                        // Cn3
+                        // 8 ~ 19
+                        r08 = vfmaq_laneq_f32(r08, r06, r01, 0);
+                        r09 = vfmaq_laneq_f32(r09, r06, r01, 1);
+                        r10 = vfmaq_laneq_f32(r10, r06, r01, 2);
+                        r11 = vfmaq_laneq_f32(r11, r06, r01, 3);
+
+                        r12 = vfmaq_laneq_f32(r12, r06, r02, 0);
+                        r13 = vfmaq_laneq_f32(r13, r06, r02, 1);
+                        r14 = vfmaq_laneq_f32(r14, r06, r02, 2);
+                        r15 = vfmaq_laneq_f32(r15, r06, r02, 3);
+
+                        r16 = vfmaq_laneq_f32(r16, r06, r03, 0);
+                        r17 = vfmaq_laneq_f32(r17, r06, r03, 1);
+                        r18 = vfmaq_laneq_f32(r18, r06, r03, 2);
+                        r19 = vfmaq_laneq_f32(r19, r06, r03, 3);
+
+                        // 20 ~ 31
+                        r20 = vfmaq_laneq_f32(r20, r07, r01, 0);
+                        r21 = vfmaq_laneq_f32(r21, r07, r01, 1);
+                        r22 = vfmaq_laneq_f32(r22, r07, r01, 2);
+                        r23 = vfmaq_laneq_f32(r23, r07, r01, 3);
+
+                        r24 = vfmaq_laneq_f32(r24, r07, r02, 0);
+                        r25 = vfmaq_laneq_f32(r25, r07, r02, 1);
+                        r26 = vfmaq_laneq_f32(r26, r07, r02, 2);
+                        r27 = vfmaq_laneq_f32(r27, r07, r02, 3);
+
+                        r28 = vfmaq_laneq_f32(r28, r07, r03, 0);
+                        r29 = vfmaq_laneq_f32(r29, r07, r03, 1);
+                        r30 = vfmaq_laneq_f32(r30, r07, r03, 2);
+                        r31 = vfmaq_laneq_f32(r31, r07, r03, 3);
+                    }
+
+                    vst1q_f32(output0_tm, r08), vst1q_f32(output0_tm + 4, r09), vst1q_f32(output0_tm + 8, r10), vst1q_f32(output0_tm + 12, r11);
+                    output0_tm += 16;
+                    vst1q_f32(output1_tm, r20), vst1q_f32(output1_tm + 4, r21), vst1q_f32(output1_tm + 8, r22), vst1q_f32(output1_tm + 12, r23);
+                    output1_tm += 16;
+
+                    vst1q_f32(output0_tm, r12), vst1q_f32(output0_tm + 4, r13), vst1q_f32(output0_tm + 8, r14), vst1q_f32(output0_tm + 12, r15);
+                    output0_tm += 16;
+                    vst1q_f32(output1_tm, r24), vst1q_f32(output1_tm + 4, r25), vst1q_f32(output1_tm + 8, r26), vst1q_f32(output1_tm + 12, r27);
+                    output1_tm += 16;
+
+                    vst1q_f32(output0_tm, r16), vst1q_f32(output0_tm + 4, r17), vst1q_f32(output0_tm + 8, r18), vst1q_f32(output0_tm + 12, r19);
+                    output0_tm += 16;
+                    vst1q_f32(output1_tm, r28), vst1q_f32(output1_tm + 4, r29), vst1q_f32(output1_tm + 8, r30), vst1q_f32(output1_tm + 12, r31);
+                    output1_tm += 16;
+                }
+
                 for (; ti + 7 < tiles; ti += 8)
                 {
                     const float* r0 = input_tm + ofstab0[ti * 2] * line_step;
@@ -1009,17 +1254,17 @@ int runWinograd63(InputArray _input, OutputArray _output, const Ptr<FastConv2d>&
 
         // Matrix multiplication, 4 output channel.
         int Ock_div4 = (K_aligned - K_div8 * 8) / 4;
-        parallel_for_(Range(0, Ock_div4), [&](const Range &range){
-            for (int outcn = range.start; outcn < range.end; outcn++)
+        parallel_for_(Range(0, 64), [&](const Range &range){
+            for (int r = range.start; r < range.end; r++)
             {
-                float* output_tmp = outputbuf0 + tiles * (outcn + K_div8 * 2)* 4;
-                float* kernel_tmp = weight_ptr0 + (outcn + K_div8 * 2) * 4 * C_aligned;
+                float* input_tm = inputbuf0 + r * big_step;
+                float* output_tmp = outputbuf0 + tiles * K_aligned * r;
+                float* kernel_tmp = weight_ptr0 + r * C_aligned * K_aligned;
 
-                for (int r = 0; r < 64; r++)
+                for (int out_div4 = 0; out_div4 < Ock_div4; out_div4 ++)
                 {
-                    float *input_tm = inputbuf0 + r * big_step;
-                    float *output0_tm = output_tmp + tiles * K_aligned * r;
-                    float *kernel_tm_i = kernel_tmp + r * C_aligned * K_aligned;
+                    float* output0_tm = output_tmp + tiles * (out_div4 + K_div8 * 2) * 4 ;
+                    float* kernel_tm_i = kernel_tmp + (out_div4 + K_div8 * 2) * 4 * C_aligned;
 
                     int ti = 0;
                     for (; ti + 7 < tiles; ti += 8)
@@ -1345,12 +1590,20 @@ int runWinograd63(InputArray _input, OutputArray _output, const Ptr<FastConv2d>&
         });
 
         int bigStepOut = tiles * K_aligned;
+        AutoBuffer<float> _fAbuf;
+        float* fAbuf0 = 0;
+        if (fusedAddPtr0)
+        {
+            _fAbuf.allocate(6 * 6 * 4 * ntasks);
+            fAbuf0 = _fAbuf.data();
+        }
 
         // Transfor Ouput
         parallel_for_(Range(0, ntasks), [&](const Range& range)
         {
             for (int task_i = range.start; task_i < range.end; task_i++)
             {
+                float* fAbuf = fAbuf0 ? fAbuf0 + task_i * 6 * 6 * 4 : 0;
                 float* outputCnbuf = outputCnbuf0 + task_i * 8 * 8 * 4;
                 for (int outCn4 = task_i; outCn4 < K_aligned / 4; outCn4 += ntasks)
                 {
@@ -1358,6 +1611,7 @@ int runWinograd63(InputArray _input, OutputArray _output, const Ptr<FastConv2d>&
                     int outCn = outCn4 * 4;
                     float* output_buf = outputbuf0 + outCn * tiles;
                     float* output_ptr = output_ptr0 + outCn * W0 * H0;
+                    float* fusedAddPtr = fusedAddPtr0 + outCn * W0 * H0;
 
                     for (int ti = 0; ti < tiles; ti++)
                     {
@@ -1366,6 +1620,9 @@ int runWinograd63(InputArray _input, OutputArray _output, const Ptr<FastConv2d>&
                         int hi = ti / W_tiles;
                         int wi = ti % W_tiles;
 
+                        int wEnd = (wi + 1) * 6 > W0 ? W0 - (wi * 6) : 6;
+                        int hEnd = (hi + 1) * 6 > H0 ? H0 - (hi * 6) : 6;
+
                         // construct the output tile.
                         for (int r = 0; r < 64; r++)
                         {
@@ -1374,11 +1631,26 @@ int runWinograd63(InputArray _input, OutputArray _output, const Ptr<FastConv2d>&
                             outputCnbuf_i += FAST_VEC_NLANES;
                         }
 
-                        winograd_trans_output_F63(outputCnbuf, conv->biasBuf.data() + outCn,
-                                                  minval, maxval, ifMinMaxAct);
+                        // construct the fusedAdd buffer.
+                        if (fAbuf && fusedAddPtr0)
+                        {
+                            memset(fAbuf, 0, sizeof(fAbuf[0]) * 6 * 6 * 4);
+                            float* fAPtr = fusedAddPtr + (hi * W0 + wi) * 6;
+                            for (int outCni = 0; outCni < FAST_VEC_NLANES; outCni++)
+                            {
+                                float* fAbufCnPtr = fAPtr + outCni * out_planesize; // skip channel
+                                for (int i = 0; i < hEnd; i++)
+                                {
+                                    for (int j = 0; j < wEnd; j++)
+                                    {
+                                        fAbuf[(i * 6 + j) * FAST_VEC_NLANES + outCni] = fAbufCnPtr[i * W0 + j];
+                                    }
+                                }
+                            }
+                        }
 
-                        int wEnd = (wi + 1) * 6 > W0 ? W0 - (wi * 6) : 6;
-                        int hEnd = (hi + 1) * 6 > H0 ? H0 - (hi * 6) : 6;
+                        winograd_trans_output_F63(outputCnbuf, conv->biasBuf.data() + outCn, fAbuf,
+                                                  minval, maxval, ifMinMaxAct);
 
                         float* output_ptr_i = output_ptr + (hi * W0 + wi) * 6;
 
@@ -1411,13 +1683,11 @@ int runWinograd63(InputArray _input, OutputArray _output, const Ptr<FastConv2d>&
             }
         });
     }
-
     return 1;
 }
-
 #else
 
-void initWinograd63(Ptr<FastConv2d>& conv, float* src_weight, int K, int C)
+void initWinograd63(Ptr<FastConv2d>& conv, InputArray _weightsMat, int K, int C)
 {
     conv->ifWinograd63 = false;
 }
index 753c00d..deb356f 100644 (file)
@@ -162,6 +162,178 @@ void Net::Impl::fuseLayers(const std::vector<LayerPin>& blobsToKeep_)
                     break;
             }
 
+            // CPU: fuse Convolution 2D layer followed by Add + activation.
+            while (nextData && (IS_DNN_CPU_TARGET(preferableTarget)) && ld.layerInstance->type == "Convolution")
+            {
+                // Note that we can only deal with conv + Add + activ here.
+                // To avoid the order like: conv + activ + add, if we found the conv has been fused with activ, we break.
+                Ptr<ConvolutionLayer> convLayer = ld.layerInstance.dynamicCast<ConvolutionLayer>();
+
+                // Only Conv2D without fusion Activation supports this fusion, other-wise, we skip.
+                if (!convLayer->isConv2D || convLayer->fusedActivation)
+                    break;
+
+                // For now, there are currently two layers in OpenCV that run the Add operator.
+                Ptr<NaryEltwiseLayer> nextNaryEltwiseLayer = nextData->layerInstance.dynamicCast<NaryEltwiseLayer>();
+                Ptr<EltwiseLayer> nextEltwiseLayer = nextData->layerInstance.dynamicCast<EltwiseLayer>();
+                if (nextNaryEltwiseLayer.empty() && nextEltwiseLayer.empty())
+                    break;
+
+                if (nextData->inputBlobsId.size() != 2)
+                    break;
+
+                if (!nextData->params.has("operation") || toLowerCase(nextData->params.get<String>("operation")) != "add")
+                {
+                    CV_LOG_DEBUG(NULL, "DNN/CPU: fusion with NaryEltwise or Eltwise Layer operation is not supported: "
+                        << nextData->params.get<String>("operation"));
+                    break;
+                }
+
+                // This optimization is for cases like
+                // some_layer                      conv
+                //   |                              |
+                //   +-- eltwise or (naryEltwise) --+
+                //               |
+                //             activ
+                // This way all the element-wise computations
+                // (i.e. some_layer+conv) would be done at [conv] layer.
+                // So we need to replace [conv]'s output blob to [eltwise]'s one
+                // considering that [activ] is an in-place layer.
+                // Also we need to move all the consumers' references.
+                // To prevent memory collisions (i.e. when input of
+                // [conv] and output of [eltwise or naryEltwise] is the same blob)
+                // we allocate a new blob.
+                {
+                    LayerData *naryOrEltwiseData = nextData;
+
+                    // Eltwise or NaryEltwise layer has two inputs. We need to determine which
+                    // is a base convolution layer and which could be used as it's bias.
+                    LayerData* biasLayerData = 0;
+                    for (int i = 0; i < 2; ++i)
+                    {
+                        LayerData *downLayerData = &layers[naryOrEltwiseData->inputBlobsId[i].lid];
+                        CV_Assert(downLayerData);
+                        // If the current downLayerData is skip, it means it is fused into the parent node.
+                        while (downLayerData->skip)
+                        {
+                            if (downLayerData->inputBlobsId.size() == 1)
+                                downLayerData = &layers[downLayerData->inputBlobsId[0].lid];
+                            else
+                            {
+                                downLayerData = 0;
+                                break;
+                            }
+                        }
+
+                        if (downLayerData && ld.id == downLayerData->id)
+                        {
+                            biasLayerData = &layers[naryOrEltwiseData->inputBlobsId[1 - i].lid];
+                            break;
+                        }
+                    }
+
+                    // We check if biasLayerData is expected layer.
+                    if (!biasLayerData)
+                        break;
+
+                    // We check if the bias output shape and the ld output shape are the same.
+                    MatShape biasOutShape = shape(biasLayerData->outputBlobs[0]);
+                    MatShape ldOutShape = shape(ld.outputBlobs[0]);
+                    if (biasOutShape != ldOutShape)
+                        break;
+
+                    CV_Assert(biasLayerData);
+                    {
+                        // fuse naryEltwise layer
+                        // bias must already be computed to fuse => bias layer must appear before convolution
+                        if (biasLayerData->id < ld.id)
+                        {
+                            // conv + naryEltwise.
+                            CV_Assert_N(biasLayerData->outputBlobs.size() == 1, ld.inputBlobs.size() == 1);
+                            CV_Assert_N(biasLayerData->outputBlobsWrappers.size() == 1, ld.inputBlobsWrappers.size() == 1);
+
+                            printf_(("\tfused with %s\n", nextNaryEltwiseLayer->name.c_str()));
+                            naryOrEltwiseData->skip = true;
+
+
+                            CV_Assert_N(ld.outputBlobs.size() == 1, ld.outputBlobsWrappers.size() == 1);
+                            // Note: Here's a trick. We set the output of conv as the output of biasLayer.
+                            ld.outputBlobs[0] = ld.outputBlobs[0].clone();
+                            ld.outputBlobsWrappers[0] = wrap(ld.outputBlobs[0]);
+
+                            // Recursively modifies the output data of biasLayerData and its parent.
+                            std::vector<LayerData*> skipDataList;
+                            skipDataList.push_back(biasLayerData);
+
+                            while (!skipDataList.empty())
+                            {
+                                LayerData* skipData = skipDataList.back();
+                                skipDataList.pop_back();
+
+                                CV_Assert(skipData->outputBlobs.size() == 1);
+                                skipData->outputBlobs[0] = ld.outputBlobs[0];
+                                skipData->outputBlobsWrappers[0] = ld.outputBlobsWrappers[0];
+                                if (skipData->skip)
+                                {
+                                    for (auto& inputLayerId : skipData->inputLayersId)
+                                    {
+                                        LayerData* inputld = &layers[inputLayerId];
+
+                                        if (inputld && inputld->outputBlobs.size() == 1)
+                                            skipDataList.push_back(inputld);
+                                    }
+                                }
+                            }
+
+                            naryOrEltwiseData->outputBlobs = ld.outputBlobs;
+                            naryOrEltwiseData->outputBlobsWrappers = ld.outputBlobsWrappers;
+
+                            // set the fusedAdd flag in [Conv];
+                            convLayer->fusedAdd = true;
+                            LayerData* finalData = naryOrEltwiseData;
+                            /* After fused Conv + naryEltwise or eltwise, we can fuse activation if:
+                             * => activation layer that follows is the only consumer of eltwise output
+                             * => activation layer does not process multiple inputs
+                             * => we do not require to keep the output of eltwise
+                             */
+                            if (naryOrEltwiseData->consumers.size() == 1)
+                            {
+                                Ptr<ActivationLayer> nextFusabeleActivLayer;
+                                LayerData* nextAct = &layers[naryOrEltwiseData->consumers[0].lid];
+
+                                if (nextData->outputBlobs.size() == 1)
+                                    nextFusabeleActivLayer = nextAct->layerInstance.dynamicCast<ActivationLayer>();
+
+                                if (!nextFusabeleActivLayer.empty())
+                                {
+                                    convLayer->setActivation(nextFusabeleActivLayer);
+                                    nextAct->skip = true;
+
+                                    nextAct->outputBlobs = ld.outputBlobs;
+                                    nextAct->outputBlobsWrappers = ld.outputBlobsWrappers;
+                                }
+                            }
+
+                            // Move references of finalData (eltwise or activation) layer consumers to the newly allocated blob.
+                            for (int i = 0; i < finalData->consumers.size(); ++i)
+                            {
+                                LayerData& consumer = layers[finalData->consumers[i].lid];
+                                for (int j = 0; j < consumer.inputBlobsId.size(); ++j)
+                                {
+                                    if (consumer.inputBlobsId[j].lid == finalData->id)
+                                    {
+                                        consumer.inputBlobs[j] = &ld.outputBlobs[0];
+                                        consumer.inputBlobsWrappers[j] = ld.outputBlobsWrappers[0];
+                                        break;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                break;
+            }
+
             // OpenCL: fuse convolution layer followed by eltwise + relu
             // CUDA: fuse convolution layer followed by eltwise (and optional activation)
             while (nextData &&
@@ -398,7 +570,7 @@ void Net::Impl::fuseLayers(const std::vector<LayerPin>& blobsToKeep_)
                                 // (i.e. some_layer+conv or some_layer*conv)
                                 // would be done at [conv] layer. So we need to
                                 // replace [conv]'s output blob to [eltwise]'s one.
-                                // Also we need to move all the consumers' references.
+                                // Also, we need to move all the consumers' references.
                                 // To prevent memory collisions (i.e. when input of
                                 // [conv] and output of [eltwise] is the same blob)
                                 // we allocate a new blob.
index ddaa51d..3ea8d17 100644 (file)
@@ -284,7 +284,7 @@ TEST(Reproducibility_SSD, Accuracy)
     Mat out = net.forward("detection_out");
 
     Mat ref = blobFromNPY(_tf("ssd_out.npy"));
-    normAssertDetections(ref, out, "", FLT_MIN);
+    normAssertDetections(ref, out, "", 0.06);
 }
 
 typedef testing::TestWithParam<tuple<Backend, Target> > Reproducibility_MobileNet_SSD;
index 54db5a3..1cafa22 100644 (file)
@@ -1029,7 +1029,7 @@ TEST_P(Test_Int8_nets, FasterRCNN_resnet50)
     Mat blob = blobFromImage(inp, 1.0, Size(800, 600), Scalar(), true, false);
     Mat ref = blobFromNPY(_tf("tensorflow/faster_rcnn_resnet50_coco_2018_01_28.detection_out.npy"));
 
-    float confThreshold = 0.5, scoreDiff = 0.05, iouDiff = 0.15;
+    float confThreshold = 0.8, scoreDiff = 0.05, iouDiff = 0.15;
     testDetectionNet(net, blob, ref, confThreshold, scoreDiff, iouDiff);
 }
 
index 5208874..bd95727 100644 (file)
@@ -488,7 +488,7 @@ TEST_P(Test_Torch_nets, ENet_accuracy)
     // Due to numerical instability in Pooling-Unpooling layers (indexes jittering)
     // thresholds for ENet must be changed. Accuracy of results was checked on
     // Cityscapes dataset and difference in mIOU with Torch is 10E-4%
-    normAssert(ref, out, "", 0.00044, /*target == DNN_TARGET_CPU ? 0.453 : */0.552);
+    normAssert(ref, out, "", 0.0005, /*target == DNN_TARGET_CPU ? 0.453 : */0.552);
     normAssertSegmentation(ref, out);
 
     const int N = 3;
@@ -496,7 +496,7 @@ TEST_P(Test_Torch_nets, ENet_accuracy)
     {
         net.setInput(inputBlob, "");
         Mat out = net.forward();
-        normAssert(ref, out, "", 0.00044, /*target == DNN_TARGET_CPU ? 0.453 : */0.552);
+        normAssert(ref, out, "", 0.0005, /*target == DNN_TARGET_CPU ? 0.453 : */0.552);
         normAssertSegmentation(ref, out);
     }
 }