more update on MVN layer ocl implementation
authorLi Peng <peng.li@intel.com>
Fri, 19 Jan 2018 10:23:02 +0000 (18:23 +0800)
committerLi Peng <peng.li@intel.com>
Fri, 19 Jan 2018 14:54:04 +0000 (22:54 +0800)
cut one ocl kernel if normVariance is disabled,
also use native_powr for performance reason.

Signed-off-by: Li Peng <peng.li@intel.com>
modules/dnn/src/layers/mvn_layer.cpp
modules/dnn/src/opencl/mvn.cl

index 46ffcc5..d5daa76 100644 (file)
@@ -96,30 +96,31 @@ public:
                 return false;
 
             int number = (s[1] % 8 == 0) ? 8 : ((s[1] % 4 == 0) ? 4 : 1);
-            String buildopt = format("-DNUM=%d ", number);
-            String kname = format("calc_mean%d", number);
-            ocl::Kernel kernel(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt);
-            if (kernel.empty())
-                return false;
             size_t global[] = { (size_t)s[0], (size_t)(s[1] / number) };
-            kernel.set(0, ocl::KernelArg::PtrReadOnly(inpMat));
-            kernel.set(1, (int)s[0]);
-            kernel.set(2, (int)s[1]);
-            kernel.set(3, ocl::KernelArg::PtrReadOnly(meanMat));
-            kernel.set(4, ocl::KernelArg::PtrWriteOnly(tmpMat));
-            ret = kernel.run(2, global, NULL, false);
-            if (!ret)
-                return false;
-
+            String buildopt = format("-DNUM=%d ", number);
             if (normVariance)
             {
+                String kname = format("calc_mean%d", number);
+                ocl::Kernel kernel(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt);
+                if (kernel.empty())
+                    return false;
+
+                kernel.set(0, ocl::KernelArg::PtrReadOnly(inpMat));
+                kernel.set(1, (int)s[0]);
+                kernel.set(2, (int)s[1]);
+                kernel.set(3, ocl::KernelArg::PtrReadOnly(meanMat));
+                kernel.set(4, ocl::KernelArg::PtrWriteOnly(tmpMat));
+                ret = kernel.run(2, global, NULL, false);
+                if (!ret)
+                    return false;
+
                 ret = ocl4dnn::ocl4dnnGEMV<float>(ocl4dnn::CblasNoTrans, s[0], s[1], alpha,
                                                   tmpMat, 0, oneMat, 0, 0.0f, devMat, 0);
                 if (!ret)
                     return false;
             }
 
-            kname = format("mvn%d", number);
+            String kname = format("mvn%d", number);
             if (normVariance)
                 buildopt += "-DNORM_VARIANCE";
             ocl::Kernel kernel1(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt);
index c87667d..c1bf1f0 100644 (file)
@@ -79,7 +79,7 @@ __kernel void CALC_MEAN(__global const Dtype* src,
 
     Dtype mean_val = mean[x];
     vec_type src_vec = load(src, index);
-    vec_type dst_vec = pow(src_vec - (vec_type)mean_val, 2);
+    vec_type dst_vec = native_powr(src_vec - (vec_type)mean_val, 2);
     store(dst_vec, dst, index);
 }