From 4b3d2c88349ae6dff5741c312927fcd0243f612d Mon Sep 17 00:00:00 2001 From: Alexander Alekhin Date: Tue, 15 Dec 2020 00:41:35 +0000 Subject: [PATCH] dnn(ocl): fix gemm kernels with beta=0 - dst is not initialized, may include NaN values - 0*NaN produces NaN --- modules/core/src/ocl.cpp | 11 +++- modules/dnn/src/ocl4dnn/src/math_functions.cpp | 78 +++++++++++++++----------- modules/dnn/src/opencl/gemm_buffer.cl | 76 ++++++++++++++++++++----- modules/dnn/test/test_onnx_importer.cpp | 3 - 4 files changed, 117 insertions(+), 51 deletions(-) diff --git a/modules/core/src/ocl.cpp b/modules/core/src/ocl.cpp index 24b18dc..781dad7 100644 --- a/modules/core/src/ocl.cpp +++ b/modules/core/src/ocl.cpp @@ -2949,6 +2949,15 @@ bool Kernel::empty() const return ptr() == 0; } +static cv::String dumpValue(size_t sz, const void* p) +{ + if (sz == 4) + return cv::format("%d / %uu / 0x%08x / %g", *(int*)p, *(int*)p, *(int*)p, *(float*)p); + if (sz == 8) + return cv::format("%lld / %lluu / 0x%16llx / %g", *(long long*)p, *(long long*)p, *(long long*)p, *(double*)p); + return cv::format("%p", p); +} + int Kernel::set(int i, const void* value, size_t sz) { if (!p || !p->handle) @@ -2959,7 +2968,7 @@ int Kernel::set(int i, const void* value, size_t sz) p->cleanupUMats(); cl_int retval = clSetKernelArg(p->handle, (cl_uint)i, sz, value); - CV_OCL_DBG_CHECK_RESULT(retval, cv::format("clSetKernelArg('%s', arg_index=%d, size=%d, value=%p)", p->name.c_str(), (int)i, (int)sz, (void*)value).c_str()); + CV_OCL_DBG_CHECK_RESULT(retval, cv::format("clSetKernelArg('%s', arg_index=%d, size=%d, value=%s)", p->name.c_str(), (int)i, (int)sz, dumpValue(sz, value).c_str()).c_str()); if (retval != CL_SUCCESS) return -1; return i+1; diff --git a/modules/dnn/src/ocl4dnn/src/math_functions.cpp b/modules/dnn/src/ocl4dnn/src/math_functions.cpp index 47224c3..e26a3c3 100644 --- a/modules/dnn/src/ocl4dnn/src/math_functions.cpp +++ b/modules/dnn/src/ocl4dnn/src/math_functions.cpp @@ -88,13 +88,13 @@ ocl::Image2D ocl4dnnGEMMCopyBufferToImage(UMat buffer, int offset, size_t global_copy[2]; global_copy[0] = width; global_copy[1] = height; - oclk_gemm_copy.set(0, ocl::KernelArg::PtrReadOnly(buffer)); - oclk_gemm_copy.set(1, image); - oclk_gemm_copy.set(2, offset); - oclk_gemm_copy.set(3, width); - oclk_gemm_copy.set(4, height); - oclk_gemm_copy.set(5, ld); - oclk_gemm_copy.run(2, global_copy, NULL, false); + oclk_gemm_copy + .args( + ocl::KernelArg::PtrReadOnly(buffer), + image, offset, + width, height, + ld) + .run(2, global_copy, NULL, false); } } else { if (!padding) @@ -112,13 +112,13 @@ ocl::Image2D ocl4dnnGEMMCopyBufferToImage(UMat buffer, int offset, global_copy[0] = padded_width; global_copy[1] = padded_height; - oclk_gemm_copy.set(0, ocl::KernelArg::PtrReadOnly(buffer)); - oclk_gemm_copy.set(1, image); - oclk_gemm_copy.set(2, offset); - oclk_gemm_copy.set(3, width); - oclk_gemm_copy.set(4, height); - oclk_gemm_copy.set(5, ld); - + oclk_gemm_copy + .args( + ocl::KernelArg::PtrReadOnly(buffer), + image, offset, + width, height, + ld) + .run(2, global_copy, NULL, false); oclk_gemm_copy.run(2, global_copy, NULL, false); } } @@ -465,8 +465,12 @@ static bool ocl4dnnFastBufferGEMM(const CBLAS_TRANSPOSE TransA, kernel_name += "_float"; } + bool isBetaZero = beta == 0; + String opts = format("-DTYPE=%d", halfPrecisionMode ? TYPE_HALF : TYPE_FLOAT); - ocl::Kernel oclk_gemm_float(kernel_name.c_str(), ocl::dnn::gemm_buffer_oclsrc, opts); + if (isBetaZero) + opts += " -DZERO_BETA=1"; + size_t local[2] = {}; size_t global[2] = {}; if (TransA == CblasNoTrans && TransB != CblasNoTrans && is_small_batch) { @@ -496,27 +500,37 @@ static bool ocl4dnnFastBufferGEMM(const CBLAS_TRANSPOSE TransA, local[1] = ly; } - int arg_idx = 0; - oclk_gemm_float.set(arg_idx++, ocl::KernelArg::PtrReadOnly(A)); - oclk_gemm_float.set(arg_idx++, offA); - oclk_gemm_float.set(arg_idx++, ocl::KernelArg::PtrReadOnly(B)); - oclk_gemm_float.set(arg_idx++, offB); - oclk_gemm_float.set(arg_idx++, ocl::KernelArg::PtrWriteOnly(C)); - oclk_gemm_float.set(arg_idx++, offC); - oclk_gemm_float.set(arg_idx++, M); - oclk_gemm_float.set(arg_idx++, N); - oclk_gemm_float.set(arg_idx++, K); - oclk_gemm_float.set(arg_idx++, (float)alpha); - oclk_gemm_float.set(arg_idx++, (float)beta); - bool ret = true; - if (TransB == CblasNoTrans || TransA != CblasNoTrans) { + if (TransB == CblasNoTrans || TransA != CblasNoTrans) + { + // _NN_ int stride = 256; for (int start_index = 0; start_index < K; start_index += stride) { - oclk_gemm_float.set(arg_idx, start_index); - ret = oclk_gemm_float.run(2, global, local, false); + ocl::Kernel oclk_gemm_float(kernel_name.c_str(), ocl::dnn::gemm_buffer_oclsrc, opts); + oclk_gemm_float.args( + ocl::KernelArg::PtrReadOnly(A), offA, + ocl::KernelArg::PtrReadOnly(B), offB, + isBetaZero ? ocl::KernelArg::PtrWriteOnly(C) : ocl::KernelArg::PtrReadWrite(C), offC, + M, N, K, + (float)alpha, (float)beta, + start_index + ); + ret &= oclk_gemm_float.run(2, global, local, false); } - } else { + } + else + { + // _NT_ + //C.reshape(1,1).setTo(0xfe00 /*FP16 NAN*/); // stable one-line reproducer for https://github.com/opencv/opencv/issues/18937 + //C.reshape(1,1).setTo(0); // non-optimal fixup (and not accurate) + ocl::Kernel oclk_gemm_float(kernel_name.c_str(), ocl::dnn::gemm_buffer_oclsrc, opts); + oclk_gemm_float.args( + ocl::KernelArg::PtrReadOnly(A), offA, + ocl::KernelArg::PtrReadOnly(B), offB, + isBetaZero ? ocl::KernelArg::PtrWriteOnly(C) : ocl::KernelArg::PtrReadWrite(C), offC, + M, N, K, + (float)alpha, (float)beta + ); ret = oclk_gemm_float.run(2, global, local, false); } return ret; diff --git a/modules/dnn/src/opencl/gemm_buffer.cl b/modules/dnn/src/opencl/gemm_buffer.cl index 8cbc34d..b345983 100644 --- a/modules/dnn/src/opencl/gemm_buffer.cl +++ b/modules/dnn/src/opencl/gemm_buffer.cl @@ -90,6 +90,12 @@ #pragma OPENCL EXTENSION cl_intel_subgroups : enable #endif +#ifdef ZERO_BETA +#define BETA_ZERO_CHECK(b0, v) (b0) +#else +#define BETA_ZERO_CHECK(b0, v) (v) +#endif + #define VEC_SIZE 4 #define LWG_HEIGHT 4 #define TILE_M 8 @@ -143,14 +149,14 @@ __kernel void TEMPLATE(gemm_buffer_NN, Dtype)( int row6 = mad24(global_y, TILE_M, 6) < M ? 6 : border; int row7 = mad24(global_y, TILE_M, 7) < M ? 7 : border; - Dtype4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0); - Dtype4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + 1 * N) : beta * vload4(0, dst_write0 + 1 * N); - Dtype4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N); - Dtype4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N); - Dtype4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N); - Dtype4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N); - Dtype4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N); - Dtype4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N); + Dtype4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0)); + Dtype4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + 1 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 1 * N)); + Dtype4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 2 * N)); + Dtype4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 3 * N)); + Dtype4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 4 * N)); + Dtype4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 5 * N)); + Dtype4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 6 * N)); + Dtype4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 7 * N)); int end_index = min(start_index + 256, K); int w = start_index; @@ -579,7 +585,7 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)( output = (local_x == 5) ? _dot.s5 : output; \ output = (local_x == 6) ? _dot.s6 : output; \ output = (local_x == 7) ? _dot.s7 : output; \ - dst_write0[0] = mad(output, alpha, beta * dst_write0[0]); \ + dst_write0[0] = BETA_ZERO_CHECK(alpha * output, mad(output, alpha, beta * dst_write0[0])); \ dst_write0 += N; if(global_x < N && global_y * 8 < M) { @@ -765,7 +771,7 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)( output = (local_x == 5) ? _dot.s5 : output; \ output = (local_x == 6) ? _dot.s6 : output; \ output = (local_x == 7) ? _dot.s7 : output; \ - dst_write0[0] = mad(output, alpha, beta * dst_write0[0]); \ + dst_write0[0] = BETA_ZERO_CHECK(alpha * output, mad(output, alpha, beta * dst_write0[0])); \ dst_write0 += N; if(global_x < N && global_y * 8 < M) { @@ -819,8 +825,9 @@ void TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype)( const Dtype4 b1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]}; #pragma unroll for(int j = 0; j < rows; ++j) { - dot0[j] += b0 * vload4(i, srcb_read + j * K); - dot1[j] += b1 * vload4(i, srcb_read + j * K); + Dtype4 a = vload4(i, srcb_read + j * K); + dot0[j] += b0 * a; + dot1[j] += b1 * a; } i += get_local_size(0); @@ -859,11 +866,19 @@ void TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype)( } } + barrier(CLK_LOCAL_MEM_FENCE); if(lid == 0) { #pragma unroll for(int j = 0; j < rows; ++j) { - dstc0[(x_gid * 4 + j)] = alpha * work_each0[j] + beta * dstc0[(x_gid * 4 + j)]; - dstc1[(x_gid * 4 + j)] = alpha * work_each1[j] + beta * dstc1[(x_gid * 4 + j)]; +#ifdef ZERO_BETA + Dtype a0 = alpha * work_each0[j]; + Dtype a1 = alpha * work_each1[j]; +#else + Dtype a0 = alpha * work_each0[j] + beta * dstc0[(x_gid * 4 + j)]; + Dtype a1 = alpha * work_each1[j] + beta * dstc1[(x_gid * 4 + j)]; +#endif + dstc0[(x_gid * 4 + j)] = a0; + dstc1[(x_gid * 4 + j)] = a1; } } } @@ -952,9 +967,15 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_2,Dtype)( } } - if(lid == 0) { + if(lid == 0) + { +#ifdef ZERO_BETA + dstc0[x_gid] = alpha * work0[0]; + dstc1[x_gid] = alpha * work1[0]; +#else dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid]; dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid]; +#endif } } } @@ -1058,10 +1079,17 @@ void TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype)( if(lid == 0) { #pragma unroll for(int j = 0; j < rows; ++j) { +#ifdef ZERO_BETA + dstc0[(x_gid * 4 + j)] = alpha * work_each0[j]; + dstc1[(x_gid * 4 + j)] = alpha * work_each1[j]; + dstc2[(x_gid * 4 + j)] = alpha * work_each2[j]; + dstc3[(x_gid * 4 + j)] = alpha * work_each3[j]; +#else dstc0[(x_gid * 4 + j)] = alpha * work_each0[j] + beta * dstc0[(x_gid * 4 + j)]; dstc1[(x_gid * 4 + j)] = alpha * work_each1[j] + beta * dstc1[(x_gid * 4 + j)]; dstc2[(x_gid * 4 + j)] = alpha * work_each2[j] + beta * dstc2[(x_gid * 4 + j)]; dstc3[(x_gid * 4 + j)] = alpha * work_each3[j] + beta * dstc3[(x_gid * 4 + j)]; +#endif } } } @@ -1179,10 +1207,17 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_4,Dtype)( } if(lid == 0) { +#ifdef ZERO_BETA + dstc0[x_gid] = alpha * work0[0]; + dstc1[x_gid] = alpha * work1[0]; + dstc2[x_gid] = alpha * work2[0]; + dstc3[x_gid] = alpha * work3[0]; +#else dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid]; dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid]; dstc2[x_gid] = alpha * work2[0] + beta * dstc2[x_gid]; dstc3[x_gid] = alpha * work3[0] + beta * dstc3[x_gid]; +#endif } } } @@ -1320,6 +1355,16 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_8,Dtype)( } if(lid == 0) { +#ifdef ZERO_BETA + dstc0[x_gid] = alpha * work0[0]; + dstc1[x_gid] = alpha * work1[0]; + dstc2[x_gid] = alpha * work2[0]; + dstc3[x_gid] = alpha * work3[0]; + dstc4[x_gid] = alpha * work4[0]; + dstc5[x_gid] = alpha * work5[0]; + dstc6[x_gid] = alpha * work6[0]; + dstc7[x_gid] = alpha * work7[0]; +#else dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid]; dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid]; dstc2[x_gid] = alpha * work2[0] + beta * dstc2[x_gid]; @@ -1328,6 +1373,7 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_8,Dtype)( dstc5[x_gid] = alpha * work5[0] + beta * dstc5[x_gid]; dstc6[x_gid] = alpha * work6[0] + beta * dstc6[x_gid]; dstc7[x_gid] = alpha * work7[0] + beta * dstc7[x_gid]; +#endif } } #undef SLM_SIZE diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 9ba10d4..f38ca67 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -718,9 +718,6 @@ TEST_P(Test_ONNX_layers, Conv1d_variable_weight_bias) TEST_P(Test_ONNX_layers, GatherMultiOutput) { - if (cvtest::skipUnstableTests && backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_OPENCL_FP16) - throw SkipTestException("Skip unstable test: https://github.com/opencv/opencv/issues/18937"); - #if defined(INF_ENGINE_RELEASE) if (target == DNN_TARGET_MYRIAD) applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE); -- 2.7.4