From 59b870a87ad22d89027e28eee3b8be1f3db1c0bf Mon Sep 17 00:00:00 2001 From: Zihao Mu Date: Fri, 1 Jul 2022 18:03:15 +0800 Subject: [PATCH] Merge pull request #21910 from zihaomu:fast_conv_ARM DNN: Accelerating convolution * Fast Conv of ARM, X86 and universal intrinsics. * improve code style. * error fixed. * improve the License * optimize memory allocated and Adjust the threshold. * change FasterRCNN_vgg16 to 2GB memory. --- modules/dnn/src/layers/convolution_layer.cpp | 54 +- .../fast_convolution/depthwise_convolution.cpp | 385 ++++++ .../fast_convolution/fast_convolution.avx2.cpp | 361 ++++++ .../layers/fast_convolution/fast_convolution.cpp | 694 ++++++++++ .../layers/fast_convolution/fast_convolution.hpp | 89 ++ .../fast_convolution/fast_convolution.simd.hpp | 342 +++++ .../layers/fast_convolution/winograd_3x3s1_f63.cpp | 1351 ++++++++++++++++++++ modules/dnn/test/test_backends.cpp | 2 +- modules/dnn/test/test_caffe_importer.cpp | 4 +- modules/dnn/test/test_int8_layers.cpp | 14 +- modules/dnn/test/test_tf_importer.cpp | 2 +- 11 files changed, 3286 insertions(+), 12 deletions(-) create mode 100644 modules/dnn/src/layers/fast_convolution/depthwise_convolution.cpp create mode 100644 modules/dnn/src/layers/fast_convolution/fast_convolution.avx2.cpp create mode 100644 modules/dnn/src/layers/fast_convolution/fast_convolution.cpp create mode 100644 modules/dnn/src/layers/fast_convolution/fast_convolution.hpp create mode 100644 modules/dnn/src/layers/fast_convolution/fast_convolution.simd.hpp create mode 100644 modules/dnn/src/layers/fast_convolution/winograd_3x3s1_f63.cpp diff --git a/modules/dnn/src/layers/convolution_layer.cpp b/modules/dnn/src/layers/convolution_layer.cpp index 0bf39f9..1244433 100644 --- a/modules/dnn/src/layers/convolution_layer.cpp +++ b/modules/dnn/src/layers/convolution_layer.cpp @@ -71,6 +71,8 @@ using namespace cv::dnn::ocl4dnn; using namespace cv::dnn::cuda4dnn; #endif +#include "fast_convolution/fast_convolution.hpp" + namespace cv { namespace dnn @@ -253,11 +255,14 @@ class ConvolutionLayerImpl CV_FINAL : public BaseConvolutionLayerImpl { public: enum { VEC_ALIGN = 8, DFT_TYPE = CV_32F }; - Mat weightsMat; + Mat weightsMat; // Used to store weight params. It will be used for layer fusion and memory alignment. std::vector biasvec; std::vector reluslope; Ptr activ; + Mat fastWeights; // Used to store weight params. It will be used for layer fusion and without memory alignment. + Ptr fastConv2dImpl; + #ifdef HAVE_OPENCL Ptr > convolutionOp; std::vector umat_blobs; @@ -433,6 +438,7 @@ public: wm.copyTo(wm_aligned); wm = wm_aligned; } + fastWeights = blobs[0].reshape(1, numOutput); weightsMat = wm; } else @@ -628,14 +634,26 @@ public: if (weightsMat.data == blobs[0].data) weightsMat = weightsMat.clone(); + // If fastWeights is the same as weightsMat, we don't need to allocate more space for fastWeights. + bool sameFastWeights = false; + if (fastWeights.step1() == weightsMat.step1()) // If weightsMat is realigned, it is not the same as fastWeights. + sameFastWeights = true; + + if (!sameFastWeights && fastWeights.data == blobs[0].data) + fastWeights = fastWeights.clone(); + Mat originWeights = blobs[0].reshape(1, outCn); for (int i = 0; i < outCn; ++i) { double wi = w.at(i); weightsMultipliers[i] *= wi; cv::multiply(originWeights.row(i), weightsMultipliers[i], weightsMat.row(i)); + if (!sameFastWeights) + cv::multiply(originWeights.row(i), weightsMultipliers[i], fastWeights.row(i)); biasvec[i] *= wi; } + if (sameFastWeights) + fastWeights = weightsMat; } if (!b.empty()) @@ -1948,8 +1966,13 @@ public: int outCn = blobs.empty() ? inputs[1].size[0] : blobs[0].size[0]; // Need to align non-const blobs + bool variableWeight = false; if (blobs.empty()) { + variableWeight = true; + if (fastWeights.data != inputs[1].data) + fastWeights = inputs[1].clone(); + Mat wm = inputs[1].reshape(1, outCn); if (wm.data != weightsMat.data) { @@ -2066,8 +2089,37 @@ public: { int nstripes = std::max(getNumThreads(), 1); + // Initialization of FastCovn2d + if ((!fastConv2dImpl || variableWeight) && inputs[0].dims == 4) + { + int K = outputs[0].size[1]; + int C = inputs[0].size[1]; + int Hk = kernel_size[kernel_size.size() - 2]; + int Wk = kernel_size.back(); + + CV_Assert(outputs[0].size[1] % ngroups == 0); + int stride_h = strides[strides.size() - 2]; + int stride_w = strides.back(); + + int dilation_h = dilations[dilations.size() - 2]; + int dilation_w = dilations.back(); + float* weightsPtr = fastWeights.ptr(); + CV_Assert(weightsPtr); + + fastConv2dImpl = initFastConv2d(ngroups, K, C, Hk, Wk, stride_w, stride_h, + dilation_w, dilation_h, pads_begin, pads_end, weightsPtr, &biasvec[0]); + } + + if (fastConv2dImpl) + { + runFastConv2d(inputs[0], outputs[0], fastConv2dImpl, nstripes, activ); + return; + } + + // Use only for Conv1D and Conv3D. ParallelConv::run(inputs[0], outputs[0], weightsMat, biasvec, reluslope, kernel_size, strides, pads_begin, pads_end, dilations, activ.get(), ngroups, nstripes); + } } diff --git a/modules/dnn/src/layers/fast_convolution/depthwise_convolution.cpp b/modules/dnn/src/layers/fast_convolution/depthwise_convolution.cpp new file mode 100644 index 0000000..c98c3d6 --- /dev/null +++ b/modules/dnn/src/layers/fast_convolution/depthwise_convolution.cpp @@ -0,0 +1,385 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +// This file is modified from the ficus (https://github.com/vpisarev/ficus/blob/master/lib/NN/OpConv.fx). +// Here is the original license: +/* + This file is a part of ficus language project. + See ficus/LICENSE for the licensing terms +*/ + +#include "../../precomp.hpp" +#include "fast_convolution.hpp" + +namespace cv { namespace dnn { + +static void depthWiseBlock(const float *inptr, float *outptr, const float *weights, float biasval, int *ofstab, int *yxtab, + float minval, float maxval, int Hi, int Wi, int H0, int W0, int ksize, int pad_top, int pad_left, + int dilation_y, int stride_x, int stride_y, int inner_xleft, int inner_xright, int inner_ytop, + int inner_ybottom, bool ifMinMaxAct, bool useSIMD, bool is3x3) +{ +#ifdef CV_SIMD128 + v_float32x4 vminval = v_setall_f32(minval), vmaxval = v_setall_f32(maxval); + + v_float32x4 w0 = v_setall_f32( + 0.f), w1 = w0, w2 = w0, w3 = w0, w4 = w0, w5 = w0, w6 = w0, w7 = w0, w8 = w0, vbias = w0; + if (useSIMD) + { + vbias = v_setall_f32(biasval); + if (is3x3) + { + w0 = v_setall_f32(weights[0]); + w1 = v_setall_f32(weights[1]); + w2 = v_setall_f32(weights[2]); + w3 = v_setall_f32(weights[3]); + w4 = v_setall_f32(weights[4]); + w5 = v_setall_f32(weights[5]); + w6 = v_setall_f32(weights[6]); + w7 = v_setall_f32(weights[7]); + w8 = v_setall_f32(weights[8]); + } + } +#endif + int dy0 = 1; + for (int y0 = 0; y0 < H0; y0 += dy0, outptr += W0 * dy0) + { +#ifdef CV_SIMD128 + dy0 = inner_ytop <= y0 && y0 + 3 < inner_ybottom && is3x3 && stride_y == 1 && dilation_y == 1 + ? 3 : 1; +#endif + int x0 = 0, x1 = y0 >= inner_ytop && y0 < inner_ybottom ? inner_xleft : W0; + int yi_ = y0 * stride_y - pad_top; + + for (;;) + { + float s_0, s_1, s_2; + if (dy0 == 3) + { + for (; x0 < x1; x0++) + { + int xi_ = x0 * stride_x - pad_left; + s_0 = s_1 = s_2 = biasval; + for (int k = 0; k < ksize; k++) + { + int dy = yxtab[k * 2]; + int yi = yi_ + dy; + int xi = xi_ + yxtab[k * 2 + 1]; + float w = weights[k]; + + if ((unsigned) xi < (unsigned) Wi) + { + s_0 += inptr[yi * Wi + xi] * w; + s_1 += inptr[(yi + 1) * Wi + xi] * w; + s_2 += inptr[(yi + 2) * Wi + xi] * w; + } + } + s_0 = std::min(std::max(s_0, minval), maxval); + s_1 = std::min(std::max(s_1, minval), maxval); + s_2 = std::min(std::max(s_2, minval), maxval); + outptr[x0] = s_0; + outptr[x0 + W0] = s_1; + outptr[x0 + W0 * 2] = s_2; + } + } + else + { + for (; x0 < x1; x0++) + { + int xi_ = x0 * stride_x - pad_left; + s_0 = biasval; + for (int k = 0; k < ksize; k++) { + int dy = yxtab[k * 2]; + int yi = yi_ + dy; + int xi = xi_ + yxtab[k * 2 + 1]; + float w = weights[k]; + if (((unsigned) yi < (unsigned) Hi) & ((unsigned) xi < (unsigned) Wi)) + s_0 += inptr[yi * Wi + xi] * w; + } + s_0 = std::min(std::max(s_0, minval), maxval); + outptr[x0] = s_0; + } + } + if (x0 == W0) + break; + x1 = inner_xright; +#ifdef CV_SIMD128 + if (useSIMD) + { + if (is3x3) + { + if (dy0 == 3) + { + for (; x0 <= x1 - FAST_VEC_NLANES; x0 += FAST_VEC_NLANES) + { + int xi_ = x0 * stride_x - pad_left; + const float *inptr_xi = inptr + Wi * yi_ + xi_; + + v_float32x4 s0, s1, s2; + v_float32x4 x00 = v_load(inptr_xi); + v_float32x4 x01 = v_load(inptr_xi + 1); + v_float32x4 x02 = v_load(inptr_xi + 2); + + v_float32x4 x10 = v_load(inptr_xi + Wi); + v_float32x4 x11 = v_load(inptr_xi + Wi + 1); + v_float32x4 x12 = v_load(inptr_xi + Wi + 2); + + v_float32x4 x20 = v_load(inptr_xi + Wi * 2); + v_float32x4 x21 = v_load(inptr_xi + Wi * 2 + 1); + v_float32x4 x22 = v_load(inptr_xi + Wi * 2 + 2); + + v_float32x4 x30 = v_load(inptr_xi + Wi * 3); + v_float32x4 x31 = v_load(inptr_xi + Wi * 3 + 1); + v_float32x4 x32 = v_load(inptr_xi + Wi * 3 + 2); + + v_float32x4 x40 = v_load(inptr_xi + Wi * 4); + v_float32x4 x41 = v_load(inptr_xi + Wi * 4 + 1); + v_float32x4 x42 = v_load(inptr_xi + Wi * 4 + 2); + + s0 = v_fma(x00, w0, vbias); + s1 = v_fma(x10, w0, vbias); + s2 = v_fma(x20, w0, vbias); + + s0 = v_fma(x01, w1, s0); + s1 = v_fma(x11, w1, s1); + s2 = v_fma(x21, w1, s2); + + s0 = v_fma(x02, w2, s0); + s1 = v_fma(x12, w2, s1); + s2 = v_fma(x22, w2, s2); + + s0 = v_fma(x10, w3, s0); + s1 = v_fma(x20, w3, s1); + s2 = v_fma(x30, w3, s2); + + s0 = v_fma(x11, w4, s0); + s1 = v_fma(x21, w4, s1); + s2 = v_fma(x31, w4, s2); + + s0 = v_fma(x12, w5, s0); + s1 = v_fma(x22, w5, s1); + s2 = v_fma(x32, w5, s2); + + s0 = v_fma(x20, w6, s0); + s1 = v_fma(x30, w6, s1); + s2 = v_fma(x40, w6, s2); + + s0 = v_fma(x21, w7, s0); + s1 = v_fma(x31, w7, s1); + s2 = v_fma(x41, w7, s2); + + s0 = v_fma(x22, w8, s0); + s1 = v_fma(x32, w8, s1); + s2 = v_fma(x42, w8, s2); + + if (ifMinMaxAct) + { + s0 = v_min(v_max(s0, vminval), vmaxval); + s1 = v_min(v_max(s1, vminval), vmaxval); + s2 = v_min(v_max(s2, vminval), vmaxval); + } + + v_store(outptr + x0, s0); + v_store(outptr + W0 + x0, s1); + v_store(outptr + W0 * 2 + x0, s2); + } + } + else + { + for (; x0 <= x1 - FAST_VEC_NLANES; x0 += FAST_VEC_NLANES) + { + int xi_ = x0 * stride_x - pad_left; + const float *inptr_xi = inptr + Wi * yi_ + xi_; + v_float32x4 s0 = v_fma(v_load(inptr_xi + ofstab[0]), w0, vbias); + v_float32x4 s1 = v_load(inptr_xi + ofstab[1]) * w1; + v_float32x4 s2 = v_load(inptr_xi + ofstab[2]) * w2; + + s0 = v_fma(v_load(inptr_xi + ofstab[3]), w3, s0); + s1 = v_fma(v_load(inptr_xi + ofstab[4]), w4, s1); + s2 = v_fma(v_load(inptr_xi + ofstab[5]), w5, s2); + + s0 = v_fma(v_load(inptr_xi + ofstab[6]), w6, s0); + s1 = v_fma(v_load(inptr_xi + ofstab[7]), w7, s1); + s2 = v_fma(v_load(inptr_xi + ofstab[8]), w8, s2); + + s0 = s0 + s1 + s2; + if (ifMinMaxAct) + s0 = v_min(v_max(s0, vminval), vmaxval); + v_store(outptr + x0, s0); + } + } + } + else + { + for (; x0 <= x1 - FAST_VEC_NLANES; x0 += FAST_VEC_NLANES) + { + int xi_ = x0 * stride_x - pad_left, k = 0; + const float *inptr_xi = inptr + Wi * yi_ + xi_; + v_float32x4 s0 = vbias; + for (; k <= ksize - 4; k += 4) + { + v_float32x4 v0 = v_load(inptr_xi + ofstab[k]); + v_float32x4 v1 = v_load(inptr_xi + ofstab[k + 1]); + v_float32x4 v2 = v_load(inptr_xi + ofstab[k + 2]); + v_float32x4 v3 = v_load(inptr_xi + ofstab[k + 3]); + + v_float32x4 ww0 = v_setall_f32(weights[k]); + v_float32x4 ww1 = v_setall_f32(weights[k+1]); + v_float32x4 ww2 = v_setall_f32(weights[k+2]); + v_float32x4 ww3 = v_setall_f32(weights[k+3]); + + s0 = v_fma(v0, ww0, s0); + s0 = v_fma(v1, ww1, s0); + s0 = v_fma(v2, ww2, s0); + s0 = v_fma(v3, ww3, s0); + } + for (; k < ksize; k++) + s0 = v_fma(v_load(inptr_xi + ofstab[k]), + v_setall_f32(weights[k]), s0); + if (ifMinMaxAct) + s0 = v_min(v_max(s0, vminval), vmaxval); + v_store(outptr + x0, s0); + } + } + } +#endif + if (dy0 == 3) + { + for (; x0 < x1; x0++) + { + int xi_ = x0 * stride_x - pad_left; + const float *inptr_xi = inptr + W0 * yi_ + xi_; + s_0 = s_1 = s_2 = biasval; + for (int k = 0; k < ksize; k++) + { + int inp_ofs = ofstab[k]; + float w = weights[k]; + s_0 += inptr_xi[inp_ofs] * w; + s_1 += inptr_xi[inp_ofs + Wi] * w; + s_2 += inptr_xi[inp_ofs + Wi * 2] * w; + } + if (ifMinMaxAct) + { + s_0 = std::min(std::max(s_0, minval), maxval); + s_1 = std::min(std::max(s_1, minval), maxval); + s_2 = std::min(std::max(s_2, minval), maxval); + } + + outptr[x0] = s_0; + outptr[x0 + W0] = s_1; + outptr[x0 + W0 * 2] = s_2; + } + } + else + { + for (; x0 < x1; x0++) + { + int xi_ = x0 * stride_x - pad_left; + const float *inptr_xi = inptr + Wi * yi_ + xi_; + s_0 = biasval; + for (int k = 0; k < ksize; k++) + { + s_0 += inptr_xi[ofstab[k]] * weights[k]; + } + + if (ifMinMaxAct) + s_0 = std::min(std::max(s_0, minval), maxval); + outptr[x0] = s_0; + } + } + x1 = W0; + } + } +} + +void runDepthwise(InputArray _input, OutputArray _output, const Ptr& conv, float minval, float maxval, ActivationLayer* activ, bool ifMinMaxAct) { + Mat input = _input.getMat(); + Mat output = _output.getMat(); + MatShape inputShape = shape(input); + MatShape outputShape = shape(output); + CV_Assert(inputShape.size() == 4 && outputShape.size() == 4); + + int N = inputShape[0], C = inputShape[1], Hi = inputShape[2], Wi = inputShape[3]; // [N, C, H, W] + int K = conv->K, Hk = conv->Hk, Wk = conv->Wk; + int H0 = outputShape[2], W0 = outputShape[3], ngroups = conv->ngroups; + + const size_t inp_planesize = (size_t) Hi * Wi; + const size_t out_planesize = (size_t) H0 * W0; + + CV_Assert(ngroups > 1 && ngroups == K && ngroups == C); + + int stride_y = conv->stride_y, stride_x = conv->stride_x; + int dilation_y = conv->dilation_y, dilation_x = conv->dilation_x; + + int pad_top = conv->pad_top, pad_bottom = conv->pad_bottom; + int pad_left = conv->pad_left, pad_right = conv->pad_right; + + int ksize = Hk * Wk, padded_ksize = ((ksize + FAST_VEC_NLANES - 1) / FAST_VEC_NLANES) * FAST_VEC_NLANES; + + const float *inp = input.ptr(); + float *out = output.ptr(); + + std::vector ofstab_(3 * padded_ksize, 0); + int *ofstab = ofstab_.data(); + int *yxtab = ofstab + padded_ksize; + + for (int k = 0; k < padded_ksize; k++) + { + int y = k < ksize ? k / Wk : 0; + int x = k < ksize ? k % Wk : 0; + int dy = y * dilation_y, dx = x * dilation_x; + yxtab[k * 2] = dy; + yxtab[k * 2 + 1] = dx; + ofstab[k] = dy * Wi + dx; + } + + const float *weights0 = conv->weightsBuf.data(), *bias = conv->biasBuf.data(); + int inner_ytop = (pad_bottom + stride_y - 1) / stride_y, inner_ybottom = 3; + int inner_xleft = (pad_left + stride_x - 1) / stride_x, inner_xright = 4; + + CV_Assert(ksize > 1 || (pad_left == 0 && pad_right == 0 && pad_top == 0 && pad_bottom == 0)); + + inner_xright = (Wi - (Wk - 1) * dilation_x + pad_left) / stride_x; + inner_xright += inner_xright * stride_x - pad_left + (Wk - 1) * dilation_x < Wi; + inner_ybottom = (Hi - (Hk - 1) * dilation_y + pad_top) / stride_y; + inner_ybottom += inner_ybottom * stride_y - pad_top + (Hk - 1) * dilation_y < Hi; + + if (inner_xleft >= inner_xright || inner_ytop >= inner_ybottom) + { + inner_xleft = W0; + inner_ytop = H0; + } + + inner_ybottom = inner_ybottom < H0 ? inner_ybottom : H0; + + bool useSIMD = stride_x == 1 && inner_xleft < W0; + bool is3x3 = Hk == 3 && Wk == 3; + + parallel_for_(Range(0, N * C), [&](const Range &r0) { + for (int nc = r0.start; nc < r0.end; nc++) + { + int c = nc % C; + const float *inptr = inp + inp_planesize * nc; + float *outptr0 = out + out_planesize * nc; + + float biasval = bias[c]; + const float *weights = weights0 + c * padded_ksize; + +#if CV_TRY_AVX2 + if (conv->useAVX2) + opt_AVX2::depthWiseBlock_AVX2(inptr, outptr0, weights, biasval, ofstab, yxtab, minval, maxval, Hi, Wi, H0, W0, ksize, + pad_top, pad_left, dilation_y, stride_x, stride_y, inner_xleft, inner_xright, inner_ytop, + inner_ybottom, ifMinMaxAct, useSIMD, is3x3); + else +#endif + depthWiseBlock(inptr, outptr0, weights, biasval, ofstab, yxtab, minval, maxval, Hi, Wi, H0, W0, ksize, + pad_top, pad_left, dilation_y, stride_x, stride_y, inner_xleft, inner_xright, inner_ytop, + inner_ybottom, ifMinMaxAct, useSIMD, is3x3); + + if (activ) + activ->forwardSlice(outptr0, outptr0, (int) out_planesize, out_planesize, c, c+1); + } + }); +} + +}} // namespace cv::dnn \ No newline at end of file diff --git a/modules/dnn/src/layers/fast_convolution/fast_convolution.avx2.cpp b/modules/dnn/src/layers/fast_convolution/fast_convolution.avx2.cpp new file mode 100644 index 0000000..22580c5 --- /dev/null +++ b/modules/dnn/src/layers/fast_convolution/fast_convolution.avx2.cpp @@ -0,0 +1,361 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include "../../precomp.hpp" +#include "fast_convolution.hpp" + +namespace cv { +namespace opt_AVX2 +{ +#if CV_TRY_AVX2 +void convBlock_AVX2(int k, const float *a, const float *b, + float *c, int ldc, const float *bias, + float minval, float maxval, bool ifActiv) +{ +#if FAST_CONV_MR == 4 && FAST_CONV_NR == 24 + __m256 vminval = _mm256_set1_ps(minval), vmaxval = _mm256_set1_ps(maxval); + __m256 c0 = _mm256_set1_ps(bias[0]), c1 = c0, c2 = c0; + __m256 c3 = _mm256_set1_ps(bias[1]), c4 = c3, c5 = c3; + __m256 c6 = _mm256_set1_ps(bias[2]), c7 = c6, c8 = c6; + __m256 c9 = _mm256_set1_ps(bias[3]), c10 = c9, c11 = c9; + + __m256 a0 = _mm256_setzero_ps(), a1 = _mm256_setzero_ps(); + __m256 b0 = _mm256_setzero_ps(), b1 = _mm256_setzero_ps(), b2 = _mm256_setzero_ps(); + + for (int p = 0; p < k; p++, a += FAST_CONV_MR, b += FAST_CONV_NR) + { + a0 = _mm256_set1_ps(a[0]), a1 = _mm256_set1_ps(a[1]); + b0 = _mm256_load_ps(b), b1 = _mm256_load_ps(b + 8), b2 = _mm256_load_ps(b + 16); + + c0 = _mm256_fmadd_ps(b0, a0, c0); + c1 = _mm256_fmadd_ps(b1, a0, c1); + c2 = _mm256_fmadd_ps(b2, a0, c2); + + c3 = _mm256_fmadd_ps(b0, a1, c3); + a0 = _mm256_set1_ps(a[2]); + c4 = _mm256_fmadd_ps(b1, a1, c4); + c5 = _mm256_fmadd_ps(b2, a1, c5); + + c6 = _mm256_fmadd_ps(b0, a0, c6); + a1 = _mm256_set1_ps(a[3]); + c7 = _mm256_fmadd_ps(b1, a0, c7); + c8 = _mm256_fmadd_ps(b2, a0, c8); + + c9 = _mm256_fmadd_ps(b0, a1, c9); + c10 = _mm256_fmadd_ps(b1, a1, c10); + c11 = _mm256_fmadd_ps(b2, a1, c11); + } + + if (ifActiv) + { + c0 = _mm256_min_ps(_mm256_max_ps(c0, vminval), vmaxval); + c1 = _mm256_min_ps(_mm256_max_ps(c1, vminval), vmaxval); + c2 = _mm256_min_ps(_mm256_max_ps(c2, vminval), vmaxval); + c3 = _mm256_min_ps(_mm256_max_ps(c3, vminval), vmaxval); + c4 = _mm256_min_ps(_mm256_max_ps(c4, vminval), vmaxval); + c5 = _mm256_min_ps(_mm256_max_ps(c5, vminval), vmaxval); + c6 = _mm256_min_ps(_mm256_max_ps(c6, vminval), vmaxval); + c7 = _mm256_min_ps(_mm256_max_ps(c7, vminval), vmaxval); + c8 = _mm256_min_ps(_mm256_max_ps(c8, vminval), vmaxval); + c9 = _mm256_min_ps(_mm256_max_ps(c9, vminval), vmaxval); + c10 = _mm256_min_ps(_mm256_max_ps(c10, vminval), vmaxval); + c11 = _mm256_min_ps(_mm256_max_ps(c11, vminval), vmaxval); + } + + _mm256_storeu_ps(c, c0); _mm256_storeu_ps(c+8, c1); _mm256_storeu_ps(c+16, c2); + _mm256_storeu_ps(c + ldc, c3); _mm256_storeu_ps(c + ldc + 8, c4); _mm256_storeu_ps(c + ldc + 16, c5); + _mm256_storeu_ps(c + ldc*2, c6); _mm256_storeu_ps(c + ldc*2 + 8, c7); _mm256_storeu_ps(c + ldc*2 + 16, c8); + _mm256_storeu_ps(c + ldc*3, c9); _mm256_storeu_ps(c + ldc*3 + 8, c10); _mm256_storeu_ps(c + ldc*3 + 16, c11); + _mm256_zeroupper(); +#else +#error "unsupported FAST_CONV_MR and/or FAST_CONV_NR in convBlock_AVX2." +#endif +} + +void depthWiseBlock_AVX2(const float *inptr, float *outptr, const float *weights, float biasval, int *ofstab, int *yxtab, + float minval, float maxval, int Hi, int Wi, int H0, int W0, int ksize, int pad_top, int pad_left, + int dilation_y, int stride_x, int stride_y, int inner_xleft, int inner_xright, int inner_ytop, + int inner_ybottom, bool ifMinMaxAct, bool useSIMD, bool is3x3) +{ + const int VECSZ = 8; + __m256 vminval = _mm256_set1_ps(minval); + __m256 vmaxval = _mm256_set1_ps(maxval); + + __m256 w0 = _mm256_setzero_ps(), + w1 = w0, w2 = w0, w3 = w0, w4 = w0, w5 = w0, w6 = w0, w7 = w0, w8 = w0, vbias = w0; + + if (useSIMD) + { + vbias = _mm256_set1_ps(biasval); + if (is3x3) + { + w0 = _mm256_set1_ps(weights[0]); + w1 = _mm256_set1_ps(weights[1]); + w2 = _mm256_set1_ps(weights[2]); + w3 = _mm256_set1_ps(weights[3]); + w4 = _mm256_set1_ps(weights[4]); + w5 = _mm256_set1_ps(weights[5]); + w6 = _mm256_set1_ps(weights[6]); + w7 = _mm256_set1_ps(weights[7]); + w8 = _mm256_set1_ps(weights[8]); + } + } + + int dy0 = 1; + for (int y0 = 0; y0 < H0; y0 += dy0, outptr += W0 * dy0) + { + dy0 = inner_ytop <= y0 && y0 + 3 < inner_ybottom && is3x3 && stride_y == 1 && dilation_y == 1 + ? 3 : 1; + + int x0 = 0, x1 = y0 >= inner_ytop && y0 < inner_ybottom ? inner_xleft : W0; + int yi_ = y0 * stride_y - pad_top; + + for (;;) + { + float s_0, s_1, s_2; + if (dy0 == 3) + { + for (; x0 < x1; x0++) + { + int xi_ = x0 * stride_x - pad_left; + s_0 = s_1 = s_2 = biasval; + for (int k = 0; k < ksize; k++) + { + int dy = yxtab[k * 2]; + int yi = yi_ + dy; + int xi = xi_ + yxtab[k * 2 + 1]; + float w = weights[k]; + + if ((unsigned) xi < (unsigned) Wi) + { + s_0 += inptr[yi * Wi + xi] * w; + s_1 += inptr[(yi + 1) * Wi + xi] * w; + s_2 += inptr[(yi + 2) * Wi + xi] * w; + } + } + if (ifMinMaxAct) + { + s_0 = std::min(std::max(s_0, minval), maxval); + s_1 = std::min(std::max(s_1, minval), maxval); + s_2 = std::min(std::max(s_2, minval), maxval); + } + + outptr[x0] = s_0; + outptr[x0 + W0] = s_1; + outptr[x0 + W0 * 2] = s_2; + } + } + else + { + for (; x0 < x1; x0++) + { + int xi_ = x0 * stride_x - pad_left; + s_0 = biasval; + for (int k = 0; k < ksize; k++) { + int dy = yxtab[k * 2]; + int yi = yi_ + dy; + int xi = xi_ + yxtab[k * 2 + 1]; + float w = weights[k]; + if (((unsigned) yi < (unsigned) Hi) & ((unsigned) xi < (unsigned) Wi)) + s_0 += inptr[yi * Wi + xi] * w; + } + if (ifMinMaxAct) + s_0 = std::min(std::max(s_0, minval), maxval); + outptr[x0] = s_0; + } + } + if (x0 == W0) + break; + x1 = inner_xright; + + if (useSIMD) + { + if (is3x3) + { + if (dy0 == 3) + { + for (; x0 <= x1 - VECSZ; x0 += VECSZ) + { + int xi_ = x0 * stride_x - pad_left; + const float *inptr_xi = inptr + Wi * yi_ + xi_; + + __m256 s0, s1, s2; + __m256 x00 = _mm256_loadu_ps(inptr_xi); + __m256 x01 = _mm256_loadu_ps(inptr_xi + 1); + __m256 x02 = _mm256_loadu_ps(inptr_xi + 2); + + __m256 x10 = _mm256_loadu_ps(inptr_xi + Wi); + __m256 x11 = _mm256_loadu_ps(inptr_xi + Wi + 1); + __m256 x12 = _mm256_loadu_ps(inptr_xi + Wi + 2); + + __m256 x20 = _mm256_loadu_ps(inptr_xi + Wi * 2); + __m256 x21 = _mm256_loadu_ps(inptr_xi + Wi * 2 + 1); + __m256 x22 = _mm256_loadu_ps(inptr_xi + Wi * 2 + 2); + + __m256 x30 = _mm256_loadu_ps(inptr_xi + Wi * 3); + __m256 x31 = _mm256_loadu_ps(inptr_xi + Wi * 3 + 1); + __m256 x32 = _mm256_loadu_ps(inptr_xi + Wi * 3 + 2); + + __m256 x40 = _mm256_loadu_ps(inptr_xi + Wi * 4); + __m256 x41 = _mm256_loadu_ps(inptr_xi + Wi * 4 + 1); + __m256 x42 = _mm256_loadu_ps(inptr_xi + Wi * 4 + 2); + + s0 = _mm256_fmadd_ps(x00, w0, vbias); + s1 = _mm256_fmadd_ps(x10, w0, vbias); + s2 = _mm256_fmadd_ps(x20, w0, vbias); + + s0 = _mm256_fmadd_ps(x01, w1, s0); + s1 = _mm256_fmadd_ps(x11, w1, s1); + s2 = _mm256_fmadd_ps(x21, w1, s2); + + s0 = _mm256_fmadd_ps(x02, w2, s0); + s1 = _mm256_fmadd_ps(x12, w2, s1); + s2 = _mm256_fmadd_ps(x22, w2, s2); + + s0 = _mm256_fmadd_ps(x10, w3, s0); + s1 = _mm256_fmadd_ps(x20, w3, s1); + s2 = _mm256_fmadd_ps(x30, w3, s2); + + s0 = _mm256_fmadd_ps(x11, w4, s0); + s1 = _mm256_fmadd_ps(x21, w4, s1); + s2 = _mm256_fmadd_ps(x31, w4, s2); + + s0 = _mm256_fmadd_ps(x12, w5, s0); + s1 = _mm256_fmadd_ps(x22, w5, s1); + s2 = _mm256_fmadd_ps(x32, w5, s2); + + s0 = _mm256_fmadd_ps(x20, w6, s0); + s1 = _mm256_fmadd_ps(x30, w6, s1); + s2 = _mm256_fmadd_ps(x40, w6, s2); + + s0 = _mm256_fmadd_ps(x21, w7, s0); + s1 = _mm256_fmadd_ps(x31, w7, s1); + s2 = _mm256_fmadd_ps(x41, w7, s2); + + s0 = _mm256_fmadd_ps(x22, w8, s0); + s1 = _mm256_fmadd_ps(x32, w8, s1); + s2 = _mm256_fmadd_ps(x42, w8, s2); + + if (ifMinMaxAct) + { + s0 = _mm256_min_ps(_mm256_max_ps(s0, vminval), vmaxval); + s1 = _mm256_min_ps(_mm256_max_ps(s1, vminval), vmaxval); + s2 = _mm256_min_ps(_mm256_max_ps(s2, vminval), vmaxval); + } + + _mm256_storeu_ps(outptr + x0, s0); + _mm256_storeu_ps(outptr + W0 + x0, s1); + _mm256_storeu_ps(outptr + W0 * 2 + x0, s2); + } + } + else + { + for (; x0 <= x1 - VECSZ; x0 += VECSZ) + { + int xi_ = x0 * stride_x - pad_left; + const float *inptr_xi = inptr + Wi * yi_ + xi_; + __m256 s0 = _mm256_fmadd_ps(_mm256_loadu_ps(inptr_xi + ofstab[0]), w0, vbias); + __m256 s1 = _mm256_mul_ps(_mm256_loadu_ps(inptr_xi + ofstab[1]), w1); + __m256 s2 = _mm256_mul_ps(_mm256_loadu_ps(inptr_xi + ofstab[2]), w2); + + s0 = _mm256_fmadd_ps(_mm256_loadu_ps(inptr_xi + ofstab[3]), w3, s0); + s1 = _mm256_fmadd_ps(_mm256_loadu_ps(inptr_xi + ofstab[4]), w4, s1); + s2 = _mm256_fmadd_ps(_mm256_loadu_ps(inptr_xi + ofstab[5]), w5, s2); + + s0 = _mm256_fmadd_ps(_mm256_loadu_ps(inptr_xi + ofstab[6]), w6, s0); + s1 = _mm256_fmadd_ps(_mm256_loadu_ps(inptr_xi + ofstab[7]), w7, s1); + s2 = _mm256_fmadd_ps(_mm256_loadu_ps(inptr_xi + ofstab[8]), w8, s2); + + s0 = _mm256_add_ps(_mm256_add_ps(s0, s1), s2); + + if (ifMinMaxAct) + s0 = _mm256_min_ps(_mm256_max_ps(s0, vminval), vmaxval); + _mm256_storeu_ps(outptr + x0, s0); + } + } + } + else + { + for (; x0 <= x1 - VECSZ; x0 += VECSZ) + { + int xi_ = x0 * stride_x - pad_left, k = 0; + const float *inptr_xi = inptr + Wi * yi_ + xi_; + __m256 s0 = vbias; + for (; k <= ksize - 4; k += 4) + { + __m256 v0 = _mm256_loadu_ps(inptr_xi + ofstab[k]); + __m256 v1 = _mm256_loadu_ps(inptr_xi + ofstab[k + 1]); + __m256 v2 = _mm256_loadu_ps(inptr_xi + ofstab[k + 2]); + __m256 v3 = _mm256_loadu_ps(inptr_xi + ofstab[k + 3]); + + __m256 ww0 = _mm256_set1_ps(weights[k]); + __m256 ww1 = _mm256_set1_ps(weights[k+1]); + __m256 ww2 = _mm256_set1_ps(weights[k+2]); + __m256 ww3 = _mm256_set1_ps(weights[k+3]); + + s0 = _mm256_fmadd_ps(v0, ww0, s0); + s0 = _mm256_fmadd_ps(v1, ww1, s0); + s0 = _mm256_fmadd_ps(v2, ww2, s0); + s0 = _mm256_fmadd_ps(v3, ww3, s0); + } + for (; k < ksize; k++) + s0 = _mm256_fmadd_ps(_mm256_loadu_ps(inptr_xi + ofstab[k]), + _mm256_set1_ps(weights[k]), s0); + + if (ifMinMaxAct) + s0 = _mm256_min_ps(_mm256_max_ps(s0, vminval), vmaxval); + _mm256_storeu_ps(outptr + x0, s0); + } + } + } + + if (dy0 == 3) + { + for (; x0 < x1; x0++) + { + int xi_ = x0 * stride_x - pad_left; + const float *inptr_xi = inptr + W0 * yi_ + xi_; + s_0 = s_1 = s_2 = biasval; + for (int k = 0; k < ksize; k++) { + int inp_ofs = ofstab[k]; + float w = weights[k]; + s_0 += inptr_xi[inp_ofs] * w; + s_1 += inptr_xi[inp_ofs + Wi] * w; + s_2 += inptr_xi[inp_ofs + Wi * 2] * w; + } + if (ifMinMaxAct) + { + s_0 = std::min(std::max(s_0, minval), maxval); + s_1 = std::min(std::max(s_1, minval), maxval); + s_2 = std::min(std::max(s_2, minval), maxval); + } + + outptr[x0] = s_0; + outptr[x0 + W0] = s_1; + outptr[x0 + W0 * 2] = s_2; + } + } + else + { + for (; x0 < x1; x0++) + { + int xi_ = x0 * stride_x - pad_left; + const float *inptr_xi = inptr + Wi * yi_ + xi_; + s_0 = biasval; + for (int k = 0; k < ksize; k++) + { + s_0 += inptr_xi[ofstab[k]] * weights[k]; + } + if (ifMinMaxAct) + s_0 = std::min(std::max(s_0, minval), maxval); + outptr[x0] = s_0; + } + } + x1 = W0; + } + } +} +#endif +} // namespace opt_AVX2 +} // namespace cv \ No newline at end of file diff --git a/modules/dnn/src/layers/fast_convolution/fast_convolution.cpp b/modules/dnn/src/layers/fast_convolution/fast_convolution.cpp new file mode 100644 index 0000000..139ea7f --- /dev/null +++ b/modules/dnn/src/layers/fast_convolution/fast_convolution.cpp @@ -0,0 +1,694 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +// This file is modified from the ficus (https://github.com/vpisarev/ficus/blob/master/lib/NN/OpConv.fx). +// Here is the original license: +/* + This file is a part of ficus language project. + See ficus/LICENSE for the licensing terms +*/ + +#include "../../precomp.hpp" +#include "fast_convolution.hpp" +#include "fast_convolution.simd.hpp" + +namespace cv { namespace dnn { + +Ptr initFastConv2d( + int ngroups, + int K, int C, int Hk, int Wk, + int stride_x, int stride_y, + int dilation_x, int dilation_y, + const std::vector& pads_begin, + const std::vector& pads_end, + float* srcWeights, + float* srcBias) +{ + Ptr conv = makePtr(); + + CV_Assert(ngroups > 0 && K > 0 && C > 0 && K % ngroups == 0); + CV_Assert(Hk > 0 && Wk > 0); + CV_Assert(stride_y > 0 && stride_x > 0); + CV_Assert(dilation_y > 0 && dilation_x > 0); + + conv->K = K; conv->C = C; conv->Hk = Hk; conv->Wk = Wk; // [K, iC, kH, kW] + conv->stride_y = stride_y; + conv->stride_x = stride_x; + conv->dilation_y = dilation_y; + conv->dilation_x = dilation_x; + + conv->ngroups = ngroups; + conv->pad_top = pads_begin[0]; + conv->pad_bottom = pads_end[0]; + conv->pad_left = pads_begin[1]; + conv->pad_right = pads_end[1]; + + // store bias; append some zero's to make sure that + // we can always read FAST_CONV_MR elements starting from any valid index + { + int k = 0, nbias = K + FAST_CONV_MR-1; + conv->biasBuf.reserve(nbias); + float* biasBufPtr = conv->biasBuf.data(); + for(; k < K; k++) + biasBufPtr[k] = srcBias ? srcBias[k] : 0.f; + for(; k < nbias; k++) + biasBufPtr[k] = 0.f; + } + +#if CV_NEON // For now, winograd is ARM platform only. + if (ngroups == 1 && Hk ==3 && Wk == 3 && stride_x == 1 && stride_y == 1 && dilation_x == 1 && dilation_y ==1 + && K >= 16 && C >= 16 ) + conv->ifWinograd63 = true; +#else + conv->ifWinograd63 = false; +#endif + + if (ngroups > 1 && ngroups == K && ngroups == C) + { + // for depth-wise convolutions on NCHW data we just preserve the weights in KCHW layout, + // but add some padding to make the weights array layout more SIMD-friendly + int ksize = Hk*Wk; + int padded_ksize = ((ksize + FAST_VEC_NLANES-1)/FAST_VEC_NLANES)*FAST_VEC_NLANES; // this code aims to let memory fit with vector size. + int nweights = C*padded_ksize; + conv->weightsBuf.reserve(nweights); + float* weightsBufPtr = conv->weightsBuf.data(); + memset(weightsBufPtr, 0, nweights*sizeof(weightsBufPtr[0])); + for(int c = 0; c < C; c++) + { + for (int k = 0; k < ksize; k++) + weightsBufPtr[c*padded_ksize + k] = srcWeights[c*ksize + k]; + } + } + else + { + // The weights are packed as + // ngroups x (ceil((K/ngroups)/FAST_CONV_MR)*FAST_CONV_MR) x (Cg*Hk*Wk) x FAST_CONV_MR tensor + int Kg = K/ngroups, Cg = max(C/ngroups, 1); + int Kg_aligned = ((Kg + FAST_CONV_MR - 1)/FAST_CONV_MR)*FAST_CONV_MR; + size_t nweights = ngroups*Kg_aligned*Cg*Hk*Wk; + conv->weightsBuf.reserve(nweights); + float* weightsBufPtr = conv->weightsBuf.data(); + memset(weightsBufPtr, 0, nweights*sizeof(weightsBufPtr[0])); + float* packed_wptr = weightsBufPtr; + + // pack the weight. + for(int g = 0; g < ngroups; g++) + { + for(int k0 = 0; k0 < Kg_aligned; k0 += FAST_CONV_MR) + { + int dk = Kg - k0 < FAST_CONV_MR ? Kg - k0 : FAST_CONV_MR; + for(int c = 0; c < Cg; c++) + { + for(int yx = 0; yx < Hk*Wk; yx++, packed_wptr += FAST_CONV_MR) + { + const float* wptr = srcWeights + ((g*Kg + k0)*Cg + c)*Hk*Wk + yx; + int k = 0; + for(; k < dk; k++, wptr += Cg*Hk*Wk) + packed_wptr[k] = *wptr; + for(; k < FAST_CONV_MR; k++) + packed_wptr[k] = 0.f; + } + } + } + } + + // Prepare Weight for Winograd F(6x6, 3x3) + if (conv->ifWinograd63) + { + initWinograd63(conv, srcWeights, K, C); + } + } + return conv; +} + +static void packInput(float* inpbuf, const float* inptr, int* yxtab, int ksize, int Cg, int Hi, int Wi, int W0, + int pad_top, int pad_left, int stride_x, int stride_y, int yx0, int slice_len, + bool fast_1x1, bool partial0, bool s1d1p0, bool s1d1) +{ + const size_t inp_planesize = (size_t)Hi*Wi; + + if (fast_1x1) + { + /* + super-fast branch for 1x1 convolutions with sy=sx=1. + in this case each feature plane can be safely treated + as 1D array and we just extract next portion + of FAST_CONV_NR elements from each feature plane and + put it together. + */ + inptr += yx0; + if (!partial0) + { + // Make special branch where memcpy() is called with a constant buffer size. + // Compilers will likely unroll this loop properly. + for (int c = 0; c < Cg; c++, inptr += inp_planesize, inpbuf += FAST_CONV_NR) + memcpy(inpbuf, inptr, FAST_CONV_NR * sizeof(inpbuf[0])); + } + else + { + for (int c = 0; c < Cg; c++, inptr += inp_planesize, inpbuf += FAST_CONV_NR) + { + memcpy(inpbuf, inptr, slice_len * sizeof(inpbuf[0])); + memset(inpbuf + slice_len, 0, (FAST_CONV_NR - slice_len) * sizeof(inpbuf[0])); + } + } + } + else if (s1d1p0) + { + /* + slower, but still fast branch for sy=sx=1, dy=dx=1 and without padding, + in this case we copy data from input tensors by chunks. + */ + for (int c = 0; c < Cg; c++) + { + float *inpbuf_c = inpbuf + c * (FAST_CONV_NR * ksize); + const float *inptr_c = inptr + c * inp_planesize; + + for (int k = 0; k < ksize; k++) + { + int y0 = yx0 / W0, x0 = yx0 % W0; + int yi = y0 + yxtab[k * 2], xi = x0 + yxtab[k * 2 + 1]; + float *inpbuf_k = inpbuf_c + k * FAST_CONV_NR; + int xi_0 = yxtab[k * 2 + 1]; + + int i = 0; + for (; i < slice_len;) + { + const float *inptr_k = inptr_c + yi * Wi + xi; + int copy_len = std::min(slice_len - i, W0 - x0); + int di_z = (slice_len == i + copy_len) ? FAST_CONV_NR - slice_len : 0; + + memcpy(inpbuf_k + i, + inptr_k, + copy_len * sizeof(inpbuf_k[0])); + + memset(inpbuf_k + i + copy_len, + 0, di_z * sizeof(inpbuf_k[0])); + + i += copy_len; + x0 = 0; + xi = xi_0; + yi++; + } + } + } + } + else if (s1d1) + { + /* + slower, but still fast branch for sy=sx=1, dy=dx=1. + in this case we copy data from input tensors by chunks and + interleave the data in inpbuf with 0's + (that correspond to the padding elements) when necessary + */ + int y0 = yx0 / W0, x0 = yx0 % W0; + for (int c = 0; c < Cg; c++) + { + float *inpbuf_c = inpbuf + c * (FAST_CONV_NR * ksize); + const float *inptr_c = inptr + c * inp_planesize; + + for (int k = 0; k < ksize; k++) + { + int x0_tmp = x0; + + int xi_0 = yxtab[k * 2 + 1] - pad_left; + + int yi = y0 + yxtab[k * 2] - pad_top, xi = x0_tmp + xi_0; + float *inpbuf_k = inpbuf_c + k * FAST_CONV_NR; + + int i = 0; + for (; i < slice_len;) { + int copyLen = std::min(slice_len - i, W0 - x0_tmp); + + int di_z = (i + copyLen == slice_len) ? FAST_CONV_NR - slice_len + : 0; // The final padding. + // pad_top or pad bottom + if (yi < 0 || yi > Hi - 1) + { + memset(inpbuf_k + i, + 0, (copyLen + di_z) * sizeof(inpbuf_k[0])); + i += copyLen + di_z; + } + else + { + int x_pad_left = 0, x_pad_right = 0; + + // pad_left + if (xi < 0) + { + x_pad_left = std::min(-xi, copyLen); + xi = 0; + copyLen -= x_pad_left; + } + + memset(inpbuf_k + i, + 0, x_pad_left * sizeof(inpbuf_k[0])); + i += x_pad_left; + + // pad right + if (xi + copyLen > Wi) + { + if (xi > Wi) + { + x_pad_right = copyLen; + copyLen = 0; + } + else + { + x_pad_right = std::min(xi + copyLen - Wi, copyLen); + copyLen -= x_pad_right; + } + } + + CV_Assert(copyLen >= 0); + + const float *inptr_k = inptr_c + yi * Wi + xi; + memcpy(inpbuf_k + i, + inptr_k, + copyLen * sizeof(inpbuf_k[0])); + + i += copyLen; + + // pad_right and the final padding. + memset(inpbuf_k + i, + 0, (di_z + x_pad_right) * sizeof(inpbuf_k[0])); + i += x_pad_right + di_z; + } + + x0_tmp = 0; + xi = xi_0; + yi++; + } + } + } + } + else + { + int y0_ = yx0 / W0, x0_ = yx0 - y0_ * W0; + for (int k = 0; k < ksize; k++) + { + int dy = yxtab[k * 2], dx = yxtab[k * 2 + 1]; + int i = 0, y0 = y0_, x0 = x0_; + for (; i < FAST_CONV_NR;) + { + float *inpbuf_ki = inpbuf + k * FAST_CONV_NR + i; + int yi = y0 * stride_y + dy - pad_top; + int xi = x0 * stride_x + dx - pad_left; + + if ((unsigned) yi < (unsigned) Hi && + (unsigned) xi < (unsigned) Wi) + { + const float *inptr_ki = inptr + yi * Wi + xi; + if (i + 4 <= FAST_CONV_NR && x0 + 4 <= W0 && xi + stride_x * 4 <= Wi) + { + if (stride_x == 2) { + for (int c = 0; c < Cg; c++, inpbuf_ki += FAST_CONV_NR * + ksize, inptr_ki += inp_planesize) + { + float t0 = inptr_ki[0], t1 = inptr_ki[2]; + float t2 = inptr_ki[4], t3 = inptr_ki[6]; + inpbuf_ki[0] = t0; + inpbuf_ki[1] = t1; + inpbuf_ki[2] = t2; + inpbuf_ki[3] = t3; + } + } + else + { + for (int c = 0; c < Cg; c++, inpbuf_ki += FAST_CONV_NR * + ksize, inptr_ki += inp_planesize) + { + float t0 = inptr_ki[0], t1 = inptr_ki[stride_x]; + float t2 = inptr_ki[stride_x * 2], t3 = inptr_ki[stride_x * 3]; + inpbuf_ki[0] = t0; + inpbuf_ki[1] = t1; + inpbuf_ki[2] = t2; + inpbuf_ki[3] = t3; + } + } + i += 4; + x0 += 4; + } + else + { + for (int c = 0; c < Cg; c++, inpbuf_ki += FAST_CONV_NR * + ksize, inptr_ki += inp_planesize) + *inpbuf_ki = *inptr_ki; + i++; + x0++; + } + } + else + { + for (int c = 0; c < Cg; c++, inpbuf_ki += FAST_CONV_NR * ksize) + inpbuf_ki[0] = 0.f; + i++; + x0++; + } + int mask = x0 >= W0; + y0 += mask; + x0 &= mask - 1; + } + } + } +} + +static void matMulCompute(float* outptr0, float* inpbuf_task, float* cbuf, const Ptr& conv, int HkWkCg, + int k0, int k1, int yx0, int yx1, size_t out_planesize, int g, int Kg, int Kg_aligned, + bool partial0, ActivationLayer*& activ, float minval, float maxval, bool ifMinMaxAct) +{ + int outstep0 = out_planesize; + + for (int k = k0; k < k1; k += FAST_CONV_MR, outptr0 += outstep0 * FAST_CONV_MR) + { + int dk = Kg - k < FAST_CONV_MR ? Kg - k : FAST_CONV_MR; + bool partial = partial0 || dk < FAST_CONV_MR; + float *outptr = outptr0; + + int outstep = outstep0; + if (partial) + { + outptr = cbuf; + outstep = FAST_CONV_NR; + } + + +#if CV_TRY_AVX2 + if (conv->useAVX2) + opt_AVX2::convBlock_AVX2( HkWkCg, conv->weightsBuf.data() + (g * Kg_aligned + k) * HkWkCg, + inpbuf_task, outptr, outstep, conv->biasBuf.data() + Kg * g + k, + minval, maxval, ifMinMaxAct); + else +#endif +#if CV_TRY_NEON + if (conv->useNEON) + opt_NEON::convBlock_NEON(HkWkCg, conv->weightsBuf.data() + (g * Kg_aligned + k) * HkWkCg, + inpbuf_task, outptr, outstep, conv->biasBuf.data() + Kg * g + k, + minval, maxval, ifMinMaxAct); + else +#endif + convBlock(HkWkCg, conv->weightsBuf.data() + (g * Kg_aligned + k) * HkWkCg, + inpbuf_task, outptr, outstep, conv->biasBuf.data() + Kg * g + k, + minval, maxval, ifMinMaxAct); + + // activation + if (activ) + activ->forwardSlice(outptr, outptr, yx1 - yx0, outstep, Kg * g + k, + Kg * g + k + dk); + + if (partial) + { + for (int i = 0; i < dk; i++) + memcpy(outptr0 + i * outstep0, cbuf + i * FAST_CONV_NR, + (yx1 - yx0) * sizeof(cbuf[0])); + } + } +} + +void runFastConv2d(InputArray _input, OutputArray _output, + const Ptr& conv, int ntasks, const Ptr& actLayer) +{ + Mat input = _input.getMat(); + Mat output = _output.getMat(); + MatShape inputShape = shape(input); + MatShape outputShape = shape(output); + CV_Assert(inputShape.size() == 4 && outputShape.size() == 4); + + ActivationLayer* activ = 0; + float minval = -FLT_MAX, maxval = FLT_MAX; + bool ifMinMaxAct = false; + if (actLayer) + { + Ptr activ_relu = actLayer.dynamicCast(); + Ptr activ_relu6 = actLayer.dynamicCast(); + + if (!activ_relu.empty()) + { + if (activ_relu->negativeSlope == 0.0f) + { + minval = 0.0f; + ifMinMaxAct = true; + activ = nullptr; + } + else // Leaky ReLU + { + activ = actLayer.get(); + } + } + else if (!activ_relu6.empty()) + { + minval = activ_relu6->minValue; + maxval = activ_relu6->maxValue; + + ifMinMaxAct = true; + activ = nullptr; + } + else + activ = actLayer.get(); + } + else + activ = nullptr; + + if (conv->ngroups > 1 && conv->ngroups == conv->K && conv->ngroups == conv->C) + { + return runDepthwise(input, output, conv, minval, maxval, activ, ifMinMaxAct); + } + +#if CV_NEON + if ( conv->ifWinograd63 + && inputShape[2] > 12 && inputShape[3] > 12 + && inputShape[2] < 120 && inputShape[3] < 120 ) + { + // In general, for winograd branch, more cores will give better performance. + int maxNumThread = std::max(getNumThreads(), 1); + if (runWinograd63(input, output, conv, maxNumThread, minval, maxval, activ, ifMinMaxAct)) + return; + } +#endif + + float* inp = input.ptr(); + float* out = output.ptr(); + + int N = inputShape[0], C = inputShape[1], Hi = inputShape[2], Wi = inputShape[3]; // [N, C, H, W] + int K = conv->K, Hk = conv->Hk, Wk = conv->Wk; + int H0 = outputShape[2], W0 = outputShape[3], ngroups = conv->ngroups; // ngroups + int Cg = C/ngroups, Kg = K/ngroups; + int Kg_nblocks = (Kg + FAST_CONV_MR-1)/FAST_CONV_MR, Kg_aligned = Kg_nblocks*FAST_CONV_MR; // align to MR + + const size_t inp_planesize = (size_t)Hi*Wi; + const size_t out_planesize = (size_t)H0*W0; + + int pad_top = conv->pad_top, pad_bottom = conv->pad_bottom; + int pad_left = conv->pad_left; + int pad_right = conv->pad_right; + + int stride_y = conv->stride_y, stride_x = conv->stride_x; + int dilation_y = conv->dilation_y, dilation_x = conv->dilation_x; + + int ksize = Hk * Wk; + bool s1d1 = stride_x == 1 && stride_y == 1 && dilation_x == 1 && dilation_y == 1; + bool s1d1p0 = s1d1 && pad_top == 0 && pad_left ==0 && pad_bottom == 0 && pad_right == 0; + bool fast_1x1 = stride_x == 1 && stride_y == 1 && ksize == 1; + int HkWkCg = Hk*Wk*Cg; + + enum { VEC_ALIGN = 8, DFT_TYPE = CV_32F }; + size_t taskbufsize = FAST_CONV_NR*HkWkCg; // input buffer + size_t taskbufsizeOutput = FAST_CONV_NR * FAST_CONV_MR; + size_t inputbufsize = 0; + size_t outbufsize = ntasks * taskbufsizeOutput; + + int stripes_per_sample = (out_planesize + FAST_CONV_NR - 1)/FAST_CONV_NR; // align to NR + size_t hw_task = stripes_per_sample; + size_t hw_aligned = stripes_per_sample * FAST_CONV_NR; + + bool separatedLoop = false; + + if (stripes_per_sample < 4 * ntasks) + { + // If stripes_per_sample is small, we parallelize on K (output channel). + stripes_per_sample = 1; + + // Separated Parallelloop could save much time in packing input data. But it may cost more memory, we use it when batch size is 1. + if (N == 1) + { + separatedLoop = true; + inputbufsize = ngroups * hw_aligned * HkWkCg; + } + + if (!separatedLoop) + { + inputbufsize = taskbufsize * ntasks; + } + } + else + { + // If stripes_per_sample is big, we parallelize on H0*W0. + Kg_nblocks = 1; + inputbufsize = taskbufsize * ntasks; + } + + int Kstripes = Kg_nblocks*stripes_per_sample; + int nsubtasks = N*ngroups*Kstripes; + + AutoBuffer inpbuf_all_, outputbuf_; + inputbufsize = alignSize(inputbufsize, VEC_ALIGN); + inpbuf_all_.allocate(inputbufsize + VEC_ALIGN); + float* inpbuf_all = alignPtr(inpbuf_all_.data(), (int)(VEC_ALIGN*sizeof(float))); + + outbufsize = alignSize(outbufsize, VEC_ALIGN); + outputbuf_.allocate(outbufsize + VEC_ALIGN); + float* output_buf = alignPtr(outputbuf_.data(), (int)(VEC_ALIGN*sizeof(float))); + + std::vector ofstab_(Hk*Wk*3, 0); + int* ofstab = ofstab_.data(); + int* yxtab = ofstab + Hk*Wk; + + for (int y = 0; y < Hk; y++) + for( int x = 0; x < Wk; x++) + { + int k = y*Wk + x; + int dy = y*dilation_y, dx = x*dilation_x; + yxtab[k*2] = dy; + yxtab[k*2+1] = dx; + ofstab[k] = dy*Wi + dx; + } + + if (ksize == 1) + { + CV_Assert(pad_left == 0 && pad_right == 0 && pad_top == 0 && pad_bottom == 0); + CV_Assert(stride_x != 1 || stride_y != 1 || (H0 == Hi && W0 == Wi)); + } + + if (separatedLoop) + { + // For now this branch only handles batch size = 1. Maybe we could support batch size < 10 in the future. + // Pack Input data + parallel_for_(Range(0, ngroups * hw_task), [&](const Range& r0) + { + for (int nhwi = r0.start; nhwi < r0.end; nhwi++) + { + int g = nhwi/hw_task; + int hw_i = nhwi % hw_task; + int hw0 = hw_i * FAST_CONV_NR; + float* inpbuf = inpbuf_all + g * hw_aligned * HkWkCg + hw0 * HkWkCg; + const float* inptr = inp + g * Cg * inp_planesize; + bool partial0 = hw0 + FAST_CONV_NR > out_planesize? true: false; + int slice_len = FAST_CONV_NR; + + if (partial0) + slice_len = out_planesize - hw0; + + packInput(inpbuf, inptr, yxtab, ksize, Cg, Hi, Wi, W0, pad_top, pad_left, stride_x, stride_y, + hw0, slice_len, fast_1x1, partial0, s1d1p0, s1d1); + } + }); + + // Compute + parallel_for_(Range(0, ntasks), [&](const Range& r0) + { + for (int task_id = r0.start; task_id < r0.end; task_id++) + { + float *cbuf = output_buf + task_id * taskbufsizeOutput; + int ngs0 = (int) ((size_t) nsubtasks * task_id / ntasks); + int ngs1 = (int) ((size_t) nsubtasks * (task_id + 1) / ntasks); + for (int subtask = ngs0; subtask < ngs1;) + { + int ng = subtask / Kstripes; + int kyx0 = subtask - ng * Kstripes; + int kyx1 = kyx0 + (ngs1 - subtask); + int n = ng / ngroups, g = ng - n * ngroups; + + CV_Assert(n <= 1); + + kyx1 = kyx1 <= Kstripes ? kyx1 : Kstripes; // Guarantee that maximum kyx1 is Kstripes. + subtask += kyx1 - kyx0; + + int k0 = kyx0 * FAST_CONV_MR; + int k1 = kyx1 * FAST_CONV_MR; + k1 = k1 <= Kg ? k1 : Kg; + + + for (int yx0 = 0; yx0 < out_planesize; yx0 += FAST_CONV_NR) + { + float* inpbuf_task = inpbuf_all + g * hw_aligned * HkWkCg + yx0 * HkWkCg; + int yx1 = yx0 + FAST_CONV_NR; + yx1 = yx1 <= out_planesize ? yx1 : out_planesize; + int slice_len = yx1 - yx0; + bool partial0 = slice_len < FAST_CONV_NR; + + int outstep0 = out_planesize; + size_t outofs = ((n * ngroups + g) * Kg + k0) * outstep0 + yx0; + float *outptr0 = out + outofs; + + matMulCompute(outptr0, inpbuf_task, cbuf, conv, HkWkCg, k0, k1, yx0, yx1, out_planesize, g, + Kg, Kg_aligned, partial0, activ, minval, maxval, ifMinMaxAct); + } + } + } + }); + } + else + { + parallel_for_(Range(0, ntasks), [&](const Range &r0) { + for (int task_id = r0.start; task_id < r0.end; task_id++) { + float *inpbuf_task = &inpbuf_all[taskbufsize * task_id]; + float *cbuf = output_buf + task_id * taskbufsizeOutput; + int ngs0 = (int) ((size_t) nsubtasks * task_id / ntasks); + int ngs1 = (int) ((size_t) nsubtasks * (task_id + 1) / ntasks); + + for (int subtask = ngs0; subtask < ngs1;) + { + int ng = subtask / Kstripes; + int kyx0 = subtask - ng * Kstripes; + int kyx1 = kyx0 + (ngs1 - subtask); + int n = ng / ngroups, g = ng - n * ngroups; + size_t inp_plane_ofs = (size_t) (n * ngroups + g) * Cg * inp_planesize; + kyx1 = kyx1 <= Kstripes ? kyx1 : Kstripes; // Guarantee that maximum kyx1 is Kstripes. + subtask += kyx1 - kyx0; + int k0, k1; + int yx0, yx_limit; + + if (stripes_per_sample == 1) + { + k0 = kyx0 * FAST_CONV_MR; + k1 = kyx1 * FAST_CONV_MR; + k1 = k1 <= Kg ? k1 : Kg; + yx0 = 0; + yx_limit = out_planesize; + } + else + { + k0 = 0; + k1 = Kg; + yx0 = kyx0 * FAST_CONV_NR; + yx_limit = kyx1 * FAST_CONV_NR; + yx_limit = yx_limit < out_planesize ? yx_limit : out_planesize; + } + + for (; yx0 < yx_limit; yx0 += FAST_CONV_NR) + { + float *inpbuf = inpbuf_task; + const float *inptr = inp + inp_plane_ofs; + int yx1 = yx0 + FAST_CONV_NR; + yx1 = yx1 <= yx_limit ? yx1 : yx_limit; + int slice_len = yx1 - yx0; + bool partial0 = slice_len < FAST_CONV_NR; + packInput(inpbuf, inptr, yxtab, ksize, Cg, Hi, Wi, W0, pad_top, pad_left, stride_x, stride_y, + yx0, slice_len, fast_1x1, partial0, s1d1p0, s1d1); + + // 2. do convolution, compute Kg x (yx1 - yx0) part of the output tensor + int outstep0 = out_planesize; + size_t outofs = ((n * ngroups + g) * Kg + k0) * outstep0 + yx0; + float *outptr0 = out + outofs; + + matMulCompute(outptr0, inpbuf_task, cbuf, conv, HkWkCg, k0, k1, yx0, yx1, out_planesize, g, + Kg, Kg_aligned, partial0, activ, minval, maxval, ifMinMaxAct); + } + } + } + }); + } +} + +}} // namespace cv::dnn \ No newline at end of file diff --git a/modules/dnn/src/layers/fast_convolution/fast_convolution.hpp b/modules/dnn/src/layers/fast_convolution/fast_convolution.hpp new file mode 100644 index 0000000..30c5ea2 --- /dev/null +++ b/modules/dnn/src/layers/fast_convolution/fast_convolution.hpp @@ -0,0 +1,89 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_FAST_CONVOLUTION_HPP +#define OPENCV_FAST_CONVOLUTION_HPP + +#include "opencv2/core/hal/intrin.hpp" + +#ifndef FAST_CONV_PRAM +#define FAST_CONV_PRAM +#if CV_NEON && __aarch64__ // 32 registers. +#define FAST_CONV_MR 4 +#define FAST_CONV_NR 28 +enum { FAST_VEC_NLANES=4 }; +#elif CV_NEON // 16 registers. +#define FAST_CONV_MR 4 +#define FAST_CONV_NR 12 +enum { FAST_VEC_NLANES=4 }; +#else // SIMD 128, AVX or AVX2 +#define FAST_CONV_MR 4 +#define FAST_CONV_NR 24 +enum { FAST_VEC_NLANES=4 }; +#endif +#endif + +namespace cv { +namespace dnn { + +struct FastConv2d +{ + int ngroups; + int K, C, Hk, Wk; + int stride_y, stride_x; + int dilation_y, dilation_x; + int pad_top, pad_bottom, pad_left, pad_right; + + std::vector weightsBuf; // For generic Conv 2D + std::vector weightsWino63Buf; // For Winograd F(6x6, 3x3). + + std::vector biasBuf; + bool ifWinograd63 = false; + bool useAVX2 = checkHardwareSupport(CPU_AVX2); + bool useNEON = checkHardwareSupport(CPU_NEON); +}; + +// return a FastConv2d instance. +Ptr initFastConv2d( + int ngroups, + int K, int C, int Hk, int Wk, + int stride_x, int stride_y, + int dilation_x, int dilation_y, + const std::vector& pads_begin, + const std::vector& pads_end, + float* srcWeights, + float* srcBias); + +// It contains different computing branches, like winograd, 1x1 conv. +void runFastConv2d(InputArray _input, OutputArray _output, + const Ptr& conv, int ntasks, const Ptr& actLayer); + +void runDepthwise(InputArray _input, OutputArray _output, const Ptr& conv, float minval, float maxval, + ActivationLayer* activ, bool ifMinMaxAct); + +// winograd init +void initWinograd63(Ptr& conv, float* src_weight, int K, int C); + +int runWinograd63(InputArray _input, OutputArray _output, const Ptr& conv, int ntasks, + float minval, float maxval, ActivationLayer* activ, bool ifMinMaxAct); + +} // namespace dnn + +namespace opt_AVX2 +{ +#if CV_TRY_AVX2 +void convBlock_AVX2(int k, const float *a, const float *b, + float *c, int ldc, const float *bias, + float minval, float maxval, bool ifActiv); + +void depthWiseBlock_AVX2(const float *inptr, float *outptr, const float *weights, float biasval, int *ofstab, int *yxtab, + float minval, float maxval, int Hi, int Wi, int H0, int W0, int ksize, int pad_top, int pad_left, + int dilation_y, int stride_x, int stride_y, int inner_xleft, int inner_xright, int inner_ytop, + int inner_ybottom, bool ifMinMaxAct, bool useSIMD, bool is3x3); +#endif +} // namespace opt_AVX2 + +} // namespace cv + +#endif //OPENCV_FAST_CONVOLUTION_HPP diff --git a/modules/dnn/src/layers/fast_convolution/fast_convolution.simd.hpp b/modules/dnn/src/layers/fast_convolution/fast_convolution.simd.hpp new file mode 100644 index 0000000..f154290 --- /dev/null +++ b/modules/dnn/src/layers/fast_convolution/fast_convolution.simd.hpp @@ -0,0 +1,342 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_FAST_CONVOLUTION_SIMD_HPP +#define OPENCV_FAST_CONVOLUTION_SIMD_HPP + +#include "opencv2/core/hal/intrin.hpp" +#include + +namespace cv { +namespace dnn { + +void convBlock(int k, const float *a, const float *b, + float *c, int ldc, const float *bias, + float minval, float maxval, bool ifActiv) +{ +#if CV_SIMD128 +#if FAST_CONV_MR == 4 && FAST_CONV_NR == 24 + { + v_float32x4 c0 = v_setall_f32(bias[0]), c1 = c0, c2 = c0, c3 = c0, c4 = c0, c5 = c0; + v_float32x4 c6 = v_setall_f32(bias[1]), c7 = c6, c8 = c6, c9 = c6, c10 = c6, c11 = c6; + v_float32x4 c12 = v_setall_f32(bias[2]), c13 = c12, c14 = c12, c15 = c12, c16 = c12, c17 = c12; + v_float32x4 c18 = v_setall_f32(bias[3]), c19 = c18, c20 = c18, c21 = c18, c22 = c18, c23 = c18; + + for (int p = 0; p < k; p++, a += FAST_CONV_MR, b += FAST_CONV_NR) + { + v_float32x4 a0 = v_setall_f32(a[0]); + v_float32x4 b0 = v_load(b), b1 = v_load(b + 4), b2 = v_load(b + 8); + v_float32x4 b3 = v_load(b + 12), b4 = v_load(b + 16), b5 = v_load(b + 20); + + c0 = v_fma(b0, a0, c0); + c1 = v_fma(b1, a0, c1); + c2 = v_fma(b2, a0, c2); + c3 = v_fma(b3, a0, c3); + c4 = v_fma(b4, a0, c4); + c5 = v_fma(b5, a0, c5); + + a0 = v_setall_f32(a[1]); + c6 = v_fma(b0, a0, c6); + c7 = v_fma(b1, a0, c7); + c8 = v_fma(b2, a0, c8); + c9 = v_fma(b3, a0, c9); + c10 = v_fma(b4, a0, c10); + c11 = v_fma(b5, a0, c11); + + a0 = v_setall_f32(a[2]); + c12 = v_fma(b0, a0, c12); + c13 = v_fma(b1, a0, c13); + c14 = v_fma(b2, a0, c14); + c15 = v_fma(b3, a0, c15); + c16 = v_fma(b4, a0, c16); + c17 = v_fma(b5, a0, c17); + + a0 = v_setall_f32(a[3]); + c18 = v_fma(b0, a0, c18); + c19 = v_fma(b1, a0, c19); + c20 = v_fma(b2, a0, c20); + c21 = v_fma(b3, a0, c21); + c22 = v_fma(b4, a0, c22); + c23 = v_fma(b5, a0, c23); + } + + if (ifActiv) { + v_float32x4 vmin = v_setall_f32(minval), vmax = v_setall_f32(maxval); + c0 = v_min(v_max(c0, vmin), vmax); + c1 = v_min(v_max(c1, vmin), vmax); + c2 = v_min(v_max(c2, vmin), vmax); + c3 = v_min(v_max(c3, vmin), vmax); + c4 = v_min(v_max(c4, vmin), vmax); + c5 = v_min(v_max(c5, vmin), vmax); + c6 = v_min(v_max(c6, vmin), vmax); + c7 = v_min(v_max(c7, vmin), vmax); + c8 = v_min(v_max(c8, vmin), vmax); + c9 = v_min(v_max(c9, vmin), vmax); + c10 = v_min(v_max(c10, vmin), vmax); + c11 = v_min(v_max(c11, vmin), vmax); + c12 = v_min(v_max(c12, vmin), vmax); + c13 = v_min(v_max(c13, vmin), vmax); + c14 = v_min(v_max(c14, vmin), vmax); + c15 = v_min(v_max(c15, vmin), vmax); + c16 = v_min(v_max(c16, vmin), vmax); + c17 = v_min(v_max(c17, vmin), vmax); + c18 = v_min(v_max(c18, vmin), vmax); + c19 = v_min(v_max(c19, vmin), vmax); + c20 = v_min(v_max(c20, vmin), vmax); + c21 = v_min(v_max(c21, vmin), vmax); + c22 = v_min(v_max(c22, vmin), vmax); + c23 = v_min(v_max(c23, vmin), vmax); + } + v_store(c, c0); + v_store(c + 4, c1); + v_store(c + 8, c2); + v_store(c + 12, c3); + v_store(c + 16, c4); + v_store(c + 20, c5); + + v_store(c + ldc, c6); + v_store(c + ldc + 4, c7); + v_store(c + ldc + 8, c8); + v_store(c + ldc + 12, c9); + v_store(c + ldc + 16, c10); + v_store(c + ldc + 20, c11); + + v_store(c + ldc * 2, c12); + v_store(c + ldc * 2 + 4, c13); + v_store(c + ldc * 2 + 8, c14); + v_store(c + ldc * 2 + 12, c15); + v_store(c + ldc * 2 + 16, c16); + v_store(c + ldc * 2 + 20, c17); + + v_store(c + ldc * 3, c18); + v_store(c + ldc * 3 + 4, c19); + v_store(c + ldc * 3 + 8, c20); + v_store(c + ldc * 3 + 12, c21); + v_store(c + ldc * 3 + 16, c22); + v_store(c + ldc * 3 + 20, c23); + } +#endif +#else + for (int i = 0; i < FAST_CONV_MR; i++) + { + float beta = bias[i]; + for (int j = 0; j < FAST_CONV_NR; j++) + c[i*ldc + j] = beta; + } + for (int p = 0; p < k; p++) + { + for (int i = 0; i < FAST_CONV_MR; i++) + { + float alpha = a[FAST_CONV_MR*p + i]; + for (int j = 0; j < FAST_CONV_NR; j++) + { + c[i*ldc+j] += b[FAST_CONV_NR*p + j]*alpha; + } + } + } + if (ifActiv) + { + for (int i = 0; i < FAST_CONV_MR; i++) + { + for (int j = 0; j < FAST_CONV_NR; j++) + { + float v = c[i*ldc + j]; + v = std::min(std::max(v, minval), maxval); + c[i*ldc + j] = v; + } + } + } +#endif +} +} // namespace dnn + +namespace opt_NEON +{ +#if CV_TRY_NEON +void convBlock_NEON(int k, const float *a, const float *b, + float *c, int ldc, const float *bias, + float minval, float maxval, bool ifActiv) +{ +#if FAST_CONV_MR == 4 && FAST_CONV_NR == 12 + { + float32x4_t c0 = vdupq_n_f32(bias[0]), c1 = c0, c2 = c0; + float32x4_t c3 = vdupq_n_f32(bias[1]), c4 = c3, c5 = c3; + float32x4_t c6 = vdupq_n_f32(bias[2]), c7 = c6, c8 = c6; + float32x4_t c9 = vdupq_n_f32(bias[3]), c10 = c9, c11 = c9; + + float32x4_t a0 = vdupq_n_f32(0.0f); + float32x4_t b0 = vdupq_n_f32(0.0f), b1 = vdupq_n_f32(0.0f), b2 = vdupq_n_f32(0.0f); + + for (int p = 0; p < k; p++, a += FAST_CONV_MR, b += FAST_CONV_NR) + { + a0 = vld1q_f32(a); + b0 = vld1q_f32(b), b1 = vld1q_f32(b + 4), b2 = vld1q_f32(b + 8); + + c0 = vfmaq_laneq_f32(c0, b0, a0, 0); + c1 = vfmaq_laneq_f32(c1, b1, a0, 0); + c2 = vfmaq_laneq_f32(c2, b2, a0, 0); + c3 = vfmaq_laneq_f32(c3, b0, a0, 1); + c4 = vfmaq_laneq_f32(c4, b1, a0, 1); + c5 = vfmaq_laneq_f32(c5, b2, a0, 1); + + c6 = vfmaq_laneq_f32(c6, b0, a0, 2); + c7 = vfmaq_laneq_f32(c7, b1, a0, 2); + c8 = vfmaq_laneq_f32(c8, b2, a0, 2); + + c9 = vfmaq_laneq_f32(c9, b0, a0, 3); + c10 = vfmaq_laneq_f32(c10, b1, a0, 3); + c11 = vfmaq_laneq_f32(c11, b2, a0, 3); + } + + if (ifActiv) + { + b0 = vdupq_n_f32(minval), b1 = vdupq_n_f32(maxval); + c0 = vminq_f32(vmaxq_f32(c0, b0), b1); + c1 = vminq_f32(vmaxq_f32(c1, b0), b1); + c2 = vminq_f32(vmaxq_f32(c2, b0), b1); + c3 = vminq_f32(vmaxq_f32(c3, b0), b1); + c4 = vminq_f32(vmaxq_f32(c4, b0), b1); + c5 = vminq_f32(vmaxq_f32(c5, b0), b1); + c6 = vminq_f32(vmaxq_f32(c6, b0), b1); + c7 = vminq_f32(vmaxq_f32(c7, b0), b1); + c8 = vminq_f32(vmaxq_f32(c8, b0), b1); + c9 = vminq_f32(vmaxq_f32(c9, b0), b1); + c10 = vminq_f32(vmaxq_f32(c10, b0), b1); + c11 = vminq_f32(vmaxq_f32(c11, b0), b1); + } + vst1q_f32(c, c0); vst1q_f32(c+4, c1); vst1q_f32(c+8, c2); + vst1q_f32(c + ldc, c3); vst1q_f32(c + ldc + 4, c4); vst1q_f32(c + ldc + 8, c5); + vst1q_f32(c + ldc*2, c6); vst1q_f32(c + ldc*2 + 4, c7); vst1q_f32(c + ldc*2 + 8, c8); + vst1q_f32(c + ldc*3, c9); vst1q_f32(c + ldc*3 + 4, c10); vst1q_f32(c + ldc*3 + 8, c11); + } +#elif FAST_CONV_MR == 4 && FAST_CONV_NR == 28 + { + float32x4_t c0 = vdupq_n_f32(bias[0]), c1 = c0, c2 = c0, c3 = c0, c4 = c0, c5 = c0, c24 = c0; + float32x4_t c6 = vdupq_n_f32(bias[1]), c7 = c6, c8 = c6, c9 = c6, c10 = c6, c11 = c6, c25 = c6; + float32x4_t c12 = vdupq_n_f32(bias[2]), c13 = c12, c14 = c12, c15 = c12, c16 = c12, c17 = c12, c26 = c12; + float32x4_t c18 = vdupq_n_f32(bias[3]), c19 = c18, c20 = c18, c21 = c18, c22 = c18, c23 = c18, c27 = c18; + + float32x4_t a0 = vdupq_n_f32(0.0f); + float32x4_t b0 = vdupq_n_f32(0.0f), b1 = vdupq_n_f32(0.0f), b2 = vdupq_n_f32(0.0f); + + for (int p = 0; p < k; p++, a += FAST_CONV_MR) { + a0 = vld1q_f32(a); + b0 = vld1q_f32(b), b1 = vld1q_f32(b + 4), b2 = vld1q_f32(b + 8); + b += 12; + + c0 = vfmaq_laneq_f32(c0, b0, a0, 0); + c1 = vfmaq_laneq_f32(c1, b1, a0, 0); + c2 = vfmaq_laneq_f32(c2, b2, a0, 0); + c6 = vfmaq_laneq_f32(c6, b0, a0, 1); + c7 = vfmaq_laneq_f32(c7, b1, a0, 1); + c8 = vfmaq_laneq_f32(c8, b2, a0, 1); + c12 = vfmaq_laneq_f32(c12, b0, a0, 2); + c13 = vfmaq_laneq_f32(c13, b1, a0, 2); + c14 = vfmaq_laneq_f32(c14, b2, a0, 2); + c18 = vfmaq_laneq_f32(c18, b0, a0, 3); + c19 = vfmaq_laneq_f32(c19, b1, a0, 3); + c20 = vfmaq_laneq_f32(c20, b2, a0, 3); + + b0 = vld1q_f32(b), b1 = vld1q_f32(b + 4), b2 = vld1q_f32(b + 8); + b += 12; + + c3 = vfmaq_laneq_f32(c3, b0, a0, 0); + c4 = vfmaq_laneq_f32(c4, b1, a0, 0); + c5 = vfmaq_laneq_f32(c5, b2, a0, 0); + + c9 = vfmaq_laneq_f32(c9, b0, a0, 1); + c10 = vfmaq_laneq_f32(c10, b1, a0, 1); + c11 = vfmaq_laneq_f32(c11, b2, a0, 1); + + c15 = vfmaq_laneq_f32(c15, b0, a0, 2); + c16 = vfmaq_laneq_f32(c16, b1, a0, 2); + c17 = vfmaq_laneq_f32(c17, b2, a0, 2); + + c21 = vfmaq_laneq_f32(c21, b0, a0, 3); + + b0 = vld1q_f32(b); + b += 4; + + c22 = vfmaq_laneq_f32(c22, b1, a0, 3); + c23 = vfmaq_laneq_f32(c23, b2, a0, 3); + + c24 = vfmaq_laneq_f32(c24, b0, a0, 0); + c25 = vfmaq_laneq_f32(c25, b0, a0, 1); + c26 = vfmaq_laneq_f32(c26, b0, a0, 2); + c27 = vfmaq_laneq_f32(c27, b0, a0, 3); + } + + if (ifActiv) { + b0 = vdupq_n_f32(minval), b1 = vdupq_n_f32(maxval); + c0 = vminq_f32(vmaxq_f32(c0, b0), b1); + c1 = vminq_f32(vmaxq_f32(c1, b0), b1); + c2 = vminq_f32(vmaxq_f32(c2, b0), b1); + c3 = vminq_f32(vmaxq_f32(c3, b0), b1); + c4 = vminq_f32(vmaxq_f32(c4, b0), b1); + c5 = vminq_f32(vmaxq_f32(c5, b0), b1); + c6 = vminq_f32(vmaxq_f32(c6, b0), b1); + c7 = vminq_f32(vmaxq_f32(c7, b0), b1); + c8 = vminq_f32(vmaxq_f32(c8, b0), b1); + c9 = vminq_f32(vmaxq_f32(c9, b0), b1); + c10 = vminq_f32(vmaxq_f32(c10, b0), b1); + c11 = vminq_f32(vmaxq_f32(c11, b0), b1); + c12 = vminq_f32(vmaxq_f32(c12, b0), b1); + c13 = vminq_f32(vmaxq_f32(c13, b0), b1); + c14 = vminq_f32(vmaxq_f32(c14, b0), b1); + c15 = vminq_f32(vmaxq_f32(c15, b0), b1); + c16 = vminq_f32(vmaxq_f32(c16, b0), b1); + c17 = vminq_f32(vmaxq_f32(c17, b0), b1); + c18 = vminq_f32(vmaxq_f32(c18, b0), b1); + c19 = vminq_f32(vmaxq_f32(c19, b0), b1); + c20 = vminq_f32(vmaxq_f32(c20, b0), b1); + c21 = vminq_f32(vmaxq_f32(c21, b0), b1); + c22 = vminq_f32(vmaxq_f32(c22, b0), b1); + c23 = vminq_f32(vmaxq_f32(c23, b0), b1); + c24 = vminq_f32(vmaxq_f32(c24, b0), b1); + c25 = vminq_f32(vmaxq_f32(c25, b0), b1); + c26 = vminq_f32(vmaxq_f32(c26, b0), b1); + c27 = vminq_f32(vmaxq_f32(c27, b0), b1); + } + vst1q_f32(c, c0); + vst1q_f32(c + 4, c1); + vst1q_f32(c + 8, c2); + vst1q_f32(c + 12, c3); + vst1q_f32(c + 16, c4); + vst1q_f32(c + 20, c5); + vst1q_f32(c + 24, c24); + + vst1q_f32(c + ldc, c6); + vst1q_f32(c + ldc + 4, c7); + vst1q_f32(c + ldc + 8, c8); + vst1q_f32(c + ldc + 12, c9); + vst1q_f32(c + ldc + 16, c10); + vst1q_f32(c + ldc + 20, c11); + vst1q_f32(c + ldc + 24, c25); + + vst1q_f32(c + ldc * 2, c12); + vst1q_f32(c + ldc * 2 + 4, c13); + vst1q_f32(c + ldc * 2 + 8, c14); + vst1q_f32(c + ldc * 2 + 12, c15); + vst1q_f32(c + ldc * 2 + 16, c16); + vst1q_f32(c + ldc * 2 + 20, c17); + vst1q_f32(c + ldc * 2 + 24, c26); + + vst1q_f32(c + ldc * 3, c18); + vst1q_f32(c + ldc * 3 + 4, c19); + vst1q_f32(c + ldc * 3 + 8, c20); + vst1q_f32(c + ldc * 3 + 12, c21); + vst1q_f32(c + ldc * 3 + 16, c22); + vst1q_f32(c + ldc * 3 + 20, c23); + vst1q_f32(c + ldc * 3 + 24, c27); + } +#else +#error "unsupported FAST_CONV_MR and/or FAST_CONV_NR in convBlock_NEON." +#endif +} + +#endif +} // namespace opt_NEON + +} // namespace cv +#endif //OPENCV_FAST_CONVOLUTION_SIMD_HPP diff --git a/modules/dnn/src/layers/fast_convolution/winograd_3x3s1_f63.cpp b/modules/dnn/src/layers/fast_convolution/winograd_3x3s1_f63.cpp new file mode 100644 index 0000000..7a0720f --- /dev/null +++ b/modules/dnn/src/layers/fast_convolution/winograd_3x3s1_f63.cpp @@ -0,0 +1,1351 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +/* +Winograd-based convolution F(6x6, 3x3). +The code has been borrowed from ncnn inference engine (https://github.com/Tencent/ncnn) +and adapted for OpenCV by Zihao Mu. + +Below is the original copyright +*/ + +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "../../precomp.hpp" +#include "fast_convolution.hpp" + +namespace cv { namespace dnn { +enum +{ + WINO_STEP=6, + WINO_KSIZE=3, + WINO_SIZE= WINO_STEP + WINO_KSIZE - 1, + WINO_AREA= WINO_SIZE * WINO_SIZE +}; + +#if CV_NEON +static void winograd_trans_input_F63(float* src, float* dst, int Channle_div4, const int tiles, const int big_step, const int line_step, const int* ofstab0) +{ + // const float itm[8][8] = { + // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, + // + // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, + // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, + // + // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, + // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, + // + // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, + // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, + // + // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} + // }; + + // 0 = r00 - r06 + (r04 - r02) * 5.25 + // 7 = r07 - r01 + (r03 - r05) * 5.25 + + // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) + // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) + + // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) + // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) + + // reuse r04 * 1.25 + // reuse r03 * 2.5 + // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) + // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) + + float tmp[8][8][FAST_VEC_NLANES]; + AutoBuffer input_buf0_; + input_buf0_.allocate(64 * tiles * FAST_VEC_NLANES); + + float* input_buf0 = input_buf0_.data(); + memset(input_buf0, 0, 64 * tiles * FAST_VEC_NLANES * sizeof(float )); + + for (int ti = 0; ti < tiles; ti++) + { + float* input0 = src + ti * 64 * 4; + float* input = input0; + for (int m = 0; m < 8; m++) + { + float32x4_t _r00 = vld1q_f32(input); + float32x4_t _r01 = vld1q_f32(input + 4); + float32x4_t _r02 = vld1q_f32(input + 8); + float32x4_t _r03 = vld1q_f32(input + 12); + float32x4_t _r04 = vld1q_f32(input + 16); + float32x4_t _r05 = vld1q_f32(input + 20); + float32x4_t _r06 = vld1q_f32(input + 24); + float32x4_t _r07 = vld1q_f32(input + 28); + + float32x4_t _tmp0m = vmlaq_n_f32(vsubq_f32(_r00, _r06), vsubq_f32(_r04, _r02), 5.25f); + float32x4_t _tmp7m = vmlaq_n_f32(vsubq_f32(_r07, _r01), vsubq_f32(_r03, _r05), 5.25f); + vst1q_f32(tmp[0][m], _tmp0m); + vst1q_f32(tmp[7][m], _tmp7m); + + float32x4_t _tmp12a = vmlsq_n_f32(vaddq_f32(_r02, _r06), _r04, 4.25f); + float32x4_t _tmp12b = vmlsq_n_f32(vaddq_f32(_r01, _r05), _r03, 4.25f); + + float32x4_t _tmp1m = vaddq_f32(_tmp12a, _tmp12b); + float32x4_t _tmp2m = vsubq_f32(_tmp12a, _tmp12b); + vst1q_f32(tmp[1][m], _tmp1m); + vst1q_f32(tmp[2][m], _tmp2m); + + float32x4_t _tmp34a = vmlsq_n_f32(vmlaq_n_f32(_r06, _r02, 0.25f), _r04, 1.25f); + float32x4_t _tmp34b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_r01, 0.5f), _r03, 2.5f), _r05, 2.f); + + float32x4_t _tmp3m = vaddq_f32(_tmp34a, _tmp34b); + float32x4_t _tmp4m = vsubq_f32(_tmp34a, _tmp34b); + vst1q_f32(tmp[3][m], _tmp3m); + vst1q_f32(tmp[4][m], _tmp4m); + + float32x4_t _tmp56a = vmlaq_n_f32(_r06, vmlsq_n_f32(_r02, _r04, 1.25f), 4.f); + float32x4_t _tmp56b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_r01, 2.f), _r03, 2.5f), _r05, 0.5f); + + float32x4_t _tmp5m = vaddq_f32(_tmp56a, _tmp56b); + float32x4_t _tmp6m = vsubq_f32(_tmp56a, _tmp56b); + vst1q_f32(tmp[5][m], _tmp5m); + vst1q_f32(tmp[6][m], _tmp6m); + + input += 8 * FAST_VEC_NLANES; + } + + float* input_buf00 = input_buf0 + ti * 4; + float* input_buf01 = input_buf00 + tiles * 4; + float* input_buf02 = input_buf00 + tiles * 8; + float* input_buf03 = input_buf00 + tiles * 12; + float* input_buf04 = input_buf00 + tiles * 16; + float* input_buf05 = input_buf00 + tiles * 20; + float* input_buf06 = input_buf00 + tiles * 24; + float* input_buf07 = input_buf00 + tiles * 28; + + for (int m = 0; m < 8; m++) + { + float32x4_t _tmp00 = vld1q_f32(tmp[m][0]); + float32x4_t _tmp01 = vld1q_f32(tmp[m][1]); + float32x4_t _tmp02 = vld1q_f32(tmp[m][2]); + float32x4_t _tmp03 = vld1q_f32(tmp[m][3]); + float32x4_t _tmp04 = vld1q_f32(tmp[m][4]); + float32x4_t _tmp05 = vld1q_f32(tmp[m][5]); + float32x4_t _tmp06 = vld1q_f32(tmp[m][6]); + float32x4_t _tmp07 = vld1q_f32(tmp[m][7]); + + float32x4_t _r0tm0 = vmlaq_n_f32(vsubq_f32(_tmp00, _tmp06), vsubq_f32(_tmp04, _tmp02), 5.25f); + float32x4_t _r0tm7 = vmlaq_n_f32(vsubq_f32(_tmp07, _tmp01), vsubq_f32(_tmp03, _tmp05), 5.25f); + + float32x4_t _tmp12a = vmlsq_n_f32(vaddq_f32(_tmp02, _tmp06), _tmp04, 4.25f); + float32x4_t _tmp12b = vmlsq_n_f32(vaddq_f32(_tmp01, _tmp05), _tmp03, 4.25f); + + float32x4_t _r0tm1 = vaddq_f32(_tmp12a, _tmp12b); + float32x4_t _r0tm2 = vsubq_f32(_tmp12a, _tmp12b); + + float32x4_t _tmp34a = vmlsq_n_f32(vmlaq_n_f32(_tmp06, _tmp02, 0.25f), _tmp04, 1.25f); + float32x4_t _tmp34b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_tmp01, 0.5f), _tmp03, 2.5f), _tmp05, 2.f); + + float32x4_t _r0tm3 = vaddq_f32(_tmp34a, _tmp34b); + float32x4_t _r0tm4 = vsubq_f32(_tmp34a, _tmp34b); + + float32x4_t _tmp56a = vmlaq_n_f32(_tmp06, vmlsq_n_f32(_tmp02, _tmp04, 1.25f), 4.f); + float32x4_t _tmp56b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_tmp01, 2.f), _tmp03, 2.5f), _tmp05, 0.5f); + + float32x4_t _r0tm5 = vaddq_f32(_tmp56a, _tmp56b); + float32x4_t _r0tm6 = vsubq_f32(_tmp56a, _tmp56b); + + vst1q_f32(input_buf00, _r0tm0); + vst1q_f32(input_buf01, _r0tm1); + vst1q_f32(input_buf02, _r0tm2); + vst1q_f32(input_buf03, _r0tm3); + vst1q_f32(input_buf04, _r0tm4); + vst1q_f32(input_buf05, _r0tm5); + vst1q_f32(input_buf06, _r0tm6); + vst1q_f32(input_buf07, _r0tm7); + + input_buf00 += tiles * 32; + input_buf01 += tiles * 32; + input_buf02 += tiles * 32; + input_buf03 += tiles * 32; + input_buf04 += tiles * 32; + input_buf05 += tiles * 32; + input_buf06 += tiles * 32; + input_buf07 += tiles * 32; + } + } + + // [line Number, input pack] + // if InpPack == 8; + for (int r = 0; r < 64; r++) + { + int ti = 0; + float* out0 = dst + r * big_step; + float* input0 = input_buf0 + 4 * tiles * r; + + // TODO! support tiles > 12 +//#if (ARMV8) +// for (; ti + 11 < tiles; ti += 12) +// { +// float* out1 = out0 + line_step * ofstab0[ti * 2] + Channle_div4 * ofstab0[ti * 2 + 1] * 4; +//// std::cout<<"ofstab0[ti * 2] = "<(2, 7) << 0, 7, 0.991359, 491.822, 81.1668, 702.573, 178.234, 0, 12, 0.94786, 132.093, 223.903, 338.077, 566.16); - float confThreshold = 0.8, scoreDiff = 0.017, iouDiff = 0.11; + float confThreshold = 0.8, scoreDiff = 0.15, iouDiff = 0.11; testFaster(net, ref, confThreshold, scoreDiff, iouDiff); } @@ -1114,7 +1114,7 @@ TEST_P(Test_Int8_nets, YoloVoc) std::string config_file = "yolo-voc.cfg"; std::string weights_file = "yolo-voc.weights"; - double scoreDiff = 0.1, iouDiff = 0.3; + double scoreDiff = 0.12, iouDiff = 0.3; { SCOPED_TRACE("batch size 1"); testDarknetModel(config_file, weights_file, ref.rowRange(0, 3), scoreDiff, iouDiff); diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index 582d8b0..72a8989 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -1336,7 +1336,7 @@ TEST_P(Test_TensorFlow_nets, EAST_text_detection) } else { - l1_geometry = 1e-4, lInf_geometry = 3e-3; + l1_geometry = 1e-4, lInf_geometry = 4.3e-3; } normAssert(scores, blobFromNPY(refScoresPath), "scores", l1_scores, lInf_scores); normAssert(geometry, blobFromNPY(refGeometryPath), "geometry", l1_geometry, lInf_geometry); -- 2.7.4