From fe494297e4f6e91bc4d3c7372adaeb3053a93bbc Mon Sep 17 00:00:00 2001 From: Li Peng Date: Fri, 19 Jan 2018 18:23:02 +0800 Subject: [PATCH] more update on MVN layer ocl implementation cut one ocl kernel if normVariance is disabled, also use native_powr for performance reason. Signed-off-by: Li Peng --- modules/dnn/src/layers/mvn_layer.cpp | 31 ++++++++++++++++--------------- modules/dnn/src/opencl/mvn.cl | 2 +- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/modules/dnn/src/layers/mvn_layer.cpp b/modules/dnn/src/layers/mvn_layer.cpp index 46ffcc5..d5daa76 100644 --- a/modules/dnn/src/layers/mvn_layer.cpp +++ b/modules/dnn/src/layers/mvn_layer.cpp @@ -96,30 +96,31 @@ public: return false; int number = (s[1] % 8 == 0) ? 8 : ((s[1] % 4 == 0) ? 4 : 1); - String buildopt = format("-DNUM=%d ", number); - String kname = format("calc_mean%d", number); - ocl::Kernel kernel(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt); - if (kernel.empty()) - return false; size_t global[] = { (size_t)s[0], (size_t)(s[1] / number) }; - kernel.set(0, ocl::KernelArg::PtrReadOnly(inpMat)); - kernel.set(1, (int)s[0]); - kernel.set(2, (int)s[1]); - kernel.set(3, ocl::KernelArg::PtrReadOnly(meanMat)); - kernel.set(4, ocl::KernelArg::PtrWriteOnly(tmpMat)); - ret = kernel.run(2, global, NULL, false); - if (!ret) - return false; - + String buildopt = format("-DNUM=%d ", number); if (normVariance) { + String kname = format("calc_mean%d", number); + ocl::Kernel kernel(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt); + if (kernel.empty()) + return false; + + kernel.set(0, ocl::KernelArg::PtrReadOnly(inpMat)); + kernel.set(1, (int)s[0]); + kernel.set(2, (int)s[1]); + kernel.set(3, ocl::KernelArg::PtrReadOnly(meanMat)); + kernel.set(4, ocl::KernelArg::PtrWriteOnly(tmpMat)); + ret = kernel.run(2, global, NULL, false); + if (!ret) + return false; + ret = ocl4dnn::ocl4dnnGEMV(ocl4dnn::CblasNoTrans, s[0], s[1], alpha, tmpMat, 0, oneMat, 0, 0.0f, devMat, 0); if (!ret) return false; } - kname = format("mvn%d", number); + String kname = format("mvn%d", number); if (normVariance) buildopt += "-DNORM_VARIANCE"; ocl::Kernel kernel1(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt); diff --git a/modules/dnn/src/opencl/mvn.cl b/modules/dnn/src/opencl/mvn.cl index c87667d..c1bf1f0 100644 --- a/modules/dnn/src/opencl/mvn.cl +++ b/modules/dnn/src/opencl/mvn.cl @@ -79,7 +79,7 @@ __kernel void CALC_MEAN(__global const Dtype* src, Dtype mean_val = mean[x]; vec_type src_vec = load(src, index); - vec_type dst_vec = pow(src_vec - (vec_type)mean_val, 2); + vec_type dst_vec = native_powr(src_vec - (vec_type)mean_val, 2); store(dst_vec, dst, index); } -- 2.7.4