Refactored deep learning layers fusion
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Tue, 13 Feb 2018 09:07:56 +0000 (12:07 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Tue, 13 Feb 2018 11:35:58 +0000 (14:35 +0300)
modules/dnn/include/opencv2/dnn/all_layers.hpp
modules/dnn/include/opencv2/dnn/dnn.hpp
modules/dnn/src/dnn.cpp
modules/dnn/src/layers/convolution_layer.cpp
modules/dnn/src/layers/mvn_layer.cpp
modules/dnn/src/layers/scale_layer.cpp
modules/dnn/src/layers/shift_layer.cpp

index 061d184..0704f6b 100644 (file)
@@ -472,7 +472,6 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
         bool hasWeights, hasBias;
         float epsilon;
 
-        virtual void getScaleShift(Mat& scale, Mat& shift) const = 0;
         static Ptr<BatchNormLayer> create(const LayerParams &params);
     };
 
index 4ad3035..e94b30c 100644 (file)
@@ -281,20 +281,26 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
         virtual bool setActivation(const Ptr<ActivationLayer>& layer);
 
         /**
-         * @brief Tries to attach to the layer the subsequent batch normalization layer, i.e. do the layer fusion in a partial case.
-         * @param[in] layer The subsequent batch normalization layer.
-         *
-         * Returns true if the batch normalization layer has been attached successfully.
+         * @brief Try to fuse current layer with a next one
+         * @param[in] top Next layer to be fused.
+         * @returns True if fusion was performed.
          */
-        virtual bool setBatchNorm(const Ptr<BatchNormLayer>& layer);
+        virtual bool tryFuse(Ptr<Layer>& top);
 
         /**
-         * @brief Tries to attach to the layer the subsequent scaling layer, i.e. do the layer fusion in a partial case.
-         * @param[in] layer The subsequent scaling layer.
+         * @brief Returns parameters of layers with channel-wise multiplication and addition.
+         * @param[out] scale Channel-wise multipliers. Total number of values should
+         *                   be equal to number of channels.
+         * @param[out] shift Channel-wise offsets. Total number of values should
+         *                   be equal to number of channels.
          *
-         * Returns true if the scaling layer has been attached successfully.
+         * Some layers can fuse their transformations with further layers.
+         * In example, convolution + batch normalization. This way base layer
+         * use weights from layer after it. Fused layer is skipped.
+         * By default, @p scale and @p shift are empty that means layer has no
+         * element-wise multiplications or additions.
          */
-        virtual bool setScale(const Ptr<ScaleLayer>& layer);
+        virtual void getScaleShift(Mat& scale, Mat& shift) const;
 
         /**
          * @brief "Deattaches" all the layers, attached to particular layer.
index b66fb42..31ae173 100644 (file)
@@ -1407,46 +1407,30 @@ struct Net::Impl
             if( ld.consumers.size() == 1 && pinsToKeep.count(LayerPin(lid, 0)) == 0 )
             {
                 LayerData* nextData = &layers[ld.consumers[0].lid];
-                Ptr<BatchNormLayer> nextBNormLayer =
-                    nextData->layerInstance.dynamicCast<BatchNormLayer>();
                 LayerPin lpNext(ld.consumers[0].lid, 0);
-                if( !nextBNormLayer.empty() && pinsToKeep.count(lpNext) == 0 )
+                while (nextData)
                 {
-                    LayerData* bnormData = nextData;
-                    nextData = 0;
-                    if( currLayer->setBatchNorm(nextBNormLayer) )
+                    Ptr<Layer> nextLayer = nextData->layerInstance;
+                    if (currLayer->tryFuse(nextLayer))
                     {
-                        printf_(("\tfused with %s\n", nextBNormLayer->name.c_str()));
-                        bnormData->skip = true;
+                        printf_(("\tfused with %s\n", nextLayer->name.c_str()));
+                        nextData->skip = true;
                         ld.outputBlobs = layers[lpNext.lid].outputBlobs;
                         ld.outputBlobsWrappers = layers[lpNext.lid].outputBlobsWrappers;
-                        if( bnormData->consumers.size() == 1 )
+                        if (nextData->consumers.size() == 1)
                         {
-                            nextData = &layers[bnormData->consumers[0].lid];
-                            lpNext = LayerPin(bnormData->consumers[0].lid, 0);
+                            int nextLayerId = nextData->consumers[0].lid;
+                            nextData = &layers[nextLayerId];
+                            lpNext = LayerPin(nextLayerId, 0);
                         }
-                    }
-                }
-
-                Ptr<ScaleLayer> nextScaleLayer;
-                if( nextData )
-                    nextScaleLayer = nextData->layerInstance.dynamicCast<ScaleLayer>();
-                if( !nextScaleLayer.empty() && pinsToKeep.count(lpNext) == 0 )
-                {
-                    LayerData* scaleData = nextData;
-                    nextData = 0;
-                    if( currLayer->setScale(nextScaleLayer) )
-                    {
-                        printf_(("\tfused with %s\n", nextScaleLayer->name.c_str()));
-                        scaleData->skip = true;
-                        ld.outputBlobs = layers[lpNext.lid].outputBlobs;
-                        ld.outputBlobsWrappers = layers[lpNext.lid].outputBlobsWrappers;
-                        if( scaleData->consumers.size() == 1 )
+                        else
                         {
-                            nextData = &layers[scaleData->consumers[0].lid];
-                            lpNext = LayerPin(scaleData->consumers[0].lid, 0);
+                            nextData = 0;
+                            break;
                         }
                     }
+                    else
+                        break;
                 }
 
                 // For now, OpenCL target support fusion with activation of ReLU/ChannelsPReLU/Power/Tanh
@@ -2627,13 +2611,16 @@ Ptr<BackendNode> Layer::tryAttach(const Ptr<BackendNode>& node)
 }
 
 bool Layer::setActivation(const Ptr<ActivationLayer>&) { return false; }
-bool Layer::setBatchNorm(const Ptr<BatchNormLayer>&) { return false; }
-bool Layer::setScale(const Ptr<ScaleLayer>&) { return false; }
+bool Layer::tryFuse(Ptr<Layer>&) { return false; }
+void Layer::getScaleShift(Mat& scale, Mat& shift) const
+{
+    scale = Mat();
+    shift = Mat();
+}
+
 void Layer::unsetAttached()
 {
     setActivation(Ptr<ActivationLayer>());
-    setBatchNorm(Ptr<BatchNormLayer>());
-    setScale(Ptr<ScaleLayer>());
 }
 
 template <typename T>
index 64c2212..3a0bc1b 100644 (file)
@@ -61,7 +61,23 @@ namespace dnn
 class BaseConvolutionLayerImpl : public ConvolutionLayer
 {
 public:
-    BaseConvolutionLayerImpl() {}
+    BaseConvolutionLayerImpl(const LayerParams &params)
+    {
+        setParamsFrom(params);
+        getConvolutionKernelParams(params, kernel.height, kernel.width, pad.height,
+                                   pad.width, stride.height, stride.width, dilation.height,
+                                   dilation.width, padMode);
+
+        numOutput = params.get<int>("num_output");
+        int ngroups = params.get<int>("group", 1);
+
+        adjustPad.height = params.get<int>("adj_h", 0);
+        adjustPad.width = params.get<int>("adj_w", 0);
+
+        CV_Assert(numOutput % ngroups == 0);
+        CV_Assert(adjustPad.width < stride.width &&
+                  adjustPad.height < stride.height);
+    }
 
     virtual bool supportBackend(int backendId)
     {
@@ -153,12 +169,10 @@ class ConvolutionLayerImpl : public BaseConvolutionLayerImpl
 {
 public:
     enum { VEC_ALIGN = 8, DFT_TYPE = CV_32F };
-    Mat weightsMat;
+    Mat weightsMat, weightsMat_doubles;
     std::vector<float> biasvec;
     std::vector<float> reluslope;
     Ptr<ActivationLayer> activ;
-    Ptr<BatchNormLayer> bnorm;
-    Ptr<ScaleLayer> scaleLayer;
 
 #ifdef HAVE_OPENCL
     Ptr<OCL4DNNConvSpatial<float> > convolutionOp;
@@ -169,7 +183,7 @@ public:
     ocl4dnnFusedActiv_t activType;
     float power;
 #endif
-    ConvolutionLayerImpl()
+    ConvolutionLayerImpl(const LayerParams &params) : BaseConvolutionLayerImpl(params)
     {
 #ifdef HAVE_OPENCL
         fusedBias = false;
@@ -225,6 +239,42 @@ public:
         return false;
     }
 
+    virtual void finalize(const std::vector<Mat*> &inputs, std::vector<Mat> &outputs)
+    {
+        BaseConvolutionLayerImpl::finalize(inputs, outputs);
+
+        CV_Assert(!blobs.empty());
+        const int outCn = blobs[0].size[0];
+        // prepare weightsMat where each row is aligned and has enough zero padding on the right to
+        // use vectorized (i.e. with intrinsics) loops without tail processing
+        Mat wm = blobs[0].reshape(1, outCn).clone();
+        if( wm.step1() % VEC_ALIGN != 0 )
+        {
+            int newcols = (int)alignSize(wm.step1(), VEC_ALIGN);
+            Mat wm_buffer = Mat(outCn, newcols, wm.type());
+            Mat wm_padding = wm_buffer.colRange(wm.cols, newcols);
+            wm_padding.setTo(Scalar::all(0.));
+            Mat wm_aligned = wm_buffer.colRange(0, wm.cols);
+            wm.copyTo(wm_aligned);
+            wm = wm_aligned;
+        }
+        weightsMat = wm;
+        weightsMat.convertTo(weightsMat_doubles, CV_64F);
+
+        Mat biasMat = hasBias() ? blobs[1].reshape(1, outCn) : Mat();
+        biasvec.resize(outCn+2);
+        if( biasMat.empty() )
+        {
+            for(int i = 0; i < outCn; i++ )
+                biasvec[i] = 0.f;
+        }
+        else
+        {
+            for(int i = 0; i < outCn; i++ )
+                biasvec[i] = biasMat.at<float>(i);
+        }
+    }
+
     bool setActivation(const Ptr<ActivationLayer>& layer)
     {
         activ = layer;
@@ -240,10 +290,11 @@ public:
             if (!activ_power.empty())
             {
                 if (activ_power->scale != 1.f || activ_power->shift != 0.f)
-                    newWeightAndBias = true;
-
-                if (activ_power->scale != 1.f)
-                    weightsMat.release();
+                {
+                    const int outCh = blobs[0].size[0];
+                    fuseWeights(Mat(1, outCh, CV_32F, Scalar(activ_power->scale)),
+                                Mat(1, outCh, CV_32F, Scalar(activ_power->shift)));
+                }
 
                 power = activ_power->power;
                 activType = OCL4DNN_CONV_FUSED_ACTIV_POWER;
@@ -258,35 +309,49 @@ public:
         return !activ.empty();
     }
 
-    bool setBatchNorm(const Ptr<BatchNormLayer>& layer )
+    virtual bool tryFuse(Ptr<Layer>& top)
     {
-        // for now the scale layer followed by the batch norm cannot be fused, only vice versa.
-        if( !scaleLayer.empty() )
-            return false;
-        bnorm = layer;
-        // we will need to re-compute the weights with the batch
-        // norm coefficients taken into account
-        weightsMat.release();
-#ifdef HAVE_OPENCL
-        newWeightAndBias = true;
-        fusedBias = false;
-#endif
-        return !bnorm.empty();
+        Mat w, b;
+        top->getScaleShift(w, b);
+        if (!w.empty() || !b.empty())
+        {
+            fuseWeights(w, b);
+            return true;
+        }
+        return false;
     }
 
-    bool setScale(const Ptr<ScaleLayer>& layer)
+    void fuseWeights(const Mat& w, const Mat& b)
     {
-        if (layer.empty() || layer->blobs.empty())
-            return false;
-        scaleLayer = layer;
-        // we will need to re-compute the weights with the scaling
-        // coefficients taken into account
-        weightsMat.release();
+        // Convolution weights have OIHW data layout. Parameters fusion in case of
+        // (conv(I) + b1 ) * w + b2
+        // means to replace convolution's weights to [w*conv(I)] and bias to [b1 * w + b2]
+        const int outCn = weightsMat.size[0];
+        CV_Assert(!weightsMat.empty(), biasvec.size() == outCn + 2,
+                  w.empty() || outCn == w.total(), b.empty() || outCn == b.total());
+
+        if (!w.empty())
+        {
+            for (int i = 0; i < outCn; ++i)
+            {
+                double wi = w.at<float>(i);
+                cv::multiply(slice(weightsMat_doubles, i), wi, slice(weightsMat_doubles, i));
+                biasvec[i] *= wi;
+            }
+            weightsMat_doubles.convertTo(weightsMat, weightsMat.type());
+        }
+
+        if (!b.empty())
+        {
+            for (int i = 0; i < outCn; ++i)
+                biasvec[i] += b.at<float>(i);
+        }
+
 #ifdef HAVE_OPENCL
-        newWeightAndBias = true;
-        fusedBias = false;
+        newWeightAndBias = !w.empty() || !b.empty();
+        fusedBias = hasBias() || !b.empty();
 #endif
-        return true;
+        biasvec[outCn] = biasvec[outCn+1] = biasvec[outCn-1];
     }
 
     virtual Ptr<BackendNode> initHalide(const std::vector<Ptr<BackendWrapper> > &inputs)
@@ -776,97 +841,7 @@ public:
             convolutionOp = Ptr<OCL4DNNConvSpatial<float> >(new OCL4DNNConvSpatial<float>(config));
         }
 
-        int k, outCn = umat_blobs[0].size[0];
-        if( weightsMat.empty() )
-        {
-            // prepare weightsMat where each row is aligned and has enough zero padding on the right to
-            // use vectorized (i.e. with intrinsics) loops without tail processing
-            Mat wm = blobs[0].reshape(1, outCn).clone();
-            if( wm.step1() % VEC_ALIGN != 0 )
-            {
-                int newcols = (int)alignSize(wm.step1(), VEC_ALIGN);
-                Mat wm_buffer = Mat(outCn, newcols, wm.type());
-                Mat wm_padding = wm_buffer.colRange(wm.cols, newcols);
-                wm_padding.setTo(Scalar::all(0.));
-                Mat wm_aligned = wm_buffer.colRange(0, wm.cols);
-                wm.copyTo(wm_aligned);
-                wm = wm_aligned;
-            }
-            weightsMat = wm;
-
-            Mat biasMat = hasBias() ? blobs[1].reshape(1, outCn) : Mat();
-            biasvec.resize(outCn+2);
-            if( biasMat.empty() )
-            {
-                for( k = 0; k < outCn; k++ )
-                    biasvec[k] = 0.f;
-            }
-            else
-            {
-                for( k = 0; k < outCn; k++ )
-                    biasvec[k] = biasMat.at<float>(k);
-            }
-
-            if( !bnorm.empty() || !scaleLayer.empty() || IS_POWER_LAYER(activ))
-            {
-                Mat scale, shift, scale2, shift2;
-                const float *scaleptr = 0, *shiftptr = 0;
-                const float *scaleptr2 = 0, *shiftptr2 = 0;
-                float a = 1.f, b = 0.f;
-
-                if( !bnorm.empty() )
-                {
-                    bnorm->getScaleShift(scale, shift);
-                    CV_Assert( scale.isContinuous() && shift.isContinuous() &&
-                               scale.type() == CV_32F && shift.type() == CV_32F &&
-                               scale.total() == (size_t)outCn &&
-                               shift.total() == (size_t)outCn );
-                    scaleptr = scale.ptr<float>();
-                    shiftptr = shift.ptr<float>();
-                }
-                if( !scaleLayer.empty() )
-                {
-                    scale2 = scaleLayer->blobs[0];
-                    CV_Assert( scale2.isContinuous() && scale2.type() == CV_32F &&
-                               scale2.total() == (size_t)outCn );
-                    scaleptr2 = scale2.ptr<float>();
-                    if( scaleLayer->hasBias )
-                    {
-                        shift2 = scaleLayer->blobs[1];
-                        CV_Assert( shift2.isContinuous() && shift2.type() == CV_32F &&
-                                   shift2.total() == (size_t)outCn );
-                        shiftptr2 = shift2.ptr<float>();
-                    }
-                }
-
-                if( IS_POWER_LAYER(activ) )
-                {
-                    Ptr<PowerLayer> activ_power = activ.dynamicCast<PowerLayer>();
-                    CV_Assert(activ_power);
-                    a = activ_power->scale;
-                    b = activ_power->shift;
-                }
-
-                if (shiftptr || shiftptr2 || b != 0.f)
-                    fusedBias = true;
-
-                for( int i = 0; i < outCn; i++ )
-                {
-                    float s1 = scaleptr ? scaleptr[i] : 1.f;
-                    float delta1 = shiftptr ? shiftptr[i] : 0.f;
-                    float s2 = scaleptr2 ? scaleptr2[i] : 1.f;
-                    float delta2 = shiftptr2 ? shiftptr2[i] : 0.f;
-                    float* w_i = weightsMat.ptr<float>(i);
-                    int j, wcols = weightsMat.cols;
-
-                    for( j = 0; j < wcols; j++ )
-                        w_i[j] *= (s1*s2*a);
-
-                    biasvec[i] = biasvec[i]*(s1*s2*a) + (delta1*s2*a + delta2*a + b);
-                }
-            }
-            biasvec[outCn] = biasvec[outCn+1] = biasvec[outCn-1];
-        }
+        int outCn = umat_blobs[0].size[0];
 
         reluslope.clear();
         if( activ )
@@ -973,86 +948,7 @@ public:
 
         int ngroups = inputs[0]->size[1]/blobs[0].size[1];
         CV_Assert(outputs[0].size[1] % ngroups == 0);
-        int k, outCn = blobs[0].size[0];
-
-        if( weightsMat.empty() )
-        {
-            // prepare weightsMat where each row is aligned and has enough zero padding on the right to
-            // use vectorized (i.e. with intrinsics) loops without tail processing
-            Mat wm = blobs[0].reshape(1, outCn).clone();
-            if( wm.step1() % VEC_ALIGN != 0 )
-            {
-                int newcols = (int)alignSize(wm.step1(), VEC_ALIGN);
-                Mat wm_buffer = Mat(outCn, newcols, wm.type());
-                Mat wm_padding = wm_buffer.colRange(wm.cols, newcols);
-                wm_padding.setTo(Scalar::all(0.));
-                Mat wm_aligned = wm_buffer.colRange(0, wm.cols);
-                wm.copyTo(wm_aligned);
-                wm = wm_aligned;
-            }
-            weightsMat = wm;
-
-            Mat biasMat = hasBias() ? blobs[1].reshape(1, outCn) : Mat();
-            biasvec.resize(outCn+2);
-            if( biasMat.empty() )
-            {
-                for( k = 0; k < outCn; k++ )
-                    biasvec[k] = 0.f;
-            }
-            else
-            {
-                for( k = 0; k < outCn; k++ )
-                    biasvec[k] = biasMat.at<float>(k);
-            }
-
-            if( !bnorm.empty() || !scaleLayer.empty() )
-            {
-                Mat scale, shift, scale2, shift2;
-                const float *scaleptr = 0, *shiftptr = 0;
-                const float *scaleptr2 = 0, *shiftptr2 = 0;
-
-                if( !bnorm.empty() )
-                {
-                    bnorm->getScaleShift(scale, shift);
-                    CV_Assert( scale.isContinuous() && shift.isContinuous() &&
-                               scale.type() == CV_32F && shift.type() == CV_32F &&
-                               scale.total() == (size_t)outCn &&
-                               shift.total() == (size_t)outCn );
-                    scaleptr = scale.ptr<float>();
-                    shiftptr = shift.ptr<float>();
-                }
-                if( !scaleLayer.empty() )
-                {
-                    scale2 = scaleLayer->blobs[0];
-                    CV_Assert( scale2.isContinuous() && scale2.type() == CV_32F &&
-                               scale2.total() == (size_t)outCn );
-                    scaleptr2 = scale2.ptr<float>();
-                    if( scaleLayer->hasBias )
-                    {
-                        shift2 = scaleLayer->blobs[1];
-                        CV_Assert( shift2.isContinuous() && shift2.type() == CV_32F &&
-                                   shift2.total() == (size_t)outCn );
-                        shiftptr2 = shift2.ptr<float>();
-                    }
-                }
-
-                for( int i = 0; i < outCn; i++ )
-                {
-                    float s1 = scaleptr ? scaleptr[i] : 1.f;
-                    float delta1 = shiftptr ? shiftptr[i] : 0.f;
-                    float s2 = scaleptr2 ? scaleptr2[i] : 1.f;
-                    float delta2 = shiftptr2 ? shiftptr2[i] : 0.f;
-                    float* w_i = weightsMat.ptr<float>(i);
-                    int j, wcols = weightsMat.cols;
-
-                    for( j = 0; j < wcols; j++ )
-                        w_i[j] *= (s1*s2);
-
-                    biasvec[i] = biasvec[i]*(s1*s2) + (delta1*s2 + delta2);
-                }
-            }
-            biasvec[outCn] = biasvec[outCn+1] = biasvec[outCn-1];
-        }
+        int outCn = blobs[0].size[0];
 
         reluslope.clear();
         if( activ )
@@ -1103,6 +999,8 @@ public:
     UMat umat_weights;
     UMat umat_biases;
 
+    DeConvolutionLayerImpl(const LayerParams& params) : BaseConvolutionLayerImpl(params) {}
+
     MatShape computeColRowShape(const MatShape &inpShape, const MatShape &outShape) const
     {
         int inpCn = inpShape[1];
@@ -1619,36 +1517,15 @@ public:
     }
 };
 
-//Convolution and Deconvolution
-static void initConvDeconvLayerFromCaffe(Ptr<BaseConvolutionLayer> l, const LayerParams &params)
-{
-    l->setParamsFrom(params);
-    getConvolutionKernelParams(params, l->kernel.height, l->kernel.width, l->pad.height,
-                               l->pad.width, l->stride.height, l->stride.width, l->dilation.height,
-                               l->dilation.width, l->padMode);
-
-    l->numOutput = params.get<int>("num_output");
-    int ngroups = params.get<int>("group", 1);
-
-    l->adjustPad.height = params.get<int>("adj_h", 0);
-    l->adjustPad.width = params.get<int>("adj_w", 0);
-
-    CV_Assert(l->numOutput % ngroups == 0);
-    CV_Assert(l->adjustPad.width < l->stride.width &&
-              l->adjustPad.height < l->stride.height);
-}
-
 Ptr<BaseConvolutionLayer> ConvolutionLayer::create(const LayerParams &params)
 {
-    ConvolutionLayerImpl* conv_ptr = new ConvolutionLayerImpl;
-    Ptr<BaseConvolutionLayer> l(conv_ptr);
-    initConvDeconvLayerFromCaffe(l, params);
+    Ptr<ConvolutionLayerImpl> l(new ConvolutionLayerImpl(params));
 
 #ifdef HAVE_OPENCL
     size_t n = params.blobs.size();
-    conv_ptr->umat_blobs.resize(n);
+    l->umat_blobs.resize(n);
     for (int i = 0; i < n; i++)
-        conv_ptr->umat_blobs[i] = params.blobs[i].getUMat(ACCESS_READ);
+        l->umat_blobs[i] = params.blobs[i].getUMat(ACCESS_READ);
 #endif
 
     return l;
@@ -1656,10 +1533,7 @@ Ptr<BaseConvolutionLayer> ConvolutionLayer::create(const LayerParams &params)
 
 Ptr<BaseConvolutionLayer> DeconvolutionLayer::create(const LayerParams &params)
 {
-    Ptr<BaseConvolutionLayer> l(new DeConvolutionLayerImpl);
-    initConvDeconvLayerFromCaffe(l, params);
-
-    return l;
+    return Ptr<BaseConvolutionLayer>(new DeConvolutionLayerImpl(params));
 }
 
 }
index c911b74..e0976cd 100644 (file)
@@ -65,16 +65,18 @@ public:
         relu_slope = 0.f;
     }
 
-    Ptr<BatchNormLayer> bnorm;
     Mat scale, shift;
-    UMat bnorm_weight, bnorm_bias;
     bool fuse_batch_norm;
 
-    bool setBatchNorm(const Ptr<BatchNormLayer>& layer )
+    virtual bool tryFuse(Ptr<Layer>& top)
     {
-        bnorm = layer;
-        fuse_batch_norm = !bnorm.empty() && (preferableTarget == DNN_TARGET_OPENCL);
-        return fuse_batch_norm;
+        if (preferableTarget == DNN_TARGET_OPENCL && !fuse_batch_norm)
+        {
+            top->getScaleShift(scale, shift);
+            fuse_batch_norm = !scale.empty() || !shift.empty();
+            return fuse_batch_norm;
+        }
+        return false;
     }
 
     Ptr<ReLULayer> activ_relu;
@@ -95,12 +97,8 @@ public:
 #ifdef HAVE_OPENCL
     bool fast_forward_ocl(std::vector<UMat> &inputs, std::vector<UMat> &outputs)
     {
-        if( fuse_batch_norm && scale.empty())
-        {
-            bnorm->getScaleShift(scale, shift);
-            bnorm_weight = scale.getUMat(ACCESS_READ);
-            bnorm_bias = shift.getUMat(ACCESS_READ);
-        }
+        UMat bnorm_weight = scale.empty() ? UMat() : scale.getUMat(ACCESS_READ);
+        UMat bnorm_bias = shift.empty() ? UMat() : shift.getUMat(ACCESS_READ);
 
         int splitDim = (acrossChannels) ? 1 : 2;
         for (size_t inpIdx = 0; inpIdx < inputs.size(); inpIdx++)
@@ -171,12 +169,8 @@ public:
             return ret;
         }
 
-        if( fuse_batch_norm && scale.empty())
-        {
-            bnorm->getScaleShift(scale, shift);
-            bnorm_weight = scale.getUMat(ACCESS_READ);
-            bnorm_bias = shift.getUMat(ACCESS_READ);
-        }
+        UMat bnorm_weight = scale.empty() ? UMat() : scale.getUMat(ACCESS_READ);
+        UMat bnorm_bias = shift.empty() ? UMat() : shift.getUMat(ACCESS_READ);
 
         for (size_t inpIdx = 0; inpIdx < inputs.size(); inpIdx++)
         {
index 66faa15..34f503e 100644 (file)
@@ -201,6 +201,12 @@ public:
         return Ptr<BackendNode>();
     }
 
+    void getScaleShift(Mat& scale, Mat& shift) const
+    {
+        scale = !blobs.empty() ? blobs[0] : Mat();
+        shift = hasBias ? blobs[1] : Mat();
+    }
+
     virtual int64 getFLOPS(const std::vector<MatShape> &inputs,
                            const std::vector<MatShape> &outputs) const
     {
index 48b547c..4a75624 100644 (file)
@@ -136,6 +136,12 @@ public:
         return Ptr<BackendNode>();
     }
 
+    void getScaleShift(Mat& scale, Mat& shift) const
+    {
+        scale = Mat();
+        shift = blobs[0];
+    }
+
     virtual int64 getFLOPS(const std::vector<MatShape> &inputs,
                            const std::vector<MatShape> &outputs) const
     {