inputs_.getUMatVector(inputs);
outputs_.getUMatVector(outputs);
- if (inputs[0].dims < 4)
+ if (inputs[0].dims < 4 || (total(shape(outputs[0]), 0, 2) % 4 != 0) ||
+ (total(shape(outputs[0]), 2) % 4 != 0))
return false;
const UMat& inpMat = inputs[0];
int rows = outputs[i].size[2];
int cols = outputs[i].size[3];
- int number = (cols % 8 == 0) ? 8 : ((cols % 4 == 0) ? 4 : 1);
- String buildopt = format("-DNUM=%d ", number);
- String kname = format("slice%d", number);
- ocl::Kernel kernel(kname.c_str(), ocl::dnn::slice_oclsrc, buildopt);
- size_t global[] = { (size_t)groups * channels, (size_t)rows * cols / number };
+ ocl::Kernel kernel("slice", ocl::dnn::slice_oclsrc);
+ size_t local[] = { 128 };
+ size_t global[] = { (size_t)groups * channels / 4 * local[0] };
int idx = 0;
kernel.set(idx++, ocl::KernelArg::PtrReadOnly(inpMat));
kernel.set(idx++, (int)(inpMat.size[2] * inpMat.size[3]));
- kernel.set(idx++, (int)inpMat.size[3]);
- kernel.set(idx++, (int)global[0]);
kernel.set(idx++, (int)(rows * cols));
+ kernel.set(idx++, (int)inpMat.size[3]);
kernel.set(idx++, (int)cols);
kernel.set(idx++, (int)sliceRanges[i][2].start);
kernel.set(idx++, (int)sliceRanges[i][3].start);
kernel.set(idx++, ocl::KernelArg::PtrWriteOnly(outputs[i]));
- bool ret = kernel.run(2, global, NULL, false);
+ bool ret = kernel.run(1, global, local, false);
if (!ret)
return false;
}
#define Dtype4 float4
#define Dtype8 float8
-#if NUM == 8
- #define load(src, index) vload8(0, src + index)
- #define store(vec, dst, index) vstore8(vec, 0, dst + index)
- #define vec_type Dtype8
- #define SLICE slice8
-#elif NUM == 4
- #define load(src, index) vload4(0, src + index)
- #define store(vec, dst, index) vstore4(vec, 0, dst + index)
- #define vec_type Dtype4
- #define SLICE slice4
-#elif NUM == 1
- #define load(src, index) src[index]
- #define store(vec, dst, index) dst[index] = vec
- #define vec_type Dtype
- #define SLICE slice1
-#endif
-
-__kernel void SLICE(__global const Dtype* src,
+__kernel void slice(__global const Dtype* src,
const int src_plane_size,
- const int src_cols,
- const int channels,
const int dst_plane_size,
+ const int src_cols,
const int dst_cols,
const int row_offset,
const int col_offset,
__global Dtype* dst)
{
- int x = get_global_id(0);
- int y = get_global_id(1) * NUM;
+ unsigned int row_gid = get_group_id(0);
+ unsigned int lid = get_local_id(0);
+ const __global Dtype *src_read = src + row_gid * 4 * src_plane_size;
+ __global Dtype *dst_read = dst + row_gid * 4 * dst_plane_size;
+ Dtype4 a0, a1, a2, a3;
+
+ int i = lid;
+ while( i < dst_plane_size / 4)
+ {
+ int row = (4 * i) / dst_cols + row_offset;
+ int col = (4 * i) % dst_cols + col_offset;
+ int src_index = row * src_cols + col;
- if ((x >= channels) || (y >= dst_plane_size))
- return;
+ a0 = vload4(0, src_read + src_index);
+ a1 = vload4(0, src_read + src_index + src_plane_size);
+ a2 = vload4(0, src_read + src_index + 2 * src_plane_size);
+ a3 = vload4(0, src_read + src_index + 3 * src_plane_size);
- int row = y / dst_cols + row_offset;
- int col = y % dst_cols + col_offset;
+ vstore4(a0, i, dst_read);
+ vstore4(a1, i, dst_read + dst_plane_size);
+ vstore4(a2, i, dst_read + 2 * dst_plane_size);
+ vstore4(a3, i, dst_read + 3 * dst_plane_size);
- int src_index = x * src_plane_size + row * src_cols + col;
- int dst_index = x * dst_plane_size + y;
- vec_type val = load(src, src_index);
- store(val, dst, dst_index);
+ i += get_local_size(0);
+ }
}