mvn, batch_norm and relu layer fusion
authorLi Peng <peng.li@intel.com>
Tue, 23 Jan 2018 15:52:41 +0000 (23:52 +0800)
committerLi Peng <peng.li@intel.com>
Thu, 25 Jan 2018 10:57:05 +0000 (18:57 +0800)
Signed-off-by: Li Peng <peng.li@intel.com>
modules/dnn/src/dnn.cpp
modules/dnn/src/layers/batch_norm_layer.cpp
modules/dnn/src/layers/mvn_layer.cpp
modules/dnn/src/opencl/mvn.cl

index 84dc8af..26ff469 100644 (file)
@@ -1190,7 +1190,8 @@ struct Net::Impl
 
             // TODO: OpenCL target support more fusion styles.
             if ( preferableTarget == DNN_TARGET_OPENCL &&
-                 (!cv::ocl::useOpenCL() || ld.layerInstance->type.compare("Convolution")) )
+                 (!cv::ocl::useOpenCL() || (ld.layerInstance->type != "Convolution" &&
+                 ld.layerInstance->type != "MVN")) )
                 continue;
 
             Ptr<Layer>& currLayer = ld.layerInstance;
index eca30f4..8acf8b2 100644 (file)
@@ -81,9 +81,6 @@ public:
             dstWeightsData[i] = w;
             dstBiasData[i] = (hasBias ? biasData[i] : 0.0f) - w * meanData[i] * varMeanScale;
         }
-
-        umat_weight = weights_.getUMat(ACCESS_READ);
-        umat_bias = bias_.getUMat(ACCESS_READ);
     }
 
     void getScaleShift(Mat& scale, Mat& shift) const
@@ -119,6 +116,12 @@ public:
         CV_Assert(blobs.size() >= 2);
         CV_Assert(inputs.size() == 1);
 
+        if (umat_weight.empty())
+        {
+            umat_weight = weights_.getUMat(ACCESS_READ);
+            umat_bias = bias_.getUMat(ACCESS_READ);
+        }
+
         UMat &inpBlob = inputs[0];
         CV_Assert(inpBlob.dims == 2 || inpBlob.dims == 4);
         int groups = inpBlob.size[0];
index d5daa76..1d5e12b 100644 (file)
@@ -60,6 +60,36 @@ public:
         normVariance = params.get<bool>("normalize_variance", true);
         acrossChannels = params.get<bool>("across_channels", false);
         eps = params.get<double>("eps", 1e-9);
+        fuse_batch_norm = false;
+        fuse_relu = false;
+        relu_slope = 0.f;
+    }
+
+    Ptr<BatchNormLayer> bnorm;
+    Mat scale, shift;
+    UMat bnorm_weight, bnorm_bias;
+    bool fuse_batch_norm;
+
+    bool setBatchNorm(const Ptr<BatchNormLayer>& layer )
+    {
+        bnorm = layer;
+        fuse_batch_norm = !bnorm.empty() && (preferableTarget == DNN_TARGET_OPENCL);
+        return fuse_batch_norm;
+    }
+
+    Ptr<ReLULayer> activ_relu;
+    float relu_slope;
+    bool fuse_relu;
+    bool setActivation(const Ptr<ActivationLayer>& layer)
+    {
+        if (!layer.empty() && preferableTarget == DNN_TARGET_OPENCL)
+        {
+            activ_relu = layer.dynamicCast<ReLULayer>();
+            if( !activ_relu.empty() )
+                relu_slope = activ_relu->negativeSlope;
+        }
+        fuse_relu = !activ_relu.empty();
+        return fuse_relu;
     }
 
 #ifdef HAVE_OPENCL
@@ -71,19 +101,24 @@ public:
         inputs_.getUMatVector(inputs);
         outputs_.getUMatVector(outputs);
 
+        if( fuse_batch_norm && scale.empty())
+        {
+            bnorm->getScaleShift(scale, shift);
+            bnorm_weight = scale.getUMat(ACCESS_READ);
+            bnorm_bias = shift.getUMat(ACCESS_READ);
+        }
+
         for (size_t inpIdx = 0; inpIdx < inputs.size(); inpIdx++)
         {
-            UMat &inpBlob = inputs[inpIdx];
-            UMat &outBlob = outputs[inpIdx];
+            UMat &inpMat = inputs[inpIdx];
+            UMat &outMat = outputs[inpIdx];
 
             int splitDim = (acrossChannels) ? 1 : 2;
             int i, newRows = 1;
             for( i = 0; i < splitDim; i++ )
-                newRows *= inpBlob.size[i];
+                newRows *= inpMat.size[i];
 
-            MatShape s = shape(newRows, inpBlob.total() / newRows);
-            UMat& inpMat = inpBlob;
-            UMat& outMat = outBlob;
+            MatShape s = shape(newRows, inpMat.total() / newRows);
             UMat oneMat = UMat::ones(s[1], 1, CV_32F);
             UMat meanMat = UMat(s[0], 1, CV_32F);
             UMat devMat  = UMat(s[0], 1, CV_32F);
@@ -121,8 +156,9 @@ public:
             }
 
             String kname = format("mvn%d", number);
-            if (normVariance)
-                buildopt += "-DNORM_VARIANCE";
+            buildopt += format("%s %s %s ", (normVariance) ? "-DNORM_VARIANCE" : "",
+                               (fuse_batch_norm) ? "-DFUSE_BATCH_NORM" : "",
+                               (fuse_relu) ? "-DFUSE_RELU" : "");
             ocl::Kernel kernel1(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt);
             if (kernel1.empty())
                 return false;
@@ -132,7 +168,11 @@ public:
             kernel1.set(3, (float)eps);
             kernel1.set(4, ocl::KernelArg::PtrReadOnly(meanMat));
             kernel1.set(5, ocl::KernelArg::PtrReadOnly(devMat));
-            kernel1.set(6, ocl::KernelArg::PtrWriteOnly(outMat));
+            kernel1.set(6, ocl::KernelArg::PtrReadOnly(bnorm_weight));
+            kernel1.set(7, ocl::KernelArg::PtrReadOnly(bnorm_bias));
+            kernel1.set(8, (int)inpMat.size[1]);
+            kernel1.set(9, (float)relu_slope);
+            kernel1.set(10, ocl::KernelArg::PtrWriteOnly(outMat));
             ret = kernel1.run(2, global, NULL, false);
             if (!ret)
                 return false;
index c1bf1f0..cc059ee 100644 (file)
@@ -89,6 +89,10 @@ __kernel void MVN(__global const Dtype* src,
                   const Dtype eps,
                   __global const Dtype* mean,
                   __global const Dtype* dev,
+                  __global const Dtype* bnorm_weight,
+                  __global const Dtype* bnorm_bias,
+                  const int channels,
+                  const float relu_slope,
                   __global Dtype* dst)
 {
     int x = get_global_id(0);
@@ -106,7 +110,21 @@ __kernel void MVN(__global const Dtype* src,
 #else
     alpha = 1;
 #endif
+
+    Dtype w = 1.f, b = 0.f;
+#ifdef FUSE_BATCH_NORM
+    w = bnorm_weight[x % channels];
+    b = bnorm_bias[x % channels];
+#endif
+
     vec_type src_vec = load(src, index) - (vec_type)mean_val;
     vec_type dst_vec = src_vec * alpha;
+    dst_vec = dst_vec * w + (vec_type)b;
+
+#ifdef FUSE_RELU
+    vec_type new_val = dst_vec * relu_slope;
+    dst_vec = select(new_val, dst_vec, dst_vec > (vec_type)0.f);
+#endif
+
     store(dst_vec, dst, index);
 }