slice layer ocl update
authorLi Peng <peng.li@intel.com>
Thu, 1 Feb 2018 17:24:27 +0000 (01:24 +0800)
committerLi Peng <peng.li@intel.com>
Tue, 6 Feb 2018 14:59:47 +0000 (22:59 +0800)
Signed-off-by: Li Peng <peng.li@intel.com>
modules/dnn/src/layers/slice_layer.cpp
modules/dnn/src/opencl/slice.cl

index 7db14e2..2d8b4dc 100644 (file)
@@ -181,7 +181,8 @@ public:
         inputs_.getUMatVector(inputs);
         outputs_.getUMatVector(outputs);
 
-        if (inputs[0].dims < 4)
+        if (inputs[0].dims < 4 || (total(shape(outputs[0]), 0, 2) % 4 != 0) ||
+            (total(shape(outputs[0]), 2) % 4 != 0))
             return false;
 
         const UMat& inpMat = inputs[0];
@@ -192,22 +193,19 @@ public:
             int rows = outputs[i].size[2];
             int cols = outputs[i].size[3];
 
-            int number = (cols % 8 == 0) ? 8 : ((cols % 4 == 0) ? 4 : 1);
-            String buildopt = format("-DNUM=%d ", number);
-            String kname = format("slice%d", number);
-            ocl::Kernel kernel(kname.c_str(), ocl::dnn::slice_oclsrc, buildopt);
-            size_t global[] = { (size_t)groups * channels, (size_t)rows * cols / number };
+            ocl::Kernel kernel("slice", ocl::dnn::slice_oclsrc);
+            size_t local[] = { 128 };
+            size_t global[] = { (size_t)groups * channels / 4 * local[0] };
             int idx = 0;
             kernel.set(idx++, ocl::KernelArg::PtrReadOnly(inpMat));
             kernel.set(idx++, (int)(inpMat.size[2] * inpMat.size[3]));
-            kernel.set(idx++, (int)inpMat.size[3]);
-            kernel.set(idx++, (int)global[0]);
             kernel.set(idx++, (int)(rows * cols));
+            kernel.set(idx++, (int)inpMat.size[3]);
             kernel.set(idx++, (int)cols);
             kernel.set(idx++, (int)sliceRanges[i][2].start);
             kernel.set(idx++, (int)sliceRanges[i][3].start);
             kernel.set(idx++, ocl::KernelArg::PtrWriteOnly(outputs[i]));
-            bool ret = kernel.run(2, global, NULL, false);
+            bool ret = kernel.run(1, global, local, false);
             if (!ret)
                 return false;
         }
index 81a7148..37ba17c 100644 (file)
 #define Dtype4 float4
 #define Dtype8 float8
 
-#if NUM == 8
-    #define load(src, index) vload8(0, src + index)
-    #define store(vec, dst, index) vstore8(vec, 0, dst + index)
-    #define vec_type Dtype8
-    #define SLICE slice8
-#elif NUM == 4
-    #define load(src, index) vload4(0, src + index)
-    #define store(vec, dst, index) vstore4(vec, 0, dst + index)
-    #define vec_type Dtype4
-    #define SLICE slice4
-#elif NUM == 1
-    #define load(src, index) src[index]
-    #define store(vec, dst, index) dst[index] = vec
-    #define vec_type Dtype
-    #define SLICE slice1
-#endif
-
-__kernel void SLICE(__global const Dtype* src,
+__kernel void slice(__global const Dtype* src,
                     const int src_plane_size,
-                    const int src_cols,
-                    const int channels,
                     const int dst_plane_size,
+                    const int src_cols,
                     const int dst_cols,
                     const int row_offset,
                     const int col_offset,
                     __global Dtype* dst)
 {
-    int x = get_global_id(0);
-    int y = get_global_id(1) * NUM;
+    unsigned int row_gid = get_group_id(0);
+    unsigned int lid = get_local_id(0);
+    const __global Dtype *src_read = src + row_gid * 4 * src_plane_size;
+    __global Dtype *dst_read = dst + row_gid * 4 * dst_plane_size;
+    Dtype4 a0, a1, a2, a3;
+
+    int i = lid;
+    while( i < dst_plane_size / 4)
+    {
+        int row = (4 * i) / dst_cols + row_offset;
+        int col = (4 * i) % dst_cols + col_offset;
+        int src_index = row * src_cols + col;
 
-    if ((x >= channels) || (y >= dst_plane_size))
-        return;
+        a0 = vload4(0, src_read + src_index);
+        a1 = vload4(0, src_read + src_index + src_plane_size);
+        a2 = vload4(0, src_read + src_index + 2 * src_plane_size);
+        a3 = vload4(0, src_read + src_index + 3 * src_plane_size);
 
-    int row = y / dst_cols + row_offset;
-    int col = y % dst_cols + col_offset;
+        vstore4(a0, i, dst_read);
+        vstore4(a1, i, dst_read + dst_plane_size);
+        vstore4(a2, i, dst_read + 2 * dst_plane_size);
+        vstore4(a3, i, dst_read + 3 * dst_plane_size);
 
-    int src_index = x * src_plane_size + row * src_cols + col;
-    int dst_index = x * dst_plane_size + y;
-    vec_type val = load(src, src_index);
-    store(val, dst, dst_index);
+        i += get_local_size(0);
+    }
 }