ocl4dnnGEMV in case of row_size < 4
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Thu, 1 Feb 2018 08:35:35 +0000 (11:35 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Thu, 1 Feb 2018 11:06:47 +0000 (14:06 +0300)
modules/dnn/src/ocl4dnn/src/math_functions.cpp

index 5fe52ac..c52a8a9 100644 (file)
@@ -451,23 +451,27 @@ bool ocl4dnnGEMV<float>(const CBLAS_TRANSPOSE TransA,
 
         uint row_size = M;
         uint col_size = N;
-        size_t localsize[] = { 128 };
-        size_t globalsize[] = { row_size / 4 * localsize[0] };
-
-        uint argId = 0;
-        k.set(argId++, ocl::KernelArg::PtrReadOnly(A));
-        k.set(argId++, offA);
-        k.set(argId++, cl_uint(col_size));
-        k.set(argId++, cl_uint(col_size%4));
-        k.set(argId++, ocl::KernelArg::PtrReadOnly(x));
-        k.set(argId++, offx);
-        k.set(argId++, alpha);
-        k.set(argId++, beta);
-        k.set(argId++, ocl::KernelArg::PtrWriteOnly(y));
-        k.set(argId++, offy);
-        k.set(argId++, NULL, localsize[0] * sizeof(cl_float4));
-
-        ret = k.run(1, globalsize, localsize, false);
+
+        if (row_size >= 4)
+        {
+            size_t localsize[] = { 128 };
+            size_t globalsize[] = { row_size / 4 * localsize[0] };
+
+            uint argId = 0;
+            k.set(argId++, ocl::KernelArg::PtrReadOnly(A));
+            k.set(argId++, offA);
+            k.set(argId++, cl_uint(col_size));
+            k.set(argId++, cl_uint(col_size%4));
+            k.set(argId++, ocl::KernelArg::PtrReadOnly(x));
+            k.set(argId++, offx);
+            k.set(argId++, alpha);
+            k.set(argId++, beta);
+            k.set(argId++, ocl::KernelArg::PtrWriteOnly(y));
+            k.set(argId++, offy);
+            k.set(argId++, NULL, localsize[0] * sizeof(cl_float4));
+
+            ret = k.run(1, globalsize, localsize, false);
+        }
 
         if ((row_size % 4) != 0 && ret)
         {