dnn(ocl): fix gemm kernels with beta=0
authorAlexander Alekhin <alexander.a.alekhin@gmail.com>
Tue, 15 Dec 2020 00:41:35 +0000 (00:41 +0000)
committerAlexander Alekhin <alexander.a.alekhin@gmail.com>
Tue, 15 Dec 2020 00:58:43 +0000 (00:58 +0000)
- dst is not initialized, may include NaN values
- 0*NaN produces NaN

modules/core/src/ocl.cpp
modules/dnn/src/ocl4dnn/src/math_functions.cpp
modules/dnn/src/opencl/gemm_buffer.cl
modules/dnn/test/test_onnx_importer.cpp

index 24b18dc..781dad7 100644 (file)
@@ -2949,6 +2949,15 @@ bool Kernel::empty() const
     return ptr() == 0;
 }
 
+static cv::String dumpValue(size_t sz, const void* p)
+{
+    if (sz == 4)
+        return cv::format("%d / %uu / 0x%08x / %g", *(int*)p, *(int*)p, *(int*)p, *(float*)p);
+    if (sz == 8)
+        return cv::format("%lld / %lluu / 0x%16llx / %g", *(long long*)p, *(long long*)p, *(long long*)p, *(double*)p);
+    return cv::format("%p", p);
+}
+
 int Kernel::set(int i, const void* value, size_t sz)
 {
     if (!p || !p->handle)
@@ -2959,7 +2968,7 @@ int Kernel::set(int i, const void* value, size_t sz)
         p->cleanupUMats();
 
     cl_int retval = clSetKernelArg(p->handle, (cl_uint)i, sz, value);
-    CV_OCL_DBG_CHECK_RESULT(retval, cv::format("clSetKernelArg('%s', arg_index=%d, size=%d, value=%p)", p->name.c_str(), (int)i, (int)sz, (void*)value).c_str());
+    CV_OCL_DBG_CHECK_RESULT(retval, cv::format("clSetKernelArg('%s', arg_index=%d, size=%d, value=%s)", p->name.c_str(), (int)i, (int)sz, dumpValue(sz, value).c_str()).c_str());
     if (retval != CL_SUCCESS)
         return -1;
     return i+1;
index 47224c3..e26a3c3 100644 (file)
@@ -88,13 +88,13 @@ ocl::Image2D ocl4dnnGEMMCopyBufferToImage(UMat buffer, int offset,
             size_t global_copy[2];
             global_copy[0] = width;
             global_copy[1] = height;
-            oclk_gemm_copy.set(0, ocl::KernelArg::PtrReadOnly(buffer));
-            oclk_gemm_copy.set(1, image);
-            oclk_gemm_copy.set(2, offset);
-            oclk_gemm_copy.set(3, width);
-            oclk_gemm_copy.set(4, height);
-            oclk_gemm_copy.set(5, ld);
-            oclk_gemm_copy.run(2, global_copy, NULL, false);
+            oclk_gemm_copy
+                .args(
+                    ocl::KernelArg::PtrReadOnly(buffer),
+                    image, offset,
+                    width, height,
+                    ld)
+                .run(2, global_copy, NULL, false);
         }
     } else {
         if (!padding)
@@ -112,13 +112,13 @@ ocl::Image2D ocl4dnnGEMMCopyBufferToImage(UMat buffer, int offset,
             global_copy[0] = padded_width;
             global_copy[1] = padded_height;
 
-            oclk_gemm_copy.set(0, ocl::KernelArg::PtrReadOnly(buffer));
-            oclk_gemm_copy.set(1, image);
-            oclk_gemm_copy.set(2, offset);
-            oclk_gemm_copy.set(3, width);
-            oclk_gemm_copy.set(4, height);
-            oclk_gemm_copy.set(5, ld);
-
+            oclk_gemm_copy
+                .args(
+                    ocl::KernelArg::PtrReadOnly(buffer),
+                    image, offset,
+                    width, height,
+                    ld)
+                .run(2, global_copy, NULL, false);
             oclk_gemm_copy.run(2, global_copy, NULL, false);
         }
     }
@@ -465,8 +465,12 @@ static bool ocl4dnnFastBufferGEMM(const CBLAS_TRANSPOSE TransA,
         kernel_name += "_float";
     }
 
+    bool isBetaZero = beta == 0;
+
     String opts = format("-DTYPE=%d", halfPrecisionMode ? TYPE_HALF : TYPE_FLOAT);
-    ocl::Kernel oclk_gemm_float(kernel_name.c_str(), ocl::dnn::gemm_buffer_oclsrc, opts);
+    if (isBetaZero)
+        opts += " -DZERO_BETA=1";
+
     size_t local[2] = {};
     size_t global[2] = {};
     if (TransA == CblasNoTrans && TransB != CblasNoTrans && is_small_batch) {
@@ -496,27 +500,37 @@ static bool ocl4dnnFastBufferGEMM(const CBLAS_TRANSPOSE TransA,
         local[1] = ly;
     }
 
-    int arg_idx = 0;
-    oclk_gemm_float.set(arg_idx++, ocl::KernelArg::PtrReadOnly(A));
-    oclk_gemm_float.set(arg_idx++, offA);
-    oclk_gemm_float.set(arg_idx++, ocl::KernelArg::PtrReadOnly(B));
-    oclk_gemm_float.set(arg_idx++, offB);
-    oclk_gemm_float.set(arg_idx++, ocl::KernelArg::PtrWriteOnly(C));
-    oclk_gemm_float.set(arg_idx++, offC);
-    oclk_gemm_float.set(arg_idx++, M);
-    oclk_gemm_float.set(arg_idx++, N);
-    oclk_gemm_float.set(arg_idx++, K);
-    oclk_gemm_float.set(arg_idx++, (float)alpha);
-    oclk_gemm_float.set(arg_idx++, (float)beta);
-
     bool ret = true;
-    if (TransB == CblasNoTrans || TransA != CblasNoTrans) {
+    if (TransB == CblasNoTrans || TransA != CblasNoTrans)
+    {
+        // _NN_
         int stride = 256;
         for (int start_index = 0; start_index < K; start_index += stride) {
-            oclk_gemm_float.set(arg_idx, start_index);
-            ret = oclk_gemm_float.run(2, global, local, false);
+            ocl::Kernel oclk_gemm_float(kernel_name.c_str(), ocl::dnn::gemm_buffer_oclsrc, opts);
+            oclk_gemm_float.args(
+                ocl::KernelArg::PtrReadOnly(A), offA,
+                ocl::KernelArg::PtrReadOnly(B), offB,
+                isBetaZero ? ocl::KernelArg::PtrWriteOnly(C) : ocl::KernelArg::PtrReadWrite(C), offC,
+                M, N, K,
+                (float)alpha, (float)beta,
+                start_index
+            );
+            ret &= oclk_gemm_float.run(2, global, local, false);
         }
-    } else {
+    }
+    else
+    {
+        // _NT_
+        //C.reshape(1,1).setTo(0xfe00 /*FP16 NAN*/);  // stable one-line reproducer for https://github.com/opencv/opencv/issues/18937
+        //C.reshape(1,1).setTo(0);  // non-optimal fixup (and not accurate)
+        ocl::Kernel oclk_gemm_float(kernel_name.c_str(), ocl::dnn::gemm_buffer_oclsrc, opts);
+        oclk_gemm_float.args(
+            ocl::KernelArg::PtrReadOnly(A), offA,
+            ocl::KernelArg::PtrReadOnly(B), offB,
+            isBetaZero ? ocl::KernelArg::PtrWriteOnly(C) : ocl::KernelArg::PtrReadWrite(C), offC,
+            M, N, K,
+            (float)alpha, (float)beta
+        );
         ret = oclk_gemm_float.run(2, global, local, false);
     }
     return ret;
index 8cbc34d..b345983 100644 (file)
 #pragma OPENCL EXTENSION  cl_intel_subgroups : enable
 #endif
 
+#ifdef ZERO_BETA
+#define BETA_ZERO_CHECK(b0, v)  (b0)
+#else
+#define BETA_ZERO_CHECK(b0, v)  (v)
+#endif
+
 #define VEC_SIZE        4
 #define LWG_HEIGHT      4
 #define TILE_M          8
@@ -143,14 +149,14 @@ __kernel void TEMPLATE(gemm_buffer_NN, Dtype)(
     int row6 = mad24(global_y, TILE_M, 6) < M ? 6 : border;
     int row7 = mad24(global_y, TILE_M, 7) < M ? 7 : border;
 
-    Dtype4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0);
-    Dtype4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + 1 * N) : beta * vload4(0, dst_write0 + 1 * N);
-    Dtype4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N);
-    Dtype4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N);
-    Dtype4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N);
-    Dtype4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N);
-    Dtype4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N);
-    Dtype4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N);
+    Dtype4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0));
+    Dtype4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + 1 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 1 * N));
+    Dtype4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 2 * N));
+    Dtype4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 3 * N));
+    Dtype4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 4 * N));
+    Dtype4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 5 * N));
+    Dtype4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 6 * N));
+    Dtype4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 7 * N));
 
     int end_index = min(start_index + 256, K);
     int w = start_index;
@@ -579,7 +585,7 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)(
     output = (local_x == 5) ? _dot.s5 : output; \
     output = (local_x == 6) ? _dot.s6 : output; \
     output = (local_x == 7) ? _dot.s7 : output; \
-    dst_write0[0] = mad(output, alpha, beta * dst_write0[0]); \
+    dst_write0[0] = BETA_ZERO_CHECK(alpha * output, mad(output, alpha, beta * dst_write0[0])); \
     dst_write0 += N;
 
     if(global_x < N && global_y * 8 < M) {
@@ -765,7 +771,7 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)(
     output = (local_x == 5) ? _dot.s5 : output; \
     output = (local_x == 6) ? _dot.s6 : output; \
     output = (local_x == 7) ? _dot.s7 : output; \
-    dst_write0[0] = mad(output, alpha, beta * dst_write0[0]); \
+    dst_write0[0] = BETA_ZERO_CHECK(alpha * output, mad(output, alpha, beta * dst_write0[0])); \
     dst_write0 += N;
 
     if(global_x < N && global_y * 8 < M) {
@@ -819,8 +825,9 @@ void TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype)(
     const Dtype4 b1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]};
 #pragma unroll
     for(int j = 0; j < rows; ++j) {
-      dot0[j] += b0 * vload4(i, srcb_read + j * K);
-      dot1[j] += b1 * vload4(i, srcb_read + j * K);
+      Dtype4 a = vload4(i, srcb_read + j * K);
+      dot0[j] += b0 * a;
+      dot1[j] += b1 * a;
     }
 
     i += get_local_size(0);
@@ -859,11 +866,19 @@ void TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype)(
     }
   }
 
+  barrier(CLK_LOCAL_MEM_FENCE);
   if(lid == 0) {
 #pragma unroll
     for(int j = 0; j < rows; ++j) {
-      dstc0[(x_gid * 4  + j)] = alpha * work_each0[j] + beta * dstc0[(x_gid * 4 + j)];
-      dstc1[(x_gid * 4  + j)] = alpha * work_each1[j] + beta * dstc1[(x_gid * 4 + j)];
+#ifdef ZERO_BETA
+      Dtype a0 = alpha * work_each0[j];
+      Dtype a1 = alpha * work_each1[j];
+#else
+      Dtype a0 = alpha * work_each0[j] + beta * dstc0[(x_gid * 4 + j)];
+      Dtype a1 = alpha * work_each1[j] + beta * dstc1[(x_gid * 4 + j)];
+#endif
+      dstc0[(x_gid * 4  + j)] = a0;
+      dstc1[(x_gid * 4  + j)] = a1;
     }
   }
 }
@@ -952,9 +967,15 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_2,Dtype)(
       }
     }
 
-    if(lid == 0) {
+    if(lid == 0)
+    {
+#ifdef ZERO_BETA
+      dstc0[x_gid] = alpha * work0[0];
+      dstc1[x_gid] = alpha * work1[0];
+#else
       dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid];
       dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid];
+#endif
     }
   }
 }
@@ -1058,10 +1079,17 @@ void TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype)(
   if(lid == 0) {
 #pragma unroll
     for(int j = 0; j < rows; ++j) {
+#ifdef ZERO_BETA
+      dstc0[(x_gid * 4  + j)] = alpha * work_each0[j];
+      dstc1[(x_gid * 4  + j)] = alpha * work_each1[j];
+      dstc2[(x_gid * 4  + j)] = alpha * work_each2[j];
+      dstc3[(x_gid * 4  + j)] = alpha * work_each3[j];
+#else
       dstc0[(x_gid * 4  + j)] = alpha * work_each0[j] + beta * dstc0[(x_gid * 4 + j)];
       dstc1[(x_gid * 4  + j)] = alpha * work_each1[j] + beta * dstc1[(x_gid * 4 + j)];
       dstc2[(x_gid * 4  + j)] = alpha * work_each2[j] + beta * dstc2[(x_gid * 4 + j)];
       dstc3[(x_gid * 4  + j)] = alpha * work_each3[j] + beta * dstc3[(x_gid * 4 + j)];
+#endif
     }
   }
 }
@@ -1179,10 +1207,17 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_4,Dtype)(
     }
 
     if(lid == 0) {
+#ifdef ZERO_BETA
+      dstc0[x_gid] = alpha * work0[0];
+      dstc1[x_gid] = alpha * work1[0];
+      dstc2[x_gid] = alpha * work2[0];
+      dstc3[x_gid] = alpha * work3[0];
+#else
       dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid];
       dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid];
       dstc2[x_gid] = alpha * work2[0] + beta * dstc2[x_gid];
       dstc3[x_gid] = alpha * work3[0] + beta * dstc3[x_gid];
+#endif
     }
   }
 }
@@ -1320,6 +1355,16 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_8,Dtype)(
   }
 
   if(lid == 0) {
+#ifdef ZERO_BETA
+    dstc0[x_gid] = alpha * work0[0];
+    dstc1[x_gid] = alpha * work1[0];
+    dstc2[x_gid] = alpha * work2[0];
+    dstc3[x_gid] = alpha * work3[0];
+    dstc4[x_gid] = alpha * work4[0];
+    dstc5[x_gid] = alpha * work5[0];
+    dstc6[x_gid] = alpha * work6[0];
+    dstc7[x_gid] = alpha * work7[0];
+#else
     dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid];
     dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid];
     dstc2[x_gid] = alpha * work2[0] + beta * dstc2[x_gid];
@@ -1328,6 +1373,7 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_8,Dtype)(
     dstc5[x_gid] = alpha * work5[0] + beta * dstc5[x_gid];
     dstc6[x_gid] = alpha * work6[0] + beta * dstc6[x_gid];
     dstc7[x_gid] = alpha * work7[0] + beta * dstc7[x_gid];
+#endif
   }
 }
 #undef SLM_SIZE
index 9ba10d4..f38ca67 100644 (file)
@@ -718,9 +718,6 @@ TEST_P(Test_ONNX_layers, Conv1d_variable_weight_bias)
 
 TEST_P(Test_ONNX_layers, GatherMultiOutput)
 {
-    if (cvtest::skipUnstableTests && backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_OPENCL_FP16)
-        throw SkipTestException("Skip unstable test: https://github.com/opencv/opencv/issues/18937");
-
 #if defined(INF_ENGINE_RELEASE)
     if (target == DNN_TARGET_MYRIAD)
         applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE);