From 389fa5d38eac015f9de03808b14272bc2b2265ac Mon Sep 17 00:00:00 2001 From: Li Peng Date: Fri, 2 Feb 2018 01:24:27 +0800 Subject: [PATCH] slice layer ocl update Signed-off-by: Li Peng --- modules/dnn/src/layers/slice_layer.cpp | 16 +++++----- modules/dnn/src/opencl/slice.cl | 54 +++++++++++++++------------------- 2 files changed, 31 insertions(+), 39 deletions(-) diff --git a/modules/dnn/src/layers/slice_layer.cpp b/modules/dnn/src/layers/slice_layer.cpp index 7db14e2..2d8b4dc 100644 --- a/modules/dnn/src/layers/slice_layer.cpp +++ b/modules/dnn/src/layers/slice_layer.cpp @@ -181,7 +181,8 @@ public: inputs_.getUMatVector(inputs); outputs_.getUMatVector(outputs); - if (inputs[0].dims < 4) + if (inputs[0].dims < 4 || (total(shape(outputs[0]), 0, 2) % 4 != 0) || + (total(shape(outputs[0]), 2) % 4 != 0)) return false; const UMat& inpMat = inputs[0]; @@ -192,22 +193,19 @@ public: int rows = outputs[i].size[2]; int cols = outputs[i].size[3]; - int number = (cols % 8 == 0) ? 8 : ((cols % 4 == 0) ? 4 : 1); - String buildopt = format("-DNUM=%d ", number); - String kname = format("slice%d", number); - ocl::Kernel kernel(kname.c_str(), ocl::dnn::slice_oclsrc, buildopt); - size_t global[] = { (size_t)groups * channels, (size_t)rows * cols / number }; + ocl::Kernel kernel("slice", ocl::dnn::slice_oclsrc); + size_t local[] = { 128 }; + size_t global[] = { (size_t)groups * channels / 4 * local[0] }; int idx = 0; kernel.set(idx++, ocl::KernelArg::PtrReadOnly(inpMat)); kernel.set(idx++, (int)(inpMat.size[2] * inpMat.size[3])); - kernel.set(idx++, (int)inpMat.size[3]); - kernel.set(idx++, (int)global[0]); kernel.set(idx++, (int)(rows * cols)); + kernel.set(idx++, (int)inpMat.size[3]); kernel.set(idx++, (int)cols); kernel.set(idx++, (int)sliceRanges[i][2].start); kernel.set(idx++, (int)sliceRanges[i][3].start); kernel.set(idx++, ocl::KernelArg::PtrWriteOnly(outputs[i])); - bool ret = kernel.run(2, global, NULL, false); + bool ret = kernel.run(1, global, local, false); if (!ret) return false; } diff --git a/modules/dnn/src/opencl/slice.cl b/modules/dnn/src/opencl/slice.cl index 81a7148..37ba17c 100644 --- a/modules/dnn/src/opencl/slice.cl +++ b/modules/dnn/src/opencl/slice.cl @@ -44,44 +44,38 @@ #define Dtype4 float4 #define Dtype8 float8 -#if NUM == 8 - #define load(src, index) vload8(0, src + index) - #define store(vec, dst, index) vstore8(vec, 0, dst + index) - #define vec_type Dtype8 - #define SLICE slice8 -#elif NUM == 4 - #define load(src, index) vload4(0, src + index) - #define store(vec, dst, index) vstore4(vec, 0, dst + index) - #define vec_type Dtype4 - #define SLICE slice4 -#elif NUM == 1 - #define load(src, index) src[index] - #define store(vec, dst, index) dst[index] = vec - #define vec_type Dtype - #define SLICE slice1 -#endif - -__kernel void SLICE(__global const Dtype* src, +__kernel void slice(__global const Dtype* src, const int src_plane_size, - const int src_cols, - const int channels, const int dst_plane_size, + const int src_cols, const int dst_cols, const int row_offset, const int col_offset, __global Dtype* dst) { - int x = get_global_id(0); - int y = get_global_id(1) * NUM; + unsigned int row_gid = get_group_id(0); + unsigned int lid = get_local_id(0); + const __global Dtype *src_read = src + row_gid * 4 * src_plane_size; + __global Dtype *dst_read = dst + row_gid * 4 * dst_plane_size; + Dtype4 a0, a1, a2, a3; + + int i = lid; + while( i < dst_plane_size / 4) + { + int row = (4 * i) / dst_cols + row_offset; + int col = (4 * i) % dst_cols + col_offset; + int src_index = row * src_cols + col; - if ((x >= channels) || (y >= dst_plane_size)) - return; + a0 = vload4(0, src_read + src_index); + a1 = vload4(0, src_read + src_index + src_plane_size); + a2 = vload4(0, src_read + src_index + 2 * src_plane_size); + a3 = vload4(0, src_read + src_index + 3 * src_plane_size); - int row = y / dst_cols + row_offset; - int col = y % dst_cols + col_offset; + vstore4(a0, i, dst_read); + vstore4(a1, i, dst_read + dst_plane_size); + vstore4(a2, i, dst_read + 2 * dst_plane_size); + vstore4(a3, i, dst_read + 3 * dst_plane_size); - int src_index = x * src_plane_size + row * src_cols + col; - int dst_index = x * dst_plane_size + y; - vec_type val = load(src, src_index); - store(val, dst, dst_index); + i += get_local_size(0); + } } -- 2.7.4