permute layer ocl implementation
authorLi Peng <peng.li@intel.com>
Tue, 5 Dec 2017 13:58:57 +0000 (21:58 +0800)
committerLi Peng <peng.li@intel.com>
Tue, 5 Dec 2017 14:10:05 +0000 (22:10 +0800)
Signed-off-by: Li Peng <peng.li@intel.com>
modules/dnn/src/layers/permute_layer.cpp
modules/dnn/src/opencl/permute.cl [new file with mode: 0644]

index a21c5a6..664c24e 100644 (file)
@@ -44,6 +44,7 @@
 #include "layers_common.hpp"
 #include <float.h>
 #include <algorithm>
+#include "opencl_kernels_dnn.hpp"
 
 namespace cv
 {
@@ -173,6 +174,24 @@ public:
         CV_Assert((int)_numAxes == inp0.dims);
 
         computeStrides(shape(*inputs[0]), shape(outputs[0]));
+
+#ifdef HAVE_OPENCL
+        if (uorder.empty())
+        {
+            std::vector<int> orderVec(_order.begin(), _order.end());;
+            Mat morder(1, orderVec.size(), CV_32SC1, &orderVec[0]);
+
+            std::vector<int> oldStrideVec(_oldStride.begin(), _oldStride.end());
+            Mat mold_stride(1, _oldStride.size(), CV_32SC1, &oldStrideVec[0]);
+
+            std::vector<int> newStrideVec(_newStride.begin(), _newStride.end());
+            Mat mnew_stride(1, newStrideVec.size(), CV_32SC1, &newStrideVec[0]);
+
+            morder.copyTo(uorder);
+            mold_stride.copyTo(uold_stride);
+            mnew_stride.copyTo(unew_stride);
+        }
+#endif
     }
 
     class PermuteInvoker : public ParallelLoopBody
@@ -247,11 +266,47 @@ public:
         }
     };
 
+#ifdef HAVE_OPENCL
+    bool forward_ocl(InputArrayOfArrays inps, OutputArrayOfArrays outs, OutputArrayOfArrays internals)
+    {
+        std::vector<UMat> inputs;
+        std::vector<UMat> outputs;
+
+        inps.getUMatVector(inputs);
+        outs.getUMatVector(outputs);
+
+        if (!_needsPermute)
+            return false;
+
+        for (size_t i = 0; i < inputs.size(); i++)
+        {
+            ocl::Kernel kernel("permute", ocl::dnn::permute_oclsrc);
+
+            kernel.set(0, (int)_count);
+            kernel.set(1, ocl::KernelArg::PtrReadOnly(inputs[i]));
+            kernel.set(2, ocl::KernelArg::PtrReadOnly(uorder));
+            kernel.set(3, ocl::KernelArg::PtrReadOnly(uold_stride));
+            kernel.set(4, ocl::KernelArg::PtrReadOnly(unew_stride));
+            kernel.set(5, (int)_numAxes);
+            kernel.set(6, ocl::KernelArg::PtrWriteOnly(outputs[i]));
+
+            if (!kernel.run(1, &_count, NULL, false))
+                return false;
+        }
+
+        return true;
+    }
+#endif
+
     void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr)
     {
         CV_TRACE_FUNCTION();
         CV_TRACE_ARG_VALUE(name, "name", name.c_str());
 
+        CV_OCL_RUN((preferableTarget == DNN_TARGET_OPENCL) &&
+                   OCL_PERFORMANCE_CHECK(ocl::Device::getDefault().isIntel()),
+                   forward_ocl(inputs_arr, outputs_arr, internals_arr))
+
         Layer::forward_fallback(inputs_arr, outputs_arr, internals_arr);
     }
 
@@ -325,6 +380,10 @@ public:
     std::vector<size_t> _newStride;
     bool _needsPermute;
 
+#ifdef HAVE_OPENCL
+    UMat uorder, uold_stride, unew_stride;
+#endif
+
     size_t _numAxes;
 };
 
diff --git a/modules/dnn/src/opencl/permute.cl b/modules/dnn/src/opencl/permute.cl
new file mode 100644 (file)
index 0000000..38aa799
--- /dev/null
@@ -0,0 +1,67 @@
+/*M///////////////////////////////////////////////////////////////////////////////////////
+//
+//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
+//
+//  By downloading, copying, installing or using the software you agree to this license.
+//  If you do not agree to this license, do not download, install,
+//  copy or use the software.
+//
+//
+//                           License Agreement
+//                For Open Source Computer Vision Library
+//
+// Copyright (C) 2017, Intel Corporation, all rights reserved.
+// Copyright (c) 2016-2017 Fabian David Tschopp, all rights reserved.
+// Third party copyrights are property of their respective owners.
+//
+// Redistribution and use in source and binary forms, with or without modification,
+// are permitted provided that the following conditions are met:
+//
+//   * Redistribution's of source code must retain the above copyright notice,
+//     this list of conditions and the following disclaimer.
+//
+//   * Redistribution's in binary form must reproduce the above copyright notice,
+//     this list of conditions and the following disclaimer in the documentation
+//     and/or other materials provided with the distribution.
+//
+//   * The name of the copyright holders may not be used to endorse or promote products
+//     derived from this software without specific prior written permission.
+//
+// This software is provided by the copyright holders and contributors "as is" and
+// any express or implied warranties, including, but not limited to, the implied
+// warranties of merchantability and fitness for a particular purpose are disclaimed.
+// In no event shall the Intel Corporation or contributors be liable for any direct,
+// indirect, incidental, special, exemplary, or consequential damages
+// (including, but not limited to, procurement of substitute goods or services;
+// loss of use, data, or profits; or business interruption) however caused
+// and on any theory of liability, whether in contract, strict liability,
+// or tort (including negligence or otherwise) arising in any way out of
+// the use of this software, even if advised of the possibility of such damage.
+//
+//M*/
+
+#define Dtype float
+
+__kernel void permute(const int nthreads,
+                      __global Dtype* bottom_data,
+                      global int* permute_order,
+                      global int* oldStride,
+                      global int* newStride,
+                      const int num_axes,
+                      __global Dtype* top_data)
+{
+    for (int i = get_global_id(0); i < nthreads; i += get_global_size(0))
+    {
+        int oldPosition = 0;
+        int newPosition = i;
+
+        for (int j = 0; j < num_axes; ++j)
+        {
+            int order = permute_order[j];
+            oldPosition += (newPosition / newStride[j]) * oldStride[order];
+            newPosition %= newStride[j];
+        }
+
+        top_data[i] = bottom_data[oldPosition];
+    }
+}