return false;
int number = (s[1] % 8 == 0) ? 8 : ((s[1] % 4 == 0) ? 4 : 1);
- String buildopt = format("-DNUM=%d ", number);
- String kname = format("calc_mean%d", number);
- ocl::Kernel kernel(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt);
- if (kernel.empty())
- return false;
size_t global[] = { (size_t)s[0], (size_t)(s[1] / number) };
- kernel.set(0, ocl::KernelArg::PtrReadOnly(inpMat));
- kernel.set(1, (int)s[0]);
- kernel.set(2, (int)s[1]);
- kernel.set(3, ocl::KernelArg::PtrReadOnly(meanMat));
- kernel.set(4, ocl::KernelArg::PtrWriteOnly(tmpMat));
- ret = kernel.run(2, global, NULL, false);
- if (!ret)
- return false;
-
+ String buildopt = format("-DNUM=%d ", number);
if (normVariance)
{
+ String kname = format("calc_mean%d", number);
+ ocl::Kernel kernel(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt);
+ if (kernel.empty())
+ return false;
+
+ kernel.set(0, ocl::KernelArg::PtrReadOnly(inpMat));
+ kernel.set(1, (int)s[0]);
+ kernel.set(2, (int)s[1]);
+ kernel.set(3, ocl::KernelArg::PtrReadOnly(meanMat));
+ kernel.set(4, ocl::KernelArg::PtrWriteOnly(tmpMat));
+ ret = kernel.run(2, global, NULL, false);
+ if (!ret)
+ return false;
+
ret = ocl4dnn::ocl4dnnGEMV<float>(ocl4dnn::CblasNoTrans, s[0], s[1], alpha,
tmpMat, 0, oneMat, 0, 0.0f, devMat, 0);
if (!ret)
return false;
}
- kname = format("mvn%d", number);
+ String kname = format("mvn%d", number);
if (normVariance)
buildopt += "-DNORM_VARIANCE";
ocl::Kernel kernel1(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt);