support axis in concat layer ocl path
authorLi Peng <peng.li@intel.com>
Tue, 28 Nov 2017 15:40:46 +0000 (23:40 +0800)
committerLi Peng <peng.li@intel.com>
Wed, 6 Dec 2017 18:26:46 +0000 (02:26 +0800)
Signed-off-by: Li Peng <peng.li@intel.com>
modules/dnn/src/layers/concat_layer.cpp

index e51e1f7..e49f22d 100644 (file)
@@ -185,12 +185,13 @@ public:
         outs.getUMatVector(outputs);
 
         int cAxis = clamp(axis, inputs[0].dims);
-        if (!(cAxis == 1 && outputs[0].dims == 4 && !padding))
+        if (padding)
             return false;
 
         int bottom_concat_axis;
-        int concat_size = inputs[0].size[2] * inputs[0].size[3];
-        int top_concat_axis = outputs[0].size[1];
+        int concat_size = total(shape(inputs[0]), cAxis + 1);
+        int top_concat_axis = outputs[0].size[cAxis];
+        int num_concats = total(shape(inputs[0]), 0, cAxis);
         int offset_concat_axis = 0;
         UMat& outMat = outputs[0];
         String buildopt = String("-DDtype=") + ocl::typeToStr(inputs[0].type()) + String(" ");
@@ -202,12 +203,12 @@ public:
                 return false;
 
             UMat& inpMat = inputs[i];
-            bottom_concat_axis = inputs[i].size[1];
+            bottom_concat_axis = inputs[i].size[cAxis];
             size_t nthreads = inputs[i].total();
 
             kernel.set(0, (int)nthreads);
             kernel.set(1, ocl::KernelArg::PtrReadOnly(inpMat));
-            kernel.set(2, (int)inputs[i].size[0]);
+            kernel.set(2, (int)num_concats);
             kernel.set(3, (int)concat_size);
             kernel.set(4, (int)top_concat_axis);
             kernel.set(5, (int)bottom_concat_axis);