From: Dmitry Kurtaev Date: Tue, 13 Feb 2018 09:07:56 +0000 (+0300) Subject: Refactored deep learning layers fusion X-Git-Tag: accepted/tizen/6.0/unified/20201030.111113~57^2 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=514e6df460badfaf9634691ea07c356c84583e4b;p=platform%2Fupstream%2Fopencv.git Refactored deep learning layers fusion --- diff --git a/modules/dnn/include/opencv2/dnn/all_layers.hpp b/modules/dnn/include/opencv2/dnn/all_layers.hpp index 061d184..0704f6b 100644 --- a/modules/dnn/include/opencv2/dnn/all_layers.hpp +++ b/modules/dnn/include/opencv2/dnn/all_layers.hpp @@ -472,7 +472,6 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN bool hasWeights, hasBias; float epsilon; - virtual void getScaleShift(Mat& scale, Mat& shift) const = 0; static Ptr create(const LayerParams ¶ms); }; diff --git a/modules/dnn/include/opencv2/dnn/dnn.hpp b/modules/dnn/include/opencv2/dnn/dnn.hpp index 4ad3035..e94b30c 100644 --- a/modules/dnn/include/opencv2/dnn/dnn.hpp +++ b/modules/dnn/include/opencv2/dnn/dnn.hpp @@ -281,20 +281,26 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN virtual bool setActivation(const Ptr& layer); /** - * @brief Tries to attach to the layer the subsequent batch normalization layer, i.e. do the layer fusion in a partial case. - * @param[in] layer The subsequent batch normalization layer. - * - * Returns true if the batch normalization layer has been attached successfully. + * @brief Try to fuse current layer with a next one + * @param[in] top Next layer to be fused. + * @returns True if fusion was performed. */ - virtual bool setBatchNorm(const Ptr& layer); + virtual bool tryFuse(Ptr& top); /** - * @brief Tries to attach to the layer the subsequent scaling layer, i.e. do the layer fusion in a partial case. - * @param[in] layer The subsequent scaling layer. + * @brief Returns parameters of layers with channel-wise multiplication and addition. + * @param[out] scale Channel-wise multipliers. Total number of values should + * be equal to number of channels. + * @param[out] shift Channel-wise offsets. Total number of values should + * be equal to number of channels. * - * Returns true if the scaling layer has been attached successfully. + * Some layers can fuse their transformations with further layers. + * In example, convolution + batch normalization. This way base layer + * use weights from layer after it. Fused layer is skipped. + * By default, @p scale and @p shift are empty that means layer has no + * element-wise multiplications or additions. */ - virtual bool setScale(const Ptr& layer); + virtual void getScaleShift(Mat& scale, Mat& shift) const; /** * @brief "Deattaches" all the layers, attached to particular layer. diff --git a/modules/dnn/src/dnn.cpp b/modules/dnn/src/dnn.cpp index b66fb42..31ae173 100644 --- a/modules/dnn/src/dnn.cpp +++ b/modules/dnn/src/dnn.cpp @@ -1407,46 +1407,30 @@ struct Net::Impl if( ld.consumers.size() == 1 && pinsToKeep.count(LayerPin(lid, 0)) == 0 ) { LayerData* nextData = &layers[ld.consumers[0].lid]; - Ptr nextBNormLayer = - nextData->layerInstance.dynamicCast(); LayerPin lpNext(ld.consumers[0].lid, 0); - if( !nextBNormLayer.empty() && pinsToKeep.count(lpNext) == 0 ) + while (nextData) { - LayerData* bnormData = nextData; - nextData = 0; - if( currLayer->setBatchNorm(nextBNormLayer) ) + Ptr nextLayer = nextData->layerInstance; + if (currLayer->tryFuse(nextLayer)) { - printf_(("\tfused with %s\n", nextBNormLayer->name.c_str())); - bnormData->skip = true; + printf_(("\tfused with %s\n", nextLayer->name.c_str())); + nextData->skip = true; ld.outputBlobs = layers[lpNext.lid].outputBlobs; ld.outputBlobsWrappers = layers[lpNext.lid].outputBlobsWrappers; - if( bnormData->consumers.size() == 1 ) + if (nextData->consumers.size() == 1) { - nextData = &layers[bnormData->consumers[0].lid]; - lpNext = LayerPin(bnormData->consumers[0].lid, 0); + int nextLayerId = nextData->consumers[0].lid; + nextData = &layers[nextLayerId]; + lpNext = LayerPin(nextLayerId, 0); } - } - } - - Ptr nextScaleLayer; - if( nextData ) - nextScaleLayer = nextData->layerInstance.dynamicCast(); - if( !nextScaleLayer.empty() && pinsToKeep.count(lpNext) == 0 ) - { - LayerData* scaleData = nextData; - nextData = 0; - if( currLayer->setScale(nextScaleLayer) ) - { - printf_(("\tfused with %s\n", nextScaleLayer->name.c_str())); - scaleData->skip = true; - ld.outputBlobs = layers[lpNext.lid].outputBlobs; - ld.outputBlobsWrappers = layers[lpNext.lid].outputBlobsWrappers; - if( scaleData->consumers.size() == 1 ) + else { - nextData = &layers[scaleData->consumers[0].lid]; - lpNext = LayerPin(scaleData->consumers[0].lid, 0); + nextData = 0; + break; } } + else + break; } // For now, OpenCL target support fusion with activation of ReLU/ChannelsPReLU/Power/Tanh @@ -2627,13 +2611,16 @@ Ptr Layer::tryAttach(const Ptr& node) } bool Layer::setActivation(const Ptr&) { return false; } -bool Layer::setBatchNorm(const Ptr&) { return false; } -bool Layer::setScale(const Ptr&) { return false; } +bool Layer::tryFuse(Ptr&) { return false; } +void Layer::getScaleShift(Mat& scale, Mat& shift) const +{ + scale = Mat(); + shift = Mat(); +} + void Layer::unsetAttached() { setActivation(Ptr()); - setBatchNorm(Ptr()); - setScale(Ptr()); } template diff --git a/modules/dnn/src/layers/convolution_layer.cpp b/modules/dnn/src/layers/convolution_layer.cpp index 64c2212..3a0bc1b 100644 --- a/modules/dnn/src/layers/convolution_layer.cpp +++ b/modules/dnn/src/layers/convolution_layer.cpp @@ -61,7 +61,23 @@ namespace dnn class BaseConvolutionLayerImpl : public ConvolutionLayer { public: - BaseConvolutionLayerImpl() {} + BaseConvolutionLayerImpl(const LayerParams ¶ms) + { + setParamsFrom(params); + getConvolutionKernelParams(params, kernel.height, kernel.width, pad.height, + pad.width, stride.height, stride.width, dilation.height, + dilation.width, padMode); + + numOutput = params.get("num_output"); + int ngroups = params.get("group", 1); + + adjustPad.height = params.get("adj_h", 0); + adjustPad.width = params.get("adj_w", 0); + + CV_Assert(numOutput % ngroups == 0); + CV_Assert(adjustPad.width < stride.width && + adjustPad.height < stride.height); + } virtual bool supportBackend(int backendId) { @@ -153,12 +169,10 @@ class ConvolutionLayerImpl : public BaseConvolutionLayerImpl { public: enum { VEC_ALIGN = 8, DFT_TYPE = CV_32F }; - Mat weightsMat; + Mat weightsMat, weightsMat_doubles; std::vector biasvec; std::vector reluslope; Ptr activ; - Ptr bnorm; - Ptr scaleLayer; #ifdef HAVE_OPENCL Ptr > convolutionOp; @@ -169,7 +183,7 @@ public: ocl4dnnFusedActiv_t activType; float power; #endif - ConvolutionLayerImpl() + ConvolutionLayerImpl(const LayerParams ¶ms) : BaseConvolutionLayerImpl(params) { #ifdef HAVE_OPENCL fusedBias = false; @@ -225,6 +239,42 @@ public: return false; } + virtual void finalize(const std::vector &inputs, std::vector &outputs) + { + BaseConvolutionLayerImpl::finalize(inputs, outputs); + + CV_Assert(!blobs.empty()); + const int outCn = blobs[0].size[0]; + // prepare weightsMat where each row is aligned and has enough zero padding on the right to + // use vectorized (i.e. with intrinsics) loops without tail processing + Mat wm = blobs[0].reshape(1, outCn).clone(); + if( wm.step1() % VEC_ALIGN != 0 ) + { + int newcols = (int)alignSize(wm.step1(), VEC_ALIGN); + Mat wm_buffer = Mat(outCn, newcols, wm.type()); + Mat wm_padding = wm_buffer.colRange(wm.cols, newcols); + wm_padding.setTo(Scalar::all(0.)); + Mat wm_aligned = wm_buffer.colRange(0, wm.cols); + wm.copyTo(wm_aligned); + wm = wm_aligned; + } + weightsMat = wm; + weightsMat.convertTo(weightsMat_doubles, CV_64F); + + Mat biasMat = hasBias() ? blobs[1].reshape(1, outCn) : Mat(); + biasvec.resize(outCn+2); + if( biasMat.empty() ) + { + for(int i = 0; i < outCn; i++ ) + biasvec[i] = 0.f; + } + else + { + for(int i = 0; i < outCn; i++ ) + biasvec[i] = biasMat.at(i); + } + } + bool setActivation(const Ptr& layer) { activ = layer; @@ -240,10 +290,11 @@ public: if (!activ_power.empty()) { if (activ_power->scale != 1.f || activ_power->shift != 0.f) - newWeightAndBias = true; - - if (activ_power->scale != 1.f) - weightsMat.release(); + { + const int outCh = blobs[0].size[0]; + fuseWeights(Mat(1, outCh, CV_32F, Scalar(activ_power->scale)), + Mat(1, outCh, CV_32F, Scalar(activ_power->shift))); + } power = activ_power->power; activType = OCL4DNN_CONV_FUSED_ACTIV_POWER; @@ -258,35 +309,49 @@ public: return !activ.empty(); } - bool setBatchNorm(const Ptr& layer ) + virtual bool tryFuse(Ptr& top) { - // for now the scale layer followed by the batch norm cannot be fused, only vice versa. - if( !scaleLayer.empty() ) - return false; - bnorm = layer; - // we will need to re-compute the weights with the batch - // norm coefficients taken into account - weightsMat.release(); -#ifdef HAVE_OPENCL - newWeightAndBias = true; - fusedBias = false; -#endif - return !bnorm.empty(); + Mat w, b; + top->getScaleShift(w, b); + if (!w.empty() || !b.empty()) + { + fuseWeights(w, b); + return true; + } + return false; } - bool setScale(const Ptr& layer) + void fuseWeights(const Mat& w, const Mat& b) { - if (layer.empty() || layer->blobs.empty()) - return false; - scaleLayer = layer; - // we will need to re-compute the weights with the scaling - // coefficients taken into account - weightsMat.release(); + // Convolution weights have OIHW data layout. Parameters fusion in case of + // (conv(I) + b1 ) * w + b2 + // means to replace convolution's weights to [w*conv(I)] and bias to [b1 * w + b2] + const int outCn = weightsMat.size[0]; + CV_Assert(!weightsMat.empty(), biasvec.size() == outCn + 2, + w.empty() || outCn == w.total(), b.empty() || outCn == b.total()); + + if (!w.empty()) + { + for (int i = 0; i < outCn; ++i) + { + double wi = w.at(i); + cv::multiply(slice(weightsMat_doubles, i), wi, slice(weightsMat_doubles, i)); + biasvec[i] *= wi; + } + weightsMat_doubles.convertTo(weightsMat, weightsMat.type()); + } + + if (!b.empty()) + { + for (int i = 0; i < outCn; ++i) + biasvec[i] += b.at(i); + } + #ifdef HAVE_OPENCL - newWeightAndBias = true; - fusedBias = false; + newWeightAndBias = !w.empty() || !b.empty(); + fusedBias = hasBias() || !b.empty(); #endif - return true; + biasvec[outCn] = biasvec[outCn+1] = biasvec[outCn-1]; } virtual Ptr initHalide(const std::vector > &inputs) @@ -776,97 +841,7 @@ public: convolutionOp = Ptr >(new OCL4DNNConvSpatial(config)); } - int k, outCn = umat_blobs[0].size[0]; - if( weightsMat.empty() ) - { - // prepare weightsMat where each row is aligned and has enough zero padding on the right to - // use vectorized (i.e. with intrinsics) loops without tail processing - Mat wm = blobs[0].reshape(1, outCn).clone(); - if( wm.step1() % VEC_ALIGN != 0 ) - { - int newcols = (int)alignSize(wm.step1(), VEC_ALIGN); - Mat wm_buffer = Mat(outCn, newcols, wm.type()); - Mat wm_padding = wm_buffer.colRange(wm.cols, newcols); - wm_padding.setTo(Scalar::all(0.)); - Mat wm_aligned = wm_buffer.colRange(0, wm.cols); - wm.copyTo(wm_aligned); - wm = wm_aligned; - } - weightsMat = wm; - - Mat biasMat = hasBias() ? blobs[1].reshape(1, outCn) : Mat(); - biasvec.resize(outCn+2); - if( biasMat.empty() ) - { - for( k = 0; k < outCn; k++ ) - biasvec[k] = 0.f; - } - else - { - for( k = 0; k < outCn; k++ ) - biasvec[k] = biasMat.at(k); - } - - if( !bnorm.empty() || !scaleLayer.empty() || IS_POWER_LAYER(activ)) - { - Mat scale, shift, scale2, shift2; - const float *scaleptr = 0, *shiftptr = 0; - const float *scaleptr2 = 0, *shiftptr2 = 0; - float a = 1.f, b = 0.f; - - if( !bnorm.empty() ) - { - bnorm->getScaleShift(scale, shift); - CV_Assert( scale.isContinuous() && shift.isContinuous() && - scale.type() == CV_32F && shift.type() == CV_32F && - scale.total() == (size_t)outCn && - shift.total() == (size_t)outCn ); - scaleptr = scale.ptr(); - shiftptr = shift.ptr(); - } - if( !scaleLayer.empty() ) - { - scale2 = scaleLayer->blobs[0]; - CV_Assert( scale2.isContinuous() && scale2.type() == CV_32F && - scale2.total() == (size_t)outCn ); - scaleptr2 = scale2.ptr(); - if( scaleLayer->hasBias ) - { - shift2 = scaleLayer->blobs[1]; - CV_Assert( shift2.isContinuous() && shift2.type() == CV_32F && - shift2.total() == (size_t)outCn ); - shiftptr2 = shift2.ptr(); - } - } - - if( IS_POWER_LAYER(activ) ) - { - Ptr activ_power = activ.dynamicCast(); - CV_Assert(activ_power); - a = activ_power->scale; - b = activ_power->shift; - } - - if (shiftptr || shiftptr2 || b != 0.f) - fusedBias = true; - - for( int i = 0; i < outCn; i++ ) - { - float s1 = scaleptr ? scaleptr[i] : 1.f; - float delta1 = shiftptr ? shiftptr[i] : 0.f; - float s2 = scaleptr2 ? scaleptr2[i] : 1.f; - float delta2 = shiftptr2 ? shiftptr2[i] : 0.f; - float* w_i = weightsMat.ptr(i); - int j, wcols = weightsMat.cols; - - for( j = 0; j < wcols; j++ ) - w_i[j] *= (s1*s2*a); - - biasvec[i] = biasvec[i]*(s1*s2*a) + (delta1*s2*a + delta2*a + b); - } - } - biasvec[outCn] = biasvec[outCn+1] = biasvec[outCn-1]; - } + int outCn = umat_blobs[0].size[0]; reluslope.clear(); if( activ ) @@ -973,86 +948,7 @@ public: int ngroups = inputs[0]->size[1]/blobs[0].size[1]; CV_Assert(outputs[0].size[1] % ngroups == 0); - int k, outCn = blobs[0].size[0]; - - if( weightsMat.empty() ) - { - // prepare weightsMat where each row is aligned and has enough zero padding on the right to - // use vectorized (i.e. with intrinsics) loops without tail processing - Mat wm = blobs[0].reshape(1, outCn).clone(); - if( wm.step1() % VEC_ALIGN != 0 ) - { - int newcols = (int)alignSize(wm.step1(), VEC_ALIGN); - Mat wm_buffer = Mat(outCn, newcols, wm.type()); - Mat wm_padding = wm_buffer.colRange(wm.cols, newcols); - wm_padding.setTo(Scalar::all(0.)); - Mat wm_aligned = wm_buffer.colRange(0, wm.cols); - wm.copyTo(wm_aligned); - wm = wm_aligned; - } - weightsMat = wm; - - Mat biasMat = hasBias() ? blobs[1].reshape(1, outCn) : Mat(); - biasvec.resize(outCn+2); - if( biasMat.empty() ) - { - for( k = 0; k < outCn; k++ ) - biasvec[k] = 0.f; - } - else - { - for( k = 0; k < outCn; k++ ) - biasvec[k] = biasMat.at(k); - } - - if( !bnorm.empty() || !scaleLayer.empty() ) - { - Mat scale, shift, scale2, shift2; - const float *scaleptr = 0, *shiftptr = 0; - const float *scaleptr2 = 0, *shiftptr2 = 0; - - if( !bnorm.empty() ) - { - bnorm->getScaleShift(scale, shift); - CV_Assert( scale.isContinuous() && shift.isContinuous() && - scale.type() == CV_32F && shift.type() == CV_32F && - scale.total() == (size_t)outCn && - shift.total() == (size_t)outCn ); - scaleptr = scale.ptr(); - shiftptr = shift.ptr(); - } - if( !scaleLayer.empty() ) - { - scale2 = scaleLayer->blobs[0]; - CV_Assert( scale2.isContinuous() && scale2.type() == CV_32F && - scale2.total() == (size_t)outCn ); - scaleptr2 = scale2.ptr(); - if( scaleLayer->hasBias ) - { - shift2 = scaleLayer->blobs[1]; - CV_Assert( shift2.isContinuous() && shift2.type() == CV_32F && - shift2.total() == (size_t)outCn ); - shiftptr2 = shift2.ptr(); - } - } - - for( int i = 0; i < outCn; i++ ) - { - float s1 = scaleptr ? scaleptr[i] : 1.f; - float delta1 = shiftptr ? shiftptr[i] : 0.f; - float s2 = scaleptr2 ? scaleptr2[i] : 1.f; - float delta2 = shiftptr2 ? shiftptr2[i] : 0.f; - float* w_i = weightsMat.ptr(i); - int j, wcols = weightsMat.cols; - - for( j = 0; j < wcols; j++ ) - w_i[j] *= (s1*s2); - - biasvec[i] = biasvec[i]*(s1*s2) + (delta1*s2 + delta2); - } - } - biasvec[outCn] = biasvec[outCn+1] = biasvec[outCn-1]; - } + int outCn = blobs[0].size[0]; reluslope.clear(); if( activ ) @@ -1103,6 +999,8 @@ public: UMat umat_weights; UMat umat_biases; + DeConvolutionLayerImpl(const LayerParams& params) : BaseConvolutionLayerImpl(params) {} + MatShape computeColRowShape(const MatShape &inpShape, const MatShape &outShape) const { int inpCn = inpShape[1]; @@ -1619,36 +1517,15 @@ public: } }; -//Convolution and Deconvolution -static void initConvDeconvLayerFromCaffe(Ptr l, const LayerParams ¶ms) -{ - l->setParamsFrom(params); - getConvolutionKernelParams(params, l->kernel.height, l->kernel.width, l->pad.height, - l->pad.width, l->stride.height, l->stride.width, l->dilation.height, - l->dilation.width, l->padMode); - - l->numOutput = params.get("num_output"); - int ngroups = params.get("group", 1); - - l->adjustPad.height = params.get("adj_h", 0); - l->adjustPad.width = params.get("adj_w", 0); - - CV_Assert(l->numOutput % ngroups == 0); - CV_Assert(l->adjustPad.width < l->stride.width && - l->adjustPad.height < l->stride.height); -} - Ptr ConvolutionLayer::create(const LayerParams ¶ms) { - ConvolutionLayerImpl* conv_ptr = new ConvolutionLayerImpl; - Ptr l(conv_ptr); - initConvDeconvLayerFromCaffe(l, params); + Ptr l(new ConvolutionLayerImpl(params)); #ifdef HAVE_OPENCL size_t n = params.blobs.size(); - conv_ptr->umat_blobs.resize(n); + l->umat_blobs.resize(n); for (int i = 0; i < n; i++) - conv_ptr->umat_blobs[i] = params.blobs[i].getUMat(ACCESS_READ); + l->umat_blobs[i] = params.blobs[i].getUMat(ACCESS_READ); #endif return l; @@ -1656,10 +1533,7 @@ Ptr ConvolutionLayer::create(const LayerParams ¶ms) Ptr DeconvolutionLayer::create(const LayerParams ¶ms) { - Ptr l(new DeConvolutionLayerImpl); - initConvDeconvLayerFromCaffe(l, params); - - return l; + return Ptr(new DeConvolutionLayerImpl(params)); } } diff --git a/modules/dnn/src/layers/mvn_layer.cpp b/modules/dnn/src/layers/mvn_layer.cpp index c911b74..e0976cd 100644 --- a/modules/dnn/src/layers/mvn_layer.cpp +++ b/modules/dnn/src/layers/mvn_layer.cpp @@ -65,16 +65,18 @@ public: relu_slope = 0.f; } - Ptr bnorm; Mat scale, shift; - UMat bnorm_weight, bnorm_bias; bool fuse_batch_norm; - bool setBatchNorm(const Ptr& layer ) + virtual bool tryFuse(Ptr& top) { - bnorm = layer; - fuse_batch_norm = !bnorm.empty() && (preferableTarget == DNN_TARGET_OPENCL); - return fuse_batch_norm; + if (preferableTarget == DNN_TARGET_OPENCL && !fuse_batch_norm) + { + top->getScaleShift(scale, shift); + fuse_batch_norm = !scale.empty() || !shift.empty(); + return fuse_batch_norm; + } + return false; } Ptr activ_relu; @@ -95,12 +97,8 @@ public: #ifdef HAVE_OPENCL bool fast_forward_ocl(std::vector &inputs, std::vector &outputs) { - if( fuse_batch_norm && scale.empty()) - { - bnorm->getScaleShift(scale, shift); - bnorm_weight = scale.getUMat(ACCESS_READ); - bnorm_bias = shift.getUMat(ACCESS_READ); - } + UMat bnorm_weight = scale.empty() ? UMat() : scale.getUMat(ACCESS_READ); + UMat bnorm_bias = shift.empty() ? UMat() : shift.getUMat(ACCESS_READ); int splitDim = (acrossChannels) ? 1 : 2; for (size_t inpIdx = 0; inpIdx < inputs.size(); inpIdx++) @@ -171,12 +169,8 @@ public: return ret; } - if( fuse_batch_norm && scale.empty()) - { - bnorm->getScaleShift(scale, shift); - bnorm_weight = scale.getUMat(ACCESS_READ); - bnorm_bias = shift.getUMat(ACCESS_READ); - } + UMat bnorm_weight = scale.empty() ? UMat() : scale.getUMat(ACCESS_READ); + UMat bnorm_bias = shift.empty() ? UMat() : shift.getUMat(ACCESS_READ); for (size_t inpIdx = 0; inpIdx < inputs.size(); inpIdx++) { diff --git a/modules/dnn/src/layers/scale_layer.cpp b/modules/dnn/src/layers/scale_layer.cpp index 66faa15..34f503e 100644 --- a/modules/dnn/src/layers/scale_layer.cpp +++ b/modules/dnn/src/layers/scale_layer.cpp @@ -201,6 +201,12 @@ public: return Ptr(); } + void getScaleShift(Mat& scale, Mat& shift) const + { + scale = !blobs.empty() ? blobs[0] : Mat(); + shift = hasBias ? blobs[1] : Mat(); + } + virtual int64 getFLOPS(const std::vector &inputs, const std::vector &outputs) const { diff --git a/modules/dnn/src/layers/shift_layer.cpp b/modules/dnn/src/layers/shift_layer.cpp index 48b547c..4a75624 100644 --- a/modules/dnn/src/layers/shift_layer.cpp +++ b/modules/dnn/src/layers/shift_layer.cpp @@ -136,6 +136,12 @@ public: return Ptr(); } + void getScaleShift(Mat& scale, Mat& shift) const + { + scale = Mat(); + shift = blobs[0]; + } + virtual int64 getFLOPS(const std::vector &inputs, const std::vector &outputs) const {